-
Notifications
You must be signed in to change notification settings - Fork 7
/
flow_control.go
324 lines (282 loc) · 7.34 KB
/
flow_control.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
package grpctunnel
//lint:file-ignore U1000 these aren't actually unused, but staticcheck is having trouble
// determining that, likely due to the use of generics
import (
"container/list"
"context"
"math"
"sync"
"sync/atomic"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const (
// TODO: make these configurable
initialWindowSize = 65536
chunkMax = 16384
)
var errFlowControlWindowExceeded = status.Errorf(codes.ResourceExhausted, "flow control window exceeded")
// sender is responsible for sending messages and managing flow control.
type sender interface {
send(data []byte) error
updateWindow(add uint32)
}
// receiver is responsible for receiving messages and managing flow control.
type receiver[T any] interface {
accept(item T) error
close()
cancel()
dequeue() (T, bool)
}
type defaultSender struct {
ctx context.Context
sendFunc func([]byte, uint32, bool) error
windowUpdates chan struct{}
currentWindow atomic.Uint32
// does not protect any fields, just used to prevent concurrent calls to send
// (so messages are sent FIFO and not incorrectly interleaved)
mu sync.Mutex
}
func newSender(ctx context.Context, initialWindowSize uint32, sendFunc func([]byte, uint32, bool) error) sender {
s := &defaultSender{
ctx: ctx,
sendFunc: sendFunc,
windowUpdates: make(chan struct{}, 1),
}
s.currentWindow.Store(initialWindowSize)
return s
}
func (s *defaultSender) updateWindow(add uint32) {
if add == 0 {
return
}
prevWindow := s.currentWindow.Add(add) - add
if prevWindow == 0 {
select {
case s.windowUpdates <- struct{}{}:
default:
}
}
}
func (s *defaultSender) send(data []byte) error {
s.mu.Lock()
defer s.mu.Unlock()
if int64(len(data)) > math.MaxUint32 {
return status.Errorf(codes.ResourceExhausted, "serialized message is too large: %d bytes > maximum %d bytes", len(data), math.MaxUint32)
}
size := uint32(len(data))
first := true
for {
windowSz := s.currentWindow.Load()
if windowSz == 0 {
// must wait for window size update before we can send more
select {
case <-s.windowUpdates:
case <-s.ctx.Done():
return s.ctx.Err()
}
continue
}
chunkSz := windowSz
if chunkSz > uint32(len(data)) {
chunkSz = uint32(len(data))
}
if chunkSz > chunkMax {
chunkSz = chunkMax
}
if !s.currentWindow.CompareAndSwap(windowSz, windowSz-chunkSz) {
continue
}
last := chunkSz == uint32(len(data))
if err := s.sendFunc(data[:chunkSz], size, first); err != nil {
return err
}
if last {
return nil
}
first = false
data = data[chunkSz:]
}
}
// defaultReceiver is a per-stream queue of messages. When we receive a message for
// a stream over a tunnel, we have to put them into this unbounded queue to prevent
// deadlock (where one consumer of a stream channel can block all operations on the
// tunnel).
//
// In practice, this does not use unbounded memory because flow control will apply
// backpressure to senders that are outpacing respective consumers. A well-behaved
// sender will respect the flow control window. A misbehaving sender will be detected
// and messages rejected if the flow control window is exceeded.
type defaultReceiver[T any] struct {
measure func(T) uint
updateWindow func(uint32)
mu sync.Mutex
cond sync.Cond
closed, cancelled bool
items *list.List
currentWindow uint32
}
func newReceiver[T any](measure func(T) uint, updateWindow func(uint32), initialWindowSize uint32) receiver[T] {
rcvr := &defaultReceiver[T]{
measure: measure,
updateWindow: updateWindow,
items: list.New(),
currentWindow: initialWindowSize,
}
rcvr.cond.L = &rcvr.mu
return rcvr
}
func (r *defaultReceiver[T]) accept(item T) error {
sz := r.measure(item)
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return nil
}
if sz > uint(r.currentWindow) {
return errFlowControlWindowExceeded
}
r.currentWindow -= uint32(sz)
signal := r.items.Len() == 0
r.items.PushBack(item)
if signal {
r.cond.Signal()
}
return nil
}
func (r *defaultReceiver[_]) close() {
r.mu.Lock()
defer r.mu.Unlock()
r.handleClosure(&r.closed)
}
func (r *defaultReceiver[_]) cancel() {
r.mu.Lock()
defer r.mu.Unlock()
r.handleClosure(&r.cancelled)
r.items.Init() // clear list to free memory
}
func (r *defaultReceiver[_]) handleClosure(b *bool) {
if *b {
return
}
*b = true
if r.items.Len() == 0 {
r.cond.Broadcast()
}
}
func (r *defaultReceiver[T]) dequeue() (T, bool) {
var windowUpdate uint
defer func() {
// TODO: Support minimum update size, so we can batch
// updates and send fewer messages over the network.
if windowUpdate > 0 {
r.updateWindow(uint32(windowUpdate))
}
}()
r.mu.Lock()
defer r.mu.Unlock()
var zero T
for {
if r.cancelled {
return zero, false
}
element := r.items.Front()
if element != nil {
item := r.items.Remove(element).(T)
sz := r.measure(item)
r.currentWindow += uint32(sz)
windowUpdate = sz
return item, true
}
if r.closed {
return zero, false
}
r.cond.Wait()
}
}
type noFlowControlSender struct {
sendFunc func([]byte, uint32, bool) error
// does not protect any fields, just used to prevent concurrent calls to send
// (so messages are sent FIFO and not incorrectly interleaved)
mu sync.Mutex
}
func newSenderWithoutFlowControl(sendFunc func([]byte, uint32, bool) error) sender {
return &noFlowControlSender{sendFunc: sendFunc}
}
func (s *noFlowControlSender) send(data []byte) error {
s.mu.Lock()
defer s.mu.Unlock()
if int64(len(data)) > math.MaxUint32 {
return status.Errorf(codes.ResourceExhausted, "serialized message is too large: %d bytes > maximum %d bytes", len(data), math.MaxUint32)
}
size := uint32(len(data))
first := true
for {
chunkSz := uint32(chunkMax)
if chunkSz > uint32(len(data)) {
chunkSz = uint32(len(data))
}
last := chunkSz == uint32(len(data))
if err := s.sendFunc(data[:chunkSz], size, first); err != nil {
return err
}
if last {
return nil
}
first = false
data = data[chunkSz:]
}
}
func (s *noFlowControlSender) updateWindow(_ uint32) {
// should never actually be called
}
type noFlowControlReceiver[T any] struct {
ctx context.Context
ingestMu sync.Mutex
ch chan T
closed chan struct{}
doClose sync.Once
}
func newReceiverWithoutFlowControl[T any](ctx context.Context) receiver[T] {
return &noFlowControlReceiver[T]{
ctx: ctx,
ch: make(chan T, 1),
closed: make(chan struct{}),
}
}
func (r *noFlowControlReceiver[T]) accept(item T) error {
r.ingestMu.Lock()
defer r.ingestMu.Unlock()
// First check closed channel. If already closed, we can't run select
// below because trying to write to closed channel r.ch will panic.
select {
case <-r.closed:
return nil
default:
}
select {
case r.ch <- item:
case <-r.closed:
// another thread intends to close; so abort and release the lock
}
return nil
}
func (r *noFlowControlReceiver[T]) close() {
r.doClose.Do(func() {
// Let any concurrent accepting thread know that we intend
// to close and thus need the lock.
close(r.closed)
// Must close the channel while lock is held to prevent
// panic in accept().
r.ingestMu.Lock()
defer r.ingestMu.Unlock()
close(r.ch)
})
}
func (r *noFlowControlReceiver[T]) cancel() {
r.close()
}
func (r *noFlowControlReceiver[T]) dequeue() (T, bool) {
t, ok := <-r.ch
return t, ok
}