diff --git a/CMakeLists.txt b/CMakeLists.txt index 2d286cc75..feb53d3e2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -591,6 +591,8 @@ add_library(duckdb_java SHARED src/jni/bindings_common.cpp src/jni/bindings_data_chunk.cpp src/jni/bindings_logical_type.cpp + src/jni/bindings_scalar_function.cpp + src/jni/bindings_table_function.cpp src/jni/bindings_validity.cpp src/jni/bindings_vector.cpp src/jni/config.cpp @@ -598,6 +600,12 @@ add_library(duckdb_java SHARED src/jni/functions.cpp src/jni/refs.cpp src/jni/types.cpp + src/jni/udf_callbacks.cpp + src/jni/udf_registration.cpp + src/jni/udf_registration_impl.cpp + src/jni/udf_table_bind_conversion.cpp + src/jni/udf_types.cpp + src/jni/udf_vector_accessors.cpp src/jni/util.cpp ${DUCKDB_SRC_FILES}) diff --git a/README.md b/README.md index 4042a68db..d836d886f 100644 --- a/README.md +++ b/README.md @@ -20,3 +20,7 @@ This optionally takes an argument to only run a single test, for example: ``` java -cp "build/release/duckdb_jdbc_tests.jar:build/release/duckdb_jdbc.jar" org/duckdb/TestDuckDBJDBC test_valid_but_local_config_throws_exception ``` + +### User-Defined Functions (Java) + +All Java UDF documentation and examples are available in [UDF.MD](UDF.MD). diff --git a/UDF.MD b/UDF.MD new file mode 100644 index 000000000..5af0aed18 --- /dev/null +++ b/UDF.MD @@ -0,0 +1,212 @@ +# User-Defined Functions (Java) + +This guide shows how to use Java Scalar UDFs and Table Functions with `DuckDBConnection`. + +## Scalar UDF + +Scalar UDF callbacks use a vectorized contract: + +```java +ScalarUdf.apply(UdfContext ctx, UdfReader[] args, UdfScalarWriter out, int rowCount) +``` + +Use `rowCount` loops and write one output value per row. + +### Basic example + +```java +try (DuckDBConnection conn = DriverManager.getConnection("jdbc:duckdb:").unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + + conn.registerScalarUdf("add_one", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, args[0].getInt(row) + 1); + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT add_one(41)")) { + rs.next(); + System.out.println(rs.getInt(1)); // 42 + } +} +``` + +### Registration forms + +You can register scalar UDFs with: + +- `DuckDBColumnType` signatures (`registerScalarUdf`) +- `Class` signatures (`registerScalarUdf`) +- explicit `UdfLogicalType` signatures (`registerScalarUdf`) +- varargs signatures (`registerScalarUdfVarArgs`) + +For decimal precision/scale, prefer explicit logical types: + +```java +conn.registerScalarUdf( + "mul_decimal", + new UdfLogicalType[] {UdfLogicalType.decimal(20, 4), UdfLogicalType.decimal(20, 4)}, + UdfLogicalType.decimal(38, 8), + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setBigDecimal(row, args[0].getBigDecimal(row).multiply(args[1].getBigDecimal(row))); + } + } +); +``` + +### Options + +`UdfOptions` controls scalar behavior: + +- `deterministic(true|false)`: marks whether equal inputs always produce equal output. Use `false` for non-deterministic logic (for example random/time-based behavior). +- `nullSpecialHandling(true|false)`: when `true`, your callback receives rows that contain `NULL` input values; when `false`, DuckDB handles null propagation before callback execution. +- `returnNullOnException(true|false)`: when `true`, Java exceptions in callback rows are returned as `NULL`; when `false`, the query fails with an error. +- `varArgs(true|false)`: enables varargs registration (normally used via `registerScalarUdfVarArgs`). + +Example: + +```java +UdfOptions options = new UdfOptions() + .deterministic(true) + .nullSpecialHandling(true) + .returnNullOnException(false); + +conn.registerScalarUdf("safe_add", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, args[0].getInt(row) + 1); + } + } + }, options); +``` + +## UdfReader / UdfScalarWriter object mappings + +| DuckDB type | Reader object | Writer object | +| --- | --- | --- | +| `BOOLEAN` | `Boolean` | `Boolean` | +| `TINYINT`, `SMALLINT`, `INTEGER`, `UTINYINT`, `USMALLINT` | `Integer` | `Integer` | +| `BIGINT`, `UINTEGER`, `UBIGINT` | `Long` | `Long` | +| `FLOAT` | `Float` | `Float` | +| `DOUBLE` | `Double` | `Double` | +| `DECIMAL` | `BigDecimal` | `BigDecimal` | +| `VARCHAR` | `String` | `String` | +| `BLOB` | `byte[]` | `byte[]` | +| `DATE` | `LocalDate` or `Date` | `LocalDate` or `Date` | +| `TIME`, `TIME_NS` | `LocalTime` | `LocalTime` | +| `TIME_WITH_TIME_ZONE` | `OffsetTime` | `OffsetTime` | +| `TIMESTAMP`, `TIMESTAMP_S`, `TIMESTAMP_MS`, `TIMESTAMP_NS` | `LocalDateTime` | `LocalDateTime` or `Date` | +| `TIMESTAMP_WITH_TIME_ZONE` | `OffsetDateTime` | `OffsetDateTime` or `Date` | +| `UUID` | `UUID` | `UUID` | +| `HUGEINT`, `UHUGEINT` | `byte[]` | `byte[]` | + +`UdfScalarWriter` supports explicit setters and `setObject(...)`. + +## Table Function + +Table function callbacks use: + +- `bind(BindContext ctx, Object[] parameters) -> TableBindResult` +- `init(InitContext ctx, TableBindResult bind) -> TableState` +- `produce(TableState state, UdfOutputAppender out) -> int` + +What each callback does: + +- `bind`: runs once per invocation to validate/interpret parameters, define output schema, and create bind state. +- `init`: runs after bind to initialize execution state (cursor/counters/chunk state). +- `produce`: runs repeatedly to emit rows in chunks; return the number of rows produced in that call. + +### Basic example + +```java +conn.registerTableFunction( + "range_java", + new TableFunction() { + @Override + public TableBindResult bind(BindContext ctx, Object[] parameters) { + long end = ((Number) parameters[0]).longValue(); + return new TableBindResult( + new String[] {"i"}, + new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.BIGINT)}, + new long[] {0L, end} + ); + } + + @Override + public TableState init(InitContext ctx, TableBindResult bind) { + return new TableState(bind.getBindState()); + } + + @Override + public int produce(TableState state, UdfOutputAppender out) { + long[] st = (long[]) state.getState(); + long current = st[0]; + long end = st[1]; + int produced = 0; + + while (produced < 256 && current < end) { + out.beginRow().append(current).endRow(); + current++; + produced++; + } + + st[0] = current; + return produced; + } + }, + new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[] {DuckDBColumnType.BIGINT}), + new TableFunctionOptions().threadSafe(false).maxThreads(1) +); +``` + +### Bind parameter object mappings + +In `bind`, parameters are materialized as Java objects. Common mappings: + +- `DECIMAL -> BigDecimal` +- `DATE -> LocalDate` +- `TIME`, `TIME_NS -> LocalTime` +- `TIMESTAMP* -> LocalDateTime` +- `TIME_WITH_TIME_ZONE -> OffsetTime` +- `TIMESTAMP_WITH_TIME_ZONE -> OffsetDateTime` +- `UUID -> UUID` + +### Output writing with UdfOutputAppender + +`UdfOutputAppender` supports: + +- primitive/object `append(...)` for one column at a time +- `setObject(...)` and typed setters (`setBigDecimal`, `setLocalDate`, etc.) +- nested output objects for container types: + - `LIST`/`ARRAY`: Java arrays or `Collection` + - `MAP`: `Map` + - `STRUCT`: positional `List`/array or named `Map` + - `UNION`: `AbstractMap.SimpleEntry` + - `ENUM`: `String` + +## Table function options + +`TableFunctionOptions`: + +- `threadSafe(false|true)` +- `maxThreads(int >= 1)` + +`TableFunctionDefinition`: + +- `withParameterTypes(...)` +- `withProjectionPushdown(true|false)` + +## Unsupported in scalar signatures + +Scalar UDF signatures do not support nested/container logical types (`LIST`, `STRUCT`, `MAP`, `ARRAY`, `UNION`, `ENUM`) and `INTERVAL`. + +## Practical recommendations + +- Use chunk-oriented loops (`rowCount`) for scalar UDF throughput. +- Avoid executing SQL on the same `DuckDBConnection` from inside callbacks. +- Use explicit logical types for decimal-sensitive workloads. diff --git a/duckdb_java.def b/duckdb_java.def index 68ff3031b..191a5759c 100644 --- a/duckdb_java.def +++ b/duckdb_java.def @@ -50,7 +50,6 @@ Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1auto_1commit Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1catalog Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1schema Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1startup - Java_org_duckdb_DuckDBBindings_duckdb_1vector_1size Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id @@ -82,6 +81,12 @@ Java_org_duckdb_DuckDBBindings_duckdb_1list_1vector_1set_1size Java_org_duckdb_DuckDBBindings_duckdb_1list_1vector_1reserve Java_org_duckdb_DuckDBBindings_duckdb_1struct_1vector_1get_1child Java_org_duckdb_DuckDBBindings_duckdb_1array_1vector_1get_1child +Java_org_duckdb_DuckDBBindings_duckdb_1udf_1get_1varchar_1bytes +Java_org_duckdb_DuckDBBindings_duckdb_1udf_1set_1varchar_1bytes +Java_org_duckdb_DuckDBBindings_duckdb_1udf_1get_1blob_1bytes +Java_org_duckdb_DuckDBBindings_duckdb_1udf_1set_1blob_1bytes +Java_org_duckdb_DuckDBBindings_duckdb_1udf_1get_1decimal +Java_org_duckdb_DuckDBBindings_duckdb_1udf_1set_1decimal Java_org_duckdb_DuckDBBindings_duckdb_1create_1data_1chunk Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1data_1chunk Java_org_duckdb_DuckDBBindings_duckdb_1data_1chunk_1reset @@ -98,6 +103,38 @@ Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1count Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1type Java_org_duckdb_DuckDBBindings_duckdb_1append_1data_1chunk Java_org_duckdb_DuckDBBindings_duckdb_1append_1default_1to_1chunk +Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function +Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile +Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling +Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function +Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function_1java +Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function_1java_1with_1function +Java_org_duckdb_DuckDBBindings_duckdb_1create_1table_1function +Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1table_1function +Java_org_duckdb_DuckDBBindings_duckdb_1table_1function_1set_1name +Java_org_duckdb_DuckDBBindings_duckdb_1table_1function_1add_1parameter +Java_org_duckdb_DuckDBBindings_duckdb_1table_1function_1supports_1projection_1pushdown +Java_org_duckdb_DuckDBBindings_duckdb_1register_1table_1function +Java_org_duckdb_DuckDBBindings_duckdb_1register_1table_1function_1java +Java_org_duckdb_DuckDBBindings_duckdb_1register_1table_1function_1java_1with_1function +Java_org_duckdb_DuckDBBindings_duckdb_1bind_1get_1parameter_1count +Java_org_duckdb_DuckDBBindings_duckdb_1bind_1get_1parameter +Java_org_duckdb_DuckDBBindings_duckdb_1bind_1add_1result_1column +Java_org_duckdb_DuckDBBindings_duckdb_1bind_1set_1bind_1data +Java_org_duckdb_DuckDBBindings_duckdb_1bind_1set_1error +Java_org_duckdb_DuckDBBindings_duckdb_1init_1set_1init_1data +Java_org_duckdb_DuckDBBindings_duckdb_1init_1get_1column_1count +Java_org_duckdb_DuckDBBindings_duckdb_1init_1get_1column_1index +Java_org_duckdb_DuckDBBindings_duckdb_1init_1set_1max_1threads +Java_org_duckdb_DuckDBBindings_duckdb_1init_1set_1error +Java_org_duckdb_DuckDBBindings_duckdb_1function_1get_1bind_1data +Java_org_duckdb_DuckDBBindings_duckdb_1function_1get_1init_1data +Java_org_duckdb_DuckDBBindings_duckdb_1function_1get_1local_1init_1data +Java_org_duckdb_DuckDBBindings_duckdb_1function_1set_1error duckdb_adbc_init duckdb_add_aggregate_function_to_set diff --git a/duckdb_java.exp b/duckdb_java.exp index 6b6cb687d..b6dcc6f7f 100644 --- a/duckdb_java.exp +++ b/duckdb_java.exp @@ -47,7 +47,6 @@ _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1auto_1commit _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1catalog _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1schema _Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1startup - _Java_org_duckdb_DuckDBBindings_duckdb_1vector_1size _Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type _Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id @@ -79,6 +78,12 @@ _Java_org_duckdb_DuckDBBindings_duckdb_1list_1vector_1set_1size _Java_org_duckdb_DuckDBBindings_duckdb_1list_1vector_1reserve _Java_org_duckdb_DuckDBBindings_duckdb_1struct_1vector_1get_1child _Java_org_duckdb_DuckDBBindings_duckdb_1array_1vector_1get_1child +_Java_org_duckdb_DuckDBBindings_duckdb_1udf_1get_1varchar_1bytes +_Java_org_duckdb_DuckDBBindings_duckdb_1udf_1set_1varchar_1bytes +_Java_org_duckdb_DuckDBBindings_duckdb_1udf_1get_1blob_1bytes +_Java_org_duckdb_DuckDBBindings_duckdb_1udf_1set_1blob_1bytes +_Java_org_duckdb_DuckDBBindings_duckdb_1udf_1get_1decimal +_Java_org_duckdb_DuckDBBindings_duckdb_1udf_1set_1decimal _Java_org_duckdb_DuckDBBindings_duckdb_1create_1data_1chunk _Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1data_1chunk _Java_org_duckdb_DuckDBBindings_duckdb_1data_1chunk_1reset @@ -95,6 +100,38 @@ _Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1count _Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1type _Java_org_duckdb_DuckDBBindings_duckdb_1append_1data_1chunk _Java_org_duckdb_DuckDBBindings_duckdb_1append_1default_1to_1chunk +_Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile +_Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling +_Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function_1java +_Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function_1java_1with_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1create_1table_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1table_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1table_1function_1set_1name +_Java_org_duckdb_DuckDBBindings_duckdb_1table_1function_1add_1parameter +_Java_org_duckdb_DuckDBBindings_duckdb_1table_1function_1supports_1projection_1pushdown +_Java_org_duckdb_DuckDBBindings_duckdb_1register_1table_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1register_1table_1function_1java +_Java_org_duckdb_DuckDBBindings_duckdb_1register_1table_1function_1java_1with_1function +_Java_org_duckdb_DuckDBBindings_duckdb_1bind_1get_1parameter_1count +_Java_org_duckdb_DuckDBBindings_duckdb_1bind_1get_1parameter +_Java_org_duckdb_DuckDBBindings_duckdb_1bind_1add_1result_1column +_Java_org_duckdb_DuckDBBindings_duckdb_1bind_1set_1bind_1data +_Java_org_duckdb_DuckDBBindings_duckdb_1bind_1set_1error +_Java_org_duckdb_DuckDBBindings_duckdb_1init_1set_1init_1data +_Java_org_duckdb_DuckDBBindings_duckdb_1init_1get_1column_1count +_Java_org_duckdb_DuckDBBindings_duckdb_1init_1get_1column_1index +_Java_org_duckdb_DuckDBBindings_duckdb_1init_1set_1max_1threads +_Java_org_duckdb_DuckDBBindings_duckdb_1init_1set_1error +_Java_org_duckdb_DuckDBBindings_duckdb_1function_1get_1bind_1data +_Java_org_duckdb_DuckDBBindings_duckdb_1function_1get_1init_1data +_Java_org_duckdb_DuckDBBindings_duckdb_1function_1get_1local_1init_1data +_Java_org_duckdb_DuckDBBindings_duckdb_1function_1set_1error _duckdb_adbc_init _duckdb_add_aggregate_function_to_set diff --git a/duckdb_java.map b/duckdb_java.map index 7ed2d7233..a4f2faeea 100644 --- a/duckdb_java.map +++ b/duckdb_java.map @@ -49,7 +49,6 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1catalog; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1set_1schema; Java_org_duckdb_DuckDBNative_duckdb_1jdbc_1startup; - Java_org_duckdb_DuckDBBindings_duckdb_1vector_1size; Java_org_duckdb_DuckDBBindings_duckdb_1create_1logical_1type; Java_org_duckdb_DuckDBBindings_duckdb_1get_1type_1id; @@ -81,6 +80,12 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBBindings_duckdb_1list_1vector_1reserve; Java_org_duckdb_DuckDBBindings_duckdb_1struct_1vector_1get_1child; Java_org_duckdb_DuckDBBindings_duckdb_1array_1vector_1get_1child; + Java_org_duckdb_DuckDBBindings_duckdb_1udf_1get_1varchar_1bytes; + Java_org_duckdb_DuckDBBindings_duckdb_1udf_1set_1varchar_1bytes; + Java_org_duckdb_DuckDBBindings_duckdb_1udf_1get_1blob_1bytes; + Java_org_duckdb_DuckDBBindings_duckdb_1udf_1set_1blob_1bytes; + Java_org_duckdb_DuckDBBindings_duckdb_1udf_1get_1decimal; + Java_org_duckdb_DuckDBBindings_duckdb_1udf_1set_1decimal; Java_org_duckdb_DuckDBBindings_duckdb_1create_1data_1chunk; Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1data_1chunk; Java_org_duckdb_DuckDBBindings_duckdb_1data_1chunk_1reset; @@ -97,6 +102,38 @@ DUCKDB_JAVA { Java_org_duckdb_DuckDBBindings_duckdb_1appender_1column_1type; Java_org_duckdb_DuckDBBindings_duckdb_1append_1data_1chunk; Java_org_duckdb_DuckDBBindings_duckdb_1append_1default_1to_1chunk; + Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile; + Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling; + Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function_1java; + Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function_1java_1with_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1create_1table_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1table_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1table_1function_1set_1name; + Java_org_duckdb_DuckDBBindings_duckdb_1table_1function_1add_1parameter; + Java_org_duckdb_DuckDBBindings_duckdb_1table_1function_1supports_1projection_1pushdown; + Java_org_duckdb_DuckDBBindings_duckdb_1register_1table_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1register_1table_1function_1java; + Java_org_duckdb_DuckDBBindings_duckdb_1register_1table_1function_1java_1with_1function; + Java_org_duckdb_DuckDBBindings_duckdb_1bind_1get_1parameter_1count; + Java_org_duckdb_DuckDBBindings_duckdb_1bind_1get_1parameter; + Java_org_duckdb_DuckDBBindings_duckdb_1bind_1add_1result_1column; + Java_org_duckdb_DuckDBBindings_duckdb_1bind_1set_1bind_1data; + Java_org_duckdb_DuckDBBindings_duckdb_1bind_1set_1error; + Java_org_duckdb_DuckDBBindings_duckdb_1init_1set_1init_1data; + Java_org_duckdb_DuckDBBindings_duckdb_1init_1get_1column_1count; + Java_org_duckdb_DuckDBBindings_duckdb_1init_1get_1column_1index; + Java_org_duckdb_DuckDBBindings_duckdb_1init_1set_1max_1threads; + Java_org_duckdb_DuckDBBindings_duckdb_1init_1set_1error; + Java_org_duckdb_DuckDBBindings_duckdb_1function_1get_1bind_1data; + Java_org_duckdb_DuckDBBindings_duckdb_1function_1get_1init_1data; + Java_org_duckdb_DuckDBBindings_duckdb_1function_1get_1local_1init_1data; + Java_org_duckdb_DuckDBBindings_duckdb_1function_1set_1error; duckdb_adbc_init; duckdb_add_aggregate_function_to_set; diff --git a/src/jni/bindings_scalar_function.cpp b/src/jni/bindings_scalar_function.cpp new file mode 100644 index 000000000..18b7e12e6 --- /dev/null +++ b/src/jni/bindings_scalar_function.cpp @@ -0,0 +1,185 @@ +#include "bindings.hpp" +#include "holders.hpp" +#include "refs.hpp" +#include "udf_registration.hpp" +#include "util.hpp" + +static duckdb_scalar_function scalar_function_buf_to_scalar_function(JNIEnv *env, jobject scalar_function_buf) { + if (scalar_function_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function buffer"); + return nullptr; + } + + auto scalar_function = reinterpret_cast(env->GetDirectBufferAddress(scalar_function_buf)); + if (scalar_function == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function"); + return nullptr; + } + + return scalar_function; +} + +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_create_scalar_function + * Signature: ()Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1create_1scalar_1function(JNIEnv *env, jclass) { + auto scalar_function = duckdb_create_scalar_function(); + return make_ptr_buf(env, scalar_function); +} + +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_destroy_scalar_function + * Signature: (Ljava/nio/ByteBuffer;)V + */ +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1scalar_1function(JNIEnv *env, jclass, + jobject scalar_function) { + auto scalar_function_ptr = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + + duckdb_destroy_scalar_function(&scalar_function_ptr); +} + +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_set_name + * Signature: (Ljava/nio/ByteBuffer;[B)V + */ +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1name(JNIEnv *env, jclass, + jobject scalar_function, + jbyteArray name) { + auto scalar_function_ptr = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + + auto name_string = jbyteArray_to_string(env, name); + if (env->ExceptionCheck()) { + return; + } + + duckdb_scalar_function_set_name(scalar_function_ptr, name_string.c_str()); +} + +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_add_parameter + * Signature: (Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)V + */ +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1add_1parameter(JNIEnv *env, jclass, + jobject scalar_function, + jobject logical_type) { + auto scalar_function_ptr = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + auto logical_type_ptr = logical_type_buf_to_logical_type(env, logical_type); + if (env->ExceptionCheck()) { + return; + } + + duckdb_scalar_function_add_parameter(scalar_function_ptr, logical_type_ptr); +} + +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_set_return_type + * Signature: (Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)V + */ +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1return_1type( + JNIEnv *env, jclass, jobject scalar_function, jobject logical_type) { + auto scalar_function_ptr = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + auto logical_type_ptr = logical_type_buf_to_logical_type(env, logical_type); + if (env->ExceptionCheck()) { + return; + } + + duckdb_scalar_function_set_return_type(scalar_function_ptr, logical_type_ptr); +} + +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_set_volatile + * Signature: (Ljava/nio/ByteBuffer;)V + */ +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1volatile(JNIEnv *env, jclass, + jobject scalar_function) { + auto scalar_function_ptr = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + + duckdb_scalar_function_set_volatile(scalar_function_ptr); +} + +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_scalar_function_set_special_handling + * Signature: (Ljava/nio/ByteBuffer;)V + */ +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1scalar_1function_1set_1special_1handling( + JNIEnv *env, jclass, jobject scalar_function) { + auto scalar_function_ptr = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return; + } + + duckdb_scalar_function_set_special_handling(scalar_function_ptr); +} + +/* + * Class: org_duckdb_DuckDBBindings + * Method: duckdb_register_scalar_function + * Signature: (Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;)I + */ +JNIEXPORT jint JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function(JNIEnv *env, jclass, + jobject connection, + jobject scalar_function) { + auto conn = conn_ref_buf_to_conn(env, connection); + if (env->ExceptionCheck()) { + return -1; + } + auto scalar_function_ptr = scalar_function_buf_to_scalar_function(env, scalar_function); + if (env->ExceptionCheck()) { + return -1; + } + + auto state = duckdb_register_scalar_function(conn, scalar_function_ptr); + return static_cast(state); +} + +extern "C" JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function_1java( + JNIEnv *env, jclass, jobject connection, jbyteArray name, jobject callback, jobjectArray argument_logical_types, + jobject return_logical_type, jboolean null_special_handling, jboolean return_null_on_exception, + jboolean deterministic, jboolean var_args) { + try { + _duckdb_jdbc_register_scalar_udf(env, nullptr, connection, name, callback, argument_logical_types, + return_logical_type, null_special_handling, return_null_on_exception, + deterministic, var_args); + } catch (const std::exception &e) { + duckdb::ErrorData error(e); + ThrowJNI(env, error.Message().c_str()); + } +} + +extern "C" JNIEXPORT void JNICALL +Java_org_duckdb_DuckDBBindings_duckdb_1register_1scalar_1function_1java_1with_1function( + JNIEnv *env, jclass, jobject connection, jobject scalar_function, jobject callback, + jobjectArray argument_logical_types, jobject return_logical_type, jboolean return_null_on_exception, + jboolean var_args) { + try { + _duckdb_jdbc_register_scalar_udf_on_function(env, nullptr, connection, scalar_function, callback, + argument_logical_types, return_logical_type, + return_null_on_exception, var_args); + } catch (const std::exception &e) { + duckdb::ErrorData error(e); + ThrowJNI(env, error.Message().c_str()); + } +} diff --git a/src/jni/bindings_table_function.cpp b/src/jni/bindings_table_function.cpp new file mode 100644 index 000000000..5d656c1f4 --- /dev/null +++ b/src/jni/bindings_table_function.cpp @@ -0,0 +1,321 @@ +#include "bindings.hpp" +#include "holders.hpp" +#include "refs.hpp" +#include "udf_registration.hpp" +#include "util.hpp" + +static duckdb_table_function table_function_buf_to_table_function(JNIEnv *env, jobject table_function_buf) { + if (table_function_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid table function buffer"); + return nullptr; + } + auto table_function = reinterpret_cast(env->GetDirectBufferAddress(table_function_buf)); + if (table_function == nullptr) { + env->ThrowNew(J_SQLException, "Invalid table function"); + return nullptr; + } + return table_function; +} + +static duckdb_bind_info bind_info_buf_to_bind_info(JNIEnv *env, jobject bind_info_buf) { + if (bind_info_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid bind info buffer"); + return nullptr; + } + auto bind_info = reinterpret_cast(env->GetDirectBufferAddress(bind_info_buf)); + if (bind_info == nullptr) { + env->ThrowNew(J_SQLException, "Invalid bind info"); + return nullptr; + } + return bind_info; +} + +static duckdb_init_info init_info_buf_to_init_info(JNIEnv *env, jobject init_info_buf) { + if (init_info_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid init info buffer"); + return nullptr; + } + auto init_info = reinterpret_cast(env->GetDirectBufferAddress(init_info_buf)); + if (init_info == nullptr) { + env->ThrowNew(J_SQLException, "Invalid init info"); + return nullptr; + } + return init_info; +} + +static duckdb_function_info function_info_buf_to_function_info(JNIEnv *env, jobject function_info_buf) { + if (function_info_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid function info buffer"); + return nullptr; + } + auto function_info = reinterpret_cast(env->GetDirectBufferAddress(function_info_buf)); + if (function_info == nullptr) { + env->ThrowNew(J_SQLException, "Invalid function info"); + return nullptr; + } + return function_info; +} + +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1create_1table_1function(JNIEnv *env, jclass) { + return make_ptr_buf(env, duckdb_create_table_function()); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1destroy_1table_1function(JNIEnv *env, jclass, + jobject table_function) { + auto tf = table_function_buf_to_table_function(env, table_function); + if (env->ExceptionCheck()) { + return; + } + duckdb_destroy_table_function(&tf); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1table_1function_1set_1name(JNIEnv *env, jclass, + jobject table_function, + jbyteArray name) { + auto tf = table_function_buf_to_table_function(env, table_function); + if (env->ExceptionCheck()) { + return; + } + auto name_string = jbyteArray_to_string(env, name); + if (env->ExceptionCheck()) { + return; + } + duckdb_table_function_set_name(tf, name_string.c_str()); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1table_1function_1add_1parameter(JNIEnv *env, jclass, + jobject table_function, + jobject logical_type) { + auto tf = table_function_buf_to_table_function(env, table_function); + if (env->ExceptionCheck()) { + return; + } + auto lt = logical_type_buf_to_logical_type(env, logical_type); + if (env->ExceptionCheck()) { + return; + } + duckdb_table_function_add_parameter(tf, lt); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1table_1function_1supports_1projection_1pushdown( + JNIEnv *env, jclass, jobject table_function, jboolean pushdown) { + auto tf = table_function_buf_to_table_function(env, table_function); + if (env->ExceptionCheck()) { + return; + } + duckdb_table_function_supports_projection_pushdown(tf, pushdown); +} + +JNIEXPORT jint JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1register_1table_1function(JNIEnv *env, jclass, + jobject connection, + jobject table_function) { + auto conn = conn_ref_buf_to_conn(env, connection); + if (env->ExceptionCheck()) { + return -1; + } + auto tf = table_function_buf_to_table_function(env, table_function); + if (env->ExceptionCheck()) { + return -1; + } + return static_cast(duckdb_register_table_function(conn, tf)); +} + +extern "C" JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1register_1table_1function_1java( + JNIEnv *env, jclass, jobject connection, jbyteArray name, jobject callback, jobjectArray parameter_logical_types, + jboolean supports_projection_pushdown, jint max_threads, jboolean thread_safe) { + try { + _duckdb_jdbc_register_table_function(env, nullptr, connection, name, callback, parameter_logical_types, + supports_projection_pushdown, max_threads, thread_safe); + } catch (const std::exception &e) { + duckdb::ErrorData error(e); + ThrowJNI(env, error.Message().c_str()); + } +} + +extern "C" JNIEXPORT void JNICALL +Java_org_duckdb_DuckDBBindings_duckdb_1register_1table_1function_1java_1with_1function( + JNIEnv *env, jclass, jobject connection, jobject table_function, jobject callback, + jobjectArray parameter_logical_types, jint max_threads, jboolean thread_safe) { + try { + _duckdb_jdbc_register_table_function_on_function(env, nullptr, connection, table_function, callback, + parameter_logical_types, max_threads, thread_safe); + } catch (const std::exception &e) { + duckdb::ErrorData error(e); + ThrowJNI(env, error.Message().c_str()); + } +} + +JNIEXPORT jlong JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1bind_1get_1parameter_1count(JNIEnv *env, jclass, + jobject bind_info_buf) { + auto info = bind_info_buf_to_bind_info(env, bind_info_buf); + if (env->ExceptionCheck()) { + return -1; + } + return static_cast(duckdb_bind_get_parameter_count(info)); +} + +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1bind_1get_1parameter(JNIEnv *env, jclass, + jobject bind_info_buf, + jlong index) { + auto info = bind_info_buf_to_bind_info(env, bind_info_buf); + if (env->ExceptionCheck()) { + return nullptr; + } + auto idx = jlong_to_idx(env, index); + if (env->ExceptionCheck()) { + return nullptr; + } + auto value = duckdb_bind_get_parameter(info, idx); + return make_ptr_buf(env, value); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1bind_1add_1result_1column(JNIEnv *env, jclass, + jobject bind_info_buf, + jbyteArray name, + jobject logical_type) { + auto info = bind_info_buf_to_bind_info(env, bind_info_buf); + if (env->ExceptionCheck()) { + return; + } + auto name_string = jbyteArray_to_string(env, name); + if (env->ExceptionCheck()) { + return; + } + auto lt = logical_type_buf_to_logical_type(env, logical_type); + if (env->ExceptionCheck()) { + return; + } + duckdb_bind_add_result_column(info, name_string.c_str(), lt); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1bind_1set_1bind_1data(JNIEnv *env, jclass, + jobject bind_info_buf, + jobject bind_data_buf) { + auto info = bind_info_buf_to_bind_info(env, bind_info_buf); + if (env->ExceptionCheck()) { + return; + } + if (bind_data_buf == nullptr) { + duckdb_bind_set_bind_data(info, nullptr, nullptr); + return; + } + void *bind_data = env->GetDirectBufferAddress(bind_data_buf); + duckdb_bind_set_bind_data(info, bind_data, nullptr); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1bind_1set_1error(JNIEnv *env, jclass, + jobject bind_info_buf, + jbyteArray error) { + auto info = bind_info_buf_to_bind_info(env, bind_info_buf); + if (env->ExceptionCheck()) { + return; + } + auto error_string = jbyteArray_to_string(env, error); + if (env->ExceptionCheck()) { + return; + } + duckdb_bind_set_error(info, error_string.c_str()); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1init_1set_1init_1data(JNIEnv *env, jclass, + jobject init_info_buf, + jobject init_data_buf) { + auto info = init_info_buf_to_init_info(env, init_info_buf); + if (env->ExceptionCheck()) { + return; + } + if (init_data_buf == nullptr) { + duckdb_init_set_init_data(info, nullptr, nullptr); + return; + } + void *init_data = env->GetDirectBufferAddress(init_data_buf); + duckdb_init_set_init_data(info, init_data, nullptr); +} + +JNIEXPORT jlong JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1init_1get_1column_1count(JNIEnv *env, jclass, + jobject init_info_buf) { + auto info = init_info_buf_to_init_info(env, init_info_buf); + if (env->ExceptionCheck()) { + return -1; + } + return static_cast(duckdb_init_get_column_count(info)); +} + +JNIEXPORT jlong JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1init_1get_1column_1index(JNIEnv *env, jclass, + jobject init_info_buf, + jlong column_index) { + auto info = init_info_buf_to_init_info(env, init_info_buf); + if (env->ExceptionCheck()) { + return -1; + } + auto idx = jlong_to_idx(env, column_index); + if (env->ExceptionCheck()) { + return -1; + } + return static_cast(duckdb_init_get_column_index(info, idx)); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1init_1set_1max_1threads(JNIEnv *env, jclass, + jobject init_info_buf, + jlong max_threads) { + auto info = init_info_buf_to_init_info(env, init_info_buf); + if (env->ExceptionCheck()) { + return; + } + duckdb_init_set_max_threads(info, static_cast(max_threads)); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1init_1set_1error(JNIEnv *env, jclass, + jobject init_info_buf, + jbyteArray error) { + auto info = init_info_buf_to_init_info(env, init_info_buf); + if (env->ExceptionCheck()) { + return; + } + auto error_string = jbyteArray_to_string(env, error); + if (env->ExceptionCheck()) { + return; + } + duckdb_init_set_error(info, error_string.c_str()); +} + +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1function_1get_1bind_1data(JNIEnv *env, jclass, + jobject function_info_buf) { + auto info = function_info_buf_to_function_info(env, function_info_buf); + if (env->ExceptionCheck()) { + return nullptr; + } + return make_ptr_buf(env, duckdb_function_get_bind_data(info)); +} + +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1function_1get_1init_1data(JNIEnv *env, jclass, + jobject function_info_buf) { + auto info = function_info_buf_to_function_info(env, function_info_buf); + if (env->ExceptionCheck()) { + return nullptr; + } + return make_ptr_buf(env, duckdb_function_get_init_data(info)); +} + +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1function_1get_1local_1init_1data( + JNIEnv *env, jclass, jobject function_info_buf) { + auto info = function_info_buf_to_function_info(env, function_info_buf); + if (env->ExceptionCheck()) { + return nullptr; + } + return make_ptr_buf(env, duckdb_function_get_local_init_data(info)); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1function_1set_1error(JNIEnv *env, jclass, + jobject function_info_buf, + jbyteArray error) { + auto info = function_info_buf_to_function_info(env, function_info_buf); + if (env->ExceptionCheck()) { + return; + } + auto error_string = jbyteArray_to_string(env, error); + if (env->ExceptionCheck()) { + return; + } + duckdb_function_set_error(info, error_string.c_str()); +} diff --git a/src/jni/bindings_vector.cpp b/src/jni/bindings_vector.cpp index 56876d68e..582739ac3 100644 --- a/src/jni/bindings_vector.cpp +++ b/src/jni/bindings_vector.cpp @@ -1,5 +1,6 @@ #include "bindings.hpp" #include "refs.hpp" +#include "udf_vector_accessors.hpp" #include "util.hpp" #include @@ -293,3 +294,38 @@ JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1array_1vector_1 return make_ptr_buf(env, res); } + +JNIEXPORT jbyteArray JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1udf_1get_1varchar_1bytes(JNIEnv *env, jclass clazz, + jobject vector_ref, + jint row) { + return _duckdb_jdbc_udf_get_varchar_bytes(env, clazz, vector_ref, row); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1udf_1set_1varchar_1bytes(JNIEnv *env, jclass clazz, + jobject vector_ref, jint row, + jbyteArray value) { + _duckdb_jdbc_udf_set_varchar_bytes(env, clazz, vector_ref, row, value); +} + +JNIEXPORT jbyteArray JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1udf_1get_1blob_1bytes(JNIEnv *env, jclass clazz, + jobject vector_ref, + jint row) { + return _duckdb_jdbc_udf_get_blob_bytes(env, clazz, vector_ref, row); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1udf_1set_1blob_1bytes(JNIEnv *env, jclass clazz, + jobject vector_ref, jint row, + jbyteArray value) { + _duckdb_jdbc_udf_set_blob_bytes(env, clazz, vector_ref, row, value); +} + +JNIEXPORT jobject JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1udf_1get_1decimal(JNIEnv *env, jclass clazz, + jobject vector_ref, jint row) { + return _duckdb_jdbc_udf_get_decimal(env, clazz, vector_ref, row); +} + +JNIEXPORT void JNICALL Java_org_duckdb_DuckDBBindings_duckdb_1udf_1set_1decimal(JNIEnv *env, jclass clazz, + jobject vector_ref, jint row, + jobject value) { + _duckdb_jdbc_udf_set_decimal(env, clazz, vector_ref, row, value); +} diff --git a/src/jni/duckdb_java.cpp b/src/jni/duckdb_java.cpp index 436dda5c4..a51a2f9d6 100644 --- a/src/jni/duckdb_java.cpp +++ b/src/jni/duckdb_java.cpp @@ -23,6 +23,7 @@ extern "C" { #include "util.hpp" #include +#include #include using namespace duckdb; @@ -30,11 +31,6 @@ using namespace std; static jint JNI_VERSION = JNI_VERSION_1_6; -void ThrowJNI(JNIEnv *env, const char *message) { - D_ASSERT(J_SQLException); - env->ThrowNew(J_SQLException, message); -} - JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { JNIEnv *env; if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { diff --git a/src/jni/refs.cpp b/src/jni/refs.cpp index f5a0a795e..f9b51cbf9 100644 --- a/src/jni/refs.cpp +++ b/src/jni/refs.cpp @@ -25,8 +25,15 @@ jclass J_Byte; jclass J_Short; jclass J_Int; jclass J_Long; +jmethodID J_Bool_init; +jmethodID J_Byte_init; +jmethodID J_Short_init; +jmethodID J_Int_init; +jmethodID J_Long_init; jclass J_Float; jclass J_Double; +jmethodID J_Float_init; +jmethodID J_Double_init; jclass J_String; jclass J_Timestamp; jmethodID J_Timestamp_valueOf; @@ -49,6 +56,7 @@ jmethodID J_BigDecimal_scale; jmethodID J_BigDecimal_scaleByPowTen; jmethodID J_BigDecimal_toPlainString; jmethodID J_BigDecimal_longValue; +jmethodID J_BigDecimal_initString; jfieldID J_HugeInt_lower; jfieldID J_HugeInt_upper; @@ -85,8 +93,14 @@ jmethodID J_DuckMap_getSQLTypeName; jclass J_List; jmethodID J_List_iterator; +jclass J_ArrayList; +jmethodID J_ArrayList_init; +jmethodID J_ArrayList_add; jclass J_Map; jmethodID J_Map_entrySet; +jclass J_LinkedHashMap; +jmethodID J_LinkedHashMap_init; +jmethodID J_LinkedHashMap_put; jclass J_Set; jmethodID J_Set_iterator; jclass J_Iterator; @@ -97,14 +111,31 @@ jmethodID J_Entry_getKey; jmethodID J_Entry_getValue; jclass J_UUID; +jmethodID J_UUID_init; jmethodID J_UUID_getMostSignificantBits; jmethodID J_UUID_getLeastSignificantBits; +jclass J_LocalDate; +jmethodID J_LocalDate_ofEpochDay; +jclass J_LocalTime; +jmethodID J_LocalTime_ofNanoOfDay; +jclass J_LocalDateTime; +jmethodID J_LocalDateTime_ofEpochSecond; +jmethodID J_LocalDateTime_atOffset; +jclass J_OffsetTime; +jmethodID J_OffsetTime_of; +jclass J_ZoneOffset; +jobject J_ZoneOffset_UTC; +jmethodID J_ZoneOffset_ofTotalSeconds; + jclass J_DuckDBDate; jmethodID J_DuckDBDate_getDaysSinceEpoch; jclass J_Object; jmethodID J_Object_toString; +jclass J_StringArray; +jclass J_Enum; +jmethodID J_Enum_name; jclass J_DuckDBTime; @@ -119,6 +150,40 @@ jobject J_ProfilerPrintFormat_GRAPHVIZ; jclass J_QueryProgress; jmethodID J_QueryProgress_init; +jclass J_ScalarUdf; +jmethodID J_ScalarUdf_apply; +jclass J_UdfReader; +jclass J_UdfNativeReader; +jmethodID J_UdfNativeReader_init; +jclass J_UdfScalarWriter; +jmethodID J_UdfScalarWriter_init; +jclass J_TableFunction; +jmethodID J_TableFunction_bind; +jmethodID J_TableFunction_init; +jmethodID J_TableFunction_produce; +jclass J_TableBindResult; +jmethodID J_TableBindResult_getColumnNames; +jmethodID J_TableBindResult_getColumnTypes; +jmethodID J_TableBindResult_getColumnLogicalTypes; +jclass J_TableState; +jclass J_TableInitContext; +jmethodID J_TableInitContext_init; +jclass J_UdfOutputAppender; +jmethodID J_UdfOutputAppender_init; +jmethodID J_UdfOutputAppender_close; +jclass J_DuckDBColumnType; +jclass J_UdfLogicalType; +jmethodID J_UdfLogicalType_getType; +jmethodID J_UdfLogicalType_getChildType; +jmethodID J_UdfLogicalType_getArraySize; +jmethodID J_UdfLogicalType_getKeyType; +jmethodID J_UdfLogicalType_getValueType; +jmethodID J_UdfLogicalType_getFieldNames; +jmethodID J_UdfLogicalType_getFieldTypes; +jmethodID J_UdfLogicalType_getEnumValues; +jmethodID J_UdfLogicalType_getDecimalWidth; +jmethodID J_UdfLogicalType_getDecimalScale; + static std::vector global_refs; template @@ -190,8 +255,15 @@ void create_refs(JNIEnv *env) { J_Short = make_class_ref(env, "java/lang/Short"); J_Int = make_class_ref(env, "java/lang/Integer"); J_Long = make_class_ref(env, "java/lang/Long"); + J_Bool_init = get_method_id(env, J_Bool, "", "(Z)V"); + J_Byte_init = get_method_id(env, J_Byte, "", "(B)V"); + J_Short_init = get_method_id(env, J_Short, "", "(S)V"); + J_Int_init = get_method_id(env, J_Int, "", "(I)V"); + J_Long_init = get_method_id(env, J_Long, "", "(J)V"); J_Float = make_class_ref(env, "java/lang/Float"); J_Double = make_class_ref(env, "java/lang/Double"); + J_Float_init = get_method_id(env, J_Float, "", "(F)V"); + J_Double_init = get_method_id(env, J_Double, "", "(D)V"); J_String = make_class_ref(env, "java/lang/String"); J_BigDecimal = make_class_ref(env, "java/math/BigDecimal"); J_HugeInt = make_class_ref(env, "org/duckdb/DuckDBHugeInt"); @@ -210,8 +282,15 @@ void create_refs(JNIEnv *env) { J_List = make_class_ref(env, "java/util/List"); J_List_iterator = get_method_id(env, J_List, "iterator", "()Ljava/util/Iterator;"); + J_ArrayList = make_class_ref(env, "java/util/ArrayList"); + J_ArrayList_init = get_method_id(env, J_ArrayList, "", "()V"); + J_ArrayList_add = get_method_id(env, J_ArrayList, "add", "(Ljava/lang/Object;)Z"); J_Map = make_class_ref(env, "java/util/Map"); J_Map_entrySet = get_method_id(env, J_Map, "entrySet", "()Ljava/util/Set;"); + J_LinkedHashMap = make_class_ref(env, "java/util/LinkedHashMap"); + J_LinkedHashMap_init = get_method_id(env, J_LinkedHashMap, "", "()V"); + J_LinkedHashMap_put = + get_method_id(env, J_LinkedHashMap, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"); J_Set = make_class_ref(env, "java/util/Set"); J_Set_iterator = get_method_id(env, J_Set, "iterator", "()Ljava/util/Iterator;"); J_Iterator = make_class_ref(env, "java/util/Iterator"); @@ -219,8 +298,25 @@ void create_refs(JNIEnv *env) { J_Iterator_next = get_method_id(env, J_Iterator, "next", "()Ljava/lang/Object;"); J_UUID = make_class_ref(env, "java/util/UUID"); + J_UUID_init = get_method_id(env, J_UUID, "", "(JJ)V"); J_UUID_getMostSignificantBits = get_method_id(env, J_UUID, "getMostSignificantBits", "()J"); J_UUID_getLeastSignificantBits = get_method_id(env, J_UUID, "getLeastSignificantBits", "()J"); + J_LocalDate = make_class_ref(env, "java/time/LocalDate"); + J_LocalDate_ofEpochDay = get_static_method_id(env, J_LocalDate, "ofEpochDay", "(J)Ljava/time/LocalDate;"); + J_LocalTime = make_class_ref(env, "java/time/LocalTime"); + J_LocalTime_ofNanoOfDay = get_static_method_id(env, J_LocalTime, "ofNanoOfDay", "(J)Ljava/time/LocalTime;"); + J_LocalDateTime = make_class_ref(env, "java/time/LocalDateTime"); + J_LocalDateTime_ofEpochSecond = get_static_method_id(env, J_LocalDateTime, "ofEpochSecond", + "(JILjava/time/ZoneOffset;)Ljava/time/LocalDateTime;"); + J_LocalDateTime_atOffset = + get_method_id(env, J_LocalDateTime, "atOffset", "(Ljava/time/ZoneOffset;)Ljava/time/OffsetDateTime;"); + J_OffsetTime = make_class_ref(env, "java/time/OffsetTime"); + J_OffsetTime_of = get_static_method_id(env, J_OffsetTime, "of", + "(Ljava/time/LocalTime;Ljava/time/ZoneOffset;)Ljava/time/OffsetTime;"); + J_ZoneOffset = make_class_ref(env, "java/time/ZoneOffset"); + J_ZoneOffset_UTC = make_static_object_field_ref(env, J_ZoneOffset, "UTC", "Ljava/time/ZoneOffset;"); + J_ZoneOffset_ofTotalSeconds = + get_static_method_id(env, J_ZoneOffset, "ofTotalSeconds", "(I)Ljava/time/ZoneOffset;"); J_DuckArray = make_class_ref(env, "org/duckdb/DuckDBArray"); J_DuckArray_init = get_method_id(env, J_DuckArray, "", "(Lorg/duckdb/DuckDBVector;II)V"); @@ -239,6 +335,9 @@ void create_refs(JNIEnv *env) { J_Object = make_class_ref(env, "java/lang/Object"); J_Object_toString = get_method_id(env, J_Object, "toString", "()Ljava/lang/String;"); + J_StringArray = make_class_ref(env, "[Ljava/lang/String;"); + J_Enum = make_class_ref(env, "java/lang/Enum"); + J_Enum_name = get_method_id(env, J_Enum, "name", "()Ljava/lang/String;"); J_Entry = make_class_ref(env, "java/util/Map$Entry"); J_Entry_getKey = get_method_id(env, J_Entry, "getKey", "()Ljava/lang/Object;"); @@ -258,6 +357,7 @@ void create_refs(JNIEnv *env) { J_BigDecimal_scaleByPowTen = get_method_id(env, J_BigDecimal, "scaleByPowerOfTen", "(I)Ljava/math/BigDecimal;"); J_BigDecimal_toPlainString = get_method_id(env, J_BigDecimal, "toPlainString", "()Ljava/lang/String;"); J_BigDecimal_longValue = get_method_id(env, J_BigDecimal, "longValue", "()J"); + J_BigDecimal_initString = get_method_id(env, J_BigDecimal, "", "(Ljava/lang/String;)V"); J_HugeInt_lower = get_field_id(env, J_HugeInt, "lower", "J"); J_HugeInt_upper = get_field_id(env, J_HugeInt, "upper", "J"); @@ -296,6 +396,55 @@ void create_refs(JNIEnv *env) { J_QueryProgress = make_class_ref(env, "org/duckdb/QueryProgress"); J_QueryProgress_init = get_method_id(env, J_QueryProgress, "", "(DJJ)V"); + + J_ScalarUdf = make_class_ref(env, "org/duckdb/udf/ScalarUdf"); + J_ScalarUdf_apply = get_method_id(env, J_ScalarUdf, "apply", + "(Lorg/duckdb/udf/UdfContext;[Lorg/duckdb/UdfReader;" + "Lorg/duckdb/UdfScalarWriter;I)V"); + J_UdfReader = make_class_ref(env, "org/duckdb/UdfReader"); + J_UdfNativeReader = make_class_ref(env, "org/duckdb/UdfNativeReader"); + J_UdfNativeReader_init = get_method_id(env, J_UdfNativeReader, "", + "(ILjava/nio/ByteBuffer;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;I)V"); + J_UdfScalarWriter = make_class_ref(env, "org/duckdb/UdfScalarWriter"); + J_UdfScalarWriter_init = get_method_id(env, J_UdfScalarWriter, "", + "(ILjava/nio/ByteBuffer;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;I)V"); + J_TableFunction = make_class_ref(env, "org/duckdb/udf/TableFunction"); + J_TableFunction_bind = get_method_id(env, J_TableFunction, "bind", + "(Lorg/duckdb/udf/BindContext;[Ljava/lang/Object;)Lorg/duckdb/udf/" + "TableBindResult;"); + J_TableFunction_init = + get_method_id(env, J_TableFunction, "init", + "(Lorg/duckdb/udf/InitContext;Lorg/duckdb/udf/TableBindResult;)Lorg/duckdb/udf/TableState;"); + J_TableFunction_produce = + get_method_id(env, J_TableFunction, "produce", "(Lorg/duckdb/udf/TableState;Lorg/duckdb/UdfOutputAppender;)I"); + J_TableBindResult = make_class_ref(env, "org/duckdb/udf/TableBindResult"); + J_TableBindResult_getColumnNames = get_method_id(env, J_TableBindResult, "getColumnNames", "()[Ljava/lang/String;"); + J_TableBindResult_getColumnTypes = + get_method_id(env, J_TableBindResult, "getColumnTypes", "()[Lorg/duckdb/DuckDBColumnType;"); + J_TableBindResult_getColumnLogicalTypes = + get_method_id(env, J_TableBindResult, "getColumnLogicalTypes", "()[Lorg/duckdb/udf/UdfLogicalType;"); + J_TableState = make_class_ref(env, "org/duckdb/udf/TableState"); + J_TableInitContext = make_class_ref(env, "org/duckdb/udf/TableInitContext"); + J_TableInitContext_init = get_method_id(env, J_TableInitContext, "", "([I)V"); + J_UdfOutputAppender = make_class_ref(env, "org/duckdb/UdfOutputAppender"); + J_UdfOutputAppender_init = get_method_id(env, J_UdfOutputAppender, "", "(Ljava/nio/ByteBuffer;)V"); + J_UdfOutputAppender_close = get_method_id(env, J_UdfOutputAppender, "close", "()V"); + J_DuckDBColumnType = make_class_ref(env, "org/duckdb/DuckDBColumnType"); + J_UdfLogicalType = make_class_ref(env, "org/duckdb/udf/UdfLogicalType"); + J_UdfLogicalType_getType = get_method_id(env, J_UdfLogicalType, "getType", "()Lorg/duckdb/DuckDBColumnType;"); + J_UdfLogicalType_getChildType = + get_method_id(env, J_UdfLogicalType, "getChildType", "()Lorg/duckdb/udf/UdfLogicalType;"); + J_UdfLogicalType_getArraySize = get_method_id(env, J_UdfLogicalType, "getArraySize", "()J"); + J_UdfLogicalType_getKeyType = + get_method_id(env, J_UdfLogicalType, "getKeyType", "()Lorg/duckdb/udf/UdfLogicalType;"); + J_UdfLogicalType_getValueType = + get_method_id(env, J_UdfLogicalType, "getValueType", "()Lorg/duckdb/udf/UdfLogicalType;"); + J_UdfLogicalType_getFieldNames = get_method_id(env, J_UdfLogicalType, "getFieldNames", "()[Ljava/lang/String;"); + J_UdfLogicalType_getFieldTypes = + get_method_id(env, J_UdfLogicalType, "getFieldTypes", "()[Lorg/duckdb/udf/UdfLogicalType;"); + J_UdfLogicalType_getEnumValues = get_method_id(env, J_UdfLogicalType, "getEnumValues", "()[Ljava/lang/String;"); + J_UdfLogicalType_getDecimalWidth = get_method_id(env, J_UdfLogicalType, "getDecimalWidth", "()I"); + J_UdfLogicalType_getDecimalScale = get_method_id(env, J_UdfLogicalType, "getDecimalScale", "()I"); } void delete_global_refs(JNIEnv *env) noexcept { diff --git a/src/jni/refs.hpp b/src/jni/refs.hpp index cd7b20121..865f2a7b5 100644 --- a/src/jni/refs.hpp +++ b/src/jni/refs.hpp @@ -22,8 +22,15 @@ extern jclass J_Byte; extern jclass J_Short; extern jclass J_Int; extern jclass J_Long; +extern jmethodID J_Bool_init; +extern jmethodID J_Byte_init; +extern jmethodID J_Short_init; +extern jmethodID J_Int_init; +extern jmethodID J_Long_init; extern jclass J_Float; extern jclass J_Double; +extern jmethodID J_Float_init; +extern jmethodID J_Double_init; extern jclass J_String; extern jclass J_Timestamp; extern jmethodID J_Timestamp_valueOf; @@ -46,6 +53,7 @@ extern jmethodID J_BigDecimal_scale; extern jmethodID J_BigDecimal_scaleByPowTen; extern jmethodID J_BigDecimal_toPlainString; extern jmethodID J_BigDecimal_longValue; +extern jmethodID J_BigDecimal_initString; extern jfieldID J_HugeInt_lower; extern jfieldID J_HugeInt_upper; @@ -82,8 +90,14 @@ extern jmethodID J_DuckMap_getSQLTypeName; extern jclass J_List; extern jmethodID J_List_iterator; +extern jclass J_ArrayList; +extern jmethodID J_ArrayList_init; +extern jmethodID J_ArrayList_add; extern jclass J_Map; extern jmethodID J_Map_entrySet; +extern jclass J_LinkedHashMap; +extern jmethodID J_LinkedHashMap_init; +extern jmethodID J_LinkedHashMap_put; extern jclass J_Set; extern jmethodID J_Set_iterator; extern jclass J_Iterator; @@ -94,14 +108,31 @@ extern jmethodID J_Entry_getKey; extern jmethodID J_Entry_getValue; extern jclass J_UUID; +extern jmethodID J_UUID_init; extern jmethodID J_UUID_getMostSignificantBits; extern jmethodID J_UUID_getLeastSignificantBits; +extern jclass J_LocalDate; +extern jmethodID J_LocalDate_ofEpochDay; +extern jclass J_LocalTime; +extern jmethodID J_LocalTime_ofNanoOfDay; +extern jclass J_LocalDateTime; +extern jmethodID J_LocalDateTime_ofEpochSecond; +extern jmethodID J_LocalDateTime_atOffset; +extern jclass J_OffsetTime; +extern jmethodID J_OffsetTime_of; +extern jclass J_ZoneOffset; +extern jobject J_ZoneOffset_UTC; +extern jmethodID J_ZoneOffset_ofTotalSeconds; + extern jclass J_DuckDBDate; extern jmethodID J_DuckDBDate_getDaysSinceEpoch; extern jclass J_Object; extern jmethodID J_Object_toString; +extern jclass J_StringArray; +extern jclass J_Enum; +extern jmethodID J_Enum_name; extern jclass J_DuckDBTime; @@ -116,6 +147,40 @@ extern jobject J_ProfilerPrintFormat_GRAPHVIZ; extern jclass J_QueryProgress; extern jmethodID J_QueryProgress_init; +extern jclass J_ScalarUdf; +extern jmethodID J_ScalarUdf_apply; +extern jclass J_UdfReader; +extern jclass J_UdfNativeReader; +extern jmethodID J_UdfNativeReader_init; +extern jclass J_UdfScalarWriter; +extern jmethodID J_UdfScalarWriter_init; +extern jclass J_TableFunction; +extern jmethodID J_TableFunction_bind; +extern jmethodID J_TableFunction_init; +extern jmethodID J_TableFunction_produce; +extern jclass J_TableBindResult; +extern jmethodID J_TableBindResult_getColumnNames; +extern jmethodID J_TableBindResult_getColumnTypes; +extern jmethodID J_TableBindResult_getColumnLogicalTypes; +extern jclass J_TableState; +extern jclass J_TableInitContext; +extern jmethodID J_TableInitContext_init; +extern jclass J_UdfOutputAppender; +extern jmethodID J_UdfOutputAppender_init; +extern jmethodID J_UdfOutputAppender_close; +extern jclass J_DuckDBColumnType; +extern jclass J_UdfLogicalType; +extern jmethodID J_UdfLogicalType_getType; +extern jmethodID J_UdfLogicalType_getChildType; +extern jmethodID J_UdfLogicalType_getArraySize; +extern jmethodID J_UdfLogicalType_getKeyType; +extern jmethodID J_UdfLogicalType_getValueType; +extern jmethodID J_UdfLogicalType_getFieldNames; +extern jmethodID J_UdfLogicalType_getFieldTypes; +extern jmethodID J_UdfLogicalType_getEnumValues; +extern jmethodID J_UdfLogicalType_getDecimalWidth; +extern jmethodID J_UdfLogicalType_getDecimalScale; + void create_refs(JNIEnv *env); void delete_global_refs(JNIEnv *env) noexcept; diff --git a/src/jni/udf_callbacks.cpp b/src/jni/udf_callbacks.cpp new file mode 100644 index 000000000..521685608 --- /dev/null +++ b/src/jni/udf_callbacks.cpp @@ -0,0 +1,550 @@ +extern "C" { +#include "duckdb.h" +} + +#include "refs.hpp" +#include "udf_callbacks.hpp" +#include "udf_table_bind_conversion.hpp" +#include "udf_types.hpp" +#include "util.hpp" + +#include +#include +#include + +static jobject create_scalar_udf_input_reader(JNIEnv *env, duckdb_vector vector, duckdb_type type, idx_t row_count, + jlong validity_size, std::vector &local_refs) { + auto spec = find_udf_type_spec(type); + if (!spec || !spec->udf_vector_supported) { + env->ThrowNew(J_SQLException, "Unsupported scalar UDF type"); + return nullptr; + } + + auto validity_ptr = reinterpret_cast(duckdb_vector_get_validity(vector)); + auto validity_buf = validity_ptr ? env->NewDirectByteBuffer(validity_ptr, validity_size) : nullptr; + if (validity_buf) { + local_refs.push_back(validity_buf); + } + + jobject data_buf = nullptr; + jobject vector_ref_buf = nullptr; + if (spec->requires_vector_ref) { + vector_ref_buf = env->NewDirectByteBuffer(vector, 0); + local_refs.push_back(vector_ref_buf); + } else { + auto data_ptr = duckdb_vector_get_data(vector); + data_buf = env->NewDirectByteBuffer( + data_ptr, static_cast(row_count * static_cast(spec->fixed_width_bytes))); + local_refs.push_back(data_buf); + } + + auto reader = env->NewObject(J_UdfNativeReader, J_UdfNativeReader_init, static_cast(type), data_buf, + vector_ref_buf, validity_buf, static_cast(row_count)); + local_refs.push_back(reader); + return reader; +} + +static jobject create_scalar_udf_output_writer(JNIEnv *env, duckdb_vector vector, duckdb_type type, idx_t row_count, + jlong validity_size, std::vector &local_refs) { + auto spec = find_udf_type_spec(type); + if (!spec || !spec->udf_vector_supported) { + env->ThrowNew(J_SQLException, "Unsupported scalar UDF output type"); + return nullptr; + } + + auto validity_ptr = reinterpret_cast(duckdb_vector_get_validity(vector)); + auto validity_buf = validity_ptr ? env->NewDirectByteBuffer(validity_ptr, validity_size) : nullptr; + if (validity_buf) { + local_refs.push_back(validity_buf); + } + + jobject data_buf = nullptr; + jobject vector_ref_buf = nullptr; + if (spec->requires_vector_ref) { + vector_ref_buf = env->NewDirectByteBuffer(vector, 0); + local_refs.push_back(vector_ref_buf); + } else { + auto data_ptr = duckdb_vector_get_data(vector); + data_buf = env->NewDirectByteBuffer( + data_ptr, static_cast(row_count * static_cast(spec->fixed_width_bytes))); + local_refs.push_back(data_buf); + } + + auto writer = env->NewObject(J_UdfScalarWriter, J_UdfScalarWriter_init, static_cast(type), data_buf, + vector_ref_buf, validity_buf, static_cast(row_count)); + local_refs.push_back(writer); + return writer; +} + +void destroy_java_scalar_udf_callback_data(void *ptr) { + if (!ptr) { + return; + } + auto data = reinterpret_cast(ptr); + CallbackEnvGuard env_guard(data->vm); + auto env = env_guard.env(); + if (env && data->callback_ref) { + delete_global_ref(env, data->callback_ref); + } + delete data; +} + +void destroy_java_table_function_callback_data(void *ptr) { + if (!ptr) { + return; + } + auto data = reinterpret_cast(ptr); + for (auto &logical_type : data->parameter_logical_types) { + if (logical_type) { + duckdb_destroy_logical_type(&logical_type); + } + } + data->parameter_logical_types.clear(); + CallbackEnvGuard env_guard(data->vm); + auto env = env_guard.env(); + if (env && data->callback_ref) { + delete_global_ref(env, data->callback_ref); + } + delete data; +} + +void destroy_java_table_function_bind_data(void *ptr) { + if (!ptr) { + return; + } + auto data = reinterpret_cast(ptr); + CallbackEnvGuard env_guard(data->vm); + auto env = env_guard.env(); + if (env && data->bind_result_ref) { + delete_global_ref(env, data->bind_result_ref); + } + delete data; +} + +void destroy_java_table_function_init_data(void *ptr) { + if (!ptr) { + return; + } + auto data = reinterpret_cast(ptr); + CallbackEnvGuard env_guard(data->vm); + auto env = env_guard.env(); + if (env && data->state_ref) { + delete_global_ref(env, data->state_ref); + } + delete data; +} + +void java_scalar_udf_callback(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output) { + auto data = reinterpret_cast(duckdb_scalar_function_get_extra_info(info)); + if (!data) { + duckdb_scalar_function_set_error(info, "Missing callback state for Java scalar UDF"); + return; + } + + CallbackEnvGuard env_guard(data->vm); + auto env = env_guard.env(); + if (!env) { + duckdb_scalar_function_set_error(info, "Failed to acquire JNIEnv for Java scalar UDF callback"); + return; + } + + auto row_count = duckdb_data_chunk_get_size(input); + auto arg_count = duckdb_data_chunk_get_column_count(input); + if ((!data->var_args && data->argument_types.size() != arg_count) || + (data->var_args && data->argument_types.size() != 1)) { + duckdb_scalar_function_set_error(info, "Scalar UDF argument mismatch"); + return; + } + duckdb_vector_ensure_validity_writable(output); + auto output_validity_ptr = reinterpret_cast(duckdb_vector_get_validity(output)); + + auto validity_size = static_cast(((row_count + 63) / 64) * sizeof(uint64_t)); + auto args = env->NewObjectArray(arg_count, J_UdfReader, nullptr); + std::vector local_refs; + local_refs.push_back(args); + for (idx_t arg_idx = 0; arg_idx < arg_count; arg_idx++) { + auto input_vector = duckdb_data_chunk_get_vector(input, arg_idx); + auto argument_type = data->var_args ? data->var_args_type : data->argument_types[arg_idx]; + auto input_reader = + create_scalar_udf_input_reader(env, input_vector, argument_type, row_count, validity_size, local_refs); + env->SetObjectArrayElement(args, static_cast(arg_idx), input_reader); + } + if (env->ExceptionCheck()) { + duckdb_scalar_function_set_error(info, "Failed to materialize scalar UDF input readers"); + delete_local_refs(env, local_refs); + return; + } + auto output_writer = + create_scalar_udf_output_writer(env, output, data->return_type, row_count, validity_size, local_refs); + if (env->ExceptionCheck()) { + duckdb_scalar_function_set_error(info, "Failed to materialize scalar UDF output writer"); + delete_local_refs(env, local_refs); + return; + } + + env->CallVoidMethod(data->callback_ref, J_ScalarUdf_apply, nullptr, args, output_writer, + static_cast(row_count)); + + if (env->ExceptionCheck()) { + auto exception = env->ExceptionOccurred(); + env->ExceptionClear(); + if (data->return_null_on_exception) { + if (output_validity_ptr) { + std::memset(output_validity_ptr, 0, static_cast(validity_size)); + } + if (exception) { + delete_local_ref(env, exception); + } + } else { + std::string error = "Exception in Java scalar UDF callback"; + if (exception) { + auto message = reinterpret_cast(env->CallObjectMethod(exception, J_Throwable_getMessage)); + if (message != nullptr) { + error = jstring_to_string(env, message); + delete_local_ref(env, message); + } + delete_local_ref(env, exception); + } + duckdb_scalar_function_set_error(info, error.c_str()); + } + } + + delete_local_refs(env, local_refs); +} + +void java_table_function_bind_callback(duckdb_bind_info info) { + auto callback_data = reinterpret_cast(duckdb_bind_get_extra_info(info)); + if (!callback_data) { + duckdb_bind_set_error(info, "Missing callback state for Java table function"); + return; + } + CallbackEnvGuard env_guard(callback_data->vm); + auto env = env_guard.env(); + if (!env) { + duckdb_bind_set_error(info, "Failed to acquire JNIEnv for Java table bind callback"); + return; + } + + auto parameter_count = duckdb_bind_get_parameter_count(info); + if (callback_data->parameter_logical_types.size() != parameter_count) { + duckdb_bind_set_error(info, "Table function parameter count mismatch"); + return; + } + auto parameters = env->NewObjectArray(static_cast(parameter_count), J_Object, nullptr); + std::vector bind_local_refs; + bind_local_refs.push_back(parameters); + for (idx_t i = 0; i < parameter_count; i++) { + auto val = duckdb_bind_get_parameter(info, i); + std::string parameter_error; + auto param_obj = table_bind_parameter_to_java(env, val, callback_data->parameter_logical_types[i], + bind_local_refs, parameter_error); + duckdb_destroy_value(&val); + if ((!parameter_error.empty() && param_obj == nullptr) || env->ExceptionCheck()) { + if (parameter_error.empty()) { + parameter_error = "Failed to materialize table function bind parameter"; + } + duckdb_bind_set_error(info, parameter_error.c_str()); + delete_local_refs(env, bind_local_refs); + return; + } + env->SetObjectArrayElement(parameters, static_cast(i), param_obj); + if (env->ExceptionCheck()) { + duckdb_bind_set_error(info, "Failed to pass table function bind parameters to Java"); + delete_local_refs(env, bind_local_refs); + return; + } + } + + auto bind_result = env->CallObjectMethod(callback_data->callback_ref, J_TableFunction_bind, nullptr, parameters); + if (env->ExceptionCheck()) { + auto exception = env->ExceptionOccurred(); + env->ExceptionClear(); + std::string error = "Exception in Java table function bind callback"; + if (exception) { + auto message = reinterpret_cast(env->CallObjectMethod(exception, J_Throwable_getMessage)); + if (message != nullptr) { + error = jstring_to_string(env, message); + delete_local_ref(env, message); + } + delete_local_ref(env, exception); + } + duckdb_bind_set_error(info, error.c_str()); + delete_local_refs(env, bind_local_refs); + return; + } + if (bind_result == nullptr) { + duckdb_bind_set_error(info, "Java table function bind returned null"); + delete_local_refs(env, bind_local_refs); + return; + } + + auto column_names = + reinterpret_cast(env->CallObjectMethod(bind_result, J_TableBindResult_getColumnNames)); + auto column_types = + reinterpret_cast(env->CallObjectMethod(bind_result, J_TableBindResult_getColumnTypes)); + auto column_logical_types = + reinterpret_cast(env->CallObjectMethod(bind_result, J_TableBindResult_getColumnLogicalTypes)); + if (env->ExceptionCheck() || column_names == nullptr || column_types == nullptr) { + duckdb_bind_set_error(info, "Invalid Java table bind result"); + delete_local_ref(env, column_names); + delete_local_ref(env, column_types); + delete_local_ref(env, column_logical_types); + delete_local_ref(env, bind_result); + delete_local_refs(env, bind_local_refs); + return; + } + auto name_count = env->GetArrayLength(column_names); + auto type_count = env->GetArrayLength(column_types); + if (name_count != type_count) { + duckdb_bind_set_error(info, "Java table bind result has mismatched schema lengths"); + delete_local_ref(env, column_names); + delete_local_ref(env, column_types); + delete_local_ref(env, column_logical_types); + delete_local_ref(env, bind_result); + delete_local_refs(env, bind_local_refs); + return; + } + if (column_logical_types != nullptr && env->GetArrayLength(column_logical_types) != name_count) { + duckdb_bind_set_error(info, "Java table bind result has mismatched logical schema lengths"); + delete_local_ref(env, column_names); + delete_local_ref(env, column_types); + delete_local_ref(env, column_logical_types); + delete_local_ref(env, bind_result); + delete_local_refs(env, bind_local_refs); + return; + } + for (jsize i = 0; i < name_count; i++) { + auto name_j = reinterpret_cast(env->GetObjectArrayElement(column_names, i)); + auto type_obj = env->GetObjectArrayElement(column_types, i); + if (!name_j || !type_obj) { + delete_local_ref(env, name_j); + delete_local_ref(env, type_obj); + duckdb_bind_set_error(info, "Unsupported column descriptor in Java table bind result"); + delete_local_ref(env, column_names); + delete_local_ref(env, column_types); + delete_local_ref(env, column_logical_types); + delete_local_ref(env, bind_result); + delete_local_refs(env, bind_local_refs); + return; + } + auto name = jstring_to_string(env, name_j); + duckdb_logical_type logical_type = nullptr; + if (column_logical_types != nullptr) { + auto logical_type_obj = env->GetObjectArrayElement(column_logical_types, i); + std::string logical_error; + logical_type = create_table_logical_type_from_java(env, logical_type_obj, logical_error); + delete_local_ref(env, logical_type_obj); + if (env->ExceptionCheck() || !logical_type) { + if (logical_error.empty()) { + logical_error = "Unsupported logical type in Java table bind result"; + } + duckdb_bind_set_error(info, logical_error.c_str()); + delete_local_ref(env, name_j); + delete_local_ref(env, type_obj); + delete_local_ref(env, column_names); + delete_local_ref(env, column_types); + delete_local_ref(env, column_logical_types); + delete_local_ref(env, bind_result); + delete_local_refs(env, bind_local_refs); + return; + } + } else { + duckdb_type duck_type = DUCKDB_TYPE_INVALID; + if (!table_column_type_from_java(env, type_obj, duck_type)) { + duckdb_bind_set_error(info, "Unsupported column type in Java table bind result"); + delete_local_ref(env, name_j); + delete_local_ref(env, type_obj); + delete_local_ref(env, column_names); + delete_local_ref(env, column_types); + delete_local_ref(env, bind_result); + delete_local_refs(env, bind_local_refs); + return; + } + logical_type = create_udf_logical_type(duck_type); + } + duckdb_bind_add_result_column(info, name.c_str(), logical_type); + duckdb_destroy_logical_type(&logical_type); + delete_local_ref(env, name_j); + delete_local_ref(env, type_obj); + } + + auto bind_data = new JavaTableFunctionBindData(); + bind_data->vm = callback_data->vm; + bind_data->bind_result_ref = env->NewGlobalRef(bind_result); + if (!bind_data->bind_result_ref) { + delete bind_data; + duckdb_bind_set_error(info, "Failed to create global ref for Java table bind state"); + delete_local_ref(env, column_names); + delete_local_ref(env, column_types); + delete_local_ref(env, column_logical_types); + delete_local_ref(env, bind_result); + delete_local_refs(env, bind_local_refs); + return; + } + duckdb_bind_set_bind_data(info, bind_data, destroy_java_table_function_bind_data); + + delete_local_ref(env, column_names); + delete_local_ref(env, column_types); + delete_local_ref(env, column_logical_types); + delete_local_ref(env, bind_result); + delete_local_refs(env, bind_local_refs); +} + +void java_table_function_init_callback(duckdb_init_info info) { + auto callback_data = reinterpret_cast(duckdb_init_get_extra_info(info)); + auto bind_data = reinterpret_cast(duckdb_init_get_bind_data(info)); + if (!callback_data || !bind_data) { + duckdb_init_set_error(info, "Missing callback/bind state for Java table function init"); + return; + } + CallbackEnvGuard env_guard(callback_data->vm); + auto env = env_guard.env(); + if (!env) { + duckdb_init_set_error(info, "Failed to acquire JNIEnv for Java table init callback"); + return; + } + + auto projected_column_count = duckdb_init_get_column_count(info); + auto projected_column_indexes = env->NewIntArray(static_cast(projected_column_count)); + if (!projected_column_indexes) { + duckdb_init_set_error(info, "Failed to allocate projected column index array"); + return; + } + std::vector projected_columns; + projected_columns.reserve(projected_column_count); + for (idx_t i = 0; i < projected_column_count; i++) { + projected_columns.push_back(static_cast(duckdb_init_get_column_index(info, i))); + } + if (!projected_columns.empty()) { + env->SetIntArrayRegion(projected_column_indexes, 0, static_cast(projected_columns.size()), + projected_columns.data()); + } + auto init_ctx = env->NewObject(J_TableInitContext, J_TableInitContext_init, projected_column_indexes); + delete_local_ref(env, projected_column_indexes); + if (env->ExceptionCheck() || !init_ctx) { + duckdb_init_set_error(info, "Failed to construct Java table init context"); + return; + } + + auto state = + env->CallObjectMethod(callback_data->callback_ref, J_TableFunction_init, init_ctx, bind_data->bind_result_ref); + delete_local_ref(env, init_ctx); + if (env->ExceptionCheck()) { + auto exception = env->ExceptionOccurred(); + env->ExceptionClear(); + std::string error = "Exception in Java table function init callback"; + if (exception) { + auto message = reinterpret_cast(env->CallObjectMethod(exception, J_Throwable_getMessage)); + if (message != nullptr) { + error = jstring_to_string(env, message); + delete_local_ref(env, message); + } + delete_local_ref(env, exception); + } + duckdb_init_set_error(info, error.c_str()); + return; + } + if (state == nullptr) { + duckdb_init_set_error(info, "Java table function init returned null"); + return; + } + auto init_data = new JavaTableFunctionInitData(); + init_data->vm = callback_data->vm; + init_data->state_ref = env->NewGlobalRef(state); + if (!init_data->state_ref) { + delete init_data; + duckdb_init_set_error(info, "Failed to create global ref for Java table init state"); + delete_local_ref(env, state); + return; + } + duckdb_init_set_init_data(info, init_data, destroy_java_table_function_init_data); + if (callback_data->thread_safe) { + duckdb_init_set_max_threads(info, callback_data->max_threads < 1 ? 1 : callback_data->max_threads); + } else { + duckdb_init_set_max_threads(info, 1); + } + + delete_local_ref(env, state); +} + +void java_table_function_main_callback(duckdb_function_info info, duckdb_data_chunk output) { + auto callback_data = reinterpret_cast(duckdb_function_get_extra_info(info)); + auto init_data = reinterpret_cast(duckdb_function_get_init_data(info)); + if (!callback_data || !init_data) { + duckdb_function_set_error(info, "Missing callback/init state for Java table function"); + return; + } + CallbackEnvGuard env_guard(callback_data->vm); + auto env = env_guard.env(); + if (!env) { + duckdb_function_set_error(info, "Failed to acquire JNIEnv for Java table function callback"); + return; + } + + auto row_capacity = duckdb_vector_size(); + auto output_ref = env->NewDirectByteBuffer(output, 0); + if (!output_ref || env->ExceptionCheck()) { + if (env->ExceptionCheck()) { + env->ExceptionClear(); + } + duckdb_function_set_error(info, "Failed to materialize Java table function output chunk"); + return; + } + + auto out_appender = env->NewObject(J_UdfOutputAppender, J_UdfOutputAppender_init, output_ref); + if (!out_appender || env->ExceptionCheck()) { + if (env->ExceptionCheck()) { + env->ExceptionClear(); + } + delete_local_ref(env, output_ref); + duckdb_function_set_error(info, "Failed to initialize Java table function output appender"); + return; + } + + auto produced = + env->CallIntMethod(callback_data->callback_ref, J_TableFunction_produce, init_data->state_ref, out_appender); + + jthrowable callback_exception = nullptr; + if (env->ExceptionCheck()) { + callback_exception = env->ExceptionOccurred(); + env->ExceptionClear(); + } + + env->CallVoidMethod(out_appender, J_UdfOutputAppender_close); + jthrowable close_exception = nullptr; + if (env->ExceptionCheck()) { + close_exception = env->ExceptionOccurred(); + env->ExceptionClear(); + } + + if (callback_exception || close_exception) { + auto exception = callback_exception ? callback_exception : close_exception; + std::string error = "Exception in Java table function callback"; + if (exception) { + auto message = reinterpret_cast(env->CallObjectMethod(exception, J_Throwable_getMessage)); + if (message != nullptr) { + error = jstring_to_string(env, message); + delete_local_ref(env, message); + } + delete_local_ref(env, exception); + if (exception == close_exception) { + close_exception = nullptr; + } + } + duckdb_function_set_error(info, error.c_str()); + } else { + if (produced < 0) { + produced = 0; + } + if (produced > static_cast(row_capacity)) { + produced = static_cast(row_capacity); + } + duckdb_data_chunk_set_size(output, static_cast(produced)); + } + + if (close_exception) { + delete_local_ref(env, close_exception); + } + delete_local_ref(env, out_appender); + delete_local_ref(env, output_ref); +} diff --git a/src/jni/udf_callbacks.hpp b/src/jni/udf_callbacks.hpp new file mode 100644 index 000000000..c52f0c277 --- /dev/null +++ b/src/jni/udf_callbacks.hpp @@ -0,0 +1,52 @@ +#pragma once + +extern "C" { +#include "duckdb.h" +} + +#include +#include + +struct JavaScalarUdfCallbackData { + JavaVM *vm; + jobject callback_ref; + bool return_null_on_exception; + bool var_args; + std::vector argument_types; + duckdb_type var_args_type; + duckdb_type return_type; +}; + +struct JavaTableFunctionCallbackData { + JavaVM *vm; + jobject callback_ref; + bool thread_safe; + idx_t max_threads; + std::vector parameter_logical_types; +}; + +struct JavaTableFunctionBindData { + JavaVM *vm; + jobject bind_result_ref; +}; + +struct JavaTableFunctionInitData { + JavaVM *vm; + jobject state_ref; +}; + +void destroy_java_scalar_udf_callback_data(void *ptr); + +void destroy_java_table_function_callback_data(void *ptr); + +void destroy_java_table_function_bind_data(void *ptr); + +void destroy_java_table_function_init_data(void *ptr); + +void java_scalar_udf_callback(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output); + +void java_table_function_bind_callback(duckdb_bind_info info); + +void java_table_function_init_callback(duckdb_init_info info); + +void java_table_function_main_callback(duckdb_function_info info, duckdb_data_chunk output); diff --git a/src/jni/udf_registration.cpp b/src/jni/udf_registration.cpp new file mode 100644 index 000000000..abe26a41e --- /dev/null +++ b/src/jni/udf_registration.cpp @@ -0,0 +1,37 @@ +#include "udf_registration.hpp" + +#include "udf_registration_internal.hpp" + +void _duckdb_jdbc_register_scalar_udf(JNIEnv *env, jclass clazz, jobject conn_ref_buf, jbyteArray name_j, + jobject callback, jobjectArray argument_logical_types_j, + jobject return_logical_type_j, jboolean special_handling, + jboolean return_null_on_exception, jboolean deterministic, jboolean var_args) { + duckdb_jdbc_register_scalar_udf_impl(env, clazz, conn_ref_buf, name_j, callback, argument_logical_types_j, + return_logical_type_j, special_handling, return_null_on_exception, + deterministic, var_args); +} + +void _duckdb_jdbc_register_scalar_udf_on_function(JNIEnv *env, jclass clazz, jobject conn_ref_buf, + jobject scalar_function_buf, jobject callback, + jobjectArray argument_logical_types_j, jobject return_logical_type_j, + jboolean return_null_on_exception, jboolean var_args) { + duckdb_jdbc_register_scalar_udf_on_function_impl(env, clazz, conn_ref_buf, scalar_function_buf, callback, + argument_logical_types_j, return_logical_type_j, + return_null_on_exception, var_args); +} + +void _duckdb_jdbc_register_table_function(JNIEnv *env, jclass clazz, jobject conn_ref_buf, jbyteArray name_j, + jobject callback, jobjectArray parameter_types_j, + jboolean supports_projection_pushdown, jint max_threads, + jboolean thread_safe) { + duckdb_jdbc_register_table_function_impl(env, clazz, conn_ref_buf, name_j, callback, parameter_types_j, + supports_projection_pushdown, max_threads, thread_safe); +} + +void _duckdb_jdbc_register_table_function_on_function(JNIEnv *env, jclass clazz, jobject conn_ref_buf, + jobject table_function_buf, jobject callback, + jobjectArray parameter_types_j, jint max_threads, + jboolean thread_safe) { + duckdb_jdbc_register_table_function_on_function_impl(env, clazz, conn_ref_buf, table_function_buf, callback, + parameter_types_j, max_threads, thread_safe); +} diff --git a/src/jni/udf_registration.hpp b/src/jni/udf_registration.hpp new file mode 100644 index 000000000..d6a912e8f --- /dev/null +++ b/src/jni/udf_registration.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include + +void _duckdb_jdbc_register_scalar_udf(JNIEnv *env, jclass clazz, jobject conn_ref_buf, jbyteArray name_j, + jobject callback, jobjectArray argument_logical_types_j, + jobject return_logical_type_j, jboolean special_handling, + jboolean return_null_on_exception, jboolean deterministic, jboolean var_args); + +void _duckdb_jdbc_register_scalar_udf_on_function(JNIEnv *env, jclass clazz, jobject conn_ref_buf, + jobject scalar_function_buf, jobject callback, + jobjectArray argument_logical_types_j, jobject return_logical_type_j, + jboolean return_null_on_exception, jboolean var_args); + +void _duckdb_jdbc_register_table_function(JNIEnv *env, jclass clazz, jobject conn_ref_buf, jbyteArray name_j, + jobject callback, jobjectArray parameter_types_j, + jboolean supports_projection_pushdown, jint max_threads, + jboolean thread_safe); + +void _duckdb_jdbc_register_table_function_on_function(JNIEnv *env, jclass clazz, jobject conn_ref_buf, + jobject table_function_buf, jobject callback, + jobjectArray parameter_types_j, jint max_threads, + jboolean thread_safe); diff --git a/src/jni/udf_registration_impl.cpp b/src/jni/udf_registration_impl.cpp new file mode 100644 index 000000000..c7302ba75 --- /dev/null +++ b/src/jni/udf_registration_impl.cpp @@ -0,0 +1,375 @@ +extern "C" { +#include "duckdb.h" +} + +#include "duckdb.hpp" +#include "holders.hpp" +#include "refs.hpp" +#include "types.hpp" +#include "udf_callbacks.hpp" +#include "udf_registration_internal.hpp" +#include "udf_types.hpp" +#include "util.hpp" + +#include +#include + +using namespace duckdb; + +static JavaVM *resolve_java_vm(JNIEnv *env) { + JavaVM *vm = nullptr; + if (env->GetJavaVM(&vm) != JNI_OK || vm == nullptr) { + env->ThrowNew(J_SQLException, "Failed to resolve JavaVM for UDF registration"); + return nullptr; + } + return vm; +} + +static duckdb_scalar_function scalar_function_buf_to_scalar_function(JNIEnv *env, jobject scalar_function_buf) { + if (scalar_function_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function buffer"); + return nullptr; + } + + auto scalar_function = reinterpret_cast(env->GetDirectBufferAddress(scalar_function_buf)); + if (scalar_function == nullptr) { + env->ThrowNew(J_SQLException, "Invalid scalar function"); + return nullptr; + } + + return scalar_function; +} + +static duckdb_table_function table_function_buf_to_table_function(JNIEnv *env, jobject table_function_buf) { + if (table_function_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid table function buffer"); + return nullptr; + } + + auto table_function = reinterpret_cast(env->GetDirectBufferAddress(table_function_buf)); + if (table_function == nullptr) { + env->ThrowNew(J_SQLException, "Invalid table function"); + return nullptr; + } + + return table_function; +} + +static void register_scalar_udf_on_function(JNIEnv *env, duckdb_connection conn, duckdb_scalar_function scalar_function, + jobject callback, jobjectArray argument_logical_types_j, + jobject return_logical_type_j, jboolean return_null_on_exception, + jboolean var_args, JavaVM *vm) { + if (argument_logical_types_j == nullptr) { + env->ThrowNew(J_SQLException, "Invalid null argument types"); + return; + } + auto arg_count = env->GetArrayLength(argument_logical_types_j); + if (arg_count < 0) { + env->ThrowNew(J_SQLException, "Invalid scalar UDF argument count"); + return; + } + if (var_args && arg_count != 1) { + env->ThrowNew(J_SQLException, "Scalar UDF varargs registration expects exactly one argument logical type"); + return; + } + + std::vector arg_types; + arg_types.reserve(static_cast(arg_count)); + std::vector arg_type_tags; + arg_type_tags.reserve(static_cast(arg_count)); + auto destroy_arg_types = [&arg_types]() { + for (auto &arg_type : arg_types) { + duckdb_destroy_logical_type(&arg_type); + } + }; + for (jsize i = 0; i < arg_count; i++) { + auto argument_type_obj = env->GetObjectArrayElement(argument_logical_types_j, i); + if (env->ExceptionCheck() || !argument_type_obj) { + destroy_arg_types(); + env->ThrowNew(J_SQLException, "Invalid scalar UDF argument logical type descriptor"); + return; + } + + std::string logical_type_error; + auto arg_logical_type = create_table_logical_type_from_java(env, argument_type_obj, logical_type_error); + delete_local_ref(env, argument_type_obj); + if (env->ExceptionCheck() || !arg_logical_type) { + destroy_arg_types(); + if (logical_type_error.empty()) { + logical_type_error = "Unsupported scalar UDF argument logical type"; + } + env->ThrowNew(J_SQLException, logical_type_error.c_str()); + return; + } + + auto arg_type_id = duckdb_get_type_id(arg_logical_type); + if (!is_supported_scalar_udf_type(arg_type_id)) { + duckdb_destroy_logical_type(&arg_logical_type); + destroy_arg_types(); + env->ThrowNew(J_SQLException, UNSUPPORTED_SCALAR_UDF_TYPE_ERROR); + return; + } + arg_types.push_back(arg_logical_type); + arg_type_tags.push_back(arg_type_id); + } + + if (return_logical_type_j == nullptr) { + destroy_arg_types(); + env->ThrowNew(J_SQLException, "Invalid null return type"); + return; + } + + std::string return_type_error; + auto return_type = create_table_logical_type_from_java(env, return_logical_type_j, return_type_error); + if (env->ExceptionCheck() || !return_type) { + destroy_arg_types(); + if (return_type_error.empty()) { + return_type_error = "Unsupported scalar UDF return logical type"; + } + env->ThrowNew(J_SQLException, return_type_error.c_str()); + return; + } + auto return_type_tag = duckdb_get_type_id(return_type); + if (!is_supported_scalar_udf_type(return_type_tag)) { + destroy_arg_types(); + duckdb_destroy_logical_type(&return_type); + env->ThrowNew(J_SQLException, UNSUPPORTED_SCALAR_UDF_TYPE_ERROR); + return; + } + + auto callback_data = new JavaScalarUdfCallbackData(); + callback_data->vm = vm; + callback_data->callback_ref = env->NewGlobalRef(callback); + callback_data->return_null_on_exception = return_null_on_exception; + callback_data->var_args = var_args; + callback_data->argument_types = std::move(arg_type_tags); + callback_data->var_args_type = + callback_data->argument_types.empty() ? DUCKDB_TYPE_INVALID : callback_data->argument_types[0]; + callback_data->return_type = return_type_tag; + if (!callback_data->callback_ref) { + delete callback_data; + destroy_arg_types(); + duckdb_destroy_logical_type(&return_type); + throw InvalidInputException("Failed to create global ref for Java scalar UDF callback"); + } + + if (var_args) { + duckdb_scalar_function_set_varargs(scalar_function, arg_types[0]); + } else { + for (auto &arg_type : arg_types) { + duckdb_scalar_function_add_parameter(scalar_function, arg_type); + } + } + duckdb_scalar_function_set_return_type(scalar_function, return_type); + duckdb_scalar_function_set_extra_info(scalar_function, callback_data, destroy_java_scalar_udf_callback_data); + duckdb_scalar_function_set_function(scalar_function, java_scalar_udf_callback); + + auto register_state = duckdb_register_scalar_function(conn, scalar_function); + + destroy_arg_types(); + duckdb_destroy_logical_type(&return_type); + + if (register_state != DuckDBSuccess) { + throw InvalidInputException("Failed to register Java scalar UDF"); + } +} + +static void register_table_function_on_function(JNIEnv *env, duckdb_connection conn, duckdb_table_function table_fn, + jobject callback, jobjectArray parameter_types_j, jint max_threads, + jboolean thread_safe, JavaVM *vm) { + if (parameter_types_j == nullptr) { + env->ThrowNew(J_SQLException, "Invalid null table function parameter types"); + return; + } + + auto parameter_count = env->GetArrayLength(parameter_types_j); + std::vector parameter_logical_types; + parameter_logical_types.reserve(static_cast(parameter_count)); + for (jsize i = 0; i < parameter_count; i++) { + auto parameter_type_obj = env->GetObjectArrayElement(parameter_types_j, i); + if (env->ExceptionCheck() || !parameter_type_obj) { + for (auto ¶meter_logical_type : parameter_logical_types) { + duckdb_destroy_logical_type(¶meter_logical_type); + } + env->ThrowNew(J_SQLException, "Invalid table function parameter logical type descriptor"); + return; + } + std::string logical_type_error; + auto parameter_logical_type = create_table_logical_type_from_java(env, parameter_type_obj, logical_type_error); + delete_local_ref(env, parameter_type_obj); + if (env->ExceptionCheck() || !parameter_logical_type) { + for (auto &existing_parameter_type : parameter_logical_types) { + duckdb_destroy_logical_type(&existing_parameter_type); + } + if (logical_type_error.empty()) { + logical_type_error = "Unsupported table function parameter logical type"; + } + env->ThrowNew(J_SQLException, logical_type_error.c_str()); + return; + } + std::string support_error; + if (!is_supported_table_bind_parameter_logical_type(parameter_logical_type, support_error)) { + duckdb_destroy_logical_type(¶meter_logical_type); + for (auto &existing_parameter_type : parameter_logical_types) { + duckdb_destroy_logical_type(&existing_parameter_type); + } + if (support_error.empty()) { + support_error = UNSUPPORTED_TABLE_FUNCTION_PARAMETER_TYPE_ERROR; + } + env->ThrowNew(J_SQLException, support_error.c_str()); + return; + } + duckdb_table_function_add_parameter(table_fn, parameter_logical_type); + parameter_logical_types.push_back(parameter_logical_type); + } + + auto callback_data = new JavaTableFunctionCallbackData(); + callback_data->vm = vm; + callback_data->callback_ref = env->NewGlobalRef(callback); + callback_data->thread_safe = thread_safe; + callback_data->max_threads = static_cast(max_threads < 1 ? 1 : max_threads); + callback_data->parameter_logical_types = std::move(parameter_logical_types); + if (!callback_data->callback_ref) { + for (auto ¶meter_logical_type : callback_data->parameter_logical_types) { + duckdb_destroy_logical_type(¶meter_logical_type); + } + delete callback_data; + throw InvalidInputException("Failed to create global ref for Java table function callback"); + } + duckdb_table_function_set_extra_info(table_fn, callback_data, destroy_java_table_function_callback_data); + duckdb_table_function_set_bind(table_fn, java_table_function_bind_callback); + duckdb_table_function_set_init(table_fn, java_table_function_init_callback); + duckdb_table_function_set_function(table_fn, java_table_function_main_callback); + + auto register_state = duckdb_register_table_function(conn, table_fn); + if (register_state != DuckDBSuccess) { + throw InvalidInputException("Failed to register Java table function"); + } +} + +void duckdb_jdbc_register_scalar_udf_impl(JNIEnv *env, jclass, jobject conn_ref_buf, jbyteArray name_j, + jobject callback, jobjectArray argument_logical_types_j, + jobject return_logical_type_j, jboolean special_handling, + jboolean return_null_on_exception, jboolean deterministic, + jboolean var_args) { + auto conn = conn_ref_buf_to_conn(env, conn_ref_buf); + if (env->ExceptionCheck()) { + return; + } + if (callback == nullptr) { + env->ThrowNew(J_SQLException, "Invalid null callback"); + return; + } + auto vm = resolve_java_vm(env); + if (env->ExceptionCheck() || vm == nullptr) { + return; + } + + auto udf_name = jbyteArray_to_string(env, name_j); + if (env->ExceptionCheck()) { + return; + } + + auto scalar_function = duckdb_create_scalar_function(); + duckdb_scalar_function_set_name(scalar_function, udf_name.c_str()); + if (special_handling) { + duckdb_scalar_function_set_special_handling(scalar_function); + } + if (!deterministic) { + duckdb_scalar_function_set_volatile(scalar_function); + } + try { + register_scalar_udf_on_function(env, conn, scalar_function, callback, argument_logical_types_j, + return_logical_type_j, return_null_on_exception, var_args, vm); + } catch (...) { + duckdb_destroy_scalar_function(&scalar_function); + throw; + } + duckdb_destroy_scalar_function(&scalar_function); +} + +void duckdb_jdbc_register_scalar_udf_on_function_impl(JNIEnv *env, jclass, jobject conn_ref_buf, + jobject scalar_function_buf, jobject callback, + jobjectArray argument_logical_types_j, + jobject return_logical_type_j, jboolean return_null_on_exception, + jboolean var_args) { + auto conn = conn_ref_buf_to_conn(env, conn_ref_buf); + if (env->ExceptionCheck()) { + return; + } + if (callback == nullptr) { + env->ThrowNew(J_SQLException, "Invalid null callback"); + return; + } + auto vm = resolve_java_vm(env); + if (env->ExceptionCheck() || vm == nullptr) { + return; + } + auto scalar_function = scalar_function_buf_to_scalar_function(env, scalar_function_buf); + if (env->ExceptionCheck()) { + return; + } + register_scalar_udf_on_function(env, conn, scalar_function, callback, argument_logical_types_j, + return_logical_type_j, return_null_on_exception, var_args, vm); +} + +void duckdb_jdbc_register_table_function_impl(JNIEnv *env, jclass, jobject conn_ref_buf, jbyteArray name_j, + jobject callback, jobjectArray parameter_types_j, + jboolean supports_projection_pushdown, jint max_threads, + jboolean thread_safe) { + auto conn = conn_ref_buf_to_conn(env, conn_ref_buf); + if (env->ExceptionCheck()) { + return; + } + if (callback == nullptr) { + env->ThrowNew(J_SQLException, "Invalid null callback"); + return; + } + auto vm = resolve_java_vm(env); + if (env->ExceptionCheck() || vm == nullptr) { + return; + } + auto fn_name = jbyteArray_to_string(env, name_j); + if (env->ExceptionCheck()) { + return; + } + auto table_fn = duckdb_create_table_function(); + duckdb_table_function_set_name(table_fn, fn_name.c_str()); + if (supports_projection_pushdown) { + duckdb_table_function_supports_projection_pushdown(table_fn, true); + } + try { + register_table_function_on_function(env, conn, table_fn, callback, parameter_types_j, max_threads, thread_safe, + vm); + } catch (...) { + duckdb_destroy_table_function(&table_fn); + throw; + } + duckdb_destroy_table_function(&table_fn); + if (env->ExceptionCheck()) { + return; + } +} + +void duckdb_jdbc_register_table_function_on_function_impl(JNIEnv *env, jclass, jobject conn_ref_buf, + jobject table_function_buf, jobject callback, + jobjectArray parameter_types_j, jint max_threads, + jboolean thread_safe) { + auto conn = conn_ref_buf_to_conn(env, conn_ref_buf); + if (env->ExceptionCheck()) { + return; + } + if (callback == nullptr) { + env->ThrowNew(J_SQLException, "Invalid null callback"); + return; + } + auto vm = resolve_java_vm(env); + if (env->ExceptionCheck() || vm == nullptr) { + return; + } + auto table_fn = table_function_buf_to_table_function(env, table_function_buf); + if (env->ExceptionCheck()) { + return; + } + register_table_function_on_function(env, conn, table_fn, callback, parameter_types_j, max_threads, thread_safe, vm); +} diff --git a/src/jni/udf_registration_internal.hpp b/src/jni/udf_registration_internal.hpp new file mode 100644 index 000000000..958b1b72d --- /dev/null +++ b/src/jni/udf_registration_internal.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +void duckdb_jdbc_register_scalar_udf_impl(JNIEnv *env, jclass clazz, jobject conn_ref_buf, jbyteArray name_j, + jobject callback, jobjectArray argument_logical_types_j, + jobject return_logical_type_j, jboolean special_handling, + jboolean return_null_on_exception, jboolean deterministic, jboolean var_args); + +void duckdb_jdbc_register_scalar_udf_on_function_impl(JNIEnv *env, jclass clazz, jobject conn_ref_buf, + jobject scalar_function_buf, jobject callback, + jobjectArray argument_logical_types_j, + jobject return_logical_type_j, jboolean return_null_on_exception, + jboolean var_args); + +void duckdb_jdbc_register_table_function_impl(JNIEnv *env, jclass clazz, jobject conn_ref_buf, jbyteArray name_j, + jobject callback, jobjectArray parameter_types_j, + jboolean supports_projection_pushdown, jint max_threads, + jboolean thread_safe); + +void duckdb_jdbc_register_table_function_on_function_impl(JNIEnv *env, jclass clazz, jobject conn_ref_buf, + jobject table_function_buf, jobject callback, + jobjectArray parameter_types_j, jint max_threads, + jboolean thread_safe); diff --git a/src/jni/udf_table_bind_conversion.cpp b/src/jni/udf_table_bind_conversion.cpp new file mode 100644 index 000000000..e82721e18 --- /dev/null +++ b/src/jni/udf_table_bind_conversion.cpp @@ -0,0 +1,534 @@ +extern "C" { +#include "duckdb.h" +} + +#include "duckdb/common/assert.hpp" +#include "refs.hpp" +#include "udf_table_bind_conversion.hpp" +#include "util.hpp" + +#include +#include + +static jobject table_bind_parameter_to_java_internal(JNIEnv *env, duckdb_value val, duckdb_logical_type logical_type, + std::string &error); + +static int64_t floor_divide_i64(int64_t value, int64_t divisor) { + D_ASSERT(divisor > 0); + auto quotient = value / divisor; + auto remainder = value % divisor; + if (remainder < 0) { + quotient -= 1; + } + return quotient; +} + +static int64_t floor_modulo_i64(int64_t value, int64_t divisor) { + D_ASSERT(divisor > 0); + auto remainder = value % divisor; + if (remainder < 0) { + remainder += divisor; + } + return remainder; +} + +static jobject date_to_local_date(JNIEnv *env, int64_t epoch_days, std::string &error) { + auto local_date = env->CallStaticObjectMethod(J_LocalDate, J_LocalDate_ofEpochDay, static_cast(epoch_days)); + if (env->ExceptionCheck() || !local_date) { + error = "Failed to materialize DATE table function parameter as LocalDate"; + return nullptr; + } + return local_date; +} + +static jobject nanos_to_local_time(JNIEnv *env, int64_t nanos_of_day, std::string &error) { + auto local_time = + env->CallStaticObjectMethod(J_LocalTime, J_LocalTime_ofNanoOfDay, static_cast(nanos_of_day)); + if (env->ExceptionCheck() || !local_time) { + error = "Failed to materialize TIME table function parameter as LocalTime"; + return nullptr; + } + return local_time; +} + +static jobject timestamp_to_local_date_time(JNIEnv *env, int64_t epoch_value, int64_t units_per_second, + int64_t nanos_per_unit, std::string &error) { + auto seconds = floor_divide_i64(epoch_value, units_per_second); + auto remainder_units = floor_modulo_i64(epoch_value, units_per_second); + auto nanos = remainder_units * nanos_per_unit; + auto local_date_time = + env->CallStaticObjectMethod(J_LocalDateTime, J_LocalDateTime_ofEpochSecond, static_cast(seconds), + static_cast(nanos), J_ZoneOffset_UTC); + if (env->ExceptionCheck() || !local_date_time) { + error = "Failed to materialize TIMESTAMP table function parameter as LocalDateTime"; + return nullptr; + } + return local_date_time; +} + +static jobject timestamp_to_offset_date_time(JNIEnv *env, int64_t micros_since_epoch, std::string &error) { + auto local_date_time = timestamp_to_local_date_time(env, micros_since_epoch, 1000000, 1000, error); + if (!local_date_time) { + return nullptr; + } + auto offset_date_time = env->CallObjectMethod(local_date_time, J_LocalDateTime_atOffset, J_ZoneOffset_UTC); + env->DeleteLocalRef(local_date_time); + if (env->ExceptionCheck() || !offset_date_time) { + error = "Failed to materialize TIMESTAMP WITH TIME ZONE table function parameter as OffsetDateTime"; + return nullptr; + } + return offset_date_time; +} + +static jobject timetz_to_offset_time(JNIEnv *env, uint64_t time_tz_bits, std::string &error) { + static constexpr int64_t MAX_TZ_SECONDS = 16 * 60 * 60 - 1; + int64_t signed_bits = static_cast(time_tz_bits); + int64_t micros = signed_bits >> 24; + int64_t inverted_biased_offset = signed_bits & 0x0FFFFFF; + int64_t offset_seconds = MAX_TZ_SECONDS - inverted_biased_offset; + + auto local_time = nanos_to_local_time(env, micros * 1000, error); + if (!local_time) { + return nullptr; + } + auto zone_offset = + env->CallStaticObjectMethod(J_ZoneOffset, J_ZoneOffset_ofTotalSeconds, static_cast(offset_seconds)); + if (env->ExceptionCheck() || !zone_offset) { + env->DeleteLocalRef(local_time); + error = "Failed to materialize TIME WITH TIME ZONE offset for table function parameter"; + return nullptr; + } + auto offset_time = env->CallStaticObjectMethod(J_OffsetTime, J_OffsetTime_of, local_time, zone_offset); + env->DeleteLocalRef(local_time); + env->DeleteLocalRef(zone_offset); + if (env->ExceptionCheck() || !offset_time) { + error = "Failed to materialize TIME WITH TIME ZONE table function parameter as OffsetTime"; + return nullptr; + } + return offset_time; +} + +static jobject uuid_to_java_uuid(JNIEnv *env, duckdb_uhugeint uuid, std::string &error) { + auto most_significant_bits = static_cast(uuid.upper); + auto least_significant_bits = static_cast(uuid.lower); + auto uuid_obj = env->NewObject(J_UUID, J_UUID_init, most_significant_bits, least_significant_bits); + if (env->ExceptionCheck() || !uuid_obj) { + error = "Failed to materialize UUID table function parameter as java.util.UUID"; + return nullptr; + } + return uuid_obj; +} + +static jobject decimal_to_bigdecimal(JNIEnv *env, duckdb_decimal decimal, std::string &error) { + auto decimal_value = duckdb_create_decimal(decimal); + if (!decimal_value) { + error = "Failed to materialize DECIMAL value"; + return nullptr; + } + varchar_ptr decimal_str_ptr(duckdb_value_to_string(decimal_value), varchar_deleter); + duckdb_destroy_value(&decimal_value); + if (!decimal_str_ptr) { + error = "Failed to convert DECIMAL value to string"; + return nullptr; + } + + auto decimal_str_len = static_cast(std::strlen(decimal_str_ptr.get())); + auto decimal_str_j = decode_charbuffer_to_jstring(env, decimal_str_ptr.get(), decimal_str_len); + if (env->ExceptionCheck() || !decimal_str_j) { + if (decimal_str_j) { + delete_local_ref(env, decimal_str_j); + } + error = "Failed to decode DECIMAL value as UTF-8"; + return nullptr; + } + + auto big_decimal = env->NewObject(J_BigDecimal, J_BigDecimal_initString, decimal_str_j); + delete_local_ref(env, decimal_str_j); + if (env->ExceptionCheck() || !big_decimal) { + error = "Failed to allocate BigDecimal for DECIMAL value"; + return nullptr; + } + return big_decimal; +} + +static jobject table_bind_parameter_scalar_to_java(JNIEnv *env, duckdb_value val, duckdb_type type, + duckdb_logical_type logical_type, std::string &error) { + switch (type) { + case DUCKDB_TYPE_BOOLEAN: + return env->NewObject(J_Bool, J_Bool_init, static_cast(duckdb_get_bool(val))); + case DUCKDB_TYPE_TINYINT: + return env->NewObject(J_Byte, J_Byte_init, static_cast(duckdb_get_int64(val))); + case DUCKDB_TYPE_SMALLINT: + return env->NewObject(J_Short, J_Short_init, static_cast(duckdb_get_int64(val))); + case DUCKDB_TYPE_INTEGER: + return env->NewObject(J_Int, J_Int_init, static_cast(duckdb_get_int64(val))); + case DUCKDB_TYPE_BIGINT: + return env->NewObject(J_Long, J_Long_init, static_cast(duckdb_get_int64(val))); + case DUCKDB_TYPE_UTINYINT: + return env->NewObject(J_Short, J_Short_init, static_cast(duckdb_get_uint8(val))); + case DUCKDB_TYPE_USMALLINT: + return env->NewObject(J_Int, J_Int_init, static_cast(duckdb_get_uint16(val))); + case DUCKDB_TYPE_UINTEGER: + return env->NewObject(J_Long, J_Long_init, static_cast(duckdb_get_uint32(val))); + case DUCKDB_TYPE_UBIGINT: + return env->NewObject(J_Long, J_Long_init, static_cast(duckdb_get_uint64(val))); + case DUCKDB_TYPE_HUGEINT: { + auto huge = duckdb_get_hugeint(val); + auto bytes = env->NewByteArray(static_cast(sizeof(huge))); + if (!bytes) { + error = "Failed to allocate byte array for HUGEINT table function parameter"; + return nullptr; + } + env->SetByteArrayRegion(bytes, 0, static_cast(sizeof(huge)), reinterpret_cast(&huge)); + return bytes; + } + case DUCKDB_TYPE_UHUGEINT: { + auto uhuge = duckdb_get_uhugeint(val); + auto bytes = env->NewByteArray(static_cast(sizeof(uhuge))); + if (!bytes) { + error = "Failed to allocate byte array for UHUGEINT table function parameter"; + return nullptr; + } + env->SetByteArrayRegion(bytes, 0, static_cast(sizeof(uhuge)), reinterpret_cast(&uhuge)); + return bytes; + } + case DUCKDB_TYPE_FLOAT: + return env->NewObject(J_Float, J_Float_init, static_cast(duckdb_get_double(val))); + case DUCKDB_TYPE_DOUBLE: + return env->NewObject(J_Double, J_Double_init, static_cast(duckdb_get_double(val))); + case DUCKDB_TYPE_DECIMAL: { + auto decimal = duckdb_get_decimal(val); + auto decimal_obj = decimal_to_bigdecimal(env, decimal, error); + if (!decimal_obj) { + if (error.empty()) { + error = "Failed to materialize DECIMAL table function parameter"; + } + return nullptr; + } + return decimal_obj; + } + case DUCKDB_TYPE_VARCHAR: { + varchar_ptr str_ptr(duckdb_get_varchar(val), varchar_deleter); + if (!str_ptr) { + error = "Failed to materialize VARCHAR table function parameter"; + return nullptr; + } + auto str_len = static_cast(std::strlen(str_ptr.get())); + auto jstr = decode_charbuffer_to_jstring(env, str_ptr.get(), str_len); + if (env->ExceptionCheck()) { + error = "Failed to decode VARCHAR table function parameter"; + return nullptr; + } + return jstr; + } + case DUCKDB_TYPE_BLOB: { + auto blob = duckdb_get_blob(val); + if (blob.size > 0 && blob.data == nullptr) { + error = "Failed to materialize BLOB table function parameter"; + return nullptr; + } + auto bytes = env->NewByteArray(static_cast(blob.size)); + if (blob.size > 0 && bytes) { + env->SetByteArrayRegion(bytes, 0, static_cast(blob.size), + reinterpret_cast(blob.data)); + } + if (blob.data) { + duckdb_free(blob.data); + } + if (!bytes) { + error = "Failed to allocate byte array for BLOB table function parameter"; + return nullptr; + } + return bytes; + } + case DUCKDB_TYPE_DATE: { + auto date = duckdb_get_date(val); + return date_to_local_date(env, date.days, error); + } + case DUCKDB_TYPE_TIME: { + auto time = duckdb_get_time(val); + return nanos_to_local_time(env, time.micros * 1000, error); + } + case DUCKDB_TYPE_TIME_NS: { + auto time_ns = duckdb_get_time_ns(val); + return nanos_to_local_time(env, time_ns.nanos, error); + } + case DUCKDB_TYPE_TIME_TZ: { + auto time_tz = duckdb_get_time_tz(val); + return timetz_to_offset_time(env, time_tz.bits, error); + } + case DUCKDB_TYPE_TIMESTAMP: { + auto ts = duckdb_get_timestamp(val); + return timestamp_to_local_date_time(env, ts.micros, 1000000, 1000, error); + } + case DUCKDB_TYPE_TIMESTAMP_S: { + auto ts_s = duckdb_get_timestamp_s(val); + return timestamp_to_local_date_time(env, ts_s.seconds, 1, 1, error); + } + case DUCKDB_TYPE_TIMESTAMP_MS: { + auto ts_ms = duckdb_get_timestamp_ms(val); + return timestamp_to_local_date_time(env, ts_ms.millis, 1000, 1000000, error); + } + case DUCKDB_TYPE_TIMESTAMP_NS: { + auto ts_ns = duckdb_get_timestamp_ns(val); + return timestamp_to_local_date_time(env, ts_ns.nanos, 1000000000, 1, error); + } + case DUCKDB_TYPE_TIMESTAMP_TZ: { + auto ts_tz = duckdb_get_timestamp_tz(val); + return timestamp_to_offset_date_time(env, ts_tz.micros, error); + } + case DUCKDB_TYPE_UUID: { + auto uuid = duckdb_get_uuid(val); + return uuid_to_java_uuid(env, uuid, error); + } + case DUCKDB_TYPE_ENUM: { + auto enum_idx = duckdb_get_enum_value(val); + auto dictionary_size = duckdb_enum_dictionary_size(logical_type); + if (enum_idx >= dictionary_size) { + error = "Invalid enum value index in table function parameter"; + return nullptr; + } + varchar_ptr enum_value_ptr(duckdb_enum_dictionary_value(logical_type, enum_idx), varchar_deleter); + if (!enum_value_ptr) { + error = "Failed to materialize ENUM table function parameter"; + return nullptr; + } + auto str_len = static_cast(std::strlen(enum_value_ptr.get())); + auto jstr = decode_charbuffer_to_jstring(env, enum_value_ptr.get(), str_len); + if (env->ExceptionCheck()) { + error = "Failed to decode ENUM table function parameter"; + return nullptr; + } + return jstr; + } + default: + error = "Unsupported scalar parameter type in Java table function bind callback"; + return nullptr; + } +} + +static jobject table_bind_list_parameter_to_java(JNIEnv *env, duckdb_value val, duckdb_logical_type logical_type, + duckdb_type type, std::string &error) { + auto child_type = type == DUCKDB_TYPE_LIST ? duckdb_list_type_child_type(logical_type) + : duckdb_array_type_child_type(logical_type); + if (!child_type) { + error = "Failed to inspect LIST/ARRAY child type in table function parameter"; + return nullptr; + } + + auto list = env->NewObject(J_ArrayList, J_ArrayList_init); + if (!list || env->ExceptionCheck()) { + duckdb_destroy_logical_type(&child_type); + error = "Failed to allocate Java list for table function parameter"; + return nullptr; + } + + auto size = duckdb_get_list_size(val); + for (idx_t i = 0; i < size; i++) { + auto child_value = duckdb_get_list_child(val, i); + std::string child_error; + auto child_obj = table_bind_parameter_to_java_internal(env, child_value, child_type, child_error); + duckdb_destroy_value(&child_value); + if (!child_error.empty() || env->ExceptionCheck()) { + if (child_obj) { + env->DeleteLocalRef(child_obj); + } + env->DeleteLocalRef(list); + duckdb_destroy_logical_type(&child_type); + error = + child_error.empty() ? "Failed to convert LIST/ARRAY child in table function parameter" : child_error; + return nullptr; + } + env->CallBooleanMethod(list, J_ArrayList_add, child_obj); + if (child_obj) { + env->DeleteLocalRef(child_obj); + } + if (env->ExceptionCheck()) { + env->DeleteLocalRef(list); + duckdb_destroy_logical_type(&child_type); + error = "Failed to append LIST/ARRAY child in Java table function parameter"; + return nullptr; + } + } + + duckdb_destroy_logical_type(&child_type); + return list; +} + +static jobject table_bind_map_parameter_to_java(JNIEnv *env, duckdb_value val, duckdb_logical_type logical_type, + std::string &error) { + auto key_type = duckdb_map_type_key_type(logical_type); + auto value_type = duckdb_map_type_value_type(logical_type); + if (!key_type || !value_type) { + if (key_type) { + duckdb_destroy_logical_type(&key_type); + } + if (value_type) { + duckdb_destroy_logical_type(&value_type); + } + error = "Failed to inspect MAP key/value types in table function parameter"; + return nullptr; + } + + auto map = env->NewObject(J_LinkedHashMap, J_LinkedHashMap_init); + if (!map || env->ExceptionCheck()) { + duckdb_destroy_logical_type(&key_type); + duckdb_destroy_logical_type(&value_type); + error = "Failed to allocate Java map for table function parameter"; + return nullptr; + } + + auto size = duckdb_get_map_size(val); + for (idx_t i = 0; i < size; i++) { + auto key_value = duckdb_get_map_key(val, i); + auto mapped_value = duckdb_get_map_value(val, i); + std::string key_error; + std::string value_error; + auto key_obj = table_bind_parameter_to_java_internal(env, key_value, key_type, key_error); + auto value_obj = table_bind_parameter_to_java_internal(env, mapped_value, value_type, value_error); + duckdb_destroy_value(&key_value); + duckdb_destroy_value(&mapped_value); + if (!key_error.empty() || !value_error.empty() || env->ExceptionCheck()) { + if (key_obj) { + env->DeleteLocalRef(key_obj); + } + if (value_obj) { + env->DeleteLocalRef(value_obj); + } + env->DeleteLocalRef(map); + duckdb_destroy_logical_type(&key_type); + duckdb_destroy_logical_type(&value_type); + error = !key_error.empty() ? key_error : value_error; + if (error.empty()) { + error = "Failed to convert MAP entry in table function parameter"; + } + return nullptr; + } + auto old_value = env->CallObjectMethod(map, J_LinkedHashMap_put, key_obj, value_obj); + if (old_value) { + env->DeleteLocalRef(old_value); + } + if (key_obj) { + env->DeleteLocalRef(key_obj); + } + if (value_obj) { + env->DeleteLocalRef(value_obj); + } + if (env->ExceptionCheck()) { + env->DeleteLocalRef(map); + duckdb_destroy_logical_type(&key_type); + duckdb_destroy_logical_type(&value_type); + error = "Failed to append MAP entry in Java table function parameter"; + return nullptr; + } + } + + duckdb_destroy_logical_type(&key_type); + duckdb_destroy_logical_type(&value_type); + return map; +} + +static jobject table_bind_struct_parameter_to_java(JNIEnv *env, duckdb_value val, duckdb_logical_type logical_type, + std::string &error) { + auto map = env->NewObject(J_LinkedHashMap, J_LinkedHashMap_init); + if (!map || env->ExceptionCheck()) { + error = "Failed to allocate Java struct map for table function parameter"; + return nullptr; + } + + auto child_count = duckdb_struct_type_child_count(logical_type); + for (idx_t i = 0; i < child_count; i++) { + varchar_ptr child_name_ptr(duckdb_struct_type_child_name(logical_type, i), varchar_deleter); + auto child_type = duckdb_struct_type_child_type(logical_type, i); + auto child_value = duckdb_get_struct_child(val, i); + if (!child_name_ptr || !child_type) { + if (child_type) { + duckdb_destroy_logical_type(&child_type); + } + duckdb_destroy_value(&child_value); + env->DeleteLocalRef(map); + error = "Failed to inspect STRUCT child metadata in table function parameter"; + return nullptr; + } + + auto child_name_len = static_cast(std::strlen(child_name_ptr.get())); + auto child_name_j = decode_charbuffer_to_jstring(env, child_name_ptr.get(), child_name_len); + std::string child_error; + auto child_obj = table_bind_parameter_to_java_internal(env, child_value, child_type, child_error); + duckdb_destroy_logical_type(&child_type); + duckdb_destroy_value(&child_value); + if (env->ExceptionCheck() || !child_error.empty() || !child_name_j) { + if (child_name_j) { + env->DeleteLocalRef(child_name_j); + } + if (child_obj) { + env->DeleteLocalRef(child_obj); + } + env->DeleteLocalRef(map); + error = child_error.empty() ? "Failed to convert STRUCT child in table function parameter" : child_error; + return nullptr; + } + auto old_value = env->CallObjectMethod(map, J_LinkedHashMap_put, child_name_j, child_obj); + if (old_value) { + env->DeleteLocalRef(old_value); + } + env->DeleteLocalRef(child_name_j); + if (child_obj) { + env->DeleteLocalRef(child_obj); + } + if (env->ExceptionCheck()) { + env->DeleteLocalRef(map); + error = "Failed to append STRUCT child in Java table function parameter"; + return nullptr; + } + } + + return map; +} + +static jobject table_bind_parameter_to_java_internal(JNIEnv *env, duckdb_value val, duckdb_logical_type logical_type, + std::string &error) { + if (duckdb_is_null_value(val)) { + return nullptr; + } + if (!logical_type) { + error = "Invalid logical type for table function parameter"; + return nullptr; + } + + auto type = static_cast(duckdb_get_type_id(logical_type)); + switch (type) { + case DUCKDB_TYPE_LIST: + case DUCKDB_TYPE_ARRAY: + return table_bind_list_parameter_to_java(env, val, logical_type, type, error); + case DUCKDB_TYPE_MAP: + return table_bind_map_parameter_to_java(env, val, logical_type, error); + case DUCKDB_TYPE_STRUCT: + return table_bind_struct_parameter_to_java(env, val, logical_type, error); + case DUCKDB_TYPE_UNION: { + varchar_ptr repr_ptr(duckdb_value_to_string(val), varchar_deleter); + if (!repr_ptr) { + error = "Failed to materialize UNION table function parameter"; + return nullptr; + } + auto repr_len = static_cast(std::strlen(repr_ptr.get())); + auto jstr = decode_charbuffer_to_jstring(env, repr_ptr.get(), repr_len); + if (env->ExceptionCheck()) { + error = "Failed to decode UNION table function parameter"; + return nullptr; + } + return jstr; + } + default: + return table_bind_parameter_scalar_to_java(env, val, type, logical_type, error); + } +} + +jobject table_bind_parameter_to_java(JNIEnv *env, duckdb_value val, duckdb_logical_type logical_type, + std::vector &local_refs, std::string &error) { + auto param_obj = table_bind_parameter_to_java_internal(env, val, logical_type, error); + if (param_obj != nullptr) { + local_refs.push_back(param_obj); + } + return param_obj; +} diff --git a/src/jni/udf_table_bind_conversion.hpp b/src/jni/udf_table_bind_conversion.hpp new file mode 100644 index 000000000..51a36d093 --- /dev/null +++ b/src/jni/udf_table_bind_conversion.hpp @@ -0,0 +1,12 @@ +#pragma once + +extern "C" { +#include "duckdb.h" +} + +#include +#include +#include + +jobject table_bind_parameter_to_java(JNIEnv *env, duckdb_value val, duckdb_logical_type logical_type, + std::vector &local_refs, std::string &error); diff --git a/src/jni/udf_types.cpp b/src/jni/udf_types.cpp new file mode 100644 index 000000000..8496ea8e6 --- /dev/null +++ b/src/jni/udf_types.cpp @@ -0,0 +1,526 @@ +extern "C" { +#include "duckdb.h" +} + +#include "refs.hpp" +#include "udf_types.hpp" +#include "util.hpp" + +#include +#include + +static const UdfTypeSpec UDF_TYPE_SPECS[] = { + {DUCKDB_TYPE_BOOLEAN, "BOOLEAN", true, true, true, false, 1, ACCESS_GET_BOOLEAN | ACCESS_SET_BOOLEAN}, + {DUCKDB_TYPE_TINYINT, "TINYINT", true, true, true, false, 1, ACCESS_GET_INT | ACCESS_SET_INT}, + {DUCKDB_TYPE_SMALLINT, "SMALLINT", true, true, true, false, 2, ACCESS_GET_INT | ACCESS_SET_INT}, + {DUCKDB_TYPE_INTEGER, "INTEGER", true, true, true, false, 4, ACCESS_GET_INT | ACCESS_SET_INT}, + {DUCKDB_TYPE_BIGINT, "BIGINT", true, true, true, false, 8, ACCESS_GET_LONG | ACCESS_SET_LONG}, + {DUCKDB_TYPE_UTINYINT, "UTINYINT", true, true, true, false, 1, ACCESS_GET_INT | ACCESS_SET_INT}, + {DUCKDB_TYPE_USMALLINT, "USMALLINT", true, true, true, false, 2, ACCESS_GET_INT | ACCESS_SET_INT}, + {DUCKDB_TYPE_UINTEGER, "UINTEGER", true, true, true, false, 4, ACCESS_GET_LONG | ACCESS_SET_LONG}, + {DUCKDB_TYPE_UBIGINT, "UBIGINT", true, true, true, false, 8, ACCESS_GET_LONG | ACCESS_SET_LONG}, + {DUCKDB_TYPE_HUGEINT, "HUGEINT", true, true, true, false, 16, ACCESS_GET_BYTES | ACCESS_SET_BYTES}, + {DUCKDB_TYPE_UHUGEINT, "UHUGEINT", true, true, true, false, 16, ACCESS_GET_BYTES | ACCESS_SET_BYTES}, + {DUCKDB_TYPE_FLOAT, "FLOAT", true, true, true, false, 4, ACCESS_GET_FLOAT | ACCESS_SET_FLOAT}, + {DUCKDB_TYPE_DOUBLE, "DOUBLE", true, true, true, false, 8, ACCESS_GET_DOUBLE | ACCESS_SET_DOUBLE}, + {DUCKDB_TYPE_VARCHAR, "VARCHAR", true, true, true, true, 0, ACCESS_GET_STRING | ACCESS_SET_STRING}, + {DUCKDB_TYPE_BLOB, "BLOB", true, true, true, true, 0, ACCESS_GET_BYTES | ACCESS_SET_BYTES}, + {DUCKDB_TYPE_DECIMAL, "DECIMAL", true, true, true, true, 0, ACCESS_GET_DECIMAL | ACCESS_SET_DECIMAL}, + {DUCKDB_TYPE_DATE, "DATE", true, true, true, false, 4, ACCESS_GET_INT | ACCESS_SET_INT}, + {DUCKDB_TYPE_TIME, "TIME", true, true, true, false, 8, ACCESS_GET_LONG | ACCESS_SET_LONG}, + {DUCKDB_TYPE_TIME_NS, "TIME_NS", true, true, true, false, 8, ACCESS_GET_LONG | ACCESS_SET_LONG}, + {DUCKDB_TYPE_TIME_TZ, "TIME_WITH_TIME_ZONE", true, true, true, false, 8, ACCESS_GET_LONG | ACCESS_SET_LONG}, + {DUCKDB_TYPE_TIMESTAMP, "TIMESTAMP", true, true, true, false, 8, ACCESS_GET_LONG | ACCESS_SET_LONG}, + {DUCKDB_TYPE_TIMESTAMP_S, "TIMESTAMP_S", true, true, true, false, 8, ACCESS_GET_LONG | ACCESS_SET_LONG}, + {DUCKDB_TYPE_TIMESTAMP_MS, "TIMESTAMP_MS", true, true, true, false, 8, ACCESS_GET_LONG | ACCESS_SET_LONG}, + {DUCKDB_TYPE_TIMESTAMP_NS, "TIMESTAMP_NS", true, true, true, false, 8, ACCESS_GET_LONG | ACCESS_SET_LONG}, + {DUCKDB_TYPE_TIMESTAMP_TZ, "TIMESTAMP_WITH_TIME_ZONE", true, true, true, false, 8, + ACCESS_GET_LONG | ACCESS_SET_LONG}, + {DUCKDB_TYPE_UUID, "UUID", true, true, true, false, 16, ACCESS_GET_BYTES | ACCESS_SET_BYTES}, +}; + +static constexpr uint8_t DEFAULT_DECIMAL_WIDTH = 18; +static constexpr uint8_t DEFAULT_DECIMAL_SCALE = 3; + +const char *UNSUPPORTED_SCALAR_UDF_TYPE_ERROR = + "Supported scalar UDF types: BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, DOUBLE, VARCHAR, DECIMAL, " + "BLOB, DATE, TIME, TIME_NS, TIMESTAMP, TIMESTAMP_S, TIMESTAMP_MS, TIMESTAMP_NS, UTINYINT, USMALLINT, " + "UINTEGER, UBIGINT, HUGEINT, UHUGEINT, TIME_WITH_TIME_ZONE, TIMESTAMP_WITH_TIME_ZONE, UUID"; + +const char *UNSUPPORTED_TABLE_FUNCTION_PARAMETER_TYPE_ERROR = + "Supported table function parameter types: BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, DOUBLE, " + "VARCHAR, DECIMAL, BLOB, DATE, TIME, TIME_NS, TIMESTAMP, TIMESTAMP_S, TIMESTAMP_MS, TIMESTAMP_NS, " + "UTINYINT, USMALLINT, UINTEGER, UBIGINT, HUGEINT, UHUGEINT, TIME_WITH_TIME_ZONE, TIMESTAMP_WITH_TIME_ZONE, " + "UUID"; + +static const UdfTypeSpec *find_udf_type_spec_by_name(const std::string &name) { + for (const auto &spec : UDF_TYPE_SPECS) { + if (name == spec.duckdb_column_type_name) { + return &spec; + } + } + return nullptr; +} + +static bool is_supported_table_function_parameter_type(duckdb_type type) { + auto spec = find_udf_type_spec(type); + return spec != nullptr && spec->udf_vector_supported; +} + +static bool duckdb_type_from_java_column_type_name(const std::string &name, duckdb_type &out_type) { + if (name == "LIST") { + out_type = DUCKDB_TYPE_LIST; + return true; + } + if (name == "ARRAY") { + out_type = DUCKDB_TYPE_ARRAY; + return true; + } + if (name == "MAP") { + out_type = DUCKDB_TYPE_MAP; + return true; + } + if (name == "STRUCT") { + out_type = DUCKDB_TYPE_STRUCT; + return true; + } + if (name == "UNION") { + out_type = DUCKDB_TYPE_UNION; + return true; + } + if (name == "ENUM") { + out_type = DUCKDB_TYPE_ENUM; + return true; + } + auto spec = find_udf_type_spec_by_name(name); + if (spec) { + out_type = spec->type; + return true; + } + return false; +} + +static bool table_any_column_type_from_java(JNIEnv *env, jobject duckdb_column_type_obj, duckdb_type &out_type) { + auto name_j = reinterpret_cast(env->CallObjectMethod(duckdb_column_type_obj, J_Enum_name)); + if (env->ExceptionCheck() || !name_j) { + return false; + } + auto name = jstring_to_string(env, name_j); + env->DeleteLocalRef(name_j); + return duckdb_type_from_java_column_type_name(name, out_type); +} + +duckdb_logical_type create_udf_logical_type(duckdb_type type) { + if (type == DUCKDB_TYPE_DECIMAL) { + return duckdb_create_decimal_type(DEFAULT_DECIMAL_WIDTH, DEFAULT_DECIMAL_SCALE); + } + return duckdb_create_logical_type(type); +} + +const UdfTypeSpec *find_udf_type_spec(duckdb_type type) { + for (const auto &spec : UDF_TYPE_SPECS) { + if (spec.type == type) { + return &spec; + } + } + return nullptr; +} + +bool capi_type_id_to_duckdb_type(jint type_id, duckdb_type &out_type) { + auto requested_type = static_cast(type_id); + auto spec = find_udf_type_spec(requested_type); + if (!spec) { + return false; + } + out_type = spec->type; + return true; +} + +bool is_supported_scalar_udf_type(duckdb_type type) { + auto spec = find_udf_type_spec(type); + return spec != nullptr && spec->scalar_udf_implemented; +} + +bool table_column_type_from_java(JNIEnv *env, jobject duckdb_column_type_obj, duckdb_type &out_type) { + auto name_j = reinterpret_cast(env->CallObjectMethod(duckdb_column_type_obj, J_Enum_name)); + if (env->ExceptionCheck() || !name_j) { + return false; + } + auto name = jstring_to_string(env, name_j); + env->DeleteLocalRef(name_j); + duckdb_type resolved_type = DUCKDB_TYPE_INVALID; + if (!duckdb_type_from_java_column_type_name(name, resolved_type)) { + return false; + } + auto spec = find_udf_type_spec(resolved_type); + if (spec && spec->table_bind_schema_supported) { + out_type = resolved_type; + return true; + } + return false; +} + +bool is_supported_table_bind_parameter_logical_type(duckdb_logical_type logical_type, std::string &error) { + if (!logical_type) { + error = "Invalid null logical type for table function parameter"; + return false; + } + + auto type = static_cast(duckdb_get_type_id(logical_type)); + if (is_supported_table_function_parameter_type(type)) { + return true; + } + + switch (type) { + case DUCKDB_TYPE_ENUM: + return true; + case DUCKDB_TYPE_LIST: + case DUCKDB_TYPE_ARRAY: { + auto child_type = type == DUCKDB_TYPE_LIST ? duckdb_list_type_child_type(logical_type) + : duckdb_array_type_child_type(logical_type); + if (!child_type) { + error = "Failed to inspect child type for LIST/ARRAY parameter"; + return false; + } + std::string child_error; + auto ok = is_supported_table_bind_parameter_logical_type(child_type, child_error); + duckdb_destroy_logical_type(&child_type); + if (!ok) { + error = "Unsupported LIST/ARRAY child type: " + child_error; + } + return ok; + } + case DUCKDB_TYPE_MAP: { + auto key_type = duckdb_map_type_key_type(logical_type); + auto value_type = duckdb_map_type_value_type(logical_type); + if (!key_type || !value_type) { + if (key_type) { + duckdb_destroy_logical_type(&key_type); + } + if (value_type) { + duckdb_destroy_logical_type(&value_type); + } + error = "Failed to inspect key/value types for MAP parameter"; + return false; + } + std::string key_error; + std::string value_error; + auto key_ok = is_supported_table_bind_parameter_logical_type(key_type, key_error); + auto value_ok = is_supported_table_bind_parameter_logical_type(value_type, value_error); + duckdb_destroy_logical_type(&key_type); + duckdb_destroy_logical_type(&value_type); + if (!key_ok || !value_ok) { + error = "Unsupported MAP parameter type: key(" + key_error + "), value(" + value_error + ")"; + return false; + } + return true; + } + case DUCKDB_TYPE_STRUCT: + case DUCKDB_TYPE_UNION: { + auto child_count = type == DUCKDB_TYPE_STRUCT ? duckdb_struct_type_child_count(logical_type) + : duckdb_union_type_member_count(logical_type); + for (idx_t i = 0; i < child_count; i++) { + auto child_type = type == DUCKDB_TYPE_STRUCT ? duckdb_struct_type_child_type(logical_type, i) + : duckdb_union_type_member_type(logical_type, i); + if (!child_type) { + error = "Failed to inspect child type for STRUCT/UNION parameter"; + return false; + } + std::string child_error; + auto child_ok = is_supported_table_bind_parameter_logical_type(child_type, child_error); + duckdb_destroy_logical_type(&child_type); + if (!child_ok) { + error = "Unsupported STRUCT/UNION child type: " + child_error; + return false; + } + } + return true; + } + default: + error = "Unsupported table function parameter logical type id: " + std::to_string(static_cast(type)); + return false; + } +} + +duckdb_logical_type create_table_logical_type_from_java(JNIEnv *env, jobject logical_type_obj, std::string &error); + +static duckdb_logical_type create_struct_or_union_type_from_java(JNIEnv *env, jobject logical_type_obj, + duckdb_type type, std::string &error) { + auto field_names = + reinterpret_cast(env->CallObjectMethod(logical_type_obj, J_UdfLogicalType_getFieldNames)); + auto field_types = + reinterpret_cast(env->CallObjectMethod(logical_type_obj, J_UdfLogicalType_getFieldTypes)); + if (env->ExceptionCheck() || !field_names || !field_types) { + error = "Invalid Java logical type for struct/union"; + if (field_names) { + env->DeleteLocalRef(field_names); + } + if (field_types) { + env->DeleteLocalRef(field_types); + } + return nullptr; + } + auto field_count = env->GetArrayLength(field_names); + if (field_count <= 0 || field_count != env->GetArrayLength(field_types)) { + error = "Struct/union logical type requires matching non-empty field names/types"; + env->DeleteLocalRef(field_names); + env->DeleteLocalRef(field_types); + return nullptr; + } + + std::vector field_name_storage; + std::vector field_name_ptrs; + std::vector member_types; + field_name_storage.reserve(field_count); + field_name_ptrs.reserve(field_count); + member_types.reserve(field_count); + + for (jsize i = 0; i < field_count; i++) { + auto field_name_j = reinterpret_cast(env->GetObjectArrayElement(field_names, i)); + auto field_type_obj = env->GetObjectArrayElement(field_types, i); + if (env->ExceptionCheck() || !field_name_j || !field_type_obj) { + error = "Invalid struct/union field descriptor in Java logical type"; + if (field_name_j) { + env->DeleteLocalRef(field_name_j); + } + if (field_type_obj) { + env->DeleteLocalRef(field_type_obj); + } + break; + } + auto field_name = jstring_to_string(env, field_name_j); + env->DeleteLocalRef(field_name_j); + if (field_name.empty()) { + error = "Struct/union field names must be non-empty"; + env->DeleteLocalRef(field_type_obj); + break; + } + std::string member_error; + auto member_type = create_table_logical_type_from_java(env, field_type_obj, member_error); + env->DeleteLocalRef(field_type_obj); + if (env->ExceptionCheck() || !member_type) { + error = member_error.empty() ? "Invalid struct/union field type in Java logical type" : member_error; + break; + } + field_name_storage.push_back(field_name); + field_name_ptrs.push_back(field_name_storage.back().c_str()); + member_types.push_back(member_type); + } + + duckdb_logical_type result = nullptr; + if (error.empty()) { + if (type == DUCKDB_TYPE_STRUCT) { + result = duckdb_create_struct_type(member_types.data(), field_name_ptrs.data(), member_types.size()); + } else { + result = duckdb_create_union_type(member_types.data(), field_name_ptrs.data(), member_types.size()); + } + if (!result) { + error = "Failed to create struct/union logical type"; + } + } + + for (auto &member_type : member_types) { + duckdb_destroy_logical_type(&member_type); + } + env->DeleteLocalRef(field_names); + env->DeleteLocalRef(field_types); + return result; +} + +static duckdb_logical_type create_enum_type_from_java(JNIEnv *env, jobject logical_type_obj, std::string &error) { + auto enum_values = + reinterpret_cast(env->CallObjectMethod(logical_type_obj, J_UdfLogicalType_getEnumValues)); + if (env->ExceptionCheck() || !enum_values) { + error = "Invalid Java logical type for enum"; + if (enum_values) { + env->DeleteLocalRef(enum_values); + } + return nullptr; + } + + auto value_count = env->GetArrayLength(enum_values); + if (value_count <= 0) { + error = "Enum logical type requires at least one value"; + env->DeleteLocalRef(enum_values); + return nullptr; + } + + std::vector value_storage; + std::vector value_ptrs; + value_storage.reserve(value_count); + value_ptrs.reserve(value_count); + for (jsize i = 0; i < value_count; i++) { + auto enum_value_j = reinterpret_cast(env->GetObjectArrayElement(enum_values, i)); + if (env->ExceptionCheck() || !enum_value_j) { + error = "Enum values must be non-null strings"; + if (enum_value_j) { + env->DeleteLocalRef(enum_value_j); + } + break; + } + auto enum_value = jstring_to_string(env, enum_value_j); + env->DeleteLocalRef(enum_value_j); + if (enum_value.empty()) { + error = "Enum values must be non-empty strings"; + break; + } + value_storage.push_back(enum_value); + value_ptrs.push_back(value_storage.back().c_str()); + } + env->DeleteLocalRef(enum_values); + + if (!error.empty()) { + return nullptr; + } + auto enum_type = duckdb_create_enum_type(value_ptrs.data(), value_ptrs.size()); + if (!enum_type) { + error = "Failed to create enum logical type"; + } + return enum_type; +} + +duckdb_logical_type create_table_logical_type_from_java(JNIEnv *env, jobject logical_type_obj, std::string &error) { + if (!logical_type_obj) { + error = "Java logical type is null"; + return nullptr; + } + + auto column_type_obj = env->CallObjectMethod(logical_type_obj, J_UdfLogicalType_getType); + if (env->ExceptionCheck() || !column_type_obj) { + error = "Java logical type has invalid DuckDBColumnType"; + if (column_type_obj) { + env->DeleteLocalRef(column_type_obj); + } + return nullptr; + } + + duckdb_type type = DUCKDB_TYPE_INVALID; + auto type_ok = table_any_column_type_from_java(env, column_type_obj, type); + env->DeleteLocalRef(column_type_obj); + if (!type_ok) { + error = "Unsupported DuckDBColumnType in Java logical type"; + return nullptr; + } + + switch (type) { + case DUCKDB_TYPE_LIST: { + auto child_type_obj = env->CallObjectMethod(logical_type_obj, J_UdfLogicalType_getChildType); + if (env->ExceptionCheck() || !child_type_obj) { + error = "List logical type requires child type"; + if (child_type_obj) { + env->DeleteLocalRef(child_type_obj); + } + return nullptr; + } + std::string child_error; + auto child_type = create_table_logical_type_from_java(env, child_type_obj, child_error); + env->DeleteLocalRef(child_type_obj); + if (!child_type || env->ExceptionCheck()) { + error = child_error.empty() ? "Failed to create list child logical type" : child_error; + return nullptr; + } + auto list_type = duckdb_create_list_type(child_type); + duckdb_destroy_logical_type(&child_type); + if (!list_type) { + error = "Failed to create list logical type"; + } + return list_type; + } + case DUCKDB_TYPE_ARRAY: { + auto child_type_obj = env->CallObjectMethod(logical_type_obj, J_UdfLogicalType_getChildType); + auto array_size = env->CallLongMethod(logical_type_obj, J_UdfLogicalType_getArraySize); + if (env->ExceptionCheck() || !child_type_obj || array_size <= 0) { + error = "Array logical type requires child type and positive array size"; + if (child_type_obj) { + env->DeleteLocalRef(child_type_obj); + } + return nullptr; + } + std::string child_error; + auto child_type = create_table_logical_type_from_java(env, child_type_obj, child_error); + env->DeleteLocalRef(child_type_obj); + if (!child_type || env->ExceptionCheck()) { + error = child_error.empty() ? "Failed to create array child logical type" : child_error; + return nullptr; + } + auto array_type = duckdb_create_array_type(child_type, array_size); + duckdb_destroy_logical_type(&child_type); + if (!array_type) { + error = "Failed to create array logical type"; + } + return array_type; + } + case DUCKDB_TYPE_MAP: { + auto key_type_obj = env->CallObjectMethod(logical_type_obj, J_UdfLogicalType_getKeyType); + auto value_type_obj = env->CallObjectMethod(logical_type_obj, J_UdfLogicalType_getValueType); + if (env->ExceptionCheck() || !key_type_obj || !value_type_obj) { + error = "Map logical type requires key and value types"; + if (key_type_obj) { + env->DeleteLocalRef(key_type_obj); + } + if (value_type_obj) { + env->DeleteLocalRef(value_type_obj); + } + return nullptr; + } + std::string key_error; + std::string value_error; + auto key_type = create_table_logical_type_from_java(env, key_type_obj, key_error); + auto value_type = create_table_logical_type_from_java(env, value_type_obj, value_error); + env->DeleteLocalRef(key_type_obj); + env->DeleteLocalRef(value_type_obj); + if (!key_type || !value_type || env->ExceptionCheck()) { + error = !key_error.empty() ? key_error : value_error; + if (error.empty()) { + error = "Failed to create map key/value logical types"; + } + if (key_type) { + duckdb_destroy_logical_type(&key_type); + } + if (value_type) { + duckdb_destroy_logical_type(&value_type); + } + return nullptr; + } + auto map_type = duckdb_create_map_type(key_type, value_type); + duckdb_destroy_logical_type(&key_type); + duckdb_destroy_logical_type(&value_type); + if (!map_type) { + error = "Failed to create map logical type"; + } + return map_type; + } + case DUCKDB_TYPE_STRUCT: + case DUCKDB_TYPE_UNION: + return create_struct_or_union_type_from_java(env, logical_type_obj, type, error); + case DUCKDB_TYPE_ENUM: + return create_enum_type_from_java(env, logical_type_obj, error); + case DUCKDB_TYPE_DECIMAL: { + static constexpr jint DECIMAL_WIDTH_MIN = 1; + static constexpr jint DECIMAL_WIDTH_MAX = 38; + static constexpr jint DECIMAL_SCALE_MIN = 0; + + auto width = env->CallIntMethod(logical_type_obj, J_UdfLogicalType_getDecimalWidth); + auto scale = env->CallIntMethod(logical_type_obj, J_UdfLogicalType_getDecimalScale); + if (env->ExceptionCheck()) { + error = "Decimal logical type has invalid width/scale"; + return nullptr; + } + if (width < DECIMAL_WIDTH_MIN || width > DECIMAL_WIDTH_MAX) { + error = "Decimal logical type width must be between 1 and 38"; + return nullptr; + } + if (scale < DECIMAL_SCALE_MIN || scale > width) { + error = "Decimal logical type scale must be between 0 and width"; + return nullptr; + } + auto decimal_type = duckdb_create_decimal_type(static_cast(width), static_cast(scale)); + if (!decimal_type) { + error = "Failed to create decimal logical type"; + } + return decimal_type; + } + default: + return create_udf_logical_type(type); + } +} diff --git a/src/jni/udf_types.hpp b/src/jni/udf_types.hpp new file mode 100644 index 000000000..9f3b26bdb --- /dev/null +++ b/src/jni/udf_types.hpp @@ -0,0 +1,56 @@ +#pragma once + +extern "C" { +#include "duckdb.h" +} + +#include +#include +#include + +enum UdfVectorAccessMask : uint32_t { + ACCESS_GET_INT = 1 << 0, + ACCESS_GET_LONG = 1 << 1, + ACCESS_GET_FLOAT = 1 << 2, + ACCESS_GET_DOUBLE = 1 << 3, + ACCESS_GET_BOOLEAN = 1 << 4, + ACCESS_GET_STRING = 1 << 5, + ACCESS_GET_BYTES = 1 << 6, + ACCESS_SET_INT = 1 << 7, + ACCESS_SET_LONG = 1 << 8, + ACCESS_SET_FLOAT = 1 << 9, + ACCESS_SET_DOUBLE = 1 << 10, + ACCESS_SET_BOOLEAN = 1 << 11, + ACCESS_SET_STRING = 1 << 12, + ACCESS_SET_BYTES = 1 << 13, + ACCESS_GET_DECIMAL = 1 << 14, + ACCESS_SET_DECIMAL = 1 << 15, +}; + +struct UdfTypeSpec { + duckdb_type type; + const char *duckdb_column_type_name; + bool udf_vector_supported; + bool scalar_udf_implemented; + bool table_bind_schema_supported; + bool requires_vector_ref; + uint8_t fixed_width_bytes; + uint32_t access_mask; +}; + +extern const char *UNSUPPORTED_SCALAR_UDF_TYPE_ERROR; +extern const char *UNSUPPORTED_TABLE_FUNCTION_PARAMETER_TYPE_ERROR; + +const UdfTypeSpec *find_udf_type_spec(duckdb_type type); + +bool capi_type_id_to_duckdb_type(jint type_id, duckdb_type &out_type); + +bool is_supported_scalar_udf_type(duckdb_type type); + +bool table_column_type_from_java(JNIEnv *env, jobject duckdb_column_type_obj, duckdb_type &out_type); + +bool is_supported_table_bind_parameter_logical_type(duckdb_logical_type logical_type, std::string &error); + +duckdb_logical_type create_udf_logical_type(duckdb_type type); + +duckdb_logical_type create_table_logical_type_from_java(JNIEnv *env, jobject logical_type_obj, std::string &error); diff --git a/src/jni/udf_vector_accessors.cpp b/src/jni/udf_vector_accessors.cpp new file mode 100644 index 000000000..1e9d7e8c5 --- /dev/null +++ b/src/jni/udf_vector_accessors.cpp @@ -0,0 +1,294 @@ +extern "C" { +#include "duckdb.h" +} + +#include "duckdb.hpp" +#include "refs.hpp" +#include "types.hpp" +#include "udf_vector_accessors.hpp" +#include "util.hpp" + +#include +#include +#include + +static duckdb_vector udf_vector_ref_buf_to_vector(JNIEnv *env, jobject vector_ref_buf) { + if (vector_ref_buf == nullptr) { + env->ThrowNew(J_SQLException, "Invalid null vector reference"); + return nullptr; + } + auto vec = reinterpret_cast(env->GetDirectBufferAddress(vector_ref_buf)); + if (!vec) { + env->ThrowNew(J_SQLException, "Invalid vector reference"); + return nullptr; + } + return vec; +} + +static jbyteArray udf_get_varlen_bytes(JNIEnv *env, duckdb_vector vector, jint row) { + auto data = reinterpret_cast(duckdb_vector_get_data(vector)); + auto &value = data[row]; + auto len = duckdb_string_t_length(value); + auto ptr = duckdb_string_t_data(&value); + auto bytes = env->NewByteArray(static_cast(len)); + if (!bytes) { + return nullptr; + } + if (len > 0) { + env->SetByteArrayRegion(bytes, 0, static_cast(len), reinterpret_cast(ptr)); + } + return bytes; +} + +static void udf_set_varlen_bytes(JNIEnv *env, duckdb_vector vector, jint row, jbyteArray value, + const char *null_value_error) { + if (value == nullptr) { + env->ThrowNew(J_SQLException, null_value_error); + return; + } + auto len = env->GetArrayLength(value); + auto bytes = env->GetByteArrayElements(value, nullptr); + if (!bytes) { + env->ThrowNew(J_SQLException, "Failed to access varlen bytes"); + return; + } + duckdb_vector_assign_string_element_len(vector, row, reinterpret_cast(bytes), len); + env->ReleaseByteArrayElements(value, bytes, JNI_ABORT); +} + +static duckdb_hugeint int64_to_hugeint(int64_t value) { + duckdb_hugeint result; + result.lower = static_cast(value); + result.upper = value < 0 ? -1 : 0; + return result; +} + +static bool decimal_vector_meta(duckdb_vector vector, uint8_t &width, uint8_t &scale, duckdb_type &internal_type, + std::string &error) { + auto logical_type = duckdb_vector_get_column_type(vector); + if (!logical_type) { + error = "Failed to get DECIMAL logical type"; + return false; + } + auto type_id = duckdb_get_type_id(logical_type); + if (type_id != DUCKDB_TYPE_DECIMAL) { + duckdb_destroy_logical_type(&logical_type); + error = "Native decimal accessor requires DECIMAL vector"; + return false; + } + width = duckdb_decimal_width(logical_type); + scale = duckdb_decimal_scale(logical_type); + internal_type = duckdb_decimal_internal_type(logical_type); + duckdb_destroy_logical_type(&logical_type); + return true; +} + +static bool read_decimal_from_vector(duckdb_vector vector, jint row, duckdb_decimal &out_decimal, std::string &error) { + uint8_t width = 0; + uint8_t scale = 0; + duckdb_type internal_type = DUCKDB_TYPE_INVALID; + if (!decimal_vector_meta(vector, width, scale, internal_type, error)) { + return false; + } + auto data = duckdb_vector_get_data(vector); + if (!data) { + error = "Failed to get DECIMAL vector data"; + return false; + } + + out_decimal.width = width; + out_decimal.scale = scale; + switch (internal_type) { + case DUCKDB_TYPE_SMALLINT: + out_decimal.value = int64_to_hugeint(static_cast(reinterpret_cast(data)[row])); + return true; + case DUCKDB_TYPE_INTEGER: + out_decimal.value = int64_to_hugeint(static_cast(reinterpret_cast(data)[row])); + return true; + case DUCKDB_TYPE_BIGINT: + out_decimal.value = int64_to_hugeint(reinterpret_cast(data)[row]); + return true; + case DUCKDB_TYPE_HUGEINT: + out_decimal.value = reinterpret_cast(data)[row]; + return true; + default: + error = "Unsupported DECIMAL physical type for native accessor"; + return false; + } +} + +static jobject decimal_to_bigdecimal(JNIEnv *env, duckdb_decimal decimal, std::string &error) { + auto decimal_value = duckdb_create_decimal(decimal); + if (!decimal_value) { + error = "Failed to materialize DECIMAL value"; + return nullptr; + } + varchar_ptr decimal_str_ptr(duckdb_value_to_string(decimal_value), varchar_deleter); + duckdb_destroy_value(&decimal_value); + if (!decimal_str_ptr) { + error = "Failed to convert DECIMAL value to string"; + return nullptr; + } + + auto decimal_str_len = static_cast(std::strlen(decimal_str_ptr.get())); + auto decimal_str_j = decode_charbuffer_to_jstring(env, decimal_str_ptr.get(), decimal_str_len); + if (env->ExceptionCheck() || !decimal_str_j) { + if (decimal_str_j) { + delete_local_ref(env, decimal_str_j); + } + error = "Failed to decode DECIMAL value as UTF-8"; + return nullptr; + } + + auto big_decimal = env->NewObject(J_BigDecimal, J_BigDecimal_initString, decimal_str_j); + delete_local_ref(env, decimal_str_j); + if (env->ExceptionCheck() || !big_decimal) { + error = "Failed to allocate BigDecimal for DECIMAL value"; + return nullptr; + } + return big_decimal; +} + +static bool write_decimal_to_vector(JNIEnv *env, duckdb_vector vector, jint row, jobject value, std::string &error) { + if (!value) { + error = "Invalid null decimal value"; + return false; + } + if (!env->IsInstanceOf(value, J_BigDecimal)) { + error = "Decimal accessor requires java.math.BigDecimal"; + return false; + } + + uint8_t width = 0; + uint8_t scale = 0; + duckdb_type internal_type = DUCKDB_TYPE_INVALID; + if (!decimal_vector_meta(vector, width, scale, internal_type, error)) { + return false; + } + auto data = duckdb_vector_get_data(vector); + if (!data) { + error = "Failed to get DECIMAL vector data"; + return false; + } + + try { + auto decimal_value = create_value_from_bigdecimal(env, value); + if (env->ExceptionCheck()) { + error = "Failed to parse BigDecimal input"; + return false; + } + auto casted = decimal_value.DefaultCastAs(duckdb::LogicalType::DECIMAL(width, scale)); + switch (internal_type) { + case DUCKDB_TYPE_SMALLINT: + reinterpret_cast(data)[row] = casted.GetValueUnsafe(); + return true; + case DUCKDB_TYPE_INTEGER: + reinterpret_cast(data)[row] = casted.GetValueUnsafe(); + return true; + case DUCKDB_TYPE_BIGINT: + reinterpret_cast(data)[row] = casted.GetValueUnsafe(); + return true; + case DUCKDB_TYPE_HUGEINT: { + auto huge_value = casted.GetValueUnsafe(); + reinterpret_cast(data)[row].lower = huge_value.lower; + reinterpret_cast(data)[row].upper = huge_value.upper; + return true; + } + default: + error = "Unsupported DECIMAL physical type for native accessor"; + return false; + } + } catch (const std::exception &e) { + error = std::string("Failed to cast BigDecimal to DECIMAL: ") + e.what(); + return false; + } +} + +jbyteArray _duckdb_jdbc_udf_get_varchar_bytes(JNIEnv *env, jclass, jobject vector_ref_buf, jint row) { + auto vector = udf_vector_ref_buf_to_vector(env, vector_ref_buf); + if (env->ExceptionCheck()) { + return nullptr; + } + if (row < 0) { + env->ThrowNew(J_SQLException, "Invalid negative row index"); + return nullptr; + } + return udf_get_varlen_bytes(env, vector, row); +} + +void _duckdb_jdbc_udf_set_varchar_bytes(JNIEnv *env, jclass, jobject vector_ref_buf, jint row, jbyteArray value) { + auto vector = udf_vector_ref_buf_to_vector(env, vector_ref_buf); + if (env->ExceptionCheck()) { + return; + } + if (row < 0) { + env->ThrowNew(J_SQLException, "Invalid negative row index"); + return; + } + udf_set_varlen_bytes(env, vector, row, value, "Invalid null string bytes"); +} + +jbyteArray _duckdb_jdbc_udf_get_blob_bytes(JNIEnv *env, jclass, jobject vector_ref_buf, jint row) { + auto vector = udf_vector_ref_buf_to_vector(env, vector_ref_buf); + if (env->ExceptionCheck()) { + return nullptr; + } + if (row < 0) { + env->ThrowNew(J_SQLException, "Invalid negative row index"); + return nullptr; + } + return udf_get_varlen_bytes(env, vector, row); +} + +void _duckdb_jdbc_udf_set_blob_bytes(JNIEnv *env, jclass, jobject vector_ref_buf, jint row, jbyteArray value) { + auto vector = udf_vector_ref_buf_to_vector(env, vector_ref_buf); + if (env->ExceptionCheck()) { + return; + } + if (row < 0) { + env->ThrowNew(J_SQLException, "Invalid negative row index"); + return; + } + udf_set_varlen_bytes(env, vector, row, value, "Invalid null blob bytes"); +} + +jobject _duckdb_jdbc_udf_get_decimal(JNIEnv *env, jclass, jobject vector_ref_buf, jint row) { + auto vector = udf_vector_ref_buf_to_vector(env, vector_ref_buf); + if (env->ExceptionCheck()) { + return nullptr; + } + if (row < 0) { + env->ThrowNew(J_SQLException, "Invalid negative row index"); + return nullptr; + } + duckdb_decimal decimal; + std::string error; + if (!read_decimal_from_vector(vector, row, decimal, error)) { + env->ThrowNew(J_SQLException, error.c_str()); + return nullptr; + } + auto result = decimal_to_bigdecimal(env, decimal, error); + if (!result) { + if (error.empty()) { + error = "Failed to convert DECIMAL value to BigDecimal"; + } + env->ThrowNew(J_SQLException, error.c_str()); + return nullptr; + } + return result; +} + +void _duckdb_jdbc_udf_set_decimal(JNIEnv *env, jclass, jobject vector_ref_buf, jint row, jobject value) { + auto vector = udf_vector_ref_buf_to_vector(env, vector_ref_buf); + if (env->ExceptionCheck()) { + return; + } + if (row < 0) { + env->ThrowNew(J_SQLException, "Invalid negative row index"); + return; + } + std::string error; + if (!write_decimal_to_vector(env, vector, row, value, error)) { + env->ThrowNew(J_SQLException, error.c_str()); + } +} diff --git a/src/jni/udf_vector_accessors.hpp b/src/jni/udf_vector_accessors.hpp new file mode 100644 index 000000000..d8ac6b233 --- /dev/null +++ b/src/jni/udf_vector_accessors.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include + +jbyteArray _duckdb_jdbc_udf_get_varchar_bytes(JNIEnv *env, jclass clazz, jobject vector_ref_buf, jint row); + +void _duckdb_jdbc_udf_set_varchar_bytes(JNIEnv *env, jclass clazz, jobject vector_ref_buf, jint row, jbyteArray value); + +jbyteArray _duckdb_jdbc_udf_get_blob_bytes(JNIEnv *env, jclass clazz, jobject vector_ref_buf, jint row); + +void _duckdb_jdbc_udf_set_blob_bytes(JNIEnv *env, jclass clazz, jobject vector_ref_buf, jint row, jbyteArray value); + +jobject _duckdb_jdbc_udf_get_decimal(JNIEnv *env, jclass clazz, jobject vector_ref_buf, jint row); + +void _duckdb_jdbc_udf_set_decimal(JNIEnv *env, jclass clazz, jobject vector_ref_buf, jint row, jobject value); diff --git a/src/jni/util.cpp b/src/jni/util.cpp index c22d03c65..20e9e478e 100644 --- a/src/jni/util.cpp +++ b/src/jni/util.cpp @@ -7,6 +7,66 @@ #include #include +void ThrowJNI(JNIEnv *env, const char *message) { + if (!J_SQLException) { + return; + } + env->ThrowNew(J_SQLException, message); +} + +JNIEnv *get_callback_env(JavaVM *vm, bool &did_attach) { + did_attach = false; + JNIEnv *env = nullptr; + auto get_env_state = vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); + if (get_env_state == JNI_EDETACHED) { + if (vm->AttachCurrentThread(reinterpret_cast(&env), nullptr) != JNI_OK) { + return nullptr; + } + did_attach = true; + } + return env; +} + +void cleanup_callback_env(JavaVM *vm, bool did_attach) { + if (did_attach) { + vm->DetachCurrentThread(); + } +} + +void delete_local_ref(JNIEnv *env, jobject ref) { + if (ref) { + env->DeleteLocalRef(ref); + } +} + +void delete_local_refs(JNIEnv *env, const std::vector &refs) { + for (auto &ref : refs) { + delete_local_ref(env, ref); + } +} + +void delete_global_ref(JNIEnv *env, jobject ref) { + if (ref) { + env->DeleteGlobalRef(ref); + } +} + +CallbackEnvGuard::CallbackEnvGuard(JavaVM *vm_p) : vm(vm_p), jni_env(nullptr), did_attach(false) { + if (vm) { + jni_env = get_callback_env(vm, did_attach); + } +} + +CallbackEnvGuard::~CallbackEnvGuard() { + if (vm) { + cleanup_callback_env(vm, did_attach); + } +} + +JNIEnv *CallbackEnvGuard::env() const { + return jni_env; +} + void check_java_exception_and_rethrow(JNIEnv *env) { if (env->ExceptionCheck()) { jthrowable exc = env->ExceptionOccurred(); diff --git a/src/jni/util.hpp b/src/jni/util.hpp index 83d0ddd1f..a4b6a3c16 100644 --- a/src/jni/util.hpp +++ b/src/jni/util.hpp @@ -8,6 +8,7 @@ extern "C" { #include #include #include +#include using jbyteArray_ptr = std::unique_ptr>; @@ -17,8 +18,33 @@ inline void varchar_deleter(char *val) { duckdb_free(val); } +void ThrowJNI(JNIEnv *env, const char *message); + void check_java_exception_and_rethrow(JNIEnv *env); +JNIEnv *get_callback_env(JavaVM *vm, bool &did_attach); + +void cleanup_callback_env(JavaVM *vm, bool did_attach); + +void delete_local_ref(JNIEnv *env, jobject ref); + +void delete_local_refs(JNIEnv *env, const std::vector &refs); + +void delete_global_ref(JNIEnv *env, jobject ref); + +class CallbackEnvGuard { +public: + explicit CallbackEnvGuard(JavaVM *vm_p); + ~CallbackEnvGuard(); + + JNIEnv *env() const; + +private: + JavaVM *vm; + JNIEnv *jni_env; + bool did_attach; +}; + std::string jbyteArray_to_string(JNIEnv *env, jbyteArray ba_j); std::string jstring_to_string(JNIEnv *env, jstring string_j); diff --git a/src/main/java/org/duckdb/DuckDBAppender.java b/src/main/java/org/duckdb/DuckDBAppender.java index 38371ae2b..bb4a14ffa 100644 --- a/src/main/java/org/duckdb/DuckDBAppender.java +++ b/src/main/java/org/duckdb/DuckDBAppender.java @@ -17,6 +17,7 @@ import java.util.*; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import org.duckdb.DuckDBVectorWriteCore.Column; public class DuckDBAppender implements AutoCloseable { @@ -78,7 +79,7 @@ public class DuckDBAppender implements AutoCloseable { private static final LocalDateTime EPOCH_DATE_TIME = LocalDateTime.ofEpochSecond(0, 0, UTC); - private static final long MAX_TOP_LEVEL_ROWS = duckdb_vector_size(); + private static final long MAX_TOP_LEVEL_ROWS = DuckDBVectorWriteCore.MAX_TOP_LEVEL_ROWS; private final DuckDBConnection conn; @@ -98,7 +99,6 @@ public class DuckDBAppender implements AutoCloseable { private Column prevColumn = null; private boolean writeInlinedStrings = true; - private long ownerThreadId = currentThread().getId(); DuckDBAppender(DuckDBConnection conn, String catalog, String schema, String table) throws SQLException { @@ -159,8 +159,8 @@ public DuckDBAppender endRow() throws SQLException { createErrMsg("all columns must be appended to before calling 'endRow', expected columns count: " + columns.size() + ", actual: " + (topCol.idx + 1))); } else { - throw new SQLException(createErrMsg( - "calls to 'beginRow' and 'endRow' must be paired and cannot be interleaved with other 'begin*' and 'end*' calls")); + throw new SQLException(createErrMsg("calls to 'beginRow' and 'endRow' must be paired and cannot be " + + "interleaved with other 'begin*' and 'end*' calls")); } } @@ -2042,9 +2042,9 @@ private void putCompositeElementStruct(Column structCol, long vectorIdx, Object LinkedHashMap map = (LinkedHashMap) structValue; collection = map.values(); } else { - throw new SQLException(createErrMsg( - "struct values must be specified as an instance of a 'java.util.LinkedHashMap' or as a collection of objects, actual class: " + - structValue.getClass().getName())); + throw new SQLException( + createErrMsg("struct values must be specified as an instance of a 'java.util.LinkedHashMap' or " + + "as a collection of objects, actual class: " + structValue.getClass().getName())); } } else { collection = (Collection) structValue; @@ -2065,9 +2065,9 @@ private void putCompositeElementStruct(Column structCol, long vectorIdx, Object private void putCompositeElementUnion(Column unionCol, long vectorIdx, Object unionValue) throws SQLException { if (!(unionValue instanceof AbstractMap.SimpleEntry)) { - throw new SQLException(createErrMsg( - "union values must be specified as an instance of 'java.util.AbstractMap.SimpleEntry', actual type: " + - unionValue.getClass().getName())); + throw new SQLException(createErrMsg("union values must be specified as an instance of " + + "'java.util.AbstractMap.SimpleEntry', actual type: " + + unionValue.getClass().getName())); } AbstractMap.SimpleEntry entry = (AbstractMap.SimpleEntry) unionValue; String tag = String.valueOf(entry.getKey()); @@ -2269,270 +2269,7 @@ private static ByteBuffer createChunk(ByteBuffer[] colTypes) throws SQLException return chunkRef; } - private static void initVecChildren(Column parent) throws SQLException { - switch (parent.colType) { - case DUCKDB_TYPE_LIST: - case DUCKDB_TYPE_MAP: { - ByteBuffer vec = duckdb_list_vector_get_child(parent.vectorRef); - Column col = new Column(parent, 0, null, vec); - parent.children.add(col); - break; - } - case DUCKDB_TYPE_STRUCT: - case DUCKDB_TYPE_UNION: { - long count = duckdb_struct_type_child_count(parent.colTypeRef); - for (int i = 0; i < count; i++) { - ByteBuffer vec = duckdb_struct_vector_get_child(parent.vectorRef, i); - Column col = new Column(parent, i, null, vec, i); - parent.children.add(col); - } - break; - } - case DUCKDB_TYPE_ARRAY: { - ByteBuffer vec = duckdb_array_vector_get_child(parent.vectorRef); - Column col = new Column(parent, 0, null, vec); - parent.children.add(col); - break; - } - } - } - private static List createTopLevelColumns(ByteBuffer chunkRef, ByteBuffer[] colTypes) throws SQLException { - List columns = new ArrayList<>(colTypes.length); - try { - for (int i = 0; i < colTypes.length; i++) { - ByteBuffer vector = duckdb_data_chunk_get_vector(chunkRef, i); - Column col = new Column(null, i, colTypes[i], vector); - columns.add(col); - colTypes[i] = null; - } - } catch (Exception e) { - for (Column col : columns) { - if (null != col) { - col.destroy(); - } - } - throw e; - } - return columns; - } - - private static Map readEnumDict(ByteBuffer colTypeRef) { - Map dict = new LinkedHashMap<>(); - long size = duckdb_enum_dictionary_size(colTypeRef); - for (long i = 0; i < size; i++) { - byte[] nameUtf8 = duckdb_enum_dictionary_value(colTypeRef, i); - String name = strFromUTF8(nameUtf8); - dict.put(name, (int) i); - } - return dict; - } - - private static class Column { - private final Column parent; - private final int idx; - private /* final */ ByteBuffer colTypeRef; - private final CAPIType colType; - private final CAPIType decimalInternalType; - private final int decimalPrecision; - private final int decimalScale; - private final long arraySize; - private final String structFieldName; - private final Map enumDict; - private final CAPIType enumInternalType; - - private final ByteBuffer vectorRef; - private final List children = new ArrayList<>(); - - private long listSize = 0; - private ByteBuffer data = null; - private ByteBuffer validity = null; - - private Column(Column parent, int idx, ByteBuffer colTypeRef, ByteBuffer vector) throws SQLException { - this(parent, idx, colTypeRef, vector, -1); - } - - private Column(Column parent, int idx, ByteBuffer colTypeRef, ByteBuffer vector, int structFieldIdx) - throws SQLException { - this.parent = parent; - this.idx = idx; - - if (null == vector) { - throw new SQLException("cannot initialize data chunk vector"); - } - - if (null == colTypeRef) { - this.colTypeRef = duckdb_vector_get_column_type(vector); - if (null == this.colTypeRef) { - throw new SQLException("cannot initialize data chunk vector type"); - } - } else { - this.colTypeRef = colTypeRef; - } - - int colTypeId = duckdb_get_type_id(this.colTypeRef); - this.colType = capiTypeFromTypeId(colTypeId); - - if (colType == DUCKDB_TYPE_DECIMAL) { - int decimalInternalTypeId = duckdb_decimal_internal_type(this.colTypeRef); - this.decimalInternalType = capiTypeFromTypeId(decimalInternalTypeId); - this.decimalPrecision = duckdb_decimal_width(this.colTypeRef); - this.decimalScale = duckdb_decimal_scale(this.colTypeRef); - } else { - this.decimalInternalType = DUCKDB_TYPE_INVALID; - this.decimalPrecision = -1; - this.decimalScale = -1; - } - - if (structFieldIdx >= 0) { - byte[] nameUTF8 = duckdb_struct_type_child_name(parent.colTypeRef, structFieldIdx); - this.structFieldName = strFromUTF8(nameUTF8); - } else { - this.structFieldName = null; - } - - this.vectorRef = vector; - - if (null == parent || parent.colType != DUCKDB_TYPE_ARRAY) { - this.arraySize = 1; - } else { - this.arraySize = duckdb_array_type_array_size(parent.colTypeRef); - } - - if (colType == DUCKDB_TYPE_ENUM) { - this.enumDict = readEnumDict(this.colTypeRef); - int enumInternalTypeId = duckdb_enum_internal_type(this.colTypeRef); - this.enumInternalType = capiTypeFromTypeId(enumInternalTypeId); - } else { - this.enumDict = null; - this.enumInternalType = null; - } - - long maxElems = maxElementsCount(); - if (colType.widthBytes > 0 || colType == DUCKDB_TYPE_DECIMAL || colType == DUCKDB_TYPE_ENUM) { - long vectorSizeBytes = maxElems * widthBytes(); - this.data = duckdb_vector_get_data(vectorRef, vectorSizeBytes); - if (null == this.data) { - throw new SQLException("cannot initialize data chunk vector data"); - } - } else { - this.data = null; - } - - duckdb_vector_ensure_validity_writable(vectorRef); - this.validity = duckdb_vector_get_validity(vectorRef, maxElems); - if (null == this.validity) { - throw new SQLException("cannot initialize data chunk vector validity"); - } - - // last call in constructor - initVecChildren(this); - } - - void reset(long listSize) throws SQLException { - if (null == parent || !(parent.colType == DUCKDB_TYPE_LIST || parent.colType == DUCKDB_TYPE_MAP)) { - throw new SQLException("invalid list column"); - } - this.listSize = listSize; - reset(); - } - - void reset() throws SQLException { - long maxElems = maxElementsCount(); - - if (null != this.data) { - long vectorSizeBytes = maxElems * widthBytes(); - this.data = duckdb_vector_get_data(vectorRef, vectorSizeBytes); - if (null == this.data) { - throw new SQLException("cannot reset data chunk vector data"); - } - } - - duckdb_vector_ensure_validity_writable(vectorRef); - this.validity = duckdb_vector_get_validity(vectorRef, maxElems); - if (null == this.validity) { - throw new SQLException("cannot reset data chunk vector validity"); - } - - for (Column col : children) { - col.reset(); - } - } - - void destroy() { - for (Column cvec : children) { - cvec.destroy(); - } - children.clear(); - if (null != colTypeRef) { - duckdb_destroy_logical_type(colTypeRef); - colTypeRef = null; - } - } - - void setNull(long vectorIdx) throws SQLException { - if (colType == DUCKDB_TYPE_ARRAY) { - setNullOnArrayIdx(vectorIdx, 0); - for (Column col : children) { - for (int i = 0; i < col.arraySize; i++) { - col.setNullOnArrayIdx(vectorIdx, i); - } - } - } else { - setNullOnVectorIdx(vectorIdx); - if (colType == DUCKDB_TYPE_LIST || colType == DUCKDB_TYPE_MAP) { - return; - } - for (Column col : children) { - col.setNull(vectorIdx); - } - } - } - - void setNullOnArrayIdx(long rowIdx, int arrayIdx) { - long vectorIdx = rowIdx * arraySize * parentArraySize() + arrayIdx; - setNullOnVectorIdx(vectorIdx); - } - - void setNullOnVectorIdx(long vectorIdx) { - long validityPos = vectorIdx / 64; - LongBuffer entries = this.validity.asLongBuffer(); - entries.position((int) validityPos); - long mask = entries.get(); - long idxInEntry = vectorIdx % 64; - mask &= ~(1L << idxInEntry); - entries.position((int) validityPos); - entries.put(mask); - } - - long widthBytes() { - if (colType == DUCKDB_TYPE_DECIMAL) { - return decimalInternalType.widthBytes; - } else if (colType == DUCKDB_TYPE_ENUM) { - return enumInternalType.widthBytes; - } else { - return colType.widthBytes; - } - } - - long parentArraySize() { - if (null == parent) { - return 1; - } - return parent.arraySize; - } - - long maxElementsCount() { - Column ancestor = this; - while (null != ancestor) { - if (null != ancestor.parent && - (ancestor.parent.colType == DUCKDB_TYPE_LIST || ancestor.parent.colType == DUCKDB_TYPE_MAP)) { - break; - } - ancestor = ancestor.parent; - } - long maxEntries = null != ancestor ? ancestor.listSize : DuckDBAppender.MAX_TOP_LEVEL_ROWS; - return maxEntries * arraySize * parentArraySize(); - } + return DuckDBVectorWriteCore.createTopLevelColumns(chunkRef, colTypes); } } diff --git a/src/main/java/org/duckdb/DuckDBBindings.java b/src/main/java/org/duckdb/DuckDBBindings.java index 4ee45c04d..468e4c04f 100644 --- a/src/main/java/org/duckdb/DuckDBBindings.java +++ b/src/main/java/org/duckdb/DuckDBBindings.java @@ -2,6 +2,9 @@ import java.nio.ByteBuffer; import java.sql.SQLException; +import org.duckdb.udf.ScalarUdf; +import org.duckdb.udf.TableFunction; +import org.duckdb.udf.UdfLogicalType; public class DuckDBBindings { @@ -77,6 +80,20 @@ public class DuckDBBindings { static native ByteBuffer duckdb_array_vector_get_child(ByteBuffer vector); + // udf vector accessors + + static native byte[] duckdb_udf_get_varchar_bytes(ByteBuffer vector_ref, int row); + + static native void duckdb_udf_set_varchar_bytes(ByteBuffer vector_ref, int row, byte[] value); + + static native byte[] duckdb_udf_get_blob_bytes(ByteBuffer vector_ref, int row); + + static native void duckdb_udf_set_blob_bytes(ByteBuffer vector_ref, int row, byte[] value); + + static native java.math.BigDecimal duckdb_udf_get_decimal(ByteBuffer vector_ref, int row); + + static native void duckdb_udf_set_decimal(ByteBuffer vector_ref, int row, java.math.BigDecimal value); + // validity static native boolean duckdb_validity_row_is_valid(ByteBuffer validity, long row); @@ -120,6 +137,88 @@ static native int duckdb_appender_create_ext(ByteBuffer connection, byte[] catal static native int duckdb_append_default_to_chunk(ByteBuffer appender, ByteBuffer chunk, long col, long row); + // scalar function + + // The returned object must be released with duckdb_destroy_scalar_function. + static native ByteBuffer duckdb_create_scalar_function(); + + static native void duckdb_destroy_scalar_function(ByteBuffer scalar_function); + + static native void duckdb_scalar_function_set_name(ByteBuffer scalar_function, byte[] name); + + static native void duckdb_scalar_function_add_parameter(ByteBuffer scalar_function, ByteBuffer logical_type); + + static native void duckdb_scalar_function_set_return_type(ByteBuffer scalar_function, ByteBuffer logical_type); + + static native void duckdb_scalar_function_set_volatile(ByteBuffer scalar_function); + + static native void duckdb_scalar_function_set_special_handling(ByteBuffer scalar_function); + + static native int duckdb_register_scalar_function(ByteBuffer connection, ByteBuffer scalar_function); + + static native void duckdb_register_scalar_function_java(ByteBuffer connection, byte[] name, ScalarUdf callback, + UdfLogicalType[] argumentLogicalTypes, + UdfLogicalType returnLogicalType, + boolean nullSpecialHandling, boolean returnNullOnException, + boolean deterministic, boolean varArgs); + + static native void duckdb_register_scalar_function_java_with_function( + ByteBuffer connection, ByteBuffer scalarFunction, ScalarUdf callback, UdfLogicalType[] argumentLogicalTypes, + UdfLogicalType returnLogicalType, boolean returnNullOnException, boolean varArgs); + + // table function + + static native ByteBuffer duckdb_create_table_function(); + + static native void duckdb_destroy_table_function(ByteBuffer table_function); + + static native void duckdb_table_function_set_name(ByteBuffer table_function, byte[] name); + + static native void duckdb_table_function_add_parameter(ByteBuffer table_function, ByteBuffer logical_type); + + static native void duckdb_table_function_supports_projection_pushdown(ByteBuffer table_function, boolean pushdown); + + static native int duckdb_register_table_function(ByteBuffer connection, ByteBuffer table_function); + + static native void duckdb_register_table_function_java(ByteBuffer connection, byte[] name, TableFunction callback, + UdfLogicalType[] parameterLogicalTypes, + boolean supportsProjectionPushdown, int maxThreads, + boolean threadSafe); + + static native void duckdb_register_table_function_java_with_function(ByteBuffer connection, + ByteBuffer tableFunction, + TableFunction callback, + UdfLogicalType[] parameterLogicalTypes, + int maxThreads, boolean threadSafe); + + static native long duckdb_bind_get_parameter_count(ByteBuffer bind_info); + + static native ByteBuffer duckdb_bind_get_parameter(ByteBuffer bind_info, long index); + + static native void duckdb_bind_add_result_column(ByteBuffer bind_info, byte[] name, ByteBuffer logical_type); + + static native void duckdb_bind_set_bind_data(ByteBuffer bind_info, ByteBuffer bind_data); + + static native void duckdb_bind_set_error(ByteBuffer bind_info, byte[] error); + + static native void duckdb_init_set_init_data(ByteBuffer init_info, ByteBuffer init_data); + + static native long duckdb_init_get_column_count(ByteBuffer init_info); + + static native long duckdb_init_get_column_index(ByteBuffer init_info, long column_index); + + static native void duckdb_init_set_max_threads(ByteBuffer init_info, long max_threads); + + static native void duckdb_init_set_error(ByteBuffer init_info, byte[] error); + + static native ByteBuffer duckdb_function_get_bind_data(ByteBuffer function_info); + + static native ByteBuffer duckdb_function_get_init_data(ByteBuffer function_info); + + static native ByteBuffer duckdb_function_get_local_init_data(ByteBuffer function_info); + + static native void duckdb_function_set_error(ByteBuffer function_info, byte[] error); + enum CAPIType { DUCKDB_TYPE_INVALID(0, 0), // bool @@ -197,7 +296,9 @@ enum CAPIType { // enum type, only useful as logical type DUCKDB_TYPE_STRING_LITERAL(37), // enum type, only useful as logical type - DUCKDB_TYPE_INTEGER_LITERAL(38); + DUCKDB_TYPE_INTEGER_LITERAL(38), + // duckdb_time_ns + DUCKDB_TYPE_TIME_NS(39, 8); final int typeId; final long widthBytes; diff --git a/src/main/java/org/duckdb/DuckDBConnection.java b/src/main/java/org/duckdb/DuckDBConnection.java index d51c0c00e..275385a95 100644 --- a/src/main/java/org/duckdb/DuckDBConnection.java +++ b/src/main/java/org/duckdb/DuckDBConnection.java @@ -26,6 +26,12 @@ import java.util.*; import java.util.concurrent.Executor; import java.util.concurrent.locks.ReentrantLock; +import org.duckdb.udf.ScalarUdf; +import org.duckdb.udf.TableFunction; +import org.duckdb.udf.TableFunctionDefinition; +import org.duckdb.udf.TableFunctionOptions; +import org.duckdb.udf.UdfLogicalType; +import org.duckdb.udf.UdfOptions; import org.duckdb.user.DuckDBMap; import org.duckdb.user.DuckDBUserArray; import org.duckdb.user.DuckDBUserStruct; @@ -267,6 +273,388 @@ public void setCatalog(String catalog) throws SQLException { } } + public void registerScalarUdf(String name, ScalarUdf callback) throws SQLException { + registerScalarUdf(name, new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.INTEGER)}, + UdfLogicalType.of(DuckDBColumnType.INTEGER), callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, ScalarUdf callback, UdfOptions options) throws SQLException { + registerScalarUdf(name, new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.INTEGER)}, + UdfLogicalType.of(DuckDBColumnType.INTEGER), callback, options); + } + + public void registerTableFunction(String name, TableFunction callback) throws SQLException { + registerTableFunction(name, callback, new TableFunctionDefinition(), new TableFunctionOptions()); + } + + public void registerTableFunction(String name, TableFunction callback, TableFunctionOptions options) + throws SQLException { + registerTableFunction(name, callback, new TableFunctionDefinition(), options); + } + + public void registerTableFunction(String name, TableFunction callback, TableFunctionDefinition definition) + throws SQLException { + registerTableFunction(name, callback, definition, new TableFunctionOptions()); + } + + public void registerTableFunction(String name, TableFunction callback, TableFunctionDefinition definition, + TableFunctionOptions options) throws SQLException { + Objects.requireNonNull(name, "name"); + Objects.requireNonNull(callback, "callback"); + Objects.requireNonNull(definition, "definition"); + Objects.requireNonNull(options, "options"); + if (options.maxThreads < 1) { + throw new SQLException("TableFunctionOptions.maxThreads must be >= 1"); + } + UdfLogicalType[] parameterTypes = definition.getParameterLogicalTypes(); + for (UdfLogicalType parameterType : parameterTypes) { + Objects.requireNonNull(parameterType, "parameterTypes cannot contain null values"); + UdfTypeCatalog.validateTableFunctionParameterLogicalType(parameterType); + } + checkOpen(); + connRefLock.lock(); + try { + checkOpen(); + ByteBuffer tableFunction = DuckDBBindings.duckdb_create_table_function(); + try { + DuckDBBindings.duckdb_table_function_set_name(tableFunction, name.getBytes(UTF_8)); + if (definition.isProjectionPushdownEnabled()) { + DuckDBBindings.duckdb_table_function_supports_projection_pushdown(tableFunction, true); + } + DuckDBBindings.duckdb_register_table_function_java_with_function( + connRef, tableFunction, callback, parameterTypes, options.maxThreads, options.threadSafe); + } finally { + DuckDBBindings.duckdb_destroy_table_function(tableFunction); + } + } finally { + connRefLock.unlock(); + } + } + + public void registerScalarUdf(String name, DuckDBColumnType[] argumentTypes, DuckDBColumnType returnType, + ScalarUdf callback) throws SQLException { + registerScalarUdf(name, argumentTypes, returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, DuckDBColumnType returnType, ScalarUdf callback) throws SQLException { + registerScalarUdf(name, new DuckDBColumnType[0], returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, DuckDBColumnType returnType, ScalarUdf callback, UdfOptions options) + throws SQLException { + registerScalarUdf(name, new DuckDBColumnType[0], returnType, callback, options); + } + + public void registerScalarUdf(String name, DuckDBColumnType argumentType, DuckDBColumnType returnType, + ScalarUdf callback) throws SQLException { + registerScalarUdf(name, new DuckDBColumnType[] {argumentType}, returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, DuckDBColumnType argumentType, DuckDBColumnType returnType, + ScalarUdf callback, UdfOptions options) throws SQLException { + registerScalarUdf(name, new DuckDBColumnType[] {argumentType}, returnType, callback, options); + } + + public void registerScalarUdf(String name, DuckDBColumnType firstArgumentType, DuckDBColumnType secondArgumentType, + DuckDBColumnType returnType, ScalarUdf callback) throws SQLException { + registerScalarUdf(name, new DuckDBColumnType[] {firstArgumentType, secondArgumentType}, returnType, callback, + new UdfOptions()); + } + + public void registerScalarUdf(String name, DuckDBColumnType firstArgumentType, DuckDBColumnType secondArgumentType, + DuckDBColumnType returnType, ScalarUdf callback, UdfOptions options) + throws SQLException { + registerScalarUdf(name, new DuckDBColumnType[] {firstArgumentType, secondArgumentType}, returnType, callback, + options); + } + + public void registerScalarUdf(String name, DuckDBColumnType firstArgumentType, DuckDBColumnType secondArgumentType, + DuckDBColumnType thirdArgumentType, DuckDBColumnType returnType, ScalarUdf callback) + throws SQLException { + registerScalarUdf(name, new DuckDBColumnType[] {firstArgumentType, secondArgumentType, thirdArgumentType}, + returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, DuckDBColumnType firstArgumentType, DuckDBColumnType secondArgumentType, + DuckDBColumnType thirdArgumentType, DuckDBColumnType returnType, ScalarUdf callback, + UdfOptions options) throws SQLException { + registerScalarUdf(name, new DuckDBColumnType[] {firstArgumentType, secondArgumentType, thirdArgumentType}, + returnType, callback, options); + } + + public void registerScalarUdf(String name, DuckDBColumnType firstArgumentType, DuckDBColumnType secondArgumentType, + DuckDBColumnType thirdArgumentType, DuckDBColumnType fourthArgumentType, + DuckDBColumnType returnType, ScalarUdf callback) throws SQLException { + registerScalarUdf( + name, new DuckDBColumnType[] {firstArgumentType, secondArgumentType, thirdArgumentType, fourthArgumentType}, + returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, DuckDBColumnType firstArgumentType, DuckDBColumnType secondArgumentType, + DuckDBColumnType thirdArgumentType, DuckDBColumnType fourthArgumentType, + DuckDBColumnType returnType, ScalarUdf callback, UdfOptions options) + throws SQLException { + registerScalarUdf( + name, new DuckDBColumnType[] {firstArgumentType, secondArgumentType, thirdArgumentType, fourthArgumentType}, + returnType, callback, options); + } + + public void registerScalarUdf(String name, DuckDBColumnType[] argumentTypes, DuckDBColumnType returnType, + ScalarUdf callback, UdfOptions options) throws SQLException { + UdfLogicalType[] argumentLogicalTypes = mapDuckdbTypesToLogicalTypes(argumentTypes); + DuckDBColumnType nonNullReturnType = Objects.requireNonNull(returnType, "returnType"); + UdfTypeCatalog.toCapiTypeIdForScalarRegistration(nonNullReturnType); + registerScalarUdfInternal(name, argumentLogicalTypes, UdfLogicalType.of(nonNullReturnType), callback, options); + } + + public void registerScalarUdf(String name, Class[] argumentTypes, Class returnType, ScalarUdf callback) + throws SQLException { + registerScalarUdf(name, argumentTypes, returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, Class[] argumentTypes, Class returnType, ScalarUdf callback, + UdfOptions options) throws SQLException { + UdfLogicalType[] argumentLogicalTypes = mapJavaClassesToLogicalTypes(argumentTypes); + UdfLogicalType returnLogicalType = mapJavaClassToLogicalType(returnType, "returnType"); + registerScalarUdfInternal(name, argumentLogicalTypes, returnLogicalType, callback, options); + } + + public void registerScalarUdf(String name, Class returnType, ScalarUdf callback) throws SQLException { + registerScalarUdf(name, new Class[ 0 ], returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, Class returnType, ScalarUdf callback, UdfOptions options) + throws SQLException { + registerScalarUdf(name, new Class[ 0 ], returnType, callback, options); + } + + public void registerScalarUdf(String name, Class argumentType, Class returnType, ScalarUdf callback) + throws SQLException { + registerScalarUdf(name, new Class[] {argumentType}, returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, Class argumentType, Class returnType, ScalarUdf callback, + UdfOptions options) throws SQLException { + registerScalarUdf(name, new Class[] {argumentType}, returnType, callback, options); + } + + public void registerScalarUdf(String name, Class firstArgumentType, Class secondArgumentType, + Class returnType, ScalarUdf callback) throws SQLException { + registerScalarUdf(name, new Class[] {firstArgumentType, secondArgumentType}, returnType, callback, + new UdfOptions()); + } + + public void registerScalarUdf(String name, Class firstArgumentType, Class secondArgumentType, + Class returnType, ScalarUdf callback, UdfOptions options) throws SQLException { + registerScalarUdf(name, new Class[] {firstArgumentType, secondArgumentType}, returnType, callback, options); + } + + public void registerScalarUdf(String name, Class firstArgumentType, Class secondArgumentType, + Class thirdArgumentType, Class returnType, ScalarUdf callback) + throws SQLException { + registerScalarUdf(name, new Class[] {firstArgumentType, secondArgumentType, thirdArgumentType}, returnType, + callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, Class firstArgumentType, Class secondArgumentType, + Class thirdArgumentType, Class returnType, ScalarUdf callback, + UdfOptions options) throws SQLException { + registerScalarUdf(name, new Class[] {firstArgumentType, secondArgumentType, thirdArgumentType}, returnType, + callback, options); + } + + public void registerScalarUdf(String name, Class firstArgumentType, Class secondArgumentType, + Class thirdArgumentType, Class fourthArgumentType, Class returnType, + ScalarUdf callback) throws SQLException { + registerScalarUdf(name, + new Class[] {firstArgumentType, secondArgumentType, thirdArgumentType, fourthArgumentType}, + returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, Class firstArgumentType, Class secondArgumentType, + Class thirdArgumentType, Class fourthArgumentType, Class returnType, + ScalarUdf callback, UdfOptions options) throws SQLException { + registerScalarUdf(name, + new Class[] {firstArgumentType, secondArgumentType, thirdArgumentType, fourthArgumentType}, + returnType, callback, options); + } + + public void registerScalarUdf(String name, UdfLogicalType[] argumentTypes, UdfLogicalType returnType, + ScalarUdf callback) throws SQLException { + registerScalarUdf(name, argumentTypes, returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, UdfLogicalType returnType, ScalarUdf callback) throws SQLException { + registerScalarUdf(name, new UdfLogicalType[0], returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, UdfLogicalType returnType, ScalarUdf callback, UdfOptions options) + throws SQLException { + registerScalarUdf(name, new UdfLogicalType[0], returnType, callback, options); + } + + public void registerScalarUdf(String name, UdfLogicalType argumentType, UdfLogicalType returnType, + ScalarUdf callback) throws SQLException { + registerScalarUdf(name, new UdfLogicalType[] {argumentType}, returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, UdfLogicalType argumentType, UdfLogicalType returnType, + ScalarUdf callback, UdfOptions options) throws SQLException { + registerScalarUdf(name, new UdfLogicalType[] {argumentType}, returnType, callback, options); + } + + public void registerScalarUdf(String name, UdfLogicalType firstArgumentType, UdfLogicalType secondArgumentType, + UdfLogicalType returnType, ScalarUdf callback) throws SQLException { + registerScalarUdf(name, new UdfLogicalType[] {firstArgumentType, secondArgumentType}, returnType, callback, + new UdfOptions()); + } + + public void registerScalarUdf(String name, UdfLogicalType firstArgumentType, UdfLogicalType secondArgumentType, + UdfLogicalType returnType, ScalarUdf callback, UdfOptions options) + throws SQLException { + registerScalarUdf(name, new UdfLogicalType[] {firstArgumentType, secondArgumentType}, returnType, callback, + options); + } + + public void registerScalarUdf(String name, UdfLogicalType firstArgumentType, UdfLogicalType secondArgumentType, + UdfLogicalType thirdArgumentType, UdfLogicalType returnType, ScalarUdf callback) + throws SQLException { + registerScalarUdf(name, new UdfLogicalType[] {firstArgumentType, secondArgumentType, thirdArgumentType}, + returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, UdfLogicalType firstArgumentType, UdfLogicalType secondArgumentType, + UdfLogicalType thirdArgumentType, UdfLogicalType returnType, ScalarUdf callback, + UdfOptions options) throws SQLException { + registerScalarUdf(name, new UdfLogicalType[] {firstArgumentType, secondArgumentType, thirdArgumentType}, + returnType, callback, options); + } + + public void registerScalarUdf(String name, UdfLogicalType firstArgumentType, UdfLogicalType secondArgumentType, + UdfLogicalType thirdArgumentType, UdfLogicalType fourthArgumentType, + UdfLogicalType returnType, ScalarUdf callback) throws SQLException { + registerScalarUdf( + name, new UdfLogicalType[] {firstArgumentType, secondArgumentType, thirdArgumentType, fourthArgumentType}, + returnType, callback, new UdfOptions()); + } + + public void registerScalarUdf(String name, UdfLogicalType firstArgumentType, UdfLogicalType secondArgumentType, + UdfLogicalType thirdArgumentType, UdfLogicalType fourthArgumentType, + UdfLogicalType returnType, ScalarUdf callback, UdfOptions options) + throws SQLException { + registerScalarUdf( + name, new UdfLogicalType[] {firstArgumentType, secondArgumentType, thirdArgumentType, fourthArgumentType}, + returnType, callback, options); + } + + public void registerScalarUdf(String name, UdfLogicalType[] argumentTypes, UdfLogicalType returnType, + ScalarUdf callback, UdfOptions options) throws SQLException { + registerScalarUdfInternal(name, argumentTypes, returnType, callback, options); + } + + private void registerScalarUdfInternal(String name, UdfLogicalType[] argumentTypes, UdfLogicalType returnType, + ScalarUdf callback, UdfOptions options) throws SQLException { + Objects.requireNonNull(name, "name"); + Objects.requireNonNull(argumentTypes, "argumentTypes"); + Objects.requireNonNull(returnType, "returnType"); + Objects.requireNonNull(callback, "callback"); + Objects.requireNonNull(options, "options"); + if (options.varArgs && argumentTypes.length != 1) { + throw new SQLException("Scalar UDF varargs registration expects exactly one argument logical type"); + } + for (int i = 0; i < argumentTypes.length; i++) { + Objects.requireNonNull(argumentTypes[i], "argumentTypes cannot contain null values"); + UdfTypeCatalog.validateScalarLogicalType(argumentTypes[i]); + } + UdfTypeCatalog.validateScalarLogicalType(returnType); + checkOpen(); + connRefLock.lock(); + try { + checkOpen(); + ByteBuffer scalarFunction = DuckDBBindings.duckdb_create_scalar_function(); + try { + DuckDBBindings.duckdb_scalar_function_set_name(scalarFunction, name.getBytes(UTF_8)); + if (options.nullSpecialHandling) { + DuckDBBindings.duckdb_scalar_function_set_special_handling(scalarFunction); + } + if (!options.deterministic) { + DuckDBBindings.duckdb_scalar_function_set_volatile(scalarFunction); + } + DuckDBBindings.duckdb_register_scalar_function_java_with_function( + connRef, scalarFunction, callback, argumentTypes, returnType, options.returnNullOnException, + options.varArgs); + } finally { + DuckDBBindings.duckdb_destroy_scalar_function(scalarFunction); + } + } finally { + connRefLock.unlock(); + } + } + + private static UdfLogicalType[] mapDuckdbTypesToLogicalTypes(DuckDBColumnType[] argumentTypes) throws SQLException { + Objects.requireNonNull(argumentTypes, "argumentTypes"); + UdfLogicalType[] argumentLogicalTypes = new UdfLogicalType[argumentTypes.length]; + for (int i = 0; i < argumentTypes.length; i++) { + DuckDBColumnType argumentType = + Objects.requireNonNull(argumentTypes[i], "argumentTypes cannot contain null values"); + UdfTypeCatalog.toCapiTypeIdForScalarRegistration(argumentType); + argumentLogicalTypes[i] = UdfLogicalType.of(argumentType); + } + return argumentLogicalTypes; + } + + private static UdfLogicalType mapJavaClassToLogicalType(Class javaType, String argumentName) + throws SQLException { + Objects.requireNonNull(javaType, argumentName); + try { + return UdfJavaTypeMapper.toLogicalType(javaType); + } catch (IllegalArgumentException e) { + throw new SQLException(e.getMessage(), e); + } + } + + private static UdfLogicalType[] mapJavaClassesToLogicalTypes(Class[] argumentTypes) throws SQLException { + Objects.requireNonNull(argumentTypes, "argumentTypes"); + UdfLogicalType[] argumentLogicalTypes = new UdfLogicalType[argumentTypes.length]; + for (int i = 0; i < argumentTypes.length; i++) { + argumentLogicalTypes[i] = + mapJavaClassToLogicalType(argumentTypes[i], "argumentTypes cannot contain null values"); + } + return argumentLogicalTypes; + } + + public void registerScalarUdfVarArgs(String name, DuckDBColumnType argumentType, DuckDBColumnType returnType, + ScalarUdf callback) throws SQLException { + registerScalarUdfVarArgs(name, argumentType, returnType, callback, new UdfOptions()); + } + + public void registerScalarUdfVarArgs(String name, DuckDBColumnType argumentType, DuckDBColumnType returnType, + ScalarUdf callback, UdfOptions options) throws SQLException { + Objects.requireNonNull(options, "options"); + UdfOptions normalizedOptions = new UdfOptions() + .deterministic(options.deterministic) + .nullSpecialHandling(options.nullSpecialHandling) + .returnNullOnException(options.returnNullOnException) + .varArgs(true); + registerScalarUdf(name, new DuckDBColumnType[] {argumentType}, returnType, callback, normalizedOptions); + } + + public void registerScalarUdfVarArgs(String name, UdfLogicalType argumentType, UdfLogicalType returnType, + ScalarUdf callback) throws SQLException { + registerScalarUdfVarArgs(name, argumentType, returnType, callback, new UdfOptions()); + } + + public void registerScalarUdfVarArgs(String name, UdfLogicalType argumentType, UdfLogicalType returnType, + ScalarUdf callback, UdfOptions options) throws SQLException { + Objects.requireNonNull(options, "options"); + UdfOptions normalizedOptions = new UdfOptions() + .deterministic(options.deterministic) + .nullSpecialHandling(options.nullSpecialHandling) + .returnNullOnException(options.returnNullOnException) + .varArgs(true); + registerScalarUdf(name, new UdfLogicalType[] {argumentType}, returnType, callback, normalizedOptions); + } + public String getCatalog() throws SQLException { checkOpen(); connRefLock.lock(); diff --git a/src/main/java/org/duckdb/DuckDBVectorWriteCore.java b/src/main/java/org/duckdb/DuckDBVectorWriteCore.java new file mode 100644 index 000000000..0982cb2e3 --- /dev/null +++ b/src/main/java/org/duckdb/DuckDBVectorWriteCore.java @@ -0,0 +1,287 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.duckdb.DuckDBBindings.*; +import static org.duckdb.DuckDBBindings.CAPIType.*; + +import java.nio.ByteBuffer; +import java.nio.LongBuffer; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +final class DuckDBVectorWriteCore { + static final long MAX_TOP_LEVEL_ROWS = duckdb_vector_size(); + + private DuckDBVectorWriteCore() { + } + + static List createTopLevelColumns(ByteBuffer chunkRef, ByteBuffer[] colTypes) throws SQLException { + List columns = new ArrayList<>(colTypes.length); + try { + for (int i = 0; i < colTypes.length; i++) { + ByteBuffer vector = duckdb_data_chunk_get_vector(chunkRef, i); + Column col = new Column(null, i, colTypes[i], vector); + columns.add(col); + colTypes[i] = null; + } + } catch (Exception e) { + for (Column col : columns) { + if (null != col) { + col.destroy(); + } + } + throw e; + } + return columns; + } + + private static void initVecChildren(Column parent) throws SQLException { + switch (parent.colType) { + case DUCKDB_TYPE_LIST: + case DUCKDB_TYPE_MAP: { + ByteBuffer vec = duckdb_list_vector_get_child(parent.vectorRef); + Column col = new Column(parent, 0, null, vec); + parent.children.add(col); + break; + } + case DUCKDB_TYPE_STRUCT: + case DUCKDB_TYPE_UNION: { + long count = duckdb_struct_type_child_count(parent.colTypeRef); + for (int i = 0; i < count; i++) { + ByteBuffer vec = duckdb_struct_vector_get_child(parent.vectorRef, i); + Column col = new Column(parent, i, null, vec, i); + parent.children.add(col); + } + break; + } + case DUCKDB_TYPE_ARRAY: { + ByteBuffer vec = duckdb_array_vector_get_child(parent.vectorRef); + Column col = new Column(parent, 0, null, vec); + parent.children.add(col); + break; + } + } + } + + private static Map readEnumDict(ByteBuffer colTypeRef) { + Map dict = new LinkedHashMap<>(); + long size = duckdb_enum_dictionary_size(colTypeRef); + for (long i = 0; i < size; i++) { + byte[] nameUtf8 = duckdb_enum_dictionary_value(colTypeRef, i); + String name = new String(nameUtf8, UTF_8); + dict.put(name, (int) i); + } + return dict; + } + + static final class Column { + final Column parent; + final int idx; + ByteBuffer colTypeRef; + final CAPIType colType; + final CAPIType decimalInternalType; + final int decimalPrecision; + final int decimalScale; + final long arraySize; + final String structFieldName; + final Map enumDict; + final CAPIType enumInternalType; + + final ByteBuffer vectorRef; + final List children = new ArrayList<>(); + + long listSize = 0; + ByteBuffer data = null; + ByteBuffer validity = null; + + Column(Column parent, int idx, ByteBuffer colTypeRef, ByteBuffer vector) throws SQLException { + this(parent, idx, colTypeRef, vector, -1); + } + + Column(Column parent, int idx, ByteBuffer colTypeRef, ByteBuffer vector, int structFieldIdx) + throws SQLException { + this.parent = parent; + this.idx = idx; + + if (null == vector) { + throw new SQLException("cannot initialize data chunk vector"); + } + + if (null == colTypeRef) { + this.colTypeRef = duckdb_vector_get_column_type(vector); + if (null == this.colTypeRef) { + throw new SQLException("cannot initialize data chunk vector type"); + } + } else { + this.colTypeRef = colTypeRef; + } + + int colTypeId = duckdb_get_type_id(this.colTypeRef); + this.colType = capiTypeFromTypeId(colTypeId); + + if (colType == DUCKDB_TYPE_DECIMAL) { + int decimalInternalTypeId = duckdb_decimal_internal_type(this.colTypeRef); + this.decimalInternalType = capiTypeFromTypeId(decimalInternalTypeId); + this.decimalPrecision = duckdb_decimal_width(this.colTypeRef); + this.decimalScale = duckdb_decimal_scale(this.colTypeRef); + } else { + this.decimalInternalType = DUCKDB_TYPE_INVALID; + this.decimalPrecision = -1; + this.decimalScale = -1; + } + + if (structFieldIdx >= 0) { + byte[] nameUTF8 = duckdb_struct_type_child_name(parent.colTypeRef, structFieldIdx); + this.structFieldName = new String(nameUTF8, UTF_8); + } else { + this.structFieldName = null; + } + + this.vectorRef = vector; + + if (null == parent || parent.colType != DUCKDB_TYPE_ARRAY) { + this.arraySize = 1; + } else { + this.arraySize = duckdb_array_type_array_size(parent.colTypeRef); + } + + if (colType == DUCKDB_TYPE_ENUM) { + this.enumDict = readEnumDict(this.colTypeRef); + int enumInternalTypeId = duckdb_enum_internal_type(this.colTypeRef); + this.enumInternalType = capiTypeFromTypeId(enumInternalTypeId); + } else { + this.enumDict = null; + this.enumInternalType = null; + } + + long maxElems = maxElementsCount(); + if (colType.widthBytes > 0 || colType == DUCKDB_TYPE_DECIMAL || colType == DUCKDB_TYPE_ENUM) { + long vectorSizeBytes = maxElems * widthBytes(); + this.data = duckdb_vector_get_data(vectorRef, vectorSizeBytes); + if (null == this.data) { + throw new SQLException("cannot initialize data chunk vector data"); + } + } else { + this.data = null; + } + + duckdb_vector_ensure_validity_writable(vectorRef); + this.validity = duckdb_vector_get_validity(vectorRef, maxElems); + if (null == this.validity) { + throw new SQLException("cannot initialize data chunk vector validity"); + } + + // Last call in constructor, after the current column is fully initialized. + initVecChildren(this); + } + + void reset(long listSize) throws SQLException { + if (null == parent || !(parent.colType == DUCKDB_TYPE_LIST || parent.colType == DUCKDB_TYPE_MAP)) { + throw new SQLException("invalid list column"); + } + this.listSize = listSize; + reset(); + } + + void reset() throws SQLException { + long maxElems = maxElementsCount(); + + if (null != this.data) { + long vectorSizeBytes = maxElems * widthBytes(); + this.data = duckdb_vector_get_data(vectorRef, vectorSizeBytes); + if (null == this.data) { + throw new SQLException("cannot reset data chunk vector data"); + } + } + + duckdb_vector_ensure_validity_writable(vectorRef); + this.validity = duckdb_vector_get_validity(vectorRef, maxElems); + if (null == this.validity) { + throw new SQLException("cannot reset data chunk vector validity"); + } + + for (Column col : children) { + col.reset(); + } + } + + void destroy() { + for (Column cvec : children) { + cvec.destroy(); + } + children.clear(); + if (null != colTypeRef) { + duckdb_destroy_logical_type(colTypeRef); + colTypeRef = null; + } + } + + void setNull(long vectorIdx) throws SQLException { + if (colType == DUCKDB_TYPE_ARRAY) { + setNullOnArrayIdx(vectorIdx, 0); + for (Column col : children) { + for (int i = 0; i < col.arraySize; i++) { + col.setNullOnArrayIdx(vectorIdx, i); + } + } + } else { + setNullOnVectorIdx(vectorIdx); + if (colType == DUCKDB_TYPE_LIST || colType == DUCKDB_TYPE_MAP) { + return; + } + for (Column col : children) { + col.setNull(vectorIdx); + } + } + } + + void setNullOnArrayIdx(long rowIdx, int arrayIdx) { + long vectorIdx = rowIdx * arraySize * parentArraySize() + arrayIdx; + setNullOnVectorIdx(vectorIdx); + } + + private void setNullOnVectorIdx(long vectorIdx) { + long validityPos = vectorIdx / 64; + LongBuffer entries = validity.asLongBuffer(); + entries.position((int) validityPos); + long mask = entries.get(); + long idxInEntry = vectorIdx % 64; + mask &= ~(1L << idxInEntry); + entries.position((int) validityPos); + entries.put(mask); + } + + long widthBytes() { + if (colType == DUCKDB_TYPE_DECIMAL) { + return decimalInternalType.widthBytes; + } else if (colType == DUCKDB_TYPE_ENUM) { + return enumInternalType.widthBytes; + } else { + return colType.widthBytes; + } + } + + private long parentArraySize() { + if (null == parent) { + return 1; + } + return parent.arraySize; + } + + private long maxElementsCount() { + Column ancestor = this; + while (null != ancestor) { + if (null != ancestor.parent && + (ancestor.parent.colType == DUCKDB_TYPE_LIST || ancestor.parent.colType == DUCKDB_TYPE_MAP)) { + break; + } + ancestor = ancestor.parent; + } + long maxEntries = null != ancestor ? ancestor.listSize : MAX_TOP_LEVEL_ROWS; + return maxEntries * arraySize * parentArraySize(); + } + } +} diff --git a/src/main/java/org/duckdb/UdfJavaTypeMapper.java b/src/main/java/org/duckdb/UdfJavaTypeMapper.java new file mode 100644 index 000000000..b155cb260 --- /dev/null +++ b/src/main/java/org/duckdb/UdfJavaTypeMapper.java @@ -0,0 +1,84 @@ +package org.duckdb; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.sql.Date; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.time.OffsetTime; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; +import org.duckdb.udf.UdfLogicalType; + +public final class UdfJavaTypeMapper { + private static final Map, UdfLogicalType> TYPE_MAPPINGS = new HashMap<>(); + + static { + TYPE_MAPPINGS.put(Boolean.class, UdfLogicalType.of(DuckDBColumnType.BOOLEAN)); + TYPE_MAPPINGS.put(Byte.class, UdfLogicalType.of(DuckDBColumnType.TINYINT)); + TYPE_MAPPINGS.put(Short.class, UdfLogicalType.of(DuckDBColumnType.SMALLINT)); + TYPE_MAPPINGS.put(Integer.class, UdfLogicalType.of(DuckDBColumnType.INTEGER)); + TYPE_MAPPINGS.put(Long.class, UdfLogicalType.of(DuckDBColumnType.BIGINT)); + TYPE_MAPPINGS.put(Float.class, UdfLogicalType.of(DuckDBColumnType.FLOAT)); + TYPE_MAPPINGS.put(Double.class, UdfLogicalType.of(DuckDBColumnType.DOUBLE)); + TYPE_MAPPINGS.put(String.class, UdfLogicalType.of(DuckDBColumnType.VARCHAR)); + TYPE_MAPPINGS.put(byte[].class, UdfLogicalType.of(DuckDBColumnType.BLOB)); + TYPE_MAPPINGS.put(BigInteger.class, UdfLogicalType.of(DuckDBColumnType.HUGEINT)); + TYPE_MAPPINGS.put(UUID.class, UdfLogicalType.of(DuckDBColumnType.UUID)); + TYPE_MAPPINGS.put(LocalDate.class, UdfLogicalType.of(DuckDBColumnType.DATE)); + TYPE_MAPPINGS.put(Date.class, UdfLogicalType.of(DuckDBColumnType.DATE)); + TYPE_MAPPINGS.put(LocalTime.class, UdfLogicalType.of(DuckDBColumnType.TIME)); + TYPE_MAPPINGS.put(OffsetTime.class, UdfLogicalType.of(DuckDBColumnType.TIME_WITH_TIME_ZONE)); + TYPE_MAPPINGS.put(LocalDateTime.class, UdfLogicalType.of(DuckDBColumnType.TIMESTAMP)); + TYPE_MAPPINGS.put(OffsetDateTime.class, UdfLogicalType.of(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE)); + } + + private UdfJavaTypeMapper() { + } + + public static UdfLogicalType toLogicalType(Class javaType) { + if (javaType == null) { + throw new IllegalArgumentException("javaType must not be null"); + } + Class normalizedType = javaType.isPrimitive() ? wrapPrimitive(javaType) : javaType; + if (normalizedType == BigDecimal.class) { + throw new IllegalArgumentException( + "BigDecimal requires explicit logical type; use UdfLogicalType.decimal(width, scale)"); + } + UdfLogicalType logicalType = TYPE_MAPPINGS.get(normalizedType); + if (logicalType == null) { + throw new IllegalArgumentException("Unsupported Java class for scalar UDF mapping: " + + normalizedType.getName()); + } + return logicalType; + } + + private static Class wrapPrimitive(Class primitiveType) { + if (primitiveType == boolean.class) { + return Boolean.class; + } + if (primitiveType == byte.class) { + return Byte.class; + } + if (primitiveType == short.class) { + return Short.class; + } + if (primitiveType == int.class) { + return Integer.class; + } + if (primitiveType == long.class) { + return Long.class; + } + if (primitiveType == float.class) { + return Float.class; + } + if (primitiveType == double.class) { + return Double.class; + } + throw new IllegalArgumentException("Unsupported primitive Java type for scalar UDF mapping: " + + primitiveType.getName()); + } +} diff --git a/src/main/java/org/duckdb/UdfNative.java b/src/main/java/org/duckdb/UdfNative.java new file mode 100644 index 000000000..cd8d09011 --- /dev/null +++ b/src/main/java/org/duckdb/UdfNative.java @@ -0,0 +1,36 @@ +package org.duckdb; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.sql.SQLException; + +public final class UdfNative { + private UdfNative() { + } + + public static String getVarchar(ByteBuffer vectorRef, int row) throws SQLException { + byte[] bytes = DuckDBBindings.duckdb_udf_get_varchar_bytes(vectorRef, row); + return new String(bytes, StandardCharsets.UTF_8); + } + + public static void setVarchar(ByteBuffer vectorRef, int row, String value) throws SQLException { + DuckDBBindings.duckdb_udf_set_varchar_bytes(vectorRef, row, value.getBytes(StandardCharsets.UTF_8)); + } + + public static byte[] getBlob(ByteBuffer vectorRef, int row) throws SQLException { + return DuckDBBindings.duckdb_udf_get_blob_bytes(vectorRef, row); + } + + public static void setBlob(ByteBuffer vectorRef, int row, byte[] value) throws SQLException { + DuckDBBindings.duckdb_udf_set_blob_bytes(vectorRef, row, value); + } + + public static BigDecimal getDecimal(ByteBuffer vectorRef, int row) throws SQLException { + return DuckDBBindings.duckdb_udf_get_decimal(vectorRef, row); + } + + public static void setDecimal(ByteBuffer vectorRef, int row, BigDecimal value) throws SQLException { + DuckDBBindings.duckdb_udf_set_decimal(vectorRef, row, value); + } +} diff --git a/src/main/java/org/duckdb/UdfNativeReader.java b/src/main/java/org/duckdb/UdfNativeReader.java new file mode 100644 index 000000000..11d33ff8c --- /dev/null +++ b/src/main/java/org/duckdb/UdfNativeReader.java @@ -0,0 +1,110 @@ +package org.duckdb; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.time.OffsetTime; +import java.util.Date; +import java.util.UUID; + +final class UdfNativeReader implements UdfReader { + private final UdfScalarWriter vector; + + UdfNativeReader(int capiTypeId, ByteBuffer data, ByteBuffer vectorRef, ByteBuffer validity, int rowCount) { + this.vector = new UdfScalarWriter(capiTypeId, data, vectorRef, validity, rowCount); + } + + @Override + public DuckDBColumnType getType() { + return vector.getType(); + } + + @Override + public boolean isNull(int row) { + return vector.isNull(row); + } + + @Override + public int getInt(int row) { + return vector.getInt(row); + } + + @Override + public long getLong(int row) { + return vector.getLong(row); + } + + @Override + public float getFloat(int row) { + return vector.getFloat(row); + } + + @Override + public double getDouble(int row) { + return vector.getDouble(row); + } + + @Override + public BigDecimal getBigDecimal(int row) { + return vector.getBigDecimal(row); + } + + @Override + public BigInteger getBigInteger(int row) { + return vector.getBigInteger(row); + } + + @Override + public Date getDate(int row) { + return vector.getDate(row); + } + + @Override + public LocalDate getLocalDate(int row) { + return vector.getLocalDate(row); + } + + @Override + public LocalTime getLocalTime(int row) { + return vector.getLocalTime(row); + } + + @Override + public OffsetTime getOffsetTime(int row) { + return vector.getOffsetTime(row); + } + + @Override + public LocalDateTime getLocalDateTime(int row) { + return vector.getLocalDateTime(row); + } + + @Override + public OffsetDateTime getOffsetDateTime(int row) { + return vector.getOffsetDateTime(row); + } + + @Override + public UUID getUUID(int row) { + return vector.getUUID(row); + } + + @Override + public boolean getBoolean(int row) { + return vector.getBoolean(row); + } + + @Override + public String getString(int row) { + return vector.getString(row); + } + + @Override + public byte[] getBytes(int row) { + return vector.getBytes(row); + } +} diff --git a/src/main/java/org/duckdb/UdfOutputAppender.java b/src/main/java/org/duckdb/UdfOutputAppender.java new file mode 100644 index 000000000..45a2a98fb --- /dev/null +++ b/src/main/java/org/duckdb/UdfOutputAppender.java @@ -0,0 +1,1068 @@ +package org.duckdb; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.time.ZoneOffset.UTC; +import static java.time.temporal.ChronoUnit.MICROS; +import static java.time.temporal.ChronoUnit.MILLIS; +import static java.time.temporal.ChronoUnit.NANOS; +import static java.time.temporal.ChronoUnit.SECONDS; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_ARRAY; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_BLOB; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_BOOLEAN; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_DATE; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_DECIMAL; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_DOUBLE; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_ENUM; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_FLOAT; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_HUGEINT; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_INTEGER; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_LIST; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_MAP; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_SMALLINT; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_STRUCT; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_TIME; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_MS; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_NS; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_S; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_TZ; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_TIME_NS; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_TIME_TZ; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_TINYINT; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_UBIGINT; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_UHUGEINT; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_UINTEGER; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_UNION; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_USMALLINT; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_UTINYINT; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_UUID; +import static org.duckdb.DuckDBBindings.CAPIType.DUCKDB_TYPE_VARCHAR; +import static org.duckdb.DuckDBBindings.duckdb_data_chunk_get_column_count; +import static org.duckdb.DuckDBBindings.duckdb_data_chunk_get_vector; +import static org.duckdb.DuckDBBindings.duckdb_destroy_logical_type; +import static org.duckdb.DuckDBBindings.duckdb_list_vector_get_size; +import static org.duckdb.DuckDBBindings.duckdb_list_vector_reserve; +import static org.duckdb.DuckDBBindings.duckdb_list_vector_set_size; +import static org.duckdb.DuckDBBindings.duckdb_vector_assign_string_element_len; +import static org.duckdb.DuckDBBindings.duckdb_vector_get_column_type; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.LongBuffer; +import java.sql.SQLException; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.time.OffsetTime; +import java.time.ZonedDateTime; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; +import org.duckdb.DuckDBVectorWriteCore.Column; + +public final class UdfOutputAppender implements AutoCloseable { + private static final long UNSIGNED_INT_MAX = 0xFFFF_FFFFL; + private static final LocalDateTime EPOCH_DATE_TIME = LocalDateTime.ofEpochSecond(0, 0, UTC); + private static final int MAX_TZ_SECONDS = 16 * 60 * 60 - 1; + private static final BigInteger U64_MIN = BigInteger.ZERO; + private static final BigInteger U64_MAX = BigInteger.ONE.shiftLeft(64).subtract(BigInteger.ONE); + private static final BigInteger U64_MODULUS = BigInteger.ONE.shiftLeft(64); + private static final BigInteger HUGEINT_MIN = BigInteger.ONE.shiftLeft(127).negate(); + private static final BigInteger HUGEINT_MAX = BigInteger.ONE.shiftLeft(127).subtract(BigInteger.ONE); + private static final BigInteger UHUGEINT_MIN = BigInteger.ZERO; + private static final BigInteger UHUGEINT_MAX = BigInteger.ONE.shiftLeft(128).subtract(BigInteger.ONE); + + private final int rowCapacity; + private final List columns; + private final UdfScalarWriter[] scalarWriters; + + private int rowCount; + private int pendingColumnIndex = -1; + private boolean closed; + + public UdfOutputAppender(ByteBuffer chunkRef) throws SQLException { + Objects.requireNonNull(chunkRef, "chunkRef"); + + this.rowCapacity = Math.toIntExact(DuckDBVectorWriteCore.MAX_TOP_LEVEL_ROWS); + + int columnCount = Math.toIntExact(duckdb_data_chunk_get_column_count(chunkRef)); + ByteBuffer[] colTypes = new ByteBuffer[columnCount]; + for (int i = 0; i < columnCount; i++) { + ByteBuffer vector = duckdb_data_chunk_get_vector(chunkRef, i); + if (vector == null) { + throw new SQLException("cannot initialize output chunk vector"); + } + colTypes[i] = duckdb_vector_get_column_type(vector); + if (colTypes[i] == null) { + throw new SQLException("cannot initialize output chunk vector type"); + } + } + + List createdColumns; + try { + createdColumns = DuckDBVectorWriteCore.createTopLevelColumns(chunkRef, colTypes); + } catch (Exception e) { + for (ByteBuffer colType : colTypes) { + if (colType != null) { + duckdb_destroy_logical_type(colType); + } + } + throw e; + } + this.columns = createdColumns; + this.scalarWriters = new UdfScalarWriter[columnCount]; + try { + for (int i = 0; i < columnCount; i++) { + Column column = columns.get(i); + DuckDBColumnType type; + try { + type = UdfTypeCatalog.fromCapiTypeId(column.colType.typeId); + } catch (Exception unsupportedType) { + continue; + } + ByteBuffer data = UdfTypeCatalog.requiresVectorRef(type) ? null : column.data; + ByteBuffer vectorRef = UdfTypeCatalog.requiresVectorRef(type) ? column.vectorRef : null; + scalarWriters[i] = + new UdfScalarWriter(column.colType.typeId, data, vectorRef, column.validity, rowCapacity); + } + } catch (Exception e) { + destroyColumns(); + throw e; + } + } + + public int getColumnCount() { + checkOpen(); + return columns.size(); + } + + public int getRowCapacity() { + return rowCapacity; + } + + public int getSize() { + return rowCount; + } + + public void setNull(int columnIndex, int row) { + checkOpen(); + checkColumnIndex(columnIndex); + checkRowIndex(row); + try { + columns.get(columnIndex).setNull(row); + } catch (SQLException e) { + throw new IllegalStateException("Failed to set NULL value: " + e.getMessage(), e); + } + } + + public void setInt(int columnIndex, int row, int value) { + checkOpen(); + checkRowIndex(row); + scalarWriter(columnIndex, UdfTypeCatalog.Accessor.SET_INT, "setInt").setInt(row, value); + } + + public void setLong(int columnIndex, int row, long value) { + checkOpen(); + checkRowIndex(row); + scalarWriter(columnIndex, UdfTypeCatalog.Accessor.SET_LONG, "setLong").setLong(row, value); + } + + public void setFloat(int columnIndex, int row, float value) { + checkOpen(); + checkRowIndex(row); + scalarWriter(columnIndex, UdfTypeCatalog.Accessor.SET_FLOAT, "setFloat").setFloat(row, value); + } + + public void setDouble(int columnIndex, int row, double value) { + checkOpen(); + checkRowIndex(row); + scalarWriter(columnIndex, UdfTypeCatalog.Accessor.SET_DOUBLE, "setDouble").setDouble(row, value); + } + + public void setBoolean(int columnIndex, int row, boolean value) { + checkOpen(); + checkRowIndex(row); + scalarWriter(columnIndex, UdfTypeCatalog.Accessor.SET_BOOLEAN, "setBoolean").setBoolean(row, value); + } + + public void setString(int columnIndex, int row, String value) { + checkOpen(); + checkRowIndex(row); + scalarWriter(columnIndex, UdfTypeCatalog.Accessor.SET_STRING, "setString").setString(row, value); + } + + public void setBytes(int columnIndex, int row, byte[] value) { + checkOpen(); + checkRowIndex(row); + scalarWriter(columnIndex, UdfTypeCatalog.Accessor.SET_BYTES, "setBytes").setBytes(row, value); + } + + public void setObject(int columnIndex, int row, Object value) { + checkOpen(); + checkColumnIndex(columnIndex); + checkRowIndex(row); + try { + writeValue(columns.get(columnIndex), row, value); + } catch (SQLException e) { + throw new IllegalStateException("Failed to set value in output column: " + e.getMessage(), e); + } + } + + public void setBigDecimal(int columnIndex, int row, BigDecimal value) { + checkOpen(); + checkRowIndex(row); + scalarWriter(columnIndex, UdfTypeCatalog.Accessor.SET_DECIMAL, "setBigDecimal").setBigDecimal(row, value); + } + + public void setLocalDate(int columnIndex, int row, LocalDate value) { + setObject(columnIndex, row, value); + } + + public void setLocalTime(int columnIndex, int row, LocalTime value) { + setObject(columnIndex, row, value); + } + + public void setOffsetTime(int columnIndex, int row, OffsetTime value) { + setObject(columnIndex, row, value); + } + + public void setLocalDateTime(int columnIndex, int row, LocalDateTime value) { + setObject(columnIndex, row, value); + } + + public void setDate(int columnIndex, int row, java.util.Date value) { + setObject(columnIndex, row, value); + } + + public void setOffsetDateTime(int columnIndex, int row, OffsetDateTime value) { + setObject(columnIndex, row, value); + } + + public void setUUID(int columnIndex, int row, UUID value) { + setObject(columnIndex, row, value); + } + + public UdfOutputAppender beginRow() { + checkOpen(); + if (pendingColumnIndex >= 0) { + throw new IllegalStateException("endRow must be called before beginRow"); + } + if (rowCount >= rowCapacity) { + throw new IllegalStateException("output row capacity exceeded"); + } + pendingColumnIndex = 0; + return this; + } + + public UdfOutputAppender appendNull() { + return appendNextValue(null); + } + + public UdfOutputAppender append(boolean value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(int value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(long value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(float value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(double value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(BigDecimal value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(String value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(byte[] value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(LocalDate value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(LocalTime value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(OffsetTime value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(LocalDateTime value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(java.util.Date value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(OffsetDateTime value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(UUID value) { + return appendNextValue(value); + } + + public UdfOutputAppender append(Object value) { + return appendNextValue(value); + } + + public UdfOutputAppender endRow() { + checkOpen(); + if (pendingColumnIndex < 0) { + throw new IllegalStateException("beginRow must be called before endRow"); + } + if (pendingColumnIndex != columns.size()) { + throw new IllegalStateException("all columns must be appended before endRow"); + } + pendingColumnIndex = -1; + rowCount++; + return this; + } + + @Override + public void close() { + if (closed) { + return; + } + destroyColumns(); + closed = true; + } + + private UdfOutputAppender appendNextValue(Object value) { + Column column = nextColumnForAppendColumn(); + try { + writeValue(column, rowCount, value); + } catch (SQLException e) { + throw new IllegalStateException("Failed to append row value: " + e.getMessage(), e); + } + pendingColumnIndex++; + return this; + } + + private void destroyColumns() { + for (Column column : columns) { + column.destroy(); + } + } + + private Column nextColumnForAppendColumn() { + checkOpen(); + if (pendingColumnIndex < 0) { + throw new IllegalStateException("beginRow must be called before append"); + } + if (pendingColumnIndex >= columns.size()) { + throw new IllegalStateException("too many values appended in current row"); + } + return columns.get(pendingColumnIndex); + } + + private UdfScalarWriter scalarWriter(int columnIndex, UdfTypeCatalog.Accessor accessor, String method) { + checkColumnIndex(columnIndex); + UdfScalarWriter writer = scalarWriters[columnIndex]; + if (writer == null) { + throw new UnsupportedOperationException(method + " is not available for non-scalar output type " + + columns.get(columnIndex).colType + " at column " + columnIndex + + "; use setObject/append(Object) for nested types"); + } + if (!UdfTypeCatalog.supportsAccessor(writer.getType(), accessor)) { + throw new UnsupportedOperationException(method + " is not supported for output type " + writer.getType() + + " at column " + columnIndex); + } + return writer; + } + + private void checkColumnIndex(int columnIndex) { + if (columnIndex < 0 || columnIndex >= columns.size()) { + throw new IndexOutOfBoundsException("column=" + columnIndex + ", columnCount=" + columns.size()); + } + } + + private void checkRowIndex(int row) { + if (row < 0 || row >= rowCapacity) { + throw new IndexOutOfBoundsException("row=" + row + ", rowCapacity=" + rowCapacity); + } + } + + private void checkOpen() { + if (closed) { + throw new IllegalStateException("UdfOutputAppender is closed"); + } + } + + private void writeValue(Column col, long vectorIdx, Object value) throws SQLException { + if (value == null) { + col.setNull(vectorIdx); + return; + } + switch (col.colType) { + case DUCKDB_TYPE_BOOLEAN: + putByte(col, vectorIdx, (byte) (requireBoolean(value) ? 1 : 0)); + return; + case DUCKDB_TYPE_TINYINT: + putByte(col, vectorIdx, (byte) requireSignedLongInRange(value, Byte.MIN_VALUE, Byte.MAX_VALUE, "TINYINT")); + return; + case DUCKDB_TYPE_UTINYINT: + putByte(col, vectorIdx, (byte) requireSignedLongInRange(value, 0, 0xFFL, "UTINYINT")); + return; + case DUCKDB_TYPE_SMALLINT: + putShort(col, vectorIdx, + (short) requireSignedLongInRange(value, Short.MIN_VALUE, Short.MAX_VALUE, "SMALLINT")); + return; + case DUCKDB_TYPE_USMALLINT: + putShort(col, vectorIdx, (short) requireSignedLongInRange(value, 0, 0xFFFFL, "USMALLINT")); + return; + case DUCKDB_TYPE_INTEGER: + putInt(col, vectorIdx, + (int) requireSignedLongInRange(value, Integer.MIN_VALUE, Integer.MAX_VALUE, "INTEGER")); + return; + case DUCKDB_TYPE_UINTEGER: + putInt(col, vectorIdx, (int) requireSignedLongInRange(value, 0, UNSIGNED_INT_MAX, "UINTEGER")); + return; + case DUCKDB_TYPE_BIGINT: + case DUCKDB_TYPE_UBIGINT: + case DUCKDB_TYPE_TIME: + case DUCKDB_TYPE_TIME_NS: + case DUCKDB_TYPE_TIME_TZ: + case DUCKDB_TYPE_TIMESTAMP: + case DUCKDB_TYPE_TIMESTAMP_S: + case DUCKDB_TYPE_TIMESTAMP_MS: + case DUCKDB_TYPE_TIMESTAMP_NS: + case DUCKDB_TYPE_TIMESTAMP_TZ: + putLong(col, vectorIdx, requireLongOrTemporal(value, col.colType)); + return; + case DUCKDB_TYPE_FLOAT: + putFloat(col, vectorIdx, (float) requireDouble(value)); + return; + case DUCKDB_TYPE_DOUBLE: + putDouble(col, vectorIdx, requireDouble(value)); + return; + case DUCKDB_TYPE_DECIMAL: + UdfNative.setDecimal(col.vectorRef, Math.toIntExact(vectorIdx), requireBigDecimal(value)); + return; + case DUCKDB_TYPE_DATE: + putInt(col, vectorIdx, requireDateEpochDays(value)); + return; + case DUCKDB_TYPE_HUGEINT: + case DUCKDB_TYPE_UHUGEINT: + putFixedWidthBytes(col, vectorIdx, requireInt128Bytes(value, col.colType), 16, col.colType.toString()); + return; + case DUCKDB_TYPE_UUID: + if (value instanceof UUID) { + UUID uuid = (UUID) value; + putUUID(col, vectorIdx, uuid); + } else { + putFixedWidthBytes(col, vectorIdx, requireBytes(value), 16, "UUID"); + } + return; + case DUCKDB_TYPE_VARCHAR: + putStringOrBlob(col, vectorIdx, requireString(value).getBytes(UTF_8)); + return; + case DUCKDB_TYPE_BLOB: + putStringOrBlob(col, vectorIdx, requireBytes(value)); + return; + case DUCKDB_TYPE_ENUM: + putEnum(col, vectorIdx, requireString(value)); + return; + case DUCKDB_TYPE_ARRAY: + case DUCKDB_TYPE_LIST: + writeCollection(col, vectorIdx, value); + return; + case DUCKDB_TYPE_MAP: + writeMap(col, vectorIdx, value); + return; + case DUCKDB_TYPE_STRUCT: + writeStruct(col, vectorIdx, value); + return; + case DUCKDB_TYPE_UNION: + writeUnion(col, vectorIdx, value); + return; + default: + throw new IllegalArgumentException("Unsupported output type for UdfOutputAppender: " + col.colType); + } + } + + private void writeCollection(Column parentCol, long vectorIdx, Object value) throws SQLException { + Column innerCol = requireSingleChild(parentCol, "collection"); + List values = asListValues(value); + int count = values.size(); + int offset = prepareListColumn(innerCol, vectorIdx, count); + for (int i = 0; i < count; i++) { + writeValue(innerCol, offset + i, values.get(i)); + } + } + + private void writeMap(Column mapCol, long vectorIdx, Object value) throws SQLException { + if (!(value instanceof Map)) { + throw new IllegalArgumentException("Expected java.util.Map for MAP column but got " + + value.getClass().getName()); + } + Column entryStructCol = requireSingleChild(mapCol, "map"); + Map map = (Map) value; + int offset = prepareListColumn(entryStructCol, vectorIdx, map.size()); + int index = 0; + for (Map.Entry entry : map.entrySet()) { + writeStructByPosition(entryStructCol, offset + index, Arrays.asList(entry.getKey(), entry.getValue())); + index++; + } + } + + private void writeStruct(Column structCol, long vectorIdx, Object value) throws SQLException { + if (value instanceof Map) { + writeStructByName(structCol, vectorIdx, (Map) value); + return; + } + writeStructByPosition(structCol, vectorIdx, asListValues(value)); + } + + private void writeStructByName(Column structCol, long vectorIdx, Map values) throws SQLException { + if (structCol.children.size() != values.size()) { + throw new IllegalArgumentException("Struct field count mismatch, expected " + structCol.children.size() + + " values but got " + values.size()); + } + for (Column child : structCol.children) { + String fieldName = child.structFieldName; + if (!values.containsKey(fieldName)) { + throw new IllegalArgumentException("Struct value map does not contain field '" + fieldName + "'"); + } + writeValue(child, vectorIdx, values.get(fieldName)); + } + } + + private void writeStructByPosition(Column structCol, long vectorIdx, List values) throws SQLException { + if (structCol.children.size() != values.size()) { + throw new IllegalArgumentException("Struct field count mismatch, expected " + structCol.children.size() + + " values but got " + values.size()); + } + for (int i = 0; i < values.size(); i++) { + writeValue(structCol.children.get(i), vectorIdx, values.get(i)); + } + } + + private void writeUnion(Column unionCol, long vectorIdx, Object value) throws SQLException { + if (!(value instanceof AbstractMap.SimpleEntry)) { + throw new IllegalArgumentException( + "Union values must be java.util.AbstractMap.SimpleEntry"); + } + AbstractMap.SimpleEntry entry = (AbstractMap.SimpleEntry) value; + String tag = String.valueOf(entry.getKey()); + Column selected = selectUnionField(unionCol, vectorIdx, tag); + writeValue(selected, vectorIdx, entry.getValue()); + } + + private Column selectUnionField(Column unionCol, long vectorIdx, String tag) throws SQLException { + if (unionCol.children.isEmpty()) { + throw new IllegalArgumentException("Invalid UNION column without children"); + } + int selectedIndex = -1; + for (int i = 1; i < unionCol.children.size(); i++) { + Column child = unionCol.children.get(i); + if (Objects.equals(child.structFieldName, tag)) { + selectedIndex = i; + break; + } + } + if (selectedIndex < 0) { + throw new IllegalArgumentException("Unknown UNION tag '" + tag + "'"); + } + + Column tagCol = unionCol.children.get(0); + putByte(tagCol, vectorIdx, (byte) (selectedIndex - 1)); + for (int i = 1; i < unionCol.children.size(); i++) { + if (i != selectedIndex) { + unionCol.children.get(i).setNull(vectorIdx); + } + } + return unionCol.children.get(selectedIndex); + } + + private void putEnum(Column col, long vectorIdx, String value) { + Integer dictValue = col.enumDict.get(value); + if (dictValue == null) { + throw new IllegalArgumentException("Invalid enum value '" + value + "', expected one of " + + col.enumDict.keySet()); + } + + int pos = (int) (vectorIdx * col.enumInternalType.widthBytes); + col.data.position(pos); + switch (col.enumInternalType) { + case DUCKDB_TYPE_UTINYINT: + col.data.put(dictValue.byteValue()); + return; + case DUCKDB_TYPE_USMALLINT: + col.data.putShort(dictValue.shortValue()); + return; + case DUCKDB_TYPE_UINTEGER: + col.data.putInt(dictValue); + return; + default: + throw new IllegalArgumentException("Unsupported enum storage type " + col.enumInternalType); + } + } + + private int prepareListColumn(Column innerCol, long vectorIdx, long listElementsCount) throws SQLException { + if (innerCol.parent == null) { + throw new IllegalArgumentException("Invalid collection inner column"); + } + Column parentCol = innerCol.parent; + switch (parentCol.colType) { + case DUCKDB_TYPE_ARRAY: + if (innerCol.arraySize != listElementsCount) { + throw new IllegalArgumentException("Fixed ARRAY size mismatch, expected " + innerCol.arraySize + + " values but got " + listElementsCount); + } + return (int) (vectorIdx * innerCol.arraySize); + case DUCKDB_TYPE_LIST: + case DUCKDB_TYPE_MAP: + long offset = duckdb_list_vector_get_size(parentCol.vectorRef); + LongBuffer listEntries = parentCol.data.asLongBuffer(); + int entryPos = (int) (vectorIdx * DUCKDB_TYPE_LIST.widthBytes / Long.BYTES); + listEntries.position(entryPos); + listEntries.put(offset); + listEntries.put(listElementsCount); + long listSize = offset + listElementsCount; + int reserveStatus = duckdb_list_vector_reserve(parentCol.vectorRef, listSize); + if (reserveStatus != 0) { + throw new SQLException("'duckdb_list_vector_reserve' failed for list size " + listSize); + } + innerCol.reset(listSize); + int setSizeStatus = duckdb_list_vector_set_size(parentCol.vectorRef, listSize); + if (setSizeStatus != 0) { + throw new SQLException("'duckdb_list_vector_set_size' failed for list size " + listSize); + } + return (int) offset; + default: + throw new IllegalArgumentException("Invalid collection parent type " + parentCol.colType); + } + } + + private static Column requireSingleChild(Column parentCol, String kind) { + if (parentCol.children.size() != 1) { + throw new IllegalArgumentException("Invalid " + kind + " type layout, expected single child"); + } + return parentCol.children.get(0); + } + + private static List asListValues(Object value) { + if (value instanceof List) { + return (List) value; + } + if (value instanceof Collection) { + return new ArrayList<>((Collection) value); + } + if (value instanceof Object[]) { + return Arrays.asList((Object[]) value); + } + if (value instanceof boolean[]) { + boolean[] arr = (boolean[]) value; + List out = new ArrayList<>(arr.length); + for (boolean v : arr) { + out.add(v); + } + return out; + } + if (value instanceof byte[]) { + byte[] arr = (byte[]) value; + List out = new ArrayList<>(arr.length); + for (byte v : arr) { + out.add(v); + } + return out; + } + if (value instanceof short[]) { + short[] arr = (short[]) value; + List out = new ArrayList<>(arr.length); + for (short v : arr) { + out.add(v); + } + return out; + } + if (value instanceof int[]) { + int[] arr = (int[]) value; + List out = new ArrayList<>(arr.length); + for (int v : arr) { + out.add(v); + } + return out; + } + if (value instanceof long[]) { + long[] arr = (long[]) value; + List out = new ArrayList<>(arr.length); + for (long v : arr) { + out.add(v); + } + return out; + } + if (value instanceof float[]) { + float[] arr = (float[]) value; + List out = new ArrayList<>(arr.length); + for (float v : arr) { + out.add(v); + } + return out; + } + if (value instanceof double[]) { + double[] arr = (double[]) value; + List out = new ArrayList<>(arr.length); + for (double v : arr) { + out.add(v); + } + return out; + } + throw new IllegalArgumentException("Expected collection/array value but got " + value.getClass().getName()); + } + + private static boolean requireBoolean(Object value) { + if (!(value instanceof Boolean)) { + throw new IllegalArgumentException("Expected Boolean value but got " + value.getClass().getName()); + } + return (Boolean) value; + } + + private static BigDecimal requireBigDecimal(Object value) { + if (value instanceof BigDecimal) { + return (BigDecimal) value; + } + if (value instanceof Number) { + return toBigDecimal((Number) value); + } + throw new IllegalArgumentException("Expected BigDecimal/Number value but got " + value.getClass().getName()); + } + + private static BigDecimal toBigDecimal(Number value) { + if (value instanceof BigInteger) { + return new BigDecimal((BigInteger) value); + } + if (value instanceof Byte || value instanceof Short || value instanceof Integer || value instanceof Long) { + return BigDecimal.valueOf(value.longValue()); + } + if (value instanceof Float || value instanceof Double) { + try { + return BigDecimal.valueOf(value.doubleValue()); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Expected finite floating-point value for DECIMAL conversion", e); + } + } + try { + return new BigDecimal(value.toString()); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Failed to coerce numeric value to BigDecimal", e); + } + } + + private static int requireDateEpochDays(Object value) { + if (value instanceof LocalDate) { + long days = ((LocalDate) value).toEpochDay(); + if (days < Integer.MIN_VALUE || days > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Expected LocalDate epoch day to fit int32 but got " + days); + } + return (int) days; + } + if (value instanceof java.util.Date) { + LocalDate localDate = Instant.ofEpochMilli(((java.util.Date) value).getTime()).atOffset(UTC).toLocalDate(); + long days = localDate.toEpochDay(); + if (days < Integer.MIN_VALUE || days > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Expected Date epoch day to fit int32 but got " + days); + } + return (int) days; + } + return (int) requireSignedLongInRange(value, Integer.MIN_VALUE, Integer.MAX_VALUE, "DATE"); + } + + private static long requireLongOrTemporal(Object value, DuckDBBindings.CAPIType colType) { + if (value instanceof Number) { + if (colType == DUCKDB_TYPE_UBIGINT) { + return requireUnsignedLongBits(value, "UBIGINT"); + } + return requireIntegralLong(value, colType.toString()); + } + switch (colType) { + case DUCKDB_TYPE_TIME: + if (value instanceof LocalTime) { + return ((LocalTime) value).toNanoOfDay() / 1000L; + } + break; + case DUCKDB_TYPE_TIME_NS: + if (value instanceof LocalTime) { + return ((LocalTime) value).toNanoOfDay(); + } + break; + case DUCKDB_TYPE_TIME_TZ: + if (value instanceof OffsetTime) { + OffsetTime time = (OffsetTime) value; + return packTimeTzMicros(time.toLocalTime().toNanoOfDay() / 1000L, time.getOffset().getTotalSeconds()); + } + break; + case DUCKDB_TYPE_TIMESTAMP_S: + case DUCKDB_TYPE_TIMESTAMP_MS: + case DUCKDB_TYPE_TIMESTAMP: + case DUCKDB_TYPE_TIMESTAMP_NS: + if (value instanceof LocalDateTime) { + return localDateTimeToMoment((LocalDateTime) value, colType); + } + if (value instanceof java.util.Date) { + return dateToMoment((java.util.Date) value, colType); + } + break; + case DUCKDB_TYPE_TIMESTAMP_TZ: + if (value instanceof OffsetDateTime) { + ZonedDateTime zdt = ((OffsetDateTime) value).atZoneSameInstant(UTC); + return EPOCH_DATE_TIME.until(zdt.toLocalDateTime(), MICROS); + } + if (value instanceof java.util.Date) { + return Math.multiplyExact(((java.util.Date) value).getTime(), 1000L); + } + break; + default: + break; + } + throw new IllegalArgumentException("Expected numeric/temporal value compatible with " + colType + " but got " + + value.getClass().getName()); + } + + private static long localDateTimeToMoment(LocalDateTime value, DuckDBBindings.CAPIType colType) { + switch (colType) { + case DUCKDB_TYPE_TIMESTAMP_S: + return EPOCH_DATE_TIME.until(value, SECONDS); + case DUCKDB_TYPE_TIMESTAMP_MS: + return EPOCH_DATE_TIME.until(value, MILLIS); + case DUCKDB_TYPE_TIMESTAMP: + return EPOCH_DATE_TIME.until(value, MICROS); + case DUCKDB_TYPE_TIMESTAMP_NS: + return EPOCH_DATE_TIME.until(value, NANOS); + default: + throw new IllegalArgumentException("Unsupported LocalDateTime conversion for " + colType); + } + } + + private static long dateToMoment(java.util.Date value, DuckDBBindings.CAPIType colType) { + long millis = value.getTime(); + switch (colType) { + case DUCKDB_TYPE_TIMESTAMP_S: + return millis / 1000L; + case DUCKDB_TYPE_TIMESTAMP_MS: + return millis; + case DUCKDB_TYPE_TIMESTAMP: + return Math.multiplyExact(millis, 1000L); + case DUCKDB_TYPE_TIMESTAMP_NS: + return Math.multiplyExact(millis, 1000000L); + default: + throw new IllegalArgumentException("Unsupported java.util.Date conversion for " + colType); + } + } + + private static long packTimeTzMicros(long micros, int offsetSeconds) { + if (offsetSeconds < -MAX_TZ_SECONDS || offsetSeconds > MAX_TZ_SECONDS) { + throw new IllegalArgumentException("TIME WITH TIME ZONE offset out of range: " + offsetSeconds + + " seconds (allowed range: -" + MAX_TZ_SECONDS + ".." + MAX_TZ_SECONDS + + ")"); + } + long normalizedOffset = MAX_TZ_SECONDS - offsetSeconds; + return ((micros & 0xFFFFFFFFFFL) << 24) | (normalizedOffset & 0xFFFFFFL); + } + + private static double requireDouble(Object value) { + if (!(value instanceof Number)) { + throw new IllegalArgumentException("Expected numeric value but got " + value.getClass().getName()); + } + return ((Number) value).doubleValue(); + } + + private static long requireSignedLongInRange(Object value, long min, long max, String typeName) { + long num = requireIntegralLong(value, typeName); + if (num < min || num > max) { + throw new IllegalArgumentException("Value out of range for " + typeName + ": " + num); + } + return num; + } + + private static long requireIntegralLong(Object value, String typeName) { + BigInteger integerValue = requireIntegralBigInteger(value, typeName); + try { + return integerValue.longValueExact(); + } catch (ArithmeticException e) { + throw new IllegalArgumentException("Value out of range for " + typeName + ": " + integerValue, e); + } + } + + private static long requireUnsignedLongBits(Object value, String typeName) { + BigInteger integerValue = requireIntegralBigInteger(value, typeName); + if (integerValue.compareTo(U64_MIN) < 0 || integerValue.compareTo(U64_MAX) > 0) { + throw new IllegalArgumentException("Value out of range for " + typeName + ": " + integerValue); + } + if (integerValue.signum() >= 0 && integerValue.bitLength() <= 63) { + return integerValue.longValue(); + } + return integerValue.subtract(U64_MODULUS).longValue(); + } + + private static BigInteger requireIntegralBigInteger(Object value, String typeName) { + if (!(value instanceof Number)) { + throw new IllegalArgumentException("Expected numeric value for " + typeName + " but got " + + value.getClass().getName()); + } + + Number number = (Number) value; + if (number instanceof BigInteger) { + return (BigInteger) number; + } + if (number instanceof BigDecimal) { + try { + return ((BigDecimal) number).toBigIntegerExact(); + } catch (ArithmeticException e) { + throw new IllegalArgumentException("Expected integral value for " + typeName + ": " + number, e); + } + } + if (number instanceof Byte || number instanceof Short || number instanceof Integer || number instanceof Long) { + return BigInteger.valueOf(number.longValue()); + } + if (number instanceof Float || number instanceof Double) { + double d = number.doubleValue(); + if (!Double.isFinite(d)) { + throw new IllegalArgumentException("Expected finite value for " + typeName + " but got " + d); + } + try { + return BigDecimal.valueOf(d).toBigIntegerExact(); + } catch (ArithmeticException e) { + throw new IllegalArgumentException("Expected integral value for " + typeName + ": " + number, e); + } + } + + try { + return new BigDecimal(number.toString()).toBigIntegerExact(); + } catch (NumberFormatException | ArithmeticException e) { + throw new IllegalArgumentException("Expected integral value for " + typeName + ": " + number, e); + } + } + + private static String requireString(Object value) { + if (!(value instanceof String)) { + throw new IllegalArgumentException("Expected String value but got " + value.getClass().getName()); + } + return (String) value; + } + + private static byte[] requireBytes(Object value) { + if (!(value instanceof byte[])) { + throw new IllegalArgumentException("Expected byte[] value but got " + value.getClass().getName()); + } + return (byte[]) value; + } + + private static byte[] requireInt128Bytes(Object value, DuckDBBindings.CAPIType type) { + if (value instanceof byte[]) { + return (byte[]) value; + } + if (value instanceof BigInteger) { + return toInt128Bytes((BigInteger) value, type); + } + throw new IllegalArgumentException("Expected BigInteger/byte[] value for " + type + " but got " + + value.getClass().getName()); + } + + private static byte[] toInt128Bytes(BigInteger value, DuckDBBindings.CAPIType type) { + if (value == null) { + throw new IllegalArgumentException("BigInteger value must not be null"); + } + if (type == DUCKDB_TYPE_HUGEINT) { + if (value.compareTo(HUGEINT_MIN) < 0 || value.compareTo(HUGEINT_MAX) > 0) { + throw new IllegalArgumentException("Value out of range for HUGEINT: " + value); + } + } else if (type == DUCKDB_TYPE_UHUGEINT) { + if (value.compareTo(UHUGEINT_MIN) < 0 || value.compareTo(UHUGEINT_MAX) > 0) { + throw new IllegalArgumentException("Value out of range for UHUGEINT: " + value); + } + } else { + throw new IllegalArgumentException("Int128 conversion is only supported for HUGEINT/UHUGEINT"); + } + + long lower = value.longValue(); + long upper = value.shiftRight(64).longValue(); + ByteBuffer buffer = ByteBuffer.allocate(16).order(java.nio.ByteOrder.nativeOrder()); + buffer.putLong(lower); + buffer.putLong(upper); + return buffer.array(); + } + + private static void putByte(Column col, long vectorIdx, byte value) { + int pos = (int) (vectorIdx * col.colType.widthBytes); + col.data.position(pos); + col.data.put(value); + } + + private static void putShort(Column col, long vectorIdx, short value) { + int pos = (int) (vectorIdx * col.colType.widthBytes); + col.data.position(pos); + col.data.putShort(value); + } + + private static void putInt(Column col, long vectorIdx, int value) { + int pos = (int) (vectorIdx * col.colType.widthBytes); + col.data.position(pos); + col.data.putInt(value); + } + + private static void putLong(Column col, long vectorIdx, long value) { + int pos = (int) (vectorIdx * col.colType.widthBytes); + col.data.position(pos); + col.data.putLong(value); + } + + private static void putFloat(Column col, long vectorIdx, float value) { + int pos = (int) (vectorIdx * col.colType.widthBytes); + col.data.position(pos); + col.data.putFloat(value); + } + + private static void putDouble(Column col, long vectorIdx, double value) { + int pos = (int) (vectorIdx * col.colType.widthBytes); + col.data.position(pos); + col.data.putDouble(value); + } + + private static void putUUID(Column col, long vectorIdx, UUID value) { + int pos = (int) (vectorIdx * col.colType.widthBytes); + col.data.position(pos); + long leastSigBits = value.getLeastSignificantBits(); + long mostSigBits = value.getMostSignificantBits(); + col.data.putLong(leastSigBits); + col.data.putLong(mostSigBits ^ Long.MIN_VALUE); + } + + private static void putFixedWidthBytes(Column col, long vectorIdx, byte[] value, int width, String typeName) { + if (value.length != width) { + throw new IllegalArgumentException("Expected " + width + " bytes for " + typeName + " value, got " + + value.length); + } + int pos = (int) (vectorIdx * col.colType.widthBytes); + col.data.position(pos); + col.data.put(value); + } + + private static void putStringOrBlob(Column col, long vectorIdx, byte[] value) { + duckdb_vector_assign_string_element_len(col.vectorRef, vectorIdx, value); + } +} diff --git a/src/main/java/org/duckdb/UdfReader.java b/src/main/java/org/duckdb/UdfReader.java new file mode 100644 index 000000000..e2f0f8cb3 --- /dev/null +++ b/src/main/java/org/duckdb/UdfReader.java @@ -0,0 +1,51 @@ +package org.duckdb; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.time.OffsetTime; +import java.util.Date; +import java.util.UUID; + +public interface UdfReader { + DuckDBColumnType getType(); + + boolean isNull(int row); + + int getInt(int row); + + long getLong(int row); + + float getFloat(int row); + + double getDouble(int row); + + BigDecimal getBigDecimal(int row); + + default BigInteger getBigInteger(int row) { + throw new UnsupportedOperationException("getBigInteger is not supported for this reader implementation"); + } + + Date getDate(int row); + + LocalDate getLocalDate(int row); + + LocalTime getLocalTime(int row); + + OffsetTime getOffsetTime(int row); + + LocalDateTime getLocalDateTime(int row); + + OffsetDateTime getOffsetDateTime(int row); + + UUID getUUID(int row); + + boolean getBoolean(int row); + + String getString(int row); + + byte[] getBytes(int row); +} diff --git a/src/main/java/org/duckdb/UdfScalarWriter.java b/src/main/java/org/duckdb/UdfScalarWriter.java new file mode 100644 index 000000000..0f3155713 --- /dev/null +++ b/src/main/java/org/duckdb/UdfScalarWriter.java @@ -0,0 +1,981 @@ +package org.duckdb; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.time.OffsetTime; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.time.temporal.ChronoUnit; +import java.util.Date; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +public final class UdfScalarWriter { + private static final long UNSIGNED_INT_MAX = 0xFFFF_FFFFL; + private static final LocalDateTime EPOCH_DATE_TIME = LocalDateTime.ofEpochSecond(0, 0, ZoneOffset.UTC); + private static final int MAX_TZ_SECONDS = 16 * 60 * 60 - 1; + private static final BigInteger U64_MIN = BigInteger.ZERO; + private static final BigInteger U64_MAX = BigInteger.ONE.shiftLeft(64).subtract(BigInteger.ONE); + private static final BigInteger U64_MODULUS = BigInteger.ONE.shiftLeft(64); + private static final BigInteger HUGEINT_MIN = BigInteger.ONE.shiftLeft(127).negate(); + private static final BigInteger HUGEINT_MAX = BigInteger.ONE.shiftLeft(127).subtract(BigInteger.ONE); + private static final BigInteger UHUGEINT_MIN = BigInteger.ZERO; + private static final BigInteger UHUGEINT_MAX = BigInteger.ONE.shiftLeft(128).subtract(BigInteger.ONE); + + private final DuckDBColumnType type; + private final ByteBuffer data; + private final ByteBuffer vectorRef; + private final ByteBuffer validity; + private final int rowCount; + + public UdfScalarWriter(int capiTypeId, ByteBuffer data, ByteBuffer vectorRef, ByteBuffer validity, int rowCount) { + this(resolveType(capiTypeId), data, vectorRef, validity, rowCount); + } + + private UdfScalarWriter(DuckDBColumnType type, ByteBuffer data, ByteBuffer vectorRef, ByteBuffer validity, + int rowCount) { + if (type == null) { + throw new IllegalArgumentException("type must not be null"); + } + if (!UdfTypeCatalog.isScalarUdfImplemented(type)) { + throw new IllegalArgumentException("Unsupported scalar UDF output type: " + type); + } + if (rowCount < 0) { + throw new IllegalArgumentException("rowCount must be non-negative"); + } + if (UdfTypeCatalog.requiresVectorRef(type)) { + if (vectorRef == null) { + throw new IllegalArgumentException("vectorRef is required for vectors backed by native accessors"); + } + } else if (data == null) { + throw new IllegalArgumentException("data is required for fixed-size vectors"); + } + + this.type = type; + this.data = data == null ? null : data.order(ByteOrder.nativeOrder()); + this.vectorRef = vectorRef; + this.validity = validity; + this.rowCount = rowCount; + } + + private static DuckDBColumnType resolveType(int capiTypeId) { + try { + return UdfTypeCatalog.fromCapiTypeId(capiTypeId); + } catch (SQLException e) { + throw new IllegalArgumentException("Unsupported scalar UDF C API type id: " + capiTypeId, e); + } + } + + private void checkIndex(int row) { + if (row < 0 || row >= rowCount) { + throw new IndexOutOfBoundsException("row=" + row + ", rowCount=" + rowCount); + } + } + + private void requireAccessor(UdfTypeCatalog.Accessor accessor, String method) { + if (!UdfTypeCatalog.supportsAccessor(type, accessor)) { + throw new UnsupportedOperationException(method + " is not supported for " + type + " vectors"); + } + } + + private void markValid(int row) { + if (validity == null) { + return; + } + int byteIndex = row / 8; + int bitIndex = row % 8; + int current = validity.get(byteIndex) & 0xFF; + validity.put(byteIndex, (byte) (current | (1 << bitIndex))); + } + + private int fixedWidthBytesForByteAccessor() { + switch (type) { + case HUGEINT: + case UHUGEINT: + case UUID: + return 16; + default: + throw new IllegalStateException("Unexpected type for byte accessor: " + type); + } + } + + private byte[] readFixedWidthBytes(int row) { + int width = fixedWidthBytesForByteAccessor(); + byte[] value = new byte[width]; + ByteBuffer buffer = data.duplicate(); + buffer.position(row * width); + buffer.get(value); + return value; + } + + private void writeFixedWidthBytes(int row, byte[] value) { + int width = fixedWidthBytesForByteAccessor(); + if (value.length != width) { + throw new IllegalArgumentException("Expected " + width + " bytes for " + type + " value, got " + + value.length); + } + ByteBuffer buffer = data.duplicate(); + buffer.position(row * width); + buffer.put(value); + } + + public DuckDBColumnType getType() { + return type; + } + + public boolean isNull(int row) { + checkIndex(row); + if (validity == null) { + return false; + } + int byteIndex = row / 8; + int bitIndex = row % 8; + int mask = 1 << bitIndex; + return (validity.get(byteIndex) & mask) == 0; + } + + public void setNull(int row) { + checkIndex(row); + if (validity == null) { + throw new UnsupportedOperationException("setNull requires a writable validity buffer"); + } + int byteIndex = row / 8; + int bitIndex = row % 8; + int mask = ~(1 << bitIndex); + int current = validity.get(byteIndex) & 0xFF; + validity.put(byteIndex, (byte) (current & mask)); + } + + public int getInt(int row) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.GET_INT, "getInt"); + switch (type) { + case TINYINT: + return data.get(row); + case UTINYINT: + return Byte.toUnsignedInt(data.get(row)); + case SMALLINT: + return data.getShort(row * Short.BYTES); + case USMALLINT: + return Short.toUnsignedInt(data.getShort(row * Short.BYTES)); + case INTEGER: + case DATE: + return data.getInt(row * Integer.BYTES); + default: + throw new IllegalStateException("Unexpected type for getInt: " + type); + } + } + + public long getLong(int row) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.GET_LONG, "getLong"); + switch (type) { + case BIGINT: + case UBIGINT: + case TIME: + case TIME_NS: + case TIME_WITH_TIME_ZONE: + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + return data.getLong(row * Long.BYTES); + case UINTEGER: + return Integer.toUnsignedLong(data.getInt(row * Integer.BYTES)); + default: + throw new IllegalStateException("Unexpected type for getLong: " + type); + } + } + + public float getFloat(int row) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.GET_FLOAT, "getFloat"); + return data.getFloat(row * Float.BYTES); + } + + public double getDouble(int row) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.GET_DOUBLE, "getDouble"); + return data.getDouble(row * Double.BYTES); + } + + public BigDecimal getBigDecimal(int row) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.GET_DECIMAL, "getBigDecimal"); + if (isNull(row)) { + return null; + } + try { + return UdfNative.getDecimal(vectorRef, row); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + public BigInteger getBigInteger(int row) { + checkIndex(row); + if (isNull(row)) { + return null; + } + if (type != DuckDBColumnType.HUGEINT && type != DuckDBColumnType.UHUGEINT) { + throw new UnsupportedOperationException("getBigInteger is not supported for " + type + " vectors"); + } + + byte[] bytes = readFixedWidthBytes(row); + ByteBuffer buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()); + long lower = buffer.getLong(); + long upper = buffer.getLong(); + if (type == DuckDBColumnType.HUGEINT) { + return BigInteger.valueOf(upper).shiftLeft(64).add(toUnsignedBigInteger(lower)); + } + return toUnsignedBigInteger(upper).shiftLeft(64).add(toUnsignedBigInteger(lower)); + } + + public Date getDate(int row) { + checkIndex(row); + if (isNull(row)) { + return null; + } + switch (type) { + case DATE: + return java.sql.Date.valueOf(getLocalDate(row)); + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: { + LocalDateTime localDateTime = getLocalDateTime(row); + Timestamp dayTimestamp = Timestamp.valueOf(localDateTime.truncatedTo(ChronoUnit.DAYS)); + return new Date(dayTimestamp.getTime()); + } + default: + throw new UnsupportedOperationException("getDate is not supported for " + type + " vectors"); + } + } + + public LocalDate getLocalDate(int row) { + checkIndex(row); + if (isNull(row)) { + return null; + } + switch (type) { + case DATE: + return LocalDate.ofEpochDay(getInt(row)); + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + return getLocalDateTime(row).toLocalDate(); + default: + throw new UnsupportedOperationException("getLocalDate is not supported for " + type + " vectors"); + } + } + + public LocalTime getLocalTime(int row) { + checkIndex(row); + if (isNull(row)) { + return null; + } + switch (type) { + case TIME: + return LocalTime.ofNanoOfDay(TimeUnit.MICROSECONDS.toNanos(getLong(row))); + case TIME_NS: + return LocalTime.ofNanoOfDay(getLong(row)); + case TIME_WITH_TIME_ZONE: + return getOffsetTime(row).toLocalTime(); + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + return getLocalDateTime(row).toLocalTime(); + default: + throw new UnsupportedOperationException("getLocalTime is not supported for " + type + " vectors"); + } + } + + public OffsetTime getOffsetTime(int row) { + checkIndex(row); + if (isNull(row)) { + return null; + } + switch (type) { + case TIME: + case TIME_NS: + return getLocalTime(row).atOffset(ZoneOffset.UTC); + case TIME_WITH_TIME_ZONE: + return DuckDBTimestamp.toOffsetTime(getLong(row)); + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + return getOffsetDateTime(row).toOffsetTime(); + default: + throw new UnsupportedOperationException("getOffsetTime is not supported for " + type + " vectors"); + } + } + + public LocalDateTime getLocalDateTime(int row) { + checkIndex(row); + if (isNull(row)) { + return null; + } + switch (type) { + case DATE: + return LocalDate.ofEpochDay(getInt(row)).atStartOfDay(); + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + return timestampToLocalDateTime(row); + default: + throw new UnsupportedOperationException("getLocalDateTime is not supported for " + type + " vectors"); + } + } + + public OffsetDateTime getOffsetDateTime(int row) { + checkIndex(row); + if (isNull(row)) { + return null; + } + switch (type) { + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: { + LocalDateTime localDateTime = getLocalDateTime(row); + Instant instant = localDateTime.toInstant(ZoneOffset.UTC); + ZoneOffset zoneOffset = ZoneId.systemDefault().getRules().getOffset(instant); + return localDateTime.atOffset(zoneOffset); + } + default: + throw new UnsupportedOperationException("getOffsetDateTime is not supported for " + type + " vectors"); + } + } + + public UUID getUUID(int row) { + checkIndex(row); + if (isNull(row)) { + return null; + } + if (type != DuckDBColumnType.UUID) { + throw new UnsupportedOperationException("getUUID is not supported for " + type + " vectors"); + } + byte[] uuidBytes = getBytes(row); + ByteBuffer buffer = ByteBuffer.wrap(uuidBytes).order(ByteOrder.nativeOrder()); + long leastSignificantBits = buffer.getLong(); + long mostSignificantBits = buffer.getLong() ^ Long.MIN_VALUE; + return new UUID(mostSignificantBits, leastSignificantBits); + } + + private LocalDateTime timestampToLocalDateTime(int row) { + try { + long value = getLong(row); + switch (type) { + case TIMESTAMP: + return DuckDBTimestamp.localDateTimeFromTimestamp(value, ChronoUnit.MICROS, null); + case TIMESTAMP_MS: + return DuckDBTimestamp.localDateTimeFromTimestamp(value, ChronoUnit.MILLIS, null); + case TIMESTAMP_NS: + return DuckDBTimestamp.localDateTimeFromTimestamp(value, ChronoUnit.NANOS, null); + case TIMESTAMP_S: + return DuckDBTimestamp.localDateTimeFromTimestamp(value, ChronoUnit.SECONDS, null); + case TIMESTAMP_WITH_TIME_ZONE: + return DuckDBTimestamp.localDateTimeFromTimestampWithTimezone(value, ChronoUnit.MICROS, null); + default: + throw new UnsupportedOperationException("timestampToLocalDateTime is not supported for " + type + + " vectors"); + } + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + public boolean getBoolean(int row) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.GET_BOOLEAN, "getBoolean"); + return data.get(row) != 0; + } + + public String getString(int row) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.GET_STRING, "getString"); + if (isNull(row)) { + return null; + } + try { + return UdfNative.getVarchar(vectorRef, row); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + public byte[] getBytes(int row) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.GET_BYTES, "getBytes"); + if (isNull(row)) { + return null; + } + if (type == DuckDBColumnType.BLOB) { + try { + return UdfNative.getBlob(vectorRef, row); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + return readFixedWidthBytes(row); + } + + public void setInt(int row, int value) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.SET_INT, "setInt"); + switch (type) { + case TINYINT: + if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { + throw new IllegalArgumentException("Value out of range for TINYINT: " + value); + } + data.put(row, (byte) value); + break; + case UTINYINT: + if (value < 0 || value > 0xFF) { + throw new IllegalArgumentException("Value out of range for UTINYINT: " + value); + } + data.put(row, (byte) value); + break; + case SMALLINT: + if (value < Short.MIN_VALUE || value > Short.MAX_VALUE) { + throw new IllegalArgumentException("Value out of range for SMALLINT: " + value); + } + data.putShort(row * Short.BYTES, (short) value); + break; + case USMALLINT: + if (value < 0 || value > 0xFFFF) { + throw new IllegalArgumentException("Value out of range for USMALLINT: " + value); + } + data.putShort(row * Short.BYTES, (short) value); + break; + case INTEGER: + case DATE: + data.putInt(row * Integer.BYTES, value); + break; + default: + throw new IllegalStateException("Unexpected type for setInt: " + type); + } + markValid(row); + } + + public void setLong(int row, long value) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.SET_LONG, "setLong"); + switch (type) { + case BIGINT: + case UBIGINT: + case TIME: + case TIME_NS: + case TIME_WITH_TIME_ZONE: + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + data.putLong(row * Long.BYTES, value); + break; + case UINTEGER: + if (value < 0 || value > UNSIGNED_INT_MAX) { + throw new IllegalArgumentException("Value out of range for UINTEGER: " + value); + } + data.putInt(row * Integer.BYTES, (int) value); + break; + default: + throw new IllegalStateException("Unexpected type for setLong: " + type); + } + markValid(row); + } + + public void setFloat(int row, float value) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.SET_FLOAT, "setFloat"); + data.putFloat(row * Float.BYTES, value); + markValid(row); + } + + public void setDouble(int row, double value) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.SET_DOUBLE, "setDouble"); + data.putDouble(row * Double.BYTES, value); + markValid(row); + } + + public void setBigDecimal(int row, BigDecimal value) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.SET_DECIMAL, "setBigDecimal"); + if (value == null) { + setNull(row); + return; + } + try { + UdfNative.setDecimal(vectorRef, row, value); + } catch (SQLException e) { + throw new RuntimeException(e); + } + markValid(row); + } + + public void setBoolean(int row, boolean value) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.SET_BOOLEAN, "setBoolean"); + data.put(row, (byte) (value ? 1 : 0)); + markValid(row); + } + + public void setString(int row, String value) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.SET_STRING, "setString"); + if (value == null) { + setNull(row); + return; + } + try { + UdfNative.setVarchar(vectorRef, row, value); + } catch (SQLException e) { + throw new RuntimeException(e); + } + markValid(row); + } + + public void setBytes(int row, byte[] value) { + checkIndex(row); + requireAccessor(UdfTypeCatalog.Accessor.SET_BYTES, "setBytes"); + if (value == null) { + setNull(row); + return; + } + if (type == DuckDBColumnType.BLOB) { + try { + UdfNative.setBlob(vectorRef, row, value); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } else { + writeFixedWidthBytes(row, value); + } + markValid(row); + } + + public void setObject(int row, Object value) { + checkIndex(row); + if (value == null) { + setNull(row); + return; + } + switch (type) { + case BOOLEAN: + setBoolean(row, requireBoolean(value)); + return; + case TINYINT: + setInt(row, (int) requireSignedLongInRange(value, Byte.MIN_VALUE, Byte.MAX_VALUE, "TINYINT")); + return; + case UTINYINT: + setInt(row, (int) requireSignedLongInRange(value, 0, 0xFFL, "UTINYINT")); + return; + case SMALLINT: + setInt(row, (int) requireSignedLongInRange(value, Short.MIN_VALUE, Short.MAX_VALUE, "SMALLINT")); + return; + case USMALLINT: + setInt(row, (int) requireSignedLongInRange(value, 0, 0xFFFFL, "USMALLINT")); + return; + case INTEGER: + setInt(row, (int) requireSignedLongInRange(value, Integer.MIN_VALUE, Integer.MAX_VALUE, "INTEGER")); + return; + case DATE: + setInt(row, requireDateEpochDays(value)); + return; + case UINTEGER: + setLong(row, requireSignedLongInRange(value, 0, UNSIGNED_INT_MAX, "UINTEGER")); + return; + case BIGINT: + case UBIGINT: + case TIME: + case TIME_NS: + case TIME_WITH_TIME_ZONE: + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + case TIMESTAMP_WITH_TIME_ZONE: + setLong(row, requireLongOrTemporal(value, type)); + return; + case FLOAT: + setFloat(row, (float) requireDouble(value)); + return; + case DOUBLE: + setDouble(row, requireDouble(value)); + return; + case DECIMAL: + setBigDecimal(row, requireBigDecimal(value)); + return; + case VARCHAR: + setString(row, requireString(value)); + return; + case BLOB: + setBytes(row, requireBytes(value)); + return; + case HUGEINT: + case UHUGEINT: + if (value instanceof BigInteger) { + setBigInteger(row, (BigInteger) value); + } else { + setBytes(row, requireFixedWidthBytes(value, 16, type.toString())); + } + return; + case UUID: + if (value instanceof UUID) { + setBytes(row, uuidToBytes((UUID) value)); + } else { + setBytes(row, requireFixedWidthBytes(value, 16, "UUID")); + } + return; + default: + throw new IllegalArgumentException("Unsupported output type for setObject: " + type); + } + } + + public void setLocalDate(int row, LocalDate value) { + setObject(row, value); + } + + public void setLocalTime(int row, LocalTime value) { + setObject(row, value); + } + + public void setOffsetTime(int row, OffsetTime value) { + setObject(row, value); + } + + public void setLocalDateTime(int row, LocalDateTime value) { + setObject(row, value); + } + + public void setDate(int row, Date value) { + setObject(row, value); + } + + public void setOffsetDateTime(int row, OffsetDateTime value) { + setObject(row, value); + } + + public void setUUID(int row, UUID value) { + setObject(row, value); + } + + public void setBigInteger(int row, BigInteger value) { + checkIndex(row); + if (type != DuckDBColumnType.HUGEINT && type != DuckDBColumnType.UHUGEINT) { + throw new UnsupportedOperationException("setBigInteger is not supported for " + type + " vectors"); + } + if (value == null) { + setNull(row); + return; + } + setBytes(row, toInt128Bytes(value, type)); + } + + private static boolean requireBoolean(Object value) { + if (!(value instanceof Boolean)) { + throw new IllegalArgumentException("Expected Boolean value but got " + value.getClass().getName()); + } + return (Boolean) value; + } + + private static double requireDouble(Object value) { + if (!(value instanceof Number)) { + throw new IllegalArgumentException("Expected numeric value but got " + value.getClass().getName()); + } + return ((Number) value).doubleValue(); + } + + private static String requireString(Object value) { + if (!(value instanceof String)) { + throw new IllegalArgumentException("Expected String value but got " + value.getClass().getName()); + } + return (String) value; + } + + private static byte[] requireBytes(Object value) { + if (!(value instanceof byte[])) { + throw new IllegalArgumentException("Expected byte[] value but got " + value.getClass().getName()); + } + return (byte[]) value; + } + + private static byte[] requireFixedWidthBytes(Object value, int width, String typeName) { + byte[] bytes = requireBytes(value); + if (bytes.length != width) { + throw new IllegalArgumentException("Expected " + width + " bytes for " + typeName + " value, got " + + bytes.length); + } + return bytes; + } + + private static BigDecimal requireBigDecimal(Object value) { + if (value instanceof BigDecimal) { + return (BigDecimal) value; + } + if (value instanceof Number) { + return toBigDecimal((Number) value); + } + throw new IllegalArgumentException("Expected BigDecimal/Number value but got " + value.getClass().getName()); + } + + private static BigDecimal toBigDecimal(Number value) { + if (value instanceof BigInteger) { + return new BigDecimal((BigInteger) value); + } + if (value instanceof Byte || value instanceof Short || value instanceof Integer || value instanceof Long) { + return BigDecimal.valueOf(value.longValue()); + } + if (value instanceof Float || value instanceof Double) { + try { + return BigDecimal.valueOf(value.doubleValue()); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Expected finite floating-point value for DECIMAL conversion", e); + } + } + try { + return new BigDecimal(value.toString()); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Failed to coerce numeric value to BigDecimal", e); + } + } + + private static long requireSignedLongInRange(Object value, long min, long max, String typeName) { + long num = requireIntegralLong(value, typeName); + if (num < min || num > max) { + throw new IllegalArgumentException("Value out of range for " + typeName + ": " + num); + } + return num; + } + + private static int requireDateEpochDays(Object value) { + if (value instanceof LocalDate) { + long days = ((LocalDate) value).toEpochDay(); + if (days < Integer.MIN_VALUE || days > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Expected LocalDate epoch day to fit int32 but got " + days); + } + return (int) days; + } + if (value instanceof Date) { + LocalDate localDate = Instant.ofEpochMilli(((Date) value).getTime()).atOffset(ZoneOffset.UTC).toLocalDate(); + long days = localDate.toEpochDay(); + if (days < Integer.MIN_VALUE || days > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Expected Date epoch day to fit int32 but got " + days); + } + return (int) days; + } + return (int) requireSignedLongInRange(value, Integer.MIN_VALUE, Integer.MAX_VALUE, "DATE"); + } + + private static long requireLongOrTemporal(Object value, DuckDBColumnType colType) { + if (value instanceof Number) { + if (colType == DuckDBColumnType.UBIGINT) { + return requireUnsignedLongBits(value, "UBIGINT"); + } + return requireIntegralLong(value, colType.toString()); + } + switch (colType) { + case TIME: + if (value instanceof LocalTime) { + return ((LocalTime) value).toNanoOfDay() / 1000L; + } + break; + case TIME_NS: + if (value instanceof LocalTime) { + return ((LocalTime) value).toNanoOfDay(); + } + break; + case TIME_WITH_TIME_ZONE: + if (value instanceof OffsetTime) { + OffsetTime time = (OffsetTime) value; + return packTimeTzMicros(time.toLocalTime().toNanoOfDay() / 1000L, time.getOffset().getTotalSeconds()); + } + break; + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP: + case TIMESTAMP_NS: + if (value instanceof LocalDateTime) { + return localDateTimeToMoment((LocalDateTime) value, colType); + } + if (value instanceof Date) { + return dateToMoment((Date) value, colType); + } + break; + case TIMESTAMP_WITH_TIME_ZONE: + if (value instanceof OffsetDateTime) { + LocalDateTime utcDateTime = + ((OffsetDateTime) value).atZoneSameInstant(ZoneOffset.UTC).toLocalDateTime(); + return EPOCH_DATE_TIME.until(utcDateTime, ChronoUnit.MICROS); + } + if (value instanceof Date) { + return Math.multiplyExact(((Date) value).getTime(), 1000L); + } + break; + default: + break; + } + throw new IllegalArgumentException("Expected numeric/temporal value compatible with " + colType + " but got " + + value.getClass().getName()); + } + + private static long requireIntegralLong(Object value, String typeName) { + BigInteger integerValue = requireIntegralBigInteger(value, typeName); + try { + return integerValue.longValueExact(); + } catch (ArithmeticException e) { + throw new IllegalArgumentException("Value out of range for " + typeName + ": " + integerValue, e); + } + } + + private static long requireUnsignedLongBits(Object value, String typeName) { + BigInteger integerValue = requireIntegralBigInteger(value, typeName); + if (integerValue.compareTo(U64_MIN) < 0 || integerValue.compareTo(U64_MAX) > 0) { + throw new IllegalArgumentException("Value out of range for " + typeName + ": " + integerValue); + } + if (integerValue.signum() >= 0 && integerValue.bitLength() <= 63) { + return integerValue.longValue(); + } + return integerValue.subtract(U64_MODULUS).longValue(); + } + + private static BigInteger requireIntegralBigInteger(Object value, String typeName) { + if (!(value instanceof Number)) { + throw new IllegalArgumentException("Expected numeric value for " + typeName + " but got " + + value.getClass().getName()); + } + + Number number = (Number) value; + if (number instanceof BigInteger) { + return (BigInteger) number; + } + if (number instanceof BigDecimal) { + try { + return ((BigDecimal) number).toBigIntegerExact(); + } catch (ArithmeticException e) { + throw new IllegalArgumentException("Expected integral value for " + typeName + ": " + number, e); + } + } + if (number instanceof Byte || number instanceof Short || number instanceof Integer || number instanceof Long) { + return BigInteger.valueOf(number.longValue()); + } + if (number instanceof Float || number instanceof Double) { + double d = number.doubleValue(); + if (!Double.isFinite(d)) { + throw new IllegalArgumentException("Expected finite value for " + typeName + " but got " + d); + } + try { + return BigDecimal.valueOf(d).toBigIntegerExact(); + } catch (ArithmeticException e) { + throw new IllegalArgumentException("Expected integral value for " + typeName + ": " + number, e); + } + } + + try { + return new BigDecimal(number.toString()).toBigIntegerExact(); + } catch (NumberFormatException | ArithmeticException e) { + throw new IllegalArgumentException("Expected integral value for " + typeName + ": " + number, e); + } + } + + private static long localDateTimeToMoment(LocalDateTime value, DuckDBColumnType colType) { + switch (colType) { + case TIMESTAMP_S: + return EPOCH_DATE_TIME.until(value, ChronoUnit.SECONDS); + case TIMESTAMP_MS: + return EPOCH_DATE_TIME.until(value, ChronoUnit.MILLIS); + case TIMESTAMP: + return EPOCH_DATE_TIME.until(value, ChronoUnit.MICROS); + case TIMESTAMP_NS: + return EPOCH_DATE_TIME.until(value, ChronoUnit.NANOS); + default: + throw new IllegalArgumentException("Unsupported LocalDateTime conversion for " + colType); + } + } + + private static long dateToMoment(Date value, DuckDBColumnType colType) { + long millis = value.getTime(); + switch (colType) { + case TIMESTAMP_S: + return millis / 1000L; + case TIMESTAMP_MS: + return millis; + case TIMESTAMP: + return Math.multiplyExact(millis, 1000L); + case TIMESTAMP_NS: + return Math.multiplyExact(millis, 1000000L); + default: + throw new IllegalArgumentException("Unsupported Date conversion for " + colType); + } + } + + private static long packTimeTzMicros(long micros, int offsetSeconds) { + if (offsetSeconds < -MAX_TZ_SECONDS || offsetSeconds > MAX_TZ_SECONDS) { + throw new IllegalArgumentException("TIME WITH TIME ZONE offset out of range: " + offsetSeconds + + " seconds (allowed range: -" + MAX_TZ_SECONDS + ".." + MAX_TZ_SECONDS + + ")"); + } + long normalizedOffset = MAX_TZ_SECONDS - offsetSeconds; + return ((micros & 0xFFFFFFFFFFL) << 24) | (normalizedOffset & 0xFFFFFFL); + } + + private static BigInteger toUnsignedBigInteger(long value) { + if (value >= 0) { + return BigInteger.valueOf(value); + } + return BigInteger.valueOf(value & Long.MAX_VALUE).setBit(63); + } + + private static byte[] toInt128Bytes(BigInteger value, DuckDBColumnType targetType) { + if (value == null) { + throw new IllegalArgumentException("BigInteger value must not be null"); + } + if (targetType == DuckDBColumnType.HUGEINT) { + if (value.compareTo(HUGEINT_MIN) < 0 || value.compareTo(HUGEINT_MAX) > 0) { + throw new IllegalArgumentException("Value out of range for HUGEINT: " + value); + } + } else if (targetType == DuckDBColumnType.UHUGEINT) { + if (value.compareTo(UHUGEINT_MIN) < 0 || value.compareTo(UHUGEINT_MAX) > 0) { + throw new IllegalArgumentException("Value out of range for UHUGEINT: " + value); + } + } else { + throw new IllegalArgumentException("Int128 conversion is only supported for HUGEINT/UHUGEINT"); + } + + long lower = value.longValue(); + long upper = value.shiftRight(64).longValue(); + ByteBuffer buffer = ByteBuffer.allocate(16).order(ByteOrder.nativeOrder()); + buffer.putLong(lower); + buffer.putLong(upper); + return buffer.array(); + } + + private static byte[] uuidToBytes(UUID value) { + ByteBuffer buffer = ByteBuffer.allocate(16).order(ByteOrder.nativeOrder()); + buffer.putLong(value.getLeastSignificantBits()); + buffer.putLong(value.getMostSignificantBits() ^ Long.MIN_VALUE); + return buffer.array(); + } +} diff --git a/src/main/java/org/duckdb/UdfTypeCatalog.java b/src/main/java/org/duckdb/UdfTypeCatalog.java new file mode 100644 index 000000000..c88b776c1 --- /dev/null +++ b/src/main/java/org/duckdb/UdfTypeCatalog.java @@ -0,0 +1,264 @@ +package org.duckdb; + +import java.sql.SQLFeatureNotSupportedException; +import java.util.Collections; +import java.util.EnumMap; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import org.duckdb.udf.UdfLogicalType; + +public final class UdfTypeCatalog { + public enum Accessor { + GET_INT, + GET_LONG, + GET_FLOAT, + GET_DOUBLE, + GET_DECIMAL, + GET_BOOLEAN, + GET_STRING, + GET_BYTES, + SET_INT, + SET_LONG, + SET_FLOAT, + SET_DOUBLE, + SET_DECIMAL, + SET_BOOLEAN, + SET_STRING, + SET_BYTES + } + + private static final EnumSet CORE_TYPES = EnumSet.of( + DuckDBColumnType.BOOLEAN, DuckDBColumnType.TINYINT, DuckDBColumnType.SMALLINT, DuckDBColumnType.INTEGER, + DuckDBColumnType.BIGINT, DuckDBColumnType.FLOAT, DuckDBColumnType.DOUBLE, DuckDBColumnType.VARCHAR); + + private static final EnumSet EXTENDED_TYPES = + EnumSet.of(DuckDBColumnType.DECIMAL, DuckDBColumnType.BLOB, DuckDBColumnType.DATE, DuckDBColumnType.TIME, + DuckDBColumnType.TIME_NS, DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIMESTAMP_S, + DuckDBColumnType.TIMESTAMP_MS, DuckDBColumnType.TIMESTAMP_NS); + + private static final EnumSet ADVANCED_TYPES = EnumSet.of( + DuckDBColumnType.UTINYINT, DuckDBColumnType.USMALLINT, DuckDBColumnType.UINTEGER, DuckDBColumnType.UBIGINT, + DuckDBColumnType.HUGEINT, DuckDBColumnType.UHUGEINT, DuckDBColumnType.TIME_WITH_TIME_ZONE, + DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, DuckDBColumnType.UUID); + + private static final EnumSet SUPPORTED_UDF_TYPES = EnumSet.copyOf(CORE_TYPES); + + static { + SUPPORTED_UDF_TYPES.addAll(EXTENDED_TYPES); + SUPPORTED_UDF_TYPES.addAll(ADVANCED_TYPES); + } + + private static final EnumSet SCALAR_IMPLEMENTED_TYPES = EnumSet.copyOf(SUPPORTED_UDF_TYPES); + + private static final EnumSet TABLE_BIND_SCHEMA_TYPES = EnumSet.copyOf(SUPPORTED_UDF_TYPES); + + private static final EnumSet TABLE_PARAMETER_TYPES = EnumSet.copyOf(SUPPORTED_UDF_TYPES); + + private static final EnumSet VARLEN_TYPES = + EnumSet.of(DuckDBColumnType.VARCHAR, DuckDBColumnType.BLOB); + private static final EnumSet VECTOR_REF_TYPES = + EnumSet.of(DuckDBColumnType.VARCHAR, DuckDBColumnType.BLOB, DuckDBColumnType.DECIMAL); + + private static final EnumMap CAPI_TYPE_IDS = new EnumMap<>(DuckDBColumnType.class); + private static final EnumMap> ACCESS_MATRIX = + new EnumMap<>(DuckDBColumnType.class); + private static final Map COLUMN_TYPES_BY_CAPI_ID = new HashMap<>(); + private static final Map> ACCESS_MATRIX_VIEW; + + static { + CAPI_TYPE_IDS.put(DuckDBColumnType.BOOLEAN, DuckDBBindings.CAPIType.DUCKDB_TYPE_BOOLEAN.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.TINYINT, DuckDBBindings.CAPIType.DUCKDB_TYPE_TINYINT.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.SMALLINT, DuckDBBindings.CAPIType.DUCKDB_TYPE_SMALLINT.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.INTEGER, DuckDBBindings.CAPIType.DUCKDB_TYPE_INTEGER.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.BIGINT, DuckDBBindings.CAPIType.DUCKDB_TYPE_BIGINT.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.UTINYINT, DuckDBBindings.CAPIType.DUCKDB_TYPE_UTINYINT.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.USMALLINT, DuckDBBindings.CAPIType.DUCKDB_TYPE_USMALLINT.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.UINTEGER, DuckDBBindings.CAPIType.DUCKDB_TYPE_UINTEGER.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.UBIGINT, DuckDBBindings.CAPIType.DUCKDB_TYPE_UBIGINT.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.HUGEINT, DuckDBBindings.CAPIType.DUCKDB_TYPE_HUGEINT.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.UHUGEINT, DuckDBBindings.CAPIType.DUCKDB_TYPE_UHUGEINT.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.FLOAT, DuckDBBindings.CAPIType.DUCKDB_TYPE_FLOAT.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.DOUBLE, DuckDBBindings.CAPIType.DUCKDB_TYPE_DOUBLE.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.VARCHAR, DuckDBBindings.CAPIType.DUCKDB_TYPE_VARCHAR.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.BLOB, DuckDBBindings.CAPIType.DUCKDB_TYPE_BLOB.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.DECIMAL, DuckDBBindings.CAPIType.DUCKDB_TYPE_DECIMAL.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.DATE, DuckDBBindings.CAPIType.DUCKDB_TYPE_DATE.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.TIME, DuckDBBindings.CAPIType.DUCKDB_TYPE_TIME.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.TIME_NS, DuckDBBindings.CAPIType.DUCKDB_TYPE_TIME_NS.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.TIMESTAMP, DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.TIMESTAMP_S, DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_S.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.TIMESTAMP_MS, DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_MS.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.TIMESTAMP_NS, DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_NS.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.TIME_WITH_TIME_ZONE, DuckDBBindings.CAPIType.DUCKDB_TYPE_TIME_TZ.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, + DuckDBBindings.CAPIType.DUCKDB_TYPE_TIMESTAMP_TZ.typeId); + CAPI_TYPE_IDS.put(DuckDBColumnType.UUID, DuckDBBindings.CAPIType.DUCKDB_TYPE_UUID.typeId); + + ACCESS_MATRIX.put(DuckDBColumnType.BOOLEAN, EnumSet.of(Accessor.GET_BOOLEAN, Accessor.SET_BOOLEAN)); + ACCESS_MATRIX.put(DuckDBColumnType.TINYINT, EnumSet.of(Accessor.GET_INT, Accessor.SET_INT)); + ACCESS_MATRIX.put(DuckDBColumnType.SMALLINT, EnumSet.of(Accessor.GET_INT, Accessor.SET_INT)); + ACCESS_MATRIX.put(DuckDBColumnType.INTEGER, EnumSet.of(Accessor.GET_INT, Accessor.SET_INT)); + ACCESS_MATRIX.put(DuckDBColumnType.BIGINT, EnumSet.of(Accessor.GET_LONG, Accessor.SET_LONG)); + ACCESS_MATRIX.put(DuckDBColumnType.UTINYINT, EnumSet.of(Accessor.GET_INT, Accessor.SET_INT)); + ACCESS_MATRIX.put(DuckDBColumnType.USMALLINT, EnumSet.of(Accessor.GET_INT, Accessor.SET_INT)); + ACCESS_MATRIX.put(DuckDBColumnType.UINTEGER, EnumSet.of(Accessor.GET_LONG, Accessor.SET_LONG)); + ACCESS_MATRIX.put(DuckDBColumnType.UBIGINT, EnumSet.of(Accessor.GET_LONG, Accessor.SET_LONG)); + ACCESS_MATRIX.put(DuckDBColumnType.HUGEINT, EnumSet.of(Accessor.GET_BYTES, Accessor.SET_BYTES)); + ACCESS_MATRIX.put(DuckDBColumnType.UHUGEINT, EnumSet.of(Accessor.GET_BYTES, Accessor.SET_BYTES)); + ACCESS_MATRIX.put(DuckDBColumnType.FLOAT, EnumSet.of(Accessor.GET_FLOAT, Accessor.SET_FLOAT)); + ACCESS_MATRIX.put(DuckDBColumnType.DOUBLE, EnumSet.of(Accessor.GET_DOUBLE, Accessor.SET_DOUBLE)); + ACCESS_MATRIX.put(DuckDBColumnType.VARCHAR, EnumSet.of(Accessor.GET_STRING, Accessor.SET_STRING)); + ACCESS_MATRIX.put(DuckDBColumnType.BLOB, EnumSet.of(Accessor.GET_BYTES, Accessor.SET_BYTES)); + ACCESS_MATRIX.put(DuckDBColumnType.DECIMAL, EnumSet.of(Accessor.GET_DECIMAL, Accessor.SET_DECIMAL)); + ACCESS_MATRIX.put(DuckDBColumnType.DATE, EnumSet.of(Accessor.GET_INT, Accessor.SET_INT)); + ACCESS_MATRIX.put(DuckDBColumnType.TIME, EnumSet.of(Accessor.GET_LONG, Accessor.SET_LONG)); + ACCESS_MATRIX.put(DuckDBColumnType.TIME_NS, EnumSet.of(Accessor.GET_LONG, Accessor.SET_LONG)); + ACCESS_MATRIX.put(DuckDBColumnType.TIMESTAMP, EnumSet.of(Accessor.GET_LONG, Accessor.SET_LONG)); + ACCESS_MATRIX.put(DuckDBColumnType.TIMESTAMP_S, EnumSet.of(Accessor.GET_LONG, Accessor.SET_LONG)); + ACCESS_MATRIX.put(DuckDBColumnType.TIMESTAMP_MS, EnumSet.of(Accessor.GET_LONG, Accessor.SET_LONG)); + ACCESS_MATRIX.put(DuckDBColumnType.TIMESTAMP_NS, EnumSet.of(Accessor.GET_LONG, Accessor.SET_LONG)); + ACCESS_MATRIX.put(DuckDBColumnType.TIME_WITH_TIME_ZONE, EnumSet.of(Accessor.GET_LONG, Accessor.SET_LONG)); + ACCESS_MATRIX.put(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, EnumSet.of(Accessor.GET_LONG, Accessor.SET_LONG)); + ACCESS_MATRIX.put(DuckDBColumnType.UUID, EnumSet.of(Accessor.GET_BYTES, Accessor.SET_BYTES)); + + for (Map.Entry entry : CAPI_TYPE_IDS.entrySet()) { + COLUMN_TYPES_BY_CAPI_ID.put(entry.getValue(), entry.getKey()); + } + + EnumMap> accessorMatrixView = new EnumMap<>(DuckDBColumnType.class); + for (Map.Entry> entry : ACCESS_MATRIX.entrySet()) { + accessorMatrixView.put(entry.getKey(), Collections.unmodifiableSet(EnumSet.copyOf(entry.getValue()))); + } + ACCESS_MATRIX_VIEW = Collections.unmodifiableMap(accessorMatrixView); + } + + private UdfTypeCatalog() { + } + + public static int toCapiTypeId(DuckDBColumnType type) throws SQLFeatureNotSupportedException { + Integer capiType = CAPI_TYPE_IDS.get(type); + if (capiType == null) { + throw new SQLFeatureNotSupportedException("Unsupported UDF type: " + type); + } + return capiType; + } + + public static DuckDBColumnType fromCapiTypeId(int capiTypeId) throws SQLFeatureNotSupportedException { + DuckDBColumnType type = COLUMN_TYPES_BY_CAPI_ID.get(capiTypeId); + if (type == null) { + throw new SQLFeatureNotSupportedException("Unsupported C API type id for UDF: " + capiTypeId); + } + return type; + } + + public static int toCapiTypeIdForScalarRegistration(DuckDBColumnType type) throws SQLFeatureNotSupportedException { + if (!SCALAR_IMPLEMENTED_TYPES.contains(type)) { + throw new SQLFeatureNotSupportedException( + "Supported scalar UDF types: BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, DOUBLE, VARCHAR, " + + "DECIMAL, BLOB, DATE, TIME, TIME_NS, TIMESTAMP, TIMESTAMP_S, TIMESTAMP_MS, TIMESTAMP_NS, " + + "UTINYINT, USMALLINT, UINTEGER, UBIGINT, HUGEINT, UHUGEINT, TIME_WITH_TIME_ZONE, " + + "TIMESTAMP_WITH_TIME_ZONE, UUID"); + } + return toCapiTypeId(type); + } + + public static void validateScalarLogicalType(UdfLogicalType logicalType) throws SQLFeatureNotSupportedException { + if (logicalType == null) { + throw new SQLFeatureNotSupportedException("Scalar UDF logical type cannot be null"); + } + DuckDBColumnType type = logicalType.getType(); + if (type == null) { + throw new SQLFeatureNotSupportedException("Scalar UDF logical type has null DuckDBColumnType"); + } + if (!SCALAR_IMPLEMENTED_TYPES.contains(type)) { + throw new SQLFeatureNotSupportedException( + "Supported scalar UDF types: BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, DOUBLE, VARCHAR, " + + "DECIMAL, BLOB, DATE, TIME, TIME_NS, TIMESTAMP, TIMESTAMP_S, TIMESTAMP_MS, TIMESTAMP_NS, " + + "UTINYINT, USMALLINT, UINTEGER, UBIGINT, HUGEINT, UHUGEINT, TIME_WITH_TIME_ZONE, " + + "TIMESTAMP_WITH_TIME_ZONE, UUID"); + } + } + + public static boolean isScalarUdfImplemented(DuckDBColumnType type) { + return SCALAR_IMPLEMENTED_TYPES.contains(type); + } + + public static boolean isTableBindSchemaType(DuckDBColumnType type) { + return TABLE_BIND_SCHEMA_TYPES.contains(type); + } + + public static boolean isVarLenType(DuckDBColumnType type) { + return VARLEN_TYPES.contains(type); + } + + public static boolean requiresVectorRef(DuckDBColumnType type) { + return VECTOR_REF_TYPES.contains(type); + } + + public static boolean isTableFunctionParameterType(DuckDBColumnType type) { + return TABLE_PARAMETER_TYPES.contains(type); + } + + public static int toCapiTypeIdForTableFunctionParameter(DuckDBColumnType type) + throws SQLFeatureNotSupportedException { + if (!TABLE_PARAMETER_TYPES.contains(type)) { + throw new SQLFeatureNotSupportedException( + "Supported table function parameter types: BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, " + + "DOUBLE, VARCHAR, DECIMAL, BLOB, DATE, TIME, TIME_NS, TIMESTAMP, TIMESTAMP_S, TIMESTAMP_MS, " + + "TIMESTAMP_NS, UTINYINT, USMALLINT, UINTEGER, UBIGINT, HUGEINT, UHUGEINT, TIME_WITH_TIME_ZONE, " + + "TIMESTAMP_WITH_TIME_ZONE, UUID"); + } + return toCapiTypeId(type); + } + + public static void validateTableFunctionParameterLogicalType(UdfLogicalType logicalType) + throws SQLFeatureNotSupportedException { + if (logicalType == null) { + throw new SQLFeatureNotSupportedException("Table function parameter logical type cannot be null"); + } + DuckDBColumnType type = logicalType.getType(); + if (type == null) { + throw new SQLFeatureNotSupportedException( + "Table function parameter logical type has null DuckDBColumnType"); + } + + if (TABLE_PARAMETER_TYPES.contains(type) || type == DuckDBColumnType.ENUM) { + return; + } + + switch (type) { + case LIST: + case ARRAY: + validateTableFunctionParameterLogicalType(logicalType.getChildType()); + return; + case MAP: + validateTableFunctionParameterLogicalType(logicalType.getKeyType()); + validateTableFunctionParameterLogicalType(logicalType.getValueType()); + return; + case STRUCT: + case UNION: + UdfLogicalType[] fieldTypes = logicalType.getFieldTypes(); + if (fieldTypes == null || fieldTypes.length == 0) { + throw new SQLFeatureNotSupportedException("Table function " + type + " parameter requires fields"); + } + for (UdfLogicalType fieldType : fieldTypes) { + validateTableFunctionParameterLogicalType(fieldType); + } + return; + default: + throw new SQLFeatureNotSupportedException("Unsupported table function parameter type in logical schema: " + + type); + } + } + + public static boolean supportsAccessor(DuckDBColumnType type, Accessor accessor) { + EnumSet accessors = ACCESS_MATRIX.get(type); + return accessors != null && accessors.contains(accessor); + } + + public static Map> accessorMatrixView() { + return ACCESS_MATRIX_VIEW; + } +} diff --git a/src/main/java/org/duckdb/udf/BindContext.java b/src/main/java/org/duckdb/udf/BindContext.java new file mode 100644 index 000000000..c75cfc06b --- /dev/null +++ b/src/main/java/org/duckdb/udf/BindContext.java @@ -0,0 +1,3 @@ +package org.duckdb.udf; + +public interface BindContext {} diff --git a/src/main/java/org/duckdb/udf/InitContext.java b/src/main/java/org/duckdb/udf/InitContext.java new file mode 100644 index 000000000..f6f5ba21b --- /dev/null +++ b/src/main/java/org/duckdb/udf/InitContext.java @@ -0,0 +1,8 @@ +package org.duckdb.udf; + +public interface InitContext { + // The projection list is ordered to match the output columns passed to produce(). + int getColumnCount(); + + int getColumnIndex(int idx); +} diff --git a/src/main/java/org/duckdb/udf/ScalarUdf.java b/src/main/java/org/duckdb/udf/ScalarUdf.java new file mode 100644 index 000000000..0e47ff3ec --- /dev/null +++ b/src/main/java/org/duckdb/udf/ScalarUdf.java @@ -0,0 +1,9 @@ +package org.duckdb.udf; + +import org.duckdb.UdfReader; +import org.duckdb.UdfScalarWriter; + +@FunctionalInterface +public interface ScalarUdf { + void apply(UdfContext ctx, UdfReader[] args, UdfScalarWriter out, int rowCount) throws Exception; +} diff --git a/src/main/java/org/duckdb/udf/TableBindResult.java b/src/main/java/org/duckdb/udf/TableBindResult.java new file mode 100644 index 000000000..9d672ce7d --- /dev/null +++ b/src/main/java/org/duckdb/udf/TableBindResult.java @@ -0,0 +1,69 @@ +package org.duckdb.udf; + +import java.util.Arrays; +import java.util.Objects; +import org.duckdb.DuckDBColumnType; + +public final class TableBindResult { + private final String[] columnNames; + private final DuckDBColumnType[] columnTypes; + private final UdfLogicalType[] columnLogicalTypes; + private final Object bindState; + + public TableBindResult(String[] columnNames, DuckDBColumnType[] columnTypes) { + this(columnNames, columnTypes, null); + } + + public TableBindResult(String[] columnNames, DuckDBColumnType[] columnTypes, Object bindState) { + this.columnNames = Objects.requireNonNull(columnNames, "columnNames").clone(); + this.columnTypes = Objects.requireNonNull(columnTypes, "columnTypes").clone(); + if (this.columnNames.length != this.columnTypes.length) { + throw new IllegalArgumentException("columnNames and columnTypes must have same length"); + } + this.columnLogicalTypes = null; + this.bindState = bindState; + } + + public TableBindResult(String[] columnNames, UdfLogicalType[] columnLogicalTypes) { + this(columnNames, columnLogicalTypes, null); + } + + public TableBindResult(String[] columnNames, UdfLogicalType[] columnLogicalTypes, Object bindState) { + this.columnNames = Objects.requireNonNull(columnNames, "columnNames").clone(); + this.columnLogicalTypes = Objects.requireNonNull(columnLogicalTypes, "columnLogicalTypes").clone(); + if (this.columnNames.length != this.columnLogicalTypes.length) { + throw new IllegalArgumentException("columnNames and columnLogicalTypes must have same length"); + } + this.columnTypes = new DuckDBColumnType[this.columnLogicalTypes.length]; + for (int i = 0; i < this.columnLogicalTypes.length; i++) { + if (this.columnLogicalTypes[i] == null) { + throw new IllegalArgumentException("columnLogicalTypes[" + i + "] must not be null"); + } + this.columnTypes[i] = this.columnLogicalTypes[i].getType(); + } + this.bindState = bindState; + } + + public String[] getColumnNames() { + return columnNames.clone(); + } + + public DuckDBColumnType[] getColumnTypes() { + return columnTypes.clone(); + } + + public UdfLogicalType[] getColumnLogicalTypes() { + return columnLogicalTypes == null ? null : columnLogicalTypes.clone(); + } + + public Object getBindState() { + return bindState; + } + + @Override + public String toString() { + return "TableBindResult{" + + "columnNames=" + Arrays.toString(columnNames) + ", columnTypes=" + Arrays.toString(columnTypes) + + ", columnLogicalTypes=" + Arrays.toString(columnLogicalTypes) + "}"; + } +} diff --git a/src/main/java/org/duckdb/udf/TableFunction.java b/src/main/java/org/duckdb/udf/TableFunction.java new file mode 100644 index 000000000..332f42fa5 --- /dev/null +++ b/src/main/java/org/duckdb/udf/TableFunction.java @@ -0,0 +1,11 @@ +package org.duckdb.udf; + +import org.duckdb.UdfOutputAppender; + +public interface TableFunction { + TableBindResult bind(BindContext ctx, Object[] parameters) throws Exception; + + TableState init(InitContext ctx, TableBindResult bind) throws Exception; + + int produce(TableState state, UdfOutputAppender out) throws Exception; +} diff --git a/src/main/java/org/duckdb/udf/TableFunctionDefinition.java b/src/main/java/org/duckdb/udf/TableFunctionDefinition.java new file mode 100644 index 000000000..34f8df78e --- /dev/null +++ b/src/main/java/org/duckdb/udf/TableFunctionDefinition.java @@ -0,0 +1,55 @@ +package org.duckdb.udf; + +import java.util.Objects; +import org.duckdb.DuckDBColumnType; + +public final class TableFunctionDefinition { + private final boolean projectionPushdown; + private final UdfLogicalType[] parameterLogicalTypes; + + public TableFunctionDefinition() { + this(false, new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.BIGINT)}); + } + + private TableFunctionDefinition(boolean projectionPushdown, UdfLogicalType[] parameterLogicalTypes) { + this.projectionPushdown = projectionPushdown; + this.parameterLogicalTypes = Objects.requireNonNull(parameterLogicalTypes, "parameterLogicalTypes").clone(); + for (UdfLogicalType parameterType : this.parameterLogicalTypes) { + Objects.requireNonNull(parameterType, "parameterLogicalTypes cannot contain null values"); + } + } + + public TableFunctionDefinition withProjectionPushdown(boolean enabled) { + return new TableFunctionDefinition(enabled, parameterLogicalTypes); + } + + public TableFunctionDefinition withParameterTypes(DuckDBColumnType[] parameterTypes) { + Objects.requireNonNull(parameterTypes, "parameterTypes"); + UdfLogicalType[] logicalTypes = new UdfLogicalType[parameterTypes.length]; + for (int i = 0; i < parameterTypes.length; i++) { + logicalTypes[i] = UdfLogicalType.of( + Objects.requireNonNull(parameterTypes[i], "parameterTypes cannot contain null values")); + } + return new TableFunctionDefinition(projectionPushdown, logicalTypes); + } + + public TableFunctionDefinition withParameterTypes(UdfLogicalType[] parameterLogicalTypes) { + return new TableFunctionDefinition(projectionPushdown, parameterLogicalTypes); + } + + public boolean isProjectionPushdownEnabled() { + return projectionPushdown; + } + + public DuckDBColumnType[] getParameterTypes() { + DuckDBColumnType[] types = new DuckDBColumnType[parameterLogicalTypes.length]; + for (int i = 0; i < parameterLogicalTypes.length; i++) { + types[i] = parameterLogicalTypes[i].getType(); + } + return types; + } + + public UdfLogicalType[] getParameterLogicalTypes() { + return parameterLogicalTypes.clone(); + } +} diff --git a/src/main/java/org/duckdb/udf/TableFunctionOptions.java b/src/main/java/org/duckdb/udf/TableFunctionOptions.java new file mode 100644 index 000000000..de222bdad --- /dev/null +++ b/src/main/java/org/duckdb/udf/TableFunctionOptions.java @@ -0,0 +1,22 @@ +package org.duckdb.udf; + +public final class TableFunctionOptions { + public boolean threadSafe = false; + public int maxThreads = 1; + + public TableFunctionOptions() { + } + + public TableFunctionOptions threadSafe(boolean value) { + this.threadSafe = value; + return this; + } + + public TableFunctionOptions maxThreads(int value) { + if (value < 1) { + throw new IllegalArgumentException("maxThreads must be >= 1"); + } + this.maxThreads = value; + return this; + } +} diff --git a/src/main/java/org/duckdb/udf/TableInitContext.java b/src/main/java/org/duckdb/udf/TableInitContext.java new file mode 100644 index 000000000..95044307a --- /dev/null +++ b/src/main/java/org/duckdb/udf/TableInitContext.java @@ -0,0 +1,21 @@ +package org.duckdb.udf; + +import java.util.Objects; + +public final class TableInitContext implements InitContext { + private final int[] columnIndexes; + + public TableInitContext(int[] columnIndexes) { + this.columnIndexes = Objects.requireNonNull(columnIndexes, "columnIndexes").clone(); + } + + @Override + public int getColumnCount() { + return columnIndexes.length; + } + + @Override + public int getColumnIndex(int projectedColumnIndex) { + return columnIndexes[projectedColumnIndex]; + } +} diff --git a/src/main/java/org/duckdb/udf/TableState.java b/src/main/java/org/duckdb/udf/TableState.java new file mode 100644 index 000000000..632aa928f --- /dev/null +++ b/src/main/java/org/duckdb/udf/TableState.java @@ -0,0 +1,13 @@ +package org.duckdb.udf; + +public final class TableState { + private final Object state; + + public TableState(Object state) { + this.state = state; + } + + public Object getState() { + return state; + } +} diff --git a/src/main/java/org/duckdb/udf/UdfContext.java b/src/main/java/org/duckdb/udf/UdfContext.java new file mode 100644 index 000000000..fb03247ec --- /dev/null +++ b/src/main/java/org/duckdb/udf/UdfContext.java @@ -0,0 +1,3 @@ +package org.duckdb.udf; + +public interface UdfContext {} diff --git a/src/main/java/org/duckdb/udf/UdfLogicalType.java b/src/main/java/org/duckdb/udf/UdfLogicalType.java new file mode 100644 index 000000000..bb01d5d20 --- /dev/null +++ b/src/main/java/org/duckdb/udf/UdfLogicalType.java @@ -0,0 +1,179 @@ +package org.duckdb.udf; + +import java.util.Arrays; +import java.util.Objects; +import org.duckdb.DuckDBColumnType; + +public final class UdfLogicalType { + private static final int DECIMAL_WIDTH_MIN = 1; + private static final int DECIMAL_WIDTH_MAX = 38; + private static final int DECIMAL_SCALE_MIN = 0; + private static final int DEFAULT_DECIMAL_WIDTH = 18; + private static final int DEFAULT_DECIMAL_SCALE = 3; + + private final DuckDBColumnType type; + private final UdfLogicalType childType; + private final long arraySize; + private final UdfLogicalType keyType; + private final UdfLogicalType valueType; + private final String[] fieldNames; + private final UdfLogicalType[] fieldTypes; + private final String[] enumValues; + private final int decimalWidth; + private final int decimalScale; + + private UdfLogicalType(DuckDBColumnType type, UdfLogicalType childType, long arraySize, UdfLogicalType keyType, + UdfLogicalType valueType, String[] fieldNames, UdfLogicalType[] fieldTypes, + String[] enumValues, int decimalWidth, int decimalScale) { + this.type = Objects.requireNonNull(type, "type"); + this.childType = childType; + this.arraySize = arraySize; + this.keyType = keyType; + this.valueType = valueType; + this.fieldNames = fieldNames; + this.fieldTypes = fieldTypes; + this.enumValues = enumValues; + this.decimalWidth = decimalWidth; + this.decimalScale = decimalScale; + } + + public static UdfLogicalType of(DuckDBColumnType type) { + Objects.requireNonNull(type, "type"); + switch (type) { + case LIST: + case ARRAY: + case MAP: + case STRUCT: + case UNION: + case ENUM: + throw new IllegalArgumentException("Use container/enum-specific factory for type " + type); + case DECIMAL: + return decimal(DEFAULT_DECIMAL_WIDTH, DEFAULT_DECIMAL_SCALE); + default: + return new UdfLogicalType(type, null, 0, null, null, null, null, null, 0, 0); + } + } + + public static UdfLogicalType list(UdfLogicalType childType) { + return new UdfLogicalType(DuckDBColumnType.LIST, Objects.requireNonNull(childType, "childType"), 0, null, null, + null, null, null, 0, 0); + } + + public static UdfLogicalType array(UdfLogicalType childType, long arraySize) { + if (arraySize <= 0) { + throw new IllegalArgumentException("arraySize must be > 0"); + } + return new UdfLogicalType(DuckDBColumnType.ARRAY, Objects.requireNonNull(childType, "childType"), arraySize, + null, null, null, null, null, 0, 0); + } + + public static UdfLogicalType map(UdfLogicalType keyType, UdfLogicalType valueType) { + return new UdfLogicalType(DuckDBColumnType.MAP, null, 0, Objects.requireNonNull(keyType, "keyType"), + Objects.requireNonNull(valueType, "valueType"), null, null, null, 0, 0); + } + + public static UdfLogicalType struct(String[] fieldNames, UdfLogicalType[] fieldTypes) { + validateFields(fieldNames, fieldTypes, "struct"); + return new UdfLogicalType(DuckDBColumnType.STRUCT, null, 0, null, null, fieldNames.clone(), fieldTypes.clone(), + null, 0, 0); + } + + public static UdfLogicalType unionType(String[] fieldNames, UdfLogicalType[] fieldTypes) { + validateFields(fieldNames, fieldTypes, "union"); + return new UdfLogicalType(DuckDBColumnType.UNION, null, 0, null, null, fieldNames.clone(), fieldTypes.clone(), + null, 0, 0); + } + + public static UdfLogicalType enumeration(String... enumValues) { + Objects.requireNonNull(enumValues, "enumValues"); + if (enumValues.length == 0) { + throw new IllegalArgumentException("enumValues must not be empty"); + } + String[] values = enumValues.clone(); + for (int i = 0; i < values.length; i++) { + if (values[i] == null || values[i].isEmpty()) { + throw new IllegalArgumentException("enumValues[" + i + "] must not be null/empty"); + } + } + return new UdfLogicalType(DuckDBColumnType.ENUM, null, 0, null, null, null, null, values, 0, 0); + } + + public static UdfLogicalType decimal(int width, int scale) { + if (width < DECIMAL_WIDTH_MIN || width > DECIMAL_WIDTH_MAX) { + throw new IllegalArgumentException("decimal width must be between " + DECIMAL_WIDTH_MIN + " and " + + DECIMAL_WIDTH_MAX); + } + if (scale < DECIMAL_SCALE_MIN || scale > width) { + throw new IllegalArgumentException("decimal scale must be between " + DECIMAL_SCALE_MIN + " and width"); + } + return new UdfLogicalType(DuckDBColumnType.DECIMAL, null, 0, null, null, null, null, null, width, scale); + } + + public DuckDBColumnType getType() { + return type; + } + + public UdfLogicalType getChildType() { + return childType; + } + + public long getArraySize() { + return arraySize; + } + + public UdfLogicalType getKeyType() { + return keyType; + } + + public UdfLogicalType getValueType() { + return valueType; + } + + public String[] getFieldNames() { + return fieldNames == null ? null : fieldNames.clone(); + } + + public UdfLogicalType[] getFieldTypes() { + return fieldTypes == null ? null : fieldTypes.clone(); + } + + public String[] getEnumValues() { + return enumValues == null ? null : enumValues.clone(); + } + + public int getDecimalWidth() { + return decimalWidth; + } + + public int getDecimalScale() { + return decimalScale; + } + + @Override + public String toString() { + return "UdfLogicalType{" + + "type=" + type + ", childType=" + childType + ", arraySize=" + arraySize + ", keyType=" + keyType + + ", valueType=" + valueType + ", fieldNames=" + Arrays.toString(fieldNames) + + ", fieldTypes=" + Arrays.toString(fieldTypes) + ", enumValues=" + Arrays.toString(enumValues) + + ", decimalWidth=" + decimalWidth + ", decimalScale=" + decimalScale + "}"; + } + + private static void validateFields(String[] fieldNames, UdfLogicalType[] fieldTypes, String kind) { + Objects.requireNonNull(fieldNames, "fieldNames"); + Objects.requireNonNull(fieldTypes, "fieldTypes"); + if (fieldNames.length == 0) { + throw new IllegalArgumentException(kind + " fieldNames must not be empty"); + } + if (fieldNames.length != fieldTypes.length) { + throw new IllegalArgumentException(kind + " fieldNames/fieldTypes length mismatch"); + } + for (int i = 0; i < fieldNames.length; i++) { + if (fieldNames[i] == null || fieldNames[i].isEmpty()) { + throw new IllegalArgumentException(kind + " fieldNames[" + i + "] must not be null/empty"); + } + if (fieldTypes[i] == null) { + throw new IllegalArgumentException(kind + " fieldTypes[" + i + "] must not be null"); + } + } + } +} diff --git a/src/main/java/org/duckdb/udf/UdfOptions.java b/src/main/java/org/duckdb/udf/UdfOptions.java new file mode 100644 index 000000000..3c5533e01 --- /dev/null +++ b/src/main/java/org/duckdb/udf/UdfOptions.java @@ -0,0 +1,31 @@ +package org.duckdb.udf; + +public final class UdfOptions { + public boolean deterministic = true; + public boolean nullSpecialHandling = false; + public boolean returnNullOnException = false; + public boolean varArgs = false; + + public UdfOptions() { + } + + public UdfOptions deterministic(boolean value) { + this.deterministic = value; + return this; + } + + public UdfOptions nullSpecialHandling(boolean value) { + this.nullSpecialHandling = value; + return this; + } + + public UdfOptions returnNullOnException(boolean value) { + this.returnNullOnException = value; + return this; + } + + public UdfOptions varArgs(boolean value) { + this.varArgs = value; + return this; + } +} diff --git a/src/test/java/org/duckdb/TestBindings.java b/src/test/java/org/duckdb/TestBindings.java index 2c97dfee7..cae1ceee8 100644 --- a/src/test/java/org/duckdb/TestBindings.java +++ b/src/test/java/org/duckdb/TestBindings.java @@ -6,15 +6,29 @@ import static org.duckdb.TestDuckDBJDBC.JDBC_URL; import static org.duckdb.test.Assertions.*; +import java.math.BigDecimal; +import java.math.BigInteger; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.sql.*; +import java.time.OffsetTime; +import java.time.ZoneOffset; import java.util.Arrays; +import java.util.EnumSet; +import org.duckdb.udf.UdfLogicalType; public class TestBindings { static final int STRING_T_SIZE_BYTES = 16; + private static void assertAccessors(DuckDBColumnType type, EnumSet expected) + throws Exception { + for (UdfTypeCatalog.Accessor accessor : UdfTypeCatalog.Accessor.values()) { + assertEquals(UdfTypeCatalog.supportsAccessor(type, accessor), expected.contains(accessor)); + } + assertEquals(UdfTypeCatalog.accessorMatrixView().get(type), expected); + } + public static void test_bindings_vector_size() throws Exception { long size = duckdb_vector_size(); assertTrue(size > 0); @@ -169,8 +183,9 @@ public static void test_bindings_array_vector() throws Exception { public static void test_bindings_struct_vector() throws Exception { ByteBuffer intType = duckdb_create_logical_type(DUCKDB_TYPE_INTEGER.typeId); ByteBuffer varcharType = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); - ByteBuffer structType = duckdb_create_struct_type(new ByteBuffer[] {intType, varcharType}, - new byte[][] {"foo".getBytes(UTF_8), "bar".getBytes(UTF_8)}); + ByteBuffer[] childTypes = new ByteBuffer[] {intType, varcharType}; + byte[][] childNames = new byte[][] {"foo".getBytes(UTF_8), "bar".getBytes(UTF_8)}; + ByteBuffer structType = duckdb_create_struct_type(childTypes, childNames); assertTrue(duckdb_get_type_id(structType) != DUCKDB_TYPE_INVALID.typeId); assertEquals(duckdb_struct_type_child_count(structType), 2L); @@ -241,6 +256,589 @@ public static void test_bindings_data_chunk() throws Exception { duckdb_destroy_logical_type(intType); } + public static void test_bindings_udf_type_catalog_mappings() throws Exception { + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.INTEGER)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.VARCHAR)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.BOOLEAN)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.TINYINT)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.SMALLINT)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.BIGINT)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.UTINYINT)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.USMALLINT)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.UINTEGER)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.UBIGINT)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.HUGEINT)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.UHUGEINT)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.FLOAT)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.DOUBLE)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.DECIMAL)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.BLOB)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.DATE)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.TIME)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.TIME_NS)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.TIMESTAMP)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.TIMESTAMP_S)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.TIMESTAMP_MS)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.TIMESTAMP_NS)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.TIME_WITH_TIME_ZONE)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE)); + assertTrue(UdfTypeCatalog.isScalarUdfImplemented(DuckDBColumnType.UUID)); + + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.BOOLEAN)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.TINYINT)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.SMALLINT)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.INTEGER)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.BIGINT)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.UTINYINT)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.USMALLINT)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.UINTEGER)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.UBIGINT)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.HUGEINT)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.UHUGEINT)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.FLOAT)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.DOUBLE)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.VARCHAR)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.DECIMAL)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.BLOB)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.DATE)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.TIME)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.TIME_NS)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.TIMESTAMP)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.TIMESTAMP_S)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.TIMESTAMP_MS)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.TIMESTAMP_NS)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.TIME_WITH_TIME_ZONE)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE)); + assertTrue(UdfTypeCatalog.isTableBindSchemaType(DuckDBColumnType.UUID)); + + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.BOOLEAN)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.TINYINT)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.SMALLINT)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.INTEGER)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.BIGINT)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.UTINYINT)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.USMALLINT)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.UINTEGER)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.UBIGINT)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.HUGEINT)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.UHUGEINT)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.FLOAT)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.DOUBLE)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.VARCHAR)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.DECIMAL)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.BLOB)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.DATE)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.TIME)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.TIME_NS)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.TIMESTAMP)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.TIMESTAMP_S)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.TIMESTAMP_MS)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.TIMESTAMP_NS)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.TIME_WITH_TIME_ZONE)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE)); + assertTrue(UdfTypeCatalog.isTableFunctionParameterType(DuckDBColumnType.UUID)); + + assertTrue(UdfTypeCatalog.isVarLenType(DuckDBColumnType.VARCHAR)); + assertTrue(UdfTypeCatalog.isVarLenType(DuckDBColumnType.BLOB)); + assertFalse(UdfTypeCatalog.isVarLenType(DuckDBColumnType.INTEGER)); + assertTrue(UdfTypeCatalog.requiresVectorRef(DuckDBColumnType.VARCHAR)); + assertTrue(UdfTypeCatalog.requiresVectorRef(DuckDBColumnType.BLOB)); + assertTrue(UdfTypeCatalog.requiresVectorRef(DuckDBColumnType.DECIMAL)); + assertFalse(UdfTypeCatalog.requiresVectorRef(DuckDBColumnType.TIMESTAMP)); + + assertTrue(UdfTypeCatalog.supportsAccessor(DuckDBColumnType.INTEGER, UdfTypeCatalog.Accessor.GET_INT)); + assertTrue(UdfTypeCatalog.supportsAccessor(DuckDBColumnType.INTEGER, UdfTypeCatalog.Accessor.SET_INT)); + assertFalse(UdfTypeCatalog.supportsAccessor(DuckDBColumnType.INTEGER, UdfTypeCatalog.Accessor.GET_LONG)); + assertTrue(UdfTypeCatalog.supportsAccessor(DuckDBColumnType.VARCHAR, UdfTypeCatalog.Accessor.GET_STRING)); + assertFalse(UdfTypeCatalog.supportsAccessor(DuckDBColumnType.VARCHAR, UdfTypeCatalog.Accessor.SET_DOUBLE)); + assertTrue(UdfTypeCatalog.supportsAccessor(DuckDBColumnType.BLOB, UdfTypeCatalog.Accessor.GET_BYTES)); + assertTrue(UdfTypeCatalog.supportsAccessor(DuckDBColumnType.DECIMAL, UdfTypeCatalog.Accessor.GET_DECIMAL)); + assertTrue(UdfTypeCatalog.supportsAccessor(DuckDBColumnType.DATE, UdfTypeCatalog.Accessor.GET_INT)); + assertTrue(UdfTypeCatalog.supportsAccessor(DuckDBColumnType.TIMESTAMP, UdfTypeCatalog.Accessor.GET_LONG)); + + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.BOOLEAN), DUCKDB_TYPE_BOOLEAN.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.TINYINT), DUCKDB_TYPE_TINYINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.SMALLINT), DUCKDB_TYPE_SMALLINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.INTEGER), DUCKDB_TYPE_INTEGER.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.BIGINT), DUCKDB_TYPE_BIGINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.UTINYINT), DUCKDB_TYPE_UTINYINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.USMALLINT), DUCKDB_TYPE_USMALLINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.UINTEGER), DUCKDB_TYPE_UINTEGER.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.UBIGINT), DUCKDB_TYPE_UBIGINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.HUGEINT), DUCKDB_TYPE_HUGEINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.UHUGEINT), DUCKDB_TYPE_UHUGEINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.FLOAT), DUCKDB_TYPE_FLOAT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.DOUBLE), DUCKDB_TYPE_DOUBLE.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.VARCHAR), DUCKDB_TYPE_VARCHAR.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.BLOB), DUCKDB_TYPE_BLOB.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.DECIMAL), DUCKDB_TYPE_DECIMAL.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.DATE), DUCKDB_TYPE_DATE.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.TIME), DUCKDB_TYPE_TIME.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.TIME_NS), DUCKDB_TYPE_TIME_NS.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.TIMESTAMP), DUCKDB_TYPE_TIMESTAMP.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.TIMESTAMP_S), DUCKDB_TYPE_TIMESTAMP_S.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.TIMESTAMP_MS), DUCKDB_TYPE_TIMESTAMP_MS.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.TIMESTAMP_NS), DUCKDB_TYPE_TIMESTAMP_NS.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.TIME_WITH_TIME_ZONE), DUCKDB_TYPE_TIME_TZ.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE), + DUCKDB_TYPE_TIMESTAMP_TZ.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(DuckDBColumnType.UUID), DUCKDB_TYPE_UUID.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForTableFunctionParameter(DuckDBColumnType.INTEGER), + DUCKDB_TYPE_INTEGER.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForTableFunctionParameter(DuckDBColumnType.DOUBLE), + DUCKDB_TYPE_DOUBLE.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForTableFunctionParameter(DuckDBColumnType.VARCHAR), + DUCKDB_TYPE_VARCHAR.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForTableFunctionParameter(DuckDBColumnType.DECIMAL), + DUCKDB_TYPE_DECIMAL.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForTableFunctionParameter(DuckDBColumnType.TIMESTAMP_NS), + DUCKDB_TYPE_TIMESTAMP_NS.typeId); + assertEquals(UdfTypeCatalog.fromCapiTypeId(DUCKDB_TYPE_INTEGER.typeId), DuckDBColumnType.INTEGER); + assertEquals(UdfTypeCatalog.fromCapiTypeId(DUCKDB_TYPE_VARCHAR.typeId), DuckDBColumnType.VARCHAR); + assertEquals(UdfTypeCatalog.fromCapiTypeId(DUCKDB_TYPE_BLOB.typeId), DuckDBColumnType.BLOB); + assertEquals(UdfTypeCatalog.fromCapiTypeId(DUCKDB_TYPE_TIME_NS.typeId), DuckDBColumnType.TIME_NS); + assertEquals(UdfTypeCatalog.fromCapiTypeId(DUCKDB_TYPE_TIME_TZ.typeId), DuckDBColumnType.TIME_WITH_TIME_ZONE); + + assertThrows(() -> { UdfTypeCatalog.fromCapiTypeId(-1); }, SQLFeatureNotSupportedException.class); + } + + public static void test_bindings_udf_java_type_mapper_biginteger_maps_to_hugeint() throws Exception { + UdfLogicalType mapped = UdfJavaTypeMapper.toLogicalType(BigInteger.class); + assertEquals(mapped.getType(), DuckDBColumnType.HUGEINT); + } + + public static void test_bindings_udf_type_catalog_accessor_matrix_all_supported_types() throws Exception { + assertAccessors(DuckDBColumnType.BOOLEAN, + EnumSet.of(UdfTypeCatalog.Accessor.GET_BOOLEAN, UdfTypeCatalog.Accessor.SET_BOOLEAN)); + assertAccessors(DuckDBColumnType.TINYINT, + EnumSet.of(UdfTypeCatalog.Accessor.GET_INT, UdfTypeCatalog.Accessor.SET_INT)); + assertAccessors(DuckDBColumnType.SMALLINT, + EnumSet.of(UdfTypeCatalog.Accessor.GET_INT, UdfTypeCatalog.Accessor.SET_INT)); + assertAccessors(DuckDBColumnType.INTEGER, + EnumSet.of(UdfTypeCatalog.Accessor.GET_INT, UdfTypeCatalog.Accessor.SET_INT)); + assertAccessors(DuckDBColumnType.BIGINT, + EnumSet.of(UdfTypeCatalog.Accessor.GET_LONG, UdfTypeCatalog.Accessor.SET_LONG)); + assertAccessors(DuckDBColumnType.UTINYINT, + EnumSet.of(UdfTypeCatalog.Accessor.GET_INT, UdfTypeCatalog.Accessor.SET_INT)); + assertAccessors(DuckDBColumnType.USMALLINT, + EnumSet.of(UdfTypeCatalog.Accessor.GET_INT, UdfTypeCatalog.Accessor.SET_INT)); + assertAccessors(DuckDBColumnType.UINTEGER, + EnumSet.of(UdfTypeCatalog.Accessor.GET_LONG, UdfTypeCatalog.Accessor.SET_LONG)); + assertAccessors(DuckDBColumnType.UBIGINT, + EnumSet.of(UdfTypeCatalog.Accessor.GET_LONG, UdfTypeCatalog.Accessor.SET_LONG)); + assertAccessors(DuckDBColumnType.HUGEINT, + EnumSet.of(UdfTypeCatalog.Accessor.GET_BYTES, UdfTypeCatalog.Accessor.SET_BYTES)); + assertAccessors(DuckDBColumnType.UHUGEINT, + EnumSet.of(UdfTypeCatalog.Accessor.GET_BYTES, UdfTypeCatalog.Accessor.SET_BYTES)); + assertAccessors(DuckDBColumnType.FLOAT, + EnumSet.of(UdfTypeCatalog.Accessor.GET_FLOAT, UdfTypeCatalog.Accessor.SET_FLOAT)); + assertAccessors(DuckDBColumnType.DOUBLE, + EnumSet.of(UdfTypeCatalog.Accessor.GET_DOUBLE, UdfTypeCatalog.Accessor.SET_DOUBLE)); + assertAccessors(DuckDBColumnType.VARCHAR, + EnumSet.of(UdfTypeCatalog.Accessor.GET_STRING, UdfTypeCatalog.Accessor.SET_STRING)); + assertAccessors(DuckDBColumnType.BLOB, + EnumSet.of(UdfTypeCatalog.Accessor.GET_BYTES, UdfTypeCatalog.Accessor.SET_BYTES)); + assertAccessors(DuckDBColumnType.DECIMAL, + EnumSet.of(UdfTypeCatalog.Accessor.GET_DECIMAL, UdfTypeCatalog.Accessor.SET_DECIMAL)); + assertAccessors(DuckDBColumnType.DATE, + EnumSet.of(UdfTypeCatalog.Accessor.GET_INT, UdfTypeCatalog.Accessor.SET_INT)); + assertAccessors(DuckDBColumnType.TIME, + EnumSet.of(UdfTypeCatalog.Accessor.GET_LONG, UdfTypeCatalog.Accessor.SET_LONG)); + assertAccessors(DuckDBColumnType.TIME_NS, + EnumSet.of(UdfTypeCatalog.Accessor.GET_LONG, UdfTypeCatalog.Accessor.SET_LONG)); + assertAccessors(DuckDBColumnType.TIMESTAMP, + EnumSet.of(UdfTypeCatalog.Accessor.GET_LONG, UdfTypeCatalog.Accessor.SET_LONG)); + assertAccessors(DuckDBColumnType.TIMESTAMP_S, + EnumSet.of(UdfTypeCatalog.Accessor.GET_LONG, UdfTypeCatalog.Accessor.SET_LONG)); + assertAccessors(DuckDBColumnType.TIMESTAMP_MS, + EnumSet.of(UdfTypeCatalog.Accessor.GET_LONG, UdfTypeCatalog.Accessor.SET_LONG)); + assertAccessors(DuckDBColumnType.TIMESTAMP_NS, + EnumSet.of(UdfTypeCatalog.Accessor.GET_LONG, UdfTypeCatalog.Accessor.SET_LONG)); + assertAccessors(DuckDBColumnType.TIME_WITH_TIME_ZONE, + EnumSet.of(UdfTypeCatalog.Accessor.GET_LONG, UdfTypeCatalog.Accessor.SET_LONG)); + assertAccessors(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, + EnumSet.of(UdfTypeCatalog.Accessor.GET_LONG, UdfTypeCatalog.Accessor.SET_LONG)); + assertAccessors(DuckDBColumnType.UUID, + EnumSet.of(UdfTypeCatalog.Accessor.GET_BYTES, UdfTypeCatalog.Accessor.SET_BYTES)); + } + + public static void test_bindings_udf_scalar_registration_type_ids() throws Exception { + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.BOOLEAN), + DUCKDB_TYPE_BOOLEAN.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.TINYINT), + DUCKDB_TYPE_TINYINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.SMALLINT), + DUCKDB_TYPE_SMALLINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.INTEGER), + DUCKDB_TYPE_INTEGER.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.BIGINT), + DUCKDB_TYPE_BIGINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.UTINYINT), + DUCKDB_TYPE_UTINYINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.USMALLINT), + DUCKDB_TYPE_USMALLINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.UINTEGER), + DUCKDB_TYPE_UINTEGER.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.UBIGINT), + DUCKDB_TYPE_UBIGINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.HUGEINT), + DUCKDB_TYPE_HUGEINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.UHUGEINT), + DUCKDB_TYPE_UHUGEINT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.FLOAT), + DUCKDB_TYPE_FLOAT.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.DOUBLE), + DUCKDB_TYPE_DOUBLE.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.VARCHAR), + DUCKDB_TYPE_VARCHAR.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.DECIMAL), + DUCKDB_TYPE_DECIMAL.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.BLOB), DUCKDB_TYPE_BLOB.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.DATE), DUCKDB_TYPE_DATE.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.TIME), DUCKDB_TYPE_TIME.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.TIME_NS), + DUCKDB_TYPE_TIME_NS.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.TIMESTAMP), + DUCKDB_TYPE_TIMESTAMP.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.TIMESTAMP_S), + DUCKDB_TYPE_TIMESTAMP_S.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.TIMESTAMP_MS), + DUCKDB_TYPE_TIMESTAMP_MS.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.TIMESTAMP_NS), + DUCKDB_TYPE_TIMESTAMP_NS.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.TIME_WITH_TIME_ZONE), + DUCKDB_TYPE_TIME_TZ.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE), + DUCKDB_TYPE_TIMESTAMP_TZ.typeId); + assertEquals(UdfTypeCatalog.toCapiTypeIdForScalarRegistration(DuckDBColumnType.UUID), DUCKDB_TYPE_UUID.typeId); + } + + public static void test_bindings_udf_capi_type_roundtrip_supported_types() throws Exception { + int[] supportedTypeIds = new int[] { + DUCKDB_TYPE_BOOLEAN.typeId, DUCKDB_TYPE_TINYINT.typeId, DUCKDB_TYPE_SMALLINT.typeId, + DUCKDB_TYPE_INTEGER.typeId, DUCKDB_TYPE_BIGINT.typeId, DUCKDB_TYPE_FLOAT.typeId, + DUCKDB_TYPE_UTINYINT.typeId, DUCKDB_TYPE_USMALLINT.typeId, DUCKDB_TYPE_UINTEGER.typeId, + DUCKDB_TYPE_UBIGINT.typeId, DUCKDB_TYPE_HUGEINT.typeId, DUCKDB_TYPE_UHUGEINT.typeId, + DUCKDB_TYPE_DOUBLE.typeId, DUCKDB_TYPE_VARCHAR.typeId, DUCKDB_TYPE_BLOB.typeId, + DUCKDB_TYPE_DATE.typeId, DUCKDB_TYPE_TIME.typeId, DUCKDB_TYPE_TIME_NS.typeId, + DUCKDB_TYPE_TIME_TZ.typeId, DUCKDB_TYPE_TIMESTAMP.typeId, DUCKDB_TYPE_TIMESTAMP_S.typeId, + DUCKDB_TYPE_TIMESTAMP_MS.typeId, DUCKDB_TYPE_TIMESTAMP_NS.typeId, DUCKDB_TYPE_TIMESTAMP_TZ.typeId, + DUCKDB_TYPE_UUID.typeId}; + for (int typeId : supportedTypeIds) { + ByteBuffer lt = duckdb_create_logical_type(typeId); + assertNotNull(lt); + assertEquals(duckdb_get_type_id(lt), typeId); + assertEquals(UdfTypeCatalog.toCapiTypeId(UdfTypeCatalog.fromCapiTypeId(typeId)), typeId); + duckdb_destroy_logical_type(lt); + } + } + + public static void test_bindings_udf_scalar_writer_fixed_width_accessors() throws Exception { + int rowCount = 8; + ByteBuffer validity = ByteBuffer.allocateDirect((rowCount + 7) / 8); + validity.put(0, (byte) 0xFF); + + UdfScalarWriter boolVector = new UdfScalarWriter(DUCKDB_TYPE_BOOLEAN.typeId, + ByteBuffer.allocateDirect(rowCount), null, validity, rowCount); + boolVector.setBoolean(0, true); + boolVector.setBoolean(1, false); + assertTrue(boolVector.getBoolean(0)); + assertFalse(boolVector.getBoolean(1)); + assertThrows(() -> { boolVector.getInt(0); }, UnsupportedOperationException.class); + + UdfScalarWriter tinyVector = new UdfScalarWriter(DUCKDB_TYPE_TINYINT.typeId, + ByteBuffer.allocateDirect(rowCount), null, validity, rowCount); + tinyVector.setInt(0, 127); + tinyVector.setInt(1, -12); + assertEquals(tinyVector.getInt(0), 127); + assertEquals(tinyVector.getInt(1), -12); + assertThrows(() -> { tinyVector.setInt(2, 128); }, IllegalArgumentException.class); + + UdfScalarWriter smallVector = new UdfScalarWriter( + DUCKDB_TYPE_SMALLINT.typeId, ByteBuffer.allocateDirect(rowCount * Short.BYTES), null, validity, rowCount); + smallVector.setInt(0, 12345); + smallVector.setInt(1, -12345); + assertEquals(smallVector.getInt(0), 12345); + assertEquals(smallVector.getInt(1), -12345); + assertThrows(() -> { smallVector.setInt(2, 40000); }, IllegalArgumentException.class); + + UdfScalarWriter unsignedTinyVector = new UdfScalarWriter( + DUCKDB_TYPE_UTINYINT.typeId, ByteBuffer.allocateDirect(rowCount), null, validity, rowCount); + unsignedTinyVector.setInt(0, 255); + assertEquals(unsignedTinyVector.getInt(0), 255); + assertThrows(() -> { unsignedTinyVector.setInt(1, -1); }, IllegalArgumentException.class); + assertThrows(() -> { unsignedTinyVector.setInt(1, 256); }, IllegalArgumentException.class); + + UdfScalarWriter unsignedSmallVector = new UdfScalarWriter( + DUCKDB_TYPE_USMALLINT.typeId, ByteBuffer.allocateDirect(rowCount * Short.BYTES), null, validity, rowCount); + unsignedSmallVector.setInt(0, 65535); + assertEquals(unsignedSmallVector.getInt(0), 65535); + assertThrows(() -> { unsignedSmallVector.setInt(1, -1); }, IllegalArgumentException.class); + assertThrows(() -> { unsignedSmallVector.setInt(1, 70000); }, IllegalArgumentException.class); + + UdfScalarWriter intVector = new UdfScalarWriter( + DUCKDB_TYPE_INTEGER.typeId, ByteBuffer.allocateDirect(rowCount * Integer.BYTES), null, validity, rowCount); + intVector.setInt(0, 99); + assertEquals(intVector.getInt(0), 99); + intVector.setNull(0); + assertTrue(intVector.isNull(0)); + intVector.setInt(0, 77); + assertFalse(intVector.isNull(0)); + assertThrows(() -> { intVector.getLong(0); }, UnsupportedOperationException.class); + + UdfScalarWriter longVector = new UdfScalarWriter( + DUCKDB_TYPE_BIGINT.typeId, ByteBuffer.allocateDirect(rowCount * Long.BYTES), null, validity, rowCount); + longVector.setLong(0, 9_000_000_000L); + assertEquals(longVector.getLong(0), 9_000_000_000L); + longVector.setObject(1, new BigDecimal("9223372036854775807")); + assertEquals(longVector.getLong(1), Long.MAX_VALUE); + assertThrows(() -> { longVector.setObject(2, new BigDecimal("1.5")); }, IllegalArgumentException.class); + assertThrows( + () -> { longVector.setObject(3, new BigInteger("9223372036854775808")); }, IllegalArgumentException.class); + assertThrows(() -> { longVector.setInt(0, 1); }, UnsupportedOperationException.class); + + UdfScalarWriter unsignedIntVector = new UdfScalarWriter( + DUCKDB_TYPE_UINTEGER.typeId, ByteBuffer.allocateDirect(rowCount * Integer.BYTES), null, validity, rowCount); + unsignedIntVector.setLong(0, 4_000_000_000L); + assertEquals(unsignedIntVector.getLong(0), 4_000_000_000L); + assertThrows(() -> { unsignedIntVector.setLong(1, -1); }, IllegalArgumentException.class); + + UdfScalarWriter unsignedBigVector = new UdfScalarWriter( + DUCKDB_TYPE_UBIGINT.typeId, ByteBuffer.allocateDirect(rowCount * Long.BYTES), null, validity, rowCount); + unsignedBigVector.setLong(0, -1L); + assertEquals(unsignedBigVector.getLong(0), -1L); + unsignedBigVector.setObject(1, new BigInteger("18446744073709551615")); + assertEquals(unsignedBigVector.getLong(1), -1L); + assertThrows(() -> { unsignedBigVector.setObject(2, new BigDecimal("1.5")); }, IllegalArgumentException.class); + assertThrows(() -> { + unsignedBigVector.setObject(3, new BigInteger("18446744073709551616")); + }, IllegalArgumentException.class); + + UdfScalarWriter hugeVector = new UdfScalarWriter( + DUCKDB_TYPE_HUGEINT.typeId, ByteBuffer.allocateDirect(rowCount * 16), null, validity, rowCount); + byte[] hugeValue = new byte[16]; + hugeValue[0] = 42; + hugeValue[15] = 7; + hugeVector.setBytes(0, hugeValue); + assertEquals(hugeVector.getBytes(0), hugeValue); + BigInteger hugeBi = new BigInteger("170141183460469231731687303715884105727"); + hugeVector.setObject(1, hugeBi); + assertEquals(hugeVector.getBigInteger(1), hugeBi); + BigInteger hugeNegOne = BigInteger.valueOf(-1); + hugeVector.setObject(2, hugeNegOne); + assertEquals(hugeVector.getBigInteger(2), hugeNegOne); + BigInteger hugeMin = new BigInteger("-170141183460469231731687303715884105728"); + hugeVector.setObject(3, hugeMin); + assertEquals(hugeVector.getBigInteger(3), hugeMin); + BigInteger hugeMinPlusOne = hugeMin.add(BigInteger.ONE); + hugeVector.setObject(4, hugeMinPlusOne); + assertEquals(hugeVector.getBigInteger(4), hugeMinPlusOne); + assertThrows(() -> { + hugeVector.setObject(5, new BigInteger("170141183460469231731687303715884105728")); + }, IllegalArgumentException.class); + + UdfScalarWriter uhugeVector = new UdfScalarWriter( + DUCKDB_TYPE_UHUGEINT.typeId, ByteBuffer.allocateDirect(rowCount * 16), null, validity, rowCount); + byte[] uhugeValue = new byte[16]; + uhugeValue[3] = 9; + uhugeValue[8] = 11; + uhugeVector.setBytes(0, uhugeValue); + assertEquals(uhugeVector.getBytes(0), uhugeValue); + BigInteger uhugeBi = new BigInteger("340282366920938463463374607431768211455"); + uhugeVector.setObject(1, uhugeBi); + assertEquals(uhugeVector.getBigInteger(1), uhugeBi); + assertThrows(() -> { uhugeVector.setObject(2, BigInteger.valueOf(-1)); }, IllegalArgumentException.class); + + UdfScalarWriter uuidVector = new UdfScalarWriter( + DUCKDB_TYPE_UUID.typeId, ByteBuffer.allocateDirect(rowCount * 16), null, validity, rowCount); + byte[] uuidBytes = new byte[16]; + for (int i = 0; i < uuidBytes.length; i++) { + uuidBytes[i] = (byte) (i + 1); + } + uuidVector.setBytes(0, uuidBytes); + assertEquals(uuidVector.getBytes(0), uuidBytes); + assertThrows(() -> { uuidVector.setBytes(1, new byte[8]); }, IllegalArgumentException.class); + + UdfScalarWriter timetzVector = new UdfScalarWriter( + DUCKDB_TYPE_TIME_TZ.typeId, ByteBuffer.allocateDirect(rowCount * Long.BYTES), null, validity, rowCount); + timetzVector.setLong(0, 123456789L); + assertEquals(timetzVector.getLong(0), 123456789L); + timetzVector.setObject(1, OffsetTime.of(1, 2, 3, 0, ZoneOffset.ofHoursMinutesSeconds(15, 59, 59))); + assertThrows(() -> { + timetzVector.setObject(2, OffsetTime.of(1, 2, 3, 0, ZoneOffset.ofHours(16))); + }, IllegalArgumentException.class); + + UdfScalarWriter timestamptzVector = + new UdfScalarWriter(DUCKDB_TYPE_TIMESTAMP_TZ.typeId, ByteBuffer.allocateDirect(rowCount * Long.BYTES), null, + validity, rowCount); + timestamptzVector.setLong(0, 987654321L); + assertEquals(timestamptzVector.getLong(0), 987654321L); + + UdfScalarWriter floatVector = new UdfScalarWriter( + DUCKDB_TYPE_FLOAT.typeId, ByteBuffer.allocateDirect(rowCount * Float.BYTES), null, validity, rowCount); + floatVector.setFloat(0, 1.25f); + assertEquals(floatVector.getFloat(0), 1.25f, 0.0001f); + + UdfScalarWriter doubleVector = new UdfScalarWriter( + DUCKDB_TYPE_DOUBLE.typeId, ByteBuffer.allocateDirect(rowCount * Double.BYTES), null, validity, rowCount); + doubleVector.setDouble(0, 42.5d); + assertEquals(doubleVector.getDouble(0), 42.5d, 0.0000001d); + + assertThrows(() -> { + new UdfScalarWriter(DUCKDB_TYPE_INTEGER.typeId, null, null, validity, rowCount); + }, IllegalArgumentException.class); + assertThrows(() -> { + new UdfScalarWriter(-1, ByteBuffer.allocateDirect(Integer.BYTES), null, validity, 1); + }, IllegalArgumentException.class); + } + + public static void test_bindings_udf_scalar_writer_varchar_accessors() throws Exception { + ByteBuffer lt = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR.typeId); + ByteBuffer vec = duckdb_create_vector(lt); + try { + duckdb_vector_ensure_validity_writable(vec); + int rowCount = (int) duckdb_vector_size(); + ByteBuffer validity = duckdb_vector_get_validity(vec, rowCount); + UdfScalarWriter stringVector = + new UdfScalarWriter(DUCKDB_TYPE_VARCHAR.typeId, null, vec, validity, rowCount); + stringVector.setString(0, "alpha"); + assertEquals(stringVector.getString(0), "alpha"); + stringVector.setString(1, null); + assertTrue(stringVector.isNull(1)); + assertNull(stringVector.getString(1)); + assertThrows(() -> { stringVector.getDouble(0); }, UnsupportedOperationException.class); + } finally { + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + + ByteBuffer blobType = duckdb_create_logical_type(DUCKDB_TYPE_BLOB.typeId); + ByteBuffer blobVec = duckdb_create_vector(blobType); + try { + duckdb_vector_ensure_validity_writable(blobVec); + int rowCount = (int) duckdb_vector_size(); + ByteBuffer validity = duckdb_vector_get_validity(blobVec, rowCount); + UdfScalarWriter blobVector = + new UdfScalarWriter(DUCKDB_TYPE_BLOB.typeId, null, blobVec, validity, rowCount); + byte[] payload = new byte[] {1, 2, 3, 4}; + blobVector.setBytes(0, payload); + assertEquals(blobVector.getBytes(0), payload); + blobVector.setBytes(1, null); + assertTrue(blobVector.isNull(1)); + assertNull(blobVector.getBytes(1)); + assertThrows(() -> { blobVector.getString(0); }, UnsupportedOperationException.class); + } finally { + duckdb_destroy_vector(blobVec); + duckdb_destroy_logical_type(blobType); + } + } + + public static void test_bindings_udf_scalar_writer_native_vectors_supported_types() throws Exception { + int rowCount = (int) duckdb_vector_size(); + int[] udfTypeIds = new int[] { + DUCKDB_TYPE_BOOLEAN.typeId, DUCKDB_TYPE_TINYINT.typeId, DUCKDB_TYPE_SMALLINT.typeId, + DUCKDB_TYPE_INTEGER.typeId, DUCKDB_TYPE_BIGINT.typeId, DUCKDB_TYPE_UTINYINT.typeId, + DUCKDB_TYPE_USMALLINT.typeId, DUCKDB_TYPE_UINTEGER.typeId, DUCKDB_TYPE_UBIGINT.typeId, + DUCKDB_TYPE_HUGEINT.typeId, DUCKDB_TYPE_UHUGEINT.typeId, DUCKDB_TYPE_FLOAT.typeId, + DUCKDB_TYPE_DOUBLE.typeId, DUCKDB_TYPE_VARCHAR.typeId, DUCKDB_TYPE_BLOB.typeId, + DUCKDB_TYPE_DATE.typeId, DUCKDB_TYPE_TIME.typeId, DUCKDB_TYPE_TIME_NS.typeId, + DUCKDB_TYPE_TIME_TZ.typeId, DUCKDB_TYPE_TIMESTAMP.typeId, DUCKDB_TYPE_TIMESTAMP_S.typeId, + DUCKDB_TYPE_TIMESTAMP_MS.typeId, DUCKDB_TYPE_TIMESTAMP_NS.typeId, DUCKDB_TYPE_TIMESTAMP_TZ.typeId, + DUCKDB_TYPE_UUID.typeId}; + + for (int typeId : udfTypeIds) { + ByteBuffer lt = duckdb_create_logical_type(typeId); + ByteBuffer vec = duckdb_create_vector(lt); + try { + duckdb_vector_ensure_validity_writable(vec); + ByteBuffer validity = duckdb_vector_get_validity(vec, rowCount); + UdfScalarWriter column; + if (UdfTypeCatalog.requiresVectorRef(UdfTypeCatalog.fromCapiTypeId(typeId))) { + column = new UdfScalarWriter(typeId, null, vec, validity, rowCount); + } else { + long widthBytes = CAPIType.capiTypeFromTypeId(typeId).widthBytes; + ByteBuffer data = duckdb_vector_get_data(vec, rowCount * widthBytes); + column = new UdfScalarWriter(typeId, data, null, validity, rowCount); + } + + if (typeId == DUCKDB_TYPE_BOOLEAN.typeId) { + column.setBoolean(0, true); + assertTrue(column.getBoolean(0)); + } else if (typeId == DUCKDB_TYPE_TINYINT.typeId) { + column.setInt(0, -7); + assertEquals(column.getInt(0), -7); + } else if (typeId == DUCKDB_TYPE_SMALLINT.typeId) { + column.setInt(0, 32123); + assertEquals(column.getInt(0), 32123); + } else if (typeId == DUCKDB_TYPE_INTEGER.typeId) { + column.setInt(0, 123456789); + assertEquals(column.getInt(0), 123456789); + } else if (typeId == DUCKDB_TYPE_BIGINT.typeId) { + column.setLong(0, 9_876_543_210L); + assertEquals(column.getLong(0), 9_876_543_210L); + } else if (typeId == DUCKDB_TYPE_UTINYINT.typeId) { + column.setInt(0, 250); + assertEquals(column.getInt(0), 250); + } else if (typeId == DUCKDB_TYPE_USMALLINT.typeId) { + column.setInt(0, 65000); + assertEquals(column.getInt(0), 65000); + } else if (typeId == DUCKDB_TYPE_UINTEGER.typeId) { + column.setLong(0, 4_000_000_000L); + assertEquals(column.getLong(0), 4_000_000_000L); + } else if (typeId == DUCKDB_TYPE_UBIGINT.typeId) { + column.setLong(0, -1L); + assertEquals(column.getLong(0), -1L); + } else if (typeId == DUCKDB_TYPE_HUGEINT.typeId || typeId == DUCKDB_TYPE_UHUGEINT.typeId || + typeId == DUCKDB_TYPE_UUID.typeId) { + byte[] bytes = new byte[16]; + bytes[0] = 1; + bytes[15] = 42; + column.setBytes(0, bytes); + assertEquals(column.getBytes(0), bytes); + } else if (typeId == DUCKDB_TYPE_FLOAT.typeId) { + column.setFloat(0, 3.25f); + assertEquals(column.getFloat(0), 3.25f, 0.0001f); + } else if (typeId == DUCKDB_TYPE_DOUBLE.typeId) { + column.setDouble(0, 8.125d); + assertEquals(column.getDouble(0), 8.125d, 0.0000001d); + } else if (typeId == DUCKDB_TYPE_VARCHAR.typeId) { + column.setString(0, "alpha"); + assertEquals(column.getString(0), "alpha"); + } else if (typeId == DUCKDB_TYPE_BLOB.typeId) { + byte[] blob = "blob-extended".getBytes(StandardCharsets.UTF_8); + column.setBytes(0, blob); + assertEquals(column.getBytes(0), blob); + } else if (typeId == DUCKDB_TYPE_DECIMAL.typeId) { + column.setBigDecimal(0, new BigDecimal("123.75")); + assertEquals(column.getBigDecimal(0), new BigDecimal("123.75")); + } else if (typeId == DUCKDB_TYPE_DATE.typeId) { + column.setInt(0, 19723); + assertEquals(column.getInt(0), 19723); + } else if (typeId == DUCKDB_TYPE_TIME.typeId || typeId == DUCKDB_TYPE_TIME_NS.typeId || + typeId == DUCKDB_TYPE_TIME_TZ.typeId || typeId == DUCKDB_TYPE_TIMESTAMP.typeId || + typeId == DUCKDB_TYPE_TIMESTAMP_S.typeId || typeId == DUCKDB_TYPE_TIMESTAMP_MS.typeId || + typeId == DUCKDB_TYPE_TIMESTAMP_NS.typeId || typeId == DUCKDB_TYPE_TIMESTAMP_TZ.typeId) { + column.setLong(0, 123456789L); + assertEquals(column.getLong(0), 123456789L); + } else { + fail("Unhandled UDF type id: " + typeId); + } + + column.setNull(1); + assertTrue(column.isNull(1)); + if (typeId == DUCKDB_TYPE_VARCHAR.typeId) { + assertNull(column.getString(1)); + } else if (typeId == DUCKDB_TYPE_BLOB.typeId || typeId == DUCKDB_TYPE_HUGEINT.typeId || + typeId == DUCKDB_TYPE_UHUGEINT.typeId || typeId == DUCKDB_TYPE_UUID.typeId) { + assertNull(column.getBytes(1)); + } + } finally { + duckdb_destroy_vector(vec); + duckdb_destroy_logical_type(lt); + } + } + } + public static void test_bindings_appender() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { @@ -301,6 +899,90 @@ public static void test_bindings_appender() throws Exception { } } + public static void test_bindings_scalar_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + duckdb_register_scalar_function_java(conn.connRef, "bindings_java_scalar".getBytes(UTF_8), + (ctx, args, out, rowCount) + -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, 42); + } + }, + new org.duckdb.udf.UdfLogicalType[0], + org.duckdb.udf.UdfLogicalType.of(DuckDBColumnType.INTEGER), false, + false, true, false); + + try (ResultSet rs = stmt.executeQuery("SELECT bindings_java_scalar()")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 42); + assertFalse(rs.next()); + } + + duckdb_register_scalar_function_java( + conn.connRef, "bindings_java_scalar_add1".getBytes(UTF_8), + (ctx, args, out, rowCount) + -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, args[0].getInt(row) + 1); + } + }, + new org.duckdb.udf.UdfLogicalType[] {org.duckdb.udf.UdfLogicalType.of(DuckDBColumnType.INTEGER)}, + org.duckdb.udf.UdfLogicalType.of(DuckDBColumnType.INTEGER), false, false, true, false); + + try (ResultSet rs = stmt.executeQuery("SELECT bindings_java_scalar_add1(41)")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 42); + assertFalse(rs.next()); + } + } + } + + public static void test_bindings_table_function() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + duckdb_register_table_function_java( + conn.connRef, "bindings_range_java".getBytes(UTF_8), + new org.duckdb.udf.TableFunction() { + @Override + public org.duckdb.udf.TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new org.duckdb.udf.TableBindResult( + new String[] {"i"}, + new org.duckdb.udf.UdfLogicalType[] { + org.duckdb.udf.UdfLogicalType.of(DuckDBColumnType.BIGINT)}, + ((Number) parameters[0]).longValue()); + } + + @Override + public org.duckdb.udf.TableState init(org.duckdb.udf.InitContext ctx, + org.duckdb.udf.TableBindResult bind) { + long end = ((Number) bind.getBindState()).longValue(); + return new org.duckdb.udf.TableState(new long[] {0L, end}); + } + + @Override + public int produce(org.duckdb.udf.TableState state, org.duckdb.UdfOutputAppender out) { + long[] st = (long[]) state.getState(); + int produced = 0; + for (; produced < 64 && st[0] < st[1]; produced++, st[0]++) { + out.beginRow().append(st[0]).endRow(); + } + return out.getSize(); + } + }, + new org.duckdb.udf.UdfLogicalType[] {org.duckdb.udf.UdfLogicalType.of(DuckDBColumnType.BIGINT)}, true, + 4, true); + + try (ResultSet rs = stmt.executeQuery("SELECT * FROM bindings_range_java(5)")) { + for (int i = 0; i < 5; i++) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), (long) i); + } + assertFalse(rs.next()); + } + } + } + public static void test_bindings_decimal_type() throws Exception { try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); Statement stmt = conn.createStatement()) { diff --git a/src/test/java/org/duckdb/TestDuckDBJDBC.java b/src/test/java/org/duckdb/TestDuckDBJDBC.java index 3769121a7..25d1ce13b 100644 --- a/src/test/java/org/duckdb/TestDuckDBJDBC.java +++ b/src/test/java/org/duckdb/TestDuckDBJDBC.java @@ -30,6 +30,7 @@ import java.time.format.DateTimeFormatterBuilder; import java.time.format.ResolverStyle; import java.time.temporal.ChronoUnit; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -45,11 +46,18 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; import javax.sql.rowset.CachedRowSet; import javax.sql.rowset.RowSetProvider; import org.duckdb.test.TempDirectory; +import org.duckdb.udf.TableBindResult; +import org.duckdb.udf.TableFunctionDefinition; +import org.duckdb.udf.TableFunctionOptions; +import org.duckdb.udf.TableState; +import org.duckdb.udf.UdfLogicalType; +import org.duckdb.udf.UdfOptions; public class TestDuckDBJDBC { @@ -86,65 +94,3706 @@ private static void executeStatementWithThread(Statement statement, ExecutorServ } } + private static DuckDBColumnType[] scalarCoreTypes() { + return new DuckDBColumnType[] {DuckDBColumnType.BOOLEAN, DuckDBColumnType.TINYINT, DuckDBColumnType.SMALLINT, + DuckDBColumnType.INTEGER, DuckDBColumnType.BIGINT, DuckDBColumnType.FLOAT, + DuckDBColumnType.DOUBLE, DuckDBColumnType.VARCHAR}; + } + + private static DuckDBColumnType[] scalarExtendedTypes() { + return new DuckDBColumnType[] { + DuckDBColumnType.DECIMAL, DuckDBColumnType.BLOB, DuckDBColumnType.DATE, + DuckDBColumnType.TIME, DuckDBColumnType.TIME_NS, DuckDBColumnType.TIMESTAMP, + DuckDBColumnType.TIMESTAMP_S, DuckDBColumnType.TIMESTAMP_MS, DuckDBColumnType.TIMESTAMP_NS}; + } + + private static DuckDBColumnType[] scalarUnsignedAndSpecialTypes() { + return new DuckDBColumnType[] {DuckDBColumnType.UTINYINT, + DuckDBColumnType.USMALLINT, + DuckDBColumnType.UINTEGER, + DuckDBColumnType.UBIGINT, + DuckDBColumnType.HUGEINT, + DuckDBColumnType.UHUGEINT, + DuckDBColumnType.TIME_WITH_TIME_ZONE, + DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, + DuckDBColumnType.UUID}; + } + + private static String nonNullLiteralForType(DuckDBColumnType type) { + switch (type) { + case BOOLEAN: + return "TRUE::BOOLEAN"; + case TINYINT: + return "7::TINYINT"; + case SMALLINT: + return "32000::SMALLINT"; + case INTEGER: + return "123456::INTEGER"; + case BIGINT: + return "9876543210::BIGINT"; + case FLOAT: + return "1.25::FLOAT"; + case DOUBLE: + return "2.5::DOUBLE"; + case VARCHAR: + return "'duck'::VARCHAR"; + default: + throw new IllegalArgumentException("Unsupported test type: " + type); + } + } + + private static String sqlTypeNameForLiteral(DuckDBColumnType type) { + switch (type) { + case TIME_WITH_TIME_ZONE: + return "TIME WITH TIME ZONE"; + case TIMESTAMP_WITH_TIME_ZONE: + return "TIMESTAMP WITH TIME ZONE"; + default: + return type.name(); + } + } + + private static String nullLiteralForType(DuckDBColumnType type) { + return "NULL::" + sqlTypeNameForLiteral(type); + } + + private static String nonNullLiteralForExtendedType(DuckDBColumnType type) { + switch (type) { + case DECIMAL: + return "42.75::DECIMAL(18,2)"; + case BLOB: + return "'blob-extended'::BLOB"; + case DATE: + return "DATE '2024-01-03'"; + case TIME: + return "TIME '01:02:03.123456'"; + case TIME_NS: + return "TIME_NS '01:02:03.123456789'"; + case TIMESTAMP: + return "TIMESTAMP '2024-01-03 04:05:06.123456'"; + case TIMESTAMP_S: + return "TIMESTAMP_S '2024-01-03 04:05:06'"; + case TIMESTAMP_MS: + return "TIMESTAMP_MS '2024-01-03 04:05:06.123'"; + case TIMESTAMP_NS: + return "TIMESTAMP_NS '2024-01-03 04:05:06.123456789'"; + default: + throw new IllegalArgumentException("Unsupported extended test type: " + type); + } + } + + private static String nonNullLiteralForUnsignedAndSpecialType(DuckDBColumnType type) { + switch (type) { + case UTINYINT: + return "250::UTINYINT"; + case USMALLINT: + return "65000::USMALLINT"; + case UINTEGER: + return "4000000000::UINTEGER"; + case UBIGINT: + return "18446744073709551615::UBIGINT"; + case HUGEINT: + return "170141183460469231731687303715884105727::HUGEINT"; + case UHUGEINT: + return "340282366920938463463374607431768211455::UHUGEINT"; + case TIME_WITH_TIME_ZONE: + return "'01:02:03+05:30'::TIME WITH TIME ZONE"; + case TIMESTAMP_WITH_TIME_ZONE: + return "'2024-01-03 04:05:06+00'::TIMESTAMP WITH TIME ZONE"; + case UUID: + return "'550e8400-e29b-41d4-a716-446655440000'::UUID"; + default: + throw new IllegalArgumentException("Unsupported unsigned/special test type: " + type); + } + } + public static void test_connection() throws Exception { Connection conn = DriverManager.getConnection(JDBC_URL); assertTrue(conn.isValid(0)); assertFalse(conn.isClosed()); - Statement stmt = conn.createStatement(); + Statement stmt = conn.createStatement(); + + ResultSet rs = stmt.executeQuery("SELECT 42 as a"); + assertFalse(stmt.isClosed()); + assertFalse(rs.isClosed()); + + assertTrue(rs.next()); + int res = rs.getInt(1); + assertEquals(res, 42); + assertFalse(rs.wasNull()); + + res = rs.getInt(1); + assertEquals(res, 42); + assertFalse(rs.wasNull()); + + res = rs.getInt("a"); + assertEquals(res, 42); + assertFalse(rs.wasNull()); + + assertThrows(() -> rs.getInt(0), SQLException.class); + + assertThrows(() -> rs.getInt(2), SQLException.class); + + assertThrows(() -> rs.getInt("b"), SQLException.class); + + assertFalse(rs.next()); + assertFalse(rs.next()); + + rs.close(); + rs.close(); + assertTrue(rs.isClosed()); + + assertThrows(() -> rs.getInt(1), SQLException.class); + + stmt.close(); + stmt.close(); + assertTrue(stmt.isClosed()); + + conn.close(); + conn.close(); + assertFalse(conn.isValid(0)); + assertTrue(conn.isClosed()); + + assertThrows(conn::createStatement, SQLException.class); + } + + public static void test_execute_exception() throws Exception { + Connection conn = DriverManager.getConnection(JDBC_URL); + Statement stmt = conn.createStatement(); + + assertThrows(() -> { + ResultSet rs = stmt.executeQuery("SELECT"); + rs.next(); + }, SQLException.class); + } + + public static void test_range_java_smoke() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("range_java", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i"}, new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, + ((Number) parameters[0]).intValue()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 1024 && current < end; produced++, current++) { + out.setInt(0, produced, current); + } + st[0] = current; + return produced; + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT * FROM range_java(5)")) { + for (int i = 0; i < 5; i++) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), i); + } + assertFalse(rs.next()); + } + } + } + + public static void test_range_java_streaming_large() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("range_java_large", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i"}, new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, + ((Number) parameters[0]).intValue()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 256 && current < end; produced++, current++) { + out.setInt(0, produced, current); + } + st[0] = current; + return produced; + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT count(*), sum(i) FROM range_java_large(10000)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 10000L); + assertEquals(rs.getLong(2), 49995000L); + assertFalse(rs.next()); + } + } + } + + public static void test_range_java_output_appender_api() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("range_java_appender", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i"}, new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, + ((Number) parameters[0]).intValue()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 128 && current < end; produced++, current++) { + out.beginRow().append(current).endRow(); + } + st[0] = current; + return out.getSize(); + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT * FROM range_java_appender(5)")) { + for (int i = 0; i < 5; i++) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), i); + } + assertFalse(rs.next()); + } + } + } + + public static void test_table_function_bind_typed_parameters() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + Object[] observedParameters = new Object[3]; + boolean[] observedNullPath = new boolean[] {false}; + + conn.registerTableFunction( + "tf_bind_typed", + new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + if (parameters.length != 3) { + throw new IllegalStateException("Expected 3 bind parameters"); + } + observedParameters[0] = parameters[0]; + observedParameters[1] = parameters[1]; + observedParameters[2] = parameters[2]; + + int start = parameters[0] == null ? 0 : ((Number) parameters[0]).intValue(); + if (parameters[0] == null) { + observedNullPath[0] = true; + } + int delta = (int) Math.round(((Number) parameters[1]).doubleValue()); + int labelLen = ((String) parameters[2]).length(); + int end = start + delta + labelLen; + return new TableBindResult(new String[] {"i"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, end); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 1024 && current < end; produced++, current++) { + out.setInt(0, produced, current); + } + st[0] = current; + return produced; + } + }, + new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[] { + DuckDBColumnType.INTEGER, DuckDBColumnType.DOUBLE, DuckDBColumnType.VARCHAR})); + + try (ResultSet rs = stmt.executeQuery("SELECT count(*), sum(i) FROM tf_bind_typed(5, 2.0, 'abc')")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 10L); + assertEquals(rs.getLong(2), 45L); + assertFalse(rs.next()); + } + + assertTrue(observedParameters[0] instanceof Integer); + assertTrue(observedParameters[1] instanceof Double); + assertTrue(observedParameters[2] instanceof String); + assertEquals(observedParameters[0], 5); + assertEquals(observedParameters[1], 2.0d); + assertEquals(observedParameters[2], "abc"); + + try (ResultSet rs = stmt.executeQuery("SELECT count(*) FROM tf_bind_typed(NULL::INTEGER, 2.0, 'xy')")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 4L); + assertFalse(rs.next()); + } + assertTrue(observedNullPath[0]); + } + } + + public static void test_table_function_bind_typed_parameters_extended_types() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + Object[] observedParameters = new Object[5]; + + conn.registerTableFunction( + "tf_bind_extended_types", + new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + if (parameters.length != 5) { + throw new IllegalStateException("Expected 5 bind parameters"); + } + System.arraycopy(parameters, 0, observedParameters, 0, parameters.length); + int rows = (int) Math.round(((Number) parameters[0]).doubleValue()); + rows = Math.max(rows, 0); + return new TableBindResult(new String[] {"i"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, rows); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0, ((Number) bind.getBindState()).intValue()}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 64 && current < end; produced++, current++) { + out.setInt(0, produced, current); + } + st[0] = current; + return produced; + } + }, + new TableFunctionDefinition().withParameterTypes( + new DuckDBColumnType[] {DuckDBColumnType.DECIMAL, DuckDBColumnType.BLOB, DuckDBColumnType.DATE, + DuckDBColumnType.TIME, DuckDBColumnType.TIMESTAMP})); + + try (ResultSet rs = + stmt.executeQuery("SELECT count(*), sum(i) FROM tf_bind_extended_types(" + + "3.0::DECIMAL(18,2), 'blob-extended'::BLOB, DATE '2024-01-03', " + + "TIME '01:02:03.123456', TIMESTAMP '2024-01-03 04:05:06.123456')")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 3L); + assertEquals(rs.getLong(2), 3L); + assertFalse(rs.next()); + } + + assertTrue(observedParameters[0] instanceof BigDecimal); + assertEquals(((BigDecimal) observedParameters[0]).compareTo(new BigDecimal("3.00")), 0); + assertTrue(observedParameters[1] instanceof byte[]); + assertEquals((byte[]) observedParameters[1], "blob-extended".getBytes(StandardCharsets.UTF_8)); + assertTrue(observedParameters[2] instanceof LocalDate); + assertTrue(observedParameters[3] instanceof LocalTime); + assertTrue(observedParameters[4] instanceof LocalDateTime); + assertEquals(observedParameters[2], LocalDate.of(2024, 1, 3)); + assertEquals(observedParameters[3], LocalTime.parse("01:02:03.123456")); + assertEquals(observedParameters[4], LocalDateTime.parse("2024-01-03T04:05:06.123456")); + } + } + + public static void test_table_function_bind_temporal_and_uuid_objects() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + Object[] observedParameters = new Object[10]; + + conn.registerTableFunction( + "tf_bind_temporal_uuid_objects", + new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + if (parameters.length != observedParameters.length) { + throw new IllegalStateException("Expected 10 bind parameters"); + } + System.arraycopy(parameters, 0, observedParameters, 0, parameters.length); + return new TableBindResult(new String[] {"i"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] producedRows = (int[]) state.getState(); + if (producedRows[0] > 0) { + return 0; + } + out.setInt(0, 0, 1); + producedRows[0] = 1; + return 1; + } + }, + new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[] { + DuckDBColumnType.DATE, DuckDBColumnType.TIME, DuckDBColumnType.TIME_NS, DuckDBColumnType.TIMESTAMP, + DuckDBColumnType.TIMESTAMP_S, DuckDBColumnType.TIMESTAMP_MS, DuckDBColumnType.TIMESTAMP_NS, + DuckDBColumnType.TIME_WITH_TIME_ZONE, DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, + DuckDBColumnType.UUID})); + + try (ResultSet rs = stmt.executeQuery("SELECT sum(i) FROM tf_bind_temporal_uuid_objects(" + + "DATE '2024-01-03', " + + "TIME '01:02:03.123456', " + + "TIME_NS '01:02:03.123456789', " + + "TIMESTAMP '2024-01-03 04:05:06.123456', " + + "TIMESTAMP_S '2024-01-03 04:05:06', " + + "TIMESTAMP_MS '2024-01-03 04:05:06.123', " + + "TIMESTAMP_NS '2024-01-03 04:05:06.123456789', " + + "'01:02:03+05:30'::TIME WITH TIME ZONE, " + + "'2024-01-03 04:05:06+00'::TIMESTAMP WITH TIME ZONE, " + + "'550e8400-e29b-41d4-a716-446655440000'::UUID)")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 1); + assertFalse(rs.next()); + } + + assertTrue(observedParameters[0] instanceof LocalDate); + assertTrue(observedParameters[1] instanceof LocalTime); + assertTrue(observedParameters[2] instanceof LocalTime); + assertTrue(observedParameters[3] instanceof LocalDateTime); + assertTrue(observedParameters[4] instanceof LocalDateTime); + assertTrue(observedParameters[5] instanceof LocalDateTime); + assertTrue(observedParameters[6] instanceof LocalDateTime); + assertTrue(observedParameters[7] instanceof OffsetTime); + assertTrue(observedParameters[8] instanceof OffsetDateTime); + assertTrue(observedParameters[9] instanceof UUID); + + assertEquals(observedParameters[0], LocalDate.of(2024, 1, 3)); + assertEquals(observedParameters[1], LocalTime.parse("01:02:03.123456")); + assertEquals(observedParameters[2], LocalTime.parse("01:02:03.123456789")); + assertEquals(observedParameters[3], LocalDateTime.parse("2024-01-03T04:05:06.123456")); + assertEquals(observedParameters[4], LocalDateTime.parse("2024-01-03T04:05:06")); + assertEquals(observedParameters[5], LocalDateTime.parse("2024-01-03T04:05:06.123")); + assertEquals(observedParameters[6], LocalDateTime.parse("2024-01-03T04:05:06.123456789")); + assertEquals(observedParameters[7], OffsetTime.parse("01:02:03+05:30")); + assertEquals(observedParameters[8], OffsetDateTime.parse("2024-01-03T04:05:06Z")); + assertEquals(observedParameters[9], UUID.fromString("550e8400-e29b-41d4-a716-446655440000")); + + try (ResultSet rs = + stmt.executeQuery("SELECT sum(i) FROM tf_bind_temporal_uuid_objects(" + + "NULL::DATE, NULL::TIME, NULL::TIME_NS, NULL::TIMESTAMP, NULL::TIMESTAMP_S, " + + "NULL::TIMESTAMP_MS, NULL::TIMESTAMP_NS, NULL::TIME WITH TIME ZONE, " + + "NULL::TIMESTAMP WITH TIME ZONE, NULL::UUID)")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 1); + assertFalse(rs.next()); + } + for (Object observedParameter : observedParameters) { + assertEquals(observedParameter, null); + } + } + } + + public static void test_table_function_bind_decimal_parameter_exact_bigdecimal() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + Object[] observedParameters = new Object[1]; + + conn.registerTableFunction("tf_bind_decimal_exact", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + if (parameters.length != 1) { + throw new IllegalStateException("Expected 1 bind parameter"); + } + observedParameters[0] = parameters[0]; + return new TableBindResult(new String[] {"i"}, new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, + null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] produced = (int[]) state.getState(); + if (produced[0] > 0) { + return 0; + } + out.setInt(0, 0, 1); + produced[0] = 1; + return 1; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[] {DuckDBColumnType.DECIMAL})); + + try (ResultSet rs = + stmt.executeQuery("SELECT sum(i) FROM tf_bind_decimal_exact(9007199254740.127::DECIMAL(18,3))")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 1); + assertFalse(rs.next()); + } + + assertTrue(observedParameters[0] instanceof BigDecimal); + assertEquals(((BigDecimal) observedParameters[0]).compareTo(new BigDecimal("9007199254740.127")), 0); + } + } + + public static void test_table_function_decimal_logical_output_boundaries_and_nulls() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_decimal_logical_boundaries", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult( + new String[] {"d4_1", "d9_4", "d18_6", "d30_10", "d38_10"}, + new UdfLogicalType[] {UdfLogicalType.decimal(4, 1), UdfLogicalType.decimal(9, 4), + UdfLogicalType.decimal(18, 6), UdfLogicalType.decimal(30, 10), + UdfLogicalType.decimal(38, 10)}, + null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] producedRows = (int[]) state.getState(); + if (producedRows[0] > 0) { + return 0; + } + + out.setBigDecimal(0, 0, new BigDecimal("999.9")); + out.setBigDecimal(1, 0, new BigDecimal("99999.9999")); + out.setBigDecimal(2, 0, new BigDecimal("999999999999.999999")); + out.setBigDecimal(3, 0, new BigDecimal("99999999999999999999.9999999999")); + out.setBigDecimal(4, 0, new BigDecimal("9999999999999999999999999999.9999999999")); + + out.setBigDecimal(0, 1, new BigDecimal("-999.9")); + out.setBigDecimal(1, 1, new BigDecimal("-99999.9999")); + out.setBigDecimal(2, 1, new BigDecimal("-999999999999.999999")); + out.setBigDecimal(3, 1, new BigDecimal("-99999999999999999999.9999999999")); + out.setBigDecimal(4, 1, new BigDecimal("-9999999999999999999999999999.9999999999")); + + for (int col = 0; col < 5; col++) { + out.setNull(col, 2); + } + + producedRows[0] = 3; + return 3; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[0])); + + try (ResultSet rs = + stmt.executeQuery("SELECT CAST(d4_1 AS VARCHAR), CAST(d9_4 AS VARCHAR), CAST(d18_6 AS VARCHAR), " + + "CAST(d30_10 AS VARCHAR), CAST(d38_10 AS VARCHAR) " + + "FROM tf_decimal_logical_boundaries() ORDER BY d4_1 DESC NULLS LAST")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "999.9"); + assertEquals(rs.getString(2), "99999.9999"); + assertEquals(rs.getString(3), "999999999999.999999"); + assertEquals(rs.getString(4), "99999999999999999999.9999999999"); + assertEquals(rs.getString(5), "9999999999999999999999999999.9999999999"); + + assertTrue(rs.next()); + assertEquals(rs.getString(1), "-999.9"); + assertEquals(rs.getString(2), "-99999.9999"); + assertEquals(rs.getString(3), "-999999999999.999999"); + assertEquals(rs.getString(4), "-99999999999999999999.9999999999"); + assertEquals(rs.getString(5), "-9999999999999999999999999999.9999999999"); + + assertTrue(rs.next()); + assertEquals(rs.getObject(1), null); + assertEquals(rs.getObject(2), null); + assertEquals(rs.getObject(3), null); + assertEquals(rs.getObject(4), null); + assertEquals(rs.getObject(5), null); + assertFalse(rs.next()); + } + } + } + + public static void test_table_function_bind_typed_parameters_complex() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + Object[] observedParameters = new Object[4]; + + conn.registerTableFunction( + "tf_bind_complex", + new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + if (parameters.length != 4) { + throw new IllegalStateException("Expected 4 bind parameters"); + } + System.arraycopy(parameters, 0, observedParameters, 0, parameters.length); + + List listParam = (List) parameters[0]; + Map mapParam = (Map) parameters[1]; + Map structParam = (Map) parameters[2]; + String enumParam = (String) parameters[3]; + + int rowCount = listParam.size() + mapParam.size() + ((Number) structParam.get("id")).intValue(); + if ("medium".equals(enumParam)) { + rowCount += 1; + } + return new TableBindResult(new String[] {"i"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, rowCount); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 128 && current < end; produced++, current++) { + out.setInt(0, produced, current); + } + st[0] = current; + return produced; + } + }, + new TableFunctionDefinition().withParameterTypes(new UdfLogicalType[] { + UdfLogicalType.list(UdfLogicalType.of(DuckDBColumnType.INTEGER)), + UdfLogicalType.map(UdfLogicalType.of(DuckDBColumnType.VARCHAR), + UdfLogicalType.of(DuckDBColumnType.INTEGER)), + UdfLogicalType.struct(new String[] {"id", "txt"}, + new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.INTEGER), + UdfLogicalType.of(DuckDBColumnType.VARCHAR)}), + UdfLogicalType.enumeration("small", "medium", "large")})); + + try (ResultSet rs = stmt.executeQuery("SELECT count(*), sum(i) FROM tf_bind_complex(" + + "[10,20,30], map(['k1','k2'], [100,200]), " + + "{'id':4, 'txt':'duck'}, " + + "'medium'::ENUM('small','medium','large'))")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 10L); + assertEquals(rs.getLong(2), 45L); + assertFalse(rs.next()); + } + + assertTrue(observedParameters[0] instanceof List); + assertTrue(observedParameters[1] instanceof Map); + assertTrue(observedParameters[2] instanceof Map); + assertTrue(observedParameters[3] instanceof String); + + List observedList = (List) observedParameters[0]; + assertEquals(observedList.size(), 3); + assertEquals(((Number) observedList.get(0)).intValue(), 10); + assertEquals(((Number) observedList.get(1)).intValue(), 20); + assertEquals(((Number) observedList.get(2)).intValue(), 30); + + Map observedMap = (Map) observedParameters[1]; + assertEquals(observedMap.size(), 2); + assertEquals(((Number) observedMap.get("k1")).intValue(), 100); + assertEquals(((Number) observedMap.get("k2")).intValue(), 200); + + Map observedStruct = (Map) observedParameters[2]; + assertEquals(((Number) observedStruct.get("id")).intValue(), 4); + assertEquals(observedStruct.get("txt"), "duck"); + assertEquals(observedParameters[3], "medium"); + } + } + + public static void test_table_function_bind_parameter_type_validation() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class)) { + assertThrows(() -> { + conn.registerTableFunction( + "tf_bad_param", + new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER}); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0, 0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + return 0; + } + }, + new TableFunctionDefinition().withParameterTypes( + new DuckDBColumnType[] {DuckDBColumnType.INTERVAL})); + }, SQLFeatureNotSupportedException.class); + + assertThrows(() -> { + conn.registerTableFunction( + "tf_bad_param_logical", + new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER}); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0, 0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + return 0; + } + }, + new TableFunctionDefinition().withParameterTypes( + new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.INTERVAL)})); + }, SQLException.class); + } + } + + public static void test_table_function_typed_outputs_core_types() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_core_out", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + int end = ((Number) parameters[0]).intValue(); + return new TableBindResult( + new String[] {"b", "t8", "s16", "i32", "i64", "f32", "f64", "txt"}, + new DuckDBColumnType[] {DuckDBColumnType.BOOLEAN, DuckDBColumnType.TINYINT, + DuckDBColumnType.SMALLINT, DuckDBColumnType.INTEGER, + DuckDBColumnType.BIGINT, DuckDBColumnType.FLOAT, + DuckDBColumnType.DOUBLE, DuckDBColumnType.VARCHAR}, + end); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 128 && current < end; produced++, current++) { + out.setBoolean(0, produced, current % 2 == 0); + out.setInt(1, produced, current - 50); + out.setInt(2, produced, 1000 + current); + out.setInt(3, produced, current * 10); + out.setLong(4, produced, 1_000_000_000_000L + current); + out.setFloat(5, produced, current + 0.5f); + if (current % 2 == 0) { + out.setDouble(6, produced, current + 0.25d); + } else { + out.setNull(6, produced); + } + if (current % 3 == 0) { + out.setNull(7, produced); + } else { + out.setString(7, produced, "v" + current); + } + } + st[0] = current; + return produced; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[] {DuckDBColumnType.INTEGER})); + + try (ResultSet rs = stmt.executeQuery("SELECT * FROM tf_core_out(6) ORDER BY i32")) { + for (int i = 0; i < 6; i++) { + assertTrue(rs.next()); + assertEquals(rs.getBoolean(1), i % 2 == 0); + assertEquals(rs.getInt(2), i - 50); + assertEquals(rs.getInt(3), 1000 + i); + assertEquals(rs.getInt(4), i * 10); + assertEquals(rs.getLong(5), 1_000_000_000_000L + i); + assertEquals(rs.getFloat(6), i + 0.5f, 0.0001f); + if (i % 2 == 0) { + assertEquals(rs.getDouble(7), i + 0.25d, 0.0000001d); + } else { + assertEquals(rs.getObject(7), null); + } + if (i % 3 == 0) { + assertEquals(rs.getString(8), null); + } else { + assertEquals(rs.getString(8), "v" + i); + } + } + assertFalse(rs.next()); + } + } + } + + public static void test_table_function_typed_outputs_extended_types() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_extended_out", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + int end = ((Number) parameters[0]).intValue(); + return new TableBindResult(new String[] {"id", "dec", "blob", "d", "t", "ts"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, + DuckDBColumnType.DECIMAL, DuckDBColumnType.BLOB, + DuckDBColumnType.DATE, DuckDBColumnType.TIME, + DuckDBColumnType.TIMESTAMP}, + end); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + LocalDate baseDate = LocalDate.of(2024, 1, 3); + LocalTime baseTime = LocalTime.of(1, 1, 1); + LocalDateTime baseTimestamp = LocalDateTime.of(2024, 1, 3, 4, 5, 6); + for (; produced < 64 && current < end; produced++, current++) { + out.setInt(0, produced, current); + out.setBigDecimal(1, produced, BigDecimal.valueOf(current).add(BigDecimal.valueOf(0.25d))); + out.setBytes(2, produced, ("b" + current).getBytes(StandardCharsets.UTF_8)); + out.setLocalDate(3, produced, baseDate.plusDays(current)); + out.setLocalTime(4, produced, baseTime.plusSeconds(current)); + out.setLocalDateTime(5, produced, baseTimestamp.plusSeconds(current)); + } + st[0] = current; + return produced; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[] {DuckDBColumnType.INTEGER})); + + try (ResultSet rs = stmt.executeQuery("SELECT id, CAST(dec AS DOUBLE), blob, CAST(d AS VARCHAR), " + + "CAST(t AS VARCHAR), CAST(ts AS VARCHAR) " + + "FROM tf_extended_out(3) ORDER BY id")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 0); + assertEquals(rs.getDouble(2), 0.25d, 0.000001d); + assertEquals(rs.getBytes(3), "b0".getBytes(StandardCharsets.UTF_8)); + assertEquals(rs.getString(4), "2024-01-03"); + assertEquals(rs.getString(5), "01:01:01"); + assertEquals(rs.getString(6), "2024-01-03 04:05:06"); + + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 1); + assertEquals(rs.getDouble(2), 1.25d, 0.000001d); + assertEquals(rs.getBytes(3), "b1".getBytes(StandardCharsets.UTF_8)); + assertEquals(rs.getString(4), "2024-01-04"); + assertEquals(rs.getString(5), "01:01:02"); + assertEquals(rs.getString(6), "2024-01-03 04:05:07"); + + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 2); + assertEquals(rs.getDouble(2), 2.25d, 0.000001d); + assertEquals(rs.getBytes(3), "b2".getBytes(StandardCharsets.UTF_8)); + assertEquals(rs.getString(4), "2024-01-05"); + assertEquals(rs.getString(5), "01:01:03"); + assertEquals(rs.getString(6), "2024-01-03 04:05:08"); + assertFalse(rs.next()); + } + } + } + + public static void test_table_function_output_appender_java_object_methods() throws Exception { + final UUID expectedUuid = UUID.fromString("550e8400-e29b-41d4-a716-446655440000"); + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_out_java_objects", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + int end = ((Number) parameters[0]).intValue(); + return new TableBindResult( + new String[] {"dec", "d", "t", "ts", "t_tz", "ts_tz", "uuid_v"}, + new DuckDBColumnType[] {DuckDBColumnType.DECIMAL, DuckDBColumnType.DATE, DuckDBColumnType.TIME, + DuckDBColumnType.TIMESTAMP, DuckDBColumnType.TIME_WITH_TIME_ZONE, + DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE, DuckDBColumnType.UUID}, + end); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + if (st[0] >= st[1]) { + return 0; + } + out.setBigDecimal(0, 0, new BigDecimal("12.345")); + out.setLocalDate(1, 0, LocalDate.of(2024, 1, 3)); + out.setLocalTime(2, 0, LocalTime.of(1, 2, 3)); + out.setDate(3, 0, new java.util.Date(1_704_254_706_000L)); // 2024-01-03 04:05:06 UTC + out.setOffsetTime(4, 0, OffsetTime.parse("01:02:03+02:00")); + out.setOffsetDateTime(5, 0, OffsetDateTime.parse("2024-01-03T04:05:06+02:00")); + out.setUUID(6, 0, expectedUuid); + st[0]++; + return 1; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[] {DuckDBColumnType.INTEGER})); + + try (ResultSet rs = stmt.executeQuery("SELECT CAST(dec AS DOUBLE), CAST(d AS VARCHAR), CAST(t AS VARCHAR), " + + "CAST(ts AS VARCHAR), CAST(uuid_v AS VARCHAR), " + + "t_tz IS NOT NULL, ts_tz IS NOT NULL " + + "FROM tf_out_java_objects(1)")) { + assertTrue(rs.next()); + assertEquals(rs.getDouble(1), 12.345d, 0.000001d); + assertEquals(rs.getString(2), "2024-01-03"); + assertEquals(rs.getString(3), "01:02:03"); + assertEquals(rs.getString(4), "2024-01-03 04:05:06"); + assertEquals(rs.getString(5), expectedUuid.toString()); + assertEquals(rs.getBoolean(6), true); + assertEquals(rs.getBoolean(7), true); + assertFalse(rs.next()); + } + } + } + + public static void test_table_function_output_appender_java_append_overloads() throws Exception { + final UUID expectedUuid = UUID.fromString("550e8400-e29b-41d4-a716-446655440001"); + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_out_java_append_overloads", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult( + new String[] {"dec", "d", "t", "ts", "uuid_v"}, + new DuckDBColumnType[] {DuckDBColumnType.DECIMAL, DuckDBColumnType.DATE, DuckDBColumnType.TIME, + DuckDBColumnType.TIMESTAMP, DuckDBColumnType.UUID}, + null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] producedRows = (int[]) state.getState(); + if (producedRows[0] > 0) { + return 0; + } + + out.beginRow() + .append(new BigDecimal("9.875")) + .append(LocalDate.of(2025, 2, 14)) + .append(LocalTime.of(12, 34, 56)) + .append(LocalDateTime.of(2025, 2, 14, 12, 34, 56)) + .append(expectedUuid) + .endRow(); + + out.beginRow().appendNull().appendNull().appendNull().appendNull().appendNull().endRow(); + + producedRows[0] = 2; + return 2; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[0])); + + try (ResultSet rs = stmt.executeQuery( + "SELECT CAST(dec AS DOUBLE), CAST(d AS VARCHAR), CAST(t AS VARCHAR), CAST(ts AS VARCHAR), " + + "CAST(uuid_v AS VARCHAR) FROM tf_out_java_append_overloads() ORDER BY d NULLS LAST")) { + assertTrue(rs.next()); + assertEquals(rs.getDouble(1), 9.875d, 0.000001d); + assertEquals(rs.getString(2), "2025-02-14"); + assertEquals(rs.getString(3), "12:34:56"); + assertEquals(rs.getString(4), "2025-02-14 12:34:56"); + assertEquals(rs.getString(5), expectedUuid.toString()); + + assertTrue(rs.next()); + assertEquals(rs.getObject(1), null); + assertEquals(rs.getObject(2), null); + assertEquals(rs.getObject(3), null); + assertEquals(rs.getObject(4), null); + assertEquals(rs.getObject(5), null); + assertFalse(rs.next()); + } + } + } + + public static void test_table_function_output_appender_decimal_exact_bigdecimal_paths() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_out_decimal_exact", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"dec"}, new DuckDBColumnType[] {DuckDBColumnType.DECIMAL}, + null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] producedRows = (int[]) state.getState(); + if (producedRows[0] > 0) { + return 0; + } + + out.beginRow().append(new BigDecimal("9007199254740.127")).endRow(); + out.setBigDecimal(0, 1, new BigDecimal("-9007199254740.127")); + + producedRows[0] = 2; + return 2; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[0])); + + try (ResultSet rs = + stmt.executeQuery("SELECT CAST(dec AS VARCHAR) FROM tf_out_decimal_exact() ORDER BY dec DESC")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "9007199254740.127"); + assertTrue(rs.next()); + assertEquals(rs.getString(1), "-9007199254740.127"); + assertFalse(rs.next()); + } + } + } + + public static void test_table_function_output_appender_decimal_number_coercion_exact() throws Exception { + final BigInteger hugeIntValue = new BigInteger("12345678901234567890123456789012345678"); + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_out_decimal_number_exact", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i", "dec"}, + new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.INTEGER), + UdfLogicalType.decimal(38, 0)}, + null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] producedRows = (int[]) state.getState(); + if (producedRows[0] > 0) { + return 0; + } + + out.beginRow().append(0).append(Long.valueOf(9007199254740993L)).endRow(); + out.beginRow().append(1).append(hugeIntValue).endRow(); + + producedRows[0] = 2; + return 2; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[0])); + + try (ResultSet rs = + stmt.executeQuery("SELECT CAST(dec AS VARCHAR) FROM tf_out_decimal_number_exact() ORDER BY i")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "9007199254740993"); + assertTrue(rs.next()); + assertEquals(rs.getString(1), hugeIntValue.toString()); + assertFalse(rs.next()); + } + } + } + + public static void test_table_function_output_appender_int128_biginteger_roundtrip_boundaries() throws Exception { + final BigInteger hugeintMin = new BigInteger("-170141183460469231731687303715884105728"); + final BigInteger hugeintMax = new BigInteger("170141183460469231731687303715884105727"); + final BigInteger uhugeintMax = new BigInteger("340282366920938463463374607431768211455"); + + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_out_int128_bigint_boundaries", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"id", "h", "uh"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, + DuckDBColumnType.HUGEINT, + DuckDBColumnType.UHUGEINT}, + null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] producedRows = (int[]) state.getState(); + if (producedRows[0] > 0) { + return 0; + } + + out.beginRow().append(0).append(hugeintMin).append(BigInteger.ZERO).endRow(); + + out.setInt(0, 1, 1); + out.setObject(1, 1, hugeintMax); + out.setObject(2, 1, uhugeintMax); + + producedRows[0] = 2; + return 2; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[0])); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT CAST(h AS VARCHAR), CAST(uh AS VARCHAR) FROM tf_out_int128_bigint_boundaries() ORDER BY id")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), hugeintMin.toString()); + assertEquals(rs.getString(2), "0"); + + assertTrue(rs.next()); + assertEquals(rs.getString(1), hugeintMax.toString()); + assertEquals(rs.getString(2), uhugeintMax.toString()); + + assertFalse(rs.next()); + } + } + } + + public static void test_table_function_output_appender_int128_biginteger_out_of_range() throws Exception { + final BigInteger hugeintOverflow = new BigInteger("170141183460469231731687303715884105728"); + final BigInteger uhugeintNegative = BigInteger.valueOf(-1); + + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_out_int128_bigint_huge_oob", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"id", "h", "uh"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, + DuckDBColumnType.HUGEINT, + DuckDBColumnType.UHUGEINT}, + null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] producedRows = (int[]) state.getState(); + if (producedRows[0] > 0) { + return 0; + } + + out.beginRow().append(0).append(hugeintOverflow).append(BigInteger.ZERO).endRow(); + producedRows[0] = 1; + return 1; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[0])); + + assertThrows( + () -> { stmt.executeQuery("SELECT * FROM tf_out_int128_bigint_huge_oob()"); }, SQLException.class); + + conn.registerTableFunction("tf_out_int128_bigint_uhuge_oob", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"id", "h", "uh"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, + DuckDBColumnType.HUGEINT, + DuckDBColumnType.UHUGEINT}, + null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] producedRows = (int[]) state.getState(); + if (producedRows[0] > 0) { + return 0; + } + + out.setInt(0, 0, 0); + out.setObject(1, 0, BigInteger.ZERO); + out.setObject(2, 0, uhugeintNegative); + producedRows[0] = 1; + return 1; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[0])); + + assertThrows( + () -> { stmt.executeQuery("SELECT * FROM tf_out_int128_bigint_uhuge_oob()"); }, SQLException.class); + } + } + + public static void test_table_function_output_appender_exact_integer_object_coercion() throws Exception { + final BigInteger ubigintMax = new BigInteger("18446744073709551615"); + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_out_exact_integer_objects", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult( + new String[] {"i", "u"}, + new DuckDBColumnType[] {DuckDBColumnType.BIGINT, DuckDBColumnType.UBIGINT}, null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] producedRows = (int[]) state.getState(); + if (producedRows[0] > 0) { + return 0; + } + + out.setObject(0, 0, new BigDecimal("9223372036854775807")); + out.setObject(1, 0, ubigintMax); + producedRows[0] = 1; + return 1; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[0])); + + try (ResultSet rs = stmt.executeQuery( + "SELECT CAST(i AS VARCHAR), CAST(u AS VARCHAR) FROM tf_out_exact_integer_objects()")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "9223372036854775807"); + assertEquals(rs.getString(2), ubigintMax.toString()); + assertFalse(rs.next()); + } + } + } + + public static void test_table_function_output_appender_exact_integer_object_coercion_rejects_invalid() + throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_out_exact_integer_bad_bigint", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i"}, new DuckDBColumnType[] {DuckDBColumnType.BIGINT}, + null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] producedRows = (int[]) state.getState(); + if (producedRows[0] > 0) { + return 0; + } + out.setObject(0, 0, new BigDecimal("1.5")); + producedRows[0] = 1; + return 1; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[0])); + assertThrows( + () -> { stmt.executeQuery("SELECT * FROM tf_out_exact_integer_bad_bigint()"); }, SQLException.class); + + conn.registerTableFunction("tf_out_exact_integer_bad_ubigint", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"u"}, new DuckDBColumnType[] {DuckDBColumnType.UBIGINT}, + null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] producedRows = (int[]) state.getState(); + if (producedRows[0] > 0) { + return 0; + } + out.setObject(0, 0, new BigInteger("18446744073709551616")); + producedRows[0] = 1; + return 1; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[0])); + assertThrows( + () -> { stmt.executeQuery("SELECT * FROM tf_out_exact_integer_bad_ubigint()"); }, SQLException.class); + } + } + + public static void test_table_function_output_appender_date_setdate_with_java_util_date() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_out_date_setdate_java_util", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"d"}, new DuckDBColumnType[] {DuckDBColumnType.DATE}, + null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] producedRows = (int[]) state.getState(); + if (producedRows[0] > 0) { + return 0; + } + out.setDate(0, 0, new java.util.Date(1_704_254_706_000L)); + producedRows[0] = 1; + return 1; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[0])); + + try (ResultSet rs = stmt.executeQuery("SELECT CAST(d AS VARCHAR) FROM tf_out_date_setdate_java_util()")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "2024-01-03"); + assertFalse(rs.next()); + } + } + } + + public static void test_table_function_output_appender_timetz_offset_range_validation() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_out_timetz_oob", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"t_tz"}, + new DuckDBColumnType[] {DuckDBColumnType.TIME_WITH_TIME_ZONE}, null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] producedRows = (int[]) state.getState(); + if (producedRows[0] > 0) { + return 0; + } + out.beginRow().append(OffsetTime.of(1, 2, 3, 0, ZoneOffset.ofHours(16))).endRow(); + producedRows[0] = 1; + return 1; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[0])); + + assertThrows(() -> { stmt.executeQuery("SELECT * FROM tf_out_timetz_oob()"); }, SQLException.class); + } + } + + public static void test_table_function_output_appender_java_object_type_mismatch() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_out_java_object_type_mismatch", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i"}, new DuckDBColumnType[] {DuckDBColumnType.INTEGER}); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(null); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + out.setLocalDate(0, 0, LocalDate.of(2024, 1, 1)); + return 1; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[0])); + + assertThrows( + () -> { stmt.executeQuery("SELECT * FROM tf_out_java_object_type_mismatch()"); }, SQLException.class); + } + } + + public static void test_table_function_typed_outputs_unsigned_and_special_roundtrip_and_nulls() throws Exception { + final DuckDBColumnType[] types = scalarUnsignedAndSpecialTypes(); + final String[] columnNames = + new String[] {"u8", "u16", "u32", "u64", "i128", "u128", "timetz", "tstz", "uuid_v"}; + + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_unsigned_special_out", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(columnNames, types, parameters.clone()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new Object[] {bind.getBindState(), Boolean.FALSE}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + Object[] stateValues = (Object[]) state.getState(); + if ((Boolean) stateValues[1]) { + return 0; + } + + Object[] parameters = (Object[]) stateValues[0]; + for (int i = 0; i < types.length; i++) { + Object parameter = parameters[i]; + if (parameter == null) { + out.setNull(i, 0); + continue; + } + + switch (types[i]) { + case UTINYINT: + case USMALLINT: + out.setInt(i, 0, ((Number) parameter).intValue()); + break; + case UINTEGER: + case UBIGINT: + out.setLong(i, 0, ((Number) parameter).longValue()); + break; + case TIME_WITH_TIME_ZONE: + out.setOffsetTime(i, 0, (OffsetTime) parameter); + break; + case TIMESTAMP_WITH_TIME_ZONE: + out.setOffsetDateTime(i, 0, (OffsetDateTime) parameter); + break; + case HUGEINT: + case UHUGEINT: + out.setBytes(i, 0, (byte[]) parameter); + break; + case UUID: + out.setUUID(i, 0, (UUID) parameter); + break; + default: + throw new IllegalStateException("Unexpected unsigned/special table type: " + types[i]); + } + } + + stateValues[1] = Boolean.TRUE; + return 1; + } + }, new TableFunctionDefinition().withParameterTypes(types)); + + StringBuilder nonNullArgs = new StringBuilder(); + StringBuilder nonNullChecks = new StringBuilder(); + for (int i = 0; i < types.length; i++) { + if (i > 0) { + nonNullArgs.append(", "); + nonNullChecks.append(", "); + } + String literal = nonNullLiteralForUnsignedAndSpecialType(types[i]); + nonNullArgs.append(literal); + nonNullChecks.append(columnNames[i]).append(" = ").append(literal); + } + + try (ResultSet rs = stmt.executeQuery("SELECT " + nonNullChecks + " FROM tf_unsigned_special_out(" + + nonNullArgs + ")")) { + assertTrue(rs.next()); + for (int i = 0; i < types.length; i++) { + assertEquals(rs.getBoolean(i + 1), true); + } + assertFalse(rs.next()); + } + + StringBuilder nullArgs = new StringBuilder(); + StringBuilder nullChecks = new StringBuilder(); + for (int i = 0; i < types.length; i++) { + if (i > 0) { + nullArgs.append(", "); + nullChecks.append(", "); + } + nullArgs.append(nullLiteralForType(types[i])); + nullChecks.append(columnNames[i]).append(" IS NULL"); + } + + try (ResultSet rs = + stmt.executeQuery("SELECT " + nullChecks + " FROM tf_unsigned_special_out(" + nullArgs + ")")) { + assertTrue(rs.next()); + for (int i = 0; i < types.length; i++) { + assertEquals(rs.getBoolean(i + 1), true); + } + assertFalse(rs.next()); + } + } + } + + public static void test_table_function_nested_and_enum_outputs() throws Exception { + final UdfLogicalType listOfInt = UdfLogicalType.list(UdfLogicalType.of(DuckDBColumnType.INTEGER)); + final UdfLogicalType arrayOfVarchar = UdfLogicalType.array(UdfLogicalType.of(DuckDBColumnType.VARCHAR), 2); + final UdfLogicalType mapVarcharInt = UdfLogicalType.map(UdfLogicalType.of(DuckDBColumnType.VARCHAR), + UdfLogicalType.of(DuckDBColumnType.INTEGER)); + final UdfLogicalType structType = UdfLogicalType.struct( + new String[] {"id", "txt"}, new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.INTEGER), + UdfLogicalType.of(DuckDBColumnType.VARCHAR)}); + final UdfLogicalType unionType = UdfLogicalType.unionType( + new String[] {"num", "txt"}, new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.INTEGER), + UdfLogicalType.of(DuckDBColumnType.VARCHAR)}); + final UdfLogicalType enumType = UdfLogicalType.enumeration("small", "medium", "large"); + + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_nested_out", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + int rowCount = ((Number) parameters[0]).intValue(); + return new TableBindResult(new String[] {"lst_i", "arr_txt", "kv", "s", "u", "en"}, + new UdfLogicalType[] {listOfInt, arrayOfVarchar, mapVarcharInt, + structType, unionType, enumType}, + rowCount); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 64 && current < end; produced++, current++) { + Map mapValue = new HashMap<>(); + mapValue.put("k", 10 + current * 10); + if (current % 2 == 0) { + out.beginRow() + .append(Arrays.asList(current + 1, current + 2, current + 3)) + .append(new String[] {"a" + current, "b" + current}) + .append(mapValue) + .append(Arrays.asList(current, "row" + current)) + .append(new AbstractMap.SimpleEntry("num", 100 + current)) + .append("medium") + .endRow(); + } else { + out.beginRow() + .append(Arrays.asList(current + 1, null, current + 3)) + .append(new String[] {"a" + current, "b" + current}) + .append(mapValue) + .append(Arrays.asList(current, "row" + current)) + .append(new AbstractMap.SimpleEntry("txt", "u" + current)) + .append("small") + .endRow(); + } + } + st[0] = current; + return out.getSize(); + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[] {DuckDBColumnType.INTEGER})); + + try (ResultSet rs = stmt.executeQuery("SELECT list_extract(lst_i, 1), list_extract(lst_i, 2), " + + "array_extract(arr_txt, 2), list_extract(map_extract(kv, 'k'), 1), " + + "s.id, s.txt, union_tag(u), u.num, u.txt, CAST(en AS VARCHAR) " + + "FROM tf_nested_out(2) ORDER BY s.id")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 1); + assertEquals(rs.getInt(2), 2); + assertEquals(rs.getString(3), "b0"); + assertEquals(rs.getInt(4), 10); + assertEquals(rs.getInt(5), 0); + assertEquals(rs.getString(6), "row0"); + assertEquals(rs.getString(7), "num"); + assertEquals(rs.getInt(8), 100); + assertEquals(rs.getObject(9), null); + assertEquals(rs.getString(10), "medium"); + + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 2); + assertEquals(rs.getObject(2), null); + assertEquals(rs.getString(3), "b1"); + assertEquals(rs.getInt(4), 20); + assertEquals(rs.getInt(5), 1); + assertEquals(rs.getString(6), "row1"); + assertEquals(rs.getString(7), "txt"); + assertEquals(rs.getObject(8), null); + assertEquals(rs.getString(9), "u1"); + assertEquals(rs.getString(10), "small"); + assertFalse(rs.next()); + } + } + } + + public static void test_table_function_nested_projection_pushdown() throws Exception { + final UdfLogicalType structType = UdfLogicalType.struct( + new String[] {"id", "txt"}, new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.INTEGER), + UdfLogicalType.of(DuckDBColumnType.VARCHAR)}); + final UdfLogicalType enumType = UdfLogicalType.enumeration("small", "large"); + + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + int[] materializedBySourceColumn = new int[] {0, 0, 0}; + int[][] observedProjectedColumns = new int[2][]; + + conn.registerTableFunction( + "tf_nested_projected", + new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + int end = ((Number) parameters[0]).intValue(); + return new TableBindResult( + new String[] {"id", "nested", "en"}, + new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.INTEGER), structType, enumType}, + end); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + int[] projectedColumns = new int[ctx.getColumnCount()]; + for (int i = 0; i < projectedColumns.length; i++) { + projectedColumns[i] = ctx.getColumnIndex(i); + } + if (observedProjectedColumns[0] == null) { + observedProjectedColumns[0] = projectedColumns.clone(); + } else { + observedProjectedColumns[1] = projectedColumns.clone(); + } + return new TableState(new Object[] {0, end, projectedColumns}); + } + + @Override + public int produce(TableState state, UdfOutputAppender out) { + Object[] st = (Object[]) state.getState(); + int current = (int) st[0]; + int end = (int) st[1]; + int[] projectedColumns = (int[]) st[2]; + int produced = 0; + for (; produced < 128 && current < end; produced++, current++) { + out.beginRow(); + for (int projectedCol = 0; projectedCol < projectedColumns.length; projectedCol++) { + int sourceColumn = projectedColumns[projectedCol]; + materializedBySourceColumn[sourceColumn]++; + if (sourceColumn == 0) { + out.append(current); + } else if (sourceColumn == 1) { + out.append(Arrays.asList(current, "p" + current)); + } else { + out.append(current % 2 == 0 ? "small" : "large"); + } + } + out.endRow(); + } + st[0] = current; + return out.getSize(); + } + }, + new TableFunctionDefinition() + .withParameterTypes(new DuckDBColumnType[] {DuckDBColumnType.INTEGER}) + .withProjectionPushdown(true)); + + try (ResultSet rs = stmt.executeQuery("SELECT en FROM tf_nested_projected(5)")) { + for (int i = 0; i < 5; i++) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), i % 2 == 0 ? "small" : "large"); + } + assertFalse(rs.next()); + } + assertEquals(materializedBySourceColumn[0], 0); + assertEquals(materializedBySourceColumn[1], 0); + assertEquals(materializedBySourceColumn[2], 5); + + try (ResultSet rs = + stmt.executeQuery("SELECT nested.id, nested.txt FROM tf_nested_projected(3) ORDER BY 1")) { + for (int i = 0; i < 3; i++) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), i); + assertEquals(rs.getString(2), "p" + i); + } + assertFalse(rs.next()); + } + assertEquals(materializedBySourceColumn[0], 0); + assertEquals(materializedBySourceColumn[1], 3); + assertEquals(materializedBySourceColumn[2], 5); + assertNotNull(observedProjectedColumns[0]); + assertNotNull(observedProjectedColumns[1]); + assertEquals(observedProjectedColumns[0].length, 1); + assertEquals(observedProjectedColumns[1].length, 1); + assertEquals(observedProjectedColumns[0][0], 2); + assertEquals(observedProjectedColumns[1][0], 1); + } + } + + public static void test_table_function_typed_outputs_streaming_chunks() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_core_stream", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + int end = ((Number) parameters[0]).intValue(); + return new TableBindResult(new String[] {"i", "d", "txt"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, + DuckDBColumnType.DOUBLE, + DuckDBColumnType.VARCHAR}, + end); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 37 && current < end; produced++, current++) { + out.setInt(0, produced, current); + out.setDouble(1, produced, current * 1.5d); + out.setString(2, produced, "x" + current); + } + st[0] = current; + return produced; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[] {DuckDBColumnType.INTEGER})); + + try (ResultSet rs = + stmt.executeQuery("SELECT count(*), sum(i), sum(d), count(txt) FROM tf_core_stream(5000)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 5000L); + assertEquals(rs.getLong(2), 12_497_500L); + assertEquals(rs.getDouble(3), 18_746_250d, 0.000001d); + assertEquals(rs.getLong(4), 5000L); + assertFalse(rs.next()); + } + } + } + + public static void test_range_java_error_propagation() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("range_java_error", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i"}, new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, + null); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) throws Exception { + throw new Exception("range_java_error"); + } + }); + + assertThrows(() -> { stmt.executeQuery("SELECT * FROM range_java_error(1)"); }, SQLException.class); + } + } + + public static void test_table_function_projection_pushdown() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + int[] materializedBySourceColumn = new int[] {0, 0, 0}; + conn.registerTableFunction("tf_projected", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"col1", "col2", "col3"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, + DuckDBColumnType.INTEGER, + DuckDBColumnType.INTEGER}, + ((Number) parameters[0]).intValue()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + int[] projectedColumns = new int[ctx.getColumnCount()]; + for (int i = 0; i < projectedColumns.length; i++) { + projectedColumns[i] = ctx.getColumnIndex(i); + } + return new TableState(new Object[] {0, end, projectedColumns}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + Object[] st = (Object[]) state.getState(); + int current = (int) st[0]; + int end = (int) st[1]; + int[] projectedColumns = (int[]) st[2]; + int produced = 0; + for (; produced < 1024 && current < end; produced++, current++) { + for (int projectedCol = 0; projectedCol < projectedColumns.length; projectedCol++) { + int sourceColumn = projectedColumns[projectedCol]; + materializedBySourceColumn[sourceColumn]++; + if (sourceColumn == 0) { + out.setInt(projectedCol, produced, current); + } else if (sourceColumn == 1) { + out.setInt(projectedCol, produced, current * 10); + } else { + out.setInt(projectedCol, produced, current * 100); + } + } + } + st[0] = current; + return produced; + } + }, new TableFunctionDefinition().withProjectionPushdown(true)); + + try (ResultSet rs = stmt.executeQuery("SELECT col1 FROM tf_projected(5)")) { + for (int i = 0; i < 5; i++) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), i); + } + assertFalse(rs.next()); + } + assertEquals(materializedBySourceColumn[0], 5); + assertEquals(materializedBySourceColumn[1], 0); + assertEquals(materializedBySourceColumn[2], 0); + } + } + + public static void test_table_function_projection_pushdown_mixed_schema() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + int[] materializedBySourceColumn = new int[] {0, 0, 0, 0, 0, 0}; + int[][] observedProjectedColumns = new int[1][]; + + conn.registerTableFunction("tf_projected_mixed", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult( + new String[] {"col_int", "col_txt", "col_dbl", "col_bool", "col_i64", "col_f32"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, DuckDBColumnType.VARCHAR, + DuckDBColumnType.DOUBLE, DuckDBColumnType.BOOLEAN, + DuckDBColumnType.BIGINT, DuckDBColumnType.FLOAT}, + ((Number) parameters[0]).intValue()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + int[] projectedColumns = new int[ctx.getColumnCount()]; + for (int i = 0; i < projectedColumns.length; i++) { + projectedColumns[i] = ctx.getColumnIndex(i); + } + observedProjectedColumns[0] = projectedColumns.clone(); + return new TableState(new Object[] {0, end, projectedColumns}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + Object[] st = (Object[]) state.getState(); + int current = (int) st[0]; + int end = (int) st[1]; + int[] projectedColumns = (int[]) st[2]; + int produced = 0; + for (; produced < 256 && current < end; produced++, current++) { + for (int projectedCol = 0; projectedCol < projectedColumns.length; projectedCol++) { + int sourceColumn = projectedColumns[projectedCol]; + materializedBySourceColumn[sourceColumn]++; + if (sourceColumn == 0) { + out.setInt(projectedCol, produced, current); + } else if (sourceColumn == 1) { + out.setString(projectedCol, produced, "s" + current); + } else if (sourceColumn == 2) { + out.setDouble(projectedCol, produced, current + 0.5d); + } else if (sourceColumn == 3) { + out.setBoolean(projectedCol, produced, current % 2 == 0); + } else if (sourceColumn == 4) { + out.setLong(projectedCol, produced, 1_000_000_000L + current); + } else { + out.setFloat(projectedCol, produced, current + 0.25f); + } + } + } + st[0] = current; + return produced; + } + }, new TableFunctionDefinition().withProjectionPushdown(true)); + + try (ResultSet rs = + stmt.executeQuery("SELECT col_i64, col_txt FROM tf_projected_mixed(5) ORDER BY col_i64")) { + for (int i = 0; i < 5; i++) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 1_000_000_000L + i); + assertEquals(rs.getString(2), "s" + i); + } + assertFalse(rs.next()); + } + + assertNotNull(observedProjectedColumns[0]); + assertEquals(observedProjectedColumns[0].length, 2); + int[] sortedProjected = observedProjectedColumns[0].clone(); + Arrays.sort(sortedProjected); + assertTrue(Arrays.equals(sortedProjected, new int[] {1, 4})); + assertEquals(materializedBySourceColumn[0], 0); + assertEquals(materializedBySourceColumn[1], 5); + assertEquals(materializedBySourceColumn[2], 0); + assertEquals(materializedBySourceColumn[3], 0); + assertEquals(materializedBySourceColumn[4], 5); + assertEquals(materializedBySourceColumn[5], 0); + + observedProjectedColumns[0] = null; + try (ResultSet rs = stmt.executeQuery("SELECT col_bool FROM tf_projected_mixed(5)")) { + for (int i = 0; i < 5; i++) { + assertTrue(rs.next()); + assertEquals(rs.getBoolean(1), i % 2 == 0); + } + assertFalse(rs.next()); + } + + assertNotNull(observedProjectedColumns[0]); + assertEquals(observedProjectedColumns[0].length, 1); + assertEquals(observedProjectedColumns[0][0], 3); + assertEquals(materializedBySourceColumn[0], 0); + assertEquals(materializedBySourceColumn[1], 5); + assertEquals(materializedBySourceColumn[2], 0); + assertEquals(materializedBySourceColumn[3], 5); + assertEquals(materializedBySourceColumn[4], 5); + assertEquals(materializedBySourceColumn[5], 0); + } + } + + public static void test_table_function_thread_options_smoke() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + try (Statement setup = conn.createStatement()) { + setup.execute("PRAGMA threads=4"); + } + + conn.registerTableFunction( + "tf_threadsafe", + new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, + ((Number) parameters[0]).intValue()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public synchronized int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 1024 && current < end; produced++, current++) { + out.setInt(0, produced, current); + } + st[0] = current; + return produced; + } + }, + new TableFunctionDefinition().withProjectionPushdown(true), + new TableFunctionOptions().threadSafe(true).maxThreads(4)); + + conn.registerTableFunction("tf_singlethread", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i"}, new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, + ((Number) parameters[0]).intValue()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 1024 && current < end; produced++, current++) { + out.setInt(0, produced, current); + } + st[0] = current; + return produced; + } + }, new TableFunctionOptions().threadSafe(false).maxThreads(8)); + + for (int i = 0; i < 20; i++) { + try (ResultSet rs = stmt.executeQuery("SELECT sum(i) FROM tf_threadsafe(100000)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 4999950000L); + } + try (ResultSet rs = stmt.executeQuery("SELECT sum(i) FROM tf_singlethread(100000)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 4999950000L); + } + } + } + } + + public static void test_table_function_thread_options_mixed_typed_outputs() throws Exception { + final int rowCount = 100000; + final long expectedSumInt = ((long) rowCount * (rowCount - 1)) / 2; + final double expectedSumDouble = expectedSumInt * 0.5d; + final long expectedTrueCount = rowCount / 2; + + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + try (Statement setup = conn.createStatement()) { + setup.execute("PRAGMA threads=4"); + } + + conn.registerTableFunction("tf_threadsafe_mixed", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult( + new String[] {"i", "d", "b", "txt"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, DuckDBColumnType.DOUBLE, + DuckDBColumnType.BOOLEAN, DuckDBColumnType.VARCHAR}, + ((Number) parameters[0]).intValue()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public synchronized int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 1024 && current < end; produced++, current++) { + out.setInt(0, produced, current); + out.setDouble(1, produced, current * 0.5d); + out.setBoolean(2, produced, current % 2 == 1); + out.setString(3, produced, "x"); + } + st[0] = current; + return produced; + } + }, new TableFunctionOptions().threadSafe(true).maxThreads(4)); + + conn.registerTableFunction("tf_singlethread_mixed", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult( + new String[] {"i", "d", "b", "txt"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, DuckDBColumnType.DOUBLE, + DuckDBColumnType.BOOLEAN, DuckDBColumnType.VARCHAR}, + ((Number) parameters[0]).intValue()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 1024 && current < end; produced++, current++) { + out.setInt(0, produced, current); + out.setDouble(1, produced, current * 0.5d); + out.setBoolean(2, produced, current % 2 == 1); + out.setString(3, produced, "x"); + } + st[0] = current; + return produced; + } + }, new TableFunctionOptions().threadSafe(false).maxThreads(8)); + + for (int i = 0; i < 20; i++) { + try (ResultSet rs = + stmt.executeQuery("SELECT sum(i), sum(d), sum(CASE WHEN b THEN 1 ELSE 0 END), count(txt) " + + "FROM tf_threadsafe_mixed(" + rowCount + ")")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), expectedSumInt); + assertEquals(rs.getDouble(2), expectedSumDouble, 0.0001d); + assertEquals(rs.getLong(3), expectedTrueCount); + assertEquals(rs.getLong(4), (long) rowCount); + assertFalse(rs.next()); + } + try (ResultSet rs = + stmt.executeQuery("SELECT sum(i), sum(d), sum(CASE WHEN b THEN 1 ELSE 0 END), count(txt) " + + "FROM tf_singlethread_mixed(" + rowCount + ")")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), expectedSumInt); + assertEquals(rs.getDouble(2), expectedSumDouble, 0.0001d); + assertEquals(rs.getLong(3), expectedTrueCount); + assertEquals(rs.getLong(4), (long) rowCount); + assertFalse(rs.next()); + } + } + } + } + + public static void test_table_function_mixed_typed_outputs_exception_propagation() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_mixed_error", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult( + new String[] {"i", "txt"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, DuckDBColumnType.VARCHAR}, + ((Number) parameters[0]).intValue()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) throws Exception { + int[] st = (int[]) state.getState(); + int current = st[0]; + if (current >= 10) { + throw new Exception("tf_mixed_error"); + } + int produced = 0; + for (; produced < 10 && current < st[1]; produced++, current++) { + out.setInt(0, produced, current); + out.setString(1, produced, "e" + current); + } + st[0] = current; + return produced; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[] {DuckDBColumnType.INTEGER})); + + assertThrows(() -> { stmt.executeQuery("SELECT * FROM tf_mixed_error(100)"); }, SQLException.class); + } + } + + public static void test_table_function_init_exception_message_propagation() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_init_error", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i"}, new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, + ((Number) parameters[0]).intValue()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) throws Exception { + throw new Exception("tf_init_error"); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + return 0; + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[] {DuckDBColumnType.INTEGER})); + + String message = + assertThrows(() -> { stmt.executeQuery("SELECT * FROM tf_init_error(3)"); }, SQLException.class); + assertTrue(message != null && message.contains("tf_init_error")); + } + } + + public static void test_table_function_main_exception_message_propagation() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerTableFunction("tf_main_error", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i"}, new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, + ((Number) parameters[0]).intValue()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + return new TableState(new int[] {0, ((Number) bind.getBindState()).intValue()}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) throws Exception { + throw new Exception("tf_main_error"); + } + }, new TableFunctionDefinition().withParameterTypes(new DuckDBColumnType[] {DuckDBColumnType.INTEGER})); + + String message = + assertThrows(() -> { stmt.executeQuery("SELECT * FROM tf_main_error(3)"); }, SQLException.class); + assertTrue(message != null && message.contains("tf_main_error")); + } + } + + private static long usedHeapBytes() { + Runtime runtime = Runtime.getRuntime(); + return runtime.totalMemory() - runtime.freeMemory(); + } + + public static void test_udf_lifecycle_hardening_repetition() throws Exception { + long baselineUsedHeap = -1; + final int iterations = 250; + for (int i = 0; i < iterations; i++) { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("life_scalar", (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, args[0].getInt(row) + 1); + } + }); + conn.registerTableFunction("life_table", new org.duckdb.udf.TableFunction() { + @Override + public TableBindResult bind(org.duckdb.udf.BindContext ctx, Object[] parameters) { + return new TableBindResult(new String[] {"i"}, + new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, + ((Number) parameters[0]).intValue()); + } + + @Override + public TableState init(org.duckdb.udf.InitContext ctx, TableBindResult bind) { + int end = ((Number) bind.getBindState()).intValue(); + return new TableState(new int[] {0, end}); + } + + @Override + public int produce(TableState state, org.duckdb.UdfOutputAppender out) { + int[] st = (int[]) state.getState(); + int current = st[0]; + int end = st[1]; + int produced = 0; + for (; produced < 1024 && current < end; produced++, current++) { + out.setInt(0, produced, current); + } + st[0] = current; + return produced; + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT sum(life_scalar(i::INTEGER)) FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 500500L); + } + try (ResultSet rs = stmt.executeQuery("SELECT sum(i) FROM life_table(1000)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 499500L); + } + } + if (i % 50 == 0) { + System.gc(); + } + if (i == 25) { + System.gc(); + Thread.sleep(20); + baselineUsedHeap = usedHeapBytes(); + } + } + + System.gc(); + Thread.sleep(50); + long finalUsedHeap = usedHeapBytes(); + assertTrue(baselineUsedHeap > 0); + // Heuristic guardrail: allow the JVM heap to fluctuate, but not grow without bound. + assertTrue(finalUsedHeap <= baselineUsedHeap + (64L * 1024L * 1024L)); + } + + public static void test_java_scalar_udf_add_one() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("add_one", (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, args[0].getInt(row) + 1); + } + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT sum(add_one(i::INTEGER)) FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 500500L); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_output_writer_api() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("add_one_writer", new org.duckdb.udf.ScalarUdf() { + @Override + public void apply(org.duckdb.udf.UdfContext ctx, UdfReader[] args, UdfScalarWriter out, int rowCount) { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, args[0].getInt(row) + 1); + } + } + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT sum(add_one_writer(i::INTEGER)) FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 500500L); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_output_writer_object_methods() throws Exception { + UdfLogicalType decimal18_3 = UdfLogicalType.decimal(18, 3); + UUID uuid = UUID.fromString("550e8400-e29b-41d4-a716-446655440000"); + + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("writer_obj_decimal", new UdfLogicalType[] {decimal18_3}, decimal18_3, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setObject(row, args[0].isNull(row) ? null : args[0].getBigDecimal(row)); + } + }); + conn.registerScalarUdf("writer_obj_date", new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.DATE)}, + UdfLogicalType.of(DuckDBColumnType.DATE), (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setLocalDate(row, + args[0].isNull(row) ? null : args[0].getLocalDate(row)); + } + }); + conn.registerScalarUdf("writer_obj_time", new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.TIME)}, + UdfLogicalType.of(DuckDBColumnType.TIME), (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setLocalTime(row, + args[0].isNull(row) ? null : args[0].getLocalTime(row)); + } + }); + conn.registerScalarUdf( + "writer_obj_timetz", new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.TIME_WITH_TIME_ZONE)}, + UdfLogicalType.of(DuckDBColumnType.TIME_WITH_TIME_ZONE), (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setOffsetTime(row, args[0].isNull(row) ? null : args[0].getOffsetTime(row)); + } + }); + conn.registerScalarUdf( + "writer_obj_ts", new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.TIMESTAMP)}, + UdfLogicalType.of(DuckDBColumnType.TIMESTAMP), (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setLocalDateTime(row, args[0].isNull(row) ? null : args[0].getLocalDateTime(row)); + } + }); + conn.registerScalarUdf( + "writer_obj_tstz", new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE)}, + UdfLogicalType.of(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE), (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setOffsetDateTime(row, args[0].isNull(row) ? null : args[0].getOffsetDateTime(row)); + } + }); + conn.registerScalarUdf("writer_obj_uuid", new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.UUID)}, + UdfLogicalType.of(DuckDBColumnType.UUID), (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setUUID(row, args[0].isNull(row) ? null : args[0].getUUID(row)); + } + }); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT " + + "writer_obj_decimal(123.450::DECIMAL(18,3)) = 123.450::DECIMAL(18,3), " + + "writer_obj_decimal(NULL::DECIMAL(18,3)) IS NULL, " + + "writer_obj_date(DATE '2024-01-03') = DATE '2024-01-03', " + + "writer_obj_time(TIME '01:02:03.123456') = TIME '01:02:03.123456', " + + "writer_obj_timetz(TIME WITH TIME ZONE '01:02:03+05:30') = TIME WITH TIME ZONE '01:02:03+05:30', " + + "writer_obj_ts(TIMESTAMP '2024-01-03 04:05:06.123456') = TIMESTAMP '2024-01-03 04:05:06.123456', " + + "writer_obj_tstz(TIMESTAMP WITH TIME ZONE '2024-01-03 04:05:06+00') = TIMESTAMP WITH TIME ZONE " + + "'2024-01-03 04:05:06+00', " + + "writer_obj_uuid('" + uuid + "'::UUID) = '" + uuid + "'::UUID")) { + assertTrue(rs.next()); + for (int col = 1; col <= 8; col++) { + assertEquals(rs.getBoolean(col), true); + } + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_output_writer_decimal_number_coercion_exact() throws Exception { + final UdfLogicalType decimal38_0 = UdfLogicalType.decimal(38, 0); + final BigInteger hugeIntValue = new BigInteger("12345678901234567890123456789012345678"); + + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("writer_obj_decimal_num_exact", + new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.INTEGER)}, decimal38_0, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + int selector = args[0].getInt(row); + if (selector == 0) { + out.setObject(row, Long.valueOf(9007199254740993L)); + } else { + out.setObject(row, hugeIntValue); + } + } + }); + + try ( + ResultSet rs = stmt.executeQuery( + "SELECT CAST(writer_obj_decimal_num_exact(i::INTEGER) AS VARCHAR) FROM range(2) t(i) ORDER BY i")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "9007199254740993"); + assertTrue(rs.next()); + assertEquals(rs.getString(1), hugeIntValue.toString()); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_reader_api() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("add_one_reader", new org.duckdb.udf.ScalarUdf() { + @Override + public void apply(org.duckdb.udf.UdfContext ctx, UdfReader[] args, UdfScalarWriter out, int rowCount) { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, args[0].getInt(row) + 1); + } + } + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT sum(add_one_reader(i::INTEGER)) FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 500500L); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_reader_object_accessors() throws Exception { + UdfLogicalType decimal18_3 = UdfLogicalType.decimal(18, 3); + UUID expectedUuid = UUID.fromString("550e8400-e29b-41d4-a716-446655440000"); + LocalDate expectedDate = LocalDate.of(2024, 1, 3); + LocalTime expectedTime = LocalTime.of(1, 2, 3, 123456000); + OffsetTime expectedOffsetTime = OffsetTime.parse("01:02:03+05:30"); + LocalDateTime expectedTimestamp = LocalDateTime.of(2024, 1, 3, 4, 5, 6, 123456000); + Instant expectedTzTimestampInstant = Instant.parse("2024-01-03T04:05:06Z"); + + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf( + "reader_objects_ok", + new UdfLogicalType[] {decimal18_3, UdfLogicalType.of(DuckDBColumnType.DATE), + UdfLogicalType.of(DuckDBColumnType.TIME), + UdfLogicalType.of(DuckDBColumnType.TIME_WITH_TIME_ZONE), + UdfLogicalType.of(DuckDBColumnType.TIMESTAMP), + UdfLogicalType.of(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE), + UdfLogicalType.of(DuckDBColumnType.UUID)}, + UdfLogicalType.of(DuckDBColumnType.BOOLEAN), (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row) || args[1].isNull(row) || args[2].isNull(row) || args[3].isNull(row) || + args[4].isNull(row) || args[5].isNull(row) || args[6].isNull(row)) { + out.setNull(row); + continue; + } + + boolean ok = args[0].getBigDecimal(row).compareTo(new BigDecimal("123.450")) == 0; + ok = ok && expectedDate.equals(args[1].getLocalDate(row)); + java.util.Date dateValue = args[1].getDate(row); + ok = ok && dateValue instanceof java.sql.Date && + expectedDate.equals(((java.sql.Date) dateValue).toLocalDate()); + ok = ok && expectedTime.equals(args[2].getLocalTime(row)); + ok = ok && expectedOffsetTime.equals(args[3].getOffsetTime(row)); + ok = ok && expectedTimestamp.equals(args[4].getLocalDateTime(row)); + OffsetDateTime offsetDateTime = args[5].getOffsetDateTime(row); + ok = ok && offsetDateTime != null && + expectedTzTimestampInstant.equals(offsetDateTime.toInstant()); + ok = ok && expectedUuid.equals(args[6].getUUID(row)); + + out.setBoolean(row, ok); + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT reader_objects_ok(123.450::DECIMAL(18,3), DATE '2024-01-03', " + + "TIME '01:02:03.123456', '01:02:03+05:30'::TIME WITH TIME ZONE, " + + "TIMESTAMP '2024-01-03 04:05:06.123456', " + + "'2024-01-03 04:05:06+00'::TIMESTAMP WITH TIME ZONE, " + + "'550e8400-e29b-41d4-a716-446655440000'::UUID), " + + "reader_objects_ok(NULL::DECIMAL(18,3), DATE '2024-01-03', " + + "TIME '01:02:03.123456', '01:02:03+05:30'::TIME WITH TIME ZONE, " + + "TIMESTAMP '2024-01-03 04:05:06.123456', " + + "'2024-01-03 04:05:06+00'::TIMESTAMP WITH TIME ZONE, " + + "'550e8400-e29b-41d4-a716-446655440000'::UUID) IS NULL")) { + assertTrue(rs.next()); + assertEquals(rs.getBoolean(1), true); + assertEquals(rs.getBoolean(2), true); + assertFalse(rs.next()); + } + + conn.registerScalarUdf("reader_object_type_error", new DuckDBColumnType[] {DuckDBColumnType.INTEGER}, + DuckDBColumnType.INTEGER, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + args[0].getUUID(row); + out.setInt(row, 0); + } + }); + + assertThrows(() -> { + try (ResultSet rs = stmt.executeQuery("SELECT reader_object_type_error(1)")) { + rs.next(); + } + }, SQLException.class); + } + } + + public static void test_java_scalar_udf_logical_type_registration() throws Exception { + UdfLogicalType decimal18_2 = UdfLogicalType.decimal(18, 2); + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("add_one_logical", + new UdfLogicalType[] {UdfLogicalType.of(DuckDBColumnType.INTEGER)}, + UdfLogicalType.of(DuckDBColumnType.INTEGER), (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, args[0].getInt(row) + 1); + } + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT sum(add_one_logical(i::INTEGER)) FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 500500L); + assertFalse(rs.next()); + } + + conn.registerScalarUdf("dec_identity_logical", new UdfLogicalType[] {decimal18_2}, decimal18_2, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + } else { + out.setBigDecimal(row, args[0].getBigDecimal(row)); + } + } + }); + + try (ResultSet rs = + stmt.executeQuery("SELECT dec_identity_logical(42.75::DECIMAL(18,2)) = 42.75::DECIMAL(18,2)")) { + assertTrue(rs.next()); + assertEquals(rs.getBoolean(1), true); + assertFalse(rs.next()); + } + + assertThrows( + () + -> conn.registerScalarUdf( + "bad_logical_arg", + new UdfLogicalType[] {UdfLogicalType.list(UdfLogicalType.of(DuckDBColumnType.INTEGER))}, + UdfLogicalType.of(DuckDBColumnType.INTEGER), (ctx, args, out, rowCount) -> {}), + SQLFeatureNotSupportedException.class); + } + } + + public static void test_udf_logical_type_decimal_factory_validation() throws Exception { + UdfLogicalType decimal = UdfLogicalType.decimal(18, 2); + assertEquals(decimal.getType(), DuckDBColumnType.DECIMAL); + assertEquals(decimal.getDecimalWidth(), 18); + assertEquals(decimal.getDecimalScale(), 2); + + assertThrows(() -> { UdfLogicalType.decimal(0, 0); }, IllegalArgumentException.class); + assertThrows(() -> { UdfLogicalType.decimal(39, 0); }, IllegalArgumentException.class); + assertThrows(() -> { UdfLogicalType.decimal(18, -1); }, IllegalArgumentException.class); + assertThrows(() -> { UdfLogicalType.decimal(18, 19); }, IllegalArgumentException.class); + } + + public static void test_java_scalar_udf_decimal_exact_width_scale_and_overflow() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + int[] widths = new int[] {4, 9, 18, 30}; + int[] scales = new int[] {1, 4, 6, 10}; + String[] literals = + new String[] {"-99.9", "12345.6789", "123456789012.123456", "12345678901234567890.1234567890"}; + + for (int i = 0; i < widths.length; i++) { + int width = widths[i]; + int scale = scales[i]; + UdfLogicalType decimalType = UdfLogicalType.decimal(width, scale); + String fnName = "f_decimal_exact_" + width + "_" + scale; + String typedLiteral = literals[i] + "::DECIMAL(" + width + "," + scale + ")"; + + conn.registerScalarUdf(fnName, new UdfLogicalType[] {decimalType}, decimalType, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + } else { + out.setBigDecimal(row, args[0].getBigDecimal(row)); + } + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT CAST(" + fnName + "(" + typedLiteral + ") AS VARCHAR), " + + "CAST(" + typedLiteral + " AS VARCHAR), " + fnName + + "(NULL::DECIMAL(" + width + "," + scale + ")) IS NULL")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), rs.getString(2)); + assertEquals(rs.getBoolean(3), true); + assertFalse(rs.next()); + } + } + + UdfLogicalType decimal4_1 = UdfLogicalType.decimal(4, 1); + conn.registerScalarUdf("f_decimal_overflow_4_1", new UdfLogicalType[] {decimal4_1}, decimal4_1, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + } else { + out.setBigDecimal(row, new BigDecimal("1000.0")); + } + } + }); + + assertThrows(() -> { + try (ResultSet rs = stmt.executeQuery("SELECT f_decimal_overflow_4_1(1.0::DECIMAL(4,1))")) { + rs.next(); + } + }, SQLException.class); + } + } + + public static void test_java_scalar_udf_decimal_boundary_values_per_width_scale() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + int[] widths = new int[] {4, 9, 18, 30, 38}; + int[] scales = new int[] {1, 4, 6, 10, 10}; + String[] maxValues = + new String[] {"999.9", "99999.9999", "999999999999.999999", "99999999999999999999.9999999999", + "9999999999999999999999999999.9999999999"}; + String[] minValues = + new String[] {"-999.9", "-99999.9999", "-999999999999.999999", "-99999999999999999999.9999999999", + "-9999999999999999999999999999.9999999999"}; + + for (int i = 0; i < widths.length; i++) { + int width = widths[i]; + int scale = scales[i]; + UdfLogicalType decimalType = UdfLogicalType.decimal(width, scale); + String fnName = "f_decimal_boundaries_" + width + "_" + scale; + String maxTyped = maxValues[i] + "::DECIMAL(" + width + "," + scale + ")"; + String minTyped = minValues[i] + "::DECIMAL(" + width + "," + scale + ")"; + + conn.registerScalarUdf(fnName, new UdfLogicalType[] {decimalType}, decimalType, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + } else { + out.setBigDecimal(row, args[0].getBigDecimal(row)); + } + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT CAST(" + fnName + "(" + maxTyped + ") AS VARCHAR), " + + "CAST(" + maxTyped + " AS VARCHAR), CAST(" + fnName + "(" + + minTyped + ") AS VARCHAR), CAST(" + minTyped + " AS VARCHAR)")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), rs.getString(2)); + assertEquals(rs.getString(3), rs.getString(4)); + assertFalse(rs.next()); + } + } + } + } + + public static void test_java_scalar_udf_reader_object_accessors_null_special_handling() throws Exception { + UdfLogicalType decimal18_3 = UdfLogicalType.decimal(18, 3); + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("reader_objects_nulls", + new UdfLogicalType[] {decimal18_3, UdfLogicalType.of(DuckDBColumnType.DATE), + UdfLogicalType.of(DuckDBColumnType.TIME), + UdfLogicalType.of(DuckDBColumnType.TIME_WITH_TIME_ZONE), + UdfLogicalType.of(DuckDBColumnType.TIMESTAMP), + UdfLogicalType.of(DuckDBColumnType.TIMESTAMP_WITH_TIME_ZONE), + UdfLogicalType.of(DuckDBColumnType.UUID)}, + UdfLogicalType.of(DuckDBColumnType.BOOLEAN), (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + boolean ok = args[0].getBigDecimal(row) == null; + ok = ok && args[1].getDate(row) == null; + ok = ok && args[1].getLocalDate(row) == null; + ok = ok && args[2].getLocalTime(row) == null; + ok = ok && args[3].getOffsetTime(row) == null; + ok = ok && args[4].getLocalDateTime(row) == null; + ok = ok && args[5].getOffsetDateTime(row) == null; + ok = ok && args[6].getUUID(row) == null; + out.setBoolean(row, ok); + } + }, new UdfOptions().nullSpecialHandling(true)); + + try (ResultSet rs = stmt.executeQuery("SELECT reader_objects_nulls(NULL::DECIMAL(18,3), NULL::DATE, " + + "NULL::TIME, NULL::TIME WITH TIME ZONE, NULL::TIMESTAMP, " + + "NULL::TIMESTAMP WITH TIME ZONE, NULL::UUID)")) { + assertTrue(rs.next()); + assertEquals(rs.getBoolean(1), true); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_null_default_handling() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("f_default_null", (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, 5); + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT f_default_null(NULL::INTEGER)")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1), null); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_null_special_handling() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("f_special_null", (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (!args[0].isNull(row)) { + throw new IllegalStateException("Expected NULL input row"); + } + out.setInt(row, 5); + } + }, new UdfOptions().nullSpecialHandling(true)); + + try (ResultSet rs = stmt.executeQuery("SELECT f_special_null(NULL::INTEGER)")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 5); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_exception_abort() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("f_throw_abort", + (ctx, args, out, rowCount) -> { throw new RuntimeException("kaboom"); }); + + assertThrows(() -> { + stmt.executeQuery("SELECT f_throw_abort(i::INTEGER) FROM range(3) t(i)"); + }, SQLException.class); + } + } + + public static void test_java_scalar_udf_exception_return_null() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("f_throw_null", (ctx, args, out, rowCount) -> { + throw new RuntimeException("kaboom"); + }, new UdfOptions().returnNullOnException(true)); + + try (ResultSet rs = stmt.executeQuery( + "SELECT count(*) total, count(f_throw_null(i::INTEGER)) non_null FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 1000L); + assertEquals(rs.getLong(2), 0L); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_deterministic_caches_constant_input() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("f_rand_det", (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, ThreadLocalRandom.current().nextInt()); + } + }); + + try (ResultSet rs = + stmt.executeQuery("SELECT count(DISTINCT f_rand_det(1::INTEGER)) FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 1L); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_volatile_varies_per_row() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("f_rand_vol", (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, ThreadLocalRandom.current().nextInt()); + } + }, new UdfOptions().deterministic(false)); + + try (ResultSet rs = + stmt.executeQuery("SELECT count(DISTINCT f_rand_vol(1::INTEGER)) FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertTrue(rs.getLong(1) > 1L); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_add2() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("add2", new DuckDBColumnType[] {DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER}, + DuckDBColumnType.INTEGER, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row) || args[1].isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, args[0].getInt(row) + args[1].getInt(row)); + } + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT sum(add2(i::INTEGER, 10::INTEGER)) FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 509500L); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_mul3() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf( + "mul3", + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER}, + DuckDBColumnType.INTEGER, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row) || args[1].isNull(row) || args[2].isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, args[0].getInt(row) * args[1].getInt(row) * args[2].getInt(row)); + } + } + }); + + try (ResultSet rs = + stmt.executeQuery("SELECT sum(mul3(i::INTEGER, 2::INTEGER, 3::INTEGER)) FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 2997000L); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_registration_overloads_column_types() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("plus1_overload", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, args[0].getInt(row) + 1); + } + }); + + conn.registerScalarUdf("sum2_overload", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, + DuckDBColumnType.INTEGER, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, args[0].getInt(row) + args[1].getInt(row)); + } + }); + + conn.registerScalarUdf("sum3_overload", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, + DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, + args[0].getInt(row) + args[1].getInt(row) + args[2].getInt(row)); + } + }); + + conn.registerScalarUdf("sum4_overload", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, + DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, args[0].getInt(row) + args[1].getInt(row) + + args[2].getInt(row) + args[3].getInt(row)); + } + }); + + try (ResultSet rs = + stmt.executeQuery("SELECT plus1_overload(41::INTEGER), sum2_overload(1::INTEGER,2::INTEGER), " + + "sum3_overload(1::INTEGER,2::INTEGER,3::INTEGER), " + + "sum4_overload(1::INTEGER,2::INTEGER,3::INTEGER,4::INTEGER)")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 42); + assertEquals(rs.getInt(2), 3); + assertEquals(rs.getInt(3), 6); + assertEquals(rs.getInt(4), 10); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_registration_overloads_logical_types() throws Exception { + UdfLogicalType integerType = UdfLogicalType.of(DuckDBColumnType.INTEGER); + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("logical_sum2_overload", integerType, integerType, integerType, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, args[0].getInt(row) + args[1].getInt(row)); + } + }); + + conn.registerScalarUdf("logical_sum3_overload", integerType, integerType, integerType, integerType, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, + args[0].getInt(row) + args[1].getInt(row) + args[2].getInt(row)); + } + }); + + conn.registerScalarUdf("logical_sum4_overload", integerType, integerType, integerType, integerType, + integerType, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, args[0].getInt(row) + args[1].getInt(row) + + args[2].getInt(row) + args[3].getInt(row)); + } + }); + + try (ResultSet rs = + stmt.executeQuery("SELECT logical_sum2_overload(1::INTEGER,2::INTEGER), " + + "logical_sum3_overload(1::INTEGER,2::INTEGER,3::INTEGER), " + + "logical_sum4_overload(1::INTEGER,2::INTEGER,3::INTEGER,4::INTEGER)")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 3); + assertEquals(rs.getInt(2), 6); + assertEquals(rs.getInt(3), 10); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_arity_validation() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("add2_arity_check", + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER}, + DuckDBColumnType.INTEGER, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, args[0].getInt(row) + args[1].getInt(row)); + } + }); + + assertThrows(() -> { stmt.executeQuery("SELECT add2_arity_check(1::INTEGER)"); }, SQLException.class); + } + } - ResultSet rs = stmt.executeQuery("SELECT 42 as a"); - assertFalse(stmt.isClosed()); - assertFalse(rs.isClosed()); + public static void test_java_scalar_udf_zero_arguments_registration() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("zero_const", DuckDBColumnType.INTEGER, (ctx, args, out, rowCount) -> { + assertEquals(args.length, 0); + for (int row = 0; row < rowCount; row++) { + out.setInt(row, 42); + } + }); - assertTrue(rs.next()); - int res = rs.getInt(1); - assertEquals(res, 42); - assertFalse(rs.wasNull()); + conn.registerScalarUdf("zero_throw_null", + UdfLogicalType.of(DuckDBColumnType.INTEGER), (ctx, args, out, rowCount) -> { + throw new RuntimeException("kaboom"); + }, new UdfOptions().returnNullOnException(true)); - res = rs.getInt(1); - assertEquals(res, 42); - assertFalse(rs.wasNull()); + try (ResultSet rs = stmt.executeQuery("SELECT sum(zero_const()) FROM range(100)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 4200L); + assertFalse(rs.next()); + } - res = rs.getInt("a"); - assertEquals(res, 42); - assertFalse(rs.wasNull()); + try (ResultSet rs = + stmt.executeQuery("SELECT count(*) total, count(zero_throw_null()) non_null FROM range(100)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 100L); + assertEquals(rs.getLong(2), 0L); + assertFalse(rs.next()); + } + } + } - assertThrows(() -> rs.getInt(0), SQLException.class); + public static void test_java_scalar_udf_varargs_registration() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdfVarArgs("sum_varargs", DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + int sum = 0; + boolean anyNull = false; + for (UdfReader arg : args) { + if (arg.isNull(row)) { + anyNull = true; + break; + } + sum += arg.getInt(row); + } + if (anyNull) { + out.setNull(row); + } else { + out.setInt(row, sum); + } + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT sum(sum_varargs()) FROM range(100)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 0L); + assertFalse(rs.next()); + } - assertThrows(() -> rs.getInt(2), SQLException.class); + try (ResultSet rs = stmt.executeQuery("SELECT sum(sum_varargs(i::INTEGER)) FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 499500L); + assertFalse(rs.next()); + } - assertThrows(() -> rs.getInt("b"), SQLException.class); + try (ResultSet rs = stmt.executeQuery( + "SELECT sum(sum_varargs(i::INTEGER, 1::INTEGER, 2::INTEGER)) FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 502500L); + assertFalse(rs.next()); + } + } + } - assertFalse(rs.next()); - assertFalse(rs.next()); + public static void test_java_scalar_udf_varargs_registration_validation() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class)) { + assertThrows(() + -> conn.registerScalarUdf( + "bad_varargs", + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER}, + DuckDBColumnType.INTEGER, + (ctx, args, out, rowCount) + -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, 0); + } + }, + new UdfOptions().varArgs(true)), + SQLException.class); + } + } - rs.close(); - rs.close(); - assertTrue(rs.isClosed()); + public static void test_java_scalar_udf_java_class_type_mapper_registration() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("class_plus1", Integer.class, Integer.class, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setInt(row, args[0].getInt(row) + 1); + } + }); - assertThrows(() -> rs.getInt(1), SQLException.class); + conn.registerScalarUdf("class_concat", new Class[] {String.class, String.class}, String.class, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setString(row, args[0].getString(row) + args[1].getString(row)); + } + }); + + conn.registerScalarUdf("class_zero", Integer.class, (ctx, args, out, rowCount) -> { + assertEquals(args.length, 0); + for (int row = 0; row < rowCount; row++) { + out.setInt(row, 7); + } + }); - stmt.close(); - stmt.close(); - assertTrue(stmt.isClosed()); + try (ResultSet rs = stmt.executeQuery( + "SELECT class_plus1(41::INTEGER), class_concat('a','b'), sum(class_zero()) FROM range(10)")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 42); + assertEquals(rs.getString(2), "ab"); + assertEquals(rs.getLong(3), 70L); + assertFalse(rs.next()); + } + } + } - conn.close(); - conn.close(); - assertFalse(conn.isValid(0)); - assertTrue(conn.isClosed()); + public static void test_java_scalar_udf_java_class_type_mapper_decimal_requires_explicit_logical_type() + throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class)) { + assertThrows(() + -> conn.registerScalarUdf("class_decimal_forbidden", BigDecimal.class, BigDecimal.class, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setBigDecimal(row, BigDecimal.ONE); + } + }), + SQLException.class); + } + } - assertThrows(conn::createStatement, SQLException.class); + public static void test_java_scalar_udf_java_class_type_mapper_biginteger_roundtrip() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("class_biginteger_add1", BigInteger.class, BigInteger.class, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + } else { + out.setObject(row, args[0].getBigInteger(row).add(BigInteger.ONE)); + } + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT CAST(class_biginteger_add1(" + + "170141183460469231731687303715884105726::HUGEINT) AS VARCHAR), " + + "class_biginteger_add1(NULL::HUGEINT) IS NULL")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "170141183460469231731687303715884105727"); + assertEquals(rs.getBoolean(2), true); + assertFalse(rs.next()); + } + } } - public static void test_execute_exception() throws Exception { - Connection conn = DriverManager.getConnection(JDBC_URL); - Statement stmt = conn.createStatement(); + public static void test_java_scalar_udf_ergonomic_registration_preserves_options_semantics() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("class_null_special_opt", Integer.class, + Integer.class, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setInt(row, 99); + } else { + out.setInt(row, args[0].getInt(row)); + } + } + }, new UdfOptions().nullSpecialHandling(true)); + + conn.registerScalarUdfVarArgs("varargs_throw_opt", DuckDBColumnType.INTEGER, + DuckDBColumnType.INTEGER, (ctx, args, out, rowCount) -> { + throw new RuntimeException("kaboom"); + }, new UdfOptions().returnNullOnException(true).deterministic(false)); + + try (ResultSet rs = stmt.executeQuery("SELECT class_null_special_opt(NULL::INTEGER)")) { + assertTrue(rs.next()); + assertEquals(rs.getInt(1), 99); + assertFalse(rs.next()); + } - assertThrows(() -> { - ResultSet rs = stmt.executeQuery("SELECT"); - rs.next(); - }, SQLException.class); + try (ResultSet rs = stmt.executeQuery( + "SELECT count(*) total, count(varargs_throw_opt(i::INTEGER)) non_null FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 1000L); + assertEquals(rs.getLong(2), 0L); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_arity_above_four() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf( + "sum5", + new DuckDBColumnType[] {DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER, + DuckDBColumnType.INTEGER, DuckDBColumnType.INTEGER}, + DuckDBColumnType.INTEGER, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row) || args[1].isNull(row) || args[2].isNull(row) || args[3].isNull(row) || + args[4].isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, args[0].getInt(row) + args[1].getInt(row) + args[2].getInt(row) + + args[3].getInt(row) + args[4].getInt(row)); + } + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT sum(sum5(i::INTEGER, 1, 2, 3, 4)) FROM range(1000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 509500L); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_core_type_registration_surface() throws Exception { + DuckDBColumnType[] unsupportedTypes = new DuckDBColumnType[] { + DuckDBColumnType.INTERVAL, DuckDBColumnType.LIST, DuckDBColumnType.ARRAY, DuckDBColumnType.STRUCT, + DuckDBColumnType.MAP, DuckDBColumnType.UNION, DuckDBColumnType.ENUM}; + + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + for (DuckDBColumnType supportedType : scalarCoreTypes()) { + String fnName = "f_supported_" + supportedType.name().toLowerCase(); + conn.registerScalarUdf(fnName, new DuckDBColumnType[] {supportedType}, supportedType, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setNull(row); + } + }); + try (ResultSet rs = + stmt.executeQuery("SELECT " + fnName + "(" + nonNullLiteralForType(supportedType) + ")")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1), null); + assertFalse(rs.next()); + } + } + + for (DuckDBColumnType unsupportedType : unsupportedTypes) { + String fnName = "f_unsupported_" + unsupportedType.name().toLowerCase(); + assertThrows(() -> { + conn.registerScalarUdf(fnName, new DuckDBColumnType[] {unsupportedType}, unsupportedType, + (ctx, args, out, rowCount) -> {}); + }, SQLFeatureNotSupportedException.class); + } + } + } + + public static void test_java_scalar_udf_extended_type_registration_surface() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + for (DuckDBColumnType supportedType : scalarExtendedTypes()) { + String fnName = "f_supported_ext_" + supportedType.name().toLowerCase(); + conn.registerScalarUdf(fnName, new DuckDBColumnType[] {supportedType}, supportedType, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setNull(row); + } + }); + try (ResultSet rs = stmt.executeQuery("SELECT " + fnName + "(" + + nonNullLiteralForExtendedType(supportedType) + ")")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1), null); + assertFalse(rs.next()); + } + } + } + } + + public static void test_java_scalar_udf_unsigned_and_special_type_registration_surface() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + for (DuckDBColumnType supportedType : scalarUnsignedAndSpecialTypes()) { + String fnName = "f_supported_unsigned_special_" + supportedType.name().toLowerCase(); + conn.registerScalarUdf(fnName, new DuckDBColumnType[] {supportedType}, supportedType, + (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + out.setNull(row); + } + }); + try (ResultSet rs = stmt.executeQuery("SELECT " + fnName + "(" + + nonNullLiteralForUnsignedAndSpecialType(supportedType) + ")")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1), null); + assertFalse(rs.next()); + } + } + } + } + + public static void test_java_scalar_udf_extended_roundtrip_and_nulls() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + for (DuckDBColumnType type : scalarExtendedTypes()) { + String fnName = "f_identity_ext_" + type.name().toLowerCase(); + conn.registerScalarUdf(fnName, new DuckDBColumnType[] {type}, type, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + continue; + } + switch (type) { + case DECIMAL: + out.setBigDecimal(row, args[0].getBigDecimal(row)); + break; + case BLOB: + out.setBytes(row, args[0].getBytes(row)); + break; + case DATE: + out.setInt(row, args[0].getInt(row)); + break; + case TIME: + case TIME_NS: + case TIMESTAMP: + case TIMESTAMP_S: + case TIMESTAMP_MS: + case TIMESTAMP_NS: + out.setLong(row, args[0].getLong(row)); + break; + default: + throw new IllegalStateException("Unexpected extended type: " + type); + } + } + }); + + String literal = nonNullLiteralForExtendedType(type); + String nullLiteral = nullLiteralForType(type); + try (ResultSet rs = stmt.executeQuery("SELECT " + fnName + "(" + literal + ") = " + literal + ", " + + fnName + "(" + nullLiteral + ") IS NULL")) { + assertTrue(rs.next()); + assertEquals(rs.getBoolean(1), true); + assertEquals(rs.getBoolean(2), true); + assertFalse(rs.next()); + } + } + } + } + + public static void test_java_scalar_udf_unsigned_and_special_roundtrip_and_nulls() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + for (DuckDBColumnType type : scalarUnsignedAndSpecialTypes()) { + String fnName = "f_identity_unsigned_special_" + type.name().toLowerCase(); + conn.registerScalarUdf(fnName, new DuckDBColumnType[] {type}, type, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + continue; + } + switch (type) { + case UTINYINT: + case USMALLINT: + out.setInt(row, args[0].getInt(row)); + break; + case UINTEGER: + case UBIGINT: + case TIME_WITH_TIME_ZONE: + case TIMESTAMP_WITH_TIME_ZONE: + out.setLong(row, args[0].getLong(row)); + break; + case HUGEINT: + case UHUGEINT: + case UUID: + out.setBytes(row, args[0].getBytes(row)); + break; + default: + throw new IllegalStateException("Unexpected unsigned/special type: " + type); + } + } + }); + + String literal = nonNullLiteralForUnsignedAndSpecialType(type); + String nullLiteral = nullLiteralForType(type); + try (ResultSet rs = stmt.executeQuery("SELECT " + fnName + "(" + literal + ") = " + literal + ", " + + fnName + "(" + nullLiteral + ") IS NULL")) { + assertTrue(rs.next()); + assertEquals(rs.getBoolean(1), true); + assertEquals(rs.getBoolean(2), true); + assertFalse(rs.next()); + } + } + } + } + + public static void test_java_scalar_udf_core_roundtrip_and_nulls() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + for (DuckDBColumnType type : scalarCoreTypes()) { + String fnName = "f_identity_" + type.name().toLowerCase(); + conn.registerScalarUdf(fnName, new DuckDBColumnType[] {type}, type, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + continue; + } + switch (type) { + case BOOLEAN: + out.setBoolean(row, args[0].getBoolean(row)); + break; + case TINYINT: + case SMALLINT: + case INTEGER: + out.setInt(row, args[0].getInt(row)); + break; + case BIGINT: + out.setLong(row, args[0].getLong(row)); + break; + case FLOAT: + out.setFloat(row, args[0].getFloat(row)); + break; + case DOUBLE: + out.setDouble(row, args[0].getDouble(row)); + break; + case VARCHAR: + out.setString(row, args[0].getString(row)); + break; + default: + throw new IllegalStateException("Unexpected core type: " + type); + } + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT " + fnName + "(" + nonNullLiteralForType(type) + "), " + + fnName + "(" + nullLiteralForType(type) + ")")) { + assertTrue(rs.next()); + switch (type) { + case BOOLEAN: + assertEquals(rs.getBoolean(1), true); + break; + case TINYINT: + assertEquals(rs.getInt(1), 7); + break; + case SMALLINT: + assertEquals(rs.getInt(1), 32000); + break; + case INTEGER: + assertEquals(rs.getInt(1), 123456); + break; + case BIGINT: + assertEquals(rs.getLong(1), 9876543210L); + break; + case FLOAT: + assertEquals(rs.getFloat(1), 1.25f, 0.0001f); + break; + case DOUBLE: + assertEquals(rs.getDouble(1), 2.5d, 0.0000001d); + break; + case VARCHAR: + assertEquals(rs.getString(1), "duck"); + break; + default: + throw new IllegalStateException("Unexpected core type: " + type); + } + assertEquals(rs.getObject(2), null); + assertFalse(rs.next()); + } + } + } + } + + public static void test_java_scalar_udf_core_null_special_handling() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + for (DuckDBColumnType type : scalarCoreTypes()) { + String fnName = "f_special_" + type.name().toLowerCase(); + conn.registerScalarUdf(fnName, new DuckDBColumnType[] {type}, type, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (!args[0].isNull(row)) { + throw new IllegalStateException("Expected NULL input row"); + } + switch (type) { + case BOOLEAN: + out.setBoolean(row, true); + break; + case TINYINT: + out.setInt(row, 12); + break; + case SMALLINT: + out.setInt(row, 1234); + break; + case INTEGER: + out.setInt(row, 123456); + break; + case BIGINT: + out.setLong(row, 123456789L); + break; + case FLOAT: + out.setFloat(row, 9.25f); + break; + case DOUBLE: + out.setDouble(row, 19.5d); + break; + case VARCHAR: + out.setString(row, "null-special"); + break; + default: + throw new IllegalStateException("Unexpected core type: " + type); + } + } + }, new UdfOptions().nullSpecialHandling(true)); + + try (ResultSet rs = stmt.executeQuery("SELECT " + fnName + "(" + nullLiteralForType(type) + ")")) { + assertTrue(rs.next()); + switch (type) { + case BOOLEAN: + assertEquals(rs.getBoolean(1), true); + break; + case TINYINT: + assertEquals(rs.getInt(1), 12); + break; + case SMALLINT: + assertEquals(rs.getInt(1), 1234); + break; + case INTEGER: + assertEquals(rs.getInt(1), 123456); + break; + case BIGINT: + assertEquals(rs.getLong(1), 123456789L); + break; + case FLOAT: + assertEquals(rs.getFloat(1), 9.25f, 0.0001f); + break; + case DOUBLE: + assertEquals(rs.getDouble(1), 19.5d, 0.0000001d); + break; + case VARCHAR: + assertEquals(rs.getString(1), "null-special"); + break; + default: + throw new IllegalStateException("Unexpected core type: " + type); + } + assertFalse(rs.next()); + } + } + } + } + + public static void test_java_scalar_udf_core_exception_return_null() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + for (DuckDBColumnType type : scalarCoreTypes()) { + String fnName = "f_throw_null_" + type.name().toLowerCase(); + conn.registerScalarUdf(fnName, new DuckDBColumnType[] {type}, type, (ctx, args, out, rowCount) -> { + throw new RuntimeException("kaboom"); + }, new UdfOptions().returnNullOnException(true)); + + try (ResultSet rs = stmt.executeQuery("SELECT " + fnName + "(" + nonNullLiteralForType(type) + ")")) { + assertTrue(rs.next()); + assertEquals(rs.getObject(1), null); + assertFalse(rs.next()); + } + } + } + } + + public static void test_java_scalar_udf_core_exception_abort() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + for (DuckDBColumnType type : scalarCoreTypes()) { + String fnName = "f_throw_abort_" + type.name().toLowerCase(); + conn.registerScalarUdf(fnName, new DuckDBColumnType[] {type}, type, + (ctx, args, out, rowCount) -> { throw new RuntimeException("kaboom"); }); + assertThrows(() -> { + stmt.executeQuery("SELECT " + fnName + "(" + nonNullLiteralForType(type) + ")"); + }, SQLException.class); + } + } + } + + public static void test_java_scalar_udf_reverse_varchar() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("reverse_java", new DuckDBColumnType[] {DuckDBColumnType.VARCHAR}, + DuckDBColumnType.VARCHAR, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + } else { + out.setString( + row, new StringBuilder(args[0].getString(row)).reverse().toString()); + } + } + }); + + try (ResultSet rs = + stmt.executeQuery("SELECT reverse_java('abcd'), reverse_java('cafe'), reverse_java('hello')")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "dcba"); + assertEquals(rs.getString(2), "efac"); + assertEquals(rs.getString(3), "olleh"); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_concat_varchar() throws Exception { + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("concat_java", + new DuckDBColumnType[] {DuckDBColumnType.VARCHAR, DuckDBColumnType.VARCHAR}, + DuckDBColumnType.VARCHAR, (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row) || args[1].isNull(row)) { + out.setNull(row); + } else { + out.setString(row, args[0].getString(row) + args[1].getString(row)); + } + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT concat_java('Hello ', 'world')")) { + assertTrue(rs.next()); + assertEquals(rs.getString(1), "Hello world"); + assertFalse(rs.next()); + } + } + } + + public static void test_java_scalar_udf_add_one_10m_benchmark() throws Exception { + if (!Boolean.getBoolean("duckdb.udf.benchmark")) { + return; + } + try (DuckDBConnection conn = DriverManager.getConnection(JDBC_URL).unwrap(DuckDBConnection.class); + Statement stmt = conn.createStatement()) { + conn.registerScalarUdf("add_one", (ctx, args, out, rowCount) -> { + for (int row = 0; row < rowCount; row++) { + if (args[0].isNull(row)) { + out.setNull(row); + } else { + out.setInt(row, args[0].getInt(row) + 1); + } + } + }); + + try (ResultSet rs = stmt.executeQuery("SELECT sum(add_one(i::INTEGER)) FROM range(10000000) t(i)")) { + assertTrue(rs.next()); + assertEquals(rs.getLong(1), 50000005000000L); + assertFalse(rs.next()); + } + } } public static void test_autocommit_off() throws Exception { @@ -251,60 +3900,60 @@ public static void test_enum() throws Exception { "CREATE TYPE enum_long AS ENUM ('enum0' ,'enum1' ,'enum2' ,'enum3' ,'enum4' ,'enum5' ,'enum6'" + ",'enum7' ,'enum8' ,'enum9' ,'enum10' ,'enum11' ,'enum12' ,'enum13' ,'enum14' ,'enum15' ,'enum16' ,'enum17'" - + - ",'enum18' ,'enum19' ,'enum20' ,'enum21' ,'enum22' ,'enum23' ,'enum24' ,'enum25' ,'enum26' ,'enum27' ,'enum28'" - + - ",'enum29' ,'enum30' ,'enum31' ,'enum32' ,'enum33' ,'enum34' ,'enum35' ,'enum36' ,'enum37' ,'enum38' ,'enum39'" - + - ",'enum40' ,'enum41' ,'enum42' ,'enum43' ,'enum44' ,'enum45' ,'enum46' ,'enum47' ,'enum48' ,'enum49' ,'enum50'" - + - ",'enum51' ,'enum52' ,'enum53' ,'enum54' ,'enum55' ,'enum56' ,'enum57' ,'enum58' ,'enum59' ,'enum60' ,'enum61'" - + - ",'enum62' ,'enum63' ,'enum64' ,'enum65' ,'enum66' ,'enum67' ,'enum68' ,'enum69' ,'enum70' ,'enum71' ,'enum72'" - + - ",'enum73' ,'enum74' ,'enum75' ,'enum76' ,'enum77' ,'enum78' ,'enum79' ,'enum80' ,'enum81' ,'enum82' ,'enum83'" - + - ",'enum84' ,'enum85' ,'enum86' ,'enum87' ,'enum88' ,'enum89' ,'enum90' ,'enum91' ,'enum92' ,'enum93' ,'enum94'" + + ",'enum18' ,'enum19' ,'enum20' ,'enum21' ,'enum22' ,'enum23' ,'enum24' ,'enum25' ,'enum26' ,'enum27' " + + ",'enum28'" + + ",'enum29' ,'enum30' ,'enum31' ,'enum32' ,'enum33' ,'enum34' ,'enum35' ,'enum36' ,'enum37' ,'enum38' " + + ",'enum39'" + + ",'enum40' ,'enum41' ,'enum42' ,'enum43' ,'enum44' ,'enum45' ,'enum46' ,'enum47' ,'enum48' ,'enum49' " + + ",'enum50'" + + ",'enum51' ,'enum52' ,'enum53' ,'enum54' ,'enum55' ,'enum56' ,'enum57' ,'enum58' ,'enum59' ,'enum60' " + + ",'enum61'" + + ",'enum62' ,'enum63' ,'enum64' ,'enum65' ,'enum66' ,'enum67' ,'enum68' ,'enum69' ,'enum70' ,'enum71' " + + ",'enum72'" + + ",'enum73' ,'enum74' ,'enum75' ,'enum76' ,'enum77' ,'enum78' ,'enum79' ,'enum80' ,'enum81' ,'enum82' " + + ",'enum83'" + + ",'enum84' ,'enum85' ,'enum86' ,'enum87' ,'enum88' ,'enum89' ,'enum90' ,'enum91' ,'enum92' ,'enum93' " + + ",'enum94'" + ",'enum95' ,'enum96' ,'enum97' ,'enum98' ,'enum99' ,'enum100' ,'enum101' ,'enum102' ,'enum103' ,'enum104' " - + - ",'enum105' ,'enum106' ,'enum107' ,'enum108' ,'enum109' ,'enum110' ,'enum111' ,'enum112' ,'enum113' ,'enum114'" - + - ",'enum115' ,'enum116' ,'enum117' ,'enum118' ,'enum119' ,'enum120' ,'enum121' ,'enum122' ,'enum123' ,'enum124'" - + - ",'enum125' ,'enum126' ,'enum127' ,'enum128' ,'enum129' ,'enum130' ,'enum131' ,'enum132' ,'enum133' ,'enum134'" - + - ",'enum135' ,'enum136' ,'enum137' ,'enum138' ,'enum139' ,'enum140' ,'enum141' ,'enum142' ,'enum143' ,'enum144'" - + - ",'enum145' ,'enum146' ,'enum147' ,'enum148' ,'enum149' ,'enum150' ,'enum151' ,'enum152' ,'enum153' ,'enum154'" - + - ",'enum155' ,'enum156' ,'enum157' ,'enum158' ,'enum159' ,'enum160' ,'enum161' ,'enum162' ,'enum163' ,'enum164'" - + - ",'enum165' ,'enum166' ,'enum167' ,'enum168' ,'enum169' ,'enum170' ,'enum171' ,'enum172' ,'enum173' ,'enum174'" - + - ",'enum175' ,'enum176' ,'enum177' ,'enum178' ,'enum179' ,'enum180' ,'enum181' ,'enum182' ,'enum183' ,'enum184'" - + - ",'enum185' ,'enum186' ,'enum187' ,'enum188' ,'enum189' ,'enum190' ,'enum191' ,'enum192' ,'enum193' ,'enum194'" - + - ",'enum195' ,'enum196' ,'enum197' ,'enum198' ,'enum199' ,'enum200' ,'enum201' ,'enum202' ,'enum203' ,'enum204'" - + - ",'enum205' ,'enum206' ,'enum207' ,'enum208' ,'enum209' ,'enum210' ,'enum211' ,'enum212' ,'enum213' ,'enum214'" - + - ",'enum215' ,'enum216' ,'enum217' ,'enum218' ,'enum219' ,'enum220' ,'enum221' ,'enum222' ,'enum223' ,'enum224'" - + - ",'enum225' ,'enum226' ,'enum227' ,'enum228' ,'enum229' ,'enum230' ,'enum231' ,'enum232' ,'enum233' ,'enum234'" - + - ",'enum235' ,'enum236' ,'enum237' ,'enum238' ,'enum239' ,'enum240' ,'enum241' ,'enum242' ,'enum243' ,'enum244'" - + - ",'enum245' ,'enum246' ,'enum247' ,'enum248' ,'enum249' ,'enum250' ,'enum251' ,'enum252' ,'enum253' ,'enum254'" - + - ",'enum255' ,'enum256' ,'enum257' ,'enum258' ,'enum259' ,'enum260' ,'enum261' ,'enum262' ,'enum263' ,'enum264'" - + - ",'enum265' ,'enum266' ,'enum267' ,'enum268' ,'enum269' ,'enum270' ,'enum271' ,'enum272' ,'enum273' ,'enum274'" - + - ",'enum275' ,'enum276' ,'enum277' ,'enum278' ,'enum279' ,'enum280' ,'enum281' ,'enum282' ,'enum283' ,'enum284'" - + - ",'enum285' ,'enum286' ,'enum287' ,'enum288' ,'enum289' ,'enum290' ,'enum291' ,'enum292' ,'enum293' ,'enum294'" + + ",'enum105' ,'enum106' ,'enum107' ,'enum108' ,'enum109' ,'enum110' ,'enum111' ,'enum112' ,'enum113' " + + ",'enum114'" + + ",'enum115' ,'enum116' ,'enum117' ,'enum118' ,'enum119' ,'enum120' ,'enum121' ,'enum122' ,'enum123' " + + ",'enum124'" + + ",'enum125' ,'enum126' ,'enum127' ,'enum128' ,'enum129' ,'enum130' ,'enum131' ,'enum132' ,'enum133' " + + ",'enum134'" + + ",'enum135' ,'enum136' ,'enum137' ,'enum138' ,'enum139' ,'enum140' ,'enum141' ,'enum142' ,'enum143' " + + ",'enum144'" + + ",'enum145' ,'enum146' ,'enum147' ,'enum148' ,'enum149' ,'enum150' ,'enum151' ,'enum152' ,'enum153' " + + ",'enum154'" + + ",'enum155' ,'enum156' ,'enum157' ,'enum158' ,'enum159' ,'enum160' ,'enum161' ,'enum162' ,'enum163' " + + ",'enum164'" + + ",'enum165' ,'enum166' ,'enum167' ,'enum168' ,'enum169' ,'enum170' ,'enum171' ,'enum172' ,'enum173' " + + ",'enum174'" + + ",'enum175' ,'enum176' ,'enum177' ,'enum178' ,'enum179' ,'enum180' ,'enum181' ,'enum182' ,'enum183' " + + ",'enum184'" + + ",'enum185' ,'enum186' ,'enum187' ,'enum188' ,'enum189' ,'enum190' ,'enum191' ,'enum192' ,'enum193' " + + ",'enum194'" + + ",'enum195' ,'enum196' ,'enum197' ,'enum198' ,'enum199' ,'enum200' ,'enum201' ,'enum202' ,'enum203' " + + ",'enum204'" + + ",'enum205' ,'enum206' ,'enum207' ,'enum208' ,'enum209' ,'enum210' ,'enum211' ,'enum212' ,'enum213' " + + ",'enum214'" + + ",'enum215' ,'enum216' ,'enum217' ,'enum218' ,'enum219' ,'enum220' ,'enum221' ,'enum222' ,'enum223' " + + ",'enum224'" + + ",'enum225' ,'enum226' ,'enum227' ,'enum228' ,'enum229' ,'enum230' ,'enum231' ,'enum232' ,'enum233' " + + ",'enum234'" + + ",'enum235' ,'enum236' ,'enum237' ,'enum238' ,'enum239' ,'enum240' ,'enum241' ,'enum242' ,'enum243' " + + ",'enum244'" + + ",'enum245' ,'enum246' ,'enum247' ,'enum248' ,'enum249' ,'enum250' ,'enum251' ,'enum252' ,'enum253' " + + ",'enum254'" + + ",'enum255' ,'enum256' ,'enum257' ,'enum258' ,'enum259' ,'enum260' ,'enum261' ,'enum262' ,'enum263' " + + ",'enum264'" + + ",'enum265' ,'enum266' ,'enum267' ,'enum268' ,'enum269' ,'enum270' ,'enum271' ,'enum272' ,'enum273' " + + ",'enum274'" + + ",'enum275' ,'enum276' ,'enum277' ,'enum278' ,'enum279' ,'enum280' ,'enum281' ,'enum282' ,'enum283' " + + ",'enum284'" + + ",'enum285' ,'enum286' ,'enum287' ,'enum288' ,'enum289' ,'enum290' ,'enum291' ,'enum292' ,'enum293' " + + ",'enum294'" + ",'enum295' ,'enum296' ,'enum297' ,'enum298' ,'enum299');"); stmt.execute("CREATE TABLE t2 (id INT, e1 enum_long);"); @@ -631,7 +4280,6 @@ public static void test_read_only() throws Exception { rs1.next(); assertEquals(rs1.getInt(1), 42); } - try (Statement stmt2 = conn_ro2.createStatement(); ResultSet rs2 = stmt2.executeQuery("SELECT * FROM test")) { rs2.next(); @@ -656,8 +4304,8 @@ public static void test_temporal_types() throws Exception { Connection conn = DriverManager.getConnection(JDBC_URL); Statement stmt = conn.createStatement(); - ResultSet rs = stmt.executeQuery( - "SELECT '2019-11-26 21:11:00'::timestamp ts, '2019-11-26'::date dt, interval '5 days' iv, '21:11:00'::time te"); + ResultSet rs = stmt.executeQuery("SELECT '2019-11-26 21:11:00'::timestamp ts, '2019-11-26'::date dt, " + + "interval '5 days' iv, '21:11:00'::time te"); assertTrue(rs.next()); assertEquals(rs.getObject("ts"), Timestamp.valueOf("2019-11-26 21:11:00")); assertEquals(rs.getTimestamp("ts"), Timestamp.valueOf("2019-11-26 21:11:00")); @@ -1518,11 +5166,15 @@ private static OffsetDateTime localDateTimeToOffset(LocalDateTime ldt) { correct_answer_map.put("uint", asList(0L, 4294967295L, null)); correct_answer_map.put("ubigint", asList(BigInteger.ZERO, new BigInteger("18446744073709551615"), null)); correct_answer_map.put( - "bignum", - asList( - "-179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368", - "179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368", - null)); + "bignum", asList("-17976931348623157081452742373170435679807056752584499659891747680315726078002" + + "853876058955863276687817154045895351438246423432132688946418276846754670353751" + + "698604991057655128207624549009038932894407586850845513394230458323690322294816" + + "5808559332123348274797826204144723168738177180919299881250404026184124858368", + "179769313486231570814527423731704356798070567525844996598917476803157260780028" + + "538760589558632766878171540458953514382464234321326889464182768467546703537516" + + "986049910576551282076245490090389328944075868508455133942304583236903222948165" + + "808559332123348274797826204144723168738177180919299881250404026184124858368", + null)); correct_answer_map.put("time", asList(LocalTime.of(0, 0), LocalTime.parse("23:59:59.999999"), null)); correct_answer_map.put("time_ns", asList(LocalTime.of(0, 0), LocalTime.parse("23:59:59.999999"), null)); correct_answer_map.put("float", asList(-3.4028234663852886e+38f, 3.4028234663852886e+38f, null)); @@ -1597,14 +5249,14 @@ public static void test_all_types() throws Exception { TimeZone.setDefault(ALL_TYPES_TIME_ZONE); try { Logger logger = Logger.getAnonymousLogger(); - String sql = - "select * EXCLUDE(time, time_ns, time_tz)" - + "\n , CASE WHEN time = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE time END AS time" - + - "\n , CASE WHEN time_ns = '24:00:00'::TIME_NS THEN '23:59:59.999999'::TIME_NS ELSE time_ns END AS time_ns" - + - "\n , CASE WHEN time_tz = '24:00:00-15:59:59'::TIMETZ THEN '23:59:59.999999-15:59:59'::TIMETZ ELSE time_tz END AS time_tz" - + "\nfrom test_all_types()"; + String sql = "select * EXCLUDE(time, time_ns, time_tz)" + + + "\n , CASE WHEN time = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE time END AS time" + + "\n , CASE WHEN time_ns = '24:00:00'::TIME_NS THEN '23:59:59.999999'::TIME_NS ELSE " + + "time_ns END AS time_ns" + + "\n , CASE WHEN time_tz = '24:00:00-15:59:59'::TIMETZ THEN " + + "'23:59:59.999999-15:59:59'::TIMETZ ELSE time_tz END AS time_tz" + + "\nfrom test_all_types()"; try (Connection conn = DriverManager.getConnection(JDBC_URL); PreparedStatement stmt = conn.prepareStatement(sql)) { @@ -2026,14 +5678,12 @@ public static void test_query_progress() throws Exception { return null; } }); - assertThrows( - () - -> stmt.executeQuery( - "WITH RECURSIVE cte AS NOT MATERIALIZED (" - + - "SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, cte.p + cte.f from cte WHERE cte.i < 200000) " - + "SELECT avg(f) FROM cte"), - SQLException.class); + assertThrows(() + -> stmt.executeQuery("WITH RECURSIVE cte AS NOT MATERIALIZED (" + + "SELECT * from test_fib1 UNION ALL SELECT cte.i + 1, cte.f, " + + "cte.p + cte.f from cte WHERE cte.i < 200000) " + + "SELECT avg(f) FROM cte"), + SQLException.class); QueryProgress qpRunning = future.get(); assertNotNull(qpRunning);