From 1999d6a695f67ca46a7fa08b45fc1e20932270b9 Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Thu, 12 Dec 2024 16:56:16 -0500 Subject: [PATCH] Implement ResolveSSHTarget in the auth server --- constants.go | 8 +++ integration/integration_test.go | 4 +- lib/auth/auth_with_roles.go | 91 +++++++++++++++++++++++++++++-- lib/auth/authclient/clt.go | 3 + lib/auth/grpcserver.go | 15 +++++ lib/auth/grpcserver_test.go | 63 +++++++++++++++++++++ lib/proxy/router.go | 5 +- lib/proxy/router_test.go | 8 +-- lib/services/readonly/readonly.go | 1 + lib/web/terminal.go | 2 +- tool/tsh/common/tsh.go | 4 +- tool/tsh/common/tsh_test.go | 6 +- 12 files changed, 190 insertions(+), 20 deletions(-) diff --git a/constants.go b/constants.go index 1c7ecf3226a96..63ba08b791c46 100644 --- a/constants.go +++ b/constants.go @@ -21,6 +21,8 @@ package teleport import ( "strings" "time" + + "github.com/gravitational/trace" ) // WebAPIVersion is a current webapi version @@ -823,9 +825,15 @@ const ( UsageWindowsDesktopOnly = "usage:windows_desktop" ) +// ErrNodeIsAmbiguous serves as an identifying error string indicating that +// the proxy subsystem found multiple nodes matching the specified hostname. +var ErrNodeIsAmbiguous = &trace.NotFoundError{Message: "ambiguous host could match multiple nodes"} + const ( // NodeIsAmbiguous serves as an identifying error string indicating that // the proxy subsystem found multiple nodes matching the specified hostname. + // TODO(tross) DELETE IN v20.0.0 + // Deprecated: Prefer using ErrNodeIsAmbiguous NodeIsAmbiguous = "err-node-is-ambiguous" // MaxLeases serves as an identifying error string indicating that the diff --git a/integration/integration_test.go b/integration/integration_test.go index cbc138b498f4f..9b5ff50af1f85 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -809,9 +809,7 @@ func testUUIDBasedProxy(t *testing.T, suite *integrationTestSuite) { // attempting to run a command by hostname should generate NodeIsAmbiguous error. _, err = runCommand(t, teleportSvr, []string{"echo", "Hello there!"}, helpers.ClientConfig{Login: suite.Me.Username, Cluster: helpers.Site, Host: Host}, 1) require.Error(t, err) - if !strings.Contains(err.Error(), teleport.NodeIsAmbiguous) { - require.FailNowf(t, "Expected %s, got %s", teleport.NodeIsAmbiguous, err.Error()) - } + require.ErrorContains(t, err, "ambiguous") // attempting to run a command by uuid should succeed. _, err = runCommand(t, teleportSvr, []string{"echo", "Hello there!"}, helpers.ClientConfig{Login: suite.Me.Username, Cluster: helpers.Site, Host: uuid1}, 1) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index e5fdaa9dab8ce..eb023db2592f5 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -1648,22 +1648,23 @@ func (a *ServerWithRoles) GetSSHTargets(ctx context.Context, req *proto.GetSSHTa return nil, trace.Wrap(err) } - lreq := proto.ListResourcesRequest{ - ResourceType: types.KindNode, + lreq := &proto.ListUnifiedResourcesRequest{ + Kinds: []string{types.KindNode}, + SortBy: types.SortBy{Field: types.ResourceMetadataName}, UseSearchAsRoles: true, } var servers []*types.ServerV2 for { - // note that we're calling ServerWithRoles.ListResources here rather than some internal method. This method + // note that we're calling ServerWithRoles.ListUnifiedResources here rather than some internal method. This method // delegates all RBAC filtering to ListResources, and then performs additional filtering on top of that. - lrsp, err := a.ListResources(ctx, lreq) + lrsp, err := a.ListUnifiedResources(ctx, lreq) if err != nil { return nil, trace.Wrap(err) } for _, rsc := range lrsp.Resources { - srv, ok := rsc.(*types.ServerV2) - if !ok { + srv := rsc.GetNode() + if srv == nil { log.Warnf("Unexpected resource type %T, expected *types.ServerV2 (skipping)", rsc) continue } @@ -1687,6 +1688,84 @@ func (a *ServerWithRoles) GetSSHTargets(ctx context.Context, req *proto.GetSSHTa }, nil } +// ResolveSSHTarget gets a server that would match an equivalent ssh dial request. +func (a *ServerWithRoles) ResolveSSHTarget(ctx context.Context, req *proto.ResolveSSHTargetRequest) (*proto.ResolveSSHTargetResponse, error) { + var servers []*types.ServerV2 + switch { + case req.Host != "": + resp, err := a.GetSSHTargets(ctx, &proto.GetSSHTargetsRequest{ + Host: req.Host, + Port: req.Port, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + servers = resp.Servers + case len(req.Labels) > 0 || req.PredicateExpression != "" || len(req.SearchKeywords) > 0: + lreq := &proto.ListUnifiedResourcesRequest{ + Kinds: []string{types.KindNode}, + SortBy: types.SortBy{Field: types.ResourceMetadataName}, + Labels: req.Labels, + PredicateExpression: req.PredicateExpression, + SearchKeywords: req.SearchKeywords, + } + for { + // note that we're calling ServerWithRoles.ListUnifiedResources here rather than some internal method. This method + // delegates all RBAC filtering to ListResources, and then performs additional filtering on top of that. + lrsp, err := a.ListUnifiedResources(ctx, lreq) + if err != nil { + return nil, trace.Wrap(err) + } + + for _, rsc := range lrsp.Resources { + srv := rsc.GetNode() + if srv == nil { + log.Warnf("Unexpected resource type %T, expected *types.ServerV2 (skipping)", rsc) + continue + } + + servers = append(servers, srv) + } + + if lrsp.NextKey == "" || len(lrsp.Resources) == 0 { + break + } + + lreq.StartKey = lrsp.NextKey + } + default: + return nil, trace.NotFound("no matching hosts") + } + + switch len(servers) { + case 1: + return &proto.ResolveSSHTargetResponse{Server: servers[0]}, nil + case 0: + return nil, trace.NotFound("no matching hosts") + default: + // try to detect case-insensitive routing setting, but default to false if we can't load + // networking config (equivalent to proxy routing behavior). + var routeToMostRecent bool + if cfg, err := a.authServer.GetReadOnlyClusterNetworkingConfig(ctx); err == nil { + routeToMostRecent = cfg.GetRoutingStrategy() == types.RoutingStrategy_MOST_RECENT + } + + if !routeToMostRecent { + return nil, trace.Wrap(teleport.ErrNodeIsAmbiguous) + } + + // Sort the resource by expiry so we can identify the most "recent" server. + slices.SortFunc(servers, func(a, b *types.ServerV2) int { + return a.Expiry().Compare(b.Expiry()) + }) + + // Sorting above is oldest expiry to newest expiry, so proceed + // with the last server in the slice. + return &proto.ResolveSSHTargetResponse{Server: servers[len(servers)-1]}, nil + } +} + // ListResources returns a paginated list of resources filtered by user access. func (a *ServerWithRoles) ListResources(ctx context.Context, req proto.ListResourcesRequest) (*types.ListResourcesResponse, error) { // Check if auth server has a license for this resource type but only return an diff --git a/lib/auth/authclient/clt.go b/lib/auth/authclient/clt.go index 4217b12a5991e..09e4caff54d29 100644 --- a/lib/auth/authclient/clt.go +++ b/lib/auth/authclient/clt.go @@ -1875,6 +1875,9 @@ type ClientI interface { // but may result in confusing behavior if it is used outside of those contexts. GetSSHTargets(ctx context.Context, req *proto.GetSSHTargetsRequest) (*proto.GetSSHTargetsResponse, error) + // ResolveSSHTarget returns the server that would be resolved in an equivalent ssh dial request. + ResolveSSHTarget(ctx context.Context, req *proto.ResolveSSHTargetRequest) (*proto.ResolveSSHTargetResponse, error) + // PerformMFACeremony retrieves an MFA challenge from the server with the given challenge extensions // and prompts the user to answer the challenge with the given promptOpts, and ultimately returning // an MFA challenge response for the user. diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 741f4626c957c..1ac98a7c93e52 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -4198,6 +4198,21 @@ func (g *GRPCServer) GetSSHTargets(ctx context.Context, req *authpb.GetSSHTarget return rsp, nil } +// ResolveSSHTarget gets a server that would match an equivalent ssh dial request. +func (g *GRPCServer) ResolveSSHTarget(ctx context.Context, req *authpb.ResolveSSHTargetRequest) (*authpb.ResolveSSHTargetResponse, error) { + auth, err := g.authenticate(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + rsp, err := auth.ServerWithRoles.ResolveSSHTarget(ctx, req) + if err != nil { + return nil, trace.Wrap(err) + } + + return rsp, nil +} + // CreateSessionTracker creates a tracker resource for an active session. func (g *GRPCServer) CreateSessionTracker(ctx context.Context, req *authpb.CreateSessionTrackerRequest) (*types.SessionTrackerV1, error) { auth, err := g.authenticate(ctx) diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index 8a91f952e001e..c92e521e386c0 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -2988,6 +2988,69 @@ func TestGetSSHTargets(t *testing.T) { require.ElementsMatch(t, []string{rsp.Servers[0].GetHostname(), rsp.Servers[1].GetHostname()}, []string{"foo", "Foo"}) } +func TestResolveSSHTarget(t *testing.T) { + t.Parallel() + ctx := context.Background() + srv := newTestTLSServer(t) + + clt, err := srv.NewClient(TestAdmin()) + require.NoError(t, err) + + upper, err := types.NewServerWithLabels(uuid.New().String(), types.KindNode, types.ServerSpecV2{ + Hostname: "Foo", + UseTunnel: true, + }, nil) + require.NoError(t, err) + upper.SetExpiry(time.Now().Add(time.Hour)) + + lower, err := types.NewServerWithLabels(uuid.New().String(), types.KindNode, types.ServerSpecV2{ + Hostname: "foo", + UseTunnel: true, + }, nil) + require.NoError(t, err) + + other, err := types.NewServerWithLabels(uuid.New().String(), types.KindNode, types.ServerSpecV2{ + Hostname: "bar", + UseTunnel: true, + }, nil) + require.NoError(t, err) + + for _, node := range []types.Server{upper, lower, other} { + _, err = clt.UpsertNode(ctx, node) + require.NoError(t, err) + } + + rsp, err := clt.ResolveSSHTarget(ctx, &proto.ResolveSSHTargetRequest{ + Host: "foo", + Port: "0", + }) + require.NoError(t, err) + require.Equal(t, "foo", rsp.Server.GetHostname()) + + cnc := types.DefaultClusterNetworkingConfig() + cnc.SetCaseInsensitiveRouting(true) + _, err = clt.UpsertClusterNetworkingConfig(ctx, cnc) + require.NoError(t, err) + + rsp, err = clt.ResolveSSHTarget(ctx, &proto.ResolveSSHTargetRequest{ + Host: "foo", + Port: "0", + }) + require.Error(t, err) + require.Nil(t, rsp) + + cnc.SetRoutingStrategy(types.RoutingStrategy_MOST_RECENT) + _, err = clt.UpsertClusterNetworkingConfig(ctx, cnc) + require.NoError(t, err) + + rsp, err = clt.ResolveSSHTarget(ctx, &proto.ResolveSSHTargetRequest{ + Host: "foo", + Port: "0", + }) + require.NoError(t, err) + require.Equal(t, "Foo", rsp.Server.GetHostname()) +} + func TestNodesCRUD(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/lib/proxy/router.go b/lib/proxy/router.go index f54f9718af604..b01d67c94125e 100644 --- a/lib/proxy/router.go +++ b/lib/proxy/router.go @@ -501,7 +501,10 @@ func getServerWithResolver(ctx context.Context, host, port string, site site, re } } case len(matches) > 1: - return nil, trace.NotFound(teleport.NodeIsAmbiguous) + // TODO(tross) DELETE IN V20.0.0 + // NodeIsAmbiguous is included in the error message for backwards compatibility + // with older nodes that expect to see that string in the error message. + return nil, trace.Wrap(teleport.ErrNodeIsAmbiguous, teleport.NodeIsAmbiguous) case len(matches) == 1: server = matches[0] } diff --git a/lib/proxy/router_test.go b/lib/proxy/router_test.go index 660f9fd435762..fcd2a4d1963a5 100644 --- a/lib/proxy/router_test.go +++ b/lib/proxy/router_test.go @@ -211,7 +211,7 @@ func TestRouteScoring(t *testing.T) { t.Run(tt.desc, func(t *testing.T) { srv, err := getServerWithResolver(ctx, tt.host, tt.port, site, resolver) if tt.ambiguous { - require.ErrorIs(t, err, trace.NotFound(teleport.NodeIsAmbiguous)) + require.ErrorIs(t, err, teleport.ErrNodeIsAmbiguous) return } require.Equal(t, tt.expect, srv.GetHostname()) @@ -375,7 +375,7 @@ func TestGetServers(t *testing.T) { site: testSite{cfg: &unambiguousCfg, nodes: servers}, host: "sheep", errAssertion: func(t require.TestingT, err error, i ...interface{}) { - require.ErrorIs(t, err, trace.NotFound(teleport.NodeIsAmbiguous)) + require.ErrorIs(t, err, teleport.ErrNodeIsAmbiguous) }, serverAssertion: func(t *testing.T, srv types.Server) { require.Empty(t, srv) @@ -456,7 +456,7 @@ func TestGetServers(t *testing.T) { site: testSite{cfg: &unambiguousInsensitiveCfg, nodes: servers}, host: "platypus", errAssertion: func(t require.TestingT, err error, i ...interface{}) { - require.ErrorIs(t, err, trace.NotFound(teleport.NodeIsAmbiguous)) + require.ErrorIs(t, err, teleport.ErrNodeIsAmbiguous) }, serverAssertion: func(t *testing.T, srv types.Server) { require.Empty(t, srv) @@ -670,7 +670,7 @@ func TestRouter_DialHost(t *testing.T) { clusterName: "test", log: logger, tracer: tracing.NoopTracer("test"), - serverResolver: serverResolver(nil, trace.NotFound(teleport.NodeIsAmbiguous)), + serverResolver: serverResolver(nil, teleport.ErrNodeIsAmbiguous), }, assertion: func(t *testing.T, params reversetunnelclient.DialParams, conn net.Conn, err error) { require.Error(t, err) diff --git a/lib/services/readonly/readonly.go b/lib/services/readonly/readonly.go index c4ed3185ace66..744f2b4cd3a5c 100644 --- a/lib/services/readonly/readonly.go +++ b/lib/services/readonly/readonly.go @@ -71,6 +71,7 @@ func sealAuthPreference(p types.AuthPreference) AuthPreference { type ClusterNetworkingConfig interface { GetCaseInsensitiveRouting() bool GetWebIdleTimeout() time.Duration + GetRoutingStrategy() types.RoutingStrategy Clone() types.ClusterNetworkingConfig } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 9326140447eac..74799aa94dc5b 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -922,7 +922,7 @@ func (t *sshBaseHandler) connectToNode(ctx context.Context, ws terminal.WSConn, if err != nil { t.log.WithError(err).Warn("Unable to stream terminal - failed to dial host.") - if errors.Is(err, trace.NotFound(teleport.NodeIsAmbiguous)) { + if errors.Is(err, teleport.ErrNodeIsAmbiguous) { const message = "error: ambiguous host could match multiple nodes\n\nHint: try addressing the node by unique id (ex: user@node-id)\n" return nil, trace.NotFound(message) } diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index aa6fb5cbf5666..db9252ee40233 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -3920,7 +3920,9 @@ func onSSH(cf *CLIConf) error { err = client.RetryWithRelogin(cf.Context, tc, sshFunc) } if err != nil { - if strings.Contains(utils.UserMessageFromError(err), teleport.NodeIsAmbiguous) { + if errors.Is(err, teleport.ErrNodeIsAmbiguous) || + // TODO(tross) DELETE IN v20.0.0 + strings.Contains(utils.UserMessageFromError(err), teleport.NodeIsAmbiguous) { clt, err := tc.ConnectToCluster(cf.Context) if err != nil { return trace.Wrap(err) diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index e19a3b945517d..792c507b5c55d 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -7028,8 +7028,7 @@ func TestSCP(t *testing.T) { return filepath.Join(dir, targetFile1) }, assertion: func(tt require.TestingT, err error, i ...any) { - require.Error(tt, err, i...) - require.ErrorContains(tt, err, "multiple matching hosts", i...) + require.ErrorIs(tt, err, teleport.ErrNodeIsAmbiguous, i...) }, }, { @@ -7088,8 +7087,7 @@ func TestSCP(t *testing.T) { return "dev.example.com:" + filepath.Join(dir, targetFile1) }, assertion: func(tt require.TestingT, err error, i ...any) { - require.Error(tt, err, i...) - require.ErrorContains(tt, err, "multiple matching hosts", i...) + require.ErrorIs(tt, err, teleport.ErrNodeIsAmbiguous, i...) }, }, {