[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support decimal128 without casting to double #328

Merged
merged 4 commits into from
Jun 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 84 additions & 41 deletions crates/duckdb/src/vtab/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use arrow::{
record_batch::RecordBatch,
};

use num::cast::AsPrimitive;
use num::{cast::AsPrimitive, ToPrimitive};

/// A pointer to the Arrow record batch for the table function.
#[repr(C)]
Expand Down Expand Up @@ -165,7 +165,7 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
// duckdb/src/main/capi/helper-c.cpp does not support decimal
// DataType::Decimal128(_, _) => Decimal,
// DataType::Decimal256(_, _) => Decimal,
DataType::Decimal128(_, _) => Double,
DataType::Decimal128(_, _) => Decimal,
DataType::Decimal256(_, _) => Double,
DataType::Map(_, _) => Map,
_ => {
Expand All @@ -177,35 +177,34 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn

/// Convert arrow DataType to duckdb logical type
pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalType, Box<dyn std::error::Error>> {
if data_type.is_primitive()
|| matches!(
data_type,
DataType::Boolean | DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary
)
{
Ok(LogicalType::new(to_duckdb_type_id(data_type)?))
} else if let DataType::Dictionary(_, value_type) = data_type {
to_duckdb_logical_type(value_type)
} else if let DataType::Struct(fields) = data_type {
let mut shape = vec![];
for field in fields.iter() {
shape.push((field.name().as_str(), to_duckdb_logical_type(field.data_type())?));
}
Ok(LogicalType::struct_type(shape.as_slice()))
} else if let DataType::List(child) = data_type {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else if let DataType::LargeList(child) = data_type {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else if let DataType::FixedSizeList(child, array_size) = data_type {
Ok(LogicalType::array(
match data_type {
DataType::Dictionary(_, value_type) => to_duckdb_logical_type(value_type),
DataType::Struct(fields) => {
let mut shape = vec![];
for field in fields.iter() {
shape.push((field.name().as_str(), to_duckdb_logical_type(field.data_type())?));
}
Ok(LogicalType::struct_type(shape.as_slice()))
}
DataType::List(child) | DataType::LargeList(child) => {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
}
DataType::FixedSizeList(child, array_size) => Ok(LogicalType::array(
&to_duckdb_logical_type(child.data_type())?,
*array_size as u64,
))
} else {
Err(
format!("Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs")
.into(),
)),
DataType::Decimal128(width, scale) if *scale > 0 => {
// DuckDB does not support negative decimal scales
Ok(LogicalType::decimal(*width, (*scale).try_into().unwrap()))
}
DataType::Boolean | DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
Ok(LogicalType::new(to_duckdb_type_id(data_type)?))
}
dtype if dtype.is_primitive() => Ok(LogicalType::new(to_duckdb_type_id(data_type)?)),
_ => Err(format!(
"Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs"
)
.into()),
}
}

Expand Down Expand Up @@ -354,13 +353,11 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<
out.as_mut_any().downcast_mut().unwrap(),
);
}
DataType::Decimal128(_, _) => {
DataType::Decimal128(width, _) => {
decimal_array_to_vector(
array
.as_any()
.downcast_ref::<Decimal128Array>()
.expect("Unable to downcast to BooleanArray"),
as_primitive_array(array),
out.as_mut_any().downcast_mut().unwrap(),
*width,
);
}

Expand Down Expand Up @@ -407,10 +404,43 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<
}

/// Convert Arrow [Decimal128Array] to a duckdb vector.
fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector) {
assert!(array.len() <= out.capacity());
for i in 0..array.len() {
out.as_mut_slice()[i] = array.value_as_string(i).parse::<f64>().unwrap();
fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector, width: u8) {
match width {
1..=4 => {
let out_data = out.as_mut_slice();
for (i, value) in array.values().iter().enumerate() {
out_data[i] = value.to_i16().unwrap();
}
}
5..=9 => {
let out_data = out.as_mut_slice();
for (i, value) in array.values().iter().enumerate() {
out_data[i] = value.to_i32().unwrap();
}
}
10..=18 => {
let out_data = out.as_mut_slice();
for (i, value) in array.values().iter().enumerate() {
out_data[i] = value.to_i64().unwrap();
}
}
19..=38 => {
let out_data = out.as_mut_slice();
for (i, value) in array.values().iter().enumerate() {
out_data[i] = value.to_i128().unwrap();
}
}
// This should never happen, arrow only supports 1-38 decimal digits
_ => panic!("Invalid decimal width: {}", width),
}

// Set nulls
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out.set_null(i);
}
}
}
}

Expand Down Expand Up @@ -581,8 +611,8 @@ mod test {
use crate::{Connection, Result};
use arrow::{
array::{
Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal256Array, FixedSizeListArray,
Float64Array, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array,
FixedSizeListArray, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray,
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
},
Expand All @@ -606,9 +636,9 @@ mod test {
let mut arr = stmt.query_arrow(param)?;
let rb = arr.next().expect("no record batch");
assert_eq!(rb.num_columns(), 1);
let column = rb.column(0).as_any().downcast_ref::<Float64Array>().unwrap();
let column = rb.column(0).as_any().downcast_ref::<Decimal128Array>().unwrap();
assert_eq!(column.len(), 1);
assert_eq!(column.value(0), 300.0);
assert_eq!(column.value(0), i128::from(30000));
Ok(())
}

Expand Down Expand Up @@ -896,6 +926,19 @@ mod test {
Ok(())
}

#[test]
fn test_decimal128_roundtrip() -> Result<(), Box<dyn Error>> {
let array: PrimitiveArray<arrow::datatypes::Decimal128Type> =
Decimal128Array::from(vec![i128::from(1), i128::from(2), i128::from(3)]);
check_rust_primitive_array_roundtrip(array.clone(), array)?;

// With width and scale
let array: PrimitiveArray<arrow::datatypes::Decimal128Type> =
Decimal128Array::from(vec![i128::from(12345)]).with_data_type(DataType::Decimal128(5, 2));
check_rust_primitive_array_roundtrip(array.clone(), array)?;
Ok(())
}

#[test]
fn test_timestamp_tz_insert() -> Result<(), Box<dyn Error>> {
// TODO: This test should be reworked once we support TIMESTAMP_TZ properly
Expand Down
Loading