diff --git a/RELEASES.md b/RELEASES.md index 13d983cf2803..0f8d57aea220 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -723,7 +723,7 @@ This version is backwards compatible to [v1.10.0](https://github.com/ava-labs/av - Add workflow to mark stale issues and PRs by @joshua-kim in https://github.com/ava-labs/avalanchego/pull/1443 - Enforce inlining functions with a single error return in `require.NoError` by @dhrubabasu in https://github.com/ava-labs/avalanchego/pull/1500 - `x/sync` / `x/merkledb` -- add `SyncableDB` interface by @danlaine in https://github.com/ava-labs/avalanchego/pull/1555 -- Rename beacon to boostrapper, define bootstrappers in JSON file for cross-language compatiblity by @gyuho in https://github.com/ava-labs/avalanchego/pull/1439 +- Rename beacon to boostrapper, define bootstrappers in JSON file for cross-language compatibility by @gyuho in https://github.com/ava-labs/avalanchego/pull/1439 - add P-chain height indexing by @dhrubabasu in https://github.com/ava-labs/avalanchego/pull/1447 - Add P-chain `GetBlockByHeight` API method by @dhrubabasu in https://github.com/ava-labs/avalanchego/pull/1448 - `x/sync` -- use for sending Range Proofs by @danlaine in https://github.com/ava-labs/avalanchego/pull/1537 diff --git a/api/info/service.go b/api/info/service.go index 65b5fb2fdf7f..47112e55b630 100644 --- a/api/info/service.go +++ b/api/info/service.go @@ -208,7 +208,7 @@ type PeersArgs struct { type Peer struct { peer.Info - Benched []ids.ID `json:"benched"` + Benched []string `json:"benched"` } // PeersReply are the results from calling Peers @@ -229,9 +229,18 @@ func (i *Info) Peers(_ *http.Request, args *PeersArgs, reply *PeersReply) error peers := i.networking.PeerInfo(args.NodeIDs) peerInfo := make([]Peer, len(peers)) for index, peer := range peers { + benchedIDs := i.benchlist.GetBenched(peer.ID) + benchedAliases := make([]string, len(benchedIDs)) + for idx, id := range benchedIDs { + alias, err := i.chainManager.PrimaryAlias(id) + if err != nil { + return fmt.Errorf("failed to get primary alias for chain ID %s: %w", id, err) + } + benchedAliases[idx] = alias + } peerInfo[index] = Peer{ Info: peer, - Benched: i.benchlist.GetBenched(peer.ID), + Benched: benchedAliases, } } diff --git a/ipcs/socket/socket_test.go b/ipcs/socket/socket_test.go index 4204d032285a..a56329b28c3e 100644 --- a/ipcs/socket/socket_test.go +++ b/ipcs/socket/socket_test.go @@ -8,6 +8,8 @@ import ( "testing" "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/utils/logging" ) func TestSocketSendAndReceive(t *testing.T) { @@ -21,7 +23,7 @@ func TestSocketSendAndReceive(t *testing.T) { ) // Create socket and client; wait for client to connect - socket := NewSocket(socketName, nil) + socket := NewSocket(socketName, logging.NoLog{}) socket.accept, connCh = newTestAcceptFn(t) require.NoError(socket.Listen()) diff --git a/snow/consensus/snowman/poll/set.go b/snow/consensus/snowman/poll/set.go index e31821476bc8..e58059f20c3d 100644 --- a/snow/consensus/snowman/poll/set.go +++ b/snow/consensus/snowman/poll/set.go @@ -4,6 +4,7 @@ package poll import ( + "errors" "fmt" "strings" "time" @@ -19,6 +20,11 @@ import ( "github.com/ava-labs/avalanchego/utils/metric" ) +var ( + errFailedPollsMetric = errors.New("failed to register polls metric") + errFailedPollDurationMetrics = errors.New("failed to register poll_duration metrics") +) + type pollHolder interface { GetPoll() Poll StartTime() time.Time @@ -52,16 +58,14 @@ func NewSet( log logging.Logger, namespace string, reg prometheus.Registerer, -) Set { +) (Set, error) { numPolls := prometheus.NewGauge(prometheus.GaugeOpts{ Namespace: namespace, Name: "polls", Help: "Number of pending network polls", }) if err := reg.Register(numPolls); err != nil { - log.Error("failed to register polls statistics", - zap.Error(err), - ) + return nil, fmt.Errorf("%w: %w", errFailedPollsMetric, err) } durPolls, err := metric.NewAverager( @@ -71,9 +75,7 @@ func NewSet( reg, ) if err != nil { - log.Error("failed to register poll_duration statistics", - zap.Error(err), - ) + return nil, fmt.Errorf("%w: %w", errFailedPollDurationMetrics, err) } return &set{ @@ -82,7 +84,7 @@ func NewSet( durPolls: durPolls, factory: factory, polls: linkedhashmap.New[uint32, pollHolder](), - } + }, nil } // Add to the current set of polls diff --git a/snow/consensus/snowman/poll/set_test.go b/snow/consensus/snowman/poll/set_test.go index 8200f25dc5f0..70830e3da36c 100644 --- a/snow/consensus/snowman/poll/set_test.go +++ b/snow/consensus/snowman/poll/set_test.go @@ -28,7 +28,7 @@ var ( vdr5 = ids.NodeID{5} ) -func TestNewSetErrorOnMetrics(t *testing.T) { +func TestNewSetErrorOnPollsMetrics(t *testing.T) { require := require.New(t) factory := NewEarlyTermNoTraversalFactory(1, 1) @@ -37,13 +37,29 @@ func TestNewSetErrorOnMetrics(t *testing.T) { registerer := prometheus.NewRegistry() require.NoError(registerer.Register(prometheus.NewCounter(prometheus.CounterOpts{ - Name: "polls", + Namespace: namespace, + Name: "polls", }))) + + _, err := NewSet(factory, log, namespace, registerer) + require.ErrorIs(err, errFailedPollsMetric) +} + +func TestNewSetErrorOnPollDurationMetrics(t *testing.T) { + require := require.New(t) + + factory := NewEarlyTermNoTraversalFactory(1, 1) + log := logging.NoLog{} + namespace := "" + registerer := prometheus.NewRegistry() + require.NoError(registerer.Register(prometheus.NewCounter(prometheus.CounterOpts{ - Name: "poll_duration", + Namespace: namespace, + Name: "poll_duration_count", }))) - require.NotNil(NewSet(factory, log, namespace, registerer)) + _, err := NewSet(factory, log, namespace, registerer) + require.ErrorIs(err, errFailedPollDurationMetrics) } func TestCreateAndFinishPollOutOfOrder_NewerFinishesFirst(t *testing.T) { @@ -56,7 +72,8 @@ func TestCreateAndFinishPollOutOfOrder_NewerFinishesFirst(t *testing.T) { log := logging.NoLog{} namespace := "" registerer := prometheus.NewRegistry() - s := NewSet(factory, log, namespace, registerer) + s, err := NewSet(factory, log, namespace, registerer) + require.NoError(err) // create two polls for the two blocks vdrBag := bag.Of(vdrs...) @@ -92,7 +109,8 @@ func TestCreateAndFinishPollOutOfOrder_OlderFinishesFirst(t *testing.T) { log := logging.NoLog{} namespace := "" registerer := prometheus.NewRegistry() - s := NewSet(factory, log, namespace, registerer) + s, err := NewSet(factory, log, namespace, registerer) + require.NoError(err) // create two polls for the two blocks vdrBag := bag.Of(vdrs...) @@ -128,7 +146,8 @@ func TestCreateAndFinishPollOutOfOrder_UnfinishedPollsGaps(t *testing.T) { log := logging.NoLog{} namespace := "" registerer := prometheus.NewRegistry() - s := NewSet(factory, log, namespace, registerer) + s, err := NewSet(factory, log, namespace, registerer) + require.NoError(err) // create three polls for the two blocks vdrBag := bag.Of(vdrs...) @@ -172,7 +191,8 @@ func TestCreateAndFinishSuccessfulPoll(t *testing.T) { log := logging.NoLog{} namespace := "" registerer := prometheus.NewRegistry() - s := NewSet(factory, log, namespace, registerer) + s, err := NewSet(factory, log, namespace, registerer) + require.NoError(err) require.Zero(s.Len()) @@ -204,7 +224,8 @@ func TestCreateAndFinishFailedPoll(t *testing.T) { log := logging.NoLog{} namespace := "" registerer := prometheus.NewRegistry() - s := NewSet(factory, log, namespace, registerer) + s, err := NewSet(factory, log, namespace, registerer) + require.NoError(err) require.Zero(s.Len()) @@ -233,7 +254,8 @@ func TestSetString(t *testing.T) { log := logging.NoLog{} namespace := "" registerer := prometheus.NewRegistry() - s := NewSet(factory, log, namespace, registerer) + s, err := NewSet(factory, log, namespace, registerer) + require.NoError(err) expected := `current polls: (Size = 1) RequestID 0: diff --git a/snow/engine/common/test_sender.go b/snow/engine/common/test_sender.go index 5896f48dfa25..5b76f3b6a2f4 100644 --- a/snow/engine/common/test_sender.go +++ b/snow/engine/common/test_sender.go @@ -6,7 +6,6 @@ package common import ( "context" "errors" - "testing" "github.com/stretchr/testify/require" @@ -27,7 +26,7 @@ var ( // SenderTest is a test sender type SenderTest struct { - T *testing.T + T require.TestingT CantAccept, CantSendGetStateSummaryFrontier, CantSendStateSummaryFrontier, diff --git a/snow/engine/snowman/transitive.go b/snow/engine/snowman/transitive.go index 803c03237c96..b4c5e3e54b51 100644 --- a/snow/engine/snowman/transitive.go +++ b/snow/engine/snowman/transitive.go @@ -116,6 +116,16 @@ func newTransitive(config Config) (*Transitive, error) { config.Params.AlphaPreference, config.Params.AlphaConfidence, ) + polls, err := poll.NewSet( + factory, + config.Ctx.Log, + "", + config.Ctx.Registerer, + ) + if err != nil { + return nil, err + } + t := &Transitive{ Config: config, StateSummaryFrontierHandler: common.NewNoOpStateSummaryFrontierHandler(config.Ctx.Log), @@ -129,12 +139,7 @@ func newTransitive(config Config) (*Transitive, error) { nonVerifieds: ancestor.NewTree(), nonVerifiedCache: nonVerifiedCache, acceptedFrontiers: acceptedFrontiers, - polls: poll.NewSet( - factory, - config.Ctx.Log, - "", - config.Ctx.Registerer, - ), + polls: polls, } return t, t.metrics.Initialize("", config.Ctx.Registerer) diff --git a/snow/validators/logger.go b/snow/validators/logger.go index 124aef423fc4..2e672a1827ba 100644 --- a/snow/validators/logger.go +++ b/snow/validators/logger.go @@ -7,7 +7,6 @@ import ( "go.uber.org/zap" "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/set" @@ -18,7 +17,6 @@ var _ SetCallbackListener = (*logger)(nil) type logger struct { log logging.Logger - enabled *utils.Atomic[bool] subnetID ids.ID nodeIDs set.Set[ids.NodeID] } @@ -27,14 +25,12 @@ type logger struct { // the specified validators func NewLogger( log logging.Logger, - enabled *utils.Atomic[bool], subnetID ids.ID, nodeIDs ...ids.NodeID, ) SetCallbackListener { nodeIDSet := set.Of(nodeIDs...) return &logger{ log: log, - enabled: enabled, subnetID: subnetID, nodeIDs: nodeIDSet, } @@ -46,7 +42,7 @@ func (l *logger) OnValidatorAdded( txID ids.ID, weight uint64, ) { - if l.enabled.Get() && l.nodeIDs.Contains(nodeID) { + if l.nodeIDs.Contains(nodeID) { var pkBytes []byte if pk != nil { pkBytes = bls.PublicKeyToBytes(pk) @@ -65,7 +61,7 @@ func (l *logger) OnValidatorRemoved( nodeID ids.NodeID, weight uint64, ) { - if l.enabled.Get() && l.nodeIDs.Contains(nodeID) { + if l.nodeIDs.Contains(nodeID) { l.log.Info("node removed from validator set", zap.Stringer("subnetID", l.subnetID), zap.Stringer("nodeID", nodeID), @@ -79,7 +75,7 @@ func (l *logger) OnValidatorWeightChanged( oldWeight uint64, newWeight uint64, ) { - if l.enabled.Get() && l.nodeIDs.Contains(nodeID) { + if l.nodeIDs.Contains(nodeID) { l.log.Info("validator weight changed", zap.Stringer("subnetID", l.subnetID), zap.Stringer("nodeID", nodeID), diff --git a/tests/e2e/banff/suites.go b/tests/e2e/banff/suites.go index 6adeb1476cfa..5bc071d6e004 100644 --- a/tests/e2e/banff/suites.go +++ b/tests/e2e/banff/suites.go @@ -25,7 +25,7 @@ var _ = ginkgo.Describe("[Banff]", func() { ginkgo.It("can send custom assets X->P and P->X", func() { keychain := e2e.Env.NewKeychain(1) - wallet := e2e.Env.NewWallet(keychain, e2e.Env.GetRandomNodeURI()) + wallet := e2e.NewWallet(keychain, e2e.Env.GetRandomNodeURI()) // Get the P-chain and the X-chain wallets pWallet := wallet.P() diff --git a/tests/e2e/c/dynamic_fees.go b/tests/e2e/c/dynamic_fees.go index edfbef2671a8..c8e005621983 100644 --- a/tests/e2e/c/dynamic_fees.go +++ b/tests/e2e/c/dynamic_fees.go @@ -51,7 +51,7 @@ var _ = e2e.DescribeCChain("[Dynamic Fees]", func() { NodeID: node.GetID(), URI: node.GetProcessContext().URI, } - ethClient := e2e.Env.NewEthClient(nodeURI) + ethClient := e2e.NewEthClient(nodeURI) ginkgo.By("initializing a transaction signer") cChainID, err := ethClient.ChainID(e2e.DefaultContext()) diff --git a/tests/e2e/c/interchain_workflow.go b/tests/e2e/c/interchain_workflow.go index 8bed85eb1bd9..d4881255ddff 100644 --- a/tests/e2e/c/interchain_workflow.go +++ b/tests/e2e/c/interchain_workflow.go @@ -34,7 +34,7 @@ var _ = e2e.DescribeCChain("[Interchain Workflow]", func() { // the wallet to avoid having to verify that all nodes are at // the same height before initializing the wallet. nodeURI := e2e.Env.GetRandomNodeURI() - ethClient := e2e.Env.NewEthClient(nodeURI) + ethClient := e2e.NewEthClient(nodeURI) ginkgo.By("allocating a pre-funded key to send from and a recipient key to deliver to") senderKey := e2e.Env.AllocateFundedKey() @@ -79,7 +79,7 @@ var _ = e2e.DescribeCChain("[Interchain Workflow]", func() { // matches on-chain state. ginkgo.By("initializing a keychain and associated wallet") keychain := secp256k1fx.NewKeychain(senderKey, recipientKey) - baseWallet := e2e.Env.NewWallet(keychain, nodeURI) + baseWallet := e2e.NewWallet(keychain, nodeURI) xWallet := baseWallet.X() cWallet := baseWallet.C() pWallet := baseWallet.P() diff --git a/tests/e2e/e2e.go b/tests/e2e/e2e.go index 130f33f1197c..44c5d911e8dd 100644 --- a/tests/e2e/e2e.go +++ b/tests/e2e/e2e.go @@ -127,16 +127,29 @@ func (te *TestEnvironment) NewKeychain(count int) *secp256k1fx.Keychain { return secp256k1fx.NewKeychain(keys...) } +// Create a new private network that is not shared with other tests. +func (te *TestEnvironment) NewPrivateNetwork() testnet.Network { + // Load the shared network to retrieve its path and exec path + sharedNetwork, err := local.ReadNetwork(te.NetworkDir) + te.require.NoError(err) + + // The private networks dir is under the shared network dir to ensure it + // will be included in the artifact uploaded in CI. + privateNetworksDir := filepath.Join(sharedNetwork.Dir, PrivateNetworksDirName) + te.require.NoError(os.MkdirAll(privateNetworksDir, perms.ReadWriteExecute)) + + return StartLocalNetwork(sharedNetwork.ExecPath, privateNetworksDir) +} + // Create a new wallet for the provided keychain against the specified node URI. -// TODO(marun) Make this a regular function. -func (te *TestEnvironment) NewWallet(keychain *secp256k1fx.Keychain, nodeURI testnet.NodeURI) primary.Wallet { +func NewWallet(keychain *secp256k1fx.Keychain, nodeURI testnet.NodeURI) primary.Wallet { tests.Outf("{{blue}} initializing a new wallet for node %s with URI: %s {{/}}\n", nodeURI.NodeID, nodeURI.URI) baseWallet, err := primary.MakeWallet(DefaultContext(), &primary.WalletConfig{ URI: nodeURI.URI, AVAXKeychain: keychain, EthKeychain: keychain, }) - te.require.NoError(err) + require.NoError(ginkgo.GinkgoT(), err) return primary.NewWalletWithOptions( baseWallet, common.WithPostIssuanceFunc( @@ -148,30 +161,15 @@ func (te *TestEnvironment) NewWallet(keychain *secp256k1fx.Keychain, nodeURI tes } // Create a new eth client targeting the specified node URI. -// TODO(marun) Make this a regular function. -func (te *TestEnvironment) NewEthClient(nodeURI testnet.NodeURI) ethclient.Client { +func NewEthClient(nodeURI testnet.NodeURI) ethclient.Client { tests.Outf("{{blue}} initializing a new eth client for node %s with URI: %s {{/}}\n", nodeURI.NodeID, nodeURI.URI) nodeAddress := strings.Split(nodeURI.URI, "//")[1] uri := fmt.Sprintf("ws://%s/ext/bc/C/ws", nodeAddress) client, err := ethclient.Dial(uri) - te.require.NoError(err) + require.NoError(ginkgo.GinkgoT(), err) return client } -// Create a new private network that is not shared with other tests. -func (te *TestEnvironment) NewPrivateNetwork() testnet.Network { - // Load the shared network to retrieve its path and exec path - sharedNetwork, err := local.ReadNetwork(te.NetworkDir) - te.require.NoError(err) - - // The private networks dir is under the shared network dir to ensure it - // will be included in the artifact uploaded in CI. - privateNetworksDir := filepath.Join(sharedNetwork.Dir, PrivateNetworksDirName) - te.require.NoError(os.MkdirAll(privateNetworksDir, perms.ReadWriteExecute)) - - return StartLocalNetwork(sharedNetwork.ExecPath, privateNetworksDir) -} - // Helper simplifying use of a timed context by canceling the context on ginkgo teardown. func ContextWithTimeout(duration time.Duration) context.Context { ctx, cancel := context.WithTimeout(context.Background(), duration) diff --git a/tests/e2e/p/interchain_workflow.go b/tests/e2e/p/interchain_workflow.go index 729418adbd97..9bea416294cb 100644 --- a/tests/e2e/p/interchain_workflow.go +++ b/tests/e2e/p/interchain_workflow.go @@ -53,7 +53,7 @@ var _ = e2e.DescribePChain("[Interchain Workflow]", ginkgo.Label(e2e.UsesCChainL keychain := e2e.Env.NewKeychain(1) keychain.Add(recipientKey) nodeURI := e2e.Env.GetRandomNodeURI() - baseWallet := e2e.Env.NewWallet(keychain, nodeURI) + baseWallet := e2e.NewWallet(keychain, nodeURI) xWallet := baseWallet.X() cWallet := baseWallet.C() pWallet := baseWallet.P() @@ -202,7 +202,7 @@ var _ = e2e.DescribePChain("[Interchain Workflow]", ginkgo.Label(e2e.UsesCChainL }) ginkgo.By("initializing a new eth client") - ethClient := e2e.Env.NewEthClient(nodeURI) + ethClient := e2e.NewEthClient(nodeURI) ginkgo.By("importing AVAX from the P-Chain to the C-Chain", func() { _, err := cWallet.IssueImportTx( diff --git a/tests/e2e/p/permissionless_subnets.go b/tests/e2e/p/permissionless_subnets.go index 1369685bf077..ab2909228365 100644 --- a/tests/e2e/p/permissionless_subnets.go +++ b/tests/e2e/p/permissionless_subnets.go @@ -32,7 +32,7 @@ var _ = e2e.DescribePChain("[Permissionless Subnets]", func() { nodeURI := e2e.Env.GetRandomNodeURI() keychain := e2e.Env.NewKeychain(1) - baseWallet := e2e.Env.NewWallet(keychain, nodeURI) + baseWallet := e2e.NewWallet(keychain, nodeURI) pWallet := baseWallet.P() xWallet := baseWallet.X() diff --git a/tests/e2e/p/staking_rewards.go b/tests/e2e/p/staking_rewards.go index df64088103ce..09b169dfb8f4 100644 --- a/tests/e2e/p/staking_rewards.go +++ b/tests/e2e/p/staking_rewards.go @@ -90,7 +90,7 @@ var _ = ginkgo.Describe("[Staking Rewards]", func() { fundedKey := e2e.Env.AllocateFundedKey() keychain.Add(fundedKey) nodeURI := e2e.Env.GetRandomNodeURI() - baseWallet := e2e.Env.NewWallet(keychain, nodeURI) + baseWallet := e2e.NewWallet(keychain, nodeURI) pWallet := baseWallet.P() ginkgo.By("retrieving alpha node id and pop") @@ -261,7 +261,7 @@ var _ = ginkgo.Describe("[Staking Rewards]", func() { rewardBalances := make(map[ids.ShortID]uint64, len(rewardKeys)) for _, rewardKey := range rewardKeys { keychain := secp256k1fx.NewKeychain(rewardKey) - baseWallet := e2e.Env.NewWallet(keychain, nodeURI) + baseWallet := e2e.NewWallet(keychain, nodeURI) pWallet := baseWallet.P() balances, err := pWallet.Builder().GetBalance() require.NoError(err) diff --git a/tests/e2e/p/workflow.go b/tests/e2e/p/workflow.go index 96bf8bafc02c..c0fda23775fd 100644 --- a/tests/e2e/p/workflow.go +++ b/tests/e2e/p/workflow.go @@ -36,7 +36,7 @@ var _ = e2e.DescribePChain("[Workflow]", func() { func() { nodeURI := e2e.Env.GetRandomNodeURI() keychain := e2e.Env.NewKeychain(2) - baseWallet := e2e.Env.NewWallet(keychain, nodeURI) + baseWallet := e2e.NewWallet(keychain, nodeURI) pWallet := baseWallet.P() avaxAssetID := baseWallet.P().AVAXAssetID() diff --git a/tests/e2e/x/interchain_workflow.go b/tests/e2e/x/interchain_workflow.go index 6d335199b5b9..373e567a666a 100644 --- a/tests/e2e/x/interchain_workflow.go +++ b/tests/e2e/x/interchain_workflow.go @@ -36,7 +36,7 @@ var _ = e2e.DescribeXChain("[Interchain Workflow]", ginkgo.Label(e2e.UsesCChainL require.NoError(err) keychain := e2e.Env.NewKeychain(1) keychain.Add(recipientKey) - baseWallet := e2e.Env.NewWallet(keychain, nodeURI) + baseWallet := e2e.NewWallet(keychain, nodeURI) xWallet := baseWallet.X() cWallet := baseWallet.C() pWallet := baseWallet.P() @@ -103,7 +103,7 @@ var _ = e2e.DescribeXChain("[Interchain Workflow]", ginkgo.Label(e2e.UsesCChainL }) ginkgo.By("initializing a new eth client") - ethClient := e2e.Env.NewEthClient(nodeURI) + ethClient := e2e.NewEthClient(nodeURI) ginkgo.By("importing AVAX from the X-Chain to the C-Chain", func() { _, err := cWallet.IssueImportTx( diff --git a/tests/e2e/x/transfer/virtuous.go b/tests/e2e/x/transfer/virtuous.go index d0ee950a53d6..119e28b9d1e9 100644 --- a/tests/e2e/x/transfer/virtuous.go +++ b/tests/e2e/x/transfer/virtuous.go @@ -83,7 +83,7 @@ var _ = e2e.DescribeXChainSerial("[Virtuous Transfer Tx AVAX]", func() { } keychain := secp256k1fx.NewKeychain(testKeys...) - baseWallet := e2e.Env.NewWallet(keychain, e2e.Env.GetRandomNodeURI()) + baseWallet := e2e.NewWallet(keychain, e2e.Env.GetRandomNodeURI()) avaxAssetID := baseWallet.X().AVAXAssetID() wallets := make([]primary.Wallet, len(testKeys)) diff --git a/utils/compare/compare.go b/utils/compare/compare.go deleted file mode 100644 index 13ec52f386cb..000000000000 --- a/utils/compare/compare.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package compare - -// Returns true iff the slices have the same elements, regardless of order. -func UnsortedEquals[T comparable](a, b []T) bool { - if len(a) != len(b) { - return false - } - m := make(map[T]int, len(a)) - for _, v := range a { - m[v]++ - } - for _, v := range b { - switch count := m[v]; count { - case 0: - // There were more instances of [v] in [b] than [a]. - return false - case 1: - delete(m, v) - default: - m[v] = count - 1 - } - } - return len(m) == 0 -} diff --git a/utils/compare/compare_test.go b/utils/compare/compare_test.go deleted file mode 100644 index e46bc838f72b..000000000000 --- a/utils/compare/compare_test.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package compare - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestUnsortedEquals(t *testing.T) { - require := require.New(t) - - require.True(UnsortedEquals([]int{}, []int{})) - require.True(UnsortedEquals(nil, []int{})) - require.True(UnsortedEquals([]int{}, nil)) - require.False(UnsortedEquals([]int{1}, nil)) - require.False(UnsortedEquals(nil, []int{1})) - require.True(UnsortedEquals([]int{1}, []int{1})) - require.False(UnsortedEquals([]int{1, 2}, []int{})) - require.False(UnsortedEquals([]int{1, 2}, []int{1})) - require.False(UnsortedEquals([]int{1}, []int{1, 2})) - require.True(UnsortedEquals([]int{2, 1}, []int{1, 2})) - require.True(UnsortedEquals([]int{1, 2}, []int{2, 1})) -} diff --git a/vms/platformvm/block/builder/helpers_test.go b/vms/platformvm/block/builder/helpers_test.go index 187c0eb92a70..ed55791147d8 100644 --- a/vms/platformvm/block/builder/helpers_test.go +++ b/vms/platformvm/block/builder/helpers_test.go @@ -236,7 +236,6 @@ func defaultState( ctx, metrics.Noop, rewards, - &utils.Atomic[bool]{}, ) require.NoError(err) diff --git a/vms/platformvm/block/executor/helpers_test.go b/vms/platformvm/block/executor/helpers_test.go index 7d5a67566472..5e79d52c9b10 100644 --- a/vms/platformvm/block/executor/helpers_test.go +++ b/vms/platformvm/block/executor/helpers_test.go @@ -275,7 +275,6 @@ func defaultState( ctx, metrics.Noop, rewards, - &utils.Atomic[bool]{}, ) if err != nil { panic(err) diff --git a/vms/platformvm/block/executor/proposal_block_test.go b/vms/platformvm/block/executor/proposal_block_test.go index 880817414ea8..4300ad9606d9 100644 --- a/vms/platformvm/block/executor/proposal_block_test.go +++ b/vms/platformvm/block/executor/proposal_block_test.go @@ -947,7 +947,7 @@ func TestBanffProposalBlockTrackedSubnet(t *testing.T) { require.NoError(propBlk.Accept(context.Background())) require.NoError(commitBlk.Accept(context.Background())) _, ok := env.config.Validators.GetValidator(subnetID, subnetValidatorNodeID) - require.Equal(tracked, ok) + require.True(ok) }) } } diff --git a/vms/platformvm/block/executor/standard_block_test.go b/vms/platformvm/block/executor/standard_block_test.go index 76ae7ca55de6..110def5c5987 100644 --- a/vms/platformvm/block/executor/standard_block_test.go +++ b/vms/platformvm/block/executor/standard_block_test.go @@ -747,7 +747,7 @@ func TestBanffStandardBlockTrackedSubnet(t *testing.T) { require.NoError(block.Verify(context.Background())) require.NoError(block.Accept(context.Background())) _, ok := env.config.Validators.GetValidator(subnetID, subnetValidatorNodeID) - require.Equal(tracked, ok) + require.True(ok) }) } } diff --git a/vms/platformvm/state/mock_state.go b/vms/platformvm/state/mock_state.go index 1e0265798d33..41ce946a12e0 100644 --- a/vms/platformvm/state/mock_state.go +++ b/vms/platformvm/state/mock_state.go @@ -144,20 +144,6 @@ func (mr *MockStateMockRecorder) AddUTXO(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUTXO", reflect.TypeOf((*MockState)(nil).AddUTXO), arg0) } -// ApplyCurrentValidators mocks base method. -func (m *MockState) ApplyCurrentValidators(arg0 ids.ID, arg1 validators.Manager) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ApplyCurrentValidators", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// ApplyCurrentValidators indicates an expected call of ApplyCurrentValidators. -func (mr *MockStateMockRecorder) ApplyCurrentValidators(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyCurrentValidators", reflect.TypeOf((*MockState)(nil).ApplyCurrentValidators), arg0, arg1) -} - // ApplyValidatorPublicKeyDiffs mocks base method. func (m *MockState) ApplyValidatorPublicKeyDiffs(arg0 context.Context, arg1 map[ids.NodeID]*validators.GetValidatorOutput, arg2, arg3 uint64) error { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index 872d9b669219..e64d24d434b1 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -140,10 +140,6 @@ type State interface { GetBlockIDAtHeight(height uint64) (ids.ID, error) - // ApplyCurrentValidators adds all the current validators and delegators of - // [subnetID] into [vdrs]. - ApplyCurrentValidators(subnetID ids.ID, vdrs validators.Manager) error - // ApplyValidatorWeightDiffs iterates from [startHeight] towards the genesis // block until it has applied all of the diffs up to and including // [endHeight]. Applying the diffs modifies [validators]. @@ -290,11 +286,10 @@ type stateBlk struct { type state struct { validatorState - cfg *config.Config - ctx *snow.Context - metrics metrics.Metrics - rewards reward.Calculator - bootstrapped *utils.Atomic[bool] + cfg *config.Config + ctx *snow.Context + metrics metrics.Metrics + rewards reward.Calculator baseDB *versiondb.Database @@ -461,7 +456,6 @@ func New( ctx *snow.Context, metrics metrics.Metrics, rewards reward.Calculator, - bootstrapped *utils.Atomic[bool], ) (State, error) { s, err := newState( db, @@ -471,7 +465,6 @@ func New( ctx, metricsReg, rewards, - bootstrapped, ) if err != nil { return nil, err @@ -516,7 +509,6 @@ func newState( ctx *snow.Context, metricsReg prometheus.Registerer, rewards reward.Calculator, - bootstrapped *utils.Atomic[bool], ) (*state, error) { blockIDCache, err := metercacher.New[uint64, ids.ID]( "block_id_cache", @@ -635,12 +627,11 @@ func newState( return &state{ validatorState: newValidatorState(), - cfg: cfg, - ctx: ctx, - metrics: metrics, - rewards: rewards, - bootstrapped: bootstrapped, - baseDB: baseDB, + cfg: cfg, + ctx: ctx, + metrics: metrics, + rewards: rewards, + baseDB: baseDB, addedBlockIDs: make(map[uint64]ids.ID), blockIDCache: blockIDCache, @@ -1139,26 +1130,6 @@ func (s *state) SetCurrentSupply(subnetID ids.ID, cs uint64) { } } -func (s *state) ApplyCurrentValidators(subnetID ids.ID, vdrs validators.Manager) error { - for nodeID, validator := range s.currentStakers.validators[subnetID] { - staker := validator.validator - if err := vdrs.AddStaker(subnetID, nodeID, staker.PublicKey, staker.TxID, staker.Weight); err != nil { - return err - } - - delegatorIterator := NewTreeIterator(validator.delegators) - for delegatorIterator.Next() { - staker := delegatorIterator.Value() - if err := vdrs.AddWeight(subnetID, nodeID, staker.Weight); err != nil { - delegatorIterator.Release() - return err - } - } - delegatorIterator.Release() - } - return nil -} - func (s *state) ApplyValidatorWeightDiffs( ctx context.Context, validators map[ids.NodeID]*validators.GetValidatorOutput, @@ -1689,17 +1660,29 @@ func (s *state) loadPendingValidators() error { // Invariant: initValidatorSets requires loadCurrentValidators to have already // been called. func (s *state) initValidatorSets() error { - if s.cfg.Validators.Count(constants.PrimaryNetworkID) != 0 { - // Enforce the invariant that the validator set is empty here. - return errValidatorSetAlreadyPopulated - } - err := s.ApplyCurrentValidators(constants.PrimaryNetworkID, s.cfg.Validators) - if err != nil { - return err - } + for subnetID, validators := range s.currentStakers.validators { + if s.cfg.Validators.Count(subnetID) != 0 { + // Enforce the invariant that the validator set is empty here. + return fmt.Errorf("%w: %s", errValidatorSetAlreadyPopulated, subnetID) + } - vl := validators.NewLogger(s.ctx.Log, s.bootstrapped, constants.PrimaryNetworkID, s.ctx.NodeID) - s.cfg.Validators.RegisterCallbackListener(constants.PrimaryNetworkID, vl) + for nodeID, validator := range validators { + validatorStaker := validator.validator + if err := s.cfg.Validators.AddStaker(subnetID, nodeID, validatorStaker.PublicKey, validatorStaker.TxID, validatorStaker.Weight); err != nil { + return err + } + + delegatorIterator := NewTreeIterator(validator.delegators) + for delegatorIterator.Next() { + delegatorStaker := delegatorIterator.Value() + if err := s.cfg.Validators.AddWeight(subnetID, nodeID, delegatorStaker.Weight); err != nil { + delegatorIterator.Release() + return err + } + } + delegatorIterator.Release() + } + } s.metrics.SetLocalStake(s.cfg.Validators.GetWeight(constants.PrimaryNetworkID, s.ctx.NodeID)) totalWeight, err := s.cfg.Validators.TotalWeight(constants.PrimaryNetworkID) @@ -1707,20 +1690,6 @@ func (s *state) initValidatorSets() error { return fmt.Errorf("failed to get total weight of primary network validators: %w", err) } s.metrics.SetTotalStake(totalWeight) - - for subnetID := range s.cfg.TrackedSubnets { - if s.cfg.Validators.Count(subnetID) != 0 { - // Enforce the invariant that the validator set is empty here. - return errValidatorSetAlreadyPopulated - } - err := s.ApplyCurrentValidators(subnetID, s.cfg.Validators) - if err != nil { - return err - } - - vl := validators.NewLogger(s.ctx.Log, s.bootstrapped, subnetID, s.ctx.NodeID) - s.cfg.Validators.RegisterCallbackListener(subnetID, vl) - } return nil } @@ -2109,11 +2078,6 @@ func (s *state) writeCurrentStakers(updateValidators bool, height uint64) error continue } - // We only track the current validator set of tracked subnets. - if subnetID != constants.PrimaryNetworkID && !s.cfg.TrackedSubnets.Contains(subnetID) { - continue - } - if weightDiff.Decrease { err = s.cfg.Validators.RemoveWeight(subnetID, nodeID, weightDiff.Amount) } else { diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index 5a29619c1beb..ae79415f4bbf 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -21,7 +21,6 @@ import ( "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/choices" "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/units" @@ -178,7 +177,6 @@ func newStateFromDB(require *require.Assertions, db database.Database) State { MintingPeriod: 365 * 24 * time.Hour, SupplyCap: 720 * units.MegaAvax, }), - &utils.Atomic[bool]{}, ) require.NoError(err) require.NotNil(state) diff --git a/vms/platformvm/txs/executor/advance_time_test.go b/vms/platformvm/txs/executor/advance_time_test.go index 9bf5aafed7ac..5f9eabe0553c 100644 --- a/vms/platformvm/txs/executor/advance_time_test.go +++ b/vms/platformvm/txs/executor/advance_time_test.go @@ -617,7 +617,7 @@ func TestTrackedSubnet(t *testing.T) { env.state.SetHeight(dummyHeight) require.NoError(env.state.Commit()) _, ok := env.config.Validators.GetValidator(subnetID, ids.NodeID(subnetValidatorNodeID)) - require.Equal(tracked, ok) + require.True(ok) }) } } diff --git a/vms/platformvm/txs/executor/helpers_test.go b/vms/platformvm/txs/executor/helpers_test.go index aee0d184d5f8..b1993584cb7a 100644 --- a/vms/platformvm/txs/executor/helpers_test.go +++ b/vms/platformvm/txs/executor/helpers_test.go @@ -224,7 +224,6 @@ func defaultState( ctx, metrics.Noop, rewards, - &utils.Atomic[bool]{}, ) if err != nil { panic(err) diff --git a/vms/platformvm/validators/manager.go b/vms/platformvm/validators/manager.go index fb7c314c90a7..a4c5c87a3040 100644 --- a/vms/platformvm/validators/manager.go +++ b/vms/platformvm/validators/manager.go @@ -48,10 +48,6 @@ type State interface { GetLastAccepted() ids.ID GetStatelessBlock(blockID ids.ID) (block.Block, error) - // ApplyCurrentValidators adds all the current validators and delegators of - // [subnetID] into [vdrs]. - ApplyCurrentValidators(subnetID ids.ID, vdrs validators.Manager) error - // ApplyValidatorWeightDiffs iterates from [startHeight] towards the genesis // block until it has applied all of the diffs up to and including // [endHeight]. Applying the diffs modifies [validators]. @@ -346,22 +342,7 @@ func (m *manager) getCurrentValidatorSets( ctx context.Context, subnetID ids.ID, ) (map[ids.NodeID]*validators.GetValidatorOutput, map[ids.NodeID]*validators.GetValidatorOutput, uint64, error) { - subnetManager := m.cfg.Validators - if subnetManager.Count(subnetID) == 0 { - // If this subnet isn't tracked, there will not be any registered - // validators. To calculate the current validators we need to first - // fetch them from state. We generate a new manager as we don't want to - // modify that long-lived reference. - // - // TODO: remove this once all subnets are included in the validator - // manager. - subnetManager = validators.NewManager() - if err := m.state.ApplyCurrentValidators(subnetID, subnetManager); err != nil { - return nil, nil, 0, err - } - } - - subnetMap := subnetManager.GetMap(subnetID) + subnetMap := m.cfg.Validators.GetMap(subnetID) primaryMap := m.cfg.Validators.GetMap(constants.PrimaryNetworkID) currentHeight, err := m.getCurrentHeight(ctx) return subnetMap, primaryMap, currentHeight, err diff --git a/vms/platformvm/validators/manager_benchmark_test.go b/vms/platformvm/validators/manager_benchmark_test.go index 54d0e264e63e..d88fca7d7e32 100644 --- a/vms/platformvm/validators/manager_benchmark_test.go +++ b/vms/platformvm/validators/manager_benchmark_test.go @@ -17,7 +17,6 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/formatting" @@ -129,7 +128,6 @@ func BenchmarkGetValidatorSet(b *testing.B) { MintingPeriod: 365 * 24 * time.Hour, SupplyCap: 720 * units.MegaAvax, }), - new(utils.Atomic[bool]), ) require.NoError(err) diff --git a/vms/platformvm/vm.go b/vms/platformvm/vm.go index c1a911f5b919..8522d3ae4715 100644 --- a/vms/platformvm/vm.go +++ b/vms/platformvm/vm.go @@ -142,7 +142,6 @@ func (vm *VM) Initialize( vm.ctx, vm.metrics, rewards, - &vm.bootstrapped, ) if err != nil { return err @@ -304,17 +303,21 @@ func (vm *VM) onNormalOperationsStarted() error { } primaryVdrIDs := vm.Validators.GetValidatorIDs(constants.PrimaryNetworkID) - if err := vm.uptimeManager.StartTracking(primaryVdrIDs, constants.PrimaryNetworkID); err != nil { return err } + vl := validators.NewLogger(vm.ctx.Log, constants.PrimaryNetworkID, vm.ctx.NodeID) + vm.Validators.RegisterCallbackListener(constants.PrimaryNetworkID, vl) + for subnetID := range vm.TrackedSubnets { vdrIDs := vm.Validators.GetValidatorIDs(subnetID) - if err := vm.uptimeManager.StartTracking(vdrIDs, subnetID); err != nil { return err } + + vl := validators.NewLogger(vm.ctx.Log, subnetID, vm.ctx.NodeID) + vm.Validators.RegisterCallbackListener(subnetID, vl) } if err := vm.state.Commit(); err != nil { diff --git a/vms/platformvm/vm_regression_test.go b/vms/platformvm/vm_regression_test.go index 8416c4114662..e94a0ea99406 100644 --- a/vms/platformvm/vm_regression_test.go +++ b/vms/platformvm/vm_regression_test.go @@ -25,7 +25,6 @@ import ( "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/snow/uptime" "github.com/ava-labs/avalanchego/snow/validators" - "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/crypto/secp256k1" @@ -659,7 +658,6 @@ func TestRejectedStateRegressionInvalidValidatorTimestamp(t *testing.T) { vm.ctx, metrics.Noop, reward.NewCalculator(vm.Config.RewardConfig), - &utils.Atomic[bool]{}, ) require.NoError(err) @@ -968,7 +966,6 @@ func TestRejectedStateRegressionInvalidValidatorReward(t *testing.T) { vm.ctx, metrics.Noop, reward.NewCalculator(vm.Config.RewardConfig), - &utils.Atomic[bool]{}, ) require.NoError(err) @@ -1397,9 +1394,6 @@ func TestRemovePermissionedValidatorDuringPendingToCurrentTransitionTracked(t *t require.NoError(createSubnetBlock.Accept(context.Background())) require.NoError(vm.SetPreference(context.Background(), vm.manager.LastAccepted())) - vm.TrackedSubnets.Add(createSubnetTx.ID()) - require.NoError(vm.state.ApplyCurrentValidators(createSubnetTx.ID(), vm.Validators)) - addSubnetValidatorTx, err := vm.txBuilder.NewAddSubnetValidatorTx( defaultMaxValidatorStake, uint64(validatorStartTime.Unix()), diff --git a/x/merkledb/README.md b/x/merkledb/README.md index 467a60e19b08..6c7d9d68775c 100644 --- a/x/merkledb/README.md +++ b/x/merkledb/README.md @@ -76,8 +76,8 @@ The node serialization format is as follows: Where: * `Value existence flag` is `1` if this node has a value, otherwise `0`. -* `Value length` is the length of the value, if it exists (i.e. if `Value existince flag` is `1`.) Otherwise not serialized. -* `Value` is the value, if it exists (i.e. if `Value existince flag` is `1`.) Otherwise not serialized. +* `Value length` is the length of the value, if it exists (i.e. if `Value existence flag` is `1`.) Otherwise not serialized. +* `Value` is the value, if it exists (i.e. if `Value existence flag` is `1`.) Otherwise not serialized. * `Number of children` is the number of children this node has. * `Child index` is the index of a child node within the list of the node's children. * `Child compressed key length` is the length of the child node's compressed key. @@ -197,8 +197,8 @@ Where: * `Child index` is the index of a child node within the list of the node's children. * `Child ID` is the child node's ID. * `Value existence flag` is `1` if this node has a value, otherwise `0`. -* `Value length` is the length of the value, if it exists (i.e. if `Value existince flag` is `1`.) Otherwise not serialized. -* `Value` is the value, if it exists (i.e. if `Value existince flag` is `1`.) Otherwise not serialized. +* `Value length` is the length of the value, if it exists (i.e. if `Value existence flag` is `1`.) Otherwise not serialized. +* `Value` is the value, if it exists (i.e. if `Value existence flag` is `1`.) Otherwise not serialized. * `Key length` is the number of nibbles in this node's key. * `Key` is the node's key. diff --git a/x/merkledb/codec.go b/x/merkledb/codec.go index e7ef1eddb7f5..a7decc6f6436 100644 --- a/x/merkledb/codec.go +++ b/x/merkledb/codec.go @@ -44,7 +44,6 @@ var ( trueBytes = []byte{trueByte} falseBytes = []byte{falseByte} - errTooManyChildren = errors.New("length of children list is larger than branching factor") errChildIndexTooLarge = errors.New("invalid child index. Must be less than branching factor") errLeadingZeroes = errors.New("varint has leading zeroes") errInvalidBool = errors.New("decoded bool is neither true nor false") @@ -63,13 +62,15 @@ type encoderDecoder interface { type encoder interface { // Assumes [n] is non-nil. encodeDBNode(n *dbNode) []byte - // Assumes [hv] is non-nil. - encodeHashValues(hv *hashValues) []byte + + // Returns the bytes that will be hashed to generate [n]'s ID. + // Assumes [n] is non-nil. + encodeHashValues(n *node) []byte } type decoder interface { // Assumes [n] is non-nil. - decodeDBNode(bytes []byte, n *dbNode, factor BranchFactor) error + decodeDBNode(bytes []byte, n *dbNode) error } func newCodec() encoderDecoder { @@ -114,9 +115,9 @@ func (c *codecImpl) encodeDBNode(n *dbNode) []byte { return buf.Bytes() } -func (c *codecImpl) encodeHashValues(hv *hashValues) []byte { +func (c *codecImpl) encodeHashValues(n *node) []byte { var ( - numChildren = len(hv.Children) + numChildren = len(n.children) // Estimate size [hv] to prevent memory allocations estimatedLen = minVarIntLen + numChildren*hashValuesChildLen + estimatedValueLen + estimatedKeyLen buf = bytes.NewBuffer(make([]byte, 0, estimatedLen)) @@ -125,19 +126,20 @@ func (c *codecImpl) encodeHashValues(hv *hashValues) []byte { c.encodeUint(buf, uint64(numChildren)) // ensure that the order of entries is consistent - for index := 0; BranchFactor(index) < hv.Key.branchFactor; index++ { - if entry, ok := hv.Children[byte(index)]; ok { - c.encodeUint(buf, uint64(index)) - _, _ = buf.Write(entry.id[:]) - } + keys := maps.Keys(n.children) + slices.Sort(keys) + for _, index := range keys { + entry := n.children[index] + c.encodeUint(buf, uint64(index)) + _, _ = buf.Write(entry.id[:]) } - c.encodeMaybeByteSlice(buf, hv.Value) - c.encodeKey(buf, hv.Key) + c.encodeMaybeByteSlice(buf, n.valueDigest) + c.encodeKey(buf, n.key) return buf.Bytes() } -func (c *codecImpl) decodeDBNode(b []byte, n *dbNode, branchFactor BranchFactor) error { +func (c *codecImpl) decodeDBNode(b []byte, n *dbNode) error { if minDBNodeLen > len(b) { return io.ErrUnexpectedEOF } @@ -154,25 +156,23 @@ func (c *codecImpl) decodeDBNode(b []byte, n *dbNode, branchFactor BranchFactor) switch { case err != nil: return err - case numChildren > uint64(branchFactor): - return errTooManyChildren case numChildren > uint64(src.Len()/minChildLen): return io.ErrUnexpectedEOF } - n.children = make(map[byte]child, branchFactor) + n.children = make(map[byte]child, numChildren) var previousChild uint64 for i := uint64(0); i < numChildren; i++ { index, err := c.decodeUint(src) if err != nil { return err } - if index >= uint64(branchFactor) || (i != 0 && index <= previousChild) { + if (i != 0 && index <= previousChild) || index > math.MaxUint8 { return errChildIndexTooLarge } previousChild = index - compressedKey, err := c.decodeKey(src, branchFactor) + compressedKey, err := c.decodeKey(src) if err != nil { return err } @@ -331,11 +331,11 @@ func (*codecImpl) decodeID(src *bytes.Reader) (ids.ID, error) { } func (c *codecImpl) encodeKey(dst *bytes.Buffer, key Key) { - c.encodeUint(dst, uint64(key.tokenLength)) + c.encodeUint(dst, uint64(key.length)) _, _ = dst.Write(key.Bytes()) } -func (c *codecImpl) decodeKey(src *bytes.Reader, branchFactor BranchFactor) (Key, error) { +func (c *codecImpl) decodeKey(src *bytes.Reader) (Key, error) { if minKeyLen > src.Len() { return Key{}, io.ErrUnexpectedEOF } @@ -347,9 +347,10 @@ func (c *codecImpl) decodeKey(src *bytes.Reader, branchFactor BranchFactor) (Key if length > math.MaxInt { return Key{}, errIntOverflow } - result := emptyKey(branchFactor) - result.tokenLength = int(length) - keyBytesLen := result.bytesNeeded(result.tokenLength) + result := Key{ + length: int(length), + } + keyBytesLen := bytesNeeded(result.length) if keyBytesLen > src.Len() { return Key{}, io.ErrUnexpectedEOF } @@ -363,8 +364,8 @@ func (c *codecImpl) decodeKey(src *bytes.Reader, branchFactor BranchFactor) (Key if result.hasPartialByte() { // Confirm that the padding bits in the partial byte are 0. // We want to only look at the bits to the right of the last token, which is at index length-1. - // Generate a mask with (8-bitsToShift) 0s followed by bitsToShift 1s. - paddingMask := byte(0xFF >> (8 - result.bitsToShift(result.tokenLength-1))) + // Generate a mask where the (result.length % 8) left bits are 0. + paddingMask := byte(0xFF >> (result.length % 8)) if buffer[keyBytesLen-1]&paddingMask != 0 { return Key{}, errNonZeroKeyPadding } diff --git a/x/merkledb/codec_test.go b/x/merkledb/codec_test.go index cb83e1ce582c..00e5790b3171 100644 --- a/x/merkledb/codec_test.go +++ b/x/merkledb/codec_test.go @@ -80,24 +80,22 @@ func FuzzCodecKey(f *testing.F) { b []byte, ) { require := require.New(t) - for _, branchFactor := range branchFactors { - codec := codec.(*codecImpl) - reader := bytes.NewReader(b) - startLen := reader.Len() - got, err := codec.decodeKey(reader, branchFactor) - if err != nil { - t.SkipNow() - } - endLen := reader.Len() - numRead := startLen - endLen - - // Encoding [got] should be the same as [b]. - var buf bytes.Buffer - codec.encodeKey(&buf, got) - bufBytes := buf.Bytes() - require.Len(bufBytes, numRead) - require.Equal(b[:numRead], bufBytes) + codec := codec.(*codecImpl) + reader := bytes.NewReader(b) + startLen := reader.Len() + got, err := codec.decodeKey(reader) + if err != nil { + t.SkipNow() } + endLen := reader.Len() + numRead := startLen - endLen + + // Encoding [got] should be the same as [b]. + var buf bytes.Buffer + codec.encodeKey(&buf, got) + bufBytes := buf.Bytes() + require.Len(bufBytes, numRead) + require.Equal(b[:numRead], bufBytes) }, ) } @@ -109,17 +107,15 @@ func FuzzCodecDBNodeCanonical(f *testing.F) { b []byte, ) { require := require.New(t) - for _, branchFactor := range branchFactors { - codec := codec.(*codecImpl) - node := &dbNode{} - if err := codec.decodeDBNode(b, node, branchFactor); err != nil { - t.SkipNow() - } - - // Encoding [node] should be the same as [b]. - buf := codec.encodeDBNode(node) - require.Equal(b, buf) + codec := codec.(*codecImpl) + node := &dbNode{} + if err := codec.decodeDBNode(b, node); err != nil { + t.SkipNow() } + + // Encoding [node] should be the same as [b]. + buf := codec.encodeDBNode(node) + require.Equal(b, buf) }, ) } @@ -133,7 +129,7 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) { valueBytes []byte, ) { require := require.New(t) - for _, branchFactor := range branchFactors { + for _, bf := range validBranchFactors { r := rand.New(rand.NewSource(int64(randSeed))) // #nosec G404 value := maybe.Nothing[[]byte]() @@ -148,7 +144,7 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) { value = maybe.Some(valueBytes) } - numChildren := r.Intn(int(branchFactor)) // #nosec G404 + numChildren := r.Intn(int(bf)) // #nosec G404 children := map[byte]child{} for i := 0; i < numChildren; i++ { @@ -159,7 +155,7 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) { _, _ = r.Read(childKeyBytes) // #nosec G404 children[byte(i)] = child{ - compressedKey: ToKey(childKeyBytes, branchFactor), + compressedKey: ToKey(childKeyBytes), id: childID, } } @@ -171,7 +167,7 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) { nodeBytes := codec.encodeDBNode(&node) var gotNode dbNode - require.NoError(codec.decodeDBNode(nodeBytes, &gotNode, branchFactor)) + require.NoError(codec.decodeDBNode(nodeBytes, &gotNode)) require.Equal(node, gotNode) nodeBytes2 := codec.encodeDBNode(&gotNode) @@ -181,31 +177,15 @@ func FuzzCodecDBNodeDeterministic(f *testing.F) { ) } -func TestCodecDecodeDBNode(t *testing.T) { +func TestCodecDecodeDBNode_TooShort(t *testing.T) { require := require.New(t) var ( parsedDBNode dbNode tooShortBytes = make([]byte, minDBNodeLen-1) ) - err := codec.decodeDBNode(tooShortBytes, &parsedDBNode, BranchFactor16) + err := codec.decodeDBNode(tooShortBytes, &parsedDBNode) require.ErrorIs(err, io.ErrUnexpectedEOF) - - proof := dbNode{ - value: maybe.Some([]byte{1}), - children: map[byte]child{}, - } - - nodeBytes := codec.encodeDBNode(&proof) - // Remove num children (0) from end - nodeBytes = nodeBytes[:len(nodeBytes)-minVarIntLen] - proofBytesBuf := bytes.NewBuffer(nodeBytes) - - // Put num children > branch factor - codec.(*codecImpl).encodeUint(proofBytesBuf, uint64(BranchFactor16+1)) - - err = codec.decodeDBNode(proofBytesBuf.Bytes(), &parsedDBNode, BranchFactor16) - require.ErrorIs(err, errTooManyChildren) } // Ensure that encodeHashValues is deterministic @@ -219,18 +199,18 @@ func FuzzEncodeHashValues(f *testing.F) { randSeed int, ) { require := require.New(t) - for _, branchFactor := range branchFactors { // Create a random *hashValues + for _, bf := range validBranchFactors { // Create a random node r := rand.New(rand.NewSource(int64(randSeed))) // #nosec G404 children := map[byte]child{} - numChildren := r.Intn(int(branchFactor)) // #nosec G404 + numChildren := r.Intn(int(bf)) // #nosec G404 for i := 0; i < numChildren; i++ { compressedKeyLen := r.Intn(32) // #nosec G404 compressedKeyBytes := make([]byte, compressedKeyLen) _, _ = r.Read(compressedKeyBytes) // #nosec G404 children[byte(i)] = child{ - compressedKey: ToKey(compressedKeyBytes, branchFactor), + compressedKey: ToKey(compressedKeyBytes), id: ids.GenerateTestID(), hasValue: r.Intn(2) == 1, // #nosec G404 } @@ -247,13 +227,15 @@ func FuzzEncodeHashValues(f *testing.F) { key := make([]byte, r.Intn(32)) // #nosec G404 _, _ = r.Read(key) // #nosec G404 - hv := &hashValues{ - Children: children, - Value: value, - Key: ToKey(key, branchFactor), + hv := &node{ + key: ToKey(key), + dbNode: dbNode{ + children: children, + value: value, + }, } - // Serialize the *hashValues with both codecs + // Serialize hv with both codecs hvBytes1 := codec1.encodeHashValues(hv) hvBytes2 := codec2.encodeHashValues(hv) @@ -267,6 +249,6 @@ func FuzzEncodeHashValues(f *testing.F) { func TestCodecDecodeKeyLengthOverflowRegression(t *testing.T) { codec := codec.(*codecImpl) bytes := bytes.NewReader(binary.AppendUvarint(nil, math.MaxInt)) - _, err := codec.decodeKey(bytes, BranchFactor16) + _, err := codec.decodeKey(bytes) require.ErrorIs(t, err, io.ErrUnexpectedEOF) } diff --git a/x/merkledb/db.go b/x/merkledb/db.go index 87439010b1f0..88dd667ae22a 100644 --- a/x/merkledb/db.go +++ b/x/merkledb/db.go @@ -204,8 +204,7 @@ type merkleDB struct { // [calculateNodeIDsHelper] at any given time. calculateNodeIDsSema *semaphore.Weighted - toKey func(p []byte) Key - rootKey Key + tokenSize int } // New returns a new merkle database. @@ -223,17 +222,13 @@ func newDatabase( config Config, metrics merkleMetrics, ) (*merkleDB, error) { - rootGenConcurrency := uint(runtime.NumCPU()) - if config.RootGenConcurrency != 0 { - rootGenConcurrency = config.RootGenConcurrency - } - if err := config.BranchFactor.Valid(); err != nil { return nil, err } - toKey := func(b []byte) Key { - return ToKey(b, config.BranchFactor) + rootGenConcurrency := uint(runtime.NumCPU()) + if config.RootGenConcurrency != 0 { + rootGenConcurrency = config.RootGenConcurrency } // Share a sync.Pool of []byte between the intermediateNodeDB and valueNodeDB @@ -246,15 +241,14 @@ func newDatabase( trieDB := &merkleDB{ metrics: metrics, baseDB: db, - valueNodeDB: newValueNodeDB(db, bufferPool, metrics, int(config.ValueNodeCacheSize), config.BranchFactor), - intermediateNodeDB: newIntermediateNodeDB(db, bufferPool, metrics, int(config.IntermediateNodeCacheSize), int(config.EvictionBatchSize)), - history: newTrieHistory(int(config.HistoryLength), toKey), + valueNodeDB: newValueNodeDB(db, bufferPool, metrics, int(config.ValueNodeCacheSize)), + intermediateNodeDB: newIntermediateNodeDB(db, bufferPool, metrics, int(config.IntermediateNodeCacheSize), int(config.EvictionBatchSize), BranchFactorToTokenSize[config.BranchFactor]), + history: newTrieHistory(int(config.HistoryLength)), debugTracer: getTracerIfEnabled(config.TraceLevel, DebugTrace, config.Tracer), infoTracer: getTracerIfEnabled(config.TraceLevel, InfoTrace, config.Tracer), childViews: make([]*trieView, 0, defaultPreallocationSize), calculateNodeIDsSema: semaphore.NewWeighted(int64(rootGenConcurrency)), - toKey: toKey, - rootKey: toKey(rootKey), + tokenSize: BranchFactorToTokenSize[config.BranchFactor], } root, err := trieDB.initializeRootIfNeeded() @@ -292,7 +286,7 @@ func newDatabase( // Deletes every intermediate node and rebuilds them by re-adding every key/value. // TODO: make this more efficient by only clearing out the stale portions of the trie. func (db *merkleDB) rebuild(ctx context.Context, cacheSize int) error { - db.root = newNode(nil, db.rootKey) + db.root = newNode(Key{}) // Delete intermediate nodes. if err := database.ClearPrefix(db.baseDB, intermediateNodePrefix, rebuildIntermediateDeletionWriteSize); err != nil { @@ -474,7 +468,7 @@ func (db *merkleDB) PrefetchPath(key []byte) error { } func (db *merkleDB) prefetchPath(view *trieView, keyBytes []byte) error { - return view.visitPathToKey(db.toKey(keyBytes), func(n *node) error { + return view.visitPathToKey(ToKey(keyBytes), func(n *node) error { if !n.hasValue() { return db.intermediateNodeDB.nodeCache.Put(n.key, n) } @@ -503,7 +497,7 @@ func (db *merkleDB) GetValues(ctx context.Context, keys [][]byte) ([][]byte, []e values := make([][]byte, len(keys)) errors := make([]error, len(keys)) for i, key := range keys { - values[i], errors[i] = db.getValueCopy(db.toKey(key)) + values[i], errors[i] = db.getValueCopy(ToKey(key)) } return values, errors } @@ -517,7 +511,7 @@ func (db *merkleDB) GetValue(ctx context.Context, key []byte) ([]byte, error) { db.lock.RLock() defer db.lock.RUnlock() - return db.getValueCopy(db.toKey(key)) + return db.getValueCopy(ToKey(key)) } // getValueCopy returns a copy of the value for the given [key]. @@ -783,7 +777,7 @@ func (db *merkleDB) Has(k []byte) (bool, error) { return false, database.ErrClosed } - _, err := db.getValueWithoutLock(db.toKey(k)) + _, err := db.getValueWithoutLock(ToKey(k)) if err == database.ErrNotFound { return false, nil } @@ -921,7 +915,7 @@ func (db *merkleDB) commitChanges(ctx context.Context, trieToCommit *trieView) e return nil } - rootChange, ok := changes.nodes[db.rootKey] + rootChange, ok := changes.nodes[Key{}] if !ok { return errNoNewRoot } @@ -1020,7 +1014,7 @@ func (db *merkleDB) VerifyChangeProof( return err } - smallestKey := maybe.Bind(start, db.toKey) + smallestKey := maybe.Bind(start, ToKey) // Make sure the start proof, if given, is well-formed. if err := verifyProofPath(proof.StartProof, smallestKey); err != nil { @@ -1030,12 +1024,12 @@ func (db *merkleDB) VerifyChangeProof( // Find the greatest key in [proof.KeyChanges] // Note that [proof.EndProof] is a proof for this key. // [largestKey] is also used when we add children of proof nodes to [trie] below. - largestKey := maybe.Bind(end, db.toKey) + largestKey := maybe.Bind(end, ToKey) if len(proof.KeyChanges) > 0 { // If [proof] has key-value pairs, we should insert children // greater than [end] to ancestors of the node containing [end] // so that we get the expected root ID. - largestKey = maybe.Some(db.toKey(proof.KeyChanges[len(proof.KeyChanges)-1].Key)) + largestKey = maybe.Some(ToKey(proof.KeyChanges[len(proof.KeyChanges)-1].Key)) } // Make sure the end proof, if given, is well-formed. @@ -1045,7 +1039,7 @@ func (db *merkleDB) VerifyChangeProof( keyValues := make(map[Key]maybe.Maybe[[]byte], len(proof.KeyChanges)) for _, keyValue := range proof.KeyChanges { - keyValues[db.toKey(keyValue.Key)] = keyValue.Value + keyValues[ToKey(keyValue.Key)] = keyValue.Value } // want to prevent commit writes to DB, but not prevent DB reads @@ -1149,9 +1143,9 @@ func (db *merkleDB) initializeRootIfNeeded() (ids.ID, error) { // not sure if the root exists or had a value or not // check under both prefixes var err error - db.root, err = db.intermediateNodeDB.Get(db.rootKey) + db.root, err = db.intermediateNodeDB.Get(Key{}) if err == database.ErrNotFound { - db.root, err = db.valueNodeDB.Get(db.rootKey) + db.root, err = db.valueNodeDB.Get(Key{}) } if err == nil { // Root already exists, so calculate its id @@ -1163,12 +1157,12 @@ func (db *merkleDB) initializeRootIfNeeded() (ids.ID, error) { } // Root doesn't exist; make a new one. - db.root = newNode(nil, db.rootKey) + db.root = newNode(Key{}) // update its ID db.root.calculateID(db.metrics) - if err := db.intermediateNodeDB.Put(db.rootKey, db.root); err != nil { + if err := db.intermediateNodeDB.Put(Key{}, db.root); err != nil { return ids.Empty, err } @@ -1248,7 +1242,7 @@ func (db *merkleDB) getNode(key Key, hasValue bool) (*node, error) { switch { case db.closed: return nil, database.ErrClosed - case key == db.rootKey: + case key == Key{}: return db.root, nil case hasValue: return db.valueNodeDB.Get(key) diff --git a/x/merkledb/db_test.go b/x/merkledb/db_test.go index d4f09803cdaf..2d6439fd294b 100644 --- a/x/merkledb/db_test.go +++ b/x/merkledb/db_test.go @@ -63,7 +63,7 @@ func Test_MerkleDB_Get_Safety(t *testing.T) { val, err := db.Get(keyBytes) require.NoError(err) - n, err := db.getNode(ToKey(keyBytes, BranchFactor16), true) + n, err := db.getNode(ToKey(keyBytes), true) require.NoError(err) // node's value shouldn't be affected by the edit @@ -96,7 +96,7 @@ func Test_MerkleDB_GetValues_Safety(t *testing.T) { } func Test_MerkleDB_DB_Interface(t *testing.T) { - for _, bf := range branchFactors { + for _, bf := range validBranchFactors { for _, test := range database.Tests { db, err := getBasicDBWithBranchFactor(bf) require.NoError(t, err) @@ -108,7 +108,7 @@ func Test_MerkleDB_DB_Interface(t *testing.T) { func Benchmark_MerkleDB_DBInterface(b *testing.B) { for _, size := range database.BenchmarkSizes { keys, values := database.SetupBenchmark(b, size[0], size[1], size[2]) - for _, bf := range branchFactors { + for _, bf := range validBranchFactors { for _, bench := range database.Benchmarks { db, err := getBasicDBWithBranchFactor(bf) require.NoError(b, err) @@ -785,7 +785,7 @@ func FuzzMerkleDBEmptyRandomizedActions(f *testing.F) { } require := require.New(t) r := rand.New(rand.NewSource(randSeed)) // #nosec G404 - for _, bf := range branchFactors { + for _, ts := range validTokenSizes { runRandDBTest( require, r, @@ -795,7 +795,7 @@ func FuzzMerkleDBEmptyRandomizedActions(f *testing.F) { size, 0.01, /*checkHashProbability*/ ), - bf, + ts, ) } }) @@ -813,7 +813,7 @@ func FuzzMerkleDBInitialValuesRandomizedActions(f *testing.F) { } require := require.New(t) r := rand.New(rand.NewSource(randSeed)) // #nosec G404 - for _, bf := range branchFactors { + for _, ts := range validTokenSizes { runRandDBTest( require, r, @@ -824,7 +824,7 @@ func FuzzMerkleDBInitialValuesRandomizedActions(f *testing.F) { numSteps, 0.001, /*checkHashProbability*/ ), - bf, + ts, ) } }) @@ -851,8 +851,8 @@ const ( opMax // boundary value, not an actual op ) -func runRandDBTest(require *require.Assertions, r *rand.Rand, rt randTest, bf BranchFactor) { - db, err := getBasicDBWithBranchFactor(bf) +func runRandDBTest(require *require.Assertions, r *rand.Rand, rt randTest, tokenSize int) { + db, err := getBasicDBWithBranchFactor(tokenSizeToBranchFactor[tokenSize]) require.NoError(err) const ( @@ -877,13 +877,13 @@ func runRandDBTest(require *require.Assertions, r *rand.Rand, rt randTest, bf Br case opUpdate: require.NoError(currentBatch.Put(step.key, step.value)) - uncommittedKeyValues[ToKey(step.key, bf)] = step.value - uncommittedDeletes.Remove(ToKey(step.key, bf)) + uncommittedKeyValues[ToKey(step.key)] = step.value + uncommittedDeletes.Remove(ToKey(step.key)) case opDelete: require.NoError(currentBatch.Delete(step.key)) - uncommittedDeletes.Add(ToKey(step.key, bf)) - delete(uncommittedKeyValues, ToKey(step.key, bf)) + uncommittedDeletes.Add(ToKey(step.key)) + delete(uncommittedKeyValues, ToKey(step.key)) case opGenerateRangeProof: root, err := db.GetMerkleRoot(context.Background()) require.NoError(err) @@ -910,6 +910,7 @@ func runRandDBTest(require *require.Assertions, r *rand.Rand, rt randTest, bf Br start, end, root, + tokenSize, )) case opGenerateChangeProof: root, err := db.GetMerkleRoot(context.Background()) @@ -937,7 +938,7 @@ func runRandDBTest(require *require.Assertions, r *rand.Rand, rt randTest, bf Br require.NoError(err) require.LessOrEqual(len(changeProof.KeyChanges), maxProofLen) - changeProofDB, err := getBasicDBWithBranchFactor(bf) + changeProofDB, err := getBasicDBWithBranchFactor(tokenSizeToBranchFactor[tokenSize]) require.NoError(err) require.NoError(changeProofDB.VerifyChangeProof( @@ -984,10 +985,10 @@ func runRandDBTest(require *require.Assertions, r *rand.Rand, rt randTest, bf Br require.ErrorIs(err, database.ErrNotFound) } - want := values[ToKey(step.key, bf)] + want := values[ToKey(step.key)] require.True(bytes.Equal(want, v)) // Use bytes.Equal so nil treated equal to []byte{} - trieValue, err := getNodeValueWithBranchFactor(db, string(step.key), bf) + trieValue, err := getNodeValue(db, string(step.key)) if err != nil { require.ErrorIs(err, database.ErrNotFound) } @@ -995,7 +996,7 @@ func runRandDBTest(require *require.Assertions, r *rand.Rand, rt randTest, bf Br require.True(bytes.Equal(want, trieValue)) // Use bytes.Equal so nil treated equal to []byte{} case opCheckhash: // Create a view with the same key-values as [db] - newDB, err := getBasicDBWithBranchFactor(bf) + newDB, err := getBasicDBWithBranchFactor(tokenSizeToBranchFactor[tokenSize]) require.NoError(err) ops := make([]database.BatchOp, 0, len(values)) @@ -1093,7 +1094,7 @@ func generateRandTestWithKeys( step.value = genEnd(step.key) case opCheckhash: // this gets really expensive so control how often it happens - if r.Float64() < checkHashProbability { + if r.Float64() > checkHashProbability { continue } } diff --git a/x/merkledb/helpers_test.go b/x/merkledb/helpers_test.go index 3cd84ce11e7c..b7a2908ff377 100644 --- a/x/merkledb/helpers_test.go +++ b/x/merkledb/helpers_test.go @@ -52,13 +52,13 @@ func writeBasicBatch(t *testing.T, db *merkleDB) { func newRandomProofNode(r *rand.Rand) ProofNode { key := make([]byte, r.Intn(32)) // #nosec G404 _, _ = r.Read(key) // #nosec G404 - serializedKey := ToKey(key, BranchFactor16) + serializedKey := ToKey(key) val := make([]byte, r.Intn(64)) // #nosec G404 _, _ = r.Read(val) // #nosec G404 children := map[byte]ids.ID{} - for j := 0; j < int(BranchFactor16); j++ { + for j := 0; j < 16; j++ { if r.Float64() < 0.5 { var childID ids.ID _, _ = r.Read(childID[:]) // #nosec G404 diff --git a/x/merkledb/history.go b/x/merkledb/history.go index c82fbb1e5f78..103c4c9357e8 100644 --- a/x/merkledb/history.go +++ b/x/merkledb/history.go @@ -32,8 +32,6 @@ type trieHistory struct { // Each change is tagged with this monotonic increasing number. nextInsertNumber uint64 - - toKey func([]byte) Key } // Tracks the beginning and ending state of a value. @@ -65,12 +63,11 @@ func newChangeSummary(estimatedSize int) *changeSummary { } } -func newTrieHistory(maxHistoryLookback int, toKey func([]byte) Key) *trieHistory { +func newTrieHistory(maxHistoryLookback int) *trieHistory { return &trieHistory{ maxHistoryLen: maxHistoryLookback, history: buffer.NewUnboundedDeque[*changeSummaryAndInsertNumber](maxHistoryLookback), lastChanges: make(map[ids.ID]*changeSummaryAndInsertNumber), - toKey: toKey, } } @@ -158,8 +155,8 @@ func (th *trieHistory) getValueChanges( // in order to stay within the [maxLength] limit if necessary. changedKeys = set.Set[Key]{} - startKey = maybe.Bind(start, th.toKey) - endKey = maybe.Bind(end, th.toKey) + startKey = maybe.Bind(start, ToKey) + endKey = maybe.Bind(end, ToKey) // For each element in the history in the range between [startRoot]'s // last appearance (exclusive) and [endRoot]'s last appearance (inclusive), @@ -237,8 +234,8 @@ func (th *trieHistory) getChangesToGetToRoot(rootID ids.ID, start maybe.Maybe[[] } var ( - startKey = maybe.Bind(start, th.toKey) - endKey = maybe.Bind(end, th.toKey) + startKey = maybe.Bind(start, ToKey) + endKey = maybe.Bind(end, ToKey) combinedChanges = newChangeSummary(defaultPreallocationSize) mostRecentChangeInsertNumber = th.nextInsertNumber - 1 mostRecentChangeIndex = th.history.Len() - 1 diff --git a/x/merkledb/history_test.go b/x/merkledb/history_test.go index 1261c92b22df..f27c1293cde0 100644 --- a/x/merkledb/history_test.go +++ b/x/merkledb/history_test.go @@ -37,7 +37,7 @@ func Test_History_Simple(t *testing.T) { require.NoError(err) require.NotNil(origProof) origRootID := db.root.id - require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() require.NoError(batch.Put([]byte("key"), []byte("value0"))) @@ -45,7 +45,7 @@ func Test_History_Simple(t *testing.T) { newProof, err := db.GetRangeProofAtRoot(context.Background(), origRootID, maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(newProof) - require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() require.NoError(batch.Put([]byte("key1"), []byte("value1"))) @@ -54,7 +54,7 @@ func Test_History_Simple(t *testing.T) { newProof, err = db.GetRangeProofAtRoot(context.Background(), origRootID, maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(newProof) - require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() require.NoError(batch.Put([]byte("k"), []byte("v"))) @@ -62,7 +62,7 @@ func Test_History_Simple(t *testing.T) { newProof, err = db.GetRangeProofAtRoot(context.Background(), origRootID, maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(newProof) - require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() require.NoError(batch.Delete([]byte("k"))) @@ -78,7 +78,7 @@ func Test_History_Simple(t *testing.T) { newProof, err = db.GetRangeProofAtRoot(context.Background(), origRootID, maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(newProof) - require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) } func Test_History_Large(t *testing.T) { @@ -141,7 +141,7 @@ func Test_History_Large(t *testing.T) { require.NoError(err) require.NotNil(proof) - require.NoError(proof.Verify(context.Background(), maybe.Nothing[[]byte](), maybe.Nothing[[]byte](), roots[i])) + require.NoError(proof.Verify(context.Background(), maybe.Nothing[[]byte](), maybe.Nothing[[]byte](), roots[i], BranchFactorToTokenSize[config.BranchFactor])) } } } @@ -240,6 +240,7 @@ func Test_History_Trigger_History_Queue_Looping(t *testing.T) { maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, + db.tokenSize, )) // write a new value into the db, now there should be 2 roots in the history @@ -256,6 +257,7 @@ func Test_History_Trigger_History_Queue_Looping(t *testing.T) { maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, + db.tokenSize, )) // trigger a new root to be added to the history, which should cause rollover since there can only be 2 @@ -312,10 +314,10 @@ func Test_History_Values_Lookup_Over_Queue_Break(t *testing.T) { // changes should still be collectable even though the history has had to loop due to hitting max size changes, err := db.history.getValueChanges(startRoot, endRoot, maybe.Nothing[[]byte](), maybe.Nothing[[]byte](), 10) require.NoError(err) - require.Contains(changes.values, ToKey([]byte("key1"), BranchFactor16)) - require.Equal([]byte("value1"), changes.values[ToKey([]byte("key1"), BranchFactor16)].after.Value()) - require.Contains(changes.values, ToKey([]byte("key2"), BranchFactor16)) - require.Equal([]byte("value3"), changes.values[ToKey([]byte("key2"), BranchFactor16)].after.Value()) + require.Contains(changes.values, ToKey([]byte("key1"))) + require.Equal([]byte("value1"), changes.values[ToKey([]byte("key1"))].after.Value()) + require.Contains(changes.values, ToKey([]byte("key2"))) + require.Equal([]byte("value3"), changes.values[ToKey([]byte("key2"))].after.Value()) } func Test_History_RepeatedRoot(t *testing.T) { @@ -337,7 +339,7 @@ func Test_History_RepeatedRoot(t *testing.T) { require.NoError(err) require.NotNil(origProof) origRootID := db.root.id - require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() require.NoError(batch.Put([]byte("key1"), []byte("other"))) @@ -347,7 +349,7 @@ func Test_History_RepeatedRoot(t *testing.T) { newProof, err := db.GetRangeProofAtRoot(context.Background(), origRootID, maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(newProof) - require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) // revert state to be the same as in orig proof batch = db.NewBatch() @@ -359,7 +361,7 @@ func Test_History_RepeatedRoot(t *testing.T) { newProof, err = db.GetRangeProofAtRoot(context.Background(), origRootID, maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(newProof) - require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) } func Test_History_ExcessDeletes(t *testing.T) { @@ -379,7 +381,7 @@ func Test_History_ExcessDeletes(t *testing.T) { require.NoError(err) require.NotNil(origProof) origRootID := db.root.id - require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() require.NoError(batch.Delete([]byte("key1"))) @@ -391,7 +393,7 @@ func Test_History_ExcessDeletes(t *testing.T) { newProof, err := db.GetRangeProofAtRoot(context.Background(), origRootID, maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(newProof) - require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) } func Test_History_DontIncludeAllNodes(t *testing.T) { @@ -411,7 +413,7 @@ func Test_History_DontIncludeAllNodes(t *testing.T) { require.NoError(err) require.NotNil(origProof) origRootID := db.root.id - require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() require.NoError(batch.Put([]byte("z"), []byte("z"))) @@ -419,7 +421,7 @@ func Test_History_DontIncludeAllNodes(t *testing.T) { newProof, err := db.GetRangeProofAtRoot(context.Background(), origRootID, maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(newProof) - require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) } func Test_History_Branching2Nodes(t *testing.T) { @@ -439,7 +441,7 @@ func Test_History_Branching2Nodes(t *testing.T) { require.NoError(err) require.NotNil(origProof) origRootID := db.root.id - require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() require.NoError(batch.Put([]byte("k"), []byte("v"))) @@ -447,7 +449,7 @@ func Test_History_Branching2Nodes(t *testing.T) { newProof, err := db.GetRangeProofAtRoot(context.Background(), origRootID, maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(newProof) - require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) } func Test_History_Branching3Nodes(t *testing.T) { @@ -467,7 +469,7 @@ func Test_History_Branching3Nodes(t *testing.T) { require.NoError(err) require.NotNil(origProof) origRootID := db.root.id - require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(origProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) batch = db.NewBatch() require.NoError(batch.Put([]byte("key321"), []byte("value321"))) @@ -475,7 +477,7 @@ func Test_History_Branching3Nodes(t *testing.T) { newProof, err := db.GetRangeProofAtRoot(context.Background(), origRootID, maybe.Some([]byte("k")), maybe.Some([]byte("key3")), 10) require.NoError(err) require.NotNil(newProof) - require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID)) + require.NoError(newProof.Verify(context.Background(), maybe.Some([]byte("k")), maybe.Some([]byte("key3")), origRootID, db.tokenSize)) } func Test_History_MaxLength(t *testing.T) { @@ -572,9 +574,7 @@ func TestHistoryRecord(t *testing.T) { require := require.New(t) maxHistoryLen := 3 - th := newTrieHistory(maxHistoryLen, func(bytes []byte) Key { - return ToKey(bytes, BranchFactor16) - }) + th := newTrieHistory(maxHistoryLen) changes := []*changeSummary{} for i := 0; i < maxHistoryLen; i++ { // Fill the history @@ -647,22 +647,20 @@ func TestHistoryRecord(t *testing.T) { func TestHistoryGetChangesToRoot(t *testing.T) { maxHistoryLen := 3 - history := newTrieHistory(maxHistoryLen, func(bytes []byte) Key { - return ToKey(bytes, BranchFactor16) - }) + history := newTrieHistory(maxHistoryLen) changes := []*changeSummary{} for i := 0; i < maxHistoryLen; i++ { // Fill the history changes = append(changes, &changeSummary{ rootID: ids.GenerateTestID(), nodes: map[Key]*change[*node]{ - history.toKey([]byte{byte(i)}): { + ToKey([]byte{byte(i)}): { before: &node{id: ids.GenerateTestID()}, after: &node{id: ids.GenerateTestID()}, }, }, values: map[Key]*change[maybe.Maybe[[]byte]]{ - history.toKey([]byte{byte(i)}): { + ToKey([]byte{byte(i)}): { before: maybe.Some([]byte{byte(i)}), after: maybe.Some([]byte{byte(i + 1)}), }, @@ -701,7 +699,7 @@ func TestHistoryGetChangesToRoot(t *testing.T) { require.Len(got.nodes, 1) require.Len(got.values, 1) reversedChanges := changes[maxHistoryLen-1] - removedKey := history.toKey([]byte{byte(maxHistoryLen - 1)}) + removedKey := ToKey([]byte{byte(maxHistoryLen - 1)}) require.Equal(reversedChanges.nodes[removedKey].before, got.nodes[removedKey].after) require.Equal(reversedChanges.values[removedKey].before, got.values[removedKey].after) require.Equal(reversedChanges.values[removedKey].after, got.values[removedKey].before) @@ -714,12 +712,12 @@ func TestHistoryGetChangesToRoot(t *testing.T) { require.Len(got.nodes, 2) require.Len(got.values, 2) reversedChanges1 := changes[maxHistoryLen-1] - removedKey1 := history.toKey([]byte{byte(maxHistoryLen - 1)}) + removedKey1 := ToKey([]byte{byte(maxHistoryLen - 1)}) require.Equal(reversedChanges1.nodes[removedKey1].before, got.nodes[removedKey1].after) require.Equal(reversedChanges1.values[removedKey1].before, got.values[removedKey1].after) require.Equal(reversedChanges1.values[removedKey1].after, got.values[removedKey1].before) reversedChanges2 := changes[maxHistoryLen-2] - removedKey2 := history.toKey([]byte{byte(maxHistoryLen - 2)}) + removedKey2 := ToKey([]byte{byte(maxHistoryLen - 2)}) require.Equal(reversedChanges2.nodes[removedKey2].before, got.nodes[removedKey2].after) require.Equal(reversedChanges2.values[removedKey2].before, got.values[removedKey2].after) require.Equal(reversedChanges2.values[removedKey2].after, got.values[removedKey2].before) @@ -733,12 +731,12 @@ func TestHistoryGetChangesToRoot(t *testing.T) { require.Len(got.nodes, 2) require.Len(got.values, 1) reversedChanges1 := changes[maxHistoryLen-1] - removedKey1 := history.toKey([]byte{byte(maxHistoryLen - 1)}) + removedKey1 := ToKey([]byte{byte(maxHistoryLen - 1)}) require.Equal(reversedChanges1.nodes[removedKey1].before, got.nodes[removedKey1].after) require.Equal(reversedChanges1.values[removedKey1].before, got.values[removedKey1].after) require.Equal(reversedChanges1.values[removedKey1].after, got.values[removedKey1].before) reversedChanges2 := changes[maxHistoryLen-2] - removedKey2 := history.toKey([]byte{byte(maxHistoryLen - 2)}) + removedKey2 := ToKey([]byte{byte(maxHistoryLen - 2)}) require.Equal(reversedChanges2.nodes[removedKey2].before, got.nodes[removedKey2].after) }, }, @@ -750,10 +748,10 @@ func TestHistoryGetChangesToRoot(t *testing.T) { require.Len(got.nodes, 2) require.Len(got.values, 1) reversedChanges1 := changes[maxHistoryLen-1] - removedKey1 := history.toKey([]byte{byte(maxHistoryLen - 1)}) + removedKey1 := ToKey([]byte{byte(maxHistoryLen - 1)}) require.Equal(reversedChanges1.nodes[removedKey1].before, got.nodes[removedKey1].after) reversedChanges2 := changes[maxHistoryLen-2] - removedKey2 := history.toKey([]byte{byte(maxHistoryLen - 2)}) + removedKey2 := ToKey([]byte{byte(maxHistoryLen - 2)}) require.Equal(reversedChanges2.nodes[removedKey2].before, got.nodes[removedKey2].after) require.Equal(reversedChanges2.values[removedKey2].before, got.values[removedKey2].after) require.Equal(reversedChanges2.values[removedKey2].after, got.values[removedKey2].before) diff --git a/x/merkledb/intermediate_node_db.go b/x/merkledb/intermediate_node_db.go index e146b943d6c2..1602aa05bbaa 100644 --- a/x/merkledb/intermediate_node_db.go +++ b/x/merkledb/intermediate_node_db.go @@ -13,7 +13,7 @@ const defaultBufferLength = 256 // Holds intermediate nodes. That is, those without values. // Changes to this database aren't written to [baseDB] until -// they're evicted from the [nodeCache] or Flush is called.. +// they're evicted from the [nodeCache] or Flush is called. type intermediateNodeDB struct { // Holds unused []byte bufferPool *sync.Pool @@ -31,6 +31,7 @@ type intermediateNodeDB struct { // the number of bytes to evict during an eviction batch evictionBatchSize int metrics merkleMetrics + tokenSize int } func newIntermediateNodeDB( @@ -39,12 +40,14 @@ func newIntermediateNodeDB( metrics merkleMetrics, size int, evictionBatchSize int, + tokenSize int, ) *intermediateNodeDB { result := &intermediateNodeDB{ metrics: metrics, baseDB: db, bufferPool: bufferPool, evictionBatchSize: evictionBatchSize, + tokenSize: tokenSize, } result.nodeCache = newOnEvictCache( size, @@ -121,15 +124,15 @@ func (db *intermediateNodeDB) Get(key Key) (*node, error) { // constructDBKey returns a key that can be used in [db.baseDB]. // We need to be able to differentiate between two keys of equal -// byte length but different token length, so we add padding to differentiate. +// byte length but different bit length, so we add padding to differentiate. // Additionally, we add a prefix indicating it is part of the intermediateNodeDB. func (db *intermediateNodeDB) constructDBKey(key Key) []byte { - if key.branchFactor == BranchFactor256 { - // For BranchFactor256, no padding is needed since byte length == token length + if db.tokenSize == 8 { + // For tokens of size byte, no padding is needed since byte length == token length return addPrefixToKey(db.bufferPool, intermediateNodePrefix, key.Bytes()) } - return addPrefixToKey(db.bufferPool, intermediateNodePrefix, key.Append(1).Bytes()) + return addPrefixToKey(db.bufferPool, intermediateNodePrefix, key.Extend(ToToken(1, db.tokenSize)).Bytes()) } func (db *intermediateNodeDB) Put(key Key, n *node) error { diff --git a/x/merkledb/intermediate_node_db_test.go b/x/merkledb/intermediate_node_db_test.go index 3d40aa7f8a05..027798017e95 100644 --- a/x/merkledb/intermediate_node_db_test.go +++ b/x/merkledb/intermediate_node_db_test.go @@ -23,7 +23,7 @@ import ( func Test_IntermediateNodeDB(t *testing.T) { require := require.New(t) - n := newNode(nil, ToKey([]byte{0x00}, BranchFactor16)) + n := newNode(ToKey([]byte{0x00})) n.setValue(maybe.Some([]byte{byte(0x02)})) nodeSize := cacheEntrySize(n.key, n) @@ -39,11 +39,12 @@ func Test_IntermediateNodeDB(t *testing.T) { &mockMetrics{}, cacheSize, evictionBatchSize, + 4, ) // Put a key-node pair - node1Key := ToKey([]byte{0x01}, BranchFactor16) - node1 := newNode(nil, node1Key) + node1Key := ToKey([]byte{0x01}) + node1 := newNode(node1Key) node1.setValue(maybe.Some([]byte{byte(0x01)})) require.NoError(db.Put(node1Key, node1)) @@ -53,7 +54,7 @@ func Test_IntermediateNodeDB(t *testing.T) { require.Equal(node1, node1Read) // Overwrite the key-node pair - node1Updated := newNode(nil, node1Key) + node1Updated := newNode(node1Key) node1Updated.setValue(maybe.Some([]byte{byte(0x02)})) require.NoError(db.Put(node1Key, node1Updated)) @@ -73,8 +74,8 @@ func Test_IntermediateNodeDB(t *testing.T) { expectedSize := 0 added := 0 for { - key := ToKey([]byte{byte(added)}, BranchFactor16) - node := newNode(nil, emptyKey(BranchFactor16)) + key := ToKey([]byte{byte(added)}) + node := newNode(Key{}) node.setValue(maybe.Some([]byte{byte(added)})) newExpectedSize := expectedSize + cacheEntrySize(key, node) if newExpectedSize > cacheSize { @@ -93,8 +94,8 @@ func Test_IntermediateNodeDB(t *testing.T) { // Put one more element in the cache, which should trigger an eviction // of all but 2 elements. 2 elements remain rather than 1 element because of // the added key prefix increasing the size tracked by the batch. - key := ToKey([]byte{byte(added)}, BranchFactor16) - node := newNode(nil, emptyKey(BranchFactor16)) + key := ToKey([]byte{byte(added)}) + node := newNode(Key{}) node.setValue(maybe.Some([]byte{byte(added)})) require.NoError(db.Put(key, node)) @@ -102,7 +103,7 @@ func Test_IntermediateNodeDB(t *testing.T) { require.Equal(1, db.nodeCache.fifo.Len()) gotKey, _, ok := db.nodeCache.fifo.Oldest() require.True(ok) - require.Equal(ToKey([]byte{byte(added)}, BranchFactor16), gotKey) + require.Equal(ToKey([]byte{byte(added)}), gotKey) // Get a node from the base database // Use an early key that has been evicted from the cache @@ -134,41 +135,45 @@ func FuzzIntermediateNodeDBConstructDBKey(f *testing.F) { cacheSize := 200 evictionBatchSize := cacheSize baseDB := memdb.New() - db := newIntermediateNodeDB( - baseDB, - &sync.Pool{ - New: func() interface{} { return make([]byte, 0) }, - }, - &mockMetrics{}, - cacheSize, - evictionBatchSize, - ) + f.Fuzz(func( t *testing.T, key []byte, tokenLength uint, ) { require := require.New(t) - for _, branchFactor := range branchFactors { - p := ToKey(key, branchFactor) - if p.tokenLength <= int(tokenLength) { + for _, tokenSize := range validTokenSizes { + db := newIntermediateNodeDB( + baseDB, + &sync.Pool{ + New: func() interface{} { return make([]byte, 0) }, + }, + &mockMetrics{}, + cacheSize, + evictionBatchSize, + tokenSize, + ) + + p := ToKey(key) + uBitLength := tokenLength * uint(tokenSize) + if uBitLength >= uint(p.length) { t.SkipNow() } - p = p.Take(int(tokenLength)) + p = p.Take(int(uBitLength)) constructedKey := db.constructDBKey(p) baseLength := len(p.value) + len(intermediateNodePrefix) require.Equal(intermediateNodePrefix, constructedKey[:len(intermediateNodePrefix)]) switch { - case branchFactor == BranchFactor256: + case tokenSize == 8: // for keys with tokens of size byte, no padding is added require.Equal(p.Bytes(), constructedKey[len(intermediateNodePrefix):]) case p.hasPartialByte(): require.Len(constructedKey, baseLength) - require.Equal(p.Append(1).Bytes(), constructedKey[len(intermediateNodePrefix):]) + require.Equal(p.Extend(ToToken(1, tokenSize)).Bytes(), constructedKey[len(intermediateNodePrefix):]) default: // when a whole number of bytes, there is an extra padding byte require.Len(constructedKey, baseLength+1) - require.Equal(p.Append(1).Bytes(), constructedKey[len(intermediateNodePrefix):]) + require.Equal(p.Extend(ToToken(1, tokenSize)).Bytes(), constructedKey[len(intermediateNodePrefix):]) } } }) @@ -187,10 +192,11 @@ func Test_IntermediateNodeDB_ConstructDBKey_DirtyBuffer(t *testing.T) { &mockMetrics{}, cacheSize, evictionBatchSize, + 4, ) db.bufferPool.Put([]byte{0xFF, 0xFF, 0xFF}) - constructedKey := db.constructDBKey(ToKey([]byte{}, BranchFactor16)) + constructedKey := db.constructDBKey(ToKey([]byte{})) require.Len(constructedKey, 2) require.Equal(intermediateNodePrefix, constructedKey[:len(intermediateNodePrefix)]) require.Equal(byte(16), constructedKey[len(constructedKey)-1]) @@ -201,9 +207,9 @@ func Test_IntermediateNodeDB_ConstructDBKey_DirtyBuffer(t *testing.T) { }, } db.bufferPool.Put([]byte{0xFF, 0xFF, 0xFF}) - p := ToKey([]byte{0xF0}, BranchFactor16).Take(1) + p := ToKey([]byte{0xF0}).Take(4) constructedKey = db.constructDBKey(p) require.Len(constructedKey, 2) require.Equal(intermediateNodePrefix, constructedKey[:len(intermediateNodePrefix)]) - require.Equal(p.Append(1).Bytes(), constructedKey[len(intermediateNodePrefix):]) + require.Equal(p.Extend(ToToken(1, 4)).Bytes(), constructedKey[len(intermediateNodePrefix):]) } diff --git a/x/merkledb/key.go b/x/merkledb/key.go index 461372a2baa8..b92ac2d7ceec 100644 --- a/x/merkledb/key.go +++ b/x/merkledb/key.go @@ -8,112 +8,135 @@ import ( "fmt" "strings" "unsafe" + + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" ) var ( - errInvalidBranchFactor = errors.New("invalid branch factor") - - branchFactorToTokenConfig = map[BranchFactor]tokenConfig{ - BranchFactor2: { - branchFactor: BranchFactor2, - tokenBitSize: 1, - tokensPerByte: 8, - singleTokenMask: 0b0000_0001, - }, - BranchFactor4: { - branchFactor: BranchFactor4, - tokenBitSize: 2, - tokensPerByte: 4, - singleTokenMask: 0b0000_0011, - }, - BranchFactor16: { - branchFactor: BranchFactor16, - tokenBitSize: 4, - tokensPerByte: 2, - singleTokenMask: 0b0000_1111, - }, - BranchFactor256: { - branchFactor: BranchFactor256, - tokenBitSize: 8, - tokensPerByte: 1, - singleTokenMask: 0b1111_1111, - }, + ErrInvalidBranchFactor = errors.New("branch factor must match one of the predefined branch factors") + + BranchFactorToTokenSize = map[BranchFactor]int{ + BranchFactor2: 1, + BranchFactor4: 2, + BranchFactor16: 4, + BranchFactor256: 8, + } + + tokenSizeToBranchFactor = map[int]BranchFactor{ + 1: BranchFactor2, + 2: BranchFactor4, + 4: BranchFactor16, + 8: BranchFactor256, + } + + validTokenSizes = maps.Keys(tokenSizeToBranchFactor) + + validBranchFactors = []BranchFactor{ + BranchFactor2, + BranchFactor4, + BranchFactor16, + BranchFactor256, } ) type BranchFactor int const ( - BranchFactor2 BranchFactor = 2 - BranchFactor4 BranchFactor = 4 - BranchFactor16 BranchFactor = 16 - BranchFactor256 BranchFactor = 256 + BranchFactor2 = BranchFactor(2) + BranchFactor4 = BranchFactor(4) + BranchFactor16 = BranchFactor(16) + BranchFactor256 = BranchFactor(256) ) -func (f BranchFactor) Valid() error { - if _, ok := branchFactorToTokenConfig[f]; ok { - return nil +// Valid checks if BranchFactor [b] is one of the predefined valid options for BranchFactor +func (b BranchFactor) Valid() error { + for _, validBF := range validBranchFactors { + if validBF == b { + return nil + } } - return fmt.Errorf("%w: %d", errInvalidBranchFactor, f) + return fmt.Errorf("%w: %d", ErrInvalidBranchFactor, b) } -type tokenConfig struct { - branchFactor BranchFactor - tokensPerByte int - tokenBitSize byte - singleTokenMask byte +// ToToken creates a key version of the passed byte with bit length equal to tokenSize +func ToToken(val byte, tokenSize int) Key { + return Key{ + value: string([]byte{val << dualBitIndex(tokenSize)}), + length: tokenSize, + } } -type Key struct { - tokenLength int - value string - tokenConfig +// Token returns the token at the specified index, +// Assumes that bitindex + tokenSize doesn't cross a byte boundary +func (k Key) Token(bitIndex int, tokenSize int) byte { + storageByte := k.value[bitIndex/8] + // Shift the byte right to get the last bit to the rightmost position. + storageByte >>= dualBitIndex((bitIndex + tokenSize) % 8) + // Apply a mask to remove any other bits in the byte. + return storageByte & (0xFF >> dualBitIndex(tokenSize)) } -func emptyKey(bf BranchFactor) Key { - return Key{ - tokenConfig: branchFactorToTokenConfig[bf], +// iteratedHasPrefix checks if the provided prefix key is a prefix of the current key starting after the [bitsOffset]th bit +// this has better performance than constructing the actual key via Skip() then calling HasPrefix because it avoids an allocation +func (k Key) iteratedHasPrefix(prefix Key, bitsOffset int, tokenSize int) bool { + if k.length-bitsOffset < prefix.length { + return false } + for i := 0; i < prefix.length; i += tokenSize { + if k.Token(bitsOffset+i, tokenSize) != prefix.Token(i, tokenSize) { + return false + } + } + return true } -// ToKey returns [keyBytes] as a new key with the given [branchFactor]. -// Assumes [branchFactor] is valid. -func ToKey(keyBytes []byte, branchFactor BranchFactor) Key { - tc := branchFactorToTokenConfig[branchFactor] - return Key{ - value: byteSliceToString(keyBytes), - tokenConfig: tc, - tokenLength: len(keyBytes) * tc.tokensPerByte, - } +type Key struct { + // The number of bits in the key. + length int + // The string representation of the key + value string } -// TokensLength returns the number of tokens in [k]. -func (k Key) TokensLength() int { - return k.tokenLength +// ToKey returns [keyBytes] as a new key +// Assumes all bits of the keyBytes are part of the Key, call Key.Take if that is not the case +// Creates a copy of [keyBytes], so keyBytes are safe to edit after the call +func ToKey(keyBytes []byte) Key { + return toKey(slices.Clone(keyBytes)) +} + +// toKey returns [keyBytes] as a new key +// Assumes all bits of the keyBytes are part of the Key, call Key.Take if that is not the case +// Caller must not modify [keyBytes] after this call. +func toKey(keyBytes []byte) Key { + return Key{ + value: byteSliceToString(keyBytes), + length: len(keyBytes) * 8, + } } // hasPartialByte returns true iff the key fits into a non-whole number of bytes func (k Key) hasPartialByte() bool { - return k.tokenLength%k.tokensPerByte > 0 + return k.length%8 > 0 } // HasPrefix returns true iff [prefix] is a prefix of [k] or equal to it. func (k Key) HasPrefix(prefix Key) bool { // [prefix] must be shorter than [k] to be a prefix. - if k.tokenLength < prefix.tokenLength { + if k.length < prefix.length { return false } // The number of tokens in the last byte of [prefix], or zero // if [prefix] fits into a whole number of bytes. - remainderTokensCount := prefix.tokenLength % k.tokensPerByte - if remainderTokensCount == 0 { + remainderBitCount := prefix.length % 8 + if remainderBitCount == 0 { return strings.HasPrefix(k.value, prefix.value) } // check that the tokens in the partially filled final byte of [prefix] are // equal to the tokens in the final byte of [k]. - remainderBitsMask := byte(0xFF >> (remainderTokensCount * int(k.tokenBitSize))) + remainderBitsMask := byte(0xFF >> remainderBitCount) prefixRemainderTokens := prefix.value[len(prefix.value)-1] | remainderBitsMask remainderTokens := k.value[len(prefix.value)-1] | remainderBitsMask @@ -134,130 +157,64 @@ func (k Key) HasStrictPrefix(prefix Key) bool { return k != prefix && k.HasPrefix(prefix) } -// Token returns the token at the specified index, -func (k Key) Token(index int) byte { - // Find the index in [k.value] of the byte containing the token at [index]. - storageByteIndex := index / k.tokensPerByte - storageByte := k.value[storageByteIndex] - // Shift the byte right to get the token to the rightmost position. - storageByte >>= k.bitsToShift(index) - // Apply a mask to remove any other tokens in the byte. - return storageByte & k.singleTokenMask -} - -// Append returns a new Path that equals the current -// Path with [token] appended to the end. -func (k Key) Append(token byte) Key { - buffer := make([]byte, k.bytesNeeded(k.tokenLength+1)) - k.appendIntoBuffer(buffer, token) - return Key{ - value: byteSliceToString(buffer), - tokenLength: k.tokenLength + 1, - tokenConfig: k.tokenConfig, - } +// Length returns the number of bits in the Key +func (k Key) Length() int { + return k.length } // Greater returns true if current Key is greater than other Key func (k Key) Greater(other Key) bool { - return k.value > other.value || (k.value == other.value && k.tokenLength > other.tokenLength) + return k.value > other.value || (k.value == other.value && k.length > other.length) } // Less returns true if current Key is less than other Key func (k Key) Less(other Key) bool { - return k.value < other.value || (k.value == other.value && k.tokenLength < other.tokenLength) + return k.value < other.value || (k.value == other.value && k.length < other.length) } -// bitsToShift returns the number of bits to right shift a token -// within its storage byte to get it to the rightmost -// position in the byte. Equivalently, this is the number of bits -// to left shift a raw token value to get it to the correct position -// within its storage byte. -// Example with branch factor 16: -// Suppose the token array is -// [0x01, 0x02, 0x03, 0x04] -// The byte representation of this array is -// [0b0001_0010, 0b0011_0100] -// To get the token at index 0 (0b0001) to the rightmost position -// in its storage byte (i.e. to make 0b0001_0010 into 0b0000_0001), -// we need to shift 0b0001_0010 to the right by 4 bits. -// Similarly: -// * Token at index 1 (0b0010) needs to be shifted by 0 bits -// * Token at index 2 (0b0011) needs to be shifted by 4 bits -// * Token at index 3 (0b0100) needs to be shifted by 0 bits -func (k Key) bitsToShift(index int) byte { - // [tokenIndex] is the index of the token in the byte. - // For example, if the branch factor is 16, then each byte contains 2 tokens. - // The first is at index 0, and the second is at index 1, by this definition. - tokenIndex := index % k.tokensPerByte - // The bit within the byte that the token starts at. - startBitIndex := k.tokenBitSize * byte(tokenIndex) - // The bit within the byte that the token ends at. - endBitIndex := startBitIndex + k.tokenBitSize - 1 - // We want to right shift until [endBitIndex] is at the last index, so return - // the distance from the end of the byte to the end of the token. - // Note that 7 is the index of the last bit in a byte. - return 7 - endBitIndex -} - -// bytesNeeded returns the number of bytes needed to store the passed number of -// tokens. -// -// Invariant: [tokens] is a non-negative, but otherwise untrusted, input and -// this method must never overflow. -func (k Key) bytesNeeded(tokens int) int { - size := tokens / k.tokensPerByte - if tokens%k.tokensPerByte != 0 { - size++ +// Extend returns a new Key that is the in-order aggregation of Key [k] with [keys] +func (k Key) Extend(keys ...Key) Key { + totalBitLength := k.length + for _, key := range keys { + totalBitLength += key.length } - return size -} - -func (k Key) AppendExtend(token byte, extensionKey Key) Key { - appendBytes := k.bytesNeeded(k.tokenLength + 1) - totalLength := k.tokenLength + 1 + extensionKey.tokenLength - buffer := make([]byte, k.bytesNeeded(totalLength)) - k.appendIntoBuffer(buffer[:appendBytes], token) - - // the extension path will be shifted based on the number of tokens in the partial byte - tokenRemainder := (k.tokenLength + 1) % k.tokensPerByte - result := Key{ - value: byteSliceToString(buffer), - tokenLength: totalLength, - tokenConfig: k.tokenConfig, + buffer := make([]byte, bytesNeeded(totalBitLength)) + copy(buffer, k.value) + currentTotal := k.length + for _, key := range keys { + extendIntoBuffer(buffer, key, currentTotal) + currentTotal += key.length } - extensionBuffer := buffer[appendBytes-1:] - if extensionKey.tokenLength == 0 { - return result + return Key{ + value: byteSliceToString(buffer), + length: totalBitLength, } +} - // If the existing value fits into a whole number of bytes, - // the extension path can be copied directly into the buffer. - if tokenRemainder == 0 { - copy(extensionBuffer[1:], extensionKey.value) - return result +func extendIntoBuffer(buffer []byte, val Key, bitsOffset int) { + if val.length == 0 { + return + } + bytesOffset := bytesNeeded(bitsOffset) + bitsRemainder := bitsOffset % 8 + if bitsRemainder == 0 { + copy(buffer[bytesOffset:], val.value) + return } - // The existing path doesn't fit into a whole number of bytes. - // Figure out how many bits to shift. - shift := extensionKey.bitsToShift(tokenRemainder - 1) // Fill the partial byte with the first [shift] bits of the extension path - extensionBuffer[0] |= extensionKey.value[0] >> (8 - shift) + buffer[bytesOffset-1] |= val.value[0] >> bitsRemainder // copy the rest of the extension path bytes into the buffer, // shifted byte shift bits - shiftCopy(extensionBuffer[1:], extensionKey.value, shift) - - return result + shiftCopy(buffer[bytesOffset:], val.value, dualBitIndex(bitsRemainder)) } -func (k Key) appendIntoBuffer(buffer []byte, token byte) { - copy(buffer, k.value) - - // Shift [token] to the left such that it's at the correct - // index within its storage byte, then OR it with its storage - // byte to write the token into the byte. - buffer[len(buffer)-1] |= token << k.bitsToShift(k.tokenLength) +// dualBitIndex gets the dual of the bit index +// ex: in a byte, the bit 5 from the right is the same as the bit 3 from the left +func dualBitIndex(shift int) int { + return (8 - shift) % 8 } // Treats [src] as a bit array and copies it into [dst] shifted by [shift] bits. @@ -266,10 +223,11 @@ func (k Key) appendIntoBuffer(buffer []byte, token byte) { // Assumes len(dst) >= len(src)-1. // If len(dst) == len(src)-1 the last byte of [src] is only partially copied // (i.e. the rightmost bits are not copied). -func shiftCopy(dst []byte, src string, shift byte) { +func shiftCopy(dst []byte, src string, shift int) { i := 0 + dualShift := dualBitIndex(shift) for ; i < len(src)-1; i++ { - dst[i] = src[i]<>(8-shift) + dst[i] = src[i]<>dualShift } if i < len(dst) { @@ -279,59 +237,56 @@ func shiftCopy(dst []byte, src string, shift byte) { } // Skip returns a new Key that contains the last -// k.length-tokensToSkip tokens of [k]. -func (k Key) Skip(tokensToSkip int) Key { - if k.tokenLength == tokensToSkip { - return emptyKey(k.branchFactor) +// k.length-bitsToSkip bits of [k]. +func (k Key) Skip(bitsToSkip int) Key { + if k.length <= bitsToSkip { + return Key{} } result := Key{ - value: k.value[tokensToSkip/k.tokensPerByte:], - tokenLength: k.tokenLength - tokensToSkip, - tokenConfig: k.tokenConfig, + value: k.value[bitsToSkip/8:], + length: k.length - bitsToSkip, } // if the tokens to skip is a whole number of bytes, // the remaining bytes exactly equals the new key. - if tokensToSkip%k.tokensPerByte == 0 { + if bitsToSkip%8 == 0 { return result } - // tokensToSkip does not remove a whole number of bytes. + // bitsToSkip does not remove a whole number of bytes. // copy the remaining shifted bytes into a new buffer. - buffer := make([]byte, k.bytesNeeded(result.tokenLength)) - bitsSkipped := tokensToSkip * int(k.tokenBitSize) - bitsRemovedFromFirstRemainingByte := byte(bitsSkipped % 8) + buffer := make([]byte, bytesNeeded(result.length)) + bitsRemovedFromFirstRemainingByte := bitsToSkip % 8 shiftCopy(buffer, result.value, bitsRemovedFromFirstRemainingByte) result.value = byteSliceToString(buffer) return result } -// Take returns a new Key that contains the first tokensToTake tokens of the current Key -func (k Key) Take(tokensToTake int) Key { - if k.tokenLength <= tokensToTake { +// Take returns a new Key that contains the first bitsToTake bits of the current Key +func (k Key) Take(bitsToTake int) Key { + if k.length <= bitsToTake { return k } result := Key{ - tokenLength: tokensToTake, - tokenConfig: k.tokenConfig, + length: bitsToTake, } - if !result.hasPartialByte() { - result.value = k.value[:tokensToTake/k.tokensPerByte] + remainderBits := result.length % 8 + if remainderBits == 0 { + result.value = k.value[:bitsToTake/8] return result } // We need to zero out some bits of the last byte so a simple slice will not work // Create a new []byte to store the altered value - buffer := make([]byte, k.bytesNeeded(tokensToTake)) + buffer := make([]byte, bytesNeeded(bitsToTake)) copy(buffer, k.value) - // We want to zero out everything to the right of the last token, which is at index [tokensToTake] - 1 - // Mask will be (8-bitsToShift) number of 1's followed by (bitsToShift) number of 0's - mask := byte(0xFF << k.bitsToShift(tokensToTake-1)) - buffer[len(buffer)-1] &= mask + // We want to zero out everything to the right of the last token, which is at index bitsToTake-1 + // Mask will be (8-remainderBits) number of 1's followed by (remainderBits) number of 0's + buffer[len(buffer)-1] &= byte(0xFF << dualBitIndex(remainderBits)) result.value = byteSliceToString(buffer) return result @@ -345,20 +300,6 @@ func (k Key) Bytes() []byte { return stringToByteSlice(k.value) } -// iteratedHasPrefix checks if the provided prefix path is a prefix of the current path after having skipped [skipTokens] tokens first -// this has better performance than constructing the actual path via Skip() then calling HasPrefix because it avoids the []byte allocation -func (k Key) iteratedHasPrefix(skipTokens int, prefix Key) bool { - if k.tokenLength-skipTokens < prefix.tokenLength { - return false - } - for i := 0; i < prefix.tokenLength; i++ { - if k.Token(skipTokens+i) != prefix.Token(i) { - return false - } - } - return true -} - // byteSliceToString converts the []byte to a string // Invariant: The input []byte must not be modified. func byteSliceToString(bs []byte) string { @@ -374,3 +315,12 @@ func stringToByteSlice(value string) []byte { // "safe" because we never edit the []byte return unsafe.Slice(unsafe.StringData(value), len(value)) } + +// Returns the number of bytes needed to store [bits] bits. +func bytesNeeded(bits int) int { + size := bits / 8 + if bits%8 != 0 { + size++ + } + return size +} diff --git a/x/merkledb/key_test.go b/x/merkledb/key_test.go index e56ee1a98050..f0819483b1a8 100644 --- a/x/merkledb/key_test.go +++ b/x/merkledb/key_test.go @@ -5,48 +5,52 @@ package merkledb import ( "fmt" + "strconv" "testing" "github.com/stretchr/testify/require" ) -var branchFactors = []BranchFactor{ - BranchFactor2, - BranchFactor4, - BranchFactor16, - BranchFactor256, +func TestBranchFactor_Valid(t *testing.T) { + require := require.New(t) + for _, bf := range validBranchFactors { + require.NoError(bf.Valid()) + } + var empty BranchFactor + err := empty.Valid() + require.ErrorIs(err, ErrInvalidBranchFactor) } func TestHasPartialByte(t *testing.T) { - for _, branchFactor := range branchFactors { - t.Run(fmt.Sprint(branchFactor), func(t *testing.T) { + for _, ts := range validTokenSizes { + t.Run(strconv.Itoa(ts), func(t *testing.T) { require := require.New(t) - key := emptyKey(branchFactor) + key := Key{} require.False(key.hasPartialByte()) - if branchFactor == BranchFactor256 { + if ts == 8 { // Tokens are an entire byte so // there is never a partial byte. - key = key.Append(0) + key = key.Extend(ToToken(1, ts)) require.False(key.hasPartialByte()) - key = key.Append(0) + key = key.Extend(ToToken(0, ts)) require.False(key.hasPartialByte()) return } // Fill all but the last token of the first byte. - for i := 0; i < key.tokensPerByte-1; i++ { - key = key.Append(0) + for i := 0; i < 8-ts; i += ts { + key = key.Extend(ToToken(1, ts)) require.True(key.hasPartialByte()) } // Fill the last token of the first byte. - key = key.Append(0) + key = key.Extend(ToToken(0, ts)) require.False(key.hasPartialByte()) // Fill the first token of the second byte. - key = key.Append(0) + key = key.Extend(ToToken(0, ts)) require.True(key.hasPartialByte()) }) } @@ -55,66 +59,71 @@ func TestHasPartialByte(t *testing.T) { func Test_Key_Has_Prefix(t *testing.T) { type test struct { name string - keyA func(bf BranchFactor) Key - keyB func(bf BranchFactor) Key + keyA func(ts int) Key + keyB func(ts int) Key isStrictPrefix bool isPrefix bool } key := "Key" - keyLength := map[BranchFactor]int{} - for _, branchFactor := range branchFactors { - config := branchFactorToTokenConfig[branchFactor] - keyLength[branchFactor] = len(key) * config.tokensPerByte - } tests := []test{ { name: "equal keys", - keyA: func(bf BranchFactor) Key { return ToKey([]byte(key), bf) }, - keyB: func(bf BranchFactor) Key { return ToKey([]byte(key), bf) }, + keyA: func(ts int) Key { return ToKey([]byte(key)) }, + keyB: func(ts int) Key { return ToKey([]byte(key)) }, isPrefix: true, isStrictPrefix: false, }, { - name: "one key has one fewer token", - keyA: func(bf BranchFactor) Key { return ToKey([]byte(key), bf) }, - keyB: func(bf BranchFactor) Key { return ToKey([]byte(key), bf).Take(keyLength[bf] - 1) }, + name: "one key has one fewer token", + keyA: func(ts int) Key { return ToKey([]byte(key)) }, + keyB: func(ts int) Key { + return ToKey([]byte(key)).Take(len(key)*8 - ts) + }, isPrefix: true, isStrictPrefix: true, }, { - name: "equal keys, both have one fewer token", - keyA: func(bf BranchFactor) Key { return ToKey([]byte(key), bf).Take(keyLength[bf] - 1) }, - keyB: func(bf BranchFactor) Key { return ToKey([]byte(key), bf).Take(keyLength[bf] - 1) }, + name: "equal keys, both have one fewer token", + keyA: func(ts int) Key { + return ToKey([]byte(key)).Take(len(key)*8 - ts) + }, + keyB: func(ts int) Key { + return ToKey([]byte(key)).Take(len(key)*8 - ts) + }, isPrefix: true, isStrictPrefix: false, }, { name: "different keys", - keyA: func(bf BranchFactor) Key { return ToKey([]byte{0xF7}, bf) }, - keyB: func(bf BranchFactor) Key { return ToKey([]byte{0xF0}, bf) }, + keyA: func(ts int) Key { return ToKey([]byte{0xF7}) }, + keyB: func(ts int) Key { return ToKey([]byte{0xF0}) }, isPrefix: false, isStrictPrefix: false, }, { - name: "same bytes, different lengths", - keyA: func(bf BranchFactor) Key { return ToKey([]byte{0x10, 0x00}, bf).Take(1) }, - keyB: func(bf BranchFactor) Key { return ToKey([]byte{0x10, 0x00}, bf).Take(2) }, + name: "same bytes, different lengths", + keyA: func(ts int) Key { + return ToKey([]byte{0x10, 0x00}).Take(ts) + }, + keyB: func(ts int) Key { + return ToKey([]byte{0x10, 0x00}).Take(ts * 2) + }, isPrefix: false, isStrictPrefix: false, }, } for _, tt := range tests { - for _, bf := range branchFactors { - t.Run(tt.name+" bf "+fmt.Sprint(bf), func(t *testing.T) { + for _, ts := range validTokenSizes { + t.Run(tt.name+" ts "+strconv.Itoa(ts), func(t *testing.T) { require := require.New(t) - keyA := tt.keyA(bf) - keyB := tt.keyB(bf) + keyA := tt.keyA(ts) + keyB := tt.keyB(ts) require.Equal(tt.isPrefix, keyA.HasPrefix(keyB)) - require.Equal(tt.isPrefix, keyA.iteratedHasPrefix(0, keyB)) + require.Equal(tt.isPrefix, keyA.iteratedHasPrefix(keyB, 0, ts)) require.Equal(tt.isStrictPrefix, keyA.HasStrictPrefix(keyB)) }) } @@ -124,30 +133,29 @@ func Test_Key_Has_Prefix(t *testing.T) { func Test_Key_Skip(t *testing.T) { require := require.New(t) - for _, bf := range branchFactors { - empty := emptyKey(bf) - require.Equal(ToKey([]byte{0}, bf).Skip(empty.tokensPerByte), empty) - if bf == BranchFactor256 { + empty := Key{} + require.Equal(ToKey([]byte{0}).Skip(8), empty) + for _, ts := range validTokenSizes { + if ts == 8 { continue } - shortKey := ToKey([]byte{0b0101_0101}, bf) - longKey := ToKey([]byte{0b0101_0101, 0b0101_0101}, bf) - for i := 0; i < shortKey.tokensPerByte; i++ { - shift := byte(i) * shortKey.tokenBitSize - skipKey := shortKey.Skip(i) + shortKey := ToKey([]byte{0b0101_0101}) + longKey := ToKey([]byte{0b0101_0101, 0b0101_0101}) + for shift := 0; shift < 8; shift += ts { + skipKey := shortKey.Skip(shift) require.Equal(byte(0b0101_0101<>(8-shift)), skipKey.value[0]) require.Equal(byte(0b0101_0101<>shift)< ts { + key1 = key1.Take(key1.length - ts) + } + key2 := ToKey(second) + if forceSecondOdd && key2.length > ts { + key2 = key2.Take(key2.length - ts) + } + token := byte(int(tokenByte) % int(tokenSizeToBranchFactor[ts])) + extendedP := key1.Extend(ToToken(token, ts), key2) + require.Equal(key1.length+key2.length+ts, extendedP.length) + firstIndex := 0 + for ; firstIndex < key1.length; firstIndex += ts { + require.Equal(key1.Token(firstIndex, ts), extendedP.Token(firstIndex, ts)) + } + require.Equal(token, extendedP.Token(firstIndex, ts)) + firstIndex += ts + for secondIndex := 0; secondIndex < key2.length; secondIndex += ts { + require.Equal(key2.Token(secondIndex, ts), extendedP.Token(firstIndex+secondIndex, ts)) + } + } + }) +} + +func FuzzKeyDoubleExtend_Any(f *testing.F) { + f.Fuzz(func( + t *testing.T, + baseKeyBytes []byte, + firstKeyBytes []byte, + secondKeyBytes []byte, + forceBaseOdd bool, forceFirstOdd bool, forceSecondOdd bool, ) { require := require.New(t) - for _, branchFactor := range branchFactors { - key1 := ToKey(first, branchFactor) - if forceFirstOdd && key1.tokenLength > 0 { - key1 = key1.Take(key1.tokenLength - 1) + for _, ts := range validTokenSizes { + baseKey := ToKey(baseKeyBytes) + if forceBaseOdd && baseKey.length > ts { + baseKey = baseKey.Take(baseKey.length - ts) + } + firstKey := ToKey(firstKeyBytes) + if forceFirstOdd && firstKey.length > ts { + firstKey = firstKey.Take(firstKey.length - ts) } - key2 := ToKey(second, branchFactor) - if forceSecondOdd && key2.tokenLength > 0 { - key2 = key2.Take(key2.tokenLength - 1) + + secondKey := ToKey(secondKeyBytes) + if forceSecondOdd && secondKey.length > ts { + secondKey = secondKey.Take(secondKey.length - ts) } - token = byte(int(token) % int(branchFactor)) - extendedP := key1.AppendExtend(token, key2) - require.Equal(key1.tokenLength+key2.tokenLength+1, extendedP.tokenLength) - for i := 0; i < key1.tokenLength; i++ { - require.Equal(key1.Token(i), extendedP.Token(i)) + + extendedP := baseKey.Extend(firstKey, secondKey) + require.Equal(baseKey.length+firstKey.length+secondKey.length, extendedP.length) + totalIndex := 0 + for baseIndex := 0; baseIndex < baseKey.length; baseIndex += ts { + require.Equal(baseKey.Token(baseIndex, ts), extendedP.Token(baseIndex, ts)) } - require.Equal(token, extendedP.Token(key1.tokenLength)) - for i := 0; i < key2.tokenLength; i++ { - require.Equal(key2.Token(i), extendedP.Token(i+1+key1.tokenLength)) + totalIndex += baseKey.length + for firstIndex := 0; firstIndex < firstKey.length; firstIndex += ts { + require.Equal(firstKey.Token(firstIndex, ts), extendedP.Token(totalIndex+firstIndex, ts)) + } + totalIndex += firstKey.length + for secondIndex := 0; secondIndex < secondKey.length; secondIndex += ts { + require.Equal(secondKey.Token(secondIndex, ts), extendedP.Token(totalIndex+secondIndex, ts)) } } }) @@ -509,15 +478,18 @@ func FuzzKeySkip(f *testing.F) { tokensToSkip uint, ) { require := require.New(t) - for _, branchFactor := range branchFactors { - key1 := ToKey(first, branchFactor) - if int(tokensToSkip) >= key1.tokenLength { + key1 := ToKey(first) + for _, ts := range validTokenSizes { + // need bits to be a multiple of token size + ubitsToSkip := tokensToSkip * uint(ts) + if ubitsToSkip >= uint(key1.length) { t.SkipNow() } - key2 := key1.Skip(int(tokensToSkip)) - require.Equal(key1.tokenLength-int(tokensToSkip), key2.tokenLength) - for i := 0; i < key2.tokenLength; i++ { - require.Equal(key1.Token(int(tokensToSkip)+i), key2.Token(i)) + bitsToSkip := int(ubitsToSkip) + key2 := key1.Skip(bitsToSkip) + require.Equal(key1.length-bitsToSkip, key2.length) + for i := 0; i < key2.length; i += ts { + require.Equal(key1.Token(bitsToSkip+i, ts), key2.Token(i, ts)) } } }) @@ -527,19 +499,24 @@ func FuzzKeyTake(f *testing.F) { f.Fuzz(func( t *testing.T, first []byte, - tokensToTake uint, + uTokensToTake uint, ) { require := require.New(t) - for _, branchFactor := range branchFactors { - key1 := ToKey(first, branchFactor) - if int(tokensToTake) >= key1.tokenLength { + for _, ts := range validTokenSizes { + key1 := ToKey(first) + uBitsToTake := uTokensToTake * uint(ts) + if uBitsToTake >= uint(key1.length) { t.SkipNow() } - key2 := key1.Take(int(tokensToTake)) - require.Equal(int(tokensToTake), key2.tokenLength) - - for i := 0; i < key2.tokenLength; i++ { - require.Equal(key1.Token(i), key2.Token(i)) + bitsToTake := int(uBitsToTake) + key2 := key1.Take(bitsToTake) + require.Equal(bitsToTake, key2.length) + if key2.hasPartialByte() { + paddingMask := byte(0xFF >> (key2.length % 8)) + require.Zero(key2.value[len(key2.value)-1] & paddingMask) + } + for i := 0; i < bitsToTake; i += ts { + require.Equal(key1.Token(i, ts), key2.Token(i, ts)) } } }) @@ -550,7 +527,7 @@ func TestShiftCopy(t *testing.T) { dst []byte src []byte expected []byte - shift byte + shift int } tests := []test{ diff --git a/x/merkledb/node.go b/x/merkledb/node.go index 259e048c1793..3fd38021a0c8 100644 --- a/x/merkledb/node.go +++ b/x/merkledb/node.go @@ -14,13 +14,6 @@ import ( const HashLength = 32 -// the values that go into the node's id -type hashValues struct { - Children map[byte]child - Value maybe.Maybe[[]byte] - Key Key -} - // Representation of a node stored in the database. type dbNode struct { value maybe.Maybe[[]byte] @@ -43,24 +36,19 @@ type node struct { } // Returns a new node with the given [key] and no value. -// If [parent] isn't nil, the new node is added as a child of [parent]. -func newNode(parent *node, key Key) *node { - newNode := &node{ +func newNode(key Key) *node { + return &node{ dbNode: dbNode{ - children: make(map[byte]child, key.branchFactor), + children: make(map[byte]child, 2), }, key: key, } - if parent != nil { - parent.addChild(newNode) - } - return newNode } // Parse [nodeBytes] to a node and set its key to [key]. func parseNode(key Key, nodeBytes []byte) (*node, error) { n := dbNode{} - if err := codec.decodeDBNode(nodeBytes, &n, key.branchFactor); err != nil { + if err := codec.decodeDBNode(nodeBytes, &n); err != nil { return nil, err } result := &node{ @@ -101,11 +89,7 @@ func (n *node) calculateID(metrics merkleMetrics) { } metrics.HashCalculated() - bytes := codec.encodeHashValues(&hashValues{ - Children: n.children, - Value: n.valueDigest, - Key: n.key, - }) + bytes := codec.encodeHashValues(n) n.id = hashing.ComputeHash256Array(bytes) } @@ -127,11 +111,11 @@ func (n *node) setValueDigest() { // Adds [child] as a child of [n]. // Assumes [child]'s key is valid as a child of [n]. // That is, [n.key] is a prefix of [child.key]. -func (n *node) addChild(childNode *node) { +func (n *node) addChild(childNode *node, tokenSize int) { n.setChildEntry( - childNode.key.Token(n.key.tokenLength), + childNode.key.Token(n.key.length, tokenSize), child{ - compressedKey: childNode.key.Skip(n.key.tokenLength + 1), + compressedKey: childNode.key.Skip(n.key.length + tokenSize), id: childNode.id, hasValue: childNode.hasValue(), }, @@ -145,9 +129,9 @@ func (n *node) setChildEntry(index byte, childEntry child) { } // Removes [child] from [n]'s children. -func (n *node) removeChild(child *node) { +func (n *node) removeChild(child *node, tokenSize int) { n.onNodeChanged() - delete(n.children, child.key.Token(n.key.tokenLength)) + delete(n.children, child.key.Token(n.key.length, tokenSize)) } // clone Returns a copy of [n]. diff --git a/x/merkledb/node_test.go b/x/merkledb/node_test.go index 9632b7c7dacb..e0cb4dd04b06 100644 --- a/x/merkledb/node_test.go +++ b/x/merkledb/node_test.go @@ -13,54 +13,57 @@ import ( ) func Test_Node_Marshal(t *testing.T) { - root := newNode(nil, emptyKey(BranchFactor16)) + root := newNode(Key{}) require.NotNil(t, root) - fullKey := ToKey([]byte("key"), BranchFactor16) - childNode := newNode(root, fullKey) + fullKey := ToKey([]byte("key")) + childNode := newNode(fullKey) + root.addChild(childNode, 4) childNode.setValue(maybe.Some([]byte("value"))) require.NotNil(t, childNode) childNode.calculateID(&mockMetrics{}) - root.addChild(childNode) + root.addChild(childNode, 4) data := root.bytes() - rootParsed, err := parseNode(ToKey([]byte(""), BranchFactor16), data) + rootParsed, err := parseNode(ToKey([]byte("")), data) require.NoError(t, err) require.Len(t, rootParsed.children, 1) - rootIndex := getSingleChildKey(root).Token(root.key.tokenLength) - parsedIndex := getSingleChildKey(rootParsed).Token(rootParsed.key.tokenLength) + rootIndex := getSingleChildKey(root, 4).Token(0, 4) + parsedIndex := getSingleChildKey(rootParsed, 4).Token(0, 4) rootChildEntry := root.children[rootIndex] parseChildEntry := rootParsed.children[parsedIndex] require.Equal(t, rootChildEntry.id, parseChildEntry.id) } func Test_Node_Marshal_Errors(t *testing.T) { - root := newNode(nil, emptyKey(BranchFactor16)) + root := newNode(Key{}) require.NotNil(t, root) - fullKey := ToKey([]byte{255}, BranchFactor16) - childNode1 := newNode(root, fullKey) + fullKey := ToKey([]byte{255}) + childNode1 := newNode(fullKey) + root.addChild(childNode1, 4) childNode1.setValue(maybe.Some([]byte("value1"))) require.NotNil(t, childNode1) childNode1.calculateID(&mockMetrics{}) - root.addChild(childNode1) + root.addChild(childNode1, 4) - fullKey = ToKey([]byte{237}, BranchFactor16) - childNode2 := newNode(root, fullKey) + fullKey = ToKey([]byte{237}) + childNode2 := newNode(fullKey) + root.addChild(childNode2, 4) childNode2.setValue(maybe.Some([]byte("value2"))) require.NotNil(t, childNode2) childNode2.calculateID(&mockMetrics{}) - root.addChild(childNode2) + root.addChild(childNode2, 4) data := root.bytes() for i := 1; i < len(data); i++ { broken := data[:i] - _, err := parseNode(ToKey([]byte(""), BranchFactor16), broken) + _, err := parseNode(ToKey([]byte("")), broken) require.ErrorIs(t, err, io.ErrUnexpectedEOF) } } diff --git a/x/merkledb/proof.go b/x/merkledb/proof.go index 63ea34542c9b..f750158d4c11 100644 --- a/x/merkledb/proof.go +++ b/x/merkledb/proof.go @@ -31,8 +31,6 @@ var ( ErrNonIncreasingValues = errors.New("keys sent are not in increasing order") ErrStateFromOutsideOfRange = errors.New("state key falls outside of the start->end range") ErrNonIncreasingProofNodes = errors.New("each proof node key must be a strict prefix of the next") - ErrExtraProofNodes = errors.New("extra proof nodes in path") - ErrDataInMissingRootProof = errors.New("there should be no state or deleted keys in a change proof that had a missing root") ErrNoMerkleProof = errors.New("empty key response must include merkle proof") ErrShouldJustBeRoot = errors.New("end proof should only contain root") ErrNoStartProof = errors.New("no start proof") @@ -42,7 +40,6 @@ var ( ErrProofValueDoesntMatch = errors.New("the provided value does not match the proof node for the provided key's value") ErrProofNodeHasUnincludedValue = errors.New("the provided proof has a value for a key within the range that is not present in the provided key/values") ErrInvalidMaybe = errors.New("maybe is nothing but has value") - ErrInvalidChildIndex = errors.New("child index must be less than branch factor") ErrNilProofNode = errors.New("proof node is nil") ErrNilValueOrHash = errors.New("proof node's valueOrHash field is nil") ErrNilKey = errors.New("key is nil") @@ -69,7 +66,7 @@ type ProofNode struct { func (node *ProofNode) ToProto() *pb.ProofNode { pbNode := &pb.ProofNode{ Key: &pb.Key{ - Length: uint64(node.Key.tokenLength), + Length: uint64(node.Key.length), Value: node.Key.Bytes(), }, ValueOrHash: &pb.MaybeBytes{ @@ -87,7 +84,7 @@ func (node *ProofNode) ToProto() *pb.ProofNode { return pbNode } -func (node *ProofNode) UnmarshalProto(pbNode *pb.ProofNode, bf BranchFactor) error { +func (node *ProofNode) UnmarshalProto(pbNode *pb.ProofNode) error { switch { case pbNode == nil: return ErrNilProofNode @@ -97,17 +94,14 @@ func (node *ProofNode) UnmarshalProto(pbNode *pb.ProofNode, bf BranchFactor) err return ErrInvalidMaybe case pbNode.Key == nil: return ErrNilKey - } - node.Key = ToKey(pbNode.Key.Value, bf).Take(int(pbNode.Key.Length)) - - if len(pbNode.Key.Value) != node.Key.bytesNeeded(node.Key.tokenLength) { + case len(pbNode.Key.Value) != bytesNeeded(int(pbNode.Key.Length)): return ErrInvalidKeyLength } - + node.Key = ToKey(pbNode.Key.Value).Take(int(pbNode.Key.Length)) node.Children = make(map[byte]ids.ID, len(pbNode.Children)) for childIndex, childIDBytes := range pbNode.Children { - if childIndex >= uint32(bf) { - return ErrInvalidChildIndex + if childIndex > math.MaxUint8 { + return errChildIndexTooLarge } childID, err := ids.ToID(childIDBytes) if err != nil { @@ -123,7 +117,7 @@ func (node *ProofNode) UnmarshalProto(pbNode *pb.ProofNode, bf BranchFactor) err return nil } -// An inclusion/exclustion proof of a key. +// Proof represents an inclusion/exclusion proof of a key. type Proof struct { // Nodes in the proof path from root --> target key // (or node that would be where key is if it doesn't exist). @@ -140,7 +134,7 @@ type Proof struct { // Returns nil if the trie given in [proof] has root [expectedRootID]. // That is, this is a valid proof that [proof.Key] exists/doesn't exist // in the trie with root [expectedRootID]. -func (proof *Proof) Verify(ctx context.Context, expectedRootID ids.ID) error { +func (proof *Proof) Verify(ctx context.Context, expectedRootID ids.ID, tokenSize int) error { // Make sure the proof is well-formed. if len(proof.Path) == 0 { return ErrNoProof @@ -172,7 +166,7 @@ func (proof *Proof) Verify(ctx context.Context, expectedRootID ids.ID) error { } // Don't bother locking [view] -- nobody else has a reference to it. - view, err := getStandaloneTrieView(ctx, nil, proof.Key.branchFactor) + view, err := getStandaloneTrieView(ctx, nil, tokenSize) if err != nil { return err } @@ -215,7 +209,7 @@ func (proof *Proof) ToProto() *pb.Proof { return pbProof } -func (proof *Proof) UnmarshalProto(pbProof *pb.Proof, bf BranchFactor) error { +func (proof *Proof) UnmarshalProto(pbProof *pb.Proof) error { switch { case pbProof == nil: return ErrNilProof @@ -225,7 +219,7 @@ func (proof *Proof) UnmarshalProto(pbProof *pb.Proof, bf BranchFactor) error { return ErrInvalidMaybe } - proof.Key = ToKey(pbProof.Key, bf) + proof.Key = ToKey(pbProof.Key) if !pbProof.Value.IsNothing { proof.Value = maybe.Some(pbProof.Value.Value) @@ -233,7 +227,7 @@ func (proof *Proof) UnmarshalProto(pbProof *pb.Proof, bf BranchFactor) error { proof.Path = make([]ProofNode, len(pbProof.Proof)) for i, pbNode := range pbProof.Proof { - if err := proof.Path[i].UnmarshalProto(pbNode, bf); err != nil { + if err := proof.Path[i].UnmarshalProto(pbNode); err != nil { return err } } @@ -287,6 +281,7 @@ func (proof *RangeProof) Verify( start maybe.Maybe[[]byte], end maybe.Maybe[[]byte], expectedRootID ids.ID, + tokenSize int, ) error { switch { case start.HasValue() && end.HasValue() && bytes.Compare(start.Value(), end.Value()) > 0: @@ -301,15 +296,6 @@ func (proof *RangeProof) Verify( return ErrNoEndProof } - // determine branch factor based on proof paths - var branchFactor BranchFactor - if len(proof.StartProof) > 0 { - branchFactor = proof.StartProof[0].Key.branchFactor - } else { - // safe because invariants prevent both start proof and end proof from being empty at the same time - branchFactor = proof.EndProof[0].Key.branchFactor - } - // Make sure the key-value pairs are sorted and in [start, end]. if err := verifyKeyValues(proof.KeyValues, start, end); err != nil { return err @@ -322,24 +308,21 @@ func (proof *RangeProof) Verify( // If [largestProvenPath] is Nothing, [proof] should // provide and prove all keys > [smallestProvenPath]. // If both are Nothing, [proof] should prove the entire trie. - smallestProvenPath := maybe.Bind(start, func(b []byte) Key { - return ToKey(b, branchFactor) - }) + smallestProvenPath := maybe.Bind(start, ToKey) + + largestProvenPath := maybe.Bind(end, ToKey) - largestProvenPath := maybe.Bind(end, func(b []byte) Key { - return ToKey(b, branchFactor) - }) if len(proof.KeyValues) > 0 { // If [proof] has key-value pairs, we should insert children // greater than [largestProvenPath] to ancestors of the node containing // [largestProvenPath] so that we get the expected root ID. - largestProvenPath = maybe.Some(ToKey(proof.KeyValues[len(proof.KeyValues)-1].Key, branchFactor)) + largestProvenPath = maybe.Some(ToKey(proof.KeyValues[len(proof.KeyValues)-1].Key)) } // The key-value pairs (allegedly) proven by [proof]. keyValues := make(map[Key][]byte, len(proof.KeyValues)) for _, keyValue := range proof.KeyValues { - keyValues[ToKey(keyValue.Key, branchFactor)] = keyValue.Value + keyValues[ToKey(keyValue.Key)] = keyValue.Value } // Ensure that the start proof is valid and contains values that @@ -380,7 +363,7 @@ func (proof *RangeProof) Verify( } // Don't need to lock [view] because nobody else has a reference to it. - view, err := getStandaloneTrieView(ctx, ops, branchFactor) + view, err := getStandaloneTrieView(ctx, ops, tokenSize) if err != nil { return err } @@ -444,21 +427,21 @@ func (proof *RangeProof) ToProto() *pb.RangeProof { } } -func (proof *RangeProof) UnmarshalProto(pbProof *pb.RangeProof, bf BranchFactor) error { +func (proof *RangeProof) UnmarshalProto(pbProof *pb.RangeProof) error { if pbProof == nil { return ErrNilRangeProof } proof.StartProof = make([]ProofNode, len(pbProof.StartProof)) for i, protoNode := range pbProof.StartProof { - if err := proof.StartProof[i].UnmarshalProto(protoNode, bf); err != nil { + if err := proof.StartProof[i].UnmarshalProto(protoNode); err != nil { return err } } proof.EndProof = make([]ProofNode, len(pbProof.EndProof)) for i, protoNode := range pbProof.EndProof { - if err := proof.EndProof[i].UnmarshalProto(protoNode, bf); err != nil { + if err := proof.EndProof[i].UnmarshalProto(protoNode); err != nil { return err } } @@ -596,21 +579,21 @@ func (proof *ChangeProof) ToProto() *pb.ChangeProof { } } -func (proof *ChangeProof) UnmarshalProto(pbProof *pb.ChangeProof, bf BranchFactor) error { +func (proof *ChangeProof) UnmarshalProto(pbProof *pb.ChangeProof) error { if pbProof == nil { return ErrNilChangeProof } proof.StartProof = make([]ProofNode, len(pbProof.StartProof)) for i, protoNode := range pbProof.StartProof { - if err := proof.StartProof[i].UnmarshalProto(protoNode, bf); err != nil { + if err := proof.StartProof[i].UnmarshalProto(protoNode); err != nil { return err } } proof.EndProof = make([]ProofNode, len(pbProof.EndProof)) for i, protoNode := range pbProof.EndProof { - if err := proof.EndProof[i].UnmarshalProto(protoNode, bf); err != nil { + if err := proof.EndProof[i].UnmarshalProto(protoNode); err != nil { return err } } @@ -754,10 +737,8 @@ func verifyProofPath(proof []ProofNode, key maybe.Maybe[Key]) error { // loop over all but the last node since it will not have the prefix in exclusion proofs for i := 0; i < len(proof)-1; i++ { - nodeKey := proof[i].Key - if key.HasValue() && nodeKey.branchFactor != key.Value().branchFactor { - return ErrInconsistentBranchFactor - } + currentProofNode := proof[i] + nodeKey := currentProofNode.Key // Because the interface only support []byte keys, // a key with a partial byte should store a value @@ -770,11 +751,8 @@ func verifyProofPath(proof []ProofNode, key maybe.Maybe[Key]) error { return ErrProofNodeNotForKey } - // each node should have a key that has a matching BranchFactor and is a prefix of the next node's key + // each node should have a key that has a matching TokenConfig and is a prefix of the next node's key nextKey := proof[i+1].Key - if nextKey.branchFactor != nodeKey.branchFactor { - return ErrInconsistentBranchFactor - } if !nextKey.HasStrictPrefix(nodeKey) { return ErrNonIncreasingProofNodes } @@ -857,12 +835,12 @@ func addPathInfo( // Add [proofNode]'s children which are outside the range // [insertChildrenLessThan, insertChildrenGreaterThan]. - compressedPath := emptyKey(key.branchFactor) + compressedKey := Key{} for index, childID := range proofNode.Children { if existingChild, ok := n.children[index]; ok { - compressedPath = existingChild.compressedKey + compressedKey = existingChild.compressedKey } - childPath := key.AppendExtend(index, compressedPath) + childPath := key.Extend(ToToken(index, t.tokenSize), compressedKey) if (shouldInsertLeftChildren && childPath.Less(insertChildrenLessThan.Value())) || (shouldInsertRightChildren && childPath.Greater(insertChildrenGreaterThan.Value())) { // We didn't set the other values on the child entry, but it doesn't matter. @@ -871,7 +849,7 @@ func addPathInfo( index, child{ id: childID, - compressedKey: compressedPath, + compressedKey: compressedKey, }) } } @@ -881,7 +859,7 @@ func addPathInfo( } // getStandaloneTrieView returns a new view that has nothing in it besides the changes due to [ops] -func getStandaloneTrieView(ctx context.Context, ops []database.BatchOp, factor BranchFactor) (*trieView, error) { +func getStandaloneTrieView(ctx context.Context, ops []database.BatchOp, size int) (*trieView, error) { db, err := newDatabase( ctx, memdb.New(), @@ -890,7 +868,7 @@ func getStandaloneTrieView(ctx context.Context, ops []database.BatchOp, factor B Tracer: trace.Noop, ValueNodeCacheSize: verificationCacheSize, IntermediateNodeCacheSize: verificationCacheSize, - BranchFactor: factor, + BranchFactor: tokenSizeToBranchFactor[size], }, &mockMetrics{}, ) diff --git a/x/merkledb/proof_test.go b/x/merkledb/proof_test.go index bf9d9da18996..508e3d545d76 100644 --- a/x/merkledb/proof_test.go +++ b/x/merkledb/proof_test.go @@ -23,7 +23,7 @@ import ( func Test_Proof_Empty(t *testing.T) { proof := &Proof{} - err := proof.Verify(context.Background(), ids.Empty) + err := proof.Verify(context.Background(), ids.Empty, 4) require.ErrorIs(t, err, ErrNoProof) } @@ -43,7 +43,7 @@ func Test_Proof_Simple(t *testing.T) { proof, err := db.GetProof(ctx, []byte{}) require.NoError(err) - require.NoError(proof.Verify(ctx, expectedRoot)) + require.NoError(proof.Verify(ctx, expectedRoot, 4)) } func Test_Proof_Verify_Bad_Data(t *testing.T) { @@ -112,7 +112,7 @@ func Test_Proof_Verify_Bad_Data(t *testing.T) { tt.malform(proof) - err = proof.Verify(context.Background(), db.getMerkleRoot()) + err = proof.Verify(context.Background(), db.getMerkleRoot(), 4) require.ErrorIs(err, tt.expectedErr) }) } @@ -151,6 +151,7 @@ func Test_RangeProof_Extra_Value(t *testing.T) { maybe.Some([]byte{1}), maybe.Some([]byte{5, 5}), db.root.id, + db.tokenSize, )) proof.KeyValues = append(proof.KeyValues, KeyValue{Key: []byte{5}, Value: []byte{5}}) @@ -160,6 +161,7 @@ func Test_RangeProof_Extra_Value(t *testing.T) { maybe.Some([]byte{1}), maybe.Some([]byte{5, 5}), db.root.id, + db.tokenSize, ) require.ErrorIs(err, ErrInvalidProof) } @@ -221,7 +223,7 @@ func Test_RangeProof_Verify_Bad_Data(t *testing.T) { tt.malform(proof) - err = proof.Verify(context.Background(), maybe.Some([]byte{2}), maybe.Some([]byte{3, 0}), db.getMerkleRoot()) + err = proof.Verify(context.Background(), maybe.Some([]byte{2}), maybe.Some([]byte{3, 0}), db.getMerkleRoot(), db.tokenSize) require.ErrorIs(err, tt.expectedErr) }) } @@ -271,19 +273,19 @@ func Test_Proof(t *testing.T) { require.Len(proof.Path, 3) - require.Equal(ToKey([]byte("key1"), BranchFactor16), proof.Path[2].Key) + require.Equal(ToKey([]byte("key1")), proof.Path[2].Key) require.Equal(maybe.Some([]byte("value1")), proof.Path[2].ValueOrHash) - require.Equal(ToKey([]byte{}, BranchFactor16), proof.Path[0].Key) + require.Equal(ToKey([]byte{}), proof.Path[0].Key) require.True(proof.Path[0].ValueOrHash.IsNothing()) expectedRootID, err := trie.GetMerkleRoot(context.Background()) require.NoError(err) - require.NoError(proof.Verify(context.Background(), expectedRootID)) + require.NoError(proof.Verify(context.Background(), expectedRootID, dbTrie.tokenSize)) proof.Path[0].ValueOrHash = maybe.Some([]byte("value2")) - err = proof.Verify(context.Background(), expectedRootID) + err = proof.Verify(context.Background(), expectedRootID, dbTrie.tokenSize) require.ErrorIs(err, ErrInvalidProof) } @@ -357,7 +359,7 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { {Key: []byte{1}, Value: []byte{1}}, {Key: []byte{0}, Value: []byte{0}}, }, - EndProof: []ProofNode{{Key: emptyKey(BranchFactor16)}}, + EndProof: []ProofNode{{Key: Key{}}}, }, expectedErr: ErrNonIncreasingValues, }, @@ -369,7 +371,7 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { KeyValues: []KeyValue{ {Key: []byte{0}, Value: []byte{0}}, }, - EndProof: []ProofNode{{Key: emptyKey(BranchFactor16)}}, + EndProof: []ProofNode{{Key: Key{}}}, }, expectedErr: ErrStateFromOutsideOfRange, }, @@ -381,7 +383,7 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { KeyValues: []KeyValue{ {Key: []byte{2}, Value: []byte{0}}, }, - EndProof: []ProofNode{{Key: emptyKey(BranchFactor16)}}, + EndProof: []ProofNode{{Key: Key{}}}, }, expectedErr: ErrStateFromOutsideOfRange, }, @@ -395,13 +397,13 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { }, StartProof: []ProofNode{ { - Key: ToKey([]byte{2}, BranchFactor16), + Key: ToKey([]byte{2}), }, { - Key: ToKey([]byte{1}, BranchFactor16), + Key: ToKey([]byte{1}), }, }, - EndProof: []ProofNode{{Key: emptyKey(BranchFactor16)}}, + EndProof: []ProofNode{{Key: Key{}}}, }, expectedErr: ErrProofNodeNotForKey, }, @@ -415,16 +417,16 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { }, StartProof: []ProofNode{ { - Key: ToKey([]byte{1}, BranchFactor16), + Key: ToKey([]byte{1}), }, { - Key: ToKey([]byte{1, 2, 3}, BranchFactor16), // Not a prefix of [1, 2] + Key: ToKey([]byte{1, 2, 3}), // Not a prefix of [1, 2] }, { - Key: ToKey([]byte{1, 2, 3, 4}, BranchFactor16), + Key: ToKey([]byte{1, 2, 3, 4}), }, }, - EndProof: []ProofNode{{Key: emptyKey(BranchFactor16)}}, + EndProof: []ProofNode{{Key: Key{}}}, }, expectedErr: ErrProofNodeNotForKey, }, @@ -438,39 +440,15 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { }, EndProof: []ProofNode{ { - Key: ToKey([]byte{2}, BranchFactor16), + Key: ToKey([]byte{2}), }, { - Key: ToKey([]byte{1}, BranchFactor16), + Key: ToKey([]byte{1}), }, }, }, expectedErr: ErrProofNodeNotForKey, }, - { - name: "inconsistent branching factor", - start: maybe.Some([]byte{1, 2}), - end: maybe.Some([]byte{1, 2}), - proof: &RangeProof{ - StartProof: []ProofNode{ - { - Key: ToKey([]byte{1}, BranchFactor16), - }, - { - Key: ToKey([]byte{1, 2}, BranchFactor16), - }, - }, - EndProof: []ProofNode{ - { - Key: ToKey([]byte{1}, BranchFactor4), - }, - { - Key: ToKey([]byte{1, 2}, BranchFactor4), - }, - }, - }, - expectedErr: ErrInconsistentBranchFactor, - }, { name: "end proof has node for wrong key", start: maybe.Nothing[[]byte](), @@ -481,13 +459,13 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { }, EndProof: []ProofNode{ { - Key: ToKey([]byte{1}, BranchFactor16), + Key: ToKey([]byte{1}), }, { - Key: ToKey([]byte{1, 2, 3}, BranchFactor16), // Not a prefix of [1, 2] + Key: ToKey([]byte{1, 2, 3}), // Not a prefix of [1, 2] }, { - Key: ToKey([]byte{1, 2, 3, 4}, BranchFactor16), + Key: ToKey([]byte{1, 2, 3, 4}), }, }, }, @@ -497,7 +475,7 @@ func Test_RangeProof_Syntactic_Verify(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := tt.proof.Verify(context.Background(), tt.start, tt.end, ids.Empty) + err := tt.proof.Verify(context.Background(), tt.start, tt.end, ids.Empty, 4) require.ErrorIs(t, err, tt.expectedErr) }) } @@ -535,6 +513,7 @@ func Test_RangeProof(t *testing.T) { maybe.Some([]byte{1}), maybe.Some([]byte{3, 5}), db.root.id, + db.tokenSize, )) } @@ -578,15 +557,16 @@ func Test_RangeProof_NilStart(t *testing.T) { require.Equal([]byte("value1"), proof.KeyValues[0].Value) require.Equal([]byte("value2"), proof.KeyValues[1].Value) - require.Equal(ToKey([]byte("key2"), BranchFactor16), proof.EndProof[2].Key, BranchFactor16) - require.Equal(ToKey([]byte("key2"), BranchFactor16).Take(7), proof.EndProof[1].Key) - require.Equal(ToKey([]byte(""), BranchFactor16), proof.EndProof[0].Key, BranchFactor16) + require.Equal(ToKey([]byte("key2")), proof.EndProof[2].Key) + require.Equal(ToKey([]byte("key2")).Take(28), proof.EndProof[1].Key) + require.Equal(ToKey([]byte("")), proof.EndProof[0].Key) require.NoError(proof.Verify( context.Background(), maybe.Nothing[[]byte](), maybe.Some([]byte("key35")), db.root.id, + db.tokenSize, )) } @@ -621,6 +601,7 @@ func Test_RangeProof_NilEnd(t *testing.T) { maybe.Some([]byte{1}), maybe.Nothing[[]byte](), db.root.id, + db.tokenSize, )) } @@ -652,17 +633,18 @@ func Test_RangeProof_EmptyValues(t *testing.T) { require.Empty(proof.KeyValues[2].Value) require.Len(proof.StartProof, 1) - require.Equal(ToKey([]byte("key1"), BranchFactor16), proof.StartProof[0].Key, BranchFactor16) + require.Equal(ToKey([]byte("key1")), proof.StartProof[0].Key) require.Len(proof.EndProof, 3) - require.Equal(ToKey([]byte("key2"), BranchFactor16), proof.EndProof[2].Key, BranchFactor16) - require.Equal(ToKey([]byte{}, BranchFactor16), proof.EndProof[0].Key, BranchFactor16) + require.Equal(ToKey([]byte("key2")), proof.EndProof[2].Key) + require.Equal(ToKey([]byte{}), proof.EndProof[0].Key) require.NoError(proof.Verify( context.Background(), maybe.Some([]byte("key1")), maybe.Some([]byte("key2")), db.root.id, + db.tokenSize, )) } @@ -942,8 +924,8 @@ func Test_ChangeProof_Syntactic_Verify(t *testing.T) { name: "start proof node has wrong prefix", proof: &ChangeProof{ StartProof: []ProofNode{ - {Key: ToKey([]byte{2}, BranchFactor16)}, - {Key: ToKey([]byte{2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{2})}, + {Key: ToKey([]byte{2, 3})}, }, }, start: maybe.Some([]byte{1, 2, 3}), @@ -954,8 +936,8 @@ func Test_ChangeProof_Syntactic_Verify(t *testing.T) { name: "start proof non-increasing", proof: &ChangeProof{ StartProof: []ProofNode{ - {Key: ToKey([]byte{1}, BranchFactor16)}, - {Key: ToKey([]byte{2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1})}, + {Key: ToKey([]byte{2, 3})}, }, }, start: maybe.Some([]byte{1, 2, 3}), @@ -969,8 +951,8 @@ func Test_ChangeProof_Syntactic_Verify(t *testing.T) { {Key: []byte{1, 2}, Value: maybe.Some([]byte{0})}, }, EndProof: []ProofNode{ - {Key: ToKey([]byte{2}, BranchFactor16)}, - {Key: ToKey([]byte{2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{2})}, + {Key: ToKey([]byte{2, 3})}, }, }, start: maybe.Nothing[[]byte](), @@ -984,8 +966,8 @@ func Test_ChangeProof_Syntactic_Verify(t *testing.T) { {Key: []byte{1, 2, 3}}, }, EndProof: []ProofNode{ - {Key: ToKey([]byte{1}, BranchFactor16)}, - {Key: ToKey([]byte{2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1})}, + {Key: ToKey([]byte{2, 3})}, }, }, start: maybe.Nothing[[]byte](), @@ -1100,119 +1082,118 @@ func TestVerifyProofPath(t *testing.T) { }, { name: "1 element", - path: []ProofNode{{Key: ToKey([]byte{1}, BranchFactor16)}}, + path: []ProofNode{{Key: ToKey([]byte{1})}}, proofKey: maybe.Nothing[Key](), expectedErr: nil, }, { name: "non-increasing keys", path: []ProofNode{ - {Key: ToKey([]byte{1}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2}, BranchFactor16)}, - {Key: ToKey([]byte{1, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1})}, + {Key: ToKey([]byte{1, 2})}, + {Key: ToKey([]byte{1, 3})}, }, - proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3})), expectedErr: ErrNonIncreasingProofNodes, }, { name: "invalid key", path: []ProofNode{ - {Key: ToKey([]byte{1}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2, 4}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1})}, + {Key: ToKey([]byte{1, 2})}, + {Key: ToKey([]byte{1, 2, 4})}, + {Key: ToKey([]byte{1, 2, 3})}, }, - proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3})), expectedErr: ErrProofNodeNotForKey, }, { name: "extra node inclusion proof", path: []ProofNode{ - {Key: ToKey([]byte{1}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1})}, + {Key: ToKey([]byte{1, 2})}, + {Key: ToKey([]byte{1, 2, 3})}, }, - proofKey: maybe.Some(ToKey([]byte{1, 2}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2})), expectedErr: ErrProofNodeNotForKey, }, { name: "extra node exclusion proof", path: []ProofNode{ - {Key: ToKey([]byte{1}, BranchFactor16)}, - {Key: ToKey([]byte{1, 3}, BranchFactor16)}, - {Key: ToKey([]byte{1, 3, 4}, BranchFactor16)}, + {Key: ToKey([]byte{1})}, + {Key: ToKey([]byte{1, 3})}, + {Key: ToKey([]byte{1, 3, 4})}, }, - proofKey: maybe.Some(ToKey([]byte{1, 2}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2})), expectedErr: ErrProofNodeNotForKey, }, { name: "happy path exclusion proof", path: []ProofNode{ - {Key: ToKey([]byte{1}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2, 4}, BranchFactor16)}, + {Key: ToKey([]byte{1})}, + {Key: ToKey([]byte{1, 2})}, + {Key: ToKey([]byte{1, 2, 4})}, }, - proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3})), expectedErr: nil, }, { name: "happy path inclusion proof", path: []ProofNode{ - {Key: ToKey([]byte{1}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1})}, + {Key: ToKey([]byte{1, 2})}, + {Key: ToKey([]byte{1, 2, 3})}, }, - proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3})), expectedErr: nil, }, { name: "repeat nodes", path: []ProofNode{ - {Key: ToKey([]byte{1}, BranchFactor16)}, - {Key: ToKey([]byte{1}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1})}, + {Key: ToKey([]byte{1})}, + {Key: ToKey([]byte{1, 2})}, + {Key: ToKey([]byte{1, 2, 3})}, }, - proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3})), expectedErr: ErrNonIncreasingProofNodes, }, { name: "repeat nodes 2", path: []ProofNode{ - {Key: ToKey([]byte{1}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1})}, + {Key: ToKey([]byte{1, 2})}, + {Key: ToKey([]byte{1, 2})}, + {Key: ToKey([]byte{1, 2, 3})}, }, - proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3})), expectedErr: ErrNonIncreasingProofNodes, }, { name: "repeat nodes 3", path: []ProofNode{ - {Key: ToKey([]byte{1}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2, 3}, BranchFactor16)}, + {Key: ToKey([]byte{1})}, + {Key: ToKey([]byte{1, 2})}, + {Key: ToKey([]byte{1, 2, 3})}, + {Key: ToKey([]byte{1, 2, 3})}, }, - proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3})), expectedErr: ErrProofNodeNotForKey, }, { name: "oddLength key with value", path: []ProofNode{ - {Key: ToKey([]byte{1}, BranchFactor16)}, - {Key: ToKey([]byte{1, 2}, BranchFactor16)}, + {Key: ToKey([]byte{1})}, + {Key: ToKey([]byte{1, 2})}, { Key: Key{ - value: string([]byte{1, 2, 240}), - tokenLength: 5, - tokenConfig: branchFactorToTokenConfig[BranchFactor16], + value: string([]byte{1, 2, 240}), + length: 20, }, ValueOrHash: maybe.Some([]byte{1}), }, }, - proofKey: maybe.Some(ToKey([]byte{1, 2, 3}, BranchFactor16)), + proofKey: maybe.Some(ToKey([]byte{1, 2, 3})), expectedErr: ErrPartialByteLengthWithValue, }, } @@ -1240,7 +1221,7 @@ func TestProofNodeUnmarshalProtoInvalidMaybe(t *testing.T) { } var unmarshaledNode ProofNode - err := unmarshaledNode.UnmarshalProto(protoNode, BranchFactor16) + err := unmarshaledNode.UnmarshalProto(protoNode) require.ErrorIs(t, err, ErrInvalidMaybe) } @@ -1257,7 +1238,7 @@ func TestProofNodeUnmarshalProtoInvalidChildBytes(t *testing.T) { } var unmarshaledNode ProofNode - err := unmarshaledNode.UnmarshalProto(protoNode, BranchFactor16) + err := unmarshaledNode.UnmarshalProto(protoNode) require.ErrorIs(t, err, hashing.ErrInvalidHashLen) } @@ -1270,11 +1251,11 @@ func TestProofNodeUnmarshalProtoInvalidChildIndex(t *testing.T) { protoNode := node.ToProto() childID := ids.GenerateTestID() - protoNode.Children[uint32(BranchFactor16)] = childID[:] + protoNode.Children[256] = childID[:] var unmarshaledNode ProofNode - err := unmarshaledNode.UnmarshalProto(protoNode, BranchFactor16) - require.ErrorIs(t, err, ErrInvalidChildIndex) + err := unmarshaledNode.UnmarshalProto(protoNode) + require.ErrorIs(t, err, errChildIndexTooLarge) } func TestProofNodeUnmarshalProtoMissingFields(t *testing.T) { @@ -1321,7 +1302,7 @@ func TestProofNodeUnmarshalProtoMissingFields(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var node ProofNode - err := node.UnmarshalProto(tt.nodeFunc(), BranchFactor16) + err := node.UnmarshalProto(tt.nodeFunc()) require.ErrorIs(t, err, tt.expectedErr) }) } @@ -1340,7 +1321,7 @@ func FuzzProofNodeProtoMarshalUnmarshal(f *testing.F) { // Assert the unmarshaled one is the same as the original. protoNode := node.ToProto() var unmarshaledNode ProofNode - require.NoError(unmarshaledNode.UnmarshalProto(protoNode, BranchFactor16)) + require.NoError(unmarshaledNode.UnmarshalProto(protoNode)) require.Equal(node, unmarshaledNode) // Marshaling again should yield same result. @@ -1397,7 +1378,7 @@ func FuzzRangeProofProtoMarshalUnmarshal(f *testing.F) { // Assert the unmarshaled one is the same as the original. var unmarshaledProof RangeProof protoProof := proof.ToProto() - require.NoError(unmarshaledProof.UnmarshalProto(protoProof, BranchFactor16)) + require.NoError(unmarshaledProof.UnmarshalProto(protoProof)) require.Equal(proof, unmarshaledProof) // Marshaling again should yield same result. @@ -1459,7 +1440,7 @@ func FuzzChangeProofProtoMarshalUnmarshal(f *testing.F) { // Assert the unmarshaled one is the same as the original. var unmarshaledProof ChangeProof protoProof := proof.ToProto() - require.NoError(unmarshaledProof.UnmarshalProto(protoProof, BranchFactor16)) + require.NoError(unmarshaledProof.UnmarshalProto(protoProof)) require.Equal(proof, unmarshaledProof) // Marshaling again should yield same result. @@ -1470,7 +1451,7 @@ func FuzzChangeProofProtoMarshalUnmarshal(f *testing.F) { func TestChangeProofUnmarshalProtoNil(t *testing.T) { var proof ChangeProof - err := proof.UnmarshalProto(nil, BranchFactor16) + err := proof.UnmarshalProto(nil) require.ErrorIs(t, err, ErrNilChangeProof) } @@ -1524,7 +1505,7 @@ func TestChangeProofUnmarshalProtoNilValue(t *testing.T) { protoProof.KeyChanges[0].Value = nil var unmarshaledProof ChangeProof - err := unmarshaledProof.UnmarshalProto(protoProof, BranchFactor16) + err := unmarshaledProof.UnmarshalProto(protoProof) require.ErrorIs(t, err, ErrNilMaybeBytes) } @@ -1542,7 +1523,7 @@ func TestChangeProofUnmarshalProtoInvalidMaybe(t *testing.T) { } var proof ChangeProof - err := proof.UnmarshalProto(protoProof, BranchFactor16) + err := proof.UnmarshalProto(protoProof) require.ErrorIs(t, err, ErrInvalidMaybe) } @@ -1575,7 +1556,7 @@ func FuzzProofProtoMarshalUnmarshal(f *testing.F) { } proof := Proof{ - Key: ToKey(key, BranchFactor16), + Key: ToKey(key), Value: value, Path: proofPath, } @@ -1584,7 +1565,7 @@ func FuzzProofProtoMarshalUnmarshal(f *testing.F) { // Assert the unmarshaled one is the same as the original. var unmarshaledProof Proof protoProof := proof.ToProto() - require.NoError(unmarshaledProof.UnmarshalProto(protoProof, BranchFactor16)) + require.NoError(unmarshaledProof.UnmarshalProto(protoProof)) require.Equal(proof, unmarshaledProof) // Marshaling again should yield same result. @@ -1626,7 +1607,7 @@ func TestProofProtoUnmarshal(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var proof Proof - err := proof.UnmarshalProto(tt.proof, BranchFactor16) + err := proof.UnmarshalProto(tt.proof) require.ErrorIs(t, err, tt.expectedErr) }) } @@ -1694,6 +1675,7 @@ func FuzzRangeProofInvariants(f *testing.F) { start, end, rootID, + db.tokenSize, )) // Make sure the start proof doesn't contain any nodes @@ -1732,14 +1714,14 @@ func FuzzRangeProofInvariants(f *testing.F) { proof := Proof{ Path: rangeProof.EndProof, - Key: ToKey(endBytes, BranchFactor16), + Key: ToKey(endBytes), Value: value, } rootID, err := db.GetMerkleRoot(context.Background()) require.NoError(err) - require.NoError(proof.Verify(context.Background(), rootID)) + require.NoError(proof.Verify(context.Background(), rootID, db.tokenSize)) default: require.NotEmpty(rangeProof.EndProof) @@ -1747,14 +1729,14 @@ func FuzzRangeProofInvariants(f *testing.F) { // EndProof should be a proof for largest key-value. proof := Proof{ Path: rangeProof.EndProof, - Key: ToKey(greatestKV.Key, BranchFactor16), + Key: ToKey(greatestKV.Key), Value: maybe.Some(greatestKV.Value), } rootID, err := db.GetMerkleRoot(context.Background()) require.NoError(err) - require.NoError(proof.Verify(context.Background(), rootID)) + require.NoError(proof.Verify(context.Background(), rootID, db.tokenSize)) } }) } @@ -1790,7 +1772,7 @@ func FuzzProofVerification(f *testing.F) { rootID, err := db.GetMerkleRoot(context.Background()) require.NoError(err) - require.NoError(proof.Verify(context.Background(), rootID)) + require.NoError(proof.Verify(context.Background(), rootID, db.tokenSize)) // Insert a new key-value pair newKey := make([]byte, 32) diff --git a/x/merkledb/trie_test.go b/x/merkledb/trie_test.go index 7908c1266af7..e278456649b1 100644 --- a/x/merkledb/trie_test.go +++ b/x/merkledb/trie_test.go @@ -19,10 +19,6 @@ import ( ) func getNodeValue(t ReadOnlyTrie, key string) ([]byte, error) { - return getNodeValueWithBranchFactor(t, key, BranchFactor16) -} - -func getNodeValueWithBranchFactor(t ReadOnlyTrie, key string, bf BranchFactor) ([]byte, error) { var view *trieView if asTrieView, ok := t.(*trieView); ok { if err := asTrieView.calculateNodeIDs(context.Background()); err != nil { @@ -38,7 +34,7 @@ func getNodeValueWithBranchFactor(t ReadOnlyTrie, key string, bf BranchFactor) ( view = dbView.(*trieView) } - path := ToKey([]byte(key), bf) + path := ToKey([]byte(key)) var result *node err := view.visitPathToKey(path, func(n *node) error { result = n @@ -123,7 +119,7 @@ func TestTrieViewVisitPathToKey(t *testing.T) { trie := trieIntf.(*trieView) var nodePath []*node - require.NoError(trie.visitPathToKey(ToKey(nil, BranchFactor16), func(n *node) error { + require.NoError(trie.visitPathToKey(ToKey(nil), func(n *node) error { nodePath = append(nodePath, n) return nil })) @@ -148,7 +144,7 @@ func TestTrieViewVisitPathToKey(t *testing.T) { require.NoError(trie.calculateNodeIDs(context.Background())) nodePath = make([]*node, 0, 2) - require.NoError(trie.visitPathToKey(ToKey(key1, BranchFactor16), func(n *node) error { + require.NoError(trie.visitPathToKey(ToKey(key1), func(n *node) error { nodePath = append(nodePath, n) return nil })) @@ -156,7 +152,7 @@ func TestTrieViewVisitPathToKey(t *testing.T) { // Root and 1 value require.Len(nodePath, 2) require.Equal(trie.root, nodePath[0]) - require.Equal(ToKey(key1, BranchFactor16), nodePath[1].key) + require.Equal(ToKey(key1), nodePath[1].key) // Insert another key which is a child of the first key2 := []byte{0, 1} @@ -174,14 +170,14 @@ func TestTrieViewVisitPathToKey(t *testing.T) { require.NoError(trie.calculateNodeIDs(context.Background())) nodePath = make([]*node, 0, 3) - require.NoError(trie.visitPathToKey(ToKey(key2, BranchFactor16), func(n *node) error { + require.NoError(trie.visitPathToKey(ToKey(key2), func(n *node) error { nodePath = append(nodePath, n) return nil })) require.Len(nodePath, 3) require.Equal(trie.root, nodePath[0]) - require.Equal(ToKey(key1, BranchFactor16), nodePath[1].key) - require.Equal(ToKey(key2, BranchFactor16), nodePath[2].key) + require.Equal(ToKey(key1), nodePath[1].key) + require.Equal(ToKey(key2), nodePath[2].key) // Insert a key which shares no prefix with the others key3 := []byte{255} @@ -199,41 +195,43 @@ func TestTrieViewVisitPathToKey(t *testing.T) { require.NoError(trie.calculateNodeIDs(context.Background())) nodePath = make([]*node, 0, 2) - require.NoError(trie.visitPathToKey(ToKey(key3, BranchFactor16), func(n *node) error { + require.NoError(trie.visitPathToKey(ToKey(key3), func(n *node) error { nodePath = append(nodePath, n) return nil })) + require.Len(nodePath, 2) require.Equal(trie.root, nodePath[0]) - require.Equal(ToKey(key3, BranchFactor16), nodePath[1].key) + require.Equal(ToKey(key3), nodePath[1].key) // Other key path not affected nodePath = make([]*node, 0, 3) - require.NoError(trie.visitPathToKey(ToKey(key2, BranchFactor16), func(n *node) error { + require.NoError(trie.visitPathToKey(ToKey(key2), func(n *node) error { nodePath = append(nodePath, n) return nil })) require.Len(nodePath, 3) require.Equal(trie.root, nodePath[0]) - require.Equal(ToKey(key1, BranchFactor16), nodePath[1].key) - require.Equal(ToKey(key2, BranchFactor16), nodePath[2].key) + require.Equal(ToKey(key1), nodePath[1].key) + require.Equal(ToKey(key2), nodePath[2].key) // Gets closest node when key doesn't exist key4 := []byte{0, 1, 2} nodePath = make([]*node, 0, 3) - require.NoError(trie.visitPathToKey(ToKey(key4, BranchFactor16), func(n *node) error { + require.NoError(trie.visitPathToKey(ToKey(key4), func(n *node) error { nodePath = append(nodePath, n) return nil })) + require.Len(nodePath, 3) require.Equal(trie.root, nodePath[0]) - require.Equal(ToKey(key1, BranchFactor16), nodePath[1].key) - require.Equal(ToKey(key2, BranchFactor16), nodePath[2].key) + require.Equal(ToKey(key1), nodePath[1].key) + require.Equal(ToKey(key2), nodePath[2].key) // Gets just root when key doesn't exist and no key shares a prefix key5 := []byte{128} nodePath = make([]*node, 0, 1) - require.NoError(trie.visitPathToKey(ToKey(key5, BranchFactor16), func(n *node) error { + require.NoError(trie.visitPathToKey(ToKey(key5), func(n *node) error { nodePath = append(nodePath, n) return nil })) @@ -320,7 +318,7 @@ func Test_Trie_WriteToDB(t *testing.T) { rawBytes, err := dbTrie.baseDB.Get(prefixedKey) require.NoError(err) - node, err := parseNode(ToKey(key, BranchFactor16), rawBytes) + node, err := parseNode(ToKey(key), rawBytes) require.NoError(err) require.Equal([]byte("value"), node.value.Value()) } @@ -488,7 +486,7 @@ func Test_Trie_ExpandOnKeyPath(t *testing.T) { require.Equal([]byte("value12"), value) } -func Test_Trie_CompressedPaths(t *testing.T) { +func Test_Trie_compressedKeys(t *testing.T) { require := require.New(t) dbTrie, err := getBasicDB() @@ -619,7 +617,7 @@ func Test_Trie_HashCountOnBranch(t *testing.T) { // Make sure the branch node with the common prefix was created. // Note it's only created on call to GetMerkleRoot, not in NewView. - _, err = view2.getEditableNode(ToKey(keyPrefix, BranchFactor16), false) + _, err = view2.getEditableNode(ToKey(keyPrefix), false) require.NoError(err) // only hashes the new branch node, the new child node, and root @@ -760,7 +758,7 @@ func Test_Trie_ChainDeletion(t *testing.T) { require.NoError(err) require.NoError(newTrie.(*trieView).calculateNodeIDs(context.Background())) - root, err := newTrie.getEditableNode(emptyKey(BranchFactor16), false) + root, err := newTrie.getEditableNode(Key{}, false) require.NoError(err) require.Len(root.children, 1) @@ -777,7 +775,7 @@ func Test_Trie_ChainDeletion(t *testing.T) { ) require.NoError(err) require.NoError(newTrie.(*trieView).calculateNodeIDs(context.Background())) - root, err = newTrie.getEditableNode(emptyKey(BranchFactor16), false) + root, err = newTrie.getEditableNode(Key{}, false) require.NoError(err) // since all values have been deleted, the nodes should have been cleaned up require.Empty(root.children) @@ -842,15 +840,15 @@ func Test_Trie_NodeCollapse(t *testing.T) { require.NoError(err) require.NoError(trie.(*trieView).calculateNodeIDs(context.Background())) - root, err := trie.getEditableNode(emptyKey(BranchFactor16), false) + root, err := trie.getEditableNode(Key{}, false) require.NoError(err) require.Len(root.children, 1) - root, err = trie.getEditableNode(emptyKey(BranchFactor16), false) + root, err = trie.getEditableNode(Key{}, false) require.NoError(err) require.Len(root.children, 1) - firstNode, err := trie.getEditableNode(getSingleChildKey(root), true) + firstNode, err := trie.getEditableNode(getSingleChildKey(root, dbTrie.tokenSize), true) require.NoError(err) require.Len(firstNode.children, 1) @@ -868,11 +866,11 @@ func Test_Trie_NodeCollapse(t *testing.T) { require.NoError(err) require.NoError(trie.(*trieView).calculateNodeIDs(context.Background())) - root, err = trie.getEditableNode(emptyKey(BranchFactor16), false) + root, err = trie.getEditableNode(Key{}, false) require.NoError(err) require.Len(root.children, 1) - firstNode, err = trie.getEditableNode(getSingleChildKey(root), true) + firstNode, err = trie.getEditableNode(getSingleChildKey(root, dbTrie.tokenSize), true) require.NoError(err) require.Len(firstNode.children, 2) } @@ -1215,9 +1213,9 @@ func Test_Trie_ConcurrentNewViewAndCommit(t *testing.T) { // Returns the path of the only child of this node. // Assumes this node has exactly one child. -func getSingleChildKey(n *node) Key { +func getSingleChildKey(n *node, tokenSize int) Key { for index, entry := range n.children { - return n.key.AppendExtend(index, entry.compressedKey) + return n.key.Extend(ToToken(index, tokenSize), entry.compressedKey) } return Key{} } diff --git a/x/merkledb/trieview.go b/x/merkledb/trieview.go index 3422379a20cc..c905cb82c218 100644 --- a/x/merkledb/trieview.go +++ b/x/merkledb/trieview.go @@ -35,7 +35,7 @@ var ( ErrPartialByteLengthWithValue = errors.New( "the underlying db only supports whole number of byte keys, so cannot record changes with partial byte lengths", ) - ErrGetPathToFailure = errors.New("GetPathTo failed to return the closest node") + ErrVisitPathToKey = errors.New("failed to visit expected node during insertion") ErrStartAfterEnd = errors.New("start key > end key") ErrNoValidRoot = errors.New("a valid root was not provided to the trieView constructor") ErrParentNotDatabase = errors.New("parent trie is not database") @@ -98,6 +98,8 @@ type trieView struct { // The root of the trie represented by this view. root *node + + tokenSize int } // NewView returns a new view on top of this Trie where the passed changes @@ -145,7 +147,7 @@ func newTrieView( parentTrie TrieView, changes ViewChanges, ) (*trieView, error) { - root, err := parentTrie.getEditableNode(db.rootKey, false /* hasValue */) + root, err := parentTrie.getEditableNode(Key{}, false /* hasValue */) if err != nil { if err == database.ErrNotFound { return nil, ErrNoValidRoot @@ -158,6 +160,7 @@ func newTrieView( db: db, parentTrie: parentTrie, changes: newChangeSummary(len(changes.BatchOps) + len(changes.MapOps)), + tokenSize: db.tokenSize, } for _, op := range changes.BatchOps { @@ -173,7 +176,7 @@ func newTrieView( newVal = maybe.Some(slices.Clone(op.Value)) } } - if err := newView.recordValueChange(db.toKey(key), newVal); err != nil { + if err := newView.recordValueChange(toKey(key), newVal); err != nil { return nil, err } } @@ -181,7 +184,7 @@ func newTrieView( if !changes.ConsumeBytes { val = maybe.Bind(val, slices.Clone[[]byte]) } - if err := newView.recordValueChange(db.toKey(stringToByteSlice(key)), val); err != nil { + if err := newView.recordValueChange(toKey(stringToByteSlice(key)), val); err != nil { return nil, err } } @@ -197,7 +200,7 @@ func newHistoricalTrieView( return nil, ErrNoValidRoot } - passedRootChange, ok := changes.nodes[db.rootKey] + passedRootChange, ok := changes.nodes[Key{}] if !ok { return nil, ErrNoValidRoot } @@ -207,6 +210,7 @@ func newHistoricalTrieView( db: db, parentTrie: db, changes: changes, + tokenSize: db.tokenSize, } // since this is a set of historical changes, all nodes have already been calculated // since no new changes have occurred, no new calculations need to be done @@ -269,7 +273,7 @@ func (t *trieView) calculateNodeIDsHelper(n *node) { ) for childIndex, child := range n.children { - childPath := n.key.AppendExtend(childIndex, child.compressedKey) + childPath := n.key.Extend(ToToken(childIndex, t.tokenSize), child.compressedKey) childNodeChange, ok := t.changes.nodes[childPath] if !ok { // This child wasn't changed. @@ -302,9 +306,8 @@ func (t *trieView) calculateNodeIDsHelper(n *node) { wg.Wait() close(updatedChildren) - keyLength := n.key.tokenLength for updatedChild := range updatedChildren { - index := updatedChild.key.Token(keyLength) + index := updatedChild.key.Token(n.key.length, t.tokenSize) n.setChildEntry(index, child{ compressedKey: n.children[index].compressedKey, id: updatedChild.id, @@ -334,7 +337,7 @@ func (t *trieView) getProof(ctx context.Context, key []byte) (*Proof, error) { defer span.End() proof := &Proof{ - Key: t.db.toKey(key), + Key: ToKey(key), } var closestNode *node @@ -355,7 +358,7 @@ func (t *trieView) getProof(ctx context.Context, key []byte) (*Proof, error) { // There is no node with the given [key]. // If there is a child at the index where the node would be // if it existed, include that child in the proof. - nextIndex := proof.Key.Token(closestNode.key.tokenLength) + nextIndex := proof.Key.Token(closestNode.key.length, t.tokenSize) child, ok := closestNode.children[nextIndex] if !ok { return proof, nil @@ -363,7 +366,7 @@ func (t *trieView) getProof(ctx context.Context, key []byte) (*Proof, error) { childNode, err := t.getNodeWithID( child.id, - closestNode.key.AppendExtend(nextIndex, child.compressedKey), + closestNode.key.Extend(ToToken(nextIndex, t.tokenSize), child.compressedKey), child.hasValue, ) if err != nil { @@ -557,7 +560,7 @@ func (t *trieView) GetValues(ctx context.Context, keys [][]byte) ([][]byte, []er valueErrors := make([]error, len(keys)) for i, key := range keys { - results[i], valueErrors[i] = t.getValueCopy(t.db.toKey(key)) + results[i], valueErrors[i] = t.getValueCopy(ToKey(key)) } return results, valueErrors } @@ -568,7 +571,7 @@ func (t *trieView) GetValue(ctx context.Context, key []byte) ([]byte, error) { _, span := t.db.debugTracer.Start(ctx, "MerkleDB.trieview.GetValue") defer span.End() - return t.getValueCopy(t.db.toKey(key)) + return t.getValueCopy(ToKey(key)) } // getValueCopy returns a copy of the value for the given [key]. @@ -654,7 +657,7 @@ func (t *trieView) remove(key Key) error { return err } if parent != nil { - parent.removeChild(nodeToDelete) + parent.removeChild(nodeToDelete, t.tokenSize) // merge the parent node and its child into a single node if possible return t.compressNodePath(grandParent, parent) @@ -692,15 +695,15 @@ func (t *trieView) compressNodePath(parent, node *node) error { // "Cycle" over the key/values to find the only child. // Note this iteration once because len(node.children) == 1. for index, entry := range node.children { - childKey = node.key.AppendExtend(index, entry.compressedKey) + childKey = node.key.Extend(ToToken(index, t.tokenSize), entry.compressedKey) childEntry = entry } // [node] is the first node with multiple children. // combine it with the [node] passed in. - parent.setChildEntry(childKey.Token(parent.key.tokenLength), + parent.setChildEntry(childKey.Token(parent.key.length, t.tokenSize), child{ - compressedKey: childKey.Skip(parent.key.tokenLength + 1), + compressedKey: childKey.Skip(parent.key.length + t.tokenSize), id: childEntry.id, hasValue: childEntry.hasValue, }) @@ -722,17 +725,16 @@ func (t *trieView) visitPathToKey(key Key, visitNode func(*node) error) error { return err } // while the entire path hasn't been matched - for currentNode.key.tokenLength < key.tokenLength { + for currentNode.key.length < key.length { // confirm that a child exists and grab its ID before attempting to load it - nextChildEntry, hasChild := currentNode.children[key.Token(currentNode.key.tokenLength)] + nextChildEntry, hasChild := currentNode.children[key.Token(currentNode.key.length, t.tokenSize)] - if !hasChild || !key.iteratedHasPrefix(currentNode.key.tokenLength+1, nextChildEntry.compressedKey) { + if !hasChild || !key.iteratedHasPrefix(nextChildEntry.compressedKey, currentNode.key.length+t.tokenSize, t.tokenSize) { // there was no child along the path or the child that was there doesn't match the remaining path return nil } - // grab the next node along the path - currentNode, err = t.getNodeWithID(nextChildEntry.id, key.Take(currentNode.key.tokenLength+1+nextChildEntry.compressedKey.tokenLength), nextChildEntry.hasValue) + currentNode, err = t.getNodeWithID(nextChildEntry.id, key.Take(currentNode.key.length+t.tokenSize+nextChildEntry.compressedKey.length), nextChildEntry.hasValue) if err != nil { return err } @@ -743,14 +745,6 @@ func (t *trieView) visitPathToKey(key Key, visitNode func(*node) error) error { return nil } -func getLengthOfCommonPrefix(first, second Key, secondOffset int) int { - commonIndex := 0 - for first.tokenLength > commonIndex && second.tokenLength > (commonIndex+secondOffset) && first.Token(commonIndex) == second.Token(commonIndex+secondOffset) { - commonIndex++ - } - return commonIndex -} - // Get a copy of the node matching the passed key from the trie. // Used by views to get nodes from their ancestors. func (t *trieView) getEditableNode(key Key, hadValue bool) (*node, error) { @@ -791,31 +785,27 @@ func (t *trieView) insert( return nil, err } - // a node with that exact path already exists so update its value + // a node with that exact key already exists so update its value if closestNode.key == key { closestNode.setValue(value) // closestNode was already marked as changed in the ancestry loop above return closestNode, nil } - closestNodeKeyLength := closestNode.key.tokenLength - // A node with the exact key doesn't exist so determine the portion of the // key that hasn't been matched yet - // Note that [key] has prefix [closestNodeFullPath] but exactMatch was false, - // so [key] must be longer than [closestNodeFullPath] and the following index and slice won't OOB. - existingChildEntry, hasChild := closestNode.children[key.Token(closestNodeKeyLength)] + // Note that [key] has prefix [closestNode.key], so [key] must be longer + // and the following index won't OOB. + existingChildEntry, hasChild := closestNode.children[key.Token(closestNode.key.length, t.tokenSize)] if !hasChild { - // there are no existing nodes along the path [fullPath], so create a new node to insert [value] - newNode := newNode( - closestNode, - key, - ) + // there are no existing nodes along the key [key], so create a new node to insert [value] + newNode := newNode(key) newNode.setValue(value) + closestNode.addChild(newNode, t.tokenSize) return newNode, t.recordNewNode(newNode) } - // if we have reached this point, then the [fullpath] we are trying to insert and + // if we have reached this point, then the [key] we are trying to insert and // the existing path node have some common prefix. // a new branching node will be created that will represent this common prefix and // have the existing path node and the value being inserted as children. @@ -824,31 +814,32 @@ func (t *trieView) insert( // find how many tokens are common between the existing child's compressed path and // the current key(offset by the closest node's key), // then move all the common tokens into the branch node - commonPrefixLength := getLengthOfCommonPrefix(existingChildEntry.compressedKey, key, closestNodeKeyLength+1) + commonPrefixLength := getLengthOfCommonPrefix( + existingChildEntry.compressedKey, + key, + closestNode.key.length+t.tokenSize, + t.tokenSize, + ) - // If the length of the existing child's compressed path is less than or equal to the branch node's key that implies that the existing child's key matched the key to be inserted. - // Since it matched the key to be inserted, it should have been the last node returned by GetPathTo - if existingChildEntry.compressedKey.tokenLength <= commonPrefixLength { - return nil, ErrGetPathToFailure + if existingChildEntry.compressedKey.length <= commonPrefixLength { + // Since the compressed key is shorter than the common prefix, + // we should have visited [existingChildEntry] in [visitPathToKey]. + return nil, ErrVisitPathToKey } - branchNode := newNode( - closestNode, - key.Take(closestNodeKeyLength+1+commonPrefixLength), - ) + branchNode := newNode(key.Take(closestNode.key.length + t.tokenSize + commonPrefixLength)) + closestNode.addChild(branchNode, t.tokenSize) nodeWithValue := branchNode - if key.tokenLength == branchNode.key.tokenLength { + if key.length == branchNode.key.length { // the branch node has exactly the key to be inserted as its key, so set the value on the branch node branchNode.setValue(value) } else { // the key to be inserted is a child of the branch node // create a new node and add the value to it - newNode := newNode( - branchNode, - key, - ) + newNode := newNode(key) newNode.setValue(value) + branchNode.addChild(newNode, t.tokenSize) if err := t.recordNewNode(newNode); err != nil { return nil, err } @@ -857,9 +848,9 @@ func (t *trieView) insert( // add the existing child onto the branch node branchNode.setChildEntry( - existingChildEntry.compressedKey.Token(commonPrefixLength), + existingChildEntry.compressedKey.Token(commonPrefixLength, t.tokenSize), child{ - compressedKey: existingChildEntry.compressedKey.Skip(commonPrefixLength + 1), + compressedKey: existingChildEntry.compressedKey.Skip(commonPrefixLength + t.tokenSize), id: existingChildEntry.id, hasValue: existingChildEntry.hasValue, }) @@ -867,6 +858,15 @@ func (t *trieView) insert( return nodeWithValue, t.recordNewNode(branchNode) } +func getLengthOfCommonPrefix(first, second Key, secondOffset int, tokenSize int) int { + commonIndex := 0 + for first.length > commonIndex && second.length > commonIndex+secondOffset && + first.Token(commonIndex, tokenSize) == second.Token(commonIndex+secondOffset, tokenSize) { + commonIndex += tokenSize + } + return commonIndex +} + // Records that a node has been created. // Must not be called after [calculateNodeIDs] has returned. func (t *trieView) recordNewNode(after *node) error { @@ -883,7 +883,7 @@ func (t *trieView) recordNodeChange(after *node) error { // Must not be called after [calculateNodeIDs] has returned. func (t *trieView) recordNodeDeleted(after *node) error { // don't delete the root. - if after.key.tokenLength == 0 { + if after.key.length == 0 { return t.recordKeyChange(after.key, after, after.hasValue(), false /* newNode */) } return t.recordKeyChange(after.key, nil, after.hasValue(), false /* newNode */) diff --git a/x/merkledb/value_node_db.go b/x/merkledb/value_node_db.go index 8f168560d7fa..339fa25f78f0 100644 --- a/x/merkledb/value_node_db.go +++ b/x/merkledb/value_node_db.go @@ -27,8 +27,7 @@ type valueNodeDB struct { nodeCache cache.Cacher[Key, *node] metrics merkleMetrics - closed utils.Atomic[bool] - branchFactor BranchFactor + closed utils.Atomic[bool] } func newValueNodeDB( @@ -36,14 +35,12 @@ func newValueNodeDB( bufferPool *sync.Pool, metrics merkleMetrics, cacheSize int, - branchFactor BranchFactor, ) *valueNodeDB { return &valueNodeDB{ - metrics: metrics, - baseDB: db, - bufferPool: bufferPool, - nodeCache: cache.NewSizedLRU(cacheSize, cacheEntrySize), - branchFactor: branchFactor, + metrics: metrics, + baseDB: db, + bufferPool: bufferPool, + nodeCache: cache.NewSizedLRU(cacheSize, cacheEntrySize), } } @@ -170,7 +167,7 @@ func (i *iterator) Next() bool { i.db.metrics.DatabaseNodeRead() key := i.nodeIter.Key() key = key[valueNodePrefixLen:] - n, err := parseNode(ToKey(key, i.db.branchFactor), i.nodeIter.Value()) + n, err := parseNode(ToKey(key), i.nodeIter.Value()) if err != nil { i.err = err return false diff --git a/x/merkledb/value_node_db_test.go b/x/merkledb/value_node_db_test.go index 910c6e1e9d6b..96c5b4b038cc 100644 --- a/x/merkledb/value_node_db_test.go +++ b/x/merkledb/value_node_db_test.go @@ -28,11 +28,10 @@ func TestValueNodeDB(t *testing.T) { }, &mockMetrics{}, size, - BranchFactor16, ) // Getting a key that doesn't exist should return an error. - key := ToKey([]byte{0x01}, BranchFactor16) + key := ToKey([]byte{0x01}) _, err := db.Get(key) require.ErrorIs(err, database.ErrNotFound) @@ -124,12 +123,11 @@ func TestValueNodeDBIterator(t *testing.T) { }, &mockMetrics{}, cacheSize, - BranchFactor16, ) // Put key-node pairs. for i := 0; i < cacheSize; i++ { - key := ToKey([]byte{byte(i)}, BranchFactor16) + key := ToKey([]byte{byte(i)}) node := &node{ dbNode: dbNode{ value: maybe.Some([]byte{byte(i)}), @@ -167,7 +165,7 @@ func TestValueNodeDBIterator(t *testing.T) { it.Release() // Put key-node pairs with a common prefix. - key := ToKey([]byte{0xFF, 0x00}, BranchFactor16) + key := ToKey([]byte{0xFF, 0x00}) n := &node{ dbNode: dbNode{ value: maybe.Some([]byte{0xFF, 0x00}), @@ -178,7 +176,7 @@ func TestValueNodeDBIterator(t *testing.T) { batch.Put(key, n) require.NoError(batch.Write()) - key = ToKey([]byte{0xFF, 0x01}, BranchFactor16) + key = ToKey([]byte{0xFF, 0x01}) n = &node{ dbNode: dbNode{ value: maybe.Some([]byte{0xFF, 0x01}), diff --git a/x/merkledb/view_iterator.go b/x/merkledb/view_iterator.go index 263aa409e882..fac213bf350b 100644 --- a/x/merkledb/view_iterator.go +++ b/x/merkledb/view_iterator.go @@ -26,8 +26,8 @@ func (t *trieView) NewIteratorWithPrefix(prefix []byte) database.Iterator { func (t *trieView) NewIteratorWithStartAndPrefix(start, prefix []byte) database.Iterator { var ( changes = make([]KeyChange, 0, len(t.changes.values)) - startKey = t.db.toKey(start) - prefixKey = t.db.toKey(prefix) + startKey = ToKey(start) + prefixKey = ToKey(prefix) ) for key, change := range t.changes.values { diff --git a/x/sync/client.go b/x/sync/client.go index 095f515d41fb..e86c37485c7f 100644 --- a/x/sync/client.go +++ b/x/sync/client.go @@ -73,7 +73,7 @@ type client struct { stateSyncMinVersion *version.Application log logging.Logger metrics SyncMetrics - branchFactor merkledb.BranchFactor + tokenSize int } type ClientConfig struct { @@ -95,7 +95,7 @@ func NewClient(config *ClientConfig) (Client, error) { stateSyncMinVersion: config.StateSyncMinVersion, log: config.Log, metrics: config.Metrics, - branchFactor: config.BranchFactor, + tokenSize: merkledb.BranchFactorToTokenSize[config.BranchFactor], }, nil } @@ -124,7 +124,7 @@ func (c *client) GetChangeProof( case *pb.SyncGetChangeProofResponse_ChangeProof: // The server had enough history to send us a change proof var changeProof merkledb.ChangeProof - if err := changeProof.UnmarshalProto(changeProofResp.ChangeProof, c.branchFactor); err != nil { + if err := changeProof.UnmarshalProto(changeProofResp.ChangeProof); err != nil { return nil, err } @@ -158,7 +158,7 @@ func (c *client) GetChangeProof( case *pb.SyncGetChangeProofResponse_RangeProof: var rangeProof merkledb.RangeProof - if err := rangeProof.UnmarshalProto(changeProofResp.RangeProof, c.branchFactor); err != nil { + if err := rangeProof.UnmarshalProto(changeProofResp.RangeProof); err != nil { return nil, err } @@ -171,6 +171,7 @@ func (c *client) GetChangeProof( startKey, endKey, req.EndRootHash, + c.tokenSize, ) if err != nil { return nil, err @@ -208,6 +209,7 @@ func verifyRangeProof( start maybe.Maybe[[]byte], end maybe.Maybe[[]byte], rootBytes []byte, + tokenSize int, ) error { root, err := ids.ToID(rootBytes) if err != nil { @@ -227,6 +229,7 @@ func verifyRangeProof( start, end, root, + tokenSize, ); err != nil { return fmt.Errorf("%w due to %w", errInvalidRangeProof, err) } @@ -257,7 +260,7 @@ func (c *client) GetRangeProof( endKey := maybeBytesToMaybe(req.EndKey) var rangeProof merkledb.RangeProof - if err := rangeProof.UnmarshalProto(&rangeProofProto, c.branchFactor); err != nil { + if err := rangeProof.UnmarshalProto(&rangeProofProto); err != nil { return nil, err } @@ -268,6 +271,7 @@ func (c *client) GetRangeProof( startKey, endKey, req.RootHash, + c.tokenSize, ); err != nil { return nil, err } diff --git a/x/sync/client_test.go b/x/sync/client_test.go index 08c1a787b474..c9f473bea53e 100644 --- a/x/sync/client_test.go +++ b/x/sync/client_test.go @@ -138,7 +138,7 @@ func sendRangeProofRequest( require.NoError(proto.Unmarshal(responseBytes, &responseProto)) var response merkledb.RangeProof - require.NoError(response.UnmarshalProto(&responseProto, merkledb.BranchFactor16)) + require.NoError(response.UnmarshalProto(&responseProto)) // modify if needed if modifyResponse != nil { @@ -456,7 +456,7 @@ func sendChangeProofRequest( if responseProto.GetChangeProof() != nil { // Server responded with a change proof var changeProof merkledb.ChangeProof - require.NoError(changeProof.UnmarshalProto(responseProto.GetChangeProof(), merkledb.BranchFactor16)) + require.NoError(changeProof.UnmarshalProto(responseProto.GetChangeProof())) // modify if needed if modifyChangeProof != nil { @@ -478,7 +478,7 @@ func sendChangeProofRequest( // Server responded with a range proof var rangeProof merkledb.RangeProof - require.NoError(rangeProof.UnmarshalProto(responseProto.GetRangeProof(), merkledb.BranchFactor16)) + require.NoError(rangeProof.UnmarshalProto(responseProto.GetRangeProof())) // modify if needed if modifyRangeProof != nil { diff --git a/x/sync/g_db/db_client.go b/x/sync/g_db/db_client.go index 8bd936a53975..9ce402f5e6ca 100644 --- a/x/sync/g_db/db_client.go +++ b/x/sync/g_db/db_client.go @@ -19,16 +19,14 @@ import ( var _ sync.DB = (*DBClient)(nil) -func NewDBClient(client pb.DBClient, branchFactor merkledb.BranchFactor) *DBClient { +func NewDBClient(client pb.DBClient) *DBClient { return &DBClient{ - client: client, - branchFactor: branchFactor, + client: client, } } type DBClient struct { - client pb.DBClient - branchFactor merkledb.BranchFactor + client pb.DBClient } func (c *DBClient) GetMerkleRoot(ctx context.Context) (ids.ID, error) { @@ -70,7 +68,7 @@ func (c *DBClient) GetChangeProof( } var proof merkledb.ChangeProof - if err := proof.UnmarshalProto(resp.GetChangeProof(), c.branchFactor); err != nil { + if err := proof.UnmarshalProto(resp.GetChangeProof()); err != nil { return nil, err } return &proof, nil @@ -122,7 +120,7 @@ func (c *DBClient) GetProof(ctx context.Context, key []byte) (*merkledb.Proof, e } var proof merkledb.Proof - if err := proof.UnmarshalProto(resp.Proof, c.branchFactor); err != nil { + if err := proof.UnmarshalProto(resp.Proof); err != nil { return nil, err } return &proof, nil @@ -152,7 +150,7 @@ func (c *DBClient) GetRangeProofAtRoot( } var proof merkledb.RangeProof - if err := proof.UnmarshalProto(resp.Proof, c.branchFactor); err != nil { + if err := proof.UnmarshalProto(resp.Proof); err != nil { return nil, err } return &proof, nil diff --git a/x/sync/g_db/db_server.go b/x/sync/g_db/db_server.go index b6471542dcca..da454e2d77cd 100644 --- a/x/sync/g_db/db_server.go +++ b/x/sync/g_db/db_server.go @@ -19,18 +19,16 @@ import ( var _ pb.DBServer = (*DBServer)(nil) -func NewDBServer(db sync.DB, branchFactor merkledb.BranchFactor) *DBServer { +func NewDBServer(db sync.DB) *DBServer { return &DBServer{ - db: db, - branchFactor: branchFactor, + db: db, } } type DBServer struct { pb.UnsafeDBServer - db sync.DB - branchFactor merkledb.BranchFactor + db sync.DB } func (s *DBServer) GetMerkleRoot( @@ -98,7 +96,7 @@ func (s *DBServer) VerifyChangeProof( req *pb.VerifyChangeProofRequest, ) (*pb.VerifyChangeProofResponse, error) { var proof merkledb.ChangeProof - if err := proof.UnmarshalProto(req.Proof, s.branchFactor); err != nil { + if err := proof.UnmarshalProto(req.Proof); err != nil { return nil, err } @@ -130,7 +128,7 @@ func (s *DBServer) CommitChangeProof( req *pb.CommitChangeProofRequest, ) (*emptypb.Empty, error) { var proof merkledb.ChangeProof - if err := proof.UnmarshalProto(req.Proof, s.branchFactor); err != nil { + if err := proof.UnmarshalProto(req.Proof); err != nil { return nil, err } @@ -201,7 +199,7 @@ func (s *DBServer) CommitRangeProof( req *pb.CommitRangeProofRequest, ) (*emptypb.Empty, error) { var proof merkledb.RangeProof - if err := proof.UnmarshalProto(req.RangeProof, s.branchFactor); err != nil { + if err := proof.UnmarshalProto(req.RangeProof); err != nil { return nil, err } diff --git a/x/sync/manager.go b/x/sync/manager.go index 0a13a89eb32b..6bd81e847aee 100644 --- a/x/sync/manager.go +++ b/x/sync/manager.go @@ -10,12 +10,15 @@ import ( "fmt" "sync" + "golang.org/x/exp/maps" + "go.uber.org/zap" "golang.org/x/exp/slices" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/maybe" + "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/x/merkledb" pb "github.com/ava-labs/avalanchego/proto/pb/sync" @@ -102,9 +105,9 @@ type Manager struct { cancelCtx context.CancelFunc // Set to true when StartSyncing is called. - syncing bool - closeOnce sync.Once - branchFactor merkledb.BranchFactor + syncing bool + closeOnce sync.Once + tokenSize int } type ManagerConfig struct { @@ -136,7 +139,7 @@ func NewManager(config ManagerConfig) (*Manager, error) { doneChan: make(chan struct{}), unprocessedWork: newWorkHeap(), processedWork: newWorkHeap(), - branchFactor: config.BranchFactor, + tokenSize: merkledb.BranchFactorToTokenSize[config.BranchFactor], } m.unprocessedWorkCond.L = &m.workLock @@ -404,7 +407,7 @@ func (m *Manager) findNextKey( // and traversing them from the longest key to the shortest key. // For each node in these proofs, compare if the children of that node exist // or have the same ID in the other proof. - proofKeyPath := merkledb.ToKey(lastReceivedKey, m.branchFactor) + proofKeyPath := merkledb.ToKey(lastReceivedKey) // If the received proof is an exclusion proof, the last node may be for a // key that is after the [lastReceivedKey]. @@ -447,7 +450,7 @@ func (m *Manager) findNextKey( // select the deepest proof node from the two proofs switch { - case receivedProofNode.Key.TokensLength() > localProofNode.Key.TokensLength(): + case receivedProofNode.Key.Length() > localProofNode.Key.Length(): // there was a branch node in the received proof that isn't in the local proof // see if the received proof node has children not present in the local proof deepestNode = &receivedProofNode @@ -455,7 +458,7 @@ func (m *Manager) findNextKey( // we have dealt with this received node, so move on to the next received node receivedProofNodeIndex-- - case localProofNode.Key.TokensLength() > receivedProofNode.Key.TokensLength(): + case localProofNode.Key.Length() > receivedProofNode.Key.Length(): // there was a branch node in the local proof that isn't in the received proof // see if the local proof node has children not present in the received proof deepestNode = &localProofNode @@ -482,20 +485,20 @@ func (m *Manager) findNextKey( // If the deepest node has the same key as [proofKeyPath], // then all of its children have keys greater than the proof key, // so we can start at the 0 token. - startingChildToken := byte(0) + startingChildToken := 0 // If the deepest node has a key shorter than the key being proven, // we can look at the next token index of the proof key to determine which of that // node's children have keys larger than [proofKeyPath]. // Any child with a token greater than the [proofKeyPath]'s token at that // index will have a larger key. - if deepestNode.Key.TokensLength() < proofKeyPath.TokensLength() { - startingChildToken = proofKeyPath.Token(deepestNode.Key.TokensLength()) + 1 + if deepestNode.Key.Length() < proofKeyPath.Length() { + startingChildToken = int(proofKeyPath.Token(deepestNode.Key.Length(), m.tokenSize)) + 1 } // determine if there are any differences in the children for the deepest unhandled node of the two proofs - if childIndex, hasDifference := findChildDifference(deepestNode, deepestNodeFromOtherProof, startingChildToken, m.branchFactor); hasDifference { - nextKey = maybe.Some(deepestNode.Key.Append(childIndex).Bytes()) + if childIndex, hasDifference := findChildDifference(deepestNode, deepestNodeFromOtherProof, startingChildToken); hasDifference { + nextKey = maybe.Some(deepestNode.Key.Extend(merkledb.ToToken(childIndex, m.tokenSize)).Bytes()) break } } @@ -794,12 +797,27 @@ func midPoint(startMaybe, endMaybe maybe.Maybe[[]byte]) maybe.Maybe[[]byte] { // findChildDifference returns the first child index that is different between node 1 and node 2 if one exists and // a bool indicating if any difference was found -func findChildDifference(node1, node2 *merkledb.ProofNode, startIndex byte, branchFactor merkledb.BranchFactor) (byte, bool) { +func findChildDifference(node1, node2 *merkledb.ProofNode, startIndex int) (byte, bool) { + // Children indices >= [startIndex] present in at least one of the nodes. + childIndices := set.Set[byte]{} + for _, node := range []*merkledb.ProofNode{node1, node2} { + if node == nil { + continue + } + for key := range node.Children { + if int(key) >= startIndex { + childIndices.Add(key) + } + } + } + + sortedChildIndices := maps.Keys(childIndices) + slices.Sort(sortedChildIndices) var ( child1, child2 ids.ID ok1, ok2 bool ) - for childIndex := startIndex; merkledb.BranchFactor(childIndex) < branchFactor; childIndex++ { + for _, childIndex := range sortedChildIndices { if node1 != nil { child1, ok1 = node1.Children[childIndex] } diff --git a/x/sync/network_server_test.go b/x/sync/network_server_test.go index 60555498457f..d79b27a14c44 100644 --- a/x/sync/network_server_test.go +++ b/x/sync/network_server_test.go @@ -114,7 +114,7 @@ func Test_Server_GetRangeProof(t *testing.T) { require.NoError(proto.Unmarshal(responseBytes, &proofProto)) var p merkledb.RangeProof - require.NoError(p.UnmarshalProto(&proofProto, merkledb.BranchFactor16)) + require.NoError(p.UnmarshalProto(&proofProto)) proof = &p } return nil diff --git a/x/sync/sync_test.go b/x/sync/sync_test.go index 9a6a5de2dba9..af908c9d941c 100644 --- a/x/sync/sync_test.go +++ b/x/sync/sync_test.go @@ -586,10 +586,11 @@ func TestFindNextKeyRandom(t *testing.T) { ) require.NoError(err) + config := newDefaultDBConfig() localDB, err := merkledb.New( context.Background(), memdb.New(), - newDefaultDBConfig(), + config, ) require.NoError(err) @@ -677,7 +678,7 @@ func TestFindNextKeyRandom(t *testing.T) { for _, node := range remoteProof.EndProof { for childIdx, childID := range node.Children { remoteKeyIDs = append(remoteKeyIDs, keyAndID{ - key: node.Key.Append(childIdx), + key: node.Key.Extend(merkledb.ToToken(childIdx, merkledb.BranchFactorToTokenSize[config.BranchFactor])), id: childID, }) } @@ -688,7 +689,7 @@ func TestFindNextKeyRandom(t *testing.T) { for _, node := range localProof.Path { for childIdx, childID := range node.Children { localKeyIDs = append(localKeyIDs, keyAndID{ - key: node.Key.Append(childIdx), + key: node.Key.Extend(merkledb.ToToken(childIdx, merkledb.BranchFactorToTokenSize[config.BranchFactor])), id: childID, }) }