Skip to content

Commit 4bad6d6

Browse files
committed
Support Decimal32/64 types
1 parent 5bb476f commit 4bad6d6

File tree

12 files changed

+623
-86
lines changed

12 files changed

+623
-86
lines changed

datafusion/common/src/cast.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
2323
use crate::{downcast_value, Result};
2424
use arrow::array::{
25-
BinaryViewArray, DurationMicrosecondArray, DurationMillisecondArray,
26-
DurationNanosecondArray, DurationSecondArray, Float16Array, Int16Array, Int8Array,
27-
LargeBinaryArray, LargeStringArray, StringViewArray, UInt16Array,
25+
BinaryViewArray, Decimal32Array, Decimal64Array, DurationMicrosecondArray,
26+
DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array,
27+
Int16Array, Int8Array, LargeBinaryArray, LargeStringArray, StringViewArray,
28+
UInt16Array,
2829
};
2930
use arrow::{
3031
array::{
@@ -97,6 +98,16 @@ pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array> {
9798
Ok(downcast_value!(array, UInt64Array))
9899
}
99100

101+
// Downcast Array to Decimal32Array
102+
pub fn as_decimal32_array(array: &dyn Array) -> Result<&Decimal32Array> {
103+
Ok(downcast_value!(array, Decimal32Array))
104+
}
105+
106+
// Downcast Array to Decimal64Array
107+
pub fn as_decimal64_array(array: &dyn Array) -> Result<&Decimal64Array> {
108+
Ok(downcast_value!(array, Decimal64Array))
109+
}
110+
100111
// Downcast Array to Decimal128Array
101112
pub fn as_decimal128_array(array: &dyn Array) -> Result<&Decimal128Array> {
102113
Ok(downcast_value!(array, Decimal128Array))

datafusion/common/src/scalar/mod.rs

Lines changed: 338 additions & 36 deletions
Large diffs are not rendered by default.

datafusion/core/tests/fuzz_cases/record_batch_generator.rs

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@ use std::sync::Arc;
2020
use arrow::array::{ArrayRef, DictionaryArray, PrimitiveArray, RecordBatch};
2121
use arrow::datatypes::{
2222
ArrowPrimitiveType, BooleanType, DataType, Date32Type, Date64Type, Decimal128Type,
23-
Decimal256Type, DurationMicrosecondType, DurationMillisecondType,
24-
DurationNanosecondType, DurationSecondType, Field, Float32Type, Float64Type,
25-
Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType,
26-
IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Schema,
27-
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
28-
TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
23+
Decimal256Type, Decimal32Type, Decimal64Type, DurationMicrosecondType,
24+
DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field,
25+
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
26+
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
27+
Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
28+
Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
2929
TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
3030
UInt8Type,
3131
};
3232
use arrow_schema::{
3333
DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION,
34-
DECIMAL256_MAX_SCALE,
34+
DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE,
3535
};
3636
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result};
3737
use rand::{rng, rngs::StdRng, Rng, SeedableRng};
@@ -104,6 +104,20 @@ pub fn get_supported_types_columns(rng_seed: u64) -> Vec<ColumnDescr> {
104104
"duration_nanosecond",
105105
DataType::Duration(TimeUnit::Nanosecond),
106106
),
107+
ColumnDescr::new("decimal32", {
108+
let precision: u8 = rng.random_range(1..=DECIMAL32_MAX_PRECISION);
109+
let scale: i8 = rng.random_range(
110+
i8::MIN..=std::cmp::min(precision as i8, DECIMAL32_MAX_SCALE),
111+
);
112+
DataType::Decimal32(precision, scale)
113+
}),
114+
ColumnDescr::new("decimal64", {
115+
let precision: u8 = rng.random_range(1..=DECIMAL64_MAX_PRECISION);
116+
let scale: i8 = rng.random_range(
117+
i8::MIN..=std::cmp::min(precision as i8, DECIMAL64_MAX_SCALE),
118+
);
119+
DataType::Decimal64(precision, scale)
120+
}),
107121
ColumnDescr::new("decimal128", {
108122
let precision: u8 = rng.random_range(1..=DECIMAL128_MAX_PRECISION);
109123
let scale: i8 = rng.random_range(
@@ -682,6 +696,32 @@ impl RecordBatchGenerator {
682696
_ => unreachable!(),
683697
}
684698
}
699+
DataType::Decimal32(precision, scale) => {
700+
generate_decimal_array!(
701+
self,
702+
num_rows,
703+
max_num_distinct,
704+
null_pct,
705+
batch_gen_rng,
706+
array_gen_rng,
707+
precision,
708+
scale,
709+
Decimal32Type
710+
)
711+
}
712+
DataType::Decimal64(precision, scale) => {
713+
generate_decimal_array!(
714+
self,
715+
num_rows,
716+
max_num_distinct,
717+
null_pct,
718+
batch_gen_rng,
719+
array_gen_rng,
720+
precision,
721+
scale,
722+
Decimal64Type
723+
)
724+
}
685725
DataType::Decimal128(precision, scale) => {
686726
generate_decimal_array!(
687727
self,

datafusion/expr-common/src/type_coercion/aggregates.rs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
use crate::signature::TypeSignature;
1919
use arrow::datatypes::{
2020
DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
21-
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
21+
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION,
22+
DECIMAL64_MAX_PRECISION,
2223
};
2324

2425
use datafusion_common::{internal_err, plan_err, Result};
@@ -150,6 +151,18 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
150151
DataType::Int64 => Ok(DataType::Int64),
151152
DataType::UInt64 => Ok(DataType::UInt64),
152153
DataType::Float64 => Ok(DataType::Float64),
154+
DataType::Decimal32(precision, scale) => {
155+
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
156+
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
157+
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
158+
Ok(DataType::Decimal128(new_precision, *scale))
159+
}
160+
DataType::Decimal64(precision, scale) => {
161+
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
162+
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
163+
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
164+
Ok(DataType::Decimal128(new_precision, *scale))
165+
}
153166
DataType::Decimal128(precision, scale) => {
154167
// In the spark, the result type is DECIMAL(min(38,precision+10), s)
155168
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
@@ -222,6 +235,16 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType>
222235
/// Internal sum type of an average
223236
pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> {
224237
match arg_type {
238+
DataType::Decimal32(precision, scale) => {
239+
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
240+
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
241+
Ok(DataType::Decimal32(new_precision, *scale))
242+
}
243+
DataType::Decimal64(precision, scale) => {
244+
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
245+
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
246+
Ok(DataType::Decimal64(new_precision, *scale))
247+
}
225248
DataType::Decimal128(precision, scale) => {
226249
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
227250
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
@@ -249,7 +272,7 @@ pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
249272
_ => matches!(
250273
arg_type,
251274
arg_type if NUMERICS.contains(arg_type)
252-
|| matches!(arg_type, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
275+
|| matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
253276
),
254277
}
255278
}
@@ -262,7 +285,7 @@ pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
262285
_ => matches!(
263286
arg_type,
264287
arg_type if NUMERICS.contains(arg_type)
265-
|| matches!(arg_type, DataType::Decimal128(_, _)| DataType::Decimal256(_, _))
288+
|| matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
266289
),
267290
}
268291
}
@@ -297,6 +320,8 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<Da
297320
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
298321
fn coerced_type(func_name: &str, data_type: &DataType) -> Result<DataType> {
299322
match &data_type {
323+
DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
324+
DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
300325
DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
301326
DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
302327
d if d.is_numeric() => Ok(DataType::Float64),

datafusion/expr/src/type_coercion/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ pub fn is_signed_numeric(dt: &DataType) -> bool {
5151
| DataType::Float16
5252
| DataType::Float32
5353
| DataType::Float64
54+
| DataType::Decimal32(_, _)
55+
| DataType::Decimal64(_, _)
5456
| DataType::Decimal128(_, _)
5557
| DataType::Decimal256(_, _),
5658
)
@@ -89,5 +91,11 @@ pub fn is_utf8_or_utf8view_or_large_utf8(dt: &DataType) -> bool {
8991

9092
/// Determine whether the given data type `dt` is a `Decimal`.
9193
pub fn is_decimal(dt: &DataType) -> bool {
92-
matches!(dt, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
94+
matches!(
95+
dt,
96+
DataType::Decimal32(_, _)
97+
| DataType::Decimal64(_, _)
98+
| DataType::Decimal128(_, _)
99+
| DataType::Decimal256(_, _)
100+
)
93101
}

datafusion/functions-aggregate/src/average.rs

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ use arrow::array::{
2424

2525
use arrow::compute::sum;
2626
use arrow::datatypes::{
27-
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType,
28-
DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
29-
DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type,
27+
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type,
28+
Decimal64Type, DecimalType, DurationMicrosecondType, DurationMillisecondType,
29+
DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, TimeUnit,
30+
UInt64Type,
3031
};
3132
use datafusion_common::{
3233
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
@@ -128,6 +129,28 @@ impl AggregateUDFImpl for Avg {
128129
} else {
129130
match (&data_type, acc_args.return_field.data_type()) {
130131
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
132+
(
133+
Decimal32(sum_precision, sum_scale),
134+
Decimal32(target_precision, target_scale),
135+
) => Ok(Box::new(DecimalAvgAccumulator::<Decimal32Type> {
136+
sum: None,
137+
count: 0,
138+
sum_scale: *sum_scale,
139+
sum_precision: *sum_precision,
140+
target_precision: *target_precision,
141+
target_scale: *target_scale,
142+
})),
143+
(
144+
Decimal64(sum_precision, sum_scale),
145+
Decimal64(target_precision, target_scale),
146+
) => Ok(Box::new(DecimalAvgAccumulator::<Decimal64Type> {
147+
sum: None,
148+
count: 0,
149+
sum_scale: *sum_scale,
150+
sum_precision: *sum_precision,
151+
target_precision: *target_precision,
152+
target_scale: *target_scale,
153+
})),
131154
(
132155
Decimal128(sum_precision, sum_scale),
133156
Decimal128(target_precision, target_scale),
@@ -202,7 +225,11 @@ impl AggregateUDFImpl for Avg {
202225
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
203226
matches!(
204227
args.return_field.data_type(),
205-
DataType::Float64 | DataType::Decimal128(_, _) | DataType::Duration(_)
228+
DataType::Float64
229+
| DataType::Decimal32(_, _)
230+
| DataType::Decimal64(_, _)
231+
| DataType::Decimal128(_, _)
232+
| DataType::Duration(_)
206233
) && !args.is_distinct
207234
}
208235

@@ -222,6 +249,44 @@ impl AggregateUDFImpl for Avg {
222249
|sum: f64, count: u64| Ok(sum / count as f64),
223250
)))
224251
}
252+
(
253+
Decimal32(_sum_precision, sum_scale),
254+
Decimal32(target_precision, target_scale),
255+
) => {
256+
let decimal_averager = DecimalAverager::<Decimal32Type>::try_new(
257+
*sum_scale,
258+
*target_precision,
259+
*target_scale,
260+
)?;
261+
262+
let avg_fn =
263+
move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32);
264+
265+
Ok(Box::new(AvgGroupsAccumulator::<Decimal32Type, _>::new(
266+
&data_type,
267+
args.return_field.data_type(),
268+
avg_fn,
269+
)))
270+
}
271+
(
272+
Decimal64(_sum_precision, sum_scale),
273+
Decimal64(target_precision, target_scale),
274+
) => {
275+
let decimal_averager = DecimalAverager::<Decimal64Type>::try_new(
276+
*sum_scale,
277+
*target_precision,
278+
*target_scale,
279+
)?;
280+
281+
let avg_fn =
282+
move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64);
283+
284+
Ok(Box::new(AvgGroupsAccumulator::<Decimal64Type, _>::new(
285+
&data_type,
286+
args.return_field.data_type(),
287+
avg_fn,
288+
)))
289+
}
225290
(
226291
Decimal128(_sum_precision, sum_scale),
227292
Decimal128(target_precision, target_scale),

datafusion/functions-aggregate/src/first_last.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ use arrow::array::{
3030
use arrow::buffer::{BooleanBuffer, NullBuffer};
3131
use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
3232
use arrow::datatypes::{
33-
DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, FieldRef,
34-
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
35-
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
36-
TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
37-
TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
38-
UInt8Type,
33+
DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Decimal32Type,
34+
Decimal64Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, Int16Type,
35+
Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType,
36+
Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
37+
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
38+
UInt32Type, UInt64Type, UInt8Type,
3939
};
4040
use datafusion_common::cast::as_boolean_array;
4141
use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx};
@@ -234,6 +234,8 @@ impl AggregateUDFImpl for FirstValue {
234234
DataType::Float32 => create_accumulator::<Float32Type>(args),
235235
DataType::Float64 => create_accumulator::<Float64Type>(args),
236236

237+
DataType::Decimal32(_, _) => create_accumulator::<Decimal32Type>(args),
238+
DataType::Decimal64(_, _) => create_accumulator::<Decimal64Type>(args),
237239
DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
238240
DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
239241

0 commit comments

Comments
 (0)