Skip to content

Commit

Permalink
Add Expiry Replay Protection (#3379)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenButtolph authored Sep 12, 2024
1 parent 6549c2d commit d366a13
Show file tree
Hide file tree
Showing 8 changed files with 536 additions and 4 deletions.
45 changes: 45 additions & 0 deletions vms/platformvm/state/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type diff struct {
// Subnet ID --> supply of native asset of the subnet
currentSupply map[ids.ID]uint64

expiryDiff *expiryDiff

currentStakerDiffs diffStakers
// map of subnetID -> nodeID -> total accrued delegatee rewards
modifiedDelegateeRewards map[ids.ID]map[ids.NodeID]uint64
Expand Down Expand Up @@ -79,6 +81,7 @@ func NewDiff(
timestamp: parentState.GetTimestamp(),
feeState: parentState.GetFeeState(),
accruedFees: parentState.GetAccruedFees(),
expiryDiff: newExpiryDiff(),
subnetOwners: make(map[ids.ID]fx.Owner),
subnetManagers: make(map[ids.ID]chainIDAndAddr),
}, nil
Expand Down Expand Up @@ -146,6 +149,41 @@ func (d *diff) SetCurrentSupply(subnetID ids.ID, currentSupply uint64) {
}
}

func (d *diff) GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error) {
parentState, ok := d.stateVersions.GetState(d.parentID)
if !ok {
return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
}

parentIterator, err := parentState.GetExpiryIterator()
if err != nil {
return nil, err
}

return d.expiryDiff.getExpiryIterator(parentIterator), nil
}

func (d *diff) HasExpiry(entry ExpiryEntry) (bool, error) {
if has, modified := d.expiryDiff.modified[entry]; modified {
return has, nil
}

parentState, ok := d.stateVersions.GetState(d.parentID)
if !ok {
return false, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID)
}

return parentState.HasExpiry(entry)
}

func (d *diff) PutExpiry(entry ExpiryEntry) {
d.expiryDiff.PutExpiry(entry)
}

func (d *diff) DeleteExpiry(entry ExpiryEntry) {
d.expiryDiff.DeleteExpiry(entry)
}

