Skip to content

Commit

Permalink
Add ssh keepalive support
Browse files Browse the repository at this point in the history
Signed-off-by: Kimmo Lehto <[email protected]>
  • Loading branch information
kke committed Mar 11, 2024
1 parent fc225b5 commit 3c67645
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
46 changes: 45 additions & 1 deletion protocol/ssh/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ type Connection struct {
log.LoggerInjectable `yaml:"-"`
Config `yaml:",inline"`

options *Options

alias string
name string

Expand All @@ -42,6 +44,8 @@ type Connection struct {

client *ssh.Client

done chan struct{}

keyPaths []string
}

Expand All @@ -51,7 +55,7 @@ func NewConnection(cfg Config, opts ...Option) (*Connection, error) {
options.InjectLoggerTo(cfg, log.KeyProtocol, "ssh-config")
cfg.SetDefaults()

c := &Connection{Config: cfg}
c := &Connection{Config: cfg, options: options}
options.InjectLoggerTo(c, log.KeyProtocol, "ssh")

return c, nil
Expand Down Expand Up @@ -187,6 +191,15 @@ func (c *Connection) IPAddress() string {
return c.Address
}

// IsConnected returns true if the connection is open.
func (c *Connection) IsConnected() bool {
if c.client == nil || c.client.Conn == nil {
return false
}
_, _, err := c.client.Conn.SendRequest("keepalive@rig", true, nil)
return err == nil
}

// SSHConfigGetAll by default points to ssh_config package's GetAll() function
// you can override it with your own implementation for testing purposes.
var SSHConfigGetAll = ssh_config.GetAll
Expand All @@ -212,6 +225,9 @@ func (c *Connection) Disconnect() {
if c.client == nil {
return
}
if c.options.KeepAliveInterval != nil {
close(c.done)
}
c.client.Close()
}

Expand Down Expand Up @@ -414,9 +430,34 @@ func (c *Connection) connectViaBastion(dst string, config *ssh.ClientConfig) err
}
c.client = ssh.NewClient(client, chans, reqs)

c.startKeepalive()

return nil
}

func (c *Connection) startKeepalive() {
if c.options.KeepAliveInterval == nil {
return
}

c.done = make(chan struct{})
go func() {
ticker := time.NewTicker(*c.options.KeepAliveInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if !c.IsConnected() {
close(c.done)
return
}
case <-c.done:
return
}
}
}()
}

// Connect opens the SSH connection.
func (c *Connection) Connect() error {
c.SetDefaults()
Expand All @@ -440,6 +481,9 @@ func (c *Connection) Connect() error {
return fmt.Errorf("ssh dial: %w", err)
}
c.client = clientDirect

c.startKeepalive()

return nil
}

Expand Down
14 changes: 13 additions & 1 deletion protocol/ssh/options.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package ssh

import "github.com/k0sproject/rig/v2/log"
import (
"time"

"github.com/k0sproject/rig/v2/log"
)

// Options for the SSH client.
type Options struct {
log.LoggerInjectable
KeepAliveInterval *time.Duration
}

// Option is a function that sets some option on the Options struct.
Expand All @@ -25,3 +30,10 @@ func WithLogger(l log.Logger) Option {
o.SetLogger(l)
}
}

// WithKeepAlive sets the keep-alive interval option.
func WithKeepAlive(d time.Duration) Option {
return func(o *Options) {
o.KeepAliveInterval = &d
}
}

0 comments on commit 3c67645

Please sign in to comment.