Skip to content

Commit

Permalink
Bulk Load CDK: Unwrap multipart streaming upload (airbytehq#48810)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Dec 6, 2024
1 parent 7a45e94 commit 8beb1d1
Show file tree
Hide file tree
Showing 13 changed files with 197 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class ReservationManager(val totalCapacityBytes: Long) {

val remainingCapacityBytes: Long
get() = totalCapacityBytes - usedBytes.get()
val totalBytesReserved: Long
get() = usedBytes.get()

/* Attempt to reserve memory. If enough memory is not available, waits until it is, then reserves. */
suspend fun <T> reserve(bytes: Long, reservedFor: T): Reserved<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class DefaultSpillToDiskTask(

// reserve enough room for the record
diskManager.reserve(wrapped.sizeBytes)

// calculate whether we should flush
val rangeProcessed = range.withNextAdjacentValue(wrapped.index)
val bytesProcessed = sizeBytes + wrapped.sizeBytes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,12 @@ interface ObjectStorageClient<T : RemoteObject<*>> {
streamProcessor: StreamProcessor<V>? = null,
block: suspend (OutputStream) -> Unit
): T

/** Experimental sane replacement interface */
suspend fun startStreamingUpload(key: String, metadata: Map<String, String>): StreamingUpload<T>
}

interface StreamingUpload<T : RemoteObject<*>> {
suspend fun uploadPart(part: ByteArray)
suspend fun complete(): T
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.command.object_storage.AvroFormatConfiguration
import io.airbyte.cdk.load.command.object_storage.CSVFormatConfiguration
import io.airbyte.cdk.load.command.object_storage.JsonFormatConfiguration
import io.airbyte.cdk.load.command.object_storage.ObjectStorageCompressionConfigurationProvider
import io.airbyte.cdk.load.command.object_storage.ObjectStorageFormatConfigurationProvider
import io.airbyte.cdk.load.command.object_storage.ParquetFormatConfiguration
import io.airbyte.cdk.load.data.ObjectType
Expand All @@ -19,6 +20,7 @@ import io.airbyte.cdk.load.data.dataWithAirbyteMeta
import io.airbyte.cdk.load.data.json.toJson
import io.airbyte.cdk.load.data.parquet.ParquetMapperPipelineFactory
import io.airbyte.cdk.load.data.withAirbyteMeta
import io.airbyte.cdk.load.file.StreamProcessor
import io.airbyte.cdk.load.file.avro.toAvroWriter
import io.airbyte.cdk.load.file.csv.toCsvPrinterWithHeader
import io.airbyte.cdk.load.file.parquet.ParquetWriter
Expand All @@ -29,6 +31,7 @@ import io.airbyte.cdk.load.util.write
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
import jakarta.inject.Singleton
import java.io.ByteArrayOutputStream
import java.io.Closeable
import java.io.OutputStream
import org.apache.avro.Schema
Expand Down Expand Up @@ -75,6 +78,7 @@ class JsonFormattingWriter(
private val outputStream: OutputStream,
private val rootLevelFlattening: Boolean,
) : ObjectStorageFormattingWriter {

override fun accept(record: DestinationRecord) {
val data =
record.dataWithAirbyteMeta(stream, rootLevelFlattening).toJson().serializeToString()
Expand All @@ -92,6 +96,7 @@ class CSVFormattingWriter(
outputStream: OutputStream,
private val rootLevelFlattening: Boolean
) : ObjectStorageFormattingWriter {

private val finalSchema = stream.schema.withAirbyteMeta(rootLevelFlattening)
private val printer = finalSchema.toCsvPrinterWithHeader(outputStream)
override fun accept(record: DestinationRecord) {
Expand Down Expand Up @@ -124,11 +129,9 @@ class AvroFormattingWriter(
}

override fun accept(record: DestinationRecord) {
val dataMapped =
pipeline
.map(record.data, record.meta?.changes)
.withAirbyteMeta(stream, record.emittedAtMs, rootLevelFlattening)
writer.write(dataMapped.toAvroRecord(mappedSchema, avroSchema))
val dataMapped = pipeline.map(record.data, record.meta?.changes)
val withMeta = dataMapped.withAirbyteMeta(stream, record.emittedAtMs, rootLevelFlattening)
writer.write(withMeta.toAvroRecord(mappedSchema, avroSchema))
}

override fun close() {
Expand All @@ -155,11 +158,60 @@ class ParquetFormattingWriter(
}

override fun accept(record: DestinationRecord) {
val dataMapped =
pipeline
.map(record.data, record.meta?.changes)
.withAirbyteMeta(stream, record.emittedAtMs, rootLevelFlattening)
writer.write(dataMapped.toAvroRecord(mappedSchema, avroSchema))
val dataMapped = pipeline.map(record.data, record.meta?.changes)
val withMeta = dataMapped.withAirbyteMeta(stream, record.emittedAtMs, rootLevelFlattening)
writer.write(withMeta.toAvroRecord(mappedSchema, avroSchema))
}

override fun close() {
writer.close()
}
}

@Singleton
@Secondary
class BufferedFormattingWriterFactory<T : OutputStream>(
private val writerFactory: ObjectStorageFormattingWriterFactory,
private val compressionConfigurationProvider: ObjectStorageCompressionConfigurationProvider<T>,
) {
fun create(stream: DestinationStream): BufferedFormattingWriter<T> {
val outputStream = ByteArrayOutputStream()
val processor =
compressionConfigurationProvider.objectStorageCompressionConfiguration.compressor
val wrappingBuffer = processor.wrapper.invoke(outputStream)
val writer = writerFactory.create(stream, wrappingBuffer)
return BufferedFormattingWriter(writer, outputStream, processor, wrappingBuffer)
}
}

class BufferedFormattingWriter<T : OutputStream>(
private val writer: ObjectStorageFormattingWriter,
private val buffer: ByteArrayOutputStream,
private val streamProcessor: StreamProcessor<T>,
private val wrappingBuffer: T
) : ObjectStorageFormattingWriter {
val bufferSize: Int
get() = buffer.size()

override fun accept(record: DestinationRecord) {
writer.accept(record)
}

fun takeBytes(): ByteArray {
wrappingBuffer.flush()
val bytes = buffer.toByteArray()
buffer.reset()
return bytes
}

fun finish(): ByteArray? {
writer.close()
streamProcessor.partFinisher.invoke(wrappingBuffer)
return if (buffer.size() > 0) {
buffer.toByteArray()
} else {
null
}
}

override fun close() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ import com.google.common.annotations.VisibleForTesting
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.command.object_storage.ObjectStorageCompressionConfigurationProvider
import io.airbyte.cdk.load.file.NoopProcessor
import io.airbyte.cdk.load.command.object_storage.ObjectStorageUploadConfigurationProvider
import io.airbyte.cdk.load.file.StreamProcessor
import io.airbyte.cdk.load.file.object_storage.BufferedFormattingWriterFactory
import io.airbyte.cdk.load.file.object_storage.ObjectStorageClient
import io.airbyte.cdk.load.file.object_storage.ObjectStorageFormattingWriterFactory
import io.airbyte.cdk.load.file.object_storage.ObjectStoragePathFactory
import io.airbyte.cdk.load.file.object_storage.RemoteObject
import io.airbyte.cdk.load.message.Batch
Expand All @@ -31,33 +31,41 @@ import java.util.concurrent.atomic.AtomicLong

@Singleton
@Secondary
class ObjectStorageStreamLoaderFactory<T : RemoteObject<*>>(
class ObjectStorageStreamLoaderFactory<T : RemoteObject<*>, U : OutputStream>(
private val client: ObjectStorageClient<T>,
private val compressionConfig: ObjectStorageCompressionConfigurationProvider<*>? = null,
private val pathFactory: ObjectStoragePathFactory,
private val writerFactory: ObjectStorageFormattingWriterFactory,
private val bufferedWriterFactory: BufferedFormattingWriterFactory<U>,
private val compressionConfigurationProvider:
ObjectStorageCompressionConfigurationProvider<U>? =
null,
private val destinationStateManager: DestinationStateManager<ObjectStorageDestinationState>,
private val uploadConfigurationProvider: ObjectStorageUploadConfigurationProvider,
) {
fun create(stream: DestinationStream): StreamLoader {
return ObjectStorageStreamLoader(
stream,
client,
compressionConfig?.objectStorageCompressionConfiguration?.compressor ?: NoopProcessor,
compressionConfigurationProvider?.objectStorageCompressionConfiguration?.compressor,
pathFactory,
writerFactory,
destinationStateManager
bufferedWriterFactory,
destinationStateManager,
uploadConfigurationProvider.objectStorageUploadConfiguration.streamingUploadPartSize,
)
}
}

@SuppressFBWarnings("NP_NONNULL_PARAM_VIOLATION", justification = "Kotlin async continuation")
@SuppressFBWarnings(
value = ["NP_NONNULL_PARAM_VIOLATION", "NP_NULL_ON_SOME_PATH_FROM_RETURN_VALUE"],
justification = "Kotlin async continuation"
)
class ObjectStorageStreamLoader<T : RemoteObject<*>, U : OutputStream>(
override val stream: DestinationStream,
private val client: ObjectStorageClient<T>,
private val compressor: StreamProcessor<U>,
private val compressor: StreamProcessor<U>?,
private val pathFactory: ObjectStoragePathFactory,
private val writerFactory: ObjectStorageFormattingWriterFactory,
private val bufferedWriterFactory: BufferedFormattingWriterFactory<U>,
private val destinationStateManager: DestinationStateManager<ObjectStorageDestinationState>,
private val partSize: Long,
) : StreamLoader {
private val log = KotlinLogging.logger {}

Expand Down Expand Up @@ -95,12 +103,18 @@ class ObjectStorageStreamLoader<T : RemoteObject<*>, U : OutputStream>(
)

val metadata = ObjectStorageDestinationState.metadataFor(stream)
val obj =
client.streamingUpload(key, metadata, streamProcessor = compressor) { outputStream ->
writerFactory.create(stream, outputStream).use { writer ->
records.forEach { writer.accept(it) }
val upload = client.startStreamingUpload(key, metadata)
bufferedWriterFactory.create(stream).use { writer ->
records.forEach {
writer.accept(it)
if (writer.bufferSize >= partSize) {
upload.uploadPart(writer.takeBytes())
}
}
writer.finish()?.let { upload.uploadPart(it) }
}
val obj = upload.complete()

log.info { "Finished writing records to $key, persisting state" }
destinationStateManager.persistState(stream)
return RemoteObject(remoteObject = obj, partNumber = partNumber)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ package io.airbyte.cdk.load.write.object_storage

import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.file.StreamProcessor
import io.airbyte.cdk.load.file.object_storage.BufferedFormattingWriterFactory
import io.airbyte.cdk.load.file.object_storage.ObjectStorageClient
import io.airbyte.cdk.load.file.object_storage.ObjectStorageFormattingWriterFactory
import io.airbyte.cdk.load.file.object_storage.ObjectStoragePathFactory
import io.airbyte.cdk.load.file.object_storage.RemoteObject
import io.airbyte.cdk.load.message.DestinationFile
Expand All @@ -31,9 +31,11 @@ class ObjectStorageStreamLoaderTest {
private val client: ObjectStorageClient<RemoteObject<Int>> = mockk(relaxed = true)
private val compressor: StreamProcessor<ByteArrayOutputStream> = mockk(relaxed = true)
private val pathFactory: ObjectStoragePathFactory = mockk(relaxed = true)
private val writerFactory: ObjectStorageFormattingWriterFactory = mockk(relaxed = true)
private val writerFactory: BufferedFormattingWriterFactory<ByteArrayOutputStream> =
mockk(relaxed = true)
private val destinationStateManager: DestinationStateManager<ObjectStorageDestinationState> =
mockk(relaxed = true)
private val partSize: Long = 1

private val objectStorageStreamLoader =
spyk(
Expand All @@ -43,7 +45,8 @@ class ObjectStorageStreamLoaderTest {
compressor,
pathFactory,
writerFactory,
destinationStateManager
destinationStateManager,
partSize
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import io.airbyte.cdk.load.file.StreamProcessor
import io.airbyte.cdk.load.file.object_storage.ObjectStorageClient
import io.airbyte.cdk.load.file.object_storage.RemoteObject
import io.airbyte.cdk.load.file.object_storage.StreamingUpload
import io.micronaut.context.annotation.Requires
import jakarta.inject.Singleton
import java.io.ByteArrayOutputStream
Expand Down Expand Up @@ -81,4 +82,11 @@ class MockObjectStorageClient : ObjectStorageClient<MockRemoteObject> {
override suspend fun delete(remoteObject: MockRemoteObject) {
objects.remove(remoteObject.key)
}

override suspend fun startStreamingUpload(
key: String,
metadata: Map<String, String>
): StreamingUpload<MockRemoteObject> {
TODO("Not yet implemented")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import io.airbyte.cdk.load.file.NoopProcessor
import io.airbyte.cdk.load.file.StreamProcessor
import io.airbyte.cdk.load.file.object_storage.ObjectStorageClient
import io.airbyte.cdk.load.file.object_storage.RemoteObject
import io.airbyte.cdk.load.file.object_storage.StreamingUpload
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Factory
import io.micronaut.context.annotation.Secondary
Expand Down Expand Up @@ -176,6 +177,33 @@ class S3Client(
upload.runUsing(block)
return S3Object(key, bucketConfig)
}

override suspend fun startStreamingUpload(
key: String,
metadata: Map<String, String>
): StreamingUpload<S3Object> {
// TODO: Remove permit handling once we control concurrency with # of accumulators
if (uploadPermits != null) {
log.info {
"Attempting to acquire upload permit for $key (${uploadPermits.availablePermits} available)"
}
uploadPermits.acquire()
log.info {
"Acquired upload permit for $key (${uploadPermits.availablePermits} available)"
}
}

val request = CreateMultipartUploadRequest {
this.bucket = bucketConfig.s3BucketName
this.key = key
this.metadata = metadata
}
val response = client.createMultipartUpload(request)

log.info { "Starting multipart upload for $key (uploadId=${response.uploadId})" }

return S3StreamingUpload(client, bucketConfig, response, uploadPermits)
}
}

@Factory
Expand Down
Loading

0 comments on commit 8beb1d1

Please sign in to comment.