diff --git a/resolver.go b/resolver.go index f641ed9..cccbc6b 100644 --- a/resolver.go +++ b/resolver.go @@ -5,10 +5,13 @@ import ( "fmt" "net" "strings" + "sync" "github.com/miekg/dns" ) +const ErrExistingZoneFmt = "attempted to add existing zone %q" + type Zone struct { // Return the specified error on any lookup using this zone. // For Server, non-nil value results in SERVFAIL response. @@ -325,6 +328,18 @@ func (r *Resolver) DialContext(ctx context.Context, network, addr string) (net.C return nil, lastErr } +func (r *Resolver) AddZone(name string, zone Zone) error { + r.zonesMutex.Lock() + defer r.zonesMutex.Unlock() + + if _, ok := r.Zones[name]; ok { + return fmt.Errorf(ErrExistingZoneFmt, name) + } + + r.Zones[name] = zone + return nil +} + func (r *Resolver) GetZone(name string) (Zone, bool) { r.zonesMutex.RLock() defer r.zonesMutex.RUnlock() diff --git a/server.go b/server.go index d99617a..2efa7bb 100644 --- a/server.go +++ b/server.go @@ -399,6 +399,10 @@ func (s *Server) PatchNet(r *net.Resolver) { } } +func (s *Server) AddZone(name string, zone Zone) error { + return s.r.AddZone(name, zone) +} + func UnpatchNet(r *net.Resolver) { r.PreferGo = false r.Dial = nil diff --git a/server_test.go b/server_test.go index e820381..b87851d 100644 --- a/server_test.go +++ b/server_test.go @@ -2,6 +2,7 @@ package mockdns import ( "context" + "fmt" "net" "reflect" "sort" @@ -175,3 +176,158 @@ func TestServer_Authoritative(t *testing.T) { t.Fatal("The authoritative flag should be set") } } + +func TestServer_AddZone_Simple(t *testing.T) { + const ( + initialZoneName = "initial.example." + additionalZoneName = "additional.example." + expectedName = "resolved.example" + ) + + // create server with initial zone record + srv, err := NewServer(map[string]Zone{ + initialZoneName: Zone{ + CNAME: expectedName, + }, + }, false) + if err != nil { + t.Fatal(err) + } + defer srv.Close() + + // ensure initial zone record resolves correctly + resolvedInitialName, err := srv.Resolver().LookupCNAME(context.Background(), initialZoneName) + if err != nil { + t.Fatal(err) + } + if expectedName != resolvedInitialName { + t.Fatalf("expected: %s; got: %s", expectedName, resolvedInitialName) + } + + // add additional zone record + err = srv.AddZone(additionalZoneName, Zone{ + CNAME: expectedName, + }) + if err != nil { + t.Fatal(err) + } + + // ensure additional zone record resolves correctly + resolvedAdditionalName, err := srv.Resolver().LookupCNAME(context.Background(), additionalZoneName) + if err != nil { + t.Fatal(err) + } + if expectedName != resolvedAdditionalName { + t.Fatalf("expected: %s; got: %s", expectedName, resolvedInitialName) + } +} + +func TestServer_AddZone_Existing(t *testing.T) { + const ( + initialZoneName = "initial.example." + expectedName = "expected.example" + unexpectedName = "unexpected.example" + ) + + var expectedErr = fmt.Errorf(ErrExistingZoneFmt, initialZoneName) + + // create server with initial zone record + srv, err := NewServer(map[string]Zone{ + initialZoneName: Zone{ + CNAME: expectedName, + }, + }, false) + if err != nil { + t.Fatal(err) + } + defer srv.Close() + + // ensure initial zone record resolves correctly + resolvedInitialName, err := srv.Resolver().LookupCNAME(context.Background(), initialZoneName) + if err != nil { + t.Fatal(err) + } + if expectedName != resolvedInitialName { + t.Fatalf("expected: %q but got: %q", initialZoneName, resolvedInitialName) + } + + // attempt to add existing zone record + err = srv.AddZone(initialZoneName, Zone{ + CNAME: unexpectedName, + }) + if expectedErr.Error() != err.Error() { + t.Fatalf("expected error %q but got %q", expectedErr, err) + } + + // ensure initial zone record resolves correctly + resolvedInitialName, err = srv.Resolver().LookupCNAME(context.Background(), initialZoneName) + if err != nil { + t.Fatal(err) + } + if expectedName != resolvedInitialName { + t.Fatalf("expected: %q but got: %q", initialZoneName, resolvedInitialName) + } + + // ensure unexpected zone record does not resolve + _, err = srv.Resolver().LookupCNAME(context.Background(), unexpectedName) + if err == nil { + t.Fatal("expected error but got nil") + } +} + +func TestServer_AddZone_Concurrent(t *testing.T) { + const ( + initialZoneName = "initial.example." + additionalZoneName = "additional.example." + expectedName = "resolved.example" + ) + + var ( + errCh = make(chan error, 1) + ) + + // create server with initial zone record + srv, err := NewServer(map[string]Zone{ + initialZoneName: Zone{ + CNAME: expectedName, + }, + }, false) + if err != nil { + t.Fatal(err) + } + defer srv.Close() + + go func() { + // add additional zone record + err := srv.AddZone(additionalZoneName, Zone{ + CNAME: expectedName, + }) + if err != nil { + errCh <- err + } + + // ensure additional zone record resolves correctly + resolvedAdditionalName, err := srv.Resolver().LookupCNAME(context.Background(), additionalZoneName) + if err != nil { + errCh <- err + } + if expectedName != resolvedAdditionalName { + errCh <- fmt.Errorf("expected: %s but got: %s", expectedName, resolvedAdditionalName) + } + + close(errCh) + }() + + // ensure initial zone record resolves correctly + resolvedInitialName, err := srv.Resolver().LookupCNAME(context.Background(), initialZoneName) + if err != nil { + t.Fatal(err) + } + if expectedName != resolvedInitialName { + t.Fatalf("expected: %s; got: %s", expectedName, resolvedInitialName) + } + + if err := <-errCh; err != nil { + t.Fatalf("unexpected error: %s", err) + } +}