Skip to content

Commit

Permalink
feat: add Server#AddZone() and Router#AddZone() methods
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanchriswhite committed May 15, 2023
1 parent 942feaa commit 1887f3e
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 0 deletions.
15 changes: 15 additions & 0 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ import (
"fmt"
"net"
"strings"
"sync"

"github.com/miekg/dns"
)

const ErrExistingZoneFmt = "attempted to add existing zone %q"

type Zone struct {
// Return the specified error on any lookup using this zone.
// For Server, non-nil value results in SERVFAIL response.
Expand Down Expand Up @@ -325,6 +328,18 @@ func (r *Resolver) DialContext(ctx context.Context, network, addr string) (net.C
return nil, lastErr
}

func (r *Resolver) AddZone(name string, zone Zone) error {
r.zonesMutex.Lock()
defer r.zonesMutex.Unlock()

if _, ok := r.Zones[name]; ok {
return fmt.Errorf(ErrExistingZoneFmt, name)
}

r.Zones[name] = zone
return nil
}

func (r *Resolver) GetZone(name string) (Zone, bool) {
r.zonesMutex.RLock()
defer r.zonesMutex.RUnlock()
Expand Down
4 changes: 4 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,10 @@ func (s *Server) PatchNet(r *net.Resolver) {
}
}

func (s *Server) AddZone(name string, zone Zone) error {
return s.r.AddZone(name, zone)
}

func UnpatchNet(r *net.Resolver) {
r.PreferGo = false
r.Dial = nil
Expand Down
156 changes: 156 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mockdns

import (
"context"
"fmt"
"net"
"reflect"
"sort"
Expand Down Expand Up @@ -175,3 +176,158 @@ func TestServer_Authoritative(t *testing.T) {
t.Fatal("The authoritative flag should be set")
}
}

func TestServer_AddZone_Simple(t *testing.T) {
const (
initialZoneName = "initial.example."
additionalZoneName = "additional.example."
expectedName = "resolved.example"
)

// create server with initial zone record
srv, err := NewServer(map[string]Zone{
initialZoneName: Zone{
CNAME: expectedName,
},
}, false)
if err != nil {
t.Fatal(err)
}
defer srv.Close()

// ensure initial zone record resolves correctly
resolvedInitialName, err := srv.Resolver().LookupCNAME(context.Background(), initialZoneName)
if err != nil {
t.Fatal(err)
}
if expectedName != resolvedInitialName {
t.Fatalf("expected: %s; got: %s", expectedName, resolvedInitialName)
}

// add additional zone record
err = srv.AddZone(additionalZoneName, Zone{
CNAME: expectedName,
})
if err != nil {
t.Fatal(err)
}

// ensure additional zone record resolves correctly
resolvedAdditionalName, err := srv.Resolver().LookupCNAME(context.Background(), additionalZoneName)
if err != nil {
t.Fatal(err)
}
if expectedName != resolvedAdditionalName {
t.Fatalf("expected: %s; got: %s", expectedName, resolvedInitialName)
}
}

func TestServer_AddZone_Existing(t *testing.T) {
const (
initialZoneName = "initial.example."
expectedName = "expected.example"
unexpectedName = "unexpected.example"
)

var expectedErr = fmt.Errorf(ErrExistingZoneFmt, initialZoneName)

// create server with initial zone record
srv, err := NewServer(map[string]Zone{
initialZoneName: Zone{
CNAME: expectedName,
},
}, false)
if err != nil {
t.Fatal(err)
}
defer srv.Close()

// ensure initial zone record resolves correctly
resolvedInitialName, err := srv.Resolver().LookupCNAME(context.Background(), initialZoneName)
if err != nil {
t.Fatal(err)
}
if expectedName != resolvedInitialName {
t.Fatalf("expected: %q but got: %q", initialZoneName, resolvedInitialName)
}

// attempt to add existing zone record
err = srv.AddZone(initialZoneName, Zone{
CNAME: unexpectedName,
})
if expectedErr.Error() != err.Error() {
t.Fatalf("expected error %q but got %q", expectedErr, err)
}

// ensure initial zone record resolves correctly
resolvedInitialName, err = srv.Resolver().LookupCNAME(context.Background(), initialZoneName)
if err != nil {
t.Fatal(err)
}
if expectedName != resolvedInitialName {
t.Fatalf("expected: %q but got: %q", initialZoneName, resolvedInitialName)
}

// ensure unexpected zone record does not resolve
_, err = srv.Resolver().LookupCNAME(context.Background(), unexpectedName)
if err == nil {
t.Fatal("expected error but got nil")
}
}

func TestServer_AddZone_Concurrent(t *testing.T) {
const (
initialZoneName = "initial.example."
additionalZoneName = "additional.example."
expectedName = "resolved.example"
)

var (
errCh = make(chan error, 1)
)

// create server with initial zone record
srv, err := NewServer(map[string]Zone{
initialZoneName: Zone{
CNAME: expectedName,
},
}, false)
if err != nil {
t.Fatal(err)
}
defer srv.Close()

go func() {
// add additional zone record
err := srv.AddZone(additionalZoneName, Zone{
CNAME: expectedName,
})
if err != nil {
errCh <- err
}

// ensure additional zone record resolves correctly
resolvedAdditionalName, err := srv.Resolver().LookupCNAME(context.Background(), additionalZoneName)
if err != nil {
errCh <- err
}
if expectedName != resolvedAdditionalName {
errCh <- fmt.Errorf("expected: %s but got: %s", expectedName, resolvedAdditionalName)
}

close(errCh)
}()

// ensure initial zone record resolves correctly
resolvedInitialName, err := srv.Resolver().LookupCNAME(context.Background(), initialZoneName)
if err != nil {
t.Fatal(err)
}
if expectedName != resolvedInitialName {
t.Fatalf("expected: %s; got: %s", expectedName, resolvedInitialName)
}

if err := <-errCh; err != nil {
t.Fatalf("unexpected error: %s", err)
}
}

0 comments on commit 1887f3e

Please sign in to comment.