Skip to content

Commit

Permalink
feat: save local and remote addr (#9)
Browse files Browse the repository at this point in the history
This changeset implements saving the local and remote addr for all the
transports we support, improving data quality.

The changes should be obvious for all protocols but DoH, for which we
introduced a lower-level mechanism to collect this info, and namely the
HTTPClientDo function, opening up the possibility of advanced
customization code to create a connection on the fly. The previous
mechanism using HTTPClient is still working as intended.

Another minor change is refactoring the structured logging code inside
the slog.go and slog_test.go files.
  • Loading branch information
bassosimone authored Nov 22, 2024
1 parent 95b7df4 commit 36a27f8
Show file tree
Hide file tree
Showing 8 changed files with 641 additions and 156 deletions.
68 changes: 66 additions & 2 deletions dohttps.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ import (
"bytes"
"context"
"io"
"net"
"net/http"
"net/http/httptrace"
"net/netip"
"sync"

"github.com/miekg/dns"
)
Expand All @@ -36,6 +40,66 @@ func (t *Transport) httpClient() *http.Client {
return http.DefaultClient
}

// httpClientDo performs an HTTP request using one of two methods:
//
// 1. if HTTPClientDo is not nil, use it directly;
//
// 2. otherwise use [*Transport.httpClient] to obtain a suitable
// [*http.Client] and perform the request with it.
func (t *Transport) httpClientDo(req *http.Request) (*http.Response, netip.AddrPort, netip.AddrPort, error) {
// If HTTPClientDo isn't nil, use it directly.
if t.HTTPClientDo != nil {
return t.HTTPClientDo(req)
}

// Prepare to collect info in a goroutine-safe way.
var (
laddr netip.AddrPort
mu sync.Mutex
raddr netip.AddrPort
)

// Create clean context for tracing where "clean" means
// we don't compose with other possible context traces
traceCtx, cancel := context.WithCancel(context.Background())

// Configure the trace for extracting laddr, raddr
trace := &httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
mu.Lock()
defer mu.Unlock()
if addr, ok := info.Conn.LocalAddr().(*net.TCPAddr); ok {
laddr = addr.AddrPort()
}
if addr, ok := info.Conn.RemoteAddr().(*net.TCPAddr); ok {
raddr = addr.AddrPort()
}
},
}
req = req.WithContext(httptrace.WithClientTrace(traceCtx, trace))

// Arrange for the inner context to be canceled
// when the outer context is done.
//
// This must be after req.WithContext to avoid
// a data race in the context itself.
go func() {
defer cancel()
select {
case <-req.Context().Done():
case <-traceCtx.Done():
}
}()

// Perform the request and return the response while holding
// the mutex protecting laddr and raddr.
client := t.httpClient()
resp, err := client.Do(req)
mu.Lock()
defer mu.Unlock()
return resp, laddr, raddr, err
}

// readAllContext is a helper function that reads all from the reader using the
// namesake transport function or the stdlib if the given function is nil.
func (t *Transport) readAllContext(ctx context.Context, r io.Reader, c io.Closer) ([]byte, error) {
Expand Down Expand Up @@ -73,7 +137,7 @@ func (t *Transport) queryHTTPS(ctx context.Context,
// the body, the response code is 200, and the content type
// is the expected one. Since servers always include the
// content type, we don't need to be flexible here.
httpResp, err := t.httpClient().Do(req)
httpResp, laddr, raddr, err := t.httpClientDo(req)
if err != nil {
return nil, err
}
Expand All @@ -96,6 +160,6 @@ func (t *Transport) queryHTTPS(ctx context.Context,
if err := resp.Unpack(rawResp); err != nil {
return nil, err
}
t.maybeLogResponse(ctx, addr, t0, rawQuery, rawResp)
t.maybeLogResponseAddrPort(ctx, addr, t0, rawQuery, rawResp, laddr, raddr)
return resp, nil
}
174 changes: 174 additions & 0 deletions dohttps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@ import (
"context"
"errors"
"io"
"net"
"net/http"
"net/http/httptrace"
"net/netip"
"testing"

"github.com/miekg/dns"
"github.com/rbmk-project/common/mocks"
"github.com/rbmk-project/common/runtimex"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -110,6 +114,142 @@ func TestTransport_httpClient(t *testing.T) {
}
}

func TestTransport_httpClientDo(t *testing.T) {
tests := []struct {
name string
setupTransport func() *Transport
expectedError error
expectedLocalAddr netip.AddrPort
expectedRemoteAddr netip.AddrPort
}{
{
name: "HTTPClientDo takes precedence",
setupTransport: func() *Transport {
return &Transport{
HTTPClientDo: func(req *http.Request) (*http.Response, netip.AddrPort, netip.AddrPort, error) {
return &http.Response{StatusCode: 200}, netip.AddrPort{}, netip.AddrPort{}, nil
},
HTTPClient: &http.Client{
Transport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return nil, errors.New("should not be called")
},
},
},
}
},
expectedError: nil,
expectedLocalAddr: netip.AddrPort{},
expectedRemoteAddr: netip.AddrPort{},
},

{
name: "HTTPClientDo returns error",
setupTransport: func() *Transport {
return &Transport{
HTTPClientDo: func(req *http.Request) (*http.Response, netip.AddrPort, netip.AddrPort, error) {
return nil, netip.AddrPort{}, netip.AddrPort{}, errors.New("custom error")
},
}
},
expectedError: errors.New("custom error"),
expectedLocalAddr: netip.AddrPort{},
expectedRemoteAddr: netip.AddrPort{},
},

{
name: "Fallback to HTTPClient success",
setupTransport: func() *Transport {
return &Transport{
HTTPClient: &http.Client{
Transport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: 200}, nil
},
},
},
}
},
expectedError: nil,
expectedLocalAddr: netip.AddrPort{},
expectedRemoteAddr: netip.AddrPort{},
},

{
name: "Fallback to HTTPClient failure",
setupTransport: func() *Transport {
return &Transport{
HTTPClient: &http.Client{
Transport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return nil, errors.New("http error")
},
},
},
}
},
expectedError: errors.New("Get \"https://example.com\": http error"),
expectedLocalAddr: netip.AddrPort{},
expectedRemoteAddr: netip.AddrPort{},
},

{
name: "Fallback to HTTPClient collects addresses",
setupTransport: func() *Transport {
return &Transport{
HTTPClient: &http.Client{
Transport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
trace := httptrace.ContextClientTrace(req.Context())
if trace != nil && trace.GotConn != nil {
trace.GotConn(httptrace.GotConnInfo{
Conn: &mocks.Conn{
MockLocalAddr: func() net.Addr {
return &net.TCPAddr{
IP: net.ParseIP("::1"),
Port: 12345,
}
},
MockRemoteAddr: func() net.Addr {
return &net.TCPAddr{
IP: net.ParseIP("::2"),
Port: 443,
}
},
},
})
}
return &http.Response{StatusCode: 200}, nil
},
},
},
}
},
expectedError: nil,
expectedLocalAddr: netip.MustParseAddrPort("[::1]:12345"),
expectedRemoteAddr: netip.MustParseAddrPort("[::2]:443"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
transport := tt.setupTransport()
req := runtimex.Try1(http.NewRequest("GET", "https://example.com", nil))
resp, la, ra, err := transport.httpClientDo(req)
if tt.expectedError != nil {
assert.Error(t, err)
assert.Equal(t, tt.expectedError.Error(), err.Error())
assert.Nil(t, resp)
} else {
assert.NoError(t, err)
assert.NotNil(t, resp)
}
assert.Equal(t, tt.expectedLocalAddr, la)
assert.Equal(t, tt.expectedRemoteAddr, ra)
})
}
}

func TestTransport_readAllContext(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -364,6 +504,40 @@ func TestTransport_queryHTTPS(t *testing.T) {
url: "https://dns.google/dns-query",
expectedError: errors.New("read failed"),
},

{
name: "HTTPClientDo takes precedence over HTTPClient",
setupTransport: func() *Transport {
return &Transport{
// Should be used
HTTPClientDo: func(req *http.Request) (*http.Response, netip.AddrPort, netip.AddrPort, error) {
dnsResp := &dns.Msg{}
rawDnsResp, err := dnsResp.Pack()
if err != nil {
panic(err)
}
resp := &http.Response{
StatusCode: 200,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(rawDnsResp)),
}
resp.Header.Set("content-type", "application/dns-message")
return resp, netip.AddrPort{}, netip.AddrPort{}, nil
},
// Should not be used
HTTPClient: &http.Client{
Transport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return nil, errors.New("HTTPClient should not be used")
},
},
},
}
},
questionName: "example.com.",
url: "https://dns.google/dns-query",
expectedError: nil,
},
}

for _, tt := range tests {
Expand Down
2 changes: 1 addition & 1 deletion dotcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (t *Transport) queryStream(ctx context.Context,
if err := resp.Unpack(rawResp); err != nil {
return nil, err
}
t.maybeLogResponse(ctx, addr, t0, rawQuery, rawResp)
t.maybeLogResponseConn(ctx, addr, t0, rawQuery, rawResp, conn)
return resp, nil
}

Expand Down
38 changes: 1 addition & 37 deletions doudp.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ package dnscore

import (
"context"
"log/slog"
"net"
"time"

Expand All @@ -36,41 +35,6 @@ func (t *Transport) timeNow() time.Time {
return time.Now()
}

// maybeLogQuery is a helper function that logs the query if the logger is set
// and returns the current time for subsequent logging.
func (t *Transport) maybeLogQuery(
ctx context.Context, addr *ServerAddr, rawQuery []byte) time.Time {
t0 := t.timeNow()
if t.Logger != nil {
t.Logger.InfoContext(
ctx,
"dnsQuery",
slog.Any("rawQuery", rawQuery),
slog.String("serverAddr", addr.Address),
slog.String("serverProtocol", string(addr.Protocol)),
slog.Time("t", t0),
)
}
return t0
}

// maybeLogResponse is a helper function that logs the response if the logger is set.
func (t *Transport) maybeLogResponse(ctx context.Context,
addr *ServerAddr, t0 time.Time, rawQuery, rawResp []byte) {
if t.Logger != nil {
t.Logger.InfoContext(
ctx,
"dnsResponse",
slog.Any("rawQuery", rawQuery),
slog.Any("rawResponse", rawResp),
slog.String("serverAddr", addr.Address),
slog.String("serverProtocol", string(addr.Protocol)),
slog.Time("t0", t0),
slog.Time("t", t.timeNow()),
)
}
}

// sendQueryUDP dials a connection, sends and logs the query and
// returns the following values:
//
Expand Down Expand Up @@ -146,7 +110,7 @@ func (t *Transport) recvResponseUDP(ctx context.Context, addr *ServerAddr, conn
if err := resp.Unpack(rawResp); err != nil {
return nil, err
}
t.maybeLogResponse(ctx, addr, t0, rawQuery, rawResp)
t.maybeLogResponseConn(ctx, addr, t0, rawQuery, rawResp, conn)
return resp, nil
}

Expand Down
Loading

0 comments on commit 36a27f8

Please sign in to comment.