func (d *diff) GetCurrentValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, error) {
// If the validator was modified in this diff, return the modified
// validator.
Expand Down Expand Up @@ -451,6 +489,13 @@ func (d *diff) Apply(baseState Chain) error {
for subnetID, supply := range d.currentSupply {
baseState.SetCurrentSupply(subnetID, supply)
}
for entry, isAdded := range d.expiryDiff.modified {
if isAdded {
baseState.PutExpiry(entry)
} else {
baseState.DeleteExpiry(entry)
}
}
for _, subnetValidatorDiffs := range d.currentStakerDiffs.validatorDiffs {
for _, validatorDiff := range subnetValidatorDiffs {
switch validatorDiff.validatorStatus {
Expand Down
161 changes: 161 additions & 0 deletions vms/platformvm/state/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/ava-labs/avalanchego/utils/constants"
"github.com/ava-labs/avalanchego/utils/iterator"
"github.com/ava-labs/avalanchego/utils/iterator/iteratormock"
"github.com/ava-labs/avalanchego/utils/set"
"github.com/ava-labs/avalanchego/vms/components/avax"
"github.com/ava-labs/avalanchego/vms/components/gas"
"github.com/ava-labs/avalanchego/vms/platformvm/fx/fxmock"
Expand Down Expand Up @@ -112,6 +113,156 @@ func TestDiffCurrentSupply(t *testing.T) {
assertChainsEqual(t, state, d)
}

func TestDiffExpiry(t *testing.T) {
type op struct {
put bool
entry ExpiryEntry
}
tests := []struct {
name string
initialExpiries []ExpiryEntry
ops []op
}{
{
name: "empty noop",
},
{
name: "insert",
ops: []op{
{
put: true,
entry: ExpiryEntry{Timestamp: 1},
},
},
},
{
name: "remove",
initialExpiries: []ExpiryEntry{
{Timestamp: 1},
},
ops: []op{
{
put: false,
entry: ExpiryEntry{Timestamp: 1},
},
},
},
{
name: "add and immediately remove",
ops: []op{
{
put: true,
entry: ExpiryEntry{Timestamp: 1},
},
{
put: false,
entry: ExpiryEntry{Timestamp: 1},
},
},
},
{
name: "add + remove + add",
ops: []op{
{
put: true,
entry: ExpiryEntry{Timestamp: 1},
},
{
put: false,
entry: ExpiryEntry{Timestamp: 1},
},
{
put: true,
entry: ExpiryEntry{Timestamp: 1},
},
},
},
{
name: "everything",
initialExpiries: []ExpiryEntry{
{Timestamp: 1},
{Timestamp: 2},
{Timestamp: 3},
},
ops: []op{
{
put: false,
entry: ExpiryEntry{Timestamp: 1},
},
{
put: false,
entry: ExpiryEntry{Timestamp: 2},
},
{
put: true,
entry: ExpiryEntry{Timestamp: 1},
},
},
},
}

for _, test := range tests {
require := require.New(t)

state := newTestState(t, memdb.New())
for _, expiry := range test.initialExpiries {
state.PutExpiry(expiry)
}

d, err := NewDiffOn(state)
require.NoError(err)

var (
expectedExpiries = set.Of(test.initialExpiries...)
unexpectedExpiries set.Set[ExpiryEntry]
)
for _, op := range test.ops {
if op.put {
d.PutExpiry(op.entry)
expectedExpiries.Add(op.entry)
unexpectedExpiries.Remove(op.entry)
} else {
d.DeleteExpiry(op.entry)
expectedExpiries.Remove(op.entry)
unexpectedExpiries.Add(op.entry)
}
}

// If expectedExpiries is empty, we want expectedExpiriesSlice to be
// nil.
var expectedExpiriesSlice []ExpiryEntry
if expectedExpiries.Len() > 0 {
expectedExpiriesSlice = expectedExpiries.List()
utils.Sort(expectedExpiriesSlice)
}

verifyChain := func(chain Chain) {
expiryIterator, err := chain.GetExpiryIterator()
require.NoError(err)
require.Equal(
expectedExpiriesSlice,
iterator.ToSlice(expiryIterator),
)

for expiry := range expectedExpiries {
has, err := chain.HasExpiry(expiry)
require.NoError(err)
require.True(has)
}
for expiry := range unexpectedExpiries {
has, err := chain.HasExpiry(expiry)
require.NoError(err)
require.False(has)
}
}

verifyChain(d)
require.NoError(d.Apply(state))
verifyChain(state)
assertChainsEqual(t, d, state)
}
}

func TestDiffCurrentValidator(t *testing.T) {
require := require.New(t)
ctrl := gomock.NewController(t)
Expand Down Expand Up @@ -527,6 +678,16 @@ func assertChainsEqual(t *testing.T, expected, actual Chain) {

t.Helper()

expectedExpiryIterator, expectedErr := expected.GetExpiryIterator()
actualExpiryIterator, actualErr := actual.GetExpiryIterator()
require.Equal(expectedErr, actualErr)
if expectedErr == nil {
require.Equal(
iterator.ToSlice(expectedExpiryIterator),
iterator.ToSlice(actualExpiryIterator),
)
}

expectedCurrentStakerIterator, expectedErr := expected.GetCurrentStakerIterator()
actualCurrentStakerIterator, actualErr := actual.GetCurrentStakerIterator()
require.Equal(expectedErr, actualErr)
Expand Down
65 changes: 61 additions & 4 deletions vms/platformvm/state/expiry.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (

"github.com/ava-labs/avalanchego/database"
"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/utils"
"github.com/ava-labs/avalanchego/utils/iterator"
)

// expiryEntry = [timestamp] + [validationID]
Expand All @@ -20,8 +22,26 @@ var (
errUnexpectedExpiryEntryLength = fmt.Errorf("expected expiry entry length %d", expiryEntryLength)

_ btree.LessFunc[ExpiryEntry] = ExpiryEntry.Less
_ utils.Sortable[ExpiryEntry] = ExpiryEntry{}
)

type Expiry interface {
// GetExpiryIterator returns an iterator of all the expiry entries in order
// of lowest to highest timestamp.
GetExpiryIterator() (iterator.Iterator[ExpiryEntry], error)

// HasExpiry returns true if the database has the specified entry.
HasExpiry(ExpiryEntry) (bool, error)

// PutExpiry adds the entry to the database. If the entry already exists, it
// is a noop.
PutExpiry(ExpiryEntry)

// DeleteExpiry removes the entry from the database. If the entry doesn't
// exist, it is a noop.
DeleteExpiry(ExpiryEntry)
}

type ExpiryEntry struct {
Timestamp uint64
ValidationID ids.ID
Expand All @@ -44,14 +64,51 @@ func (e *ExpiryEntry) Unmarshal(data []byte) error {
return nil
}

// Invariant: Less produces the same ordering as the marshalled bytes.
func (e ExpiryEntry) Less(o ExpiryEntry) bool {
return e.Compare(o) == -1
}

// Invariant: Compare produces the same ordering as the marshalled bytes.
func (e ExpiryEntry) Compare(o ExpiryEntry) int {
switch {
case e.Timestamp < o.Timestamp:
return true
return -1
case e.Timestamp > o.Timestamp:
return false
return 1
default:
return e.ValidationID.Compare(o.ValidationID) == -1
return e.ValidationID.Compare(o.ValidationID)
}
}

type expiryDiff struct {
modified map[ExpiryEntry]bool // bool represents isAdded
added *btree.BTreeG[ExpiryEntry]
}

func newExpiryDiff() *expiryDiff {
return &expiryDiff{
modified: make(map[ExpiryEntry]bool),
added: btree.NewG(defaultTreeDegree, ExpiryEntry.Less),
}
}

func (e *expiryDiff) PutExpiry(entry ExpiryEntry) {
e.modified[entry] = true
e.added.ReplaceOrInsert(entry)
}

func (e *expiryDiff) DeleteExpiry(entry ExpiryEntry) {
e.modified[entry] = false
e.added.Delete(entry)
}

func (e *expiryDiff) getExpiryIterator(parentIterator iterator.Iterator[ExpiryEntry]) iterator.Iterator[ExpiryEntry] {
return iterator.Merge(
ExpiryEntry.Less,
iterator.Filter(parentIterator, func(entry ExpiryEntry) bool {
_, ok := e.modified[entry]
return ok
}),
iterator.FromTree(e.added),
)
}
Loading

0 comments on commit d366a13

Please sign in to comment.