Skip to content

Commit

Permalink
feat(dnscoretest): add listening customisation points
Browse files Browse the repository at this point in the history
This change allows to customise how we create listening connections
with dnscoretest. In turn, this opens up the possibility of reusing
this code with custom listeners (e.g., rbmk-project/x/netsim).
  • Loading branch information
bassosimone committed Nov 22, 2024
1 parent 5043a32 commit f4af514
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dnscoretest/dohttps.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (s *Server) StartHTTPS(handler Handler) <-chan struct{} {
go func() {
cert := runtimex.Try1(tls.X509KeyPair(certPEM, keyPEM))
config := &tls.Config{Certificates: []tls.Certificate{cert}}
listener := runtimex.Try1(tls.Listen("tcp", "127.0.0.1:0", config))
listener := runtimex.Try1(s.listenTLS("tcp", "127.0.0.1:0", config))
s.Addr = listener.Addr().String()
s.RootCAs = x509.NewCertPool()
runtimex.Assert(s.RootCAs.AppendCertsFromPEM(certPEM), "cannot append PEM cert")
Expand Down
10 changes: 9 additions & 1 deletion dnscoretest/dotcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (s *Server) StartTCP(handler Handler) <-chan struct{} {
runtimex.Assert(!s.started, "already started")
ready := make(chan struct{})
go func() {
listener := runtimex.Try1(net.Listen("tcp", "127.0.0.1:0"))
listener := runtimex.Try1(s.listen("tcp", "127.0.0.1:0"))
s.Addr = listener.Addr().String()
s.ioclosers = append(s.ioclosers, listener)
s.started = true
Expand All @@ -34,6 +34,14 @@ func (s *Server) StartTCP(handler Handler) <-chan struct{} {
return ready
}

// listen either used the stdlib or the custom Listen func.
func (s *Server) listen(network, address string) (net.Listener, error) {
if s.Listen != nil {
return s.Listen(network, address)
}
return net.Listen(network, address)
}

// serveConn serves a single DNS query over TCP or TLS.
func (s *Server) serveConn(handler Handler, conn net.Conn) {
// Close the connection when done serving
Expand Down
11 changes: 10 additions & 1 deletion dnscoretest/dotls.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/tls"
"crypto/x509"
_ "embed"
"net"

"github.com/rbmk-project/common/runtimex"
)
Expand All @@ -27,7 +28,7 @@ func (s *Server) StartTLS(handler Handler) <-chan struct{} {
go func() {
cert := runtimex.Try1(tls.X509KeyPair(certPEM, keyPEM))
config := &tls.Config{Certificates: []tls.Certificate{cert}}
listener := runtimex.Try1(tls.Listen("tcp", "127.0.0.1:0", config))
listener := runtimex.Try1(s.listenTLS("tcp", "127.0.0.1:0", config))
s.Addr = listener.Addr().String()
s.RootCAs = x509.NewCertPool()
runtimex.Assert(s.RootCAs.AppendCertsFromPEM(certPEM), "cannot append PEM cert")
Expand All @@ -44,3 +45,11 @@ func (s *Server) StartTLS(handler Handler) <-chan struct{} {
}()
return ready
}

// listenTLS either uses the stdlib or the custom ListenTLS func.
func (s *Server) listenTLS(network, address string, config *tls.Config) (net.Listener, error) {
if s.ListenTLS != nil {
return s.ListenTLS(network, address, config)
}
return tls.Listen(network, address, config)
}
10 changes: 9 additions & 1 deletion dnscoretest/doudp.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func (s *Server) StartUDP(handler Handler) <-chan struct{} {
runtimex.Assert(!s.started, "already started")
ready := make(chan struct{})
go func() {
pconn := runtimex.Try1(net.ListenPacket("udp", "127.0.0.1:0"))
pconn := runtimex.Try1(s.listenPacket("udp", "127.0.0.1:0"))
s.Addr = pconn.LocalAddr().String()
s.ioclosers = append(s.ioclosers, pconn)
s.started = true
Expand All @@ -27,6 +27,14 @@ func (s *Server) StartUDP(handler Handler) <-chan struct{} {
return ready
}

// listenPacket either uses the standard library or the custom ListenPacket func.
func (s *Server) listenPacket(network, address string) (net.PacketConn, error) {
if s.ListenPacket != nil {
return s.ListenPacket(network, address)
}
return net.ListenPacket(network, address)
}

// servePacketConn serves a single DNS query over UDP.
func (s *Server) servePacketConn(handler Handler, pconn net.PacketConn) error {
buf := make([]byte, 4096)
Expand Down
14 changes: 14 additions & 0 deletions dnscoretest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
package dnscoretest

import (
"crypto/tls"
"crypto/x509"
"io"
"net"
)

// Server is a fake DNS server.
Expand All @@ -15,6 +17,18 @@ type Server struct {
// DNS-over-TCP, and DNS-over-TLS.
Addr string

// Listen is an optional func to override the default
// function used to create a [net.Listener].
Listen func(network, address string) (net.Listener, error)

// ListenPacket is an optional func to override the default
// function used to create a listening [net.PacketConn].
ListenPacket func(network, address string) (net.PacketConn, error)

// ListenTLS is an optional func to override the default
// function used to listen using TLS.
ListenTLS func(network, address string, config *tls.Config) (net.Listener, error)

// RootCAs contains the cert pool the client should use
// for DNS-over-TLS and DNS-over-HTTPS.
RootCAs *x509.CertPool
Expand Down
36 changes: 36 additions & 0 deletions dnscoretest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,48 @@
package dnscoretest

import (
"crypto/tls"
"errors"
"net"
"testing"

"github.com/rbmk-project/common/mocks"
"github.com/stretchr/testify/assert"
)

func TestServer_listen(t *testing.T) {
expectedErr := errors.New("mocked error")
srv := &Server{
Listen: func(network, address string) (net.Listener, error) {
return nil, expectedErr
},
}
_, err := srv.listen("tcp", "127.0.0.1:0")
assert.ErrorIs(t, err, expectedErr)
}

func TestServer_listenPacket(t *testing.T) {
expectedErr := errors.New("mocked error")
srv := &Server{
ListenPacket: func(network, address string) (net.PacketConn, error) {
return nil, expectedErr
},
}
_, err := srv.listenPacket("udp", "127.0.0.1:0")
assert.ErrorIs(t, err, expectedErr)
}

func TestServer_listenTLS(t *testing.T) {
expectedErr := errors.New("mocked error")
srv := &Server{
ListenTLS: func(network, address string, config *tls.Config) (net.Listener, error) {
return nil, expectedErr
},
}
_, err := srv.listenTLS("tcp", "127.0.0.1:0", nil)
assert.ErrorIs(t, err, expectedErr)
}

func TestServer_Close(t *testing.T) {
expected := errors.New("mocked error")
srv := &Server{}
Expand Down

0 comments on commit f4af514

Please sign in to comment.