diff --git a/lib/service/service.go b/lib/service/service.go
index 3de8967dd71f3..53492052aa26e 100644
--- a/lib/service/service.go
+++ b/lib/service/service.go
@@ -4583,6 +4583,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
TracerProvider: process.TracingProvider,
AutomaticUpgradesChannels: cfg.Proxy.AutomaticUpgradesChannels,
IntegrationAppHandler: connectionsHandler,
+ FeatureWatchInterval: utils.HalfJitter(web.DefaultFeatureWatchInterval * 2),
}
webHandler, err := web.NewHandler(webConfig)
if err != nil {
diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go
index 771adabe1f3f3..76a440143b977 100644
--- a/lib/web/apiserver.go
+++ b/lib/web/apiserver.go
@@ -21,6 +21,7 @@
package web
import (
+ "cmp"
"compress/gzip"
"context"
"encoding/base64"
@@ -126,6 +127,9 @@ const (
// Example values:
// - github-actions-ssh: indicates that the resource was added via the Bot GitHub Actions SSH flow
webUIFlowLabelKey = "teleport.internal/ui-flow"
+ // DefaultFeatureWatchInterval is the default time in which the feature watcher
+ // should ping the auth server to check for updated features
+ DefaultFeatureWatchInterval = time.Minute * 5
)
// healthCheckAppServerFunc defines a function used to perform a health check
@@ -162,10 +166,6 @@ type Handler struct {
userConns atomic.Int32
// ClusterFeatures contain flags for supported and unsupported features.
- // Note: This field can become stale since it's only set on initial proxy
- // startup. To get the latest feature flags you'll need to ping from the
- // auth server.
- // https://github.com/gravitational/teleport/issues/39161
ClusterFeatures proto.Features
// nodeWatcher is a services.NodeWatcher used by Assist to lookup nodes from
@@ -325,6 +325,10 @@ type Config struct {
// IntegrationAppHandler handles App Access requests which use an Integration.
IntegrationAppHandler app.ServerHandler
+
+ // FeatureWatchInterval is the interval between pings to the auth server
+ // to fetch new cluster features
+ FeatureWatchInterval time.Duration
}
// SetDefaults ensures proper default values are set if
@@ -339,6 +343,8 @@ func (c *Config) SetDefaults() {
if c.PresenceChecker == nil {
c.PresenceChecker = client.RunPresenceTask
}
+
+ c.FeatureWatchInterval = cmp.Or(c.FeatureWatchInterval, DefaultFeatureWatchInterval)
}
type APIHandler struct {
@@ -669,6 +675,8 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {
}
}
+ go h.startFeatureWatcher()
+
return &APIHandler{
handler: h,
appHandler: appHandler,
@@ -1692,14 +1700,7 @@ func (h *Handler) getWebConfig(w http.ResponseWriter, r *http.Request, p httprou
}
}
- clusterFeatures := h.ClusterFeatures
- // ping server to get cluster features since h.ClusterFeatures may be stale
- pingResponse, err := h.GetProxyClient().Ping(r.Context())
- if err != nil {
- h.log.WithError(err).Warn("Cannot retrieve cluster features, client may receive stale features")
- } else {
- clusterFeatures = *pingResponse.ServerFeatures
- }
+ clusterFeatures := h.GetClusterFeatures()
// get tunnel address to display on cloud instances
tunnelPublicAddr := ""
diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go
index 854695f648866..dc7aedbe97d25 100644
--- a/lib/web/apiserver_test.go
+++ b/lib/web/apiserver_test.go
@@ -4592,6 +4592,7 @@ func TestGetWebConfig(t *testing.T) {
env := newWebPack(t, 1, func(cfg *proxyConfig) {
cfg.serviceConfig = svcConfig
})
+ handler := env.proxies[0].handler.handler
// Set auth preference with passwordless.
const MOTD = "Welcome to cluster, your activity will be recorded."
@@ -4622,6 +4623,9 @@ func TestGetWebConfig(t *testing.T) {
_, err = env.server.Auth().UpsertGithubConnector(ctx, github)
require.NoError(t, err)
+ // start the feature watcher so the web config gets new features
+ env.clock.Advance(DefaultFeatureWatchInterval * 2)
+
expectedCfg := webclient.WebConfig{
Auth: webclient.WebConfigAuthSettings{
SecondFactor: constants.SecondFactorOptional,
@@ -4669,6 +4673,7 @@ func TestGetWebConfig(t *testing.T) {
AutomaticUpgrades: true,
},
})
+ env.clock.Advance(DefaultFeatureWatchInterval * 2)
svcConfig.Proxy.AssistAPIKey = "test"
require.NoError(t, err)
@@ -4680,7 +4685,7 @@ func TestGetWebConfig(t *testing.T) {
},
}
require.NoError(t, channels.CheckAndSetDefaults())
- env.proxies[0].handler.handler.cfg.AutomaticUpgradesChannels = channels
+ handler.cfg.AutomaticUpgradesChannels = channels
expectedCfg.IsCloud = true
expectedCfg.IsUsageBasedBilling = true
@@ -4689,14 +4694,20 @@ func TestGetWebConfig(t *testing.T) {
expectedCfg.AssistEnabled = false
expectedCfg.JoinActiveSessions = false
- // request and verify enabled features are enabled.
- re, err = clt.Get(ctx, endpoint, nil)
- require.NoError(t, err)
- require.True(t, strings.HasPrefix(string(re.Bytes()), "var GRV_CONFIG"))
- str = strings.ReplaceAll(string(re.Bytes()), "var GRV_CONFIG = ", "")
- err = json.Unmarshal([]byte(str[:len(str)-1]), &cfg)
- require.NoError(t, err)
- require.Equal(t, expectedCfg, cfg)
+ // request and verify enabled features are eventually enabled.
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ re, err := clt.Get(ctx, endpoint, nil)
+ if !assert.NoError(t, err) {
+ return
+ }
+ assert.True(t, bytes.HasPrefix(re.Bytes(), []byte("var GRV_CONFIG")))
+ res := bytes.ReplaceAll(re.Bytes(), []byte("var GRV_CONFIG = "), []byte{})
+ err = json.Unmarshal(res[:len(res)-1], &cfg)
+ assert.NoError(t, err)
+ diff := cmp.Diff(expectedCfg, cfg)
+ assert.Empty(t, diff)
+
+ }, time.Second*5, time.Millisecond*50)
// use mock client to assert that if ping returns an error, we'll default to
// cluster config
@@ -4715,15 +4726,22 @@ func TestGetWebConfig(t *testing.T) {
IsUsageBasedBilling: false,
},
})
+ env.clock.Advance(DefaultFeatureWatchInterval * 2)
// request and verify again
- re, err = clt.Get(ctx, endpoint, nil)
- require.NoError(t, err)
- require.True(t, strings.HasPrefix(string(re.Bytes()), "var GRV_CONFIG"))
- str = strings.ReplaceAll(string(re.Bytes()), "var GRV_CONFIG = ", "")
- err = json.Unmarshal([]byte(str[:len(str)-1]), &cfg)
- require.NoError(t, err)
- require.Equal(t, expectedCfg, cfg)
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ re, err := clt.Get(ctx, endpoint, nil)
+ if !assert.NoError(t, err) {
+ return
+ }
+ assert.True(t, bytes.HasPrefix(re.Bytes(), []byte("var GRV_CONFIG")))
+ res := bytes.ReplaceAll(re.Bytes(), []byte("var GRV_CONFIG = "), []byte{})
+ err = json.Unmarshal(res[:len(res)-1], &cfg)
+ assert.NoError(t, err)
+ diff := cmp.Diff(expectedCfg, cfg)
+ assert.Empty(t, diff)
+
+ }, time.Second*5, time.Millisecond*50)
}
func TestGetWebConfig_IGSFeatureLimits(t *testing.T) {
@@ -4745,6 +4763,8 @@ func TestGetWebConfig_IGSFeatureLimits(t *testing.T) {
Questionnaire: true,
},
})
+ // start the feature watcher so the web config gets new features
+ env.clock.Advance(DefaultFeatureWatchInterval * 2)
expectedCfg := webclient.WebConfig{
Auth: webclient.WebConfigAuthSettings{
@@ -4766,20 +4786,25 @@ func TestGetWebConfig_IGSFeatureLimits(t *testing.T) {
IsUsageBasedBilling: true,
}
- // Make a request.
clt := env.proxies[0].newClient(t)
- endpoint := clt.Endpoint("web", "config.js")
- re, err := clt.Get(ctx, endpoint, nil)
- require.NoError(t, err)
- require.True(t, strings.HasPrefix(string(re.Bytes()), "var GRV_CONFIG"))
-
- // Response is type application/javascript, we need to strip off the variable name
- // and the semicolon at the end, then we are left with json like object.
- var cfg webclient.WebConfig
- str := strings.ReplaceAll(string(re.Bytes()), "var GRV_CONFIG = ", "")
- err = json.Unmarshal([]byte(str[:len(str)-1]), &cfg)
- require.NoError(t, err)
- require.Equal(t, expectedCfg, cfg)
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ // Make a request.
+ endpoint := clt.Endpoint("web", "config.js")
+ re, err := clt.Get(ctx, endpoint, nil)
+ if !assert.NoError(t, err) {
+ return
+ }
+ assert.True(t, bytes.HasPrefix(re.Bytes(), []byte("var GRV_CONFIG")))
+
+ // Response is type application/javascript, we need to strip off the variable name
+ // and the semicolon at the end, then we are left with json like object.
+ var cfg webclient.WebConfig
+ res := bytes.ReplaceAll(re.Bytes(), []byte("var GRV_CONFIG = "), []byte{})
+ err = json.Unmarshal(res[:len(res)-1], &cfg)
+ assert.NoError(t, err)
+ diff := cmp.Diff(expectedCfg, cfg)
+ assert.Empty(t, diff)
+ }, time.Second*5, time.Millisecond*50)
}
func TestCreatePrivilegeToken(t *testing.T) {
diff --git a/lib/web/features.go b/lib/web/features.go
new file mode 100644
index 0000000000000..b387d0f176958
--- /dev/null
+++ b/lib/web/features.go
@@ -0,0 +1,71 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package web
+
+import (
+ "github.com/gravitational/teleport/api/client/proto"
+)
+
+// SetClusterFeatures sets the flags for supported and unsupported features.
+// TODO(mcbattirola): make method unexported, fix tests using it to set
+// test modules instead.
+func (h *Handler) SetClusterFeatures(features proto.Features) {
+ h.Mutex.Lock()
+ defer h.Mutex.Unlock()
+
+ h.ClusterFeatures = features
+}
+
+// GetClusterFeatures returns flags for supported and unsupported features.
+func (h *Handler) GetClusterFeatures() proto.Features {
+ h.Mutex.Lock()
+ defer h.Mutex.Unlock()
+
+ return h.ClusterFeatures
+}
+
+// startFeatureWatcher periodically pings the auth server and updates `clusterFeatures`.
+// Must be called only once per `handler`, otherwise it may close an already closed channel
+// which will cause a panic.
+// The watcher doesn't ping the auth server immediately upon start because features are
+// already set by the config object in `NewHandler`.
+func (h *Handler) startFeatureWatcher() {
+ ticker := h.clock.NewTicker(h.cfg.FeatureWatchInterval)
+ h.log.WithField("interval", h.cfg.FeatureWatchInterval).Info("Proxy handler features watcher has started")
+ ctx := h.cfg.Context
+
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ticker.Chan():
+ h.log.Info("Pinging auth server for features")
+ pingResponse, err := h.GetProxyClient().Ping(ctx)
+ if err != nil {
+ h.log.WithError(err).Error("Auth server ping failed")
+ continue
+ }
+
+ h.SetClusterFeatures(*pingResponse.ServerFeatures)
+ h.log.WithField("features", pingResponse.ServerFeatures).Info("Done updating proxy features")
+ case <-ctx.Done():
+ h.log.Info("Feature service has stopped")
+ return
+ }
+ }
+}
diff --git a/lib/web/features_test.go b/lib/web/features_test.go
new file mode 100644
index 0000000000000..9e167f58324ca
--- /dev/null
+++ b/lib/web/features_test.go
@@ -0,0 +1,160 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package web
+
+import (
+ "context"
+ "log/slog"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/jonboulle/clockwork"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/api/client/proto"
+ "github.com/gravitational/teleport/api/utils"
+ "github.com/gravitational/teleport/lib/auth/authclient"
+)
+
+// mockedPingTestProxy is a test proxy with a mocked Ping method
+// that returns the internal features
+type mockedFeatureGetter struct {
+ authclient.ClientI
+ features proto.Features
+}
+
+func (m *mockedFeatureGetter) Ping(ctx context.Context) (proto.PingResponse, error) {
+ return proto.PingResponse{
+ ServerFeatures: utils.CloneProtoMsg(&m.features),
+ }, nil
+}
+
+func (m *mockedFeatureGetter) setFeatures(f proto.Features) {
+ m.features = f
+}
+
+func TestFeaturesWatcher(t *testing.T) {
+ clock := clockwork.NewFakeClock()
+
+ mockClient := &mockedFeatureGetter{features: proto.Features{
+ Kubernetes: true,
+ AccessRequests: &proto.AccessRequestsFeature{},
+ }}
+
+ ctx, cancel := context.WithCancel(context.Background())
+ handler := &Handler{
+ cfg: Config{
+ FeatureWatchInterval: 100 * time.Millisecond,
+ ProxyClient: mockClient,
+ Context: ctx,
+ },
+ clock: clock,
+ ClusterFeatures: proto.Features{},
+ log: newPackageLogger(),
+ logger: slog.Default().With(teleport.ComponentKey, teleport.ComponentWeb),
+ }
+
+ // before running the watcher, features should match the value passed to the handler
+ requireFeatures(t, clock, proto.Features{}, handler.GetClusterFeatures)
+
+ go handler.startFeatureWatcher()
+ clock.BlockUntil(1)
+
+ // after starting the watcher, handler.GetClusterFeatures should return
+ // values matching the client's response
+ features := proto.Features{
+ Kubernetes: true,
+ AccessRequests: &proto.AccessRequestsFeature{},
+ }
+ expected := utils.CloneProtoMsg(&features)
+ requireFeatures(t, clock, *expected, handler.GetClusterFeatures)
+
+ // update values once again and check if the features are properly updated
+ features = proto.Features{
+ Kubernetes: false,
+ AccessRequests: &proto.AccessRequestsFeature{},
+ }
+ mockClient.setFeatures(features)
+ expected = utils.CloneProtoMsg(&features)
+ requireFeatures(t, clock, *expected, handler.GetClusterFeatures)
+
+ // test updating features
+ features = proto.Features{
+ Kubernetes: true,
+ ExternalAuditStorage: true,
+ AccessList: &proto.AccessListFeature{CreateLimit: 10},
+ AccessMonitoring: &proto.AccessMonitoringFeature{},
+ App: true,
+ AccessRequests: &proto.AccessRequestsFeature{},
+ }
+ mockClient.setFeatures(features)
+
+ expected = &proto.Features{
+ Kubernetes: true,
+ ExternalAuditStorage: true,
+ AccessList: &proto.AccessListFeature{CreateLimit: 10},
+ AccessMonitoring: &proto.AccessMonitoringFeature{},
+ App: true,
+ AccessRequests: &proto.AccessRequestsFeature{},
+ }
+ requireFeatures(t, clock, *expected, handler.GetClusterFeatures)
+
+ // stop watcher and ensure it stops updating features
+ cancel()
+ features = proto.Features{
+ Kubernetes: !features.Kubernetes,
+ App: !features.App,
+ DB: true,
+ AccessRequests: &proto.AccessRequestsFeature{},
+ }
+ mockClient.setFeatures(features)
+ notExpected := utils.CloneProtoMsg(&features)
+ // assert the handler never get these last features as the watcher is stopped
+ neverFeatures(t, clock, *notExpected, handler.GetClusterFeatures)
+}
+
+// requireFeatures is a helper function that advances the clock, then
+// calls `getFeatures` every 100ms for up to 1 second, until it
+// returns the expected result (`want`).
+func requireFeatures(t *testing.T, fakeClock clockwork.FakeClock, want proto.Features, getFeatures func() proto.Features) {
+ t.Helper()
+
+ // Advance the clock so the service fetch and stores features
+ fakeClock.Advance(1 * time.Second)
+
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ diff := cmp.Diff(want, getFeatures())
+ assert.Empty(t, diff)
+ }, 5*time.Second, time.Millisecond*100)
+}
+
+// neverFeatures is a helper function that advances the clock, then
+// calls `getFeatures` every 100ms for up to 1 second. If at some point `getFeatures`
+// returns `doNotWant`, the test fails.
+func neverFeatures(t *testing.T, fakeClock clockwork.FakeClock, doNotWant proto.Features, getFeatures func() proto.Features) {
+ t.Helper()
+
+ fakeClock.Advance(1 * time.Second)
+ require.Never(t, func() bool {
+ return cmp.Diff(doNotWant, getFeatures()) == ""
+ }, 1*time.Second, time.Millisecond*100)
+}
diff --git a/lib/web/integrations_awsoidc.go b/lib/web/integrations_awsoidc.go
index 1819a66059ece..631643f1fd7d4 100644
--- a/lib/web/integrations_awsoidc.go
+++ b/lib/web/integrations_awsoidc.go
@@ -125,7 +125,7 @@ func (h *Handler) awsOIDCDeployService(w http.ResponseWriter, r *http.Request, p
}
teleportVersionTag := teleport.Version
- if automaticUpgrades(h.ClusterFeatures) {
+ if automaticUpgrades(h.GetClusterFeatures()) {
cloudStableVersion, err := h.cfg.AutomaticUpgradesChannels.DefaultVersion(ctx)
if err != nil {
return "", trace.Wrap(err)
@@ -178,7 +178,7 @@ func (h *Handler) awsOIDCDeployDatabaseServices(w http.ResponseWriter, r *http.R
}
teleportVersionTag := teleport.Version
- if automaticUpgrades(h.ClusterFeatures) {
+ if automaticUpgrades(h.GetClusterFeatures()) {
cloudStableVersion, err := h.cfg.AutomaticUpgradesChannels.DefaultVersion(ctx)
if err != nil {
return "", trace.Wrap(err)
@@ -480,7 +480,7 @@ func (h *Handler) awsOIDCEnrollEKSClusters(w http.ResponseWriter, r *http.Reques
return nil, trace.BadParameter("an integration name is required")
}
- agentVersion, err := kubeutils.GetKubeAgentVersion(ctx, h.cfg.ProxyClient, h.ClusterFeatures, h.cfg.AutomaticUpgradesChannels)
+ agentVersion, err := kubeutils.GetKubeAgentVersion(ctx, h.cfg.ProxyClient, h.GetClusterFeatures(), h.cfg.AutomaticUpgradesChannels)
if err != nil {
return nil, trace.Wrap(err)
}
diff --git a/lib/web/join_tokens.go b/lib/web/join_tokens.go
index 9b5bfd459aaf3..31338c34837e3 100644
--- a/lib/web/join_tokens.go
+++ b/lib/web/join_tokens.go
@@ -216,7 +216,7 @@ func (h *Handler) createTokenHandle(w http.ResponseWriter, r *http.Request, para
func (h *Handler) getAutoUpgrades(ctx context.Context) (bool, string, error) {
var autoUpgradesVersion string
var err error
- autoUpgrades := automaticUpgrades(h.ClusterFeatures)
+ autoUpgrades := automaticUpgrades(h.GetClusterFeatures())
if autoUpgrades {
autoUpgradesVersion, err = h.cfg.AutomaticUpgradesChannels.DefaultVersion(ctx)
if err != nil {