Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 99 additions & 11 deletions datafusion/functions-table/src/generate_series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,21 @@ impl GenerateSeriesTable {
step,
include_end,
name,
} => Arc::new(RwLock::new(GenericSeriesState {
schema: self.schema(),
start: *start,
end: *end,
step: *step,
current: *start,
batch_size,
include_end: *include_end,
name,
})),
} => {
let (end, include_end) =
normalize_int64_series(*start, *end, *step, *include_end);

Arc::new(RwLock::new(GenericSeriesState {
schema: self.schema(),
start: *start,
end,
step: *step,
current: *start,
batch_size,
include_end,
name,
}))
}
GenSeriesArgs::TimestampArgs {
start,
end,
Expand Down Expand Up @@ -391,6 +396,14 @@ impl<T: SeriesValue> LazyBatchGenerator for GenericSeriesState<T> {
.should_stop(self.end.clone(), &self.step, self.include_end)
{
buf.push(self.current.to_value_type());
if self
.current
.should_stop(self.end.clone(), &self.step, false)
{
self.current.advance(&self.step)?;
break;
}

self.current.advance(&self.step)?;
}

Expand Down Expand Up @@ -433,6 +446,47 @@ fn reach_end_int64(val: i64, end: i64, step: i64, include_end: bool) -> bool {
}
}

fn normalize_int64_series(
start: i64,
end: i64,
step: i64,
include_end: bool,
) -> (i64, bool) {
let Some(last_value) = last_int64_series_value(start, end, step, include_end) else {
return (end, include_end);
};

(last_value, true)
}

fn last_int64_series_value(
start: i64,
end: i64,
step: i64,
include_end: bool,
) -> Option<i64> {
if reach_end_int64(start, end, step, include_end) {
return None;
}

let start = i128::from(start);
let end = i128::from(end);
let step = i128::from(step);

let last = if step > 0 {
let limit = if include_end { end } else { end - 1 };
let steps = (limit - start) / step;
start + steps * step
} else {
let step_abs = -step;
let limit = if include_end { end } else { end + 1 };
let steps = (start - limit) / step_abs;
start - steps * step_abs
};

Some(last as i64)
}

fn validate_interval_step(step: IntervalMonthDayNano) -> Result<()> {
if step.months == 0 && step.days == 0 && step.nanoseconds == 0 {
return plan_err!("Step interval cannot be zero");
Expand Down Expand Up @@ -760,11 +814,12 @@ impl TableFunctionImpl for RangeFunc {
mod generate_series_tests {
use std::sync::Arc;

use arrow::array::Int64Array;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Result;
use datafusion_physical_plan::memory::LazyBatchGenerator;

use crate::generate_series::GenericSeriesState;
use crate::generate_series::{GenericSeriesState, normalize_int64_series};

#[test]
fn test_generic_series_state_reset() -> Result<()> {
Expand All @@ -791,4 +846,37 @@ mod generate_series_tests {

Ok(())
}

#[test]
fn test_normalize_int64_series_for_exclusive_upper_bound() {
assert_eq!(normalize_int64_series(0, 10, 3, false), (9, true));
}

#[test]
fn test_generate_series_state_stops_before_integer_overflow() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
let (end, include_end) = normalize_int64_series(i64::MAX - 1, i64::MAX, 2, false);
let mut state = GenericSeriesState::<i64> {
schema,
start: i64::MAX - 1,
end,
step: 2,
current: i64::MAX - 1,
batch_size: 8192,
include_end,
name: "test",
};

let batch = state.generate_next_batch()?.expect("missing batch");
let values = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.expect("int64 array");

assert_eq!(values.values(), &[i64::MAX - 1]);
assert!(state.generate_next_batch()?.is_none());

Ok(())
}
}
8 changes: 8 additions & 0 deletions datafusion/sqllogictest/test_files/table_functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ SELECT * FROM generate_series(1, 2, 3, 4)
statement error DataFusion error: Error during planning: Argument \#1 must be an INTEGER, TIMESTAMP, DATE or NULL, got Utf8
SELECT * FROM generate_series('foo', 'bar')

# Regression test: generate_series with a step that would overflow i64 after the last
# included value must return the reachable values rather than an error, matching
# PostgreSQL/DuckDB behavior.
query I
SELECT * FROM generate_series(9223372036854775806, 9223372036854775807, 2)
----
9223372036854775806

# UDF and UDTF `generate_series` can be used simultaneously
query ? rowsort
SELECT generate_series(1, t1.end) FROM generate_series(3, 5) as t1(end)
Expand Down
Loading