Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public abstract class AbstractThriftReader extends DorisReader {
private int readCount = 0;

private final Boolean datetimeJava8ApiEnabled;
private final Boolean useTimestampNtz;

protected AbstractThriftReader(DorisReaderPartition partition) throws Exception {
super(partition);
Expand Down Expand Up @@ -112,6 +113,7 @@ protected AbstractThriftReader(DorisReaderPartition partition) throws Exception
this.asyncThread = null;
}
this.datetimeJava8ApiEnabled = partition.getDateTimeJava8APIEnabled();
this.useTimestampNtz = config.getValue(DorisOptions.DORIS_READ_TIMESTAMP_NTZ_ENABLED);
}

private void runAsync() throws DorisException, InterruptedException {
Expand All @@ -128,7 +130,7 @@ private void runAsync() throws DorisException, InterruptedException {
});
endOfStream.set(nextResult.isEos());
if (!endOfStream.get()) {
rowBatch = new RowBatch(nextResult, dorisSchema, datetimeJava8ApiEnabled);
rowBatch = new RowBatch(nextResult, dorisSchema, datetimeJava8ApiEnabled, useTimestampNtz);
offset += rowBatch.getReadRowCount();
rowBatch.close();
rowBatchQueue.put(rowBatch);
Expand Down Expand Up @@ -182,7 +184,7 @@ public boolean hasNext() throws DorisException {
});
endOfStream.set(nextResult.isEos());
if (!endOfStream.get()) {
rowBatch = new RowBatch(nextResult, dorisSchema, datetimeJava8ApiEnabled);
rowBatch = new RowBatch(nextResult, dorisSchema, datetimeJava8ApiEnabled, useTimestampNtz);
}
}
hasNext = !endOfStream.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public class DorisFlightSqlReader extends DorisReader {
private AdbcConnection connection;
private final ArrowReader arrowReader;
private final Boolean datetimeJava8ApiEnabled;
private final Boolean useTimestampNtz;

public DorisFlightSqlReader(DorisReaderPartition partition) throws Exception {
super(partition);
Expand All @@ -85,6 +86,7 @@ public DorisFlightSqlReader(DorisReaderPartition partition) throws Exception {
this.schema = processDorisSchema(partition);
this.arrowReader = executeQuery();
this.datetimeJava8ApiEnabled = partition.getDateTimeJava8APIEnabled();
this.useTimestampNtz = config.getValue(DorisOptions.DORIS_READ_TIMESTAMP_NTZ_ENABLED);
}

@Override
Expand All @@ -96,7 +98,7 @@ public boolean hasNext() throws DorisException {
throw new DorisException(e);
}
if (!endOfStream.get()) {
rowBatch = new RowBatch(arrowReader, schema, datetimeJava8ApiEnabled);
rowBatch = new RowBatch(arrowReader, schema, datetimeJava8ApiEnabled, useTimestampNtz);
}
}
return !endOfStream.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,19 @@ public class RowBatch implements Serializable {
private List<FieldVector> fieldVectors;

private final Boolean datetimeJava8ApiEnabled;
private final Boolean useTimestampNtz;

public RowBatch(TScanBatchResult nextResult, Schema schema, Boolean datetimeJava8ApiEnabled) throws DorisException {
this(nextResult, schema, datetimeJava8ApiEnabled, false);
}

public RowBatch(TScanBatchResult nextResult, Schema schema, Boolean datetimeJava8ApiEnabled, Boolean useTimestampNtz) throws DorisException {

this.rootAllocator = new RootAllocator(Integer.MAX_VALUE);
this.arrowReader = new ArrowStreamReader(new ByteArrayInputStream(nextResult.getRows()), rootAllocator);
this.schema = schema;
this.datetimeJava8ApiEnabled = datetimeJava8ApiEnabled;
this.useTimestampNtz = useTimestampNtz != null ? useTimestampNtz : false;

try {
VectorSchemaRoot root = arrowReader.getVectorSchemaRoot();
Expand All @@ -128,10 +134,15 @@ public RowBatch(TScanBatchResult nextResult, Schema schema, Boolean datetimeJava
}

public RowBatch(ArrowReader reader, Schema schema, Boolean datetimeJava8ApiEnabled) throws DorisException {
this(reader, schema, datetimeJava8ApiEnabled, false);
}

public RowBatch(ArrowReader reader, Schema schema, Boolean datetimeJava8ApiEnabled, Boolean useTimestampNtz) throws DorisException {

this.arrowReader = reader;
this.schema = schema;
this.datetimeJava8ApiEnabled = datetimeJava8ApiEnabled;
this.useTimestampNtz = useTimestampNtz != null ? useTimestampNtz : false;

try {
VectorSchemaRoot root = arrowReader.getVectorSchemaRoot();
Expand Down Expand Up @@ -409,10 +420,18 @@ public void convertArrowToRowBatch() throws DorisException {
String stringValue = completeMilliseconds(new String(varCharVector.get(rowIndex),
StandardCharsets.UTF_8));
LocalDateTime dateTime = LocalDateTime.parse(stringValue, dateTimeV2Formatter);
if (datetimeJava8ApiEnabled) {

// If useTimestampNtz is enabled, keep LocalDateTime without timezone conversion
// This is for Spark TimestampNTZType support (Spark 3.4+)
if (useTimestampNtz) {
// For TimestampNTZ, we keep LocalDateTime directly without timezone conversion
addValueToRow(rowIndex, dateTime);
} else if (datetimeJava8ApiEnabled) {
// For TimestampType with Java8 API, convert to Instant with timezone
Instant instant = dateTime.atZone(DEFAULT_ZONE_ID).toInstant();
addValueToRow(rowIndex, instant);
} else {
// For TimestampType without Java8 API, use Timestamp
addValueToRow(rowIndex, Timestamp.valueOf(dateTime));
}
}
Expand All @@ -424,10 +443,18 @@ public void convertArrowToRowBatch() throws DorisException {
continue;
}
LocalDateTime dateTime = getDateTime(rowIndex, timeStampVector);
if (datetimeJava8ApiEnabled) {

// If useTimestampNtz is enabled, keep LocalDateTime without timezone conversion
// This is for Spark TimestampNTZType support (Spark 3.4+)
if (useTimestampNtz) {
// For TimestampNTZ, we keep LocalDateTime directly without timezone conversion
addValueToRow(rowIndex, dateTime);
} else if (datetimeJava8ApiEnabled) {
// For TimestampType with Java8 API, convert to Instant with timezone
Instant instant = dateTime.atZone(DEFAULT_ZONE_ID).toInstant();
addValueToRow(rowIndex, instant);
} else {
// For TimestampType without Java8 API, use Timestamp
addValueToRow(rowIndex, Timestamp.valueOf(dateTime));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,12 @@ public class DorisOptions {

public static final ConfigOption<Integer> DORIS_SINK_NET_BUFFER_SIZE = ConfigOptions.name("doris.sink.net.buffer.size").intType().defaultValue(1024 * 1024).withDescription("");

/**
* Enable TIMESTAMP_NTZ (Timestamp without timezone) support for Spark 3.4+.
* When enabled, Doris DATETIME/DATETIMEV2 types will be mapped to Spark TimestampNTZType instead of TimestampType.
* Default: false (maintain backward compatibility).
*/
public static final ConfigOption<Boolean> DORIS_READ_TIMESTAMP_NTZ_ENABLED = ConfigOptions.name("doris.read.timestamp.ntz.enabled").booleanType().defaultValue(false).withDescription("Enable TIMESTAMP_NTZ type support for Spark 3.4+. Default: false");


}
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ private[sql] class DorisRelation(
val tableIdentifier = cfg.getValue(DorisOptions.DORIS_TABLE_IDENTIFIER)
val tableIdentifierArr = tableIdentifier.split("\\.").map(_.replaceAll("`", ""))
val dorisSchema = frontend.getTableSchema(tableIdentifierArr(0), tableIdentifierArr(1))
val useTimestampNtz = cfg.getValue(DorisOptions.DORIS_READ_TIMESTAMP_NTZ_ENABLED)
StructType(dorisSchema.getProperties.asScala.map(field => {
StructField(field.getName, SchemaConvertors.toCatalystType(field.getType, field.getPrecision, field.getScale))
StructField(field.getName, SchemaConvertors.toCatalystType(field.getType, field.getPrecision, field.getScale, useTimestampNtz))
}))

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,34 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}
import java.time.{Instant, LocalDate, LocalDateTime}
import java.util
import scala.collection.JavaConverters.mapAsScalaMapConverter
import scala.collection.mutable

object RowConvertors {

/**
* Try to get TimestampNTZType using reflection for Spark 3.4+ compatibility.
*/
private lazy val timestampNTZTypeOption: Option[DataType] = {
try {
val timestampNTZClass = Class.forName("org.apache.spark.sql.types.TimestampNTZType$")
val instance = timestampNTZClass.getField("MODULE$").get(null)
Some(instance.asInstanceOf[DataType])
} catch {
case _: ClassNotFoundException | _: NoSuchFieldException | _: NoSuchMethodException =>
None
}
}

/**
* Check if a DataType is TimestampNTZType (for Spark 3.4+).
*/
private def isTimestampNTZType(dt: DataType): Boolean = {
timestampNTZTypeOption.exists(_.getClass == dt.getClass)
}

private val MAPPER = JsonMapper.builder().addModule(DefaultScalaModule)
.configure(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS, true).build()

Expand Down Expand Up @@ -67,72 +88,103 @@ object RowConvertors {
private def asScalaValue(row: SpecializedGetters, dataType: DataType, ordinal: Int): Any = {
if (row.isNullAt(ordinal)) null
else {
dataType match {
case NullType => NULL_VALUE
case BooleanType => row.getBoolean(ordinal)
case ByteType => row.getByte(ordinal)
case ShortType => row.getShort(ordinal)
case IntegerType => row.getInt(ordinal)
case LongType => row.getLong(ordinal)
case FloatType => row.getFloat(ordinal)
case DoubleType => row.getDouble(ordinal)
case StringType => Option(row.getUTF8String(ordinal)).map(_.toString).getOrElse(NULL_VALUE)
case TimestampType =>
DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)).toString
case DateType => DateTimeUtils.toJavaDate(row.getInt(ordinal)).toString
case BinaryType => row.getBinary(ordinal)
case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale).toJavaBigDecimal
case at: ArrayType =>
val arrayData = row.getArray(ordinal)
if (arrayData == null) NULL_VALUE
else {
(0 until arrayData.numElements()).map(i => {
if (arrayData.isNullAt(i)) null else asScalaValue(arrayData, at.elementType, i)
}).mkString("[", ",", "]")
}
case mt: MapType =>
val mapData = row.getMap(ordinal)
if (mapData.numElements() == 0) "{}"
else {
val keys = mapData.keyArray()
val values = mapData.valueArray()
val map = mutable.HashMap[Any, Any]()
// Check for TimestampNTZType first to avoid MatchError
if (isTimestampNTZType(dataType)) {
// TimestampNTZType: convert microsecond timestamp to LocalDateTime string
// DateTimeUtils.localDateTimeFromMicros converts microseconds to LocalDateTime
try {
val method = Class.forName("org.apache.spark.sql.catalyst.util.DateTimeUtils")
.getMethod("localDateTimeFromMicros", classOf[Long])
val localDateTime = method.invoke(null, Long.box(row.getLong(ordinal))).asInstanceOf[LocalDateTime]
localDateTime.toString
} catch {
case _: Exception =>
// Fallback: use timestamp directly as string
row.getLong(ordinal).toString
}
} else {
dataType match {
case NullType => NULL_VALUE
case BooleanType => row.getBoolean(ordinal)
case ByteType => row.getByte(ordinal)
case ShortType => row.getShort(ordinal)
case IntegerType => row.getInt(ordinal)
case LongType => row.getLong(ordinal)
case FloatType => row.getFloat(ordinal)
case DoubleType => row.getDouble(ordinal)
case StringType => Option(row.getUTF8String(ordinal)).map(_.toString).getOrElse(NULL_VALUE)
case TimestampType =>
DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)).toString
case DateType => DateTimeUtils.toJavaDate(row.getInt(ordinal)).toString
case BinaryType => row.getBinary(ordinal)
case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale).toJavaBigDecimal
case at: ArrayType =>
val arrayData = row.getArray(ordinal)
if (arrayData == null) NULL_VALUE
else {
(0 until arrayData.numElements()).map(i => {
if (arrayData.isNullAt(i)) null else asScalaValue(arrayData, at.elementType, i)
}).mkString("[", ",", "]")
}
case mt: MapType =>
val mapData = row.getMap(ordinal)
if (mapData.numElements() == 0) "{}"
else {
val keys = mapData.keyArray()
val values = mapData.valueArray()
val map = mutable.HashMap[Any, Any]()
var i = 0
while (i < keys.numElements()) {
map += asScalaValue(keys, mt.keyType, i) -> asScalaValue(values, mt.valueType, i)
i += 1
}
MAPPER.writeValueAsString(map)
}
case st: StructType =>
val structData = row.getStruct(ordinal, st.length)
val map = new java.util.TreeMap[String, Any]()
var i = 0
while (i < keys.numElements()) {
map += asScalaValue(keys, mt.keyType, i) -> asScalaValue(values, mt.valueType, i)
while (i < structData.numFields) {
val field = st.fields(i)
map.put(field.name, asScalaValue(structData, field.dataType, i))
i += 1
}
MAPPER.writeValueAsString(map)
}
case st: StructType =>
val structData = row.getStruct(ordinal, st.length)
val map = new java.util.TreeMap[String, Any]()
var i = 0
while (i < structData.numFields) {
val field = st.fields(i)
map.put(field.name, asScalaValue(structData, field.dataType, i))
i += 1
}
MAPPER.writeValueAsString(map)
case _ => throw new Exception(s"Unsupported spark type: ${dataType.typeName}")
case _ => throw new Exception(s"Unsupported spark type: ${dataType.typeName}")
}
}
}
}

def convertValue(v: Any, dataType: DataType, datetimeJava8ApiEnabled: Boolean): Any = {
dataType match {
case StringType => UTF8String.fromString(v.asInstanceOf[String])
case TimestampType if datetimeJava8ApiEnabled => DateTimeUtils.instantToMicros(v.asInstanceOf[Instant])
case TimestampType => DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp])
case DateType if datetimeJava8ApiEnabled => v.asInstanceOf[LocalDate].toEpochDay.toInt
case DateType => DateTimeUtils.fromJavaDate(v.asInstanceOf[Date])
case _: MapType =>
val map = v.asInstanceOf[java.util.Map[String, String]].asScala
val keys = map.keys.toArray.map(UTF8String.fromString)
val values = map.values.toArray.map(UTF8String.fromString)
ArrayBasedMapData(keys, values)
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType | _: DecimalType => v
case _ => throw new Exception(s"Unsupported spark type: ${dataType.typeName}")
// Add explicit case for TimestampNTZType to avoid MatchError
if (isTimestampNTZType(dataType)) {
// TimestampNTZType: convert LocalDateTime to microsecond timestamp without timezone conversion
v match {
case localDateTime: LocalDateTime =>
// Convert LocalDateTime to microseconds since epoch (1970-01-01T00:00:00)
// LocalDateTime.toEpochSecond(ZoneOffset.UTC) gives seconds, then multiply by 1_000_000 for microseconds
val seconds = localDateTime.atZone(java.time.ZoneOffset.UTC).toEpochSecond
val nanos = localDateTime.getNano
seconds * 1000000L + nanos / 1000
case null => null
case _ => throw new Exception(s"TimestampNTZType expects LocalDateTime, but got ${v.getClass}")
}
} else {
dataType match {
case StringType => UTF8String.fromString(v.asInstanceOf[String])
case TimestampType if datetimeJava8ApiEnabled => DateTimeUtils.instantToMicros(v.asInstanceOf[Instant])
case TimestampType => DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp])
case DateType if datetimeJava8ApiEnabled => v.asInstanceOf[LocalDate].toEpochDay.toInt
case DateType => DateTimeUtils.fromJavaDate(v.asInstanceOf[Date])
case _: MapType =>
val map = v.asInstanceOf[java.util.Map[String, String]].asScala
val keys = map.keys.toArray.map(UTF8String.fromString)
val values = map.values.toArray.map(UTF8String.fromString)
ArrayBasedMapData(keys, values)
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType | _: DecimalType => v
case _ => throw new Exception(s"Unsupported spark type: ${dataType.typeName}")
}
}
}

Expand Down
Loading
Loading