From 28163426b8173911be3f18a34081bf6f3b217206 Mon Sep 17 00:00:00 2001 From: Qiyuan Dong Date: Mon, 9 Dec 2024 15:47:00 +0100 Subject: [PATCH] Address pr comments --- .../internal/InternalScanFileUtils.java | 32 +++++++++++ .../delta/kernel/internal/SnapshotImpl.java | 10 ---- .../delta/kernel/internal/TableFeatures.java | 40 +++++++------- .../kernel/internal/TransactionImpl.java | 4 +- .../internal/rowtracking/RowTracking.java | 20 +++---- .../kernel/internal/TableFeaturesSuite.scala | 10 ++-- .../kernel/defaults/RowTrackingSuite.scala | 54 +++++++------------ 7 files changed, 88 insertions(+), 82 deletions(-) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java index ec0ec463a22..0756501dff5 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java @@ -32,6 +32,7 @@ import java.net.URI; import java.util.HashMap; import java.util.Map; +import java.util.Optional; /** * Utilities to extract information out of the scan file rows returned by {@link @@ -87,6 +88,11 @@ private InternalScanFileUtils() {} private static final int ADD_FILE_DV_ORDINAL = ADD_FILE_SCHEMA.indexOf("deletionVector"); + private static final int ADD_FILE_BASE_ROW_ID_ORDINAL = ADD_FILE_SCHEMA.indexOf("baseRowId"); + + private static final int ADD_FILE_DEFAULT_ROW_COMMIT_VERSION = + ADD_FILE_SCHEMA.indexOf("defaultRowCommitVersion"); + private static final int TABLE_ROOT_ORDINAL = SCAN_FILE_SCHEMA.indexOf(TABLE_ROOT_COL_NAME); public static final int ADD_FILE_STATS_ORDINAL = AddFile.SCHEMA_WITH_STATS.indexOf("stats"); @@ -190,4 +196,30 @@ public static DeletionVectorDescriptor getDeletionVectorDescriptorFromRow(Row sc public static Column getPartitionValuesParsedRefInAddFile(String partitionColName) { return new Column(new String[] {"add", "partitionValues_parsed", partitionColName}); } + + /** + * Get the base row id from the given scan file row. + * + * @param scanFile {@link Row} representing one scan file. + * @return base row id if present, otherwise empty. + */ + public static Optional getBaseRowId(Row scanFile) { + Row addFile = getAddFileEntry(scanFile); + return addFile.isNullAt(ADD_FILE_BASE_ROW_ID_ORDINAL) + ? Optional.empty() + : Optional.of(addFile.getLong(ADD_FILE_BASE_ROW_ID_ORDINAL)); + } + + /** + * Get the default row commit version from the given scan file row. + * + * @param scanFile {@link Row} representing one scan file. + * @return default row commit version if present, otherwise empty. + */ + public static Optional getDefaultRowCommitVersion(Row scanFile) { + Row addFile = getAddFileEntry(scanFile); + return addFile.isNullAt(ADD_FILE_DEFAULT_ROW_COMMIT_VERSION) + ? Optional.empty() + : Optional.of(addFile.getLong(ADD_FILE_DEFAULT_ROW_COMMIT_VERSION)); + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/SnapshotImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/SnapshotImpl.java index 54ac6563110..72ee67492fa 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/SnapshotImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/SnapshotImpl.java @@ -151,16 +151,6 @@ public LogSegment getLogSegment() { return logSegment; } - /** - * Returns the log replay object. Visible for testing, where we need to access all the active - * AddFiles for a snapshot. - * - * @return the {@link LogReplay} object - */ - public LogReplay getLogReplay() { - return logReplay; - } - public CreateCheckpointIterator getCreateCheckpointIterator(Engine engine) { long minFileRetentionTimestampMillis = System.currentTimeMillis() - TOMBSTONE_RETENTION.fromMetadata(metadata); diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java index e9953647646..bc8b7ad3b1c 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java @@ -142,6 +142,14 @@ public static void validateWriteSupportedTable( throw unsupportedWriterFeature(tablePath, writerFeature); } } + // Eventually we may have a way to declare and enforce dependencies between features. + // By putting this check for row tracking here, it makes it easier to spot that row + // tracking defines such a dependency that can be implicitly checked. + if (isRowTrackingSupported(protocol) && !isDomainMetadataSupported(protocol)) { + throw new KernelException( + "Feature 'rowTracking' is supported and depends on feature `domainMetadata`" + + "but 'domainMetadata' is unsupported"); + } break; default: throw unsupportedWriterProtocol(tablePath, minWriterVersion); @@ -195,12 +203,7 @@ public static Set extractAutomaticallyEnabledWriterFeatures( * @return true if the "domainMetadata" feature is supported, false otherwise */ public static boolean isDomainMetadataSupported(Protocol protocol) { - List writerFeatures = protocol.getWriterFeatures(); - if (writerFeatures == null) { - return false; - } - return writerFeatures.contains(DOMAIN_METADATA_FEATURE_NAME) - && protocol.getMinWriterVersion() >= TABLE_FEATURES_MIN_WRITER_VERSION; + return isWriterFeatureSupported(protocol, DOMAIN_METADATA_FEATURE_NAME); } /** @@ -210,21 +213,7 @@ public static boolean isDomainMetadataSupported(Protocol protocol) { * @return true if the protocol supports row tracking, false otherwise */ public static boolean isRowTrackingSupported(Protocol protocol) { - List writerFeatures = protocol.getWriterFeatures(); - if (writerFeatures == null) { - return false; - } - boolean rowTrackingSupported = - writerFeatures.contains(ROW_TRACKING_FEATURE_NAME) - && protocol.getMinWriterVersion() >= TABLE_FEATURES_MIN_WRITER_VERSION; - boolean domainMetadataSupported = isDomainMetadataSupported(protocol); - - if (rowTrackingSupported && !domainMetadataSupported) { - // This should not happen. Row tracking should automatically bring in domain metadata. - throw new KernelException( - "Feature 'rowTracking' is supported but 'domainMetadata' is unsupported"); - } - return rowTrackingSupported; + return isWriterFeatureSupported(protocol, ROW_TRACKING_FEATURE_NAME); } /** @@ -283,4 +272,13 @@ private static void validateNoInvariants(StructType tableSchema) { throw columnInvariantsNotSupported(); } } + + private static boolean isWriterFeatureSupported(Protocol protocol, String featureName) { + List writerFeatures = protocol.getWriterFeatures(); + if (writerFeatures == null) { + return false; + } + return writerFeatures.contains(featureName) + && protocol.getMinWriterVersion() >= TABLE_FEATURES_MIN_WRITER_VERSION; + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TransactionImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TransactionImpl.java index 9b8601b26ce..4739074d7e2 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TransactionImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TransactionImpl.java @@ -150,7 +150,9 @@ public TransactionCommitResult commit(Engine engine, CloseableIterable data // If row tracking is supported, assign base row IDs and default row commit versions to any // AddFile or RemoveFile actions that do not yet have them. If the row ID high watermark // changes, emit a DomainMetadata action. - RowTracking.updateHighWaterMark(protocol, readSnapshot, domainMetadatas, dataActions); + Optional highWaterMark = + RowTracking.createNewHighWaterMarkIfNeeded(protocol, readSnapshot, dataActions); + highWaterMark.ifPresent(domainMetadatas::add); dataActions = RowTracking.assignBaseRowIdAndDefaultRowCommitVersion( protocol, readSnapshot, commitAsVersion, dataActions); diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/rowtracking/RowTracking.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/rowtracking/RowTracking.java index 35843a4ef0e..b095099a261 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/rowtracking/RowTracking.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/rowtracking/RowTracking.java @@ -24,7 +24,7 @@ import io.delta.kernel.utils.CloseableIterator; import io.delta.kernel.utils.DataFileStatistics; import java.io.IOException; -import java.util.List; +import java.util.Optional; /** A collection of helper methods for working with row tracking. */ public class RowTracking { @@ -106,16 +106,12 @@ public CloseableIterator iterator() { * * @param protocol the protocol to check for row tracking support * @param snapshot the current snapshot of the table - * @param domainMetadatas the list of domain metadata actions to append to if needed * @param dataActions the iterable of data actions that may update the high watermark */ - public static void updateHighWaterMark( - Protocol protocol, - SnapshotImpl snapshot, - List domainMetadatas, - CloseableIterable dataActions) { + public static Optional createNewHighWaterMarkIfNeeded( + Protocol protocol, SnapshotImpl snapshot, CloseableIterable dataActions) { if (!TableFeatures.isRowTrackingSupported(protocol)) { - return; + return Optional.empty(); } final long prevRowIdHighWatermark = readRowIdHighWaterMark(snapshot); @@ -130,11 +126,9 @@ public static void updateHighWaterMark( } }); - // Emit a DomainMetadata action to update the high watermark if it has changed - if (newRowIdHighWatermark[0] != prevRowIdHighWatermark) { - domainMetadatas.add( - new RowTrackingMetadataDomain(newRowIdHighWatermark[0]).toDomainMetadata()); - } + return (newRowIdHighWatermark[0] != prevRowIdHighWatermark) + ? Optional.of(new RowTrackingMetadataDomain(newRowIdHighWatermark[0]).toDomainMetadata()) + : Optional.empty(); } /** diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/TableFeaturesSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/TableFeaturesSuite.scala index f373b2f2865..9ed288e5987 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/TableFeaturesSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/TableFeaturesSuite.scala @@ -69,10 +69,14 @@ class TableFeaturesSuite extends AnyFunSuite { } Seq("appendOnly", "inCommitTimestamp", "columnMapping", "typeWidening-preview", "typeWidening", - "domainMetadata", "rowTracking") - .foreach { supportedWriterFeature => + "domainMetadata", "rowTracking").foreach { supportedWriterFeature => test(s"validateWriteSupported: protocol 7 with $supportedWriterFeature") { - checkSupported(createTestProtocol(minWriterVersion = 7, supportedWriterFeature)) + val protocol = if (supportedWriterFeature == "rowTracking") { + createTestProtocol(minWriterVersion = 7, supportedWriterFeature, "domainMetadata") + } else { + createTestProtocol(minWriterVersion = 7, supportedWriterFeature) + } + checkSupported(protocol) } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/RowTrackingSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/RowTrackingSuite.scala index 55d084d152f..50ca3735f94 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/RowTrackingSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/RowTrackingSuite.scala @@ -19,7 +19,7 @@ import io.delta.kernel.defaults.internal.parquet.ParquetSuiteBase import io.delta.kernel.engine.Engine import io.delta.kernel.exceptions.KernelException import io.delta.kernel.expressions.Literal -import io.delta.kernel.internal.{SnapshotImpl, TableImpl} +import io.delta.kernel.internal.{SnapshotImpl, TableImpl, InternalScanFileUtils} import io.delta.kernel.internal.actions.{AddFile, Protocol, SingleAction} import io.delta.kernel.internal.util.Utils.toCloseableIterator import io.delta.kernel.internal.rowtracking.RowTrackingMetadataDomain @@ -65,21 +65,13 @@ class RowTrackingSuite extends DeltaTableWriteSuiteBase with ParquetSuiteBase { val table = TableImpl.forPath(engine, tablePath) val snapshot = table.getLatestSnapshot(engine).asInstanceOf[SnapshotImpl] - val AddFileActionsBatches = snapshot.getLogReplay.getAddFilesAsColumnarBatches( - engine, - false, /* shouldReadStats */ - Optional.empty() - ).asScala - - val modificationTimeOrdinal = AddFile.SCHEMA_WITHOUT_STATS.indexOf("modificationTime") - val baseRowIdOrdinal = AddFile.SCHEMA_WITHOUT_STATS.indexOf("baseRowId") - - val sortedBaseRowIds = AddFileActionsBatches - .flatMap(_.getRows.asScala) - .toSeq - .map(_.getStruct(0)) - .sortBy(_.getLong(modificationTimeOrdinal)) - .map(_.getLong(baseRowIdOrdinal)) + val scanFileRows = collectScanFileRows( + snapshot.getScanBuilder(engine).build() + ) + val sortedBaseRowIds = scanFileRows + .map(InternalScanFileUtils.getBaseRowId) + .map(_.orElse(-1)) + .sorted assert(sortedBaseRowIds === expectedValue) } @@ -91,24 +83,15 @@ class RowTrackingSuite extends DeltaTableWriteSuiteBase with ParquetSuiteBase { val table = TableImpl.forPath(engine, tablePath) val snapshot = table.getLatestSnapshot(engine).asInstanceOf[SnapshotImpl] - val AddFileActionsBatches = snapshot.getLogReplay.getAddFilesAsColumnarBatches( - engine, - false, /* shouldReadStats */ - Optional.empty() - ).asScala - - val modificationTimeOrdinal = AddFile.SCHEMA_WITHOUT_STATS.indexOf("modificationTime") - val defaultRowCommitVersionOrdinal = - AddFile.SCHEMA_WITHOUT_STATS.indexOf("defaultRowCommitVersion") - - val AddFileDefaultRowCommitVersionsSorted = AddFileActionsBatches - .flatMap(_.getRows.asScala) - .toSeq - .map(_.getStruct(0)) - .sortBy(_.getLong(modificationTimeOrdinal)) - .map(_.getLong(defaultRowCommitVersionOrdinal)) + val scanFileRows = collectScanFileRows( + snapshot.getScanBuilder(engine).build() + ) + val sortedAddFileDefaultRowCommitVersions = scanFileRows + .map(InternalScanFileUtils.getDefaultRowCommitVersion) + .map(_.orElse(-1)) + .sorted - assert(AddFileDefaultRowCommitVersionsSorted === expectedValue) + assert(sortedAddFileDefaultRowCommitVersions === expectedValue) } private def verifyHighWatermark(engine: Engine, tablePath: String, expectedValue: Long) = { @@ -261,7 +244,10 @@ class RowTrackingSuite extends DeltaTableWriteSuiteBase with ParquetSuiteBase { assert( e.getMessage - .contains("Feature 'rowTracking' is supported but 'domainMetadata' is unsupported") + .contains( + "Feature 'rowTracking' is supported and depends on feature `domainMetadata`" + + "but 'domainMetadata' is unsupported" + ) ) }) }