Skip to content

Commit

Permalink
Multiple refactors on cplb.tcpproxy
Browse files Browse the repository at this point in the history
1- Simplify tcpproxy by reomivng useless interfaces:
The original tcpproxy allowed different types of routes and needed to do
a bunch of interfacing for it to work. Since we only implement one kind
of route and one kind of target, we remove all the interfaces and merge
both structs into a unique struct.

2- Remove proxy.AddRoute: We only used it once and setRoutes can cover
that use case

3- Lock tcpproxy.Proxy when modifying routes to make it thread safe.
Prior to this we relied on the proxy being called only from one
goroutine. Now it can be called concurrently, not that we expect to do
that but a lock gives us extra safety.

4- Panic if tcpproxy.SetRoutes gets and empty route list. We now check
this in cplb_unix.go.

5- Remove the route feeding goroutine for round robin, since we added a
lock to make proxy.SetRoutes threadsafe we don't need that anymore and
it can be made much simpler by adding a lock.

Signed-off-by: Juan-Luis de Sousa-Valadas Castaño <[email protected]>
  • Loading branch information
juanluisvaladas committed Dec 19, 2024
1 parent 527b32d commit 5155c18
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 129 deletions.
8 changes: 6 additions & 2 deletions pkg/component/controller/cplb/cplb_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ func (k *Keepalived) watchReconcilerUpdatesReverseProxy() error {
k.proxy = tcpproxy.Proxy{}
// We don't know how long until we get the first update, so initially we
// forward everything to localhost
k.proxy.AddRoute(fmt.Sprintf(":%d", k.Config.UserSpaceProxyPort), tcpproxy.To(fmt.Sprintf("127.0.0.1:%d", k.APIPort)))
k.proxy.SetRoutes(fmt.Sprintf(":%d", k.Config.UserSpaceProxyPort), []tcpproxy.Route{tcpproxy.To(fmt.Sprintf("127.0.0.1:%d", k.APIPort))})

if err := k.proxy.Start(); err != nil {
return fmt.Errorf("failed to start proxy: %w", err)
Expand All @@ -372,11 +372,15 @@ func (k *Keepalived) watchReconcilerUpdatesReverseProxy() error {
}

func (k *Keepalived) setProxyRoutes() {
routes := []tcpproxy.Target{}
routes := []tcpproxy.Route{}
for _, addr := range k.reconciler.GetIPs() {
routes = append(routes, tcpproxy.To(fmt.Sprintf("%s:%d", addr, k.APIPort)))
}

if len(routes) == 0 {
k.log.Error("No API servers available, leave previous configuration")
return
}
k.proxy.SetRoutes(fmt.Sprintf(":%d", k.Config.UserSpaceProxyPort), routes)
}

Expand Down
173 changes: 55 additions & 118 deletions pkg/component/controller/cplb/tcpproxy/tcpproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ import (
"errors"
"fmt"
"io"
"log"
"net"
"sync"
"time"

"github.com/sirupsen/logrus"
)

// Proxy is a proxy. Its zero value is a valid proxy that does
Expand All @@ -38,12 +40,13 @@ import (
// The order that routes are added in matters; each is matched in the order
// registered.
type Proxy struct {
mux sync.RWMutex
configs map[string]*config // ip:port => config

lns []net.Listener
donec chan struct{} // closed before err
err error // any error from listening
routesChan chan route
connNumber int // connection number counter, used for round robin

// ListenFunc optionally specifies an alternate listen
// function. If nil, net.Dial is used.
Expand All @@ -56,22 +59,7 @@ type Matcher func(ctx context.Context, hostname string) bool

// config contains the proxying state for one listener.
type config struct {
routes []route
}

// A route matches a connection to a target.
type route interface {
// match examines the initial bytes of a connection, looking for a
// match. If a match is found, match returns a non-nil Target to
// which the stream should be proxied. match returns nil if the
// connection doesn't match.
//
// match must not consume bytes from the given bufio.Reader, it
// can only Peek.
//
// If an sni or host header was parsed successfully, that will be
// returned as the second parameter.
match(*bufio.Reader) (Target, string)
routes []Route
}

func (p *Proxy) netListen() func(net, laddr string) (net.Listener, error) {
Expand All @@ -91,28 +79,7 @@ func (p *Proxy) configFor(ipPort string) *config {
return p.configs[ipPort]
}

func (p *Proxy) addRoute(ipPort string, r route) {
cfg := p.configFor(ipPort)
cfg.routes = append(cfg.routes, r)
}

// AddRoute appends an always-matching route to the ipPort listener,
// directing any connection to dest.
//
// This is generally used as either the only rule (for simple TCP
// proxies), or as the final fallback rule for an ipPort.
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddRoute(ipPort string, dest Target) {
p.addRoute(ipPort, fixedTarget{dest})
}

func (p *Proxy) setRoutes(ipPort string, targets []Target) {
var routes []route
for _, target := range targets {
routes = append(routes, fixedTarget{target})
}

func (p *Proxy) setRoutes(ipPort string, routes []Route) {
cfg := p.configFor(ipPort)
cfg.routes = routes
}
Expand All @@ -122,19 +89,15 @@ func (p *Proxy) setRoutes(ipPort string, targets []Target) {
// It's possible that the old routes are still used once after this
// function is called. If an empty slice is passed, the routes are
// preserved in order to avoid an infinite loop.
func (p *Proxy) SetRoutes(ipPort string, targets []Target) {
func (p *Proxy) SetRoutes(ipPort string, targets []Route) {
p.mux.Lock()
defer p.mux.Unlock()
if len(targets) == 0 {
return
panic("SetRoutes with empty targets")
}
p.setRoutes(ipPort, targets)
}

type fixedTarget struct {
t Target
}

func (m fixedTarget) match(*bufio.Reader) (Target, string) { return m.t, "" }

// Run is calls Start, and then Wait.
//
// It blocks until there's an error. The return value is always
Expand Down Expand Up @@ -183,7 +146,6 @@ func (p *Proxy) Start() error {
return err
}
p.lns = append(p.lns, ln)
p.routesChan = make(chan route)
go p.serveListener(errc, ln, config)
}
go p.awaitFirstError(errc)
Expand All @@ -196,48 +158,35 @@ func (p *Proxy) awaitFirstError(errc <-chan error) {
}

func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, cfg *config) {
go p.roundRobin(cfg)
for {
c, err := ln.Accept()
if err != nil {
ret <- err
return
}
go p.serveConn(c)
go p.serveConn(c, cfg)
}
}

// serveConn runs in its own goroutine and matches c against routes.
// It returns whether it matched purely for testing.
func (p *Proxy) serveConn(c net.Conn) bool {
func (p *Proxy) serveConn(c net.Conn, cfg *config) bool {
br := bufio.NewReader(c)
for route := range p.routesChan {
if target, hostName := route.match(br); target != nil {
if n := br.Buffered(); n > 0 {
peeked, _ := br.Peek(br.Buffered())
c = &Conn{
HostName: hostName,
Peeked: peeked,
Conn: c,
}
}
target.HandleConn(c)
return true
}
}
// TODO: hook for this?
log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String())
c.Close()
return false
}

// roundRobin writes to a channel the next route to use.
func (p *Proxy) roundRobin(cfg *config) {
for {
for _, route := range cfg.routes {
p.routesChan <- route
p.mux.RLock()
p.connNumber++
route := cfg.routes[p.connNumber%(len(cfg.routes))]
p.mux.RUnlock()

if n := br.Buffered(); n > 0 {
peeked, _ := br.Peek(br.Buffered())
c = &Conn{
Peeked: peeked,
Conn: c,
}
}
route.HandleConn(c)
return true
}

// Conn is an incoming connection that has had some bytes read from it
Expand Down Expand Up @@ -276,29 +225,17 @@ func (c *Conn) Read(p []byte) (n int, err error) {
return c.Conn.Read(p)
}

// Target is what an incoming matched connection is sent to.
type Target interface {
// HandleConn is called when an incoming connection is
// matched. After the call to HandleConn, the tcpproxy
// package never touches the conn again. Implementations are
// responsible for closing the connection when needed.
//
// The concrete type of conn will be of type *Conn if any
// bytes have been consumed for the purposes of route
// matching.
HandleConn(net.Conn)
}

// To is shorthand way of writing &tcpproxy.DialProxy{Addr: addr}.
func To(addr string) *DialProxy {
return &DialProxy{Addr: addr}
func To(addr string) Route {
return Route{Addr: addr}
}

// DialProxy implements Target by dialing a new connection to Addr
// Route is what an incoming connection is sent to.
// It handles them by dialing a new connection to Addr
// and then proxying data back and forth.
//
// The To func is a shorthand way of creating a DialProxy.
type DialProxy struct {
// The To func is a shorthand way of creating a Route.
type Route struct {
// Addr is the TCP address to proxy to.
Addr string

Expand Down Expand Up @@ -366,29 +303,29 @@ func closeWrite(c net.Conn) {
}

// HandleConn implements the Target interface.
func (dp *DialProxy) HandleConn(src net.Conn) {
func (r *Route) HandleConn(src net.Conn) {
ctx := context.Background()
var cancel context.CancelFunc
if dp.DialTimeout >= 0 {
ctx, cancel = context.WithTimeout(ctx, dp.dialTimeout())
if r.DialTimeout >= 0 {
ctx, cancel = context.WithTimeout(ctx, r.dialTimeout())
}
dst, err := dp.dialContext()(ctx, "tcp", dp.Addr)
dst, err := r.dialContext()(ctx, "tcp", r.Addr)
if cancel != nil {
cancel()
}
if err != nil {
dp.onDialError()(src, err)
r.onDialError()(src, err)
return
}
defer goCloseConn(dst)

if err = dp.sendProxyHeader(dst, src); err != nil {
dp.onDialError()(src, err)
if err = r.sendProxyHeader(dst, src); err != nil {
r.onDialError()(src, err)
return
}
defer goCloseConn(src)

if ka := dp.keepAlivePeriod(); ka > 0 {
if ka := r.keepAlivePeriod(); ka > 0 {
for _, c := range []net.Conn{src, dst} {
if c, ok := tcpConn(c); ok {
_ = c.SetKeepAlive(true)
Expand All @@ -404,8 +341,8 @@ func (dp *DialProxy) HandleConn(src net.Conn) {
<-errc
}

func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
switch dp.ProxyProtocolVersion {
func (r *Route) sendProxyHeader(w io.Writer, src net.Conn) error {
switch r.ProxyProtocolVersion {
case 0:
return nil
case 1:
Expand All @@ -429,7 +366,7 @@ func (dp *DialProxy) sendProxyHeader(w io.Writer, src net.Conn) error {
_, err := fmt.Fprintf(w, "PROXY %s %s %s %d %d\r\n", family, srcAddr.IP, dstAddr.IP, srcAddr.Port, dstAddr.Port)
return err
default:
return fmt.Errorf("PROXY protocol version %d not supported", dp.ProxyProtocolVersion)
return fmt.Errorf("PROXY protocol version %d not supported", r.ProxyProtocolVersion)
}
}

Expand Down Expand Up @@ -458,35 +395,35 @@ func proxyCopy(errc chan<- error, dst, src net.Conn) {
errc <- err
}

func (dp *DialProxy) keepAlivePeriod() time.Duration {
if dp.KeepAlivePeriod != 0 {
return dp.KeepAlivePeriod
func (r *Route) keepAlivePeriod() time.Duration {
if r.KeepAlivePeriod != 0 {
return r.KeepAlivePeriod
}
return time.Minute
}

func (dp *DialProxy) dialTimeout() time.Duration {
if dp.DialTimeout > 0 {
return dp.DialTimeout
func (r *Route) dialTimeout() time.Duration {
if r.DialTimeout > 0 {
return r.DialTimeout
}
return 10 * time.Second
}

var defaultDialer = new(net.Dialer)

func (dp *DialProxy) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) {
if dp.DialContext != nil {
return dp.DialContext
func (r *Route) dialContext() func(ctx context.Context, network, address string) (net.Conn, error) {
if r.DialContext != nil {
return r.DialContext
}
return defaultDialer.DialContext
}

func (dp *DialProxy) onDialError() func(src net.Conn, dstDialErr error) {
if dp.OnDialError != nil {
return dp.OnDialError
func (r *Route) onDialError() func(src net.Conn, dstDialErr error) {
if r.OnDialError != nil {
return r.OnDialError
}
return func(src net.Conn, dstDialErr error) {
log.Printf("tcpproxy: for incoming conn %v, error dialing %q: %v", src.RemoteAddr().String(), dp.Addr, dstDialErr)
logrus.WithFields(logrus.Fields{"component": "tcpproxy"}).Errorf("for incoming conn %v, error dialing %q: %v", src.RemoteAddr().String(), r.Addr, dstDialErr)
src.Close()
}
}
Loading

0 comments on commit 5155c18

Please sign in to comment.