From efe2551b2e40cacbe5fba5d93d0e5ade71c4ddd3 Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Tue, 29 Aug 2023 17:26:36 +0800 Subject: [PATCH] Add BatchIO and multiple ports options to UDPMux Add BatchIO and multiple ports options to NewMultiUDPMuxPort(s) --- go.mod | 2 +- go.sum | 3 +- udp_mux.go | 8 ++ udp_mux_multi.go | 167 ++++++++++++++++++++++++++++++++++++------ udp_mux_multi_test.go | 106 ++++++++++++++++++++------- udp_muxed_conn.go | 25 ++++++- 6 files changed, 256 insertions(+), 55 deletions(-) diff --git a/go.mod b/go.mod index 600da208..a739a67b 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/pion/mdns v0.0.7 github.com/pion/randutil v0.1.0 github.com/pion/stun v0.6.1 - github.com/pion/transport/v2 v2.2.1 + github.com/pion/transport/v2 v2.2.2-0.20230829043045-6a34769ff4b0 github.com/pion/turn/v2 v2.1.3 github.com/stretchr/testify v1.8.4 golang.org/x/net v0.13.0 diff --git a/go.sum b/go.sum index 0f24d2c1..3ab4c78e 100644 --- a/go.sum +++ b/go.sum @@ -19,8 +19,9 @@ github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TB github.com/pion/stun v0.6.1 h1:8lp6YejULeHBF8NmV8e2787BogQhduZugh5PdhDyyN4= github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/8= github.com/pion/transport/v2 v2.0.0/go.mod h1:HS2MEBJTwD+1ZI2eSXSvHJx/HnzQqRy2/LXxt6eVMHc= -github.com/pion/transport/v2 v2.2.1 h1:7qYnCBlpgSJNYMbLCKuSY9KbQdBFoETvPNETv0y4N7c= github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= +github.com/pion/transport/v2 v2.2.2-0.20230829043045-6a34769ff4b0 h1:7z51t0GDPVHvR8KTnVfUGgeE0KqZvc9o5J3UMVMrykY= +github.com/pion/transport/v2 v2.2.2-0.20230829043045-6a34769ff4b0/go.mod h1:OJg3ojoBJopjEeECq2yJdXH9YVrUJ1uQ++NjXLOUorc= github.com/pion/turn/v2 v2.1.3 h1:pYxTVWG2gpC97opdRc5IGsQ1lJ9O/IlNhkzj7MMrGAA= github.com/pion/turn/v2 v2.1.3/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/udp_mux.go b/udp_mux.go index 47084f83..8afb00a7 100644 --- a/udp_mux.go +++ b/udp_mux.go @@ -23,6 +23,7 @@ type UDPMux interface { GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) RemoveConnByUfrag(ufrag string) GetListenAddresses() []net.Addr + ConnCount() int } // UDPMuxDefault is an implementation of the interface @@ -176,6 +177,13 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er return c, nil } +// ConnCount return count of working connections created by UDPMuxDefault +func (m *UDPMuxDefault) ConnCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.connsIPv4) + len(m.connsIPv6) +} + // RemoveConnByUfrag stops and removes the muxed packet connection func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { removedConns := make([]*udpMuxedConn, 0, 2) diff --git a/udp_mux_multi.go b/udp_mux_multi.go index 158cbc37..f64566ef 100644 --- a/udp_mux_multi.go +++ b/udp_mux_multi.go @@ -6,10 +6,12 @@ package ice import ( "fmt" "net" + "time" "github.com/pion/logging" "github.com/pion/transport/v2" "github.com/pion/transport/v2/stdnet" + tudp "github.com/pion/transport/v2/udp" ) // MultiUDPMuxDefault implements both UDPMux and AllConnsGetter, @@ -18,20 +20,93 @@ import ( type MultiUDPMuxDefault struct { muxes []UDPMux localAddrToMux map[string]UDPMux + + enablePortBalance bool + // Manage port balance for mux that listen on multiple ports for same IP, + // for each IP, only return one addr (one port) for each GetListenAddresses call to + // avoid duplicate ip candidates be gathered for a single ice agent. + multiPortsAddresses []*multiPortsAddress +} + +type addrMux struct { + addr net.Addr + mux UDPMux +} + +// each multiPortsAddress represents muxes listen on different ports of a same IP +type multiPortsAddress struct { + addresseMuxes []*addrMux +} + +func (mpa *multiPortsAddress) next() net.Addr { + leastAddr, leastConns := mpa.addresseMuxes[0].addr, mpa.addresseMuxes[0].mux.ConnCount() + for i := 1; i < len(mpa.addresseMuxes); i++ { + am := mpa.addresseMuxes[i] + if count := am.mux.ConnCount(); count < leastConns { + leastConns = count + leastAddr = am.addr + } + } + return leastAddr +} + +// MultiUDPMuxOption provide options for NewMultiUDPMuxDefault +type MultiUDPMuxOption func(*multipleUDPMuxDefaultParams) + +// MultiUDPMuxOptionWithPortBalance enable traffic balance on ports that belongs to same IP address, +// that means the MultiUDPMuxDefault will return a single port that has least connection for each IP address +// in GetListenAddresses return. +func MultiUDPMuxOptionWithPortBalance() MultiUDPMuxOption { + return func(params *multipleUDPMuxDefaultParams) { + params.portBalance = true + } +} + +type multipleUDPMuxDefaultParams struct { + portBalance bool } // NewMultiUDPMuxDefault creates an instance of MultiUDPMuxDefault that // uses the provided UDPMux instances. func NewMultiUDPMuxDefault(muxes ...UDPMux) *MultiUDPMuxDefault { + return NewMultiUDPMuxDefaultWithOptions(muxes) +} + +// NewMultiUDPMuxDefaultWithOptions creates an instance of MultiUDPMuxDefault that +// uses the provided UDPMux instances and options. +func NewMultiUDPMuxDefaultWithOptions(muxes []UDPMux, opts ...MultiUDPMuxOption) *MultiUDPMuxDefault { + var params multipleUDPMuxDefaultParams + for _, opt := range opts { + opt(¶ms) + } + addrToMux := make(map[string]UDPMux) + ipToAddrs := make(map[string]*multiPortsAddress) for _, mux := range muxes { for _, addr := range mux.GetListenAddresses() { addrToMux[addr.String()] = mux + + udpAddr, _ := addr.(*net.UDPAddr) + ip := udpAddr.IP.String() + if mpa, ok := ipToAddrs[ip]; ok { + mpa.addresseMuxes = append(mpa.addresseMuxes, &addrMux{addr, mux}) + } else { + ipToAddrs[ip] = &multiPortsAddress{ + addresseMuxes: []*addrMux{{addr, mux}}, + } + } } } + + multiPortsAddresses := make([]*multiPortsAddress, 0, len(ipToAddrs)) + for _, mpa := range ipToAddrs { + multiPortsAddresses = append(multiPortsAddresses, mpa) + } return &MultiUDPMuxDefault{ - muxes: muxes, - localAddrToMux: addrToMux, + muxes: muxes, + localAddrToMux: addrToMux, + multiPortsAddresses: multiPortsAddresses, + enablePortBalance: params.portBalance, } } @@ -45,6 +120,15 @@ func (m *MultiUDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketCon return mux.GetConn(ufrag, addr) } +// ConnCount return count of working connections created by the mux. +func (m *MultiUDPMuxDefault) ConnCount() int { + var count int + for _, mux := range m.muxes { + count += mux.ConnCount() + } + return count +} + // RemoveConnByUfrag stops and removes the muxed packet connection // from all underlying UDPMux instances. func (m *MultiUDPMuxDefault) RemoveConnByUfrag(ufrag string) { @@ -64,8 +148,18 @@ func (m *MultiUDPMuxDefault) Close() error { return err } -// GetListenAddresses returns the list of addresses that this mux is listening on +// GetListenAddresses returns the list of addresses that this mux is listening on, +// if port balance enabled and there are multiple mux listen on different ports of a same IP addr, +// will return the mux who has least connections of that IP addr. func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr { + if m.enablePortBalance { + addrs := make([]net.Addr, 0, len(m.multiPortsAddresses)) + for _, mpa := range m.multiPortsAddresses { + addrs = append(addrs, mpa.next()) + } + return addrs + } + addrs := make([]net.Addr, 0, len(m.localAddrToMux)) for _, mux := range m.muxes { addrs = append(addrs, mux.GetListenAddresses()...) @@ -76,6 +170,12 @@ func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr { // NewMultiUDPMuxFromPort creates an instance of MultiUDPMuxDefault that // listen all interfaces on the provided port. func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) { + return NewMultiUDPMuxFromPorts([]int{port}, opts...) +} + +// NewMultiUDPMuxFromPorts creates an instance of MultiUDPMuxDefault that +// listen all interfaces and balance traffic on the provided ports. +func NewMultiUDPMuxFromPorts(ports []int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) { params := multiUDPMuxFromPortParam{ networks: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, } @@ -95,20 +195,29 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu return nil, err } - conns := make([]net.PacketConn, 0, len(ips)) + conns := make([]net.PacketConn, 0, len(ports)*len(ips)) for _, ip := range ips { - conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{IP: ip, Port: port}) - if listenErr != nil { - err = listenErr - break - } - if params.readBufferSize > 0 { - _ = conn.SetReadBuffer(params.readBufferSize) + for _, port := range ports { + conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{IP: ip, Port: port}) + if listenErr != nil { + err = listenErr + break + } + if params.readBufferSize > 0 { + _ = conn.SetReadBuffer(params.readBufferSize) + } + if params.writeBufferSize > 0 { + _ = conn.SetWriteBuffer(params.writeBufferSize) + } + if params.batchWriteSize > 0 { + conns = append(conns, tudp.NewBatchConn(conn, params.batchWriteSize, params.batchWriteInterval)) + } else { + conns = append(conns, conn) + } } - if params.writeBufferSize > 0 { - _ = conn.SetWriteBuffer(params.writeBufferSize) + if err != nil { + break } - conns = append(conns, conn) } if err != nil { @@ -128,7 +237,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu muxes = append(muxes, mux) } - return NewMultiUDPMuxDefault(muxes...), nil + return NewMultiUDPMuxDefaultWithOptions(muxes, MultiUDPMuxOptionWithPortBalance()), nil } // UDPMuxFromPortOption provide options for NewMultiUDPMuxFromPort @@ -137,14 +246,16 @@ type UDPMuxFromPortOption interface { } type multiUDPMuxFromPortParam struct { - ifFilter func(string) bool - ipFilter func(ip net.IP) bool - networks []NetworkType - readBufferSize int - writeBufferSize int - logger logging.LeveledLogger - includeLoopback bool - net transport.Net + ifFilter func(string) bool + ipFilter func(ip net.IP) bool + networks []NetworkType + readBufferSize int + writeBufferSize int + logger logging.LeveledLogger + includeLoopback bool + net transport.Net + batchWriteSize int + batchWriteInterval time.Duration } type udpMuxFromPortOption struct { @@ -226,3 +337,13 @@ func UDPMuxFromPortWithNet(n transport.Net) UDPMuxFromPortOption { }, } } + +// UDPMuxFromPortWithBatchWrite enable batch write for UDPMux +func UDPMuxFromPortWithBatchWrite(batchWriteSize int, batchWriteInterval time.Duration) UDPMuxFromPortOption { + return &udpMuxFromPortOption{ + f: func(p *multiUDPMuxFromPortParam) { + p.batchWriteSize = batchWriteSize + p.batchWriteInterval = batchWriteInterval + }, + } +} diff --git a/udp_mux_multi_test.go b/udp_mux_multi_test.go index f37617df..6bb622c3 100644 --- a/udp_mux_multi_test.go +++ b/udp_mux_multi_test.go @@ -7,6 +7,7 @@ package ice import ( + "fmt" "net" "strings" "sync" @@ -110,6 +111,11 @@ func testMultiUDPMuxConnections(t *testing.T, udpMuxMulti *MultiUDPMuxDefault, u } } +type muxCase struct { + ports []int + opts []UDPMuxFromPortOption +} + func TestUnspecifiedUDPMux(t *testing.T) { report := test.CheckRoutines(t) defer report() @@ -117,39 +123,87 @@ func TestUnspecifiedUDPMux(t *testing.T) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() - muxPort := 7778 - udpMuxMulti, err := NewMultiUDPMuxFromPort(muxPort, UDPMuxFromPortWithInterfaceFilter(func(s string) bool { + cases := map[string]muxCase{ + "single port": {ports: []int{7778}}, + "multi ports": {ports: []int{7779, 7780, 7781}}, + "batch write": {ports: []int{7782}, opts: []UDPMuxFromPortOption{UDPMuxFromPortWithBatchWrite(10, 2*time.Millisecond)}}, + } + + for name, val := range cases { + cname, muxCase := name, val + t.Run(cname, func(t *testing.T) { + opts := []UDPMuxFromPortOption{ + UDPMuxFromPortWithInterfaceFilter(func(s string) bool { + defaultDockerBridgeNetwork := strings.Contains(s, "docker") + customDockerBridgeNetwork := strings.Contains(s, "br-") + return !defaultDockerBridgeNetwork && !customDockerBridgeNetwork + }), + UDPMuxFromPortWithReadBufferSize(1024 * 1024), + UDPMuxFromPortWithWriteBufferSize(1024 * 1024), + } + opts = append(opts, muxCase.opts...) + udpMuxMulti, err := NewMultiUDPMuxFromPorts(muxCase.ports, opts...) + require.NoError(t, err) + + require.GreaterOrEqual(t, len(udpMuxMulti.muxes), 1, "at least have 1 muxes") + defer func() { + _ = udpMuxMulti.Close() + }() + + wg := sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp) + }() + wg.Add(1) + go func() { + defer wg.Done() + testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", udp4) + }() + + // Skip IPv6 test on i386 + const ptrSize = 32 << (^uintptr(0) >> 63) + if ptrSize != 32 { + testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", udp6) + } + + wg.Wait() + + require.NoError(t, udpMuxMulti.Close()) + }) + } +} + +func TestMultiUDPMuxPortBalance(t *testing.T) { + ports := []int{8779, 8780, 8781} + udpMuxMulti, err := NewMultiUDPMuxFromPorts(ports, UDPMuxFromPortWithInterfaceFilter(func(s string) bool { defaultDockerBridgeNetwork := strings.Contains(s, "docker") customDockerBridgeNetwork := strings.Contains(s, "br-") return !defaultDockerBridgeNetwork && !customDockerBridgeNetwork })) require.NoError(t, err) - - require.GreaterOrEqual(t, len(udpMuxMulti.muxes), 1, "at least have 1 muxes") defer func() { - _ = udpMuxMulti.Close() + require.NoError(t, udpMuxMulti.Close()) }() - wg := sync.WaitGroup{} - - wg.Add(1) - go func() { - defer wg.Done() - testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp) - }() - wg.Add(1) - go func() { - defer wg.Done() - testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", udp4) - }() - - // Skip IPv6 test on i386 - const ptrSize = 32 << (^uintptr(0) >> 63) - if ptrSize != 32 { - testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", udp6) + // check port allocation is balanced + expectPorts := make(map[int]bool) + for i := range ports { + addr := udpMuxMulti.GetListenAddresses()[0] + ufrag := fmt.Sprintf("ufragetest%d", i) + conn, err := udpMuxMulti.GetConn(ufrag, addr) + require.NoError(t, err) + require.NotNil(t, conn) + udpLocalAddr, _ := conn.LocalAddr().(*net.UDPAddr) + require.False(t, expectPorts[udpLocalAddr.Port], fmt.Sprint("port ", udpLocalAddr.Port, " is already used", expectPorts)) + expectPorts[udpLocalAddr.Port] = true + + conn2, err := udpMuxMulti.GetConn(ufrag, addr) + require.NoError(t, err) + require.Equal(t, conn, conn2) } - - wg.Wait() - - require.NoError(t, udpMuxMulti.Close()) + require.Equal(t, len(ports), len(expectPorts)) + require.Equal(t, len(ports), udpMuxMulti.ConnCount()) } diff --git a/udp_muxed_conn.go b/udp_muxed_conn.go index 09e4b3a8..8c3c078e 100644 --- a/udp_muxed_conn.go +++ b/udp_muxed_conn.go @@ -8,12 +8,17 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/pion/logging" "github.com/pion/transport/v2/packetio" ) +const ( + iceConnectedTimeout = 25 * time.Second +) + type udpMuxedConnParams struct { Mux *UDPMuxDefault AddrPool *sync.Pool @@ -33,6 +38,9 @@ type udpMuxedConn struct { closedChan chan struct{} closeOnce sync.Once mu sync.Mutex + + startAt time.Time + iceConnected atomic.Bool } func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn { @@ -40,6 +48,7 @@ func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn { params: params, buf: packetio.NewBuffer(), closedChan: make(chan struct{}), + startAt: time.Now(), } return p @@ -80,10 +89,18 @@ func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { if c.isClosed() { return 0, io.ErrClosedPipe } - // Each time we write to a new address, we'll register it with the mux - addr := rAddr.String() - if !c.containsAddress(addr) { - c.addAddress(addr) + + // Only checking the address at the ice connecting stage to reduce the check cost + if !c.iceConnected.Load() { + if time.Since(c.startAt) > iceConnectedTimeout { + c.iceConnected.Store(true) + } else { + // Each time we write to a new address, we'll register it with the mux + addr := rAddr.String() + if !c.containsAddress(addr) { + c.addAddress(addr) + } + } } return c.params.Mux.writeTo(buf, rAddr)