Skip to content

Commit

Permalink
Implement ResolveSSHTarget in the auth server
Browse files Browse the repository at this point in the history
  • Loading branch information
rosstimothy committed Dec 16, 2024
1 parent 242f1cc commit 1999d6a
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 20 deletions.
8 changes: 8 additions & 0 deletions constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ package teleport
import (
"strings"
"time"

"github.com/gravitational/trace"
)

// WebAPIVersion is a current webapi version
Expand Down Expand Up @@ -823,9 +825,15 @@ const (
UsageWindowsDesktopOnly = "usage:windows_desktop"
)

// ErrNodeIsAmbiguous serves as an identifying error string indicating that
// the proxy subsystem found multiple nodes matching the specified hostname.
var ErrNodeIsAmbiguous = &trace.NotFoundError{Message: "ambiguous host could match multiple nodes"}

const (
// NodeIsAmbiguous serves as an identifying error string indicating that
// the proxy subsystem found multiple nodes matching the specified hostname.
// TODO(tross) DELETE IN v20.0.0
// Deprecated: Prefer using ErrNodeIsAmbiguous
NodeIsAmbiguous = "err-node-is-ambiguous"

// MaxLeases serves as an identifying error string indicating that the
Expand Down
4 changes: 1 addition & 3 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,9 +809,7 @@ func testUUIDBasedProxy(t *testing.T, suite *integrationTestSuite) {
// attempting to run a command by hostname should generate NodeIsAmbiguous error.
_, err = runCommand(t, teleportSvr, []string{"echo", "Hello there!"}, helpers.ClientConfig{Login: suite.Me.Username, Cluster: helpers.Site, Host: Host}, 1)
require.Error(t, err)
if !strings.Contains(err.Error(), teleport.NodeIsAmbiguous) {
require.FailNowf(t, "Expected %s, got %s", teleport.NodeIsAmbiguous, err.Error())
}
require.ErrorContains(t, err, "ambiguous")

// attempting to run a command by uuid should succeed.
_, err = runCommand(t, teleportSvr, []string{"echo", "Hello there!"}, helpers.ClientConfig{Login: suite.Me.Username, Cluster: helpers.Site, Host: uuid1}, 1)
Expand Down
91 changes: 85 additions & 6 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -1648,22 +1648,23 @@ func (a *ServerWithRoles) GetSSHTargets(ctx context.Context, req *proto.GetSSHTa
return nil, trace.Wrap(err)
}

lreq := proto.ListResourcesRequest{
ResourceType: types.KindNode,
lreq := &proto.ListUnifiedResourcesRequest{
Kinds: []string{types.KindNode},
SortBy: types.SortBy{Field: types.ResourceMetadataName},
UseSearchAsRoles: true,
}
var servers []*types.ServerV2
for {
// note that we're calling ServerWithRoles.ListResources here rather than some internal method. This method
// note that we're calling ServerWithRoles.ListUnifiedResources here rather than some internal method. This method
// delegates all RBAC filtering to ListResources, and then performs additional filtering on top of that.
lrsp, err := a.ListResources(ctx, lreq)
lrsp, err := a.ListUnifiedResources(ctx, lreq)
if err != nil {
return nil, trace.Wrap(err)
}

for _, rsc := range lrsp.Resources {
srv, ok := rsc.(*types.ServerV2)
if !ok {
srv := rsc.GetNode()
if srv == nil {
log.Warnf("Unexpected resource type %T, expected *types.ServerV2 (skipping)", rsc)
continue
}
Expand All @@ -1687,6 +1688,84 @@ func (a *ServerWithRoles) GetSSHTargets(ctx context.Context, req *proto.GetSSHTa
}, nil
}

// ResolveSSHTarget gets a server that would match an equivalent ssh dial request.
func (a *ServerWithRoles) ResolveSSHTarget(ctx context.Context, req *proto.ResolveSSHTargetRequest) (*proto.ResolveSSHTargetResponse, error) {
var servers []*types.ServerV2
switch {
case req.Host != "":
resp, err := a.GetSSHTargets(ctx, &proto.GetSSHTargetsRequest{
Host: req.Host,
Port: req.Port,
})
if err != nil {
return nil, trace.Wrap(err)
}

servers = resp.Servers
case len(req.Labels) > 0 || req.PredicateExpression != "" || len(req.SearchKeywords) > 0:
lreq := &proto.ListUnifiedResourcesRequest{
Kinds: []string{types.KindNode},
SortBy: types.SortBy{Field: types.ResourceMetadataName},
Labels: req.Labels,
PredicateExpression: req.PredicateExpression,
SearchKeywords: req.SearchKeywords,
}
for {
// note that we're calling ServerWithRoles.ListUnifiedResources here rather than some internal method. This method
// delegates all RBAC filtering to ListResources, and then performs additional filtering on top of that.
lrsp, err := a.ListUnifiedResources(ctx, lreq)
if err != nil {
return nil, trace.Wrap(err)
}

for _, rsc := range lrsp.Resources {
srv := rsc.GetNode()
if srv == nil {
log.Warnf("Unexpected resource type %T, expected *types.ServerV2 (skipping)", rsc)
continue
}

servers = append(servers, srv)
}

if lrsp.NextKey == "" || len(lrsp.Resources) == 0 {
break
}

lreq.StartKey = lrsp.NextKey
}
default:
return nil, trace.NotFound("no matching hosts")
}

switch len(servers) {
case 1:
return &proto.ResolveSSHTargetResponse{Server: servers[0]}, nil
case 0:
return nil, trace.NotFound("no matching hosts")
default:
// try to detect case-insensitive routing setting, but default to false if we can't load
// networking config (equivalent to proxy routing behavior).
var routeToMostRecent bool
if cfg, err := a.authServer.GetReadOnlyClusterNetworkingConfig(ctx); err == nil {
routeToMostRecent = cfg.GetRoutingStrategy() == types.RoutingStrategy_MOST_RECENT
}

if !routeToMostRecent {
return nil, trace.Wrap(teleport.ErrNodeIsAmbiguous)
}

// Sort the resource by expiry so we can identify the most "recent" server.
slices.SortFunc(servers, func(a, b *types.ServerV2) int {
return a.Expiry().Compare(b.Expiry())
})

// Sorting above is oldest expiry to newest expiry, so proceed
// with the last server in the slice.
return &proto.ResolveSSHTargetResponse{Server: servers[len(servers)-1]}, nil
}
}

// ListResources returns a paginated list of resources filtered by user access.
func (a *ServerWithRoles) ListResources(ctx context.Context, req proto.ListResourcesRequest) (*types.ListResourcesResponse, error) {
// Check if auth server has a license for this resource type but only return an
Expand Down
3 changes: 3 additions & 0 deletions lib/auth/authclient/clt.go
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,9 @@ type ClientI interface {
// but may result in confusing behavior if it is used outside of those contexts.
GetSSHTargets(ctx context.Context, req *proto.GetSSHTargetsRequest) (*proto.GetSSHTargetsResponse, error)

// ResolveSSHTarget returns the server that would be resolved in an equivalent ssh dial request.
ResolveSSHTarget(ctx context.Context, req *proto.ResolveSSHTargetRequest) (*proto.ResolveSSHTargetResponse, error)

// PerformMFACeremony retrieves an MFA challenge from the server with the given challenge extensions
// and prompts the user to answer the challenge with the given promptOpts, and ultimately returning
// an MFA challenge response for the user.
Expand Down
15 changes: 15 additions & 0 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4198,6 +4198,21 @@ func (g *GRPCServer) GetSSHTargets(ctx context.Context, req *authpb.GetSSHTarget
return rsp, nil
}

// ResolveSSHTarget gets a server that would match an equivalent ssh dial request.
func (g *GRPCServer) ResolveSSHTarget(ctx context.Context, req *authpb.ResolveSSHTargetRequest) (*authpb.ResolveSSHTargetResponse, error) {
auth, err := g.authenticate(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

rsp, err := auth.ServerWithRoles.ResolveSSHTarget(ctx, req)
if err != nil {
return nil, trace.Wrap(err)
}

return rsp, nil
}

// CreateSessionTracker creates a tracker resource for an active session.
func (g *GRPCServer) CreateSessionTracker(ctx context.Context, req *authpb.CreateSessionTrackerRequest) (*types.SessionTrackerV1, error) {
auth, err := g.authenticate(ctx)
Expand Down
63 changes: 63 additions & 0 deletions lib/auth/grpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2988,6 +2988,69 @@ func TestGetSSHTargets(t *testing.T) {
require.ElementsMatch(t, []string{rsp.Servers[0].GetHostname(), rsp.Servers[1].GetHostname()}, []string{"foo", "Foo"})
}

func TestResolveSSHTarget(t *testing.T) {
t.Parallel()
ctx := context.Background()
srv := newTestTLSServer(t)

clt, err := srv.NewClient(TestAdmin())
require.NoError(t, err)

upper, err := types.NewServerWithLabels(uuid.New().String(), types.KindNode, types.ServerSpecV2{
Hostname: "Foo",
UseTunnel: true,
}, nil)
require.NoError(t, err)
upper.SetExpiry(time.Now().Add(time.Hour))

lower, err := types.NewServerWithLabels(uuid.New().String(), types.KindNode, types.ServerSpecV2{
Hostname: "foo",
UseTunnel: true,
}, nil)
require.NoError(t, err)

other, err := types.NewServerWithLabels(uuid.New().String(), types.KindNode, types.ServerSpecV2{
Hostname: "bar",
UseTunnel: true,
}, nil)
require.NoError(t, err)

for _, node := range []types.Server{upper, lower, other} {
_, err = clt.UpsertNode(ctx, node)
require.NoError(t, err)
}

rsp, err := clt.ResolveSSHTarget(ctx, &proto.ResolveSSHTargetRequest{
Host: "foo",
Port: "0",
})
require.NoError(t, err)
require.Equal(t, "foo", rsp.Server.GetHostname())

cnc := types.DefaultClusterNetworkingConfig()
cnc.SetCaseInsensitiveRouting(true)
_, err = clt.UpsertClusterNetworkingConfig(ctx, cnc)
require.NoError(t, err)

rsp, err = clt.ResolveSSHTarget(ctx, &proto.ResolveSSHTargetRequest{
Host: "foo",
Port: "0",
})
require.Error(t, err)
require.Nil(t, rsp)

cnc.SetRoutingStrategy(types.RoutingStrategy_MOST_RECENT)
_, err = clt.UpsertClusterNetworkingConfig(ctx, cnc)
require.NoError(t, err)

rsp, err = clt.ResolveSSHTarget(ctx, &proto.ResolveSSHTargetRequest{
Host: "foo",
Port: "0",
})
require.NoError(t, err)
require.Equal(t, "Foo", rsp.Server.GetHostname())
}

func TestNodesCRUD(t *testing.T) {
t.Parallel()
ctx := context.Background()
Expand Down
5 changes: 4 additions & 1 deletion lib/proxy/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,10 @@ func getServerWithResolver(ctx context.Context, host, port string, site site, re
}
}
case len(matches) > 1:
return nil, trace.NotFound(teleport.NodeIsAmbiguous)
// TODO(tross) DELETE IN V20.0.0
// NodeIsAmbiguous is included in the error message for backwards compatibility
// with older nodes that expect to see that string in the error message.
return nil, trace.Wrap(teleport.ErrNodeIsAmbiguous, teleport.NodeIsAmbiguous)
case len(matches) == 1:
server = matches[0]
}
Expand Down
8 changes: 4 additions & 4 deletions lib/proxy/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func TestRouteScoring(t *testing.T) {
t.Run(tt.desc, func(t *testing.T) {
srv, err := getServerWithResolver(ctx, tt.host, tt.port, site, resolver)
if tt.ambiguous {
require.ErrorIs(t, err, trace.NotFound(teleport.NodeIsAmbiguous))
require.ErrorIs(t, err, teleport.ErrNodeIsAmbiguous)
return
}
require.Equal(t, tt.expect, srv.GetHostname())
Expand Down Expand Up @@ -375,7 +375,7 @@ func TestGetServers(t *testing.T) {
site: testSite{cfg: &unambiguousCfg, nodes: servers},
host: "sheep",
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorIs(t, err, trace.NotFound(teleport.NodeIsAmbiguous))
require.ErrorIs(t, err, teleport.ErrNodeIsAmbiguous)
},
serverAssertion: func(t *testing.T, srv types.Server) {
require.Empty(t, srv)
Expand Down Expand Up @@ -456,7 +456,7 @@ func TestGetServers(t *testing.T) {
site: testSite{cfg: &unambiguousInsensitiveCfg, nodes: servers},
host: "platypus",
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorIs(t, err, trace.NotFound(teleport.NodeIsAmbiguous))
require.ErrorIs(t, err, teleport.ErrNodeIsAmbiguous)
},
serverAssertion: func(t *testing.T, srv types.Server) {
require.Empty(t, srv)
Expand Down Expand Up @@ -670,7 +670,7 @@ func TestRouter_DialHost(t *testing.T) {
clusterName: "test",
log: logger,
tracer: tracing.NoopTracer("test"),
serverResolver: serverResolver(nil, trace.NotFound(teleport.NodeIsAmbiguous)),
serverResolver: serverResolver(nil, teleport.ErrNodeIsAmbiguous),
},
assertion: func(t *testing.T, params reversetunnelclient.DialParams, conn net.Conn, err error) {
require.Error(t, err)
Expand Down
1 change: 1 addition & 0 deletions lib/services/readonly/readonly.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func sealAuthPreference(p types.AuthPreference) AuthPreference {
type ClusterNetworkingConfig interface {
GetCaseInsensitiveRouting() bool
GetWebIdleTimeout() time.Duration
GetRoutingStrategy() types.RoutingStrategy
Clone() types.ClusterNetworkingConfig
}

Expand Down
2 changes: 1 addition & 1 deletion lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ func (t *sshBaseHandler) connectToNode(ctx context.Context, ws terminal.WSConn,
if err != nil {
t.log.WithError(err).Warn("Unable to stream terminal - failed to dial host.")

if errors.Is(err, trace.NotFound(teleport.NodeIsAmbiguous)) {
if errors.Is(err, teleport.ErrNodeIsAmbiguous) {
const message = "error: ambiguous host could match multiple nodes\n\nHint: try addressing the node by unique id (ex: user@node-id)\n"
return nil, trace.NotFound(message)
}
Expand Down
4 changes: 3 additions & 1 deletion tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -3920,7 +3920,9 @@ func onSSH(cf *CLIConf) error {
err = client.RetryWithRelogin(cf.Context, tc, sshFunc)
}
if err != nil {
if strings.Contains(utils.UserMessageFromError(err), teleport.NodeIsAmbiguous) {
if errors.Is(err, teleport.ErrNodeIsAmbiguous) ||
// TODO(tross) DELETE IN v20.0.0
strings.Contains(utils.UserMessageFromError(err), teleport.NodeIsAmbiguous) {
clt, err := tc.ConnectToCluster(cf.Context)
if err != nil {
return trace.Wrap(err)
Expand Down
6 changes: 2 additions & 4 deletions tool/tsh/common/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7028,8 +7028,7 @@ func TestSCP(t *testing.T) {
return filepath.Join(dir, targetFile1)
},
assertion: func(tt require.TestingT, err error, i ...any) {
require.Error(tt, err, i...)
require.ErrorContains(tt, err, "multiple matching hosts", i...)
require.ErrorIs(tt, err, teleport.ErrNodeIsAmbiguous, i...)
},
},
{
Expand Down Expand Up @@ -7088,8 +7087,7 @@ func TestSCP(t *testing.T) {
return "dev.example.com:" + filepath.Join(dir, targetFile1)
},
assertion: func(tt require.TestingT, err error, i ...any) {
require.Error(tt, err, i...)
require.ErrorContains(tt, err, "multiple matching hosts", i...)
require.ErrorIs(tt, err, teleport.ErrNodeIsAmbiguous, i...)
},
},
{
Expand Down

0 comments on commit 1999d6a

Please sign in to comment.