Skip to content

Commit

Permalink
Add BatchIO and multiple ports options to UDPMux
Browse files Browse the repository at this point in the history
Add BatchIO and multiple ports options to
NewMultiUDPMuxPort(s)
  • Loading branch information
cnderrauber committed Aug 29, 2023
1 parent 0ec2333 commit 3338e2a
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 62 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ require (
github.com/pion/mdns v0.0.7
github.com/pion/randutil v0.1.0
github.com/pion/stun v0.6.1
github.com/pion/transport/v2 v2.2.1
github.com/pion/transport/v2 v2.2.2-0.20230829043045-6a34769ff4b0
github.com/pion/turn/v2 v2.1.3
github.com/stretchr/testify v1.8.4
golang.org/x/net v0.13.0
Expand Down
3 changes: 2 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TB
github.com/pion/stun v0.6.1 h1:8lp6YejULeHBF8NmV8e2787BogQhduZugh5PdhDyyN4=
github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/8=
github.com/pion/transport/v2 v2.0.0/go.mod h1:HS2MEBJTwD+1ZI2eSXSvHJx/HnzQqRy2/LXxt6eVMHc=
github.com/pion/transport/v2 v2.2.1 h1:7qYnCBlpgSJNYMbLCKuSY9KbQdBFoETvPNETv0y4N7c=
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
github.com/pion/transport/v2 v2.2.2-0.20230829043045-6a34769ff4b0 h1:7z51t0GDPVHvR8KTnVfUGgeE0KqZvc9o5J3UMVMrykY=
github.com/pion/transport/v2 v2.2.2-0.20230829043045-6a34769ff4b0/go.mod h1:OJg3ojoBJopjEeECq2yJdXH9YVrUJ1uQ++NjXLOUorc=
github.com/pion/turn/v2 v2.1.3 h1:pYxTVWG2gpC97opdRc5IGsQ1lJ9O/IlNhkzj7MMrGAA=
github.com/pion/turn/v2 v2.1.3/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand Down
8 changes: 8 additions & 0 deletions udp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type UDPMux interface {
GetConn(ufrag string, addr net.Addr) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string)
GetListenAddresses() []net.Addr
ConnCount() int
}

// UDPMuxDefault is an implementation of the interface
Expand Down Expand Up @@ -176,6 +177,13 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
return c, nil
}

// ConnCount return count of working connections created by UDPMuxDefault
func (m *UDPMuxDefault) ConnCount() int {
m.mu.Lock()
defer m.mu.Unlock()
return len(m.connsIPv4) + len(m.connsIPv6)
}

// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2)
Expand Down
167 changes: 144 additions & 23 deletions udp_mux_multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ package ice
import (
"fmt"
"net"
"time"

"github.com/pion/logging"
"github.com/pion/transport/v2"
"github.com/pion/transport/v2/stdnet"
tudp "github.com/pion/transport/v2/udp"
)

// MultiUDPMuxDefault implements both UDPMux and AllConnsGetter,
Expand All @@ -18,20 +20,93 @@ import (
type MultiUDPMuxDefault struct {
muxes []UDPMux
localAddrToMux map[string]UDPMux

enablePortBalance bool
// Manage port balance for mux that listen on multiple ports for same IP,
// for each IP, only return one addr (one port) for each GetListenAddresses call to
// avoid duplicate ip candidates be gathered for a single ice agent.
multiPortsAddresses []*multiPortsAddress
}

type addrMux struct {
addr net.Addr
mux UDPMux
}

// each multiPortsAddress represents muxes listen on different ports of a same IP
type multiPortsAddress struct {
addresseMuxes []*addrMux
}

func (mpa *multiPortsAddress) next() net.Addr {
leastAddr, leastConns := mpa.addresseMuxes[0].addr, mpa.addresseMuxes[0].mux.ConnCount()
for i := 1; i < len(mpa.addresseMuxes); i++ {
am := mpa.addresseMuxes[i]
if count := am.mux.ConnCount(); count < leastConns {
leastConns = count
leastAddr = am.addr
}
}
return leastAddr
}

// MultiUDPMuxOption provide options for NewMultiUDPMuxDefault
type MultiUDPMuxOption func(*multipleUDPMuxDefaultParams)

// MultiUDPMuxOptionWithPortBalance enable traffic balance on ports that belongs to same IP address,
// that means the MultiUDPMuxDefault will return a single port that has least connection for each IP address
// in GetListenAddresses return.
func MultiUDPMuxOptionWithPortBalance() MultiUDPMuxOption {
return func(params *multipleUDPMuxDefaultParams) {
params.portBalance = true
}
}

type multipleUDPMuxDefaultParams struct {
portBalance bool
}

// NewMultiUDPMuxDefault creates an instance of MultiUDPMuxDefault that
// uses the provided UDPMux instances.
func NewMultiUDPMuxDefault(muxes ...UDPMux) *MultiUDPMuxDefault {
return NewMultiUDPMuxDefaultWithOptions(muxes)
}

// NewMultiUDPMuxDefaultWithOptions creates an instance of MultiUDPMuxDefault that
// uses the provided UDPMux instances and options.
func NewMultiUDPMuxDefaultWithOptions(muxes []UDPMux, opts ...MultiUDPMuxOption) *MultiUDPMuxDefault {
var params multipleUDPMuxDefaultParams
for _, opt := range opts {
opt(&params)
}

addrToMux := make(map[string]UDPMux)
ipToAddrs := make(map[string]*multiPortsAddress)
for _, mux := range muxes {
for _, addr := range mux.GetListenAddresses() {
addrToMux[addr.String()] = mux

udpAddr, _ := addr.(*net.UDPAddr)
ip := udpAddr.IP.String()
if mpa, ok := ipToAddrs[ip]; ok {
mpa.addresseMuxes = append(mpa.addresseMuxes, &addrMux{addr, mux})
} else {
ipToAddrs[ip] = &multiPortsAddress{
addresseMuxes: []*addrMux{{addr, mux}},
}
}
}
}

multiPortsAddresses := make([]*multiPortsAddress, 0, len(ipToAddrs))
for _, mpa := range ipToAddrs {
multiPortsAddresses = append(multiPortsAddresses, mpa)
}
return &MultiUDPMuxDefault{
muxes: muxes,
localAddrToMux: addrToMux,
muxes: muxes,
localAddrToMux: addrToMux,
multiPortsAddresses: multiPortsAddresses,
enablePortBalance: params.portBalance,
}
}

Expand All @@ -45,6 +120,15 @@ func (m *MultiUDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketCon
return mux.GetConn(ufrag, addr)
}

// ConnCount return count of working connections created by the mux.
func (m *MultiUDPMuxDefault) ConnCount() int {
var count int
for _, mux := range m.muxes {
count += mux.ConnCount()
}
return count
}

// RemoveConnByUfrag stops and removes the muxed packet connection
// from all underlying UDPMux instances.
func (m *MultiUDPMuxDefault) RemoveConnByUfrag(ufrag string) {
Expand All @@ -64,8 +148,18 @@ func (m *MultiUDPMuxDefault) Close() error {
return err
}

// GetListenAddresses returns the list of addresses that this mux is listening on
// GetListenAddresses returns the list of addresses that this mux is listening on,
// if port balance enabled and there are multiple mux listen on different ports of a same IP addr,
// will return the mux who has least connections of that IP addr.
func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr {
if m.enablePortBalance {
addrs := make([]net.Addr, 0, len(m.multiPortsAddresses))
for _, mpa := range m.multiPortsAddresses {
addrs = append(addrs, mpa.next())
}
return addrs
}

addrs := make([]net.Addr, 0, len(m.localAddrToMux))
for _, mux := range m.muxes {
addrs = append(addrs, mux.GetListenAddresses()...)
Expand All @@ -76,6 +170,12 @@ func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr {
// NewMultiUDPMuxFromPort creates an instance of MultiUDPMuxDefault that
// listen all interfaces on the provided port.
func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) {
return NewMultiUDPMuxFromPorts([]int{port}, opts...)
}

// NewMultiUDPMuxFromPorts creates an instance of MultiUDPMuxDefault that
// listen all interfaces and balance traffic on the provided ports.
func NewMultiUDPMuxFromPorts(ports []int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) {
params := multiUDPMuxFromPortParam{
networks: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
}
Expand All @@ -95,20 +195,29 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
return nil, err
}

conns := make([]net.PacketConn, 0, len(ips))
conns := make([]net.PacketConn, 0, len(ports)*len(ips))
for _, ip := range ips {
conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{IP: ip, Port: port})
if listenErr != nil {
err = listenErr
break
}
if params.readBufferSize > 0 {
_ = conn.SetReadBuffer(params.readBufferSize)
for _, port := range ports {
conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{IP: ip, Port: port})
if listenErr != nil {
err = listenErr
break
}
if params.readBufferSize > 0 {
_ = conn.SetReadBuffer(params.readBufferSize)
}
if params.writeBufferSize > 0 {
_ = conn.SetWriteBuffer(params.writeBufferSize)
}
if params.batchWriteSize > 0 {
conns = append(conns, tudp.NewBatchConn(conn, params.batchWriteSize, params.batchWriteInterval))
} else {
conns = append(conns, conn)
}
}
if params.writeBufferSize > 0 {
_ = conn.SetWriteBuffer(params.writeBufferSize)
if err != nil {
break
}
conns = append(conns, conn)
}

if err != nil {
Expand All @@ -128,7 +237,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
muxes = append(muxes, mux)
}

return NewMultiUDPMuxDefault(muxes...), nil
return NewMultiUDPMuxDefaultWithOptions(muxes, MultiUDPMuxOptionWithPortBalance()), nil
}

// UDPMuxFromPortOption provide options for NewMultiUDPMuxFromPort
Expand All @@ -137,14 +246,16 @@ type UDPMuxFromPortOption interface {
}

type multiUDPMuxFromPortParam struct {
ifFilter func(string) bool
ipFilter func(ip net.IP) bool
networks []NetworkType
readBufferSize int
writeBufferSize int
logger logging.LeveledLogger
includeLoopback bool
net transport.Net
ifFilter func(string) bool
ipFilter func(ip net.IP) bool
networks []NetworkType
readBufferSize int
writeBufferSize int
logger logging.LeveledLogger
includeLoopback bool
net transport.Net
batchWriteSize int
batchWriteInterval time.Duration
}

type udpMuxFromPortOption struct {
Expand Down Expand Up @@ -226,3 +337,13 @@ func UDPMuxFromPortWithNet(n transport.Net) UDPMuxFromPortOption {
},
}
}

// UDPMuxFromPortWithBatchWrite enable batch write for UDPMux
func UDPMuxFromPortWithBatchWrite(batchWriteSize int, batchWriteInterval time.Duration) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.batchWriteSize = batchWriteSize
p.batchWriteInterval = batchWriteInterval
},
}
}
96 changes: 63 additions & 33 deletions udp_mux_multi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package ice

import (
"fmt"
"net"
"strings"
"sync"
Expand Down Expand Up @@ -117,39 +118,68 @@ func TestUnspecifiedUDPMux(t *testing.T) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()

muxPort := 7778
udpMuxMulti, err := NewMultiUDPMuxFromPort(muxPort, UDPMuxFromPortWithInterfaceFilter(func(s string) bool {
defaultDockerBridgeNetwork := strings.Contains(s, "docker")
customDockerBridgeNetwork := strings.Contains(s, "br-")
return !defaultDockerBridgeNetwork && !customDockerBridgeNetwork
}))
require.NoError(t, err)

require.GreaterOrEqual(t, len(udpMuxMulti.muxes), 1, "at least have 1 muxes")
defer func() {
_ = udpMuxMulti.Close()
}()

wg := sync.WaitGroup{}

wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp)
}()
wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", udp4)
}()

// Skip IPv6 test on i386
const ptrSize = 32 << (^uintptr(0) >> 63)
if ptrSize != 32 {
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", udp6)
cases := map[string][]int{
"single port": {7778},
"multi ports": {7779, 7780, 7781},
}

wg.Wait()

require.NoError(t, udpMuxMulti.Close())
for name, ports := range cases {
cname, cports := name, ports
t.Run(cname, func(t *testing.T) {
udpMuxMulti, err := NewMultiUDPMuxFromPorts(cports, UDPMuxFromPortWithInterfaceFilter(func(s string) bool {
defaultDockerBridgeNetwork := strings.Contains(s, "docker")
customDockerBridgeNetwork := strings.Contains(s, "br-")
return !defaultDockerBridgeNetwork && !customDockerBridgeNetwork
}))
require.NoError(t, err)

require.GreaterOrEqual(t, len(udpMuxMulti.muxes), 1, "at least have 1 muxes")
defer func() {
_ = udpMuxMulti.Close()
}()

wg := sync.WaitGroup{}

wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag1", udp)
}()
wg.Add(1)
go func() {
defer wg.Done()
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag2", udp4)
}()

// Skip IPv6 test on i386
const ptrSize = 32 << (^uintptr(0) >> 63)
if ptrSize != 32 {
testMultiUDPMuxConnections(t, udpMuxMulti, "ufrag3", udp6)
}

wg.Wait()

// check port allocation is balanced
if len(cports) > 1 {
expectPorts := make(map[int]bool)
for i := range cports {
addr := udpMuxMulti.GetListenAddresses()[0]
ufrag := fmt.Sprintf("ufragetest%d", i)
conn, err := udpMuxMulti.GetConn(ufrag, addr)
require.NoError(t, err)
require.NotNil(t, conn)
udpLocalAddr, _ := conn.LocalAddr().(*net.UDPAddr)
require.False(t, expectPorts[udpLocalAddr.Port], fmt.Sprint("port ", udpLocalAddr.Port, " is already used", expectPorts))
expectPorts[udpLocalAddr.Port] = true

conn2, err := udpMuxMulti.GetConn(ufrag, addr)
require.NoError(t, err)
require.Equal(t, conn, conn2)
}
require.Equal(t, len(cports), len(expectPorts))
}

require.NoError(t, udpMuxMulti.Close())
})
}
}
Loading

0 comments on commit 3338e2a

Please sign in to comment.