From 942feaa4e89c1389c94a5dc92d2f5b54f5e0467f Mon Sep 17 00:00:00 2001 From: Bryan White Date: Sun, 14 May 2023 21:02:34 +0200 Subject: [PATCH 1/2] chore: make server/resolver save for concurrent use --- resolver.go | 31 ++++++++++++++++++++++++++++--- server.go | 6 +++--- server_test.go | 2 +- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/resolver.go b/resolver.go index 968f6c2..f641ed9 100644 --- a/resolver.go +++ b/resolver.go @@ -36,13 +36,17 @@ type Zone struct { // and so can be used as a drop-in replacement for it if tested code // supports it. type Resolver struct { - Zones map[string]Zone + zonesMutex sync.RWMutex + Zones map[string]Zone // Don't follow CNAME in Zones for Lookup*. - SkipCNAME bool + skipCNAME bool } func (r *Resolver) LookupAddr(ctx context.Context, addr string) (names []string, err error) { + r.zonesMutex.RLock() + defer r.zonesMutex.RUnlock() + arpa, err := dns.ReverseAddr(addr) if err != nil { return nil, err @@ -62,6 +66,9 @@ func (r *Resolver) LookupAddr(ctx context.Context, addr string) (names []string, } func (r *Resolver) LookupCNAME(ctx context.Context, host string) (cname string, err error) { + r.zonesMutex.RLock() + defer r.zonesMutex.RUnlock() + rzone, ok := r.Zones[strings.ToLower(host)] if !ok { return "", notFound(host) @@ -91,6 +98,9 @@ func (r *Resolver) LookupHost(ctx context.Context, host string) (addrs []string, } func (r *Resolver) targetZone(name string) (ad bool, rname string, zone Zone, err error) { + r.zonesMutex.RLock() + defer r.zonesMutex.RUnlock() + rname = strings.ToLower(dns.Fqdn(name)) rzone, ok := r.Zones[rname] if !ok { @@ -103,7 +113,7 @@ func (r *Resolver) targetZone(name string) (ad bool, rname string, zone Zone, er ad = rzone.AD - if !r.SkipCNAME { + if !r.skipCNAME { for rzone.CNAME != "" { rname = rzone.CNAME rzone, ok = r.Zones[rname] @@ -314,3 +324,18 @@ func (r *Resolver) DialContext(ctx context.Context, network, addr string) (net.C } return nil, lastErr } + +func (r *Resolver) GetZone(name string) (Zone, bool) { + r.zonesMutex.RLock() + defer r.zonesMutex.RUnlock() + + zone, ok := r.Zones[name] + return zone, ok +} + +func (r *Resolver) SetSkipCNAME(skip bool) { + r.zonesMutex.Lock() + defer r.zonesMutex.Unlock() + + r.skipCNAME = skip +} diff --git a/server.go b/server.go index 16a305b..d99617a 100644 --- a/server.go +++ b/server.go @@ -169,7 +169,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, m *dns.Msg) { return } - qnameZone, ok := s.r.Zones[qname] + qnameZone, ok := s.r.GetZone(qname) if !ok { s.writeErr(w, reply, notFound(qname)) return @@ -314,7 +314,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, m *dns.Msg) { }) } case dns.TypePTR: - rzone, ok := s.r.Zones[q.Name] + rzone, ok := s.r.GetZone(q.Name) if !ok { s.writeErr(w, reply, notFound(q.Name)) return @@ -350,7 +350,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, m *dns.Msg) { }, } default: - rzone, ok := s.r.Zones[q.Name] + rzone, ok := s.r.GetZone(q.Name) if !ok { s.writeErr(w, reply, notFound(q.Name)) return diff --git a/server_test.go b/server_test.go index 1e6690c..e820381 100644 --- a/server_test.go +++ b/server_test.go @@ -158,7 +158,7 @@ func TestServer_Authoritative(t *testing.T) { if err != nil { t.Fatal(err) } - srv.Resolver().SkipCNAME = true + srv.Resolver().SetSkipCNAME(true) defer srv.Close() msg := new(dns.Msg) From 1887f3e55c34f9a1bf15200bfb4c17adba783bac Mon Sep 17 00:00:00 2001 From: Bryan White Date: Sun, 14 May 2023 21:03:02 +0200 Subject: [PATCH 2/2] feat: add `Server#AddZone()` and `Router#AddZone()` methods --- resolver.go | 15 +++++ server.go | 4 ++ server_test.go | 156 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 175 insertions(+) diff --git a/resolver.go b/resolver.go index f641ed9..cccbc6b 100644 --- a/resolver.go +++ b/resolver.go @@ -5,10 +5,13 @@ import ( "fmt" "net" "strings" + "sync" "github.com/miekg/dns" ) +const ErrExistingZoneFmt = "attempted to add existing zone %q" + type Zone struct { // Return the specified error on any lookup using this zone. // For Server, non-nil value results in SERVFAIL response. @@ -325,6 +328,18 @@ func (r *Resolver) DialContext(ctx context.Context, network, addr string) (net.C return nil, lastErr } +func (r *Resolver) AddZone(name string, zone Zone) error { + r.zonesMutex.Lock() + defer r.zonesMutex.Unlock() + + if _, ok := r.Zones[name]; ok { + return fmt.Errorf(ErrExistingZoneFmt, name) + } + + r.Zones[name] = zone + return nil +} + func (r *Resolver) GetZone(name string) (Zone, bool) { r.zonesMutex.RLock() defer r.zonesMutex.RUnlock() diff --git a/server.go b/server.go index d99617a..2efa7bb 100644 --- a/server.go +++ b/server.go @@ -399,6 +399,10 @@ func (s *Server) PatchNet(r *net.Resolver) { } } +func (s *Server) AddZone(name string, zone Zone) error { + return s.r.AddZone(name, zone) +} + func UnpatchNet(r *net.Resolver) { r.PreferGo = false r.Dial = nil diff --git a/server_test.go b/server_test.go index e820381..b87851d 100644 --- a/server_test.go +++ b/server_test.go @@ -2,6 +2,7 @@ package mockdns import ( "context" + "fmt" "net" "reflect" "sort" @@ -175,3 +176,158 @@ func TestServer_Authoritative(t *testing.T) { t.Fatal("The authoritative flag should be set") } } + +func TestServer_AddZone_Simple(t *testing.T) { + const ( + initialZoneName = "initial.example." + additionalZoneName = "additional.example." + expectedName = "resolved.example" + ) + + // create server with initial zone record + srv, err := NewServer(map[string]Zone{ + initialZoneName: Zone{ + CNAME: expectedName, + }, + }, false) + if err != nil { + t.Fatal(err) + } + defer srv.Close() + + // ensure initial zone record resolves correctly + resolvedInitialName, err := srv.Resolver().LookupCNAME(context.Background(), initialZoneName) + if err != nil { + t.Fatal(err) + } + if expectedName != resolvedInitialName { + t.Fatalf("expected: %s; got: %s", expectedName, resolvedInitialName) + } + + // add additional zone record + err = srv.AddZone(additionalZoneName, Zone{ + CNAME: expectedName, + }) + if err != nil { + t.Fatal(err) + } + + // ensure additional zone record resolves correctly + resolvedAdditionalName, err := srv.Resolver().LookupCNAME(context.Background(), additionalZoneName) + if err != nil { + t.Fatal(err) + } + if expectedName != resolvedAdditionalName { + t.Fatalf("expected: %s; got: %s", expectedName, resolvedInitialName) + } +} + +func TestServer_AddZone_Existing(t *testing.T) { + const ( + initialZoneName = "initial.example." + expectedName = "expected.example" + unexpectedName = "unexpected.example" + ) + + var expectedErr = fmt.Errorf(ErrExistingZoneFmt, initialZoneName) + + // create server with initial zone record + srv, err := NewServer(map[string]Zone{ + initialZoneName: Zone{ + CNAME: expectedName, + }, + }, false) + if err != nil { + t.Fatal(err) + } + defer srv.Close() + + // ensure initial zone record resolves correctly + resolvedInitialName, err := srv.Resolver().LookupCNAME(context.Background(), initialZoneName) + if err != nil { + t.Fatal(err) + } + if expectedName != resolvedInitialName { + t.Fatalf("expected: %q but got: %q", initialZoneName, resolvedInitialName) + } + + // attempt to add existing zone record + err = srv.AddZone(initialZoneName, Zone{ + CNAME: unexpectedName, + }) + if expectedErr.Error() != err.Error() { + t.Fatalf("expected error %q but got %q", expectedErr, err) + } + + // ensure initial zone record resolves correctly + resolvedInitialName, err = srv.Resolver().LookupCNAME(context.Background(), initialZoneName) + if err != nil { + t.Fatal(err) + } + if expectedName != resolvedInitialName { + t.Fatalf("expected: %q but got: %q", initialZoneName, resolvedInitialName) + } + + // ensure unexpected zone record does not resolve + _, err = srv.Resolver().LookupCNAME(context.Background(), unexpectedName) + if err == nil { + t.Fatal("expected error but got nil") + } +} + +func TestServer_AddZone_Concurrent(t *testing.T) { + const ( + initialZoneName = "initial.example." + additionalZoneName = "additional.example." + expectedName = "resolved.example" + ) + + var ( + errCh = make(chan error, 1) + ) + + // create server with initial zone record + srv, err := NewServer(map[string]Zone{ + initialZoneName: Zone{ + CNAME: expectedName, + }, + }, false) + if err != nil { + t.Fatal(err) + } + defer srv.Close() + + go func() { + // add additional zone record + err := srv.AddZone(additionalZoneName, Zone{ + CNAME: expectedName, + }) + if err != nil { + errCh <- err + } + + // ensure additional zone record resolves correctly + resolvedAdditionalName, err := srv.Resolver().LookupCNAME(context.Background(), additionalZoneName) + if err != nil { + errCh <- err + } + if expectedName != resolvedAdditionalName { + errCh <- fmt.Errorf("expected: %s but got: %s", expectedName, resolvedAdditionalName) + } + + close(errCh) + }() + + // ensure initial zone record resolves correctly + resolvedInitialName, err := srv.Resolver().LookupCNAME(context.Background(), initialZoneName) + if err != nil { + t.Fatal(err) + } + if expectedName != resolvedInitialName { + t.Fatalf("expected: %s; got: %s", expectedName, resolvedInitialName) + } + + if err := <-errCh; err != nil { + t.Fatalf("unexpected error: %s", err) + } +}