diff --git a/api/client/client.go b/api/client/client.go index 28871aa871957..84d509e7b86d4 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -4261,6 +4261,12 @@ func (c *Client) GetSSHTargets(ctx context.Context, req *proto.GetSSHTargetsRequ return rsp, trace.Wrap(err) } +// ResolveSSHTarget gets a server that would match an equivalent ssh dial request. +func (c *Client) ResolveSSHTarget(ctx context.Context, req *proto.ResolveSSHTargetRequest) (*proto.ResolveSSHTargetResponse, error) { + rsp, err := c.grpc.ResolveSSHTarget(ctx, req) + return rsp, trace.Wrap(err) +} + // CreateSessionTracker creates a tracker resource for an active session. func (c *Client) CreateSessionTracker(ctx context.Context, st types.SessionTracker) (types.SessionTracker, error) { v1, ok := st.(*types.SessionTrackerV1) diff --git a/lib/client/api.go b/lib/client/api.go index 88693bc768bb4..9e3072069335f 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1489,7 +1489,7 @@ type TargetNode struct { func (tc *TeleportClient) GetTargetNodes(ctx context.Context, clt client.ListUnifiedResourcesClient, options SSHOptions) ([]TargetNode, error) { ctx, span := tc.Tracer.Start( ctx, - "teleportClient/getTargetNodes", + "teleportClient/GetTargetNodes", oteltrace.WithSpanKind(oteltrace.SpanKindClient), ) defer span.End() @@ -1553,6 +1553,112 @@ func (tc *TeleportClient) GetTargetNodes(ctx context.Context, clt client.ListUni }, nil } +// GetTargetNode returns a single host matching the target host provided by users. Host resolution +// honors an explicit host, i.e. tsh ssh user@hostname, label based hosts, i.e. tsh ssh user@foo=bar, +// as well as respecting any proxy templates that are specified. +func (tc *TeleportClient) GetTargetNode(ctx context.Context, clt authclient.ClientI, options *SSHOptions) (*TargetNode, error) { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/GetTargetNode", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + + if options != nil && options.HostAddress != "" { + return &TargetNode{ + Hostname: options.HostAddress, + Addr: options.HostAddress, + }, nil + } + + if len(tc.Labels) == 0 && len(tc.SearchKeywords) == 0 && tc.PredicateExpression == "" { + log.Debugf("Using provided host %s", tc.Host) + + // detect the common error when users use host:port address format + _, port, err := net.SplitHostPort(tc.Host) + // client has used host:port notation + if err == nil { + return nil, trace.BadParameter("please use ssh subcommand with '--port=%v' flag instead of semicolon", port) + } + + addr := net.JoinHostPort(tc.Host, strconv.Itoa(tc.HostPort)) + return &TargetNode{ + Hostname: tc.Host, + Addr: addr, + }, nil + } + + // Query for nodes if labels, fuzzy search, or predicate expressions were provided. + log.Debugf("Attempting to resolve matching host from labels=%v|search=%v|predicate=%v", tc.Labels, tc.SearchKeywords, tc.PredicateExpression) + resp, err := clt.ResolveSSHTarget(ctx, &proto.ResolveSSHTargetRequest{ + PredicateExpression: tc.PredicateExpression, + SearchKeywords: tc.SearchKeywords, + Labels: tc.Labels, + }) + switch { + //TODO(tross): DELETE IN v20.0.0 + case trace.IsNotImplemented(err): + resources, err := client.GetAllUnifiedResources(ctx, clt, &proto.ListUnifiedResourcesRequest{ + Kinds: []string{types.KindNode}, + SortBy: types.SortBy{Field: types.ResourceMetadataName}, + Labels: tc.Labels, + SearchKeywords: tc.SearchKeywords, + PredicateExpression: tc.PredicateExpression, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + if len(resources) == 0 { + return nil, trace.NotFound("no matching SSH hosts found for search terms or query expression") + } + + if len(resources) > 1 { + // If routing does not allow choosing the most recent host, then abort with + // an ambiguous host error. + cnc, err := clt.GetClusterNetworkingConfig(ctx) + if err != nil || cnc.GetRoutingStrategy() != types.RoutingStrategy_MOST_RECENT { + return nil, trace.BadParameter("found multiple matching SSH hosts %v", resources[:2]) + } + + // Sort the resource by expiry so we can identify the most "recent". + slices.SortFunc(resources, func(a, b *types.EnrichedResource) int { + return a.Expiry().Compare(b.Expiry()) + }) + + } + + // Sorting above is oldest expiry to newest expiry, so proceed + // with the last item server in the slice. + server, ok := resources[len(resources)-1].ResourceWithLabels.(types.Server) + if !ok { + return nil, trace.BadParameter("recevied unexpected resource type %T", resources[0].ResourceWithLabels) + } + + // Dialing is happening by UUID but a port is still required by + // the Proxy dial request. Zero is an indicator to the Proxy that + // it may chose the appropriate port based on the target server. + return &TargetNode{ + Hostname: server.GetHostname(), + Addr: server.GetName() + ":0", + }, nil + case err == nil: + if resp.GetServer() == nil { + return nil, trace.NotFound("no matching SSH hosts found") + } + + // Dialing is happening by UUID but a port is still required by + // the Proxy dial request. Zero is an indicator to the Proxy that + // it may chose the appropriate port based on the target server. + return &TargetNode{ + Hostname: resp.GetServer().GetHostname(), + Addr: resp.GetServer().GetName() + ":0", + }, nil + default: + return nil, trace.Wrap(err) + } +} + // ReissueUserCerts issues new user certs based on params and stores them in // the local key agent (usually on disk in ~/.tsh). func (tc *TeleportClient) ReissueUserCerts(ctx context.Context, cachePolicy CertCachePolicy, params ReissueParams) error { @@ -2434,19 +2540,11 @@ func (tc *TeleportClient) SFTP(ctx context.Context, source []string, destination defer clt.Close() // Respect any proxy templates and attempt host resolution. - resolvedNodes, err := tc.GetTargetNodes(ctx, clt.AuthClient, SSHOptions{}) + target, err := tc.GetTargetNode(ctx, clt.AuthClient, nil) if err != nil { return trace.Wrap(err) } - switch len(resolvedNodes) { - case 1: - case 0: - return trace.NotFound("no matching hosts found") - default: - return trace.BadParameter("multiple matching hosts found") - } - var cfg *sftp.Config switch { case isDownload: @@ -2469,7 +2567,7 @@ func (tc *TeleportClient) SFTP(ctx context.Context, source []string, destination } } - return trace.Wrap(tc.TransferFiles(ctx, clt, tc.HostLogin, resolvedNodes[0].Addr, cfg)) + return trace.Wrap(tc.TransferFiles(ctx, clt, tc.HostLogin, target.Addr, cfg)) } // TransferFiles copies files between the current machine and the diff --git a/lib/client/api_test.go b/lib/client/api_test.go index 9da016c7af401..8265ecd2846dc 100644 --- a/lib/client/api_test.go +++ b/lib/client/api_test.go @@ -43,6 +43,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/grpc/interceptors" "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" @@ -1380,6 +1381,192 @@ func TestGetTargetNodes(t *testing.T) { } } +type fakeGetTargetNodeClient struct { + authclient.ClientI + + nodes []*types.ServerV2 + resolved *types.ServerV2 + resolveErr error + routeToMostRecent bool +} + +func (f fakeGetTargetNodeClient) ListUnifiedResources(ctx context.Context, req *proto.ListUnifiedResourcesRequest) (*proto.ListUnifiedResourcesResponse, error) { + out := make([]*proto.PaginatedResource, 0, len(f.nodes)) + for _, n := range f.nodes { + out = append(out, &proto.PaginatedResource{Resource: &proto.PaginatedResource_Node{Node: n}}) + } + + return &proto.ListUnifiedResourcesResponse{Resources: out}, nil +} + +func (f fakeGetTargetNodeClient) ResolveSSHTarget(ctx context.Context, req *proto.ResolveSSHTargetRequest) (*proto.ResolveSSHTargetResponse, error) { + if f.resolveErr != nil { + return nil, f.resolveErr + } + + return &proto.ResolveSSHTargetResponse{Server: f.resolved}, nil +} + +func (f fakeGetTargetNodeClient) GetClusterNetworkingConfig(ctx context.Context) (types.ClusterNetworkingConfig, error) { + cfg := types.DefaultClusterNetworkingConfig() + if f.routeToMostRecent { + cfg.SetRoutingStrategy(types.RoutingStrategy_MOST_RECENT) + } + + return cfg, nil +} + +func TestGetTargetNode(t *testing.T) { + now := time.Now() + then := now.Add(-5 * time.Hour) + + tests := []struct { + name string + options *SSHOptions + labels map[string]string + search []string + predicate string + host string + port int + clt fakeGetTargetNodeClient + errAssertion require.ErrorAssertionFunc + expected TargetNode + }{ + { + name: "options override", + options: &SSHOptions{ + HostAddress: "test:1234", + }, + host: "llama", + port: 56789, + errAssertion: require.NoError, + expected: TargetNode{Hostname: "test:1234", Addr: "test:1234"}, + }, + { + name: "explicit target", + host: "test", + port: 1234, + errAssertion: require.NoError, + expected: TargetNode{Hostname: "test", Addr: "test:1234"}, + }, + { + name: "resolved labels", + labels: map[string]string{"foo": "bar"}, + errAssertion: require.NoError, + expected: TargetNode{Hostname: "resolved-labels", Addr: "abcd:0"}, + clt: fakeGetTargetNodeClient{ + nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "labels"}}}, + resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-labels"}}, + }, + }, + { + name: "fallback labels", + labels: map[string]string{"foo": "bar"}, + errAssertion: require.NoError, + expected: TargetNode{Hostname: "labels", Addr: "abcd:0"}, + clt: fakeGetTargetNodeClient{ + nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "labels"}}}, + resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-labels"}}, + resolveErr: trace.NotImplemented(""), + }, + }, + { + name: "resolved search", + search: []string{"foo", "bar"}, + errAssertion: require.NoError, + expected: TargetNode{Hostname: "resolved-search", Addr: "abcd:0"}, + clt: fakeGetTargetNodeClient{ + nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "search"}}}, + resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-search"}}, + }, + }, + + { + name: "fallback search", + search: []string{"foo", "bar"}, + errAssertion: require.NoError, + expected: TargetNode{Hostname: "search", Addr: "abcd:0"}, + clt: fakeGetTargetNodeClient{ + nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "search"}}}, + resolveErr: trace.NotImplemented(""), + resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-search"}}, + }, + }, + { + name: "resolved predicate", + predicate: `resource.spec.hostname == "test"`, + errAssertion: require.NoError, + expected: TargetNode{Hostname: "resolved-predicate", Addr: "abcd:0"}, + clt: fakeGetTargetNodeClient{ + nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "predicate"}}}, + resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-predicate"}}, + }, + }, + { + name: "fallback predicate", + predicate: `resource.spec.hostname == "test"`, + errAssertion: require.NoError, + expected: TargetNode{Hostname: "predicate", Addr: "abcd:0"}, + clt: fakeGetTargetNodeClient{ + nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "predicate"}}}, + resolveErr: trace.NotImplemented(""), + resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-predicate"}}, + }, + }, + { + name: "fallback ambiguous hosts", + predicate: `resource.spec.hostname == "test"`, + errAssertion: require.Error, + clt: fakeGetTargetNodeClient{ + nodes: []*types.ServerV2{ + {Metadata: types.Metadata{Name: "abcd-1"}, Spec: types.ServerSpecV2{Hostname: "predicate"}}, + {Metadata: types.Metadata{Name: "abcd-2"}, Spec: types.ServerSpecV2{Hostname: "predicate"}}, + }, + resolveErr: trace.NotImplemented(""), + resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-predicate"}}, + }, + }, + { + name: "fallback and route to recent", + predicate: `resource.spec.hostname == "test"`, + errAssertion: require.NoError, + expected: TargetNode{Hostname: "predicate-now", Addr: "abcd-1:0"}, + clt: fakeGetTargetNodeClient{ + nodes: []*types.ServerV2{ + {Metadata: types.Metadata{Name: "abcd-0", Expires: &then}, Spec: types.ServerSpecV2{Hostname: "predicate-then"}}, + {Metadata: types.Metadata{Name: "abcd-1", Expires: &now}, Spec: types.ServerSpecV2{Hostname: "predicate-now"}}, + {Metadata: types.Metadata{Name: "abcd-2", Expires: &then}, Spec: types.ServerSpecV2{Hostname: "predicate-then-again"}}, + }, + resolveErr: trace.NotImplemented(""), + routeToMostRecent: true, + resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-predicate"}}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clt := TeleportClient{ + Config: Config{ + Tracer: tracing.NoopTracer(""), + Labels: test.labels, + SearchKeywords: test.search, + PredicateExpression: test.predicate, + Host: test.host, + HostPort: test.port, + }, + } + + match, err := clt.GetTargetNode(context.Background(), test.clt, test.options) + test.errAssertion(t, err) + if match == nil { + match = &TargetNode{} + } + require.EqualValues(t, test.expected, *match) + }) + } +} + func TestNonRetryableError(t *testing.T) { orgError := trace.AccessDenied("do not enter") err := &NonRetryableError{ diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go index 4f0f1fee92135..5a11adbbf5309 100644 --- a/tool/tsh/common/proxy.go +++ b/tool/tsh/common/proxy.go @@ -36,7 +36,6 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" libclient "github.com/gravitational/teleport/lib/client" @@ -63,39 +62,20 @@ func onProxyCommandSSH(cf *CLIConf) error { return trace.Wrap(err) } - var target string - switch { - case tc.Host != "": - targetHost, targetPort, err := net.SplitHostPort(tc.Host) - if err != nil { - targetHost = tc.Host - targetPort = strconv.Itoa(tc.HostPort) - } - targetHost = cleanTargetHost(targetHost, tc.WebProxyHost(), clt.ClusterName()) - target = net.JoinHostPort(targetHost, targetPort) - case len(tc.SearchKeywords) != 0 || tc.PredicateExpression != "": - nodes, err := client.GetAllResources[types.Server](cf.Context, clt.AuthClient, tc.ResourceFilter(types.KindNode)) - if err != nil { - return trace.Wrap(err) - } - - if len(nodes) == 0 { - return trace.NotFound("no matching SSH hosts found for search terms or query expression") - } - - if len(nodes) > 1 { - return trace.BadParameter("found multiple matching SSH hosts %v", nodes[:2]) - } + targetHost, targetPort, err := net.SplitHostPort(tc.Host) + if err != nil { + targetHost = tc.Host + targetPort = strconv.Itoa(tc.HostPort) + } + targetHost = cleanTargetHost(targetHost, tc.WebProxyHost(), clt.ClusterName()) + tc.Host = net.JoinHostPort(targetHost, targetPort) - // Dialing is happening by UUID but a port is still required by - // the Proxy dial request. Zero is an indicator to the Proxy that - // it may chose the appropriate port based on the target server. - target = fmt.Sprintf("%s:0", nodes[0].GetName()) - default: - return trace.BadParameter("no hostname, search terms or query expression provided") + target, err := tc.GetTargetNode(cf.Context, clt.AuthClient, nil) + if err != nil { + return trace.Wrap(err) } - conn, _, err := clt.DialHostWithResumption(cf.Context, target, clt.ClusterName(), tc.LocalAgent().ExtendedAgent) + conn, _, err := clt.DialHostWithResumption(cf.Context, target.Addr, clt.ClusterName(), tc.LocalAgent().ExtendedAgent) if err != nil { return trace.Wrap(err) }