diff --git a/p2p/exchange.go b/p2p/exchange.go index 607bfb0f..503d7494 100644 --- a/p2p/exchange.go +++ b/p2p/exchange.go @@ -44,7 +44,8 @@ type Exchange[H header.Header[H]] struct { peerTracker *peerTracker metrics *exchangeMetrics - Params ClientParameters + Params ClientParameters + CustomValidate func(H) error } func NewExchange[H header.Header[H]]( @@ -253,6 +254,7 @@ func (ex *Exchange[H]) GetRangeByHeight( ) ([]H, error) { session := newSession[H]( ex.ctx, ex.host, ex.peerTracker, ex.protocolID, ex.Params.RangeRequestTimeout, ex.metrics, withValidation(from), + func(s *session[H]) { s.customValidate = ex.CustomValidate }, ) defer session.close() // we request the next header height that we don't have: `fromHead`+1 @@ -335,7 +337,7 @@ func (ex *Exchange[H]) request( return nil, err } - hdrs, err := processResponses[H](responses) + hdrs, err := processResponses[H](responses, ex.CustomValidate) if err != nil { return nil, err } diff --git a/p2p/session.go b/p2p/session.go index 37f1e258..276c9d84 100644 --- a/p2p/session.go +++ b/p2p/session.go @@ -41,9 +41,10 @@ type session[H header.Header[H]] struct { from H requestTimeout time.Duration - ctx context.Context - cancel context.CancelFunc - reqCh chan *p2p_pb.HeaderRequest + ctx context.Context + cancel context.CancelFunc + reqCh chan *p2p_pb.HeaderRequest + customValidate func(H) error } func newSession[H header.Header[H]]( @@ -222,7 +223,7 @@ func (s *session[H]) doRequest( // processResponses converts HeaderResponse to Header. func (s *session[H]) processResponses(responses []*p2p_pb.HeaderResponse) ([]H, error) { - hdrs, err := processResponses[H](responses) + hdrs, err := processResponses[H](responses, s.customValidate) if err != nil { return nil, err } @@ -288,7 +289,7 @@ func prepareRequests(from, amount, headersPerPeer uint64) []*p2p_pb.HeaderReques } // processResponses converts HeaderResponses to Headers -func processResponses[H header.Header[H]](resps []*p2p_pb.HeaderResponse) ([]H, error) { +func processResponses[H header.Header[H]](resps []*p2p_pb.HeaderResponse, customValidate func(H) error) ([]H, error) { if len(resps) == 0 { return nil, errEmptyResponse } @@ -311,6 +312,13 @@ func processResponses[H header.Header[H]](resps []*p2p_pb.HeaderResponse) ([]H, return nil, err } + if customValidate != nil { + err = customValidate(hdr) + if err != nil { + return nil, err + } + } + hdrs = append(hdrs, hdr) } diff --git a/p2p/subscriber.go b/p2p/subscriber.go index c5f70dcc..99781267 100644 --- a/p2p/subscriber.go +++ b/p2p/subscriber.go @@ -26,10 +26,11 @@ type SubscriberParams struct { type Subscriber[H header.Header[H]] struct { pubsubTopicID string - metrics *subscriberMetrics - pubsub *pubsub.PubSub - topic *pubsub.Topic - msgID pubsub.MsgIdFunction + metrics *subscriberMetrics + pubsub *pubsub.PubSub + topic *pubsub.Topic + msgID pubsub.MsgIdFunction + CustomValidate func(H) error } // WithSubscriberMetrics enables metrics collection for the Subscriber. @@ -118,6 +119,17 @@ func (s *Subscriber[H]) SetVerifier(val func(context.Context, H) error) error { return pubsub.ValidationReject } + if s.CustomValidate != nil { + err = s.CustomValidate(hdr) + if err != nil { + log.Errorw("invalid header", + "from", p.ShortString(), + "err", err) + s.metrics.reject(ctx) + return pubsub.ValidationReject + } + } + var verErr *header.VerifyError err = val(ctx, hdr) switch {