diff --git a/adapters/dns.go b/adapters/dns.go index 37d3a4f..033a642 100644 --- a/adapters/dns.go +++ b/adapters/dns.go @@ -10,9 +10,12 @@ import ( "sync" "time" + "github.com/cenkalti/backoff/v4" "github.com/miekg/dns" "github.com/overmindtech/sdp-go" "github.com/overmindtech/sdpcache" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) // DNSAdapter struct on which all methods are registered @@ -47,6 +50,7 @@ func (s *DNSAdapter) Cache() *sdpcache.Cache { } var DefaultServers = []string{ + "169.254.169.253:53", // Route 53 default resolver. See https://docs.aws.amazon.com/vpc/latest/userguide/AmazonDNS-concepts.html#AmazonDNS "1.1.1.1:53", "8.8.8.8:53", "8.8.4.4:53", @@ -57,26 +61,6 @@ const UniqueAttribute = "name" var ErrNoServersAvailable = errors.New("no dns servers available") -// getActiveServer -func (d *DNSAdapter) getActiveServer(ctx context.Context) (string, error) { - if len(d.Servers) == 0 { - d.Servers = DefaultServers - } - - for _, server := range d.Servers { - conn, err := d.client.DialContext(ctx, server) - - if err != nil { - continue - } - - defer conn.Close() - return server, nil - } - - return "", ErrNoServersAvailable -} - // Type is the type of items that this returns func (d *DNSAdapter) Type() string { return "dns" @@ -92,6 +76,13 @@ func (d *DNSAdapter) Weight() int { return 100 } +func (d *DNSAdapter) GetServers() []string { + if len(d.Servers) == 0 { + return DefaultServers + } + return d.Servers +} + func (d *DNSAdapter) Metadata() *sdp.AdapterMetadata { return dnsMetadata } @@ -221,15 +212,71 @@ func (d *DNSAdapter) Search(ctx context.Context, scope string, query string, ign return items, nil } -// MakeReverseQuery Makes a reverse DNS query, then forward DNS queries for all results -func (d *DNSAdapter) MakeReverseQuery(ctx context.Context, query string) ([]*sdp.Item, error) { - arpa, err := dns.ReverseAddr(query) +// retryDNSQuery handles retrying DNS queries with backoff and server rotation +func (d *DNSAdapter) retryDNSQuery(ctx context.Context, queryFn func(context.Context, string) ([]*sdp.Item, error)) ([]*sdp.Item, error) { + b := backoff.NewExponentialBackOff() + b.InitialInterval = 100 * time.Millisecond + b.MaxInterval = 500 * time.Millisecond + b.MaxElapsedTime = 30 * time.Second + + var items []*sdp.Item + var i int + var server string + + operation := func() error { + if i >= len(d.GetServers()) { + i = 0 + } + + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + server = d.GetServers()[i] + + var err error + items, err = queryFn(ctx, server) + if err != nil { + i++ // Move to next server on error + if errors.Is(err, context.DeadlineExceeded) || + strings.Contains(err.Error(), "timeout") || + strings.Contains(err.Error(), "temporary failure") { + return err // Retry on timeout + } + return backoff.Permanent(err) + } + + return nil + } + + err := backoff.Retry(operation, b) + span := trace.SpanFromContext(ctx) + span.SetAttributes( + attribute.String("ovm.dns.server", server), + ) if err != nil { return nil, err } - server, err := d.getActiveServer(ctx) + return items, nil +} + +// Updated MakeQuery +func (d *DNSAdapter) MakeQuery(ctx context.Context, query string) ([]*sdp.Item, error) { + return d.retryDNSQuery(ctx, func(ctx context.Context, server string) ([]*sdp.Item, error) { + return d.makeQueryImpl(ctx, query, server) + }) +} + +// Updated MakeReverseQuery +func (d *DNSAdapter) MakeReverseQuery(ctx context.Context, query string) ([]*sdp.Item, error) { + return d.retryDNSQuery(ctx, func(ctx context.Context, server string) ([]*sdp.Item, error) { + return d.makeReverseQueryImpl(ctx, query, server) + }) +} + +func (d *DNSAdapter) makeReverseQueryImpl(ctx context.Context, query string, server string) ([]*sdp.Item, error) { + arpa, err := dns.ReverseAddr(query) if err != nil { return nil, err @@ -282,14 +329,7 @@ func trimDnsSuffix(name string) string { return name } -// MakeQuery Actually makes A and AAAA queries for a given DNS entry -func (d *DNSAdapter) MakeQuery(ctx context.Context, query string) ([]*sdp.Item, error) { - server, err := d.getActiveServer(ctx) - - if err != nil { - return nil, err - } - +func (d *DNSAdapter) makeQueryImpl(ctx context.Context, query string, server string) ([]*sdp.Item, error) { // Create the query msg := dns.Msg{ Question: []dns.Question{ diff --git a/cmd/root.go b/cmd/root.go index def89bf..95b199f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,6 +1,8 @@ package cmd import ( + "context" + "errors" "fmt" "net/http" "os" @@ -72,11 +74,35 @@ var rootCmd = &cobra.Command{ healthCheckPort := viper.GetString("service-port") healthCheckPath := "/healthz" + healthCheckDNSAdapter := adapters.DNSAdapter{} + + // Set up the health check + healthCheck := func() error { + if !e.IsNATSConnected() { + return errors.New("NATS not connected") + } + + // We have seen some issues with DNS lookups within kube where the + // stdlib container will just start timing out on DNS requests. We + // should check that the DNS adapter is working so that the + // container can die if this happens to it + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := healthCheckDNSAdapter.Search(ctx, "global", "www.google.com", true) + if err != nil { + return fmt.Errorf("test dns lookup failed: %w", err) + } + + return nil + } + + e.EngineConfig.HeartbeatOptions.HealthCheck = healthCheck http.HandleFunc(healthCheckPath, func(rw http.ResponseWriter, r *http.Request) { - if e.IsNATSConnected() { + err := healthCheck() + if err == nil { fmt.Fprint(rw, "ok") } else { - http.Error(rw, "NATS not connected", http.StatusInternalServerError) + http.Error(rw, err.Error(), http.StatusInternalServerError) } })