From 700a46f53f15c31ec54519faa0fbcc956d559236 Mon Sep 17 00:00:00 2001 From: Qiyuan Dong Date: Thu, 12 Dec 2024 14:36:42 +0100 Subject: [PATCH] Address PR comments --- .../io/delta/kernel/internal/DeltaErrors.java | 6 ++ .../delta/kernel/internal/TableFeatures.java | 5 +- .../kernel/internal/TransactionImpl.java | 17 ++-- .../kernel/internal/actions/AddFile.java | 96 ++++++++++++++----- .../internal/rowtracking/RowTracking.java | 49 +++++----- .../kernel/defaults/RowTrackingSuite.scala | 4 +- 6 files changed, 112 insertions(+), 65 deletions(-) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/DeltaErrors.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/DeltaErrors.java index 47fec3751c..37a7f8759d 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/DeltaErrors.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/DeltaErrors.java @@ -299,6 +299,12 @@ public static KernelException rowIDAssignmentWithoutStats() { + "when writing to a Delta table with the 'rowTracking' table feature supported"); } + public static KernelException rowTrackingSupportedWithDomainMetadataUnsupported() { + return new KernelException( + "Feature 'rowTracking' is supported and depends on feature 'domainMetadata'," + + " but 'domainMetadata' is unsupported"); + } + /* ------------------------ HELPER METHODS ----------------------------- */ private static String formatTimestamp(long millisSinceEpochUTC) { return new Timestamp(millisSinceEpochUTC).toInstant().toString(); diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java index bc8b7ad3b1..b6f3bae827 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java @@ -19,7 +19,6 @@ import static io.delta.kernel.internal.DeltaErrors.*; import static io.delta.kernel.internal.TableConfig.IN_COMMIT_TIMESTAMPS_ENABLED; -import io.delta.kernel.exceptions.KernelException; import io.delta.kernel.internal.actions.Metadata; import io.delta.kernel.internal.actions.Protocol; import io.delta.kernel.internal.util.ColumnMapping; @@ -146,9 +145,7 @@ public static void validateWriteSupportedTable( // By putting this check for row tracking here, it makes it easier to spot that row // tracking defines such a dependency that can be implicitly checked. if (isRowTrackingSupported(protocol) && !isDomainMetadataSupported(protocol)) { - throw new KernelException( - "Feature 'rowTracking' is supported and depends on feature `domainMetadata`" - + "but 'domainMetadata' is unsupported"); + throw DeltaErrors.rowTrackingSupportedWithDomainMetadataUnsupported(); } break; default: diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TransactionImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TransactionImpl.java index 4739074d7e..b54d74157a 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TransactionImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TransactionImpl.java @@ -148,14 +148,15 @@ public TransactionCommitResult commit(Engine engine, CloseableIterable data engine, attemptCommitInfo.getInCommitTimestamp(), readSnapshot.getVersion(engine)); // If row tracking is supported, assign base row IDs and default row commit versions to any - // AddFile or RemoveFile actions that do not yet have them. If the row ID high watermark - // changes, emit a DomainMetadata action. - Optional highWaterMark = - RowTracking.createNewHighWaterMarkIfNeeded(protocol, readSnapshot, dataActions); - highWaterMark.ifPresent(domainMetadatas::add); - dataActions = - RowTracking.assignBaseRowIdAndDefaultRowCommitVersion( - protocol, readSnapshot, commitAsVersion, dataActions); + // AddFile actions that do not yet have them. If the row ID high watermark changes, emit a + // DomainMetadata action to update it. + if (TableFeatures.isRowTrackingSupported(protocol)) { + RowTracking.createNewHighWaterMarkIfNeeded(readSnapshot, dataActions) + .ifPresent(domainMetadatas::add); + dataActions = + RowTracking.assignBaseRowIdAndDefaultRowCommitVersion( + readSnapshot, commitAsVersion, dataActions); + } int numRetries = 0; do { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/AddFile.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/AddFile.java index 58e4631ffb..01293f6860 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/AddFile.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/AddFile.java @@ -16,10 +16,8 @@ package io.delta.kernel.internal.actions; import static io.delta.kernel.internal.util.InternalUtils.relativizePath; -import static io.delta.kernel.internal.util.InternalUtils.requireNonNull; import static io.delta.kernel.internal.util.PartitionUtils.serializePartitionMap; import static io.delta.kernel.internal.util.Preconditions.checkArgument; -import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toMap; import io.delta.kernel.data.MapValue; @@ -27,12 +25,15 @@ import io.delta.kernel.expressions.Literal; import io.delta.kernel.internal.data.GenericRow; import io.delta.kernel.internal.fs.Path; +import io.delta.kernel.internal.util.InternalUtils; +import io.delta.kernel.internal.util.VectorUtils; import io.delta.kernel.types.*; import io.delta.kernel.utils.DataFileStatistics; import io.delta.kernel.utils.DataFileStatus; import java.net.URI; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.stream.IntStream; @@ -106,13 +107,10 @@ public static Row convertDataFileStatus( /** * Utility to generate an {@link AddFile} action from an 'AddFile' {@link Row}. * - * @param row the row to read - * @return the extracted {@link AddFile} action + * @throws NullPointerException if row is null */ public static AddFile fromRow(Row row) { - if (row == null) { - return null; - } + Objects.requireNonNull(row, "Cannot generate an AddFile action from a null row"); checkArgument( row.getSchema().equals(FULL_SCHEMA), @@ -121,15 +119,17 @@ public static AddFile fromRow(Row row) { row.getSchema()); return new AddFile( - requireNonNull(row, COL_NAME_TO_ORDINAL.get("path"), "path") + InternalUtils.requireNonNull(row, COL_NAME_TO_ORDINAL.get("path"), "path") .getString(COL_NAME_TO_ORDINAL.get("path")), - requireNonNull(row, COL_NAME_TO_ORDINAL.get("partitionValues"), "partitionValues") + InternalUtils.requireNonNull( + row, COL_NAME_TO_ORDINAL.get("partitionValues"), "partitionValues") .getMap(COL_NAME_TO_ORDINAL.get("partitionValues")), - requireNonNull(row, COL_NAME_TO_ORDINAL.get("size"), "size") + InternalUtils.requireNonNull(row, COL_NAME_TO_ORDINAL.get("size"), "size") .getLong(COL_NAME_TO_ORDINAL.get("size")), - requireNonNull(row, COL_NAME_TO_ORDINAL.get("modificationTime"), "modificationTime") + InternalUtils.requireNonNull( + row, COL_NAME_TO_ORDINAL.get("modificationTime"), "modificationTime") .getLong(COL_NAME_TO_ORDINAL.get("modificationTime")), - requireNonNull(row, COL_NAME_TO_ORDINAL.get("dataChange"), "dataChange") + InternalUtils.requireNonNull(row, COL_NAME_TO_ORDINAL.get("dataChange"), "dataChange") .getBoolean(COL_NAME_TO_ORDINAL.get("dataChange")), Optional.ofNullable( row.isNullAt(COL_NAME_TO_ORDINAL.get("deletionVector")) @@ -178,8 +178,8 @@ public AddFile( Optional baseRowId, Optional defaultRowCommitVersion, Optional stats) { - this.path = requireNonNull(path, "path is null"); - this.partitionValues = requireNonNull(partitionValues, "partitionValues is null"); + this.path = Objects.requireNonNull(path, "path is null"); + this.partitionValues = Objects.requireNonNull(partitionValues, "partitionValues is null"); this.size = size; this.modificationTime = modificationTime; this.dataChange = dataChange; @@ -220,12 +220,7 @@ public Optional getStats() { return stats; } - /** - * Creates a new AddFile instance with the specified base row ID. - * - * @param baseRowId the new base row ID to be assigned - * @return a new AddFile instance with the updated base row ID - */ + /** Creates a new AddFile instance with the specified base row ID. */ public AddFile withNewBaseRowId(long baseRowId) { return new AddFile( path, @@ -240,12 +235,7 @@ public AddFile withNewBaseRowId(long baseRowId) { stats); } - /** - * Creates a new AddFile instance with the specified default row commit version. - * - * @param defaultRowCommitVersion the new default row commit version to be assigned - * @return a new AddFile instance with the updated default row commit version - */ + /** Creates a new AddFile instance with the specified default row commit version. */ public AddFile withNewDefaultRowCommitVersion(long defaultRowCommitVersion) { return new AddFile( path, @@ -259,4 +249,58 @@ public AddFile withNewDefaultRowCommitVersion(long defaultRowCommitVersion) { Optional.of(defaultRowCommitVersion), stats); } + + @Override + public String toString() { + // Explicitly convert the partitionValues and tags to Java Maps + Map partitionValuesJavaMap = VectorUtils.toJavaMap(this.partitionValues); + Optional> tagsJavaMap = this.tags.map(VectorUtils::toJavaMap); + + StringBuilder sb = new StringBuilder(); + sb.append("AddFile{"); + sb.append("path='").append(path).append('\''); + sb.append(", partitionValues=").append(partitionValuesJavaMap); + sb.append(", size=").append(size); + sb.append(", modificationTime=").append(modificationTime); + sb.append(", dataChange=").append(dataChange); + sb.append(", deletionVector=").append(deletionVector); + sb.append(", tags=").append(tagsJavaMap); + sb.append(", baseRowId=").append(baseRowId); + sb.append(", defaultRowCommitVersion=").append(defaultRowCommitVersion); + sb.append(", stats=").append(stats); + sb.append('}'); + return sb.toString(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (!(obj instanceof AddFile)) return false; + AddFile other = (AddFile) obj; + return size == other.size + && modificationTime == other.modificationTime + && dataChange == other.dataChange + && Objects.equals(path, other.path) + && Objects.equals(partitionValues, other.partitionValues) + && Objects.equals(deletionVector, other.deletionVector) + && Objects.equals(tags, other.tags) + && Objects.equals(baseRowId, other.baseRowId) + && Objects.equals(defaultRowCommitVersion, other.defaultRowCommitVersion) + && Objects.equals(stats, other.stats); + } + + @Override + public int hashCode() { + return Objects.hash( + path, + partitionValues, + size, + modificationTime, + dataChange, + deletionVector, + tags, + baseRowId, + defaultRowCommitVersion, + stats); + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/rowtracking/RowTracking.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/rowtracking/RowTracking.java index b095099a26..6301322ad2 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/rowtracking/RowTracking.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/rowtracking/RowTracking.java @@ -15,6 +15,8 @@ */ package io.delta.kernel.internal.rowtracking; +import static io.delta.kernel.internal.util.Preconditions.checkArgument; + import io.delta.kernel.data.Row; import io.delta.kernel.internal.DeltaErrors; import io.delta.kernel.internal.SnapshotImpl; @@ -25,6 +27,7 @@ import io.delta.kernel.utils.DataFileStatistics; import java.io.IOException; import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; /** A collection of helper methods for working with row tracking. */ public class RowTracking { @@ -43,7 +46,6 @@ private RowTracking() { * watermark and increments the watermark accordingly. If a default row commit version is missing, * it assigns the provided commit version. * - * @param protocol the protocol to check for row tracking support * @param snapshot the current snapshot of the table * @param commitVersion the version of the commit for default row commit version assignment * @param dataActions the {@link CloseableIterable} of data actions to process @@ -51,13 +53,11 @@ private RowTracking() { * versions assigned */ public static CloseableIterable assignBaseRowIdAndDefaultRowCommitVersion( - Protocol protocol, - SnapshotImpl snapshot, - long commitVersion, - CloseableIterable dataActions) { - if (!TableFeatures.isRowTrackingSupported(protocol)) { - return dataActions; - } + SnapshotImpl snapshot, long commitVersion, CloseableIterable dataActions) { + checkArgument( + TableFeatures.isRowTrackingSupported(snapshot.getProtocol()), + "Base row ID and default row commit version are assigned " + + "only when feature 'rowTracking' is supported."); return new CloseableIterable() { @Override @@ -68,8 +68,8 @@ public void close() throws IOException { @Override public CloseableIterator iterator() { // Used to keep track of the current high watermark as we iterate through the data actions. - // Use a one-element array here to allow for mutation within the lambda. - final long[] currRowIdHighWatermark = {readRowIdHighWaterMark(snapshot)}; + // Use an AtomicLong to allow for updating the high watermark in the lambda. + final AtomicLong currRowIdHighWatermark = new AtomicLong(readRowIdHighWaterMark(snapshot)); return dataActions .iterator() .map( @@ -83,9 +83,9 @@ public CloseableIterator iterator() { // Assign base row ID if missing if (!addFile.getBaseRowId().isPresent()) { - final long numRecords = getNumRecords(addFile); - addFile = addFile.withNewBaseRowId(currRowIdHighWatermark[0] + 1L); - currRowIdHighWatermark[0] += numRecords; + final long numRecords = getNumRecordsOrThrow(addFile); + addFile = addFile.withNewBaseRowId(currRowIdHighWatermark.get() + 1L); + currRowIdHighWatermark.addAndGet(numRecords); } // Assign default row commit version if missing @@ -104,30 +104,29 @@ public CloseableIterator iterator() { * Emits a {@link DomainMetadata} action if the row ID high watermark has changed due to newly * processed {@link AddFile} actions. * - * @param protocol the protocol to check for row tracking support * @param snapshot the current snapshot of the table * @param dataActions the iterable of data actions that may update the high watermark */ public static Optional createNewHighWaterMarkIfNeeded( - Protocol protocol, SnapshotImpl snapshot, CloseableIterable dataActions) { - if (!TableFeatures.isRowTrackingSupported(protocol)) { - return Optional.empty(); - } + SnapshotImpl snapshot, CloseableIterable dataActions) { + checkArgument( + TableFeatures.isRowTrackingSupported(snapshot.getProtocol()), + "Row ID high watermark is updated only when feature 'rowTracking' is supported."); final long prevRowIdHighWatermark = readRowIdHighWaterMark(snapshot); - // Use a one-element array here to allow for mutation within the lambda. - final long[] newRowIdHighWatermark = {prevRowIdHighWatermark}; + // Use an AtomicLong to allow for updating the high watermark in the lambda + final AtomicLong newRowIdHighWatermark = new AtomicLong(prevRowIdHighWatermark); dataActions.forEach( row -> { if (!row.isNullAt(ADD_FILE_ORDINAL)) { - newRowIdHighWatermark[0] += - getNumRecords(AddFile.fromRow(row.getStruct(ADD_FILE_ORDINAL))); + newRowIdHighWatermark.addAndGet( + getNumRecordsOrThrow(AddFile.fromRow(row.getStruct(ADD_FILE_ORDINAL)))); } }); - return (newRowIdHighWatermark[0] != prevRowIdHighWatermark) - ? Optional.of(new RowTrackingMetadataDomain(newRowIdHighWatermark[0]).toDomainMetadata()) + return (newRowIdHighWatermark.get() != prevRowIdHighWatermark) + ? Optional.of(new RowTrackingMetadataDomain(newRowIdHighWatermark.get()).toDomainMetadata()) : Optional.empty(); } @@ -151,7 +150,7 @@ public static long readRowIdHighWaterMark(SnapshotImpl snapshot) { * @param addFile the AddFile action * @return the number of records */ - private static long getNumRecords(AddFile addFile) { + private static long getNumRecordsOrThrow(AddFile addFile) { return addFile .getStats() .map(DataFileStatistics::getNumRecords) diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/RowTrackingSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/RowTrackingSuite.scala index 50ca3735f9..8e03e4074e 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/RowTrackingSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/RowTrackingSuite.scala @@ -245,8 +245,8 @@ class RowTrackingSuite extends DeltaTableWriteSuiteBase with ParquetSuiteBase { assert( e.getMessage .contains( - "Feature 'rowTracking' is supported and depends on feature `domainMetadata`" - + "but 'domainMetadata' is unsupported" + "Feature 'rowTracking' is supported and depends on feature 'domainMetadata'," + + " but 'domainMetadata' is unsupported" ) ) })