From 91dabcdd196252c7eb66106ed995844f3b3ead4d Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Wed, 18 Dec 2024 17:34:07 -0500 Subject: [PATCH] Align client bulk write API with legacy bulk write. --- .../client_side_encryption_test.go | 2 +- internal/integration/client_test.go | 46 ++- internal/integration/crud_prose_test.go | 320 +++++++++++------- internal/integration/csot_prose_test.go | 15 +- .../unified/client_operation_execution.go | 172 +++++----- mongo/client.go | 23 +- mongo/client_bulk_write.go | 43 +-- mongo/client_bulk_write_models.go | 87 ++--- mongo/client_bulk_write_test.go | 4 +- mongo/client_test.go | 11 +- x/mongo/driver/operation.go | 34 +- x/mongo/driver/wiremessage/wiremessage.go | 34 -- 12 files changed, 423 insertions(+), 368 deletions(-) diff --git a/internal/integration/client_side_encryption_test.go b/internal/integration/client_side_encryption_test.go index 84708601bc..9f1c5c757f 100644 --- a/internal/integration/client_side_encryption_test.go +++ b/internal/integration/client_side_encryption_test.go @@ -395,7 +395,7 @@ func TestClientSideEncryptionCustomCrypt(t *testing.T) { assert.Equal(mt, cc.numCloseCalls, 0, "expected 0 calls to Close, got %v", cc.numCloseCalls) assert.Equal(mt, cc.numBypassAutoEncryptionCalls, 1, - "expected 2 calls to BypassAutoEncryption, got %v", cc.numBypassAutoEncryptionCalls) + "expected 1 calls to BypassAutoEncryption, got %v", cc.numBypassAutoEncryptionCalls) }) } diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index 921aabf9bc..9f22fdd83d 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -725,23 +725,39 @@ func TestClient(t *testing.T) { testCases := []struct { name string - models *mongo.ClientWriteModels + writes []mongo.ClientBulkWrite }{ { - name: "DeleteOne", - models: (&mongo.ClientWriteModels{}).AppendDeleteOne("foo", "bar", mongo.NewClientDeleteOneModel()), + name: "DeleteOne", + writes: []mongo.ClientBulkWrite{{ + Database: "foo", + Collection: "bar", + Model: mongo.NewClientDeleteOneModel(), + }}, }, { - name: "DeleteMany", - models: (&mongo.ClientWriteModels{}).AppendDeleteMany("foo", "bar", mongo.NewClientDeleteManyModel()), + name: "DeleteMany", + writes: []mongo.ClientBulkWrite{{ + Database: "foo", + Collection: "bar", + Model: mongo.NewClientDeleteManyModel(), + }}, }, { - name: "UpdateOne", - models: (&mongo.ClientWriteModels{}).AppendUpdateOne("foo", "bar", mongo.NewClientUpdateOneModel()), + name: "UpdateOne", + writes: []mongo.ClientBulkWrite{{ + Database: "foo", + Collection: "bar", + Model: mongo.NewClientUpdateOneModel(), + }}, }, { - name: "UpdateMany", - models: (&mongo.ClientWriteModels{}).AppendUpdateMany("foo", "bar", mongo.NewClientUpdateManyModel()), + name: "UpdateMany", + writes: []mongo.ClientBulkWrite{{ + Database: "foo", + Collection: "bar", + Model: mongo.NewClientUpdateManyModel(), + }}, }, } for _, tc := range testCases { @@ -750,7 +766,7 @@ func TestClient(t *testing.T) { mt.Run(tc.name, func(mt *mtest.T) { mt.Parallel() - _, err := mt.Client.BulkWrite(context.Background(), tc.models) + _, err := mt.Client.BulkWrite(context.Background(), tc.writes) require.ErrorContains(mt, err, "filter is required") }) } @@ -779,11 +795,13 @@ func TestClient(t *testing.T) { mt.Run(tc.name, func(mt *mtest.T) { mt.Parallel() - var models *mongo.ClientWriteModels - insertOneModel := mongo.NewClientInsertOneModel().SetDocument(bson.D{{"x", 1}}) - models = (&mongo.ClientWriteModels{}).AppendInsertOne("foo", "bar", insertOneModel) - res, err := mt.Client.BulkWrite(context.Background(), models, tc.opts) + writes := []mongo.ClientBulkWrite{{ + Database: "foo", + Collection: "bar", + Model: insertOneModel, + }} + res, err := mt.Client.BulkWrite(context.Background(), writes, tc.opts) require.NoError(mt, err, "BulkWrite error: %v", err) require.NotNil(mt, res, "expected a ClientBulkWriteResult") assert.Equal(mt, res.Acknowledged, tc.want, "expected Acknowledged: %v, got: %v", tc.want, res.Acknowledged) diff --git a/internal/integration/crud_prose_test.go b/internal/integration/crud_prose_test.go index 3b5a069535..81adf0a112 100644 --- a/internal/integration/crud_prose_test.go +++ b/internal/integration/crud_prose_test.go @@ -420,7 +420,7 @@ func TestClientBulkWrite(t *testing.T) { mtOpts := mtest.NewOptions().MinServerVersion("8.0").AtlasDataLake(false).ClientType(mtest.Pinned) mt := mtest.New(t, mtOpts) - mt.Run("bulkWrite batch splits a writeModels input with greater than maxWriteBatchSize operations", func(mt *mtest.T) { + mt.Run("3. MongoClient.bulkWrite batch splits a writeModels input with greater than maxWriteBatchSize operations", func(mt *mtest.T) { var opsCnt []int monitor := &event.CommandMonitor{ Started: func(_ context.Context, e *event.CommandStartedEvent) { @@ -440,23 +440,26 @@ func TestClientBulkWrite(t *testing.T) { } err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) require.NoError(mt, err, "Hello error: %v", err) - models := &mongo.ClientWriteModels{} - numModels := hello.MaxWriteBatchSize + 1 - for i := 0; i < numModels; i++ { - models. - AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + var writes []mongo.ClientBulkWrite + num := hello.MaxWriteBatchSize + 1 + for i := 0; i < num; i++ { + writes = append(writes, mongo.ClientBulkWrite{ + Database: "db", + Collection: "coll", + Model: &mongo.ClientInsertOneModel{ Document: bson.D{{"a", "b"}}, - }) + }, + }) } - result, err := mt.Client.BulkWrite(context.Background(), models) + result, err := mt.Client.BulkWrite(context.Background(), writes) require.NoError(mt, err, "BulkWrite error: %v", err) - assert.Equal(mt, numModels, int(result.InsertedCount), "expected InsertedCount: %d, got %d", numModels, result.InsertedCount) + assert.Equal(mt, num, int(result.InsertedCount), "expected InsertedCount: %d, got %d", num, result.InsertedCount) require.Len(mt, opsCnt, 2, "expected %d bulkWrite commands, got: %d", 2, len(opsCnt)) - assert.Equal(mt, numModels-1, opsCnt[0], "expected %d firstEvent.command.ops, got: %d", numModels-1, opsCnt[0]) + assert.Equal(mt, num-1, opsCnt[0], "expected %d firstEvent.command.ops, got: %d", num-1, opsCnt[0]) assert.Equal(mt, 1, opsCnt[1], "expected %d secondEvent.command.ops, got: %d", 1, opsCnt[1]) }) - mt.Run("bulkWrite batch splits when an ops payload exceeds maxMessageSizeBytes", func(mt *mtest.T) { + mt.Run("4. MongoClient.bulkWrite batch splits when an ops payload exceeds maxMessageSizeBytes", func(mt *mtest.T) { var opsCnt []int monitor := &event.CommandMonitor{ Started: func(_ context.Context, e *event.CommandStartedEvent) { @@ -477,23 +480,26 @@ func TestClientBulkWrite(t *testing.T) { } err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) require.NoError(mt, err, "Hello error: %v", err) - models := &mongo.ClientWriteModels{} - numModels := hello.MaxMessageSizeBytes/hello.MaxBsonObjectSize + 1 - for i := 0; i < numModels; i++ { - models. - AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + var writes []mongo.ClientBulkWrite + num := hello.MaxMessageSizeBytes/hello.MaxBsonObjectSize + 1 + for i := 0; i < num; i++ { + writes = append(writes, mongo.ClientBulkWrite{ + Database: "db", + Collection: "coll", + Model: &mongo.ClientInsertOneModel{ Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize-500)}}, - }) + }, + }) } - result, err := mt.Client.BulkWrite(context.Background(), models) + result, err := mt.Client.BulkWrite(context.Background(), writes) require.NoError(mt, err, "BulkWrite error: %v", err) - assert.Equal(mt, numModels, int(result.InsertedCount), "expected InsertedCount: %d, got: %d", numModels, result.InsertedCount) + assert.Equal(mt, num, int(result.InsertedCount), "expected InsertedCount: %d, got: %d", num, result.InsertedCount) require.Len(mt, opsCnt, 2, "expected %d bulkWrite commands, got: %d", 2, len(opsCnt)) - assert.Equal(mt, numModels-1, opsCnt[0], "expected %d firstEvent.command.ops, got: %d", numModels-1, opsCnt[0]) + assert.Equal(mt, num-1, opsCnt[0], "expected %d firstEvent.command.ops, got: %d", num-1, opsCnt[0]) assert.Equal(mt, 1, opsCnt[1], "expected %d secondEvent.command.ops, got: %d", 1, opsCnt[1]) }) - mt.Run("bulkWrite collects WriteConcernErrors across batches", func(mt *mtest.T) { + mt.Run("5. MongoClient.bulkWrite collects WriteConcernErrors across batches", func(mt *mtest.T) { var eventCnt int monitor := &event.CommandMonitor{ Started: func(_ context.Context, e *event.CommandStartedEvent) { @@ -523,26 +529,29 @@ func TestClientBulkWrite(t *testing.T) { }, }) - models := &mongo.ClientWriteModels{} - numModels := hello.MaxWriteBatchSize + 1 - for i := 0; i < numModels; i++ { - models. - AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + var writes []mongo.ClientBulkWrite + num := hello.MaxWriteBatchSize + 1 + for i := 0; i < num; i++ { + writes = append(writes, mongo.ClientBulkWrite{ + Database: "db", + Collection: "coll", + Model: &mongo.ClientInsertOneModel{ Document: bson.D{{"a", "b"}}, - }) + }, + }) } - _, err = mt.Client.BulkWrite(context.Background(), models) + _, err = mt.Client.BulkWrite(context.Background(), writes) require.Error(mt, err, "expected a BulkWrite error") bwe, ok := err.(mongo.ClientBulkWriteException) require.True(mt, ok, "expected a BulkWriteException, got %T: %v", err, err) assert.Len(mt, bwe.WriteConcernErrors, 2, "expected %d writeConcernErrors, got: %d", 2, len(bwe.WriteConcernErrors)) require.NotNil(mt, bwe.PartialResult) - assert.Equal(mt, numModels, int(bwe.PartialResult.InsertedCount), - "expected InsertedCount: %d, got: %d", numModels, bwe.PartialResult.InsertedCount) + assert.Equal(mt, num, int(bwe.PartialResult.InsertedCount), + "expected InsertedCount: %d, got: %d", num, bwe.PartialResult.InsertedCount) require.Equal(mt, 2, eventCnt, "expected %d bulkWrite commands, got: %d", 2, eventCnt) }) - mt.Run("bulkWrite handles individual WriteErrors across batches", func(mt *mtest.T) { + mt.Run("6. MongoClient.bulkWrite handles individual WriteErrors across batches", func(mt *mtest.T) { var eventCnt int monitor := &event.CommandMonitor{ Started: func(_ context.Context, e *event.CommandStartedEvent) { @@ -565,19 +574,22 @@ func TestClientBulkWrite(t *testing.T) { _, err = coll.InsertOne(context.Background(), bson.D{{"_id", 1}}) require.NoError(mt, err, "InsertOne error: %v", err) - models := &mongo.ClientWriteModels{} + var writes []mongo.ClientBulkWrite numModels := hello.MaxWriteBatchSize + 1 for i := 0; i < numModels; i++ { - models. - AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + writes = append(writes, mongo.ClientBulkWrite{ + Database: "db", + Collection: "coll", + Model: &mongo.ClientInsertOneModel{ Document: bson.D{{"_id", 1}}, - }) + }, + }) } mt.Run("unordered", func(mt *mtest.T) { eventCnt = 0 mt.ResetClient(options.Client().SetMonitor(monitor)) - _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false)) + _, err := mt.Client.BulkWrite(context.Background(), writes, options.ClientBulkWrite().SetOrdered(false)) require.Error(mt, err, "expected a BulkWrite error") bwe, ok := err.(mongo.ClientBulkWriteException) require.True(mt, ok, "expected a BulkWriteException, got %T: %v", err, err) @@ -587,7 +599,7 @@ func TestClientBulkWrite(t *testing.T) { mt.Run("ordered", func(mt *mtest.T) { eventCnt = 0 mt.ResetClient(options.Client().SetMonitor(monitor)) - _, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(true)) + _, err := mt.Client.BulkWrite(context.Background(), writes, options.ClientBulkWrite().SetOrdered(true)) require.Error(mt, err, "expected a BulkWrite error") bwe, ok := err.(mongo.ClientBulkWriteException) require.True(mt, ok, "expected a BulkWriteException, got %T: %v", err, err) @@ -596,7 +608,7 @@ func TestClientBulkWrite(t *testing.T) { }) }) - mt.Run("bulkWrite handles a cursor requiring a getMore", func(mt *mtest.T) { + mt.Run("7. MongoClient.bulkWrite handles a cursor requiring a getMore", func(mt *mtest.T) { var getMoreCalled int monitor := &event.CommandMonitor{ Started: func(_ context.Context, e *event.CommandStartedEvent) { @@ -617,17 +629,26 @@ func TestClientBulkWrite(t *testing.T) { require.NoError(mt, err, "Drop error: %v", err) upsert := true - models := (&mongo.ClientWriteModels{}). - AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ - Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, - Update: bson.D{{"$set", bson.D{{"x", 1}}}}, - Upsert: &upsert, - }). - AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ - Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, - Update: bson.D{{"$set", bson.D{{"x", 1}}}}, - Upsert: &upsert, - }) + models := []mongo.ClientBulkWrite{ + { + Database: "db", + Collection: "coll", + Model: &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }, + }, + { + Database: "db", + Collection: "coll", + Model: &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }, + }, + } result, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetVerboseResults(true)) require.NoError(mt, err, "BulkWrite error: %v", err) assert.Equal(mt, int64(2), result.UpsertedCount, "expected InsertedCount: %d, got: %d", 2, result.UpsertedCount) @@ -635,7 +656,7 @@ func TestClientBulkWrite(t *testing.T) { assert.Equal(mt, 1, getMoreCalled, "expected %d getMore call, got: %d", 1, getMoreCalled) }) - mt.RunOpts("bulkWrite handles a cursor requiring a getMore within a transaction", + mt.RunOpts("8. MongoClient.bulkWrite handles a cursor requiring getMore within a transaction", mtest.NewOptions().MinServerVersion("8.0").AtlasDataLake(false).ClientType(mtest.Pinned). Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.ShardedReplicaSet), func(mt *mtest.T) { @@ -663,17 +684,26 @@ func TestClientBulkWrite(t *testing.T) { defer session.EndSession(context.Background()) upsert := true - models := (&mongo.ClientWriteModels{}). - AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ - Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, - Update: bson.D{{"$set", bson.D{{"x", 1}}}}, - Upsert: &upsert, - }). - AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ - Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, - Update: bson.D{{"$set", bson.D{{"x", 1}}}}, - Upsert: &upsert, - }) + models := []mongo.ClientBulkWrite{ + { + Database: "db", + Collection: "coll", + Model: &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }, + }, + { + Database: "db", + Collection: "coll", + Model: &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }, + }, + } result, err := session.WithTransaction(context.Background(), func(ctx context.Context) (interface{}, error) { return mt.Client.BulkWrite(ctx, models, options.ClientBulkWrite().SetVerboseResults(true)) }) @@ -685,7 +715,7 @@ func TestClientBulkWrite(t *testing.T) { assert.Equal(mt, 1, getMoreCalled, "expected %d getMore call, got: %d", 1, getMoreCalled) }) - mt.Run("bulkWrite handles a getMore error", func(mt *mtest.T) { + mt.Run("9. MongoClient.bulkWrite handles a getMore error", func(mt *mtest.T) { var getMoreCalled int var killCursorsCalled int monitor := &event.CommandMonitor{ @@ -721,17 +751,26 @@ func TestClientBulkWrite(t *testing.T) { require.NoError(mt, err, "Drop error: %v", err) upsert := true - models := (&mongo.ClientWriteModels{}). - AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ - Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, - Update: bson.D{{"$set", bson.D{{"x", 1}}}}, - Upsert: &upsert, - }). - AppendUpdateOne("db", "coll", &mongo.ClientUpdateOneModel{ - Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, - Update: bson.D{{"$set", bson.D{{"x", 1}}}}, - Upsert: &upsert, - }) + models := []mongo.ClientBulkWrite{ + { + Database: "db", + Collection: "coll", + Model: &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("a", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }, + }, + { + Database: "db", + Collection: "coll", + Model: &mongo.ClientUpdateOneModel{ + Filter: bson.D{{"_id", strings.Repeat("b", hello.MaxBsonObjectSize/2)}}, + Update: bson.D{{"$set", bson.D{{"x", 1}}}}, + Upsert: &upsert, + }, + }, + } _, err = mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetVerboseResults(true)) assert.Error(mt, err, "expected a BulkWrite error") bwe, ok := err.(mongo.ClientBulkWriteException) @@ -745,7 +784,7 @@ func TestClientBulkWrite(t *testing.T) { assert.Equal(mt, 1, killCursorsCalled, "expected %d killCursors call, got: %d", 1, killCursorsCalled) }) - mt.Run("bulkWrite batch splits when the addition of a new namespace exceeds the maximum message size", func(mt *mtest.T) { + mt.Run("11. MongoClient.bulkWrite batch splits when the addition of a new namespace exceeds the maximum message size", func(mt *mtest.T) { type cmd struct { Ops []bson.D NsInfo []struct { @@ -771,60 +810,73 @@ func TestClientBulkWrite(t *testing.T) { err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) require.NoError(mt, err, "Hello error: %v", err) - newModels := func() (int, *mongo.ClientWriteModels) { + newWrites := func() (int, []mongo.ClientBulkWrite) { maxBsonObjectSize := hello.MaxBsonObjectSize opsBytes := hello.MaxMessageSizeBytes - 1122 - numModels := opsBytes / maxBsonObjectSize - - models := &mongo.ClientWriteModels{} - n := numModels - for i := 0; i < n; i++ { - models. - AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + num := opsBytes / maxBsonObjectSize + + var writes []mongo.ClientBulkWrite + for i := 0; i < num; i++ { + writes = append(writes, mongo.ClientBulkWrite{ + Database: "db", + Collection: "coll", + Model: &mongo.ClientInsertOneModel{ Document: bson.D{{"a", strings.Repeat("b", maxBsonObjectSize-57)}}, - }) + }, + }) } if remainderBytes := opsBytes % maxBsonObjectSize; remainderBytes > 217 { - n++ - models. - AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + num++ + writes = append(writes, mongo.ClientBulkWrite{ + Database: "db", + Collection: "coll", + Model: &mongo.ClientInsertOneModel{ Document: bson.D{{"a", strings.Repeat("b", remainderBytes-57)}}, - }) + }, + }) } - return n, models + return num, writes } - mt.Run("no batch-splitting required", func(mt *mtest.T) { + mt.Run("Case 1: No batch-splitting required", func(mt *mtest.T) { bwCmd = bwCmd[:0] mt.ResetClient(options.Client().SetMonitor(monitor)) - numModels, models := newModels() - models.AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ - Document: bson.D{{"a", "b"}}, + num, writes := newWrites() + writes = append(writes, mongo.ClientBulkWrite{ + Database: "db", + Collection: "coll", + Model: &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", "b"}}, + }, }) - result, err := mt.Client.BulkWrite(context.Background(), models) + result, err := mt.Client.BulkWrite(context.Background(), writes) require.NoError(mt, err, "BulkWrite error: %v", err) - assert.Equal(mt, numModels+1, int(result.InsertedCount), "expected insertedCound: %d, got: %d", numModels+1, result.InsertedCount) + assert.Equal(mt, num+1, int(result.InsertedCount), "expected insertedCound: %d, got: %d", num+1, result.InsertedCount) require.Len(mt, bwCmd, 1, "expected %d bulkWrite call, got: %d", 1, len(bwCmd)) - assert.Len(mt, bwCmd[0].Ops, numModels+1, "expected %d ops, got: %d", numModels+1, len(bwCmd[0].Ops)) + assert.Len(mt, bwCmd[0].Ops, num+1, "expected %d ops, got: %d", num+1, len(bwCmd[0].Ops)) require.Len(mt, bwCmd[0].NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(bwCmd[0].NsInfo)) assert.Equal(mt, "db.coll", bwCmd[0].NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", bwCmd[0].NsInfo[0].Ns) }) - mt.Run("batch-splitting required", func(mt *mtest.T) { + mt.Run("Case 2: Batch-splitting required", func(mt *mtest.T) { bwCmd = bwCmd[:0] mt.ResetClient(options.Client().SetMonitor(monitor)) coll := strings.Repeat("c", 200) - numModels, models := newModels() - models.AppendInsertOne("db", coll, &mongo.ClientInsertOneModel{ - Document: bson.D{{"a", "b"}}, + num, writes := newWrites() + writes = append(writes, mongo.ClientBulkWrite{ + Database: "db", + Collection: coll, + Model: &mongo.ClientInsertOneModel{ + Document: bson.D{{"a", "b"}}, + }, }) - result, err := mt.Client.BulkWrite(context.Background(), models) + result, err := mt.Client.BulkWrite(context.Background(), writes) require.NoError(mt, err, "BulkWrite error: %v", err) - assert.Equal(mt, numModels+1, int(result.InsertedCount), "expected insertedCound: %d, got: %d", numModels+1, result.InsertedCount) + assert.Equal(mt, num+1, int(result.InsertedCount), "expected insertedCound: %d, got: %d", num+1, result.InsertedCount) require.Len(mt, bwCmd, 2, "expected %d bulkWrite calls, got: %d", 2, len(bwCmd)) - assert.Len(mt, bwCmd[0].Ops, numModels, "expected %d ops, got: %d", numModels, len(bwCmd[0].Ops)) + assert.Len(mt, bwCmd[0].Ops, num, "expected %d ops, got: %d", num, len(bwCmd[0].Ops)) require.Len(mt, bwCmd[0].NsInfo, 1, "expected %d nsInfo, got: %d", 1, len(bwCmd[0].NsInfo)) assert.Equal(mt, "db.coll", bwCmd[0].NsInfo[0].Ns, "expected namespace: %s, got: %s", "db.coll", bwCmd[0].NsInfo[0].Ns) @@ -834,32 +886,38 @@ func TestClientBulkWrite(t *testing.T) { }) }) - mt.Run("bulkWrite returns an error if no operations can be added to ops", func(mt *mtest.T) { + mt.Run("12. MongoClient.bulkWrite returns an error if no operations can be added to ops", func(mt *mtest.T) { mt.ResetClient(options.Client()) var hello struct { MaxMessageSizeBytes int } err := mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) require.NoError(mt, err, "Hello error: %v", err) - mt.Run("document too large", func(mt *mtest.T) { - models := (&mongo.ClientWriteModels{}). - AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + mt.Run("Case 1: document too large", func(mt *mtest.T) { + writes := []mongo.ClientBulkWrite{{ + Database: "db", + Collection: "coll", + Model: &mongo.ClientInsertOneModel{ Document: bson.D{{"a", strings.Repeat("b", hello.MaxMessageSizeBytes)}}, - }) - _, err := mt.Client.BulkWrite(context.Background(), models) + }, + }} + _, err := mt.Client.BulkWrite(context.Background(), writes) require.EqualError(mt, err, driver.ErrDocumentTooLarge.Error()) }) - mt.Run("namespace too large", func(mt *mtest.T) { - models := (&mongo.ClientWriteModels{}). - AppendInsertOne("db", strings.Repeat("c", hello.MaxMessageSizeBytes), &mongo.ClientInsertOneModel{ + mt.Run("Case 2: namespace too large", func(mt *mtest.T) { + writes := []mongo.ClientBulkWrite{{ + Database: "db", + Collection: strings.Repeat("c", hello.MaxMessageSizeBytes), + Model: &mongo.ClientInsertOneModel{ Document: bson.D{{"a", "b"}}, - }) - _, err := mt.Client.BulkWrite(context.Background(), models) + }, + }} + _, err := mt.Client.BulkWrite(context.Background(), writes) require.EqualError(mt, err, driver.ErrDocumentTooLarge.Error()) }) }) - mt.Run("bulkWrite returns an error if auto-encryption is configured", func(mt *mtest.T) { + mt.Run("13. MongoClient.bulkWrite returns an error if auto-encryption is configured", func(mt *mtest.T) { if os.Getenv("DOCKER_RUNNING") != "" { mt.Skip("skipping test in docker environment") } @@ -873,15 +931,18 @@ func TestClientBulkWrite(t *testing.T) { }, }) mt.ResetClient(options.Client().SetAutoEncryptionOptions(autoEncryptionOpts)) - models := (&mongo.ClientWriteModels{}). - AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + writes := []mongo.ClientBulkWrite{{ + Database: "db", + Collection: "coll", + Model: &mongo.ClientInsertOneModel{ Document: bson.D{{"a", "b"}}, - }) - _, err := mt.Client.BulkWrite(context.Background(), models) + }, + }} + _, err := mt.Client.BulkWrite(context.Background(), writes) require.ErrorContains(mt, err, "bulkWrite does not currently support automatic encryption") }) - mt.Run("bulkWrite with unacknowledged write concern uses w:0 for all batches", func(mt *mtest.T) { + mt.Run("15. MongoClient.bulkWrite with unacknowledged write concern uses w:0 for all batches", func(mt *mtest.T) { type cmd struct { Ops []bson.D WriteConcern struct { @@ -912,20 +973,23 @@ func TestClientBulkWrite(t *testing.T) { err = coll.Drop(context.Background()) require.NoError(mt, err, "Drop error: %v", err) - numModels := hello.MaxMessageSizeBytes/hello.MaxBsonObjectSize + 1 - models := &mongo.ClientWriteModels{} - for i := 0; i < numModels; i++ { - models. - AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + num := hello.MaxMessageSizeBytes/hello.MaxBsonObjectSize + 1 + var writes []mongo.ClientBulkWrite + for i := 0; i < num; i++ { + writes = append(writes, mongo.ClientBulkWrite{ + Database: "db", + Collection: "coll", + Model: &mongo.ClientInsertOneModel{ Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize-500)}}, - }) + }, + }) } - result, err := mt.Client.BulkWrite(context.Background(), models, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged())) + result, err := mt.Client.BulkWrite(context.Background(), writes, options.ClientBulkWrite().SetOrdered(false).SetWriteConcern(writeconcern.Unacknowledged())) require.NoError(mt, err, "BulkWrite error: %v", err) assert.False(mt, result.Acknowledged) require.Len(mt, bwCmd, 2, "expected %d bulkWrite calls, got: %d", 2, len(bwCmd)) - assert.Len(mt, bwCmd[0].Ops, numModels-1, "expected %d ops, got: %d", numModels-1, len(bwCmd[0].Ops)) + assert.Len(mt, bwCmd[0].Ops, num-1, "expected %d ops, got: %d", num-1, len(bwCmd[0].Ops)) assert.Equal(mt, int32(0), bwCmd[0].WriteConcern.W, "expected writeConcern: %d, got: %v", 0, bwCmd[0].WriteConcern.W) assert.Len(mt, bwCmd[1].Ops, 1, "expected %d ops, got: %d", 1, len(bwCmd[1].Ops)) @@ -933,6 +997,6 @@ func TestClientBulkWrite(t *testing.T) { n, err := coll.CountDocuments(context.Background(), bson.D{}) require.NoError(mt, err, "CountDocuments error: %v", err) - assert.Equal(mt, numModels, int(n), "expected %d documents, got: %d", numModels, n) + assert.Equal(mt, num, int(n), "expected %d documents, got: %d", num, n) }) } diff --git a/internal/integration/csot_prose_test.go b/internal/integration/csot_prose_test.go index ec944e9a9b..38dd254edc 100644 --- a/internal/integration/csot_prose_test.go +++ b/internal/integration/csot_prose_test.go @@ -203,13 +203,16 @@ func TestCSOTProse(t *testing.T) { err = mt.DB.RunCommand(context.Background(), bson.D{{"hello", 1}}).Decode(&hello) require.NoError(mt, err, "Hello error: %v", err) - models := &mongo.ClientWriteModels{} + var writes []mongo.ClientBulkWrite n := hello.MaxMessageSizeBytes/hello.MaxBsonObjectSize + 1 for i := 0; i < n; i++ { - models. - AppendInsertOne("db", "coll", &mongo.ClientInsertOneModel{ + writes = append(writes, mongo.ClientBulkWrite{ + Database: "db", + Collection: "coll", + Model: &mongo.ClientInsertOneModel{ Document: bson.D{{"a", strings.Repeat("b", hello.MaxBsonObjectSize-500)}}, - }) + }, + }) } var cnt int @@ -227,8 +230,8 @@ func TestCSOTProse(t *testing.T) { integtest.AddTestServerAPIVersion(cliOptions) cli, err := mongo.Connect(cliOptions) require.NoError(mt, err, "Connect error: %v", err) - _, err = cli.BulkWrite(context.Background(), models) - assert.ErrorContains(mt, err, "context deadline exceeded", "expected a timeout error, got: %v", err) + _, err = cli.BulkWrite(context.Background(), writes) + assert.ErrorIs(mt, err, context.DeadlineExceeded, "expected a timeout error, got: %v", err) assert.Equal(mt, 2, cnt, "expected bulkWrite calls: %d, got: %d", 2, cnt) }) } diff --git a/internal/integration/unified/client_operation_execution.go b/internal/integration/unified/client_operation_execution.go index 212b2c59dd..08ae2d52e4 100644 --- a/internal/integration/unified/client_operation_execution.go +++ b/internal/integration/unified/client_operation_execution.go @@ -173,7 +173,7 @@ func executeClientBulkWrite(ctx context.Context, operation *operation) (*operati return nil, err } - wirteModels := &mongo.ClientWriteModels{} + var writes []mongo.ClientBulkWrite opts := options.ClientBulkWrite() elems, err := operation.Arguments.Elements() @@ -192,10 +192,27 @@ func executeClientBulkWrite(ctx context.Context, operation *operation) (*operati } for _, m := range models { model := m.Document().Index(0) - err = appendClientBulkWriteModel(model.Key(), model.Value().Document(), wirteModels) + var op *mongo.ClientBulkWrite + switch key := model.Key(); key { + case "insertOne": + op, err = createClientInsertOneModel(model.Value().Document()) + case "updateOne": + op, err = createClientUpdateOneModel(model.Value().Document()) + case "updateMany": + op, err = createClientUpdateManyModel(model.Value().Document()) + case "replaceOne": + op, err = createClientReplaceOneModel(model.Value().Document()) + case "deleteOne": + op, err = createClientDeleteOneModel(model.Value().Document()) + case "deleteMany": + op, err = createClientDeleteManyModel(model.Value().Document()) + default: + err = fmt.Errorf("unrecognized bulkWrite model %q", key) + } if err != nil { return nil, err } + writes = append(writes, *op) } case "bypassDocumentValidation": opts.SetBypassDocumentValidation(val.Boolean()) @@ -223,7 +240,7 @@ func executeClientBulkWrite(ctx context.Context, operation *operation) (*operati } } - res, err := client.BulkWrite(ctx, wirteModels, opts) + res, err := client.BulkWrite(ctx, writes, opts) var bwe mongo.ClientBulkWriteException if errors.As(err, &bwe) { res = bwe.PartialResult @@ -283,69 +300,26 @@ func executeClientBulkWrite(ctx context.Context, operation *operation) (*operati return newDocumentResult(rawBuilder.Build(), err), nil } -func appendClientBulkWriteModel(key string, value bson.Raw, model *mongo.ClientWriteModels) error { - switch key { - case "insertOne": - namespace, m, err := createClientInsertOneModel(value) - if err != nil { - return err - } - ns := strings.SplitN(namespace, ".", 2) - model.AppendInsertOne(ns[0], ns[1], m) - case "updateOne": - namespace, m, err := createClientUpdateOneModel(value) - if err != nil { - return err - } - ns := strings.SplitN(namespace, ".", 2) - model.AppendUpdateOne(ns[0], ns[1], m) - case "updateMany": - namespace, m, err := createClientUpdateManyModel(value) - if err != nil { - return err - } - ns := strings.SplitN(namespace, ".", 2) - model.AppendUpdateMany(ns[0], ns[1], m) - case "replaceOne": - namespace, m, err := createClientReplaceOneModel(value) - if err != nil { - return err - } - ns := strings.SplitN(namespace, ".", 2) - model.AppendReplaceOne(ns[0], ns[1], m) - case "deleteOne": - namespace, m, err := createClientDeleteOneModel(value) - if err != nil { - return err - } - ns := strings.SplitN(namespace, ".", 2) - model.AppendDeleteOne(ns[0], ns[1], m) - case "deleteMany": - namespace, m, err := createClientDeleteManyModel(value) - if err != nil { - return err - } - ns := strings.SplitN(namespace, ".", 2) - model.AppendDeleteMany(ns[0], ns[1], m) - } - return nil -} - -func createClientInsertOneModel(value bson.Raw) (string, *mongo.ClientInsertOneModel, error) { +func createClientInsertOneModel(value bson.Raw) (*mongo.ClientBulkWrite, error) { var v struct { Namespace string Document bson.Raw } err := bson.Unmarshal(value, &v) if err != nil { - return "", nil, err + return nil, err } - return v.Namespace, &mongo.ClientInsertOneModel{ - Document: v.Document, + ns := strings.SplitN(v.Namespace, ".", 2) + return &mongo.ClientBulkWrite{ + Database: ns[0], + Collection: ns[1], + Model: &mongo.ClientInsertOneModel{ + Document: v.Document, + }, }, nil } -func createClientUpdateOneModel(value bson.Raw) (string, *mongo.ClientUpdateOneModel, error) { +func createClientUpdateOneModel(value bson.Raw) (*mongo.ClientBulkWrite, error) { var v struct { Namespace string Filter bson.Raw @@ -357,13 +331,13 @@ func createClientUpdateOneModel(value bson.Raw) (string, *mongo.ClientUpdateOneM } err := bson.Unmarshal(value, &v) if err != nil { - return "", nil, err + return nil, err } var hint interface{} if v.Hint != nil { hint, err = createHint(*v.Hint) if err != nil { - return "", nil, err + return nil, err } } model := &mongo.ClientUpdateOneModel{ @@ -376,11 +350,15 @@ func createClientUpdateOneModel(value bson.Raw) (string, *mongo.ClientUpdateOneM if len(v.ArrayFilters) > 0 { model.ArrayFilters = v.ArrayFilters } - return v.Namespace, model, nil - + ns := strings.SplitN(v.Namespace, ".", 2) + return &mongo.ClientBulkWrite{ + Database: ns[0], + Collection: ns[1], + Model: model, + }, nil } -func createClientUpdateManyModel(value bson.Raw) (string, *mongo.ClientUpdateManyModel, error) { +func createClientUpdateManyModel(value bson.Raw) (*mongo.ClientBulkWrite, error) { var v struct { Namespace string Filter bson.Raw @@ -392,13 +370,13 @@ func createClientUpdateManyModel(value bson.Raw) (string, *mongo.ClientUpdateMan } err := bson.Unmarshal(value, &v) if err != nil { - return "", nil, err + return nil, err } var hint interface{} if v.Hint != nil { hint, err = createHint(*v.Hint) if err != nil { - return "", nil, err + return nil, err } } model := &mongo.ClientUpdateManyModel{ @@ -411,10 +389,15 @@ func createClientUpdateManyModel(value bson.Raw) (string, *mongo.ClientUpdateMan if len(v.ArrayFilters) > 0 { model.ArrayFilters = v.ArrayFilters } - return v.Namespace, model, nil + ns := strings.SplitN(v.Namespace, ".", 2) + return &mongo.ClientBulkWrite{ + Database: ns[0], + Collection: ns[1], + Model: model, + }, nil } -func createClientReplaceOneModel(value bson.Raw) (string, *mongo.ClientReplaceOneModel, error) { +func createClientReplaceOneModel(value bson.Raw) (*mongo.ClientBulkWrite, error) { var v struct { Namespace string Filter bson.Raw @@ -425,25 +408,30 @@ func createClientReplaceOneModel(value bson.Raw) (string, *mongo.ClientReplaceOn } err := bson.Unmarshal(value, &v) if err != nil { - return "", nil, err + return nil, err } var hint interface{} if v.Hint != nil { hint, err = createHint(*v.Hint) if err != nil { - return "", nil, err + return nil, err } } - return v.Namespace, &mongo.ClientReplaceOneModel{ - Filter: v.Filter, - Replacement: v.Replacement, - Collation: v.Collation, - Hint: hint, - Upsert: v.Upsert, + ns := strings.SplitN(v.Namespace, ".", 2) + return &mongo.ClientBulkWrite{ + Database: ns[0], + Collection: ns[1], + Model: &mongo.ClientReplaceOneModel{ + Filter: v.Filter, + Replacement: v.Replacement, + Collation: v.Collation, + Hint: hint, + Upsert: v.Upsert, + }, }, nil } -func createClientDeleteOneModel(value bson.Raw) (string, *mongo.ClientDeleteOneModel, error) { +func createClientDeleteOneModel(value bson.Raw) (*mongo.ClientBulkWrite, error) { var v struct { Namespace string Filter bson.Raw @@ -452,23 +440,28 @@ func createClientDeleteOneModel(value bson.Raw) (string, *mongo.ClientDeleteOneM } err := bson.Unmarshal(value, &v) if err != nil { - return "", nil, err + return nil, err } var hint interface{} if v.Hint != nil { hint, err = createHint(*v.Hint) if err != nil { - return "", nil, err + return nil, err } } - return v.Namespace, &mongo.ClientDeleteOneModel{ - Filter: v.Filter, - Collation: v.Collation, - Hint: hint, + ns := strings.SplitN(v.Namespace, ".", 2) + return &mongo.ClientBulkWrite{ + Database: ns[0], + Collection: ns[1], + Model: &mongo.ClientDeleteOneModel{ + Filter: v.Filter, + Collation: v.Collation, + Hint: hint, + }, }, nil } -func createClientDeleteManyModel(value bson.Raw) (string, *mongo.ClientDeleteManyModel, error) { +func createClientDeleteManyModel(value bson.Raw) (*mongo.ClientBulkWrite, error) { var v struct { Namespace string Filter bson.Raw @@ -477,18 +470,23 @@ func createClientDeleteManyModel(value bson.Raw) (string, *mongo.ClientDeleteMan } err := bson.Unmarshal(value, &v) if err != nil { - return "", nil, err + return nil, err } var hint interface{} if v.Hint != nil { hint, err = createHint(*v.Hint) if err != nil { - return "", nil, err + return nil, err } } - return v.Namespace, &mongo.ClientDeleteManyModel{ - Filter: v.Filter, - Collation: v.Collation, - Hint: hint, + ns := strings.SplitN(v.Namespace, ".", 2) + return &mongo.ClientBulkWrite{ + Database: ns[0], + Collection: ns[1], + Model: &mongo.ClientDeleteManyModel{ + Filter: v.Filter, + Collation: v.Collation, + Hint: hint, + }, }, nil } diff --git a/mongo/client.go b/mongo/client.go index 7c308f722a..09535f2ba6 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -872,16 +872,23 @@ func (c *Client) createBaseCursorOptions() driver.CursorOptions { } } +// ClientBulkWrite is a struct that can be used in a client-level BulkWrite operation. +type ClientBulkWrite struct { + Database string + Collection string + Model ClientWriteModel +} + // BulkWrite performs a client-level bulk write operation. -func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels, +func (c *Client) BulkWrite(ctx context.Context, writes []ClientBulkWrite, opts ...options.Lister[options.ClientBulkWriteOptions]) (*ClientBulkWriteResult, error) { // TODO(GODRIVER-3403): Remove after support for QE with Client.bulkWrite. if c.isAutoEncryptionSet { return nil, errors.New("bulkWrite does not currently support automatic encryption") } - if models == nil { - return nil, ErrNilValue + if len(writes) == 0 { + return nil, ErrEmptySlice } bwo, err := mongoutil.NewOptions(opts...) if err != nil { @@ -930,8 +937,16 @@ func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels, } selector := makePinnedSelector(sess, writeSelector) + writePairs := make([]clientBulkWritePair, len(writes)) + for i, w := range writes { + writePairs[i] = clientBulkWritePair{ + namespace: fmt.Sprintf("%s.%s", w.Database, w.Collection), + model: w.Model, + } + } + op := clientBulkWrite{ - models: models.models, + writePairs: writePairs, ordered: bwo.Ordered, bypassDocumentValidation: bwo.BypassDocumentValidation, comment: bwo.Comment, diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index 4560654cd5..b2fd551a82 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -29,8 +29,13 @@ const ( database = "admin" ) +type clientBulkWritePair struct { + namespace string + model interface{} +} + type clientBulkWrite struct { - models []clientWriteModel + writePairs []clientBulkWritePair errorsOnly bool ordered *bool bypassDocumentValidation *bool @@ -45,21 +50,21 @@ type clientBulkWrite struct { } func (bw *clientBulkWrite) execute(ctx context.Context) error { - if len(bw.models) == 0 { + if len(bw.writePairs) == 0 { return ErrEmptySlice } - for _, m := range bw.models { + for _, m := range bw.writePairs { if m.model == nil { return ErrNilDocument } } batches := &modelBatches{ - session: bw.session, - client: bw.client, - ordered: bw.ordered == nil || *bw.ordered, - models: bw.models, - result: &bw.result, - retryMode: driver.RetryOnce, + session: bw.session, + client: bw.client, + ordered: bw.ordered == nil || *bw.ordered, + writePairs: bw.writePairs, + result: &bw.result, + retryMode: driver.RetryOnce, } err := driver.Operation{ CommandFn: bw.newCommand(), @@ -106,7 +111,7 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error { _, ok := batches.writeErrors[0] hasSuccess = !ok } else { - hasSuccess = len(batches.writeErrors) < len(bw.models) + hasSuccess = len(batches.writeErrors) < len(bw.writePairs) } if hasSuccess { exception.PartialResult = batches.result @@ -177,8 +182,8 @@ type modelBatches struct { session *session.Client client *Client - ordered bool - models []clientWriteModel + ordered bool + writePairs []clientBulkWritePair offset int @@ -199,16 +204,16 @@ func (mb *modelBatches) IsOrdered() *bool { func (mb *modelBatches) AdvanceBatches(n int) { mb.offset += n - if mb.offset > len(mb.models) { - mb.offset = len(mb.models) + if mb.offset > len(mb.writePairs) { + mb.offset = len(mb.writePairs) } } func (mb *modelBatches) Size() int { - if mb.offset > len(mb.models) { + if mb.offset > len(mb.writePairs) { return 0 } - return len(mb.models) - mb.offset + return len(mb.writePairs) - mb.offset } func (mb *modelBatches) AppendBatchSequence(dst []byte, maxCount, totalSize int) (int, []byte, error) { @@ -279,17 +284,17 @@ func (mb *modelBatches) appendBatches(fn functionSet, dst []byte, maxCount, tota totalSize -= 1000 size := len(dst) + len(nsDst) var n int - for i := mb.offset; i < len(mb.models); i++ { + for i := mb.offset; i < len(mb.writePairs); i++ { if n == maxCount { break } - ns := mb.models[i].namespace + ns := mb.writePairs[i].namespace nsIdx, exists := getNsIndex(ns) var doc bsoncore.Document var err error - switch model := mb.models[i].model.(type) { + switch model := mb.writePairs[i].model.(type) { case *ClientInsertOneModel: mb.cursorHandlers = append(mb.cursorHandlers, mb.appendInsertResult) var id interface{} diff --git a/mongo/client_bulk_write_models.go b/mongo/client_bulk_write_models.go index 51208a8710..fdcac3d9ef 100644 --- a/mongo/client_bulk_write_models.go +++ b/mongo/client_bulk_write_models.go @@ -7,76 +7,17 @@ package mongo import ( - "fmt" - "go.mongodb.org/mongo-driver/v2/mongo/options" ) -// ClientWriteModels is a struct that can be used in a client-level BulkWrite operation. -type ClientWriteModels struct { - models []clientWriteModel -} -type clientWriteModel struct { - namespace string - model interface{} -} - -// AppendInsertOne appends ClientInsertOneModels. -func (m *ClientWriteModels) AppendInsertOne(database, collection string, models ...*ClientInsertOneModel) *ClientWriteModels { - if m == nil { - m = &ClientWriteModels{} - } - for _, model := range models { - m.models = append(m.models, clientWriteModel{ - namespace: fmt.Sprintf("%s.%s", database, collection), - model: model, - }) - } - return m -} - -// appendModels is a helper function to append models to ClientWriteModels. -func appendModels[T ClientUpdateOneModel | - ClientUpdateManyModel | - ClientReplaceOneModel | - ClientDeleteOneModel | - ClientDeleteManyModel]( - m *ClientWriteModels, database, collection string, models []*T) *ClientWriteModels { - if m == nil { - m = &ClientWriteModels{} - } - for _, model := range models { - m.models = append(m.models, clientWriteModel{ - namespace: fmt.Sprintf("%s.%s", database, collection), - model: model, - }) - } - return m -} - -// AppendUpdateOne appends ClientUpdateOneModels. -func (m *ClientWriteModels) AppendUpdateOne(database, collection string, models ...*ClientUpdateOneModel) *ClientWriteModels { - return appendModels(m, database, collection, models) -} - -// AppendUpdateMany appends ClientUpdateManyModels. -func (m *ClientWriteModels) AppendUpdateMany(database, collection string, models ...*ClientUpdateManyModel) *ClientWriteModels { - return appendModels(m, database, collection, models) -} - -// AppendReplaceOne appends ClientReplaceOneModels. -func (m *ClientWriteModels) AppendReplaceOne(database, collection string, models ...*ClientReplaceOneModel) *ClientWriteModels { - return appendModels(m, database, collection, models) -} - -// AppendDeleteOne appends ClientDeleteOneModels. -func (m *ClientWriteModels) AppendDeleteOne(database, collection string, models ...*ClientDeleteOneModel) *ClientWriteModels { - return appendModels(m, database, collection, models) -} - -// AppendDeleteMany appends ClientDeleteManyModels. -func (m *ClientWriteModels) AppendDeleteMany(database, collection string, models ...*ClientDeleteManyModel) *ClientWriteModels { - return appendModels(m, database, collection, models) +// ClientWriteModel is an interface implemented by models that can be used in a client-level BulkWrite operation. Each +// ClientWriteModel represents a write. +// +// This interface is implemented by ClientDeleteOneModel, ClientDeleteManyModel, ClientInsertOneModel, +// ClientReplaceOneModel, ClientUpdateOneModel, and ClientUpdateManyModel. Custom implementations of this interface must +// not be used. +type ClientWriteModel interface { + clientWriteModel() } // ClientInsertOneModel is used to insert a single document in a client-level BulkWrite operation. @@ -91,6 +32,8 @@ func NewClientInsertOneModel() *ClientInsertOneModel { return &ClientInsertOneModel{} } +func (*ClientInsertOneModel) clientWriteModel() {} + // SetDocument specifies the document to be inserted. The document cannot be nil. If it does not have an _id field when // transformed into BSON, one will be added automatically to the marshalled document. The original document will not be // modified. @@ -116,6 +59,8 @@ func NewClientUpdateOneModel() *ClientUpdateOneModel { return &ClientUpdateOneModel{} } +func (*ClientUpdateOneModel) clientWriteModel() {} + // SetHint specifies the index to use for the operation. This should either be the index name as a string or the index // specification as a document. The default value is nil, which means that no hint will be sent. func (uom *ClientUpdateOneModel) SetHint(hint interface{}) *ClientUpdateOneModel { @@ -177,6 +122,8 @@ func NewClientUpdateManyModel() *ClientUpdateManyModel { return &ClientUpdateManyModel{} } +func (*ClientUpdateManyModel) clientWriteModel() {} + // SetHint specifies the index to use for the operation. This should either be the index name as a string or the index // specification as a document. The default value is nil, which means that no hint will be sent. func (umm *ClientUpdateManyModel) SetHint(hint interface{}) *ClientUpdateManyModel { @@ -236,6 +183,8 @@ func NewClientReplaceOneModel() *ClientReplaceOneModel { return &ClientReplaceOneModel{} } +func (*ClientReplaceOneModel) clientWriteModel() {} + // SetHint specifies the index to use for the operation. This should either be the index name as a string or the index // specification as a document. The default value is nil, which means that no hint will be sent. func (rom *ClientReplaceOneModel) SetHint(hint interface{}) *ClientReplaceOneModel { @@ -287,6 +236,8 @@ func NewClientDeleteOneModel() *ClientDeleteOneModel { return &ClientDeleteOneModel{} } +func (*ClientDeleteOneModel) clientWriteModel() {} + // SetFilter specifies a filter to use to select the document to delete. The filter must be a document containing query // operators. It cannot be nil. If the filter matches multiple documents, one will be selected from the matching // documents. @@ -323,6 +274,8 @@ func NewClientDeleteManyModel() *ClientDeleteManyModel { return &ClientDeleteManyModel{} } +func (*ClientDeleteManyModel) clientWriteModel() {} + // SetFilter specifies a filter to use to select documents to delete. The filter must be a document containing query // operators. It cannot be nil. func (dmm *ClientDeleteManyModel) SetFilter(filter interface{}) *ClientDeleteManyModel { diff --git a/mongo/client_bulk_write_test.go b/mongo/client_bulk_write_test.go index 2b01444824..7eb4fd9907 100644 --- a/mongo/client_bulk_write_test.go +++ b/mongo/client_bulk_write_test.go @@ -18,7 +18,7 @@ func TestBatches(t *testing.T) { t.Parallel() batches := &modelBatches{ - models: make([]clientWriteModel, 2), + writePairs: make([]clientBulkWritePair, 2), } batches.AdvanceBatches(3) size := batches.Size() @@ -33,7 +33,7 @@ func TestAppendBatchSequence(t *testing.T) { require.NoError(t, err, "NewClient error: %v", err) return &modelBatches{ client: client, - models: []clientWriteModel{ + writePairs: []clientBulkWritePair{ {"ns0", nil}, {"ns1", &ClientInsertOneModel{ Document: bson.D{{"foo", 42}}, diff --git a/mongo/client_test.go b/mongo/client_test.go index 8d0c4245dc..4e595036c4 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -540,12 +540,13 @@ func TestClient(t *testing.T) { } document := bson.D{{"largeField", strings.Repeat("a", 16777216-100)}} // Adjust size to account for BSON overhead - models := &ClientWriteModels{} - models = models.AppendInsertOne("db", "x", NewClientInsertOneModel().SetDocument(document)) - models = models.AppendInsertOne("db", "x", NewClientInsertOneModel().SetDocument(document)) - models = models.AppendInsertOne("db", "x", NewClientInsertOneModel().SetDocument(document)) + writes := []ClientBulkWrite{ + {"db", "x", NewClientInsertOneModel().SetDocument(document)}, + {"db", "x", NewClientInsertOneModel().SetDocument(document)}, + {"db", "x", NewClientInsertOneModel().SetDocument(document)}, + } - _, err = client.BulkWrite(context.Background(), models) + _, err = client.BulkWrite(context.Background(), writes) require.NoError(t, err) assert.Equal(t, 2, bulkWrites, "expected %d bulkWrites, got %d", 2, bulkWrites) }) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 8659c92e90..3d0b5cb9a4 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -1407,7 +1407,7 @@ func (op Operation) createWireMessage( for b := dst[batchOffset:]; len(b) > 0; /* nothing */ { var seq []byte var ok bool - seq, b, ok = wiremessage.DocumentSequenceToArray(b) + seq, b, ok = documentSequenceToArray(b) if !ok { break } @@ -2232,3 +2232,35 @@ func sessionsSupported(wireVersion *description.VersionRange) bool { func retryWritesSupported(s description.Server) bool { return s.SessionTimeoutMinutes != nil && s.Kind != description.ServerKindStandalone } + +func documentSequenceToArray(src []byte) (dst, rem []byte, ok bool) { + stype, rem, ok := wiremessage.ReadMsgSectionType(src) + if !ok || stype != wiremessage.DocumentSequence { + return nil, src, false + } + var identifier string + var ret []byte + identifier, rem, ret, ok = wiremessage.ReadMsgSectionRawDocumentSequence(rem) + if !ok { + return nil, src, false + } + + aidx, dst := bsoncore.AppendArrayElementStart(nil, identifier) + i := 0 + for { + var doc bsoncore.Document + doc, rem, ok = bsoncore.ReadDocument(rem) + if !ok { + break + } + dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(i), doc) + i++ + } + if len(rem) > 0 { + return nil, src, false + } + + dst, _ = bsoncore.AppendArrayEnd(dst, aidx) + + return dst, ret, true +} diff --git a/x/mongo/driver/wiremessage/wiremessage.go b/x/mongo/driver/wiremessage/wiremessage.go index 9330499242..dd16cb7be0 100644 --- a/x/mongo/driver/wiremessage/wiremessage.go +++ b/x/mongo/driver/wiremessage/wiremessage.go @@ -16,7 +16,6 @@ package wiremessage import ( "bytes" "encoding/binary" - "strconv" "strings" "sync/atomic" @@ -423,39 +422,6 @@ func ReadMsgSectionRawDocumentSequence(src []byte) (identifier string, data []by return identifier, rem, rest, true } -// DocumentSequenceToArray converts a document sequence in byte slice to an array. -func DocumentSequenceToArray(src []byte) (dst, rem []byte, ok bool) { - stype, rem, ok := ReadMsgSectionType(src) - if !ok || stype != DocumentSequence { - return nil, src, false - } - var identifier string - var ret []byte - identifier, rem, ret, ok = ReadMsgSectionRawDocumentSequence(rem) - if !ok { - return nil, src, false - } - - aidx, dst := bsoncore.AppendArrayElementStart(nil, identifier) - i := 0 - for { - var doc bsoncore.Document - doc, rem, ok = bsoncore.ReadDocument(rem) - if !ok { - break - } - dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(i), doc) - i++ - } - if len(rem) > 0 { - return nil, src, false - } - - dst, _ = bsoncore.AppendArrayEnd(dst, aidx) - - return dst, ret, true -} - // ReadMsgChecksum reads a checksum from src. func ReadMsgChecksum(src []byte) (checksum uint32, rem []byte, ok bool) { i32, rem, ok := readi32(src)