Skip to content

Commit

Permalink
drpcstream: fix data race in packet buffer handling
Browse files Browse the repository at this point in the history
the packet buffer would go into a finished state when
closed even if someone is still holding and using the
data slice from get, allowing put to unblock and causing
a data race. this adds a held boolean to the packet buffer
state so that close will wait for any handing of the
data buffer before going into the finished state.

fixes #48

Change-Id: I131fc7addb26153e31b68015e58a2e400c1ce8f5
  • Loading branch information
zeebo committed Feb 20, 2023
1 parent 220d855 commit 89d4b63
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 23 deletions.
22 changes: 19 additions & 3 deletions drpcstream/pktbuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type packetBuffer struct {
err error
data []byte
set bool
held bool
}

func (pb *packetBuffer) init() {
Expand All @@ -23,10 +24,14 @@ func (pb *packetBuffer) Close(err error) {
pb.mu.Lock()
defer pb.mu.Unlock()

for pb.held {
pb.cond.Wait()
}

if pb.err == nil {
pb.err = err
pb.data = nil
pb.set = false
pb.err = err
pb.cond.Broadcast()
}
}
Expand All @@ -35,15 +40,19 @@ func (pb *packetBuffer) Put(data []byte) {
pb.mu.Lock()
defer pb.mu.Unlock()

for pb.set && pb.err == nil {
pb.cond.Wait()
}
if pb.err != nil {
return
}

pb.data = data
pb.set = true
pb.held = false
pb.cond.Broadcast()

for pb.set && pb.err == nil {
for pb.set || pb.held {
pb.cond.Wait()
}
}
Expand All @@ -55,8 +64,14 @@ func (pb *packetBuffer) Get() ([]byte, error) {
for !pb.set && pb.err == nil {
pb.cond.Wait()
}
if pb.err != nil {
return nil, pb.err
}

pb.held = true
pb.cond.Broadcast()

return pb.data, pb.err
return pb.data, nil
}

func (pb *packetBuffer) Done() {
Expand All @@ -65,5 +80,6 @@ func (pb *packetBuffer) Done() {

pb.data = nil
pb.set = false
pb.held = false
pb.cond.Broadcast()
}
49 changes: 30 additions & 19 deletions drpcstream/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,37 +374,39 @@ func (s *Stream) rawFlushLocked() (err error) {
return s.checkCancelError(errs.Wrap(s.wr.Flush()))
}

// RawRecv returns the raw bytes received for a message.
func (s *Stream) RawRecv() (data []byte, err error) {
data, err = s.rawRecv()
if err != nil {
return nil, err
}
data = append([]byte(nil), data...)
s.pbuf.Done()
return data, nil
}

// rawRecv returns the raw bytes received for a message. It does not make a
// copy of the bytes and so care must be taken to signal when HandlePacket
// is allowed to return.
func (s *Stream) rawRecv() (data []byte, err error) {
func (s *Stream) checkRecvFlush() (err error) {
s.flush.Do(func() { err = s.RawFlush() })
if err != nil {
return nil, err
return err
}

if s.opts.ManualFlush && !s.wr.Empty() {
if err := s.RawFlush(); err != nil {
return nil, err
return err
}
}

return nil
}

// RawRecv returns the raw bytes received for a message.
func (s *Stream) RawRecv() (data []byte, err error) {
if err := s.checkRecvFlush(); err != nil {
return nil, err
}

defer s.checkFinished()
s.read.Lock()
defer s.read.Unlock()

return s.pbuf.Get()
data, err = s.pbuf.Get()
if err != nil {
return nil, err
}
data = append([]byte(nil), data...)
s.pbuf.Done()

return data, nil
}

//
Expand Down Expand Up @@ -437,12 +439,21 @@ func (s *Stream) MsgSend(msg drpc.Message, enc drpc.Encoding) (err error) {

// MsgRecv recives some message data and unmarshals it with enc into msg.
func (s *Stream) MsgRecv(msg drpc.Message, enc drpc.Encoding) (err error) {
data, err := s.rawRecv()
if err := s.checkRecvFlush(); err != nil {
return err
}

defer s.checkFinished()
s.read.Lock()
defer s.read.Unlock()

data, err := s.pbuf.Get()
if err != nil {
return err
}
err = enc.Unmarshal(data, msg)
s.pbuf.Done()

return err
}

Expand Down
45 changes: 44 additions & 1 deletion drpcstream/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func TestStream_CorkUntilFirstRead(t *testing.T) {

assert.Equal(t, buf.String(), "\x05\x00\x01\x05write")
}
for i := 0; i < 10000; i++ {
for i := 0; i < 100; i++ {
run()
}
}
Expand All @@ -259,3 +259,46 @@ func (byteEncoding) Unmarshal(buf []byte, msg drpc.Message) error {
*msg.(*[]byte) = append(*msg.(*[]byte), buf...)
return nil
}

func TestStream_PacketBufferReuse(t *testing.T) {
run := func() {
ctx := drpctest.NewTracker(t)
defer ctx.Close()
defer ctx.Wait()

buf := make([]byte, 20)
st := New(ctx, 0, drpcwire.NewWriter(io.Discard, 0))

ctx.Run(func(ctx context.Context) {
for !st.IsTerminated() {
err := st.HandlePacket(drpcwire.Packet{
Data: buf,
Kind: drpcwire.KindMessage,
})
if err != nil {
return
}
for i := range buf {
buf[i]++
}
}
})

ctx.Run(func(ctx context.Context) {
for !st.IsTerminated() {
_, err := st.RawRecv()
if err != nil {
return
}
}
})

ctx.Run(func(ctx context.Context) {
st.Cancel(context.Canceled)
})
}

for i := 0; i < 100; i++ {
run()
}
}

0 comments on commit 89d4b63

Please sign in to comment.