diff --git a/go.mod b/go.mod index 600da208..f4d6c95c 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/pion/mdns v0.0.7 github.com/pion/randutil v0.1.0 github.com/pion/stun v0.6.1 - github.com/pion/transport/v2 v2.2.1 + github.com/pion/transport/v2 v2.2.2 github.com/pion/turn/v2 v2.1.3 github.com/stretchr/testify v1.8.4 golang.org/x/net v0.13.0 diff --git a/go.sum b/go.sum index 0f24d2c1..2c07a13a 100644 --- a/go.sum +++ b/go.sum @@ -19,8 +19,9 @@ github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TB github.com/pion/stun v0.6.1 h1:8lp6YejULeHBF8NmV8e2787BogQhduZugh5PdhDyyN4= github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/8= github.com/pion/transport/v2 v2.0.0/go.mod h1:HS2MEBJTwD+1ZI2eSXSvHJx/HnzQqRy2/LXxt6eVMHc= -github.com/pion/transport/v2 v2.2.1 h1:7qYnCBlpgSJNYMbLCKuSY9KbQdBFoETvPNETv0y4N7c= github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= +github.com/pion/transport/v2 v2.2.2 h1:yv+EKSU2dpmInuCebQ1rsBFCYL7p+aV90xIlshSBO+A= +github.com/pion/transport/v2 v2.2.2/go.mod h1:OJg3ojoBJopjEeECq2yJdXH9YVrUJ1uQ++NjXLOUorc= github.com/pion/turn/v2 v2.1.3 h1:pYxTVWG2gpC97opdRc5IGsQ1lJ9O/IlNhkzj7MMrGAA= github.com/pion/turn/v2 v2.1.3/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/udp_mux.go b/udp_mux.go index 47084f83..823bdbeb 100644 --- a/udp_mux.go +++ b/udp_mux.go @@ -25,6 +25,11 @@ type UDPMux interface { GetListenAddresses() []net.Addr } +// MuxConnCount return count of working connections created by the mux. +type MuxConnCount interface { + ConnCount() int +} + // UDPMuxDefault is an implementation of the interface type UDPMuxDefault struct { params UDPMuxParams @@ -176,6 +181,13 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er return c, nil } +// ConnCount return count of working connections created by UDPMuxDefault +func (m *UDPMuxDefault) ConnCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.connsIPv4) + len(m.connsIPv6) +} + // RemoveConnByUfrag stops and removes the muxed packet connection func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { removedConns := make([]*udpMuxedConn, 0, 2) diff --git a/udp_mux_multi.go b/udp_mux_multi.go index 158cbc37..e25ef547 100644 --- a/udp_mux_multi.go +++ b/udp_mux_multi.go @@ -4,12 +4,19 @@ package ice import ( + "errors" "fmt" "net" + "time" "github.com/pion/logging" "github.com/pion/transport/v2" "github.com/pion/transport/v2/stdnet" + tudp "github.com/pion/transport/v2/udp" +) + +var ( + errPortBalanceRequireConnCount = errors.New("Port balance requires UDPMux implements MuxConnCount interface") ) // MultiUDPMuxDefault implements both UDPMux and AllConnsGetter, @@ -18,21 +25,108 @@ import ( type MultiUDPMuxDefault struct { muxes []UDPMux localAddrToMux map[string]UDPMux + + enablePortBalance bool + // Manage port balance for mux that listen on multiple ports for same IP, + // for each IP, only return one addr (one port) for each GetListenAddresses call to + // avoid duplicate ip candidates be gathered for a single ice agent. + multiPortsAddresses []*multiPortsAddress +} + +type addrMux struct { + addr net.Addr + mux MuxConnCount +} + +// each multiPortsAddress represents muxes listen on different ports of a same IP +type multiPortsAddress struct { + addresseMuxes []*addrMux +} + +func (mpa *multiPortsAddress) next() net.Addr { + leastAddr, leastConns := mpa.addresseMuxes[0].addr, mpa.addresseMuxes[0].mux.ConnCount() + for i := 1; i < len(mpa.addresseMuxes); i++ { + am := mpa.addresseMuxes[i] + if count := am.mux.ConnCount(); count < leastConns { + leastConns = count + leastAddr = am.addr + } + } + return leastAddr +} + +// MultiUDPMuxOption provide options for NewMultiUDPMuxDefault +type MultiUDPMuxOption func(*multipleUDPMuxDefaultParams) + +// MultiUDPMuxOptionWithPortBalance enables load balancing traffic on multiple ports belonging to the same IP +// When enabled, GetListenAddresses will return the port with the least number of connections for each corresponding IP +func MultiUDPMuxOptionWithPortBalance() MultiUDPMuxOption { + return func(params *multipleUDPMuxDefaultParams) { + params.portBalance = true + } +} + +type multipleUDPMuxDefaultParams struct { + portBalance bool } // NewMultiUDPMuxDefault creates an instance of MultiUDPMuxDefault that // uses the provided UDPMux instances. func NewMultiUDPMuxDefault(muxes ...UDPMux) *MultiUDPMuxDefault { + mux, err := NewMultiUDPMuxDefaultWithOptions(muxes) + if err != nil { + panic(err) + } + return mux +} + +// NewMultiUDPMuxDefaultWithOptions creates an instance of MultiUDPMuxDefault that +// uses the provided UDPMux instances and options. +func NewMultiUDPMuxDefaultWithOptions(muxes []UDPMux, opts ...MultiUDPMuxOption) (*MultiUDPMuxDefault, error) { + var params multipleUDPMuxDefaultParams + for _, opt := range opts { + opt(¶ms) + } + + if params.portBalance { + for _, mux := range muxes { + if _, ok := mux.(MuxConnCount); !ok { + return nil, errPortBalanceRequireConnCount + } + } + } + addrToMux := make(map[string]UDPMux) + ipToAddrs := make(map[string]*multiPortsAddress) for _, mux := range muxes { for _, addr := range mux.GetListenAddresses() { addrToMux[addr.String()] = mux + + if params.portBalance { + muxCount, _ := mux.(MuxConnCount) + udpAddr, _ := addr.(*net.UDPAddr) + ip := udpAddr.IP.String() + if mpa, ok := ipToAddrs[ip]; ok { + mpa.addresseMuxes = append(mpa.addresseMuxes, &addrMux{addr, muxCount}) + } else { + ipToAddrs[ip] = &multiPortsAddress{ + addresseMuxes: []*addrMux{{addr, muxCount}}, + } + } + } } } - return &MultiUDPMuxDefault{ - muxes: muxes, - localAddrToMux: addrToMux, + + multiPortsAddresses := make([]*multiPortsAddress, 0, len(ipToAddrs)) + for _, mpa := range ipToAddrs { + multiPortsAddresses = append(multiPortsAddresses, mpa) } + return &MultiUDPMuxDefault{ + muxes: muxes, + localAddrToMux: addrToMux, + multiPortsAddresses: multiPortsAddresses, + enablePortBalance: params.portBalance, + }, nil } // GetConn returns a PacketConn given the connection's ufrag and network @@ -64,8 +158,18 @@ func (m *MultiUDPMuxDefault) Close() error { return err } -// GetListenAddresses returns the list of addresses that this mux is listening on +// GetListenAddresses returns the list of addresses that this mux is listening on, +// if port balance enabled and there are multiple muxes listening to different ports of the same IP addr, +// it will return the mux that has the least number of connections. func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr { + if m.enablePortBalance { + addrs := make([]net.Addr, 0, len(m.multiPortsAddresses)) + for _, mpa := range m.multiPortsAddresses { + addrs = append(addrs, mpa.next()) + } + return addrs + } + addrs := make([]net.Addr, 0, len(m.localAddrToMux)) for _, mux := range m.muxes { addrs = append(addrs, mux.GetListenAddresses()...) @@ -76,6 +180,12 @@ func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr { // NewMultiUDPMuxFromPort creates an instance of MultiUDPMuxDefault that // listen all interfaces on the provided port. func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) { + return NewMultiUDPMuxFromPorts([]int{port}, opts...) +} + +// NewMultiUDPMuxFromPorts creates an instance of MultiUDPMuxDefault that +// listens to all interfaces and balances traffic on the provided ports. +func NewMultiUDPMuxFromPorts(ports []int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) { params := multiUDPMuxFromPortParam{ networks: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, } @@ -95,20 +205,29 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu return nil, err } - conns := make([]net.PacketConn, 0, len(ips)) + conns := make([]net.PacketConn, 0, len(ports)*len(ips)) for _, ip := range ips { - conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{IP: ip, Port: port}) - if listenErr != nil { - err = listenErr - break + for _, port := range ports { + conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{IP: ip, Port: port}) + if listenErr != nil { + err = listenErr + break + } + if params.readBufferSize > 0 { + _ = conn.SetReadBuffer(params.readBufferSize) + } + if params.writeBufferSize > 0 { + _ = conn.SetWriteBuffer(params.writeBufferSize) + } + if params.batchWriteSize > 0 { + conns = append(conns, tudp.NewBatchConn(conn, params.batchWriteSize, params.batchWriteInterval)) + } else { + conns = append(conns, conn) + } } - if params.readBufferSize > 0 { - _ = conn.SetReadBuffer(params.readBufferSize) - } - if params.writeBufferSize > 0 { - _ = conn.SetWriteBuffer(params.writeBufferSize) + if err != nil { + break } - conns = append(conns, conn) } if err != nil { @@ -128,7 +247,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu muxes = append(muxes, mux) } - return NewMultiUDPMuxDefault(muxes...), nil + return NewMultiUDPMuxDefaultWithOptions(muxes, MultiUDPMuxOptionWithPortBalance()) } // UDPMuxFromPortOption provide options for NewMultiUDPMuxFromPort @@ -137,14 +256,16 @@ type UDPMuxFromPortOption interface { } type multiUDPMuxFromPortParam struct { - ifFilter func(string) bool - ipFilter func(ip net.IP) bool - networks []NetworkType - readBufferSize int - writeBufferSize int - logger logging.LeveledLogger - includeLoopback bool - net transport.Net + ifFilter func(string) bool + ipFilter func(ip net.IP) bool + networks []NetworkType + readBufferSize int + writeBufferSize int + logger logging.LeveledLogger + includeLoopback bool + net transport.Net + batchWriteSize int + batchWriteInterval time.Duration } type udpMuxFromPortOption struct { @@ -226,3 +347,13 @@ func UDPMuxFromPortWithNet(n transport.Net) UDPMuxFromPortOption { }, } } + +// UDPMuxFromPortWithBatchWrite enable batch write for UDPMux +func UDPMuxFromPortWithBatchWrite(batchWriteSize int, batchWriteInterval time.Duration) UDPMuxFromPortOption { + return &udpMuxFromPortOption{ + f: func(p *multiUDPMuxFromPortParam) { + p.batchWriteSize = batchWriteSize + p.batchWriteInterval = batchWriteInterval + }, + } +} diff --git a/udp_mux_multi_test.go b/udp_mux_multi_test.go index f37617df..6f7964fc 100644 --- a/udp_mux_multi_test.go +++ b/udp_mux_multi_test.go @@ -7,6 +7,7 @@ package ice import ( + "fmt" "net" "strings" "sync" @@ -110,6 +111,11 @@ func testMultiUDPMuxConnections(t *testing.T, udpMuxMulti *MultiUDPMuxDefault, u } } +type muxCase struct { + ports []int + opts []UDPMuxFromPortOption +} + func TestUnspecifiedUDPMux(t *testing.T) { report := test.CheckRoutines(t) defer report() @@ -117,39 +123,95 @@ func TestUnspecifiedUDPMux(t *testing.T) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() - muxPort := 7778 - udpMuxMulti, err := NewMultiUDPMuxFromPort(muxPort, UDPMuxFromPortWithInterfaceFilter(func(s string) bool { + cases := map[string]muxCase{ + "single port": {ports: []int{7778}}, + "multi ports": {ports: []int{7779, 7780, 7781}}, + "batch write": {ports: []int{7782}, opts: []UDPMuxFromPortOption{UDPMuxFromPortWithBatchWrite(10, 2*time.Millisecond)}}, + } + + for name, val := range cases { + cname, muxCase := name, val + t.Run(cname, func(t *testing.T) { + opts := []UDPMuxFromPortOption{ + UDPMuxFromPortWithInterfaceFilter(func(s string) bool { + defaultDockerBridgeNetwork := strings.Contains(s, "docker") + customDockerBridgeNetwork := strings.Contains(s, "br-") + return !defaultDockerBridgeNetwork && !customDockerBridgeNetwork + }), + UDPMuxFromPortWithReadBufferSize(1024 * 1024), + UDPMuxFromPortWithWriteBufferSize(1024 * 1024), + } + opts = append(opts, muxCase.opts...) + udpMuxMulti, err := NewMultiUDPMuxFromPorts(muxCase.ports, opts...) + require.NoError(t, err) + + require.GreaterOrEqual(t, len(udpMuxMulti.muxes), 1, "at least have 1 muxes") + defer func() { + _ = udpMuxMulti.Close() + }() + + wg := sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp) + }() + wg.Add(1) + go func() { + defer wg.Done() + testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", udp4) + }() + + // Skip IPv6 test on i386 + const ptrSize = 32 << (^uintptr(0) >> 63) + if ptrSize != 32 { + testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", udp6) + } + + wg.Wait() + + require.NoError(t, udpMuxMulti.Close()) + }) + } +} + +func TestMultiUDPMuxPortBalance(t *testing.T) { + ports := []int{8779, 8780, 8781} + udpMuxMulti, err := NewMultiUDPMuxFromPorts(ports, UDPMuxFromPortWithInterfaceFilter(func(s string) bool { defaultDockerBridgeNetwork := strings.Contains(s, "docker") customDockerBridgeNetwork := strings.Contains(s, "br-") return !defaultDockerBridgeNetwork && !customDockerBridgeNetwork })) require.NoError(t, err) - - require.GreaterOrEqual(t, len(udpMuxMulti.muxes), 1, "at least have 1 muxes") defer func() { - _ = udpMuxMulti.Close() + require.NoError(t, udpMuxMulti.Close()) }() - wg := sync.WaitGroup{} - - wg.Add(1) - go func() { - defer wg.Done() - testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp) - }() - wg.Add(1) - go func() { - defer wg.Done() - testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", udp4) - }() - - // Skip IPv6 test on i386 - const ptrSize = 32 << (^uintptr(0) >> 63) - if ptrSize != 32 { - testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", udp6) + // check port allocation is balanced + expectPorts := make(map[int]bool) + for i := range ports { + addr := udpMuxMulti.GetListenAddresses()[0] + ufrag := fmt.Sprintf("ufragetest%d", i) + conn, err := udpMuxMulti.GetConn(ufrag, addr) + require.NoError(t, err) + require.NotNil(t, conn) + udpLocalAddr, _ := conn.LocalAddr().(*net.UDPAddr) + require.False(t, expectPorts[udpLocalAddr.Port], fmt.Sprint("port ", udpLocalAddr.Port, " is already used", expectPorts)) + expectPorts[udpLocalAddr.Port] = true + + conn2, err := udpMuxMulti.GetConn(ufrag, addr) + require.NoError(t, err) + require.Equal(t, conn, conn2) } + require.Equal(t, len(ports), len(expectPorts)) +} - wg.Wait() +type udpMuxWithoutConnCount struct { + UDPMux +} - require.NoError(t, udpMuxMulti.Close()) +func TestMultiUDPMuxPortBalanceWithoutConnCount(t *testing.T) { + _, err := NewMultiUDPMuxDefaultWithOptions([]UDPMux{&udpMuxWithoutConnCount{}, &udpMuxWithoutConnCount{}}, MultiUDPMuxOptionWithPortBalance()) + require.ErrorIs(t, err, errPortBalanceRequireConnCount) } diff --git a/udp_muxed_conn.go b/udp_muxed_conn.go index 09e4b3a8..8c3c078e 100644 --- a/udp_muxed_conn.go +++ b/udp_muxed_conn.go @@ -8,12 +8,17 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/pion/logging" "github.com/pion/transport/v2/packetio" ) +const ( + iceConnectedTimeout = 25 * time.Second +) + type udpMuxedConnParams struct { Mux *UDPMuxDefault AddrPool *sync.Pool @@ -33,6 +38,9 @@ type udpMuxedConn struct { closedChan chan struct{} closeOnce sync.Once mu sync.Mutex + + startAt time.Time + iceConnected atomic.Bool } func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn { @@ -40,6 +48,7 @@ func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn { params: params, buf: packetio.NewBuffer(), closedChan: make(chan struct{}), + startAt: time.Now(), } return p @@ -80,10 +89,18 @@ func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { if c.isClosed() { return 0, io.ErrClosedPipe } - // Each time we write to a new address, we'll register it with the mux - addr := rAddr.String() - if !c.containsAddress(addr) { - c.addAddress(addr) + + // Only checking the address at the ice connecting stage to reduce the check cost + if !c.iceConnected.Load() { + if time.Since(c.startAt) > iceConnectedTimeout { + c.iceConnected.Store(true) + } else { + // Each time we write to a new address, we'll register it with the mux + addr := rAddr.String() + if !c.containsAddress(addr) { + c.addAddress(addr) + } + } } return c.params.Mux.writeTo(buf, rAddr)