Skip to content

Commit

Permalink
Merge pull request #4607 from twz123/join-client-hardening
Browse files Browse the repository at this point in the history
Use a ten second timeout for join requests
  • Loading branch information
twz123 authored Jun 17, 2024
2 parents a11b91b + dfb5031 commit 7f91d37
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 34 deletions.
27 changes: 20 additions & 7 deletions cmd/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -691,16 +691,29 @@ func joinController(ctx context.Context, tokenArg string, certRootDir string) (*
return nil, fmt.Errorf("wrong token type %s, expected type: controller-bootstrap", joinClient.JoinTokenType())
}

logrus.Info("Joining existing cluster via ", joinClient.Address())

var caData v1beta1.CaResponse
err = retry.Do(func() error {
caData, err = joinClient.GetCA()
retryErr := retry.Do(
func() error {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
caData, err = joinClient.GetCA(ctx)
return err
},
retry.Context(ctx),
retry.LastErrorOnly(true),
retry.OnRetry(func(attempt uint, err error) {
logrus.WithError(err).Debug("Failed to join in attempt #", attempt+1, ", retrying after backoff")
}),
)
if retryErr != nil {
if err != nil {
return fmt.Errorf("failed to sync CA: %w", err)
retryErr = err
}
return nil
}, retry.Context(ctx))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to join existing cluster via %s: %w", joinClient.Address(), retryErr)
}

logrus.Info("Got valid CA response, storing certificates")
return joinClient, writeCerts(caData, certRootDir)
}
33 changes: 22 additions & 11 deletions pkg/component/controller/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"strings"
"time"

"github.com/avast/retry-go"
"github.com/sirupsen/logrus"
"go.etcd.io/etcd/client/pkg/v3/tlsutil"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -93,19 +94,29 @@ func (e *Etcd) Init(_ context.Context) error {
return assets.Stage(e.K0sVars.BinDir, "etcd", constant.BinDirMode)
}

func (e *Etcd) syncEtcdConfig(peerURL, etcdCaCert, etcdCaCertKey string) ([]string, error) {
func (e *Etcd) syncEtcdConfig(ctx context.Context, peerURL, etcdCaCert, etcdCaCertKey string) ([]string, error) {
logrus.Info("Synchronizing etcd config with existing cluster via ", e.JoinClient.Address())

var etcdResponse v1beta1.EtcdResponse
var err error
for i := 0; i < 20; i++ {
logrus.Debugf("trying to sync etcd config")
etcdResponse, err = e.JoinClient.JoinEtcd(peerURL)
if err == nil {
break
retryErr := retry.Do(
func() error {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
etcdResponse, err = e.JoinClient.JoinEtcd(ctx, peerURL)
return err
},
retry.Context(ctx),
retry.LastErrorOnly(true),
retry.OnRetry(func(attempt uint, err error) {
logrus.WithError(err).Debug("Failed to synchronize etcd config in attempt #", attempt+1, ", retrying after backoff")
}),
)
if retryErr != nil {
if err != nil {
retryErr = err
}
time.Sleep(500 * time.Millisecond)
}
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to synchronize etcd config with existing cluster via %s: %w", e.JoinClient.Address(), retryErr)
}

logrus.Debugf("got cluster info: %v", etcdResponse.InitialCluster)
Expand Down Expand Up @@ -179,7 +190,7 @@ func (e *Etcd) Start(ctx context.Context) error {
if file.Exists(filepath.Join(e.K0sVars.EtcdDataDir, "member", "snap", "db")) {
logrus.Warnf("etcd db file(s) already exist, not gonna run join process")
} else if e.JoinClient != nil {
initialCluster, err := e.syncEtcdConfig(peerURL, etcdCaCert, etcdCaCertKey)
initialCluster, err := e.syncEtcdConfig(ctx, peerURL, etcdCaCert, etcdCaCertKey)
if err != nil {
return fmt.Errorf("failed to sync etcd config: %w", err)
}
Expand Down
31 changes: 15 additions & 16 deletions pkg/token/joinclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package token

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
Expand All @@ -27,7 +28,6 @@ import (
"os"

"github.com/k0sproject/k0s/pkg/apis/k0s/v1beta1"
"github.com/sirupsen/logrus"
"k8s.io/client-go/tools/clientcmd"
)

Expand Down Expand Up @@ -74,16 +74,23 @@ func JoinClientFromToken(encodedToken string) (*JoinClient, error) {
c.joinAddress = config.Host
c.joinTokenType = GetTokenType(&raw)

logrus.Info("initialized join client successfully")
return c, nil
}

func (j *JoinClient) Address() string {
return j.joinAddress
}

func (j *JoinClient) JoinTokenType() string {
return j.joinTokenType
}

// GetCA calls the CA sync API
func (j *JoinClient) GetCA() (v1beta1.CaResponse, error) {
func (j *JoinClient) GetCA(ctx context.Context) (v1beta1.CaResponse, error) {
var caData v1beta1.CaResponse
req, err := http.NewRequest(http.MethodGet, j.joinAddress+"/v1beta1/ca", nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, j.joinAddress+"/v1beta1/ca", nil)
if err != nil {
return caData, err
return caData, fmt.Errorf("failed to create join request: %w", err)
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", j.bearerToken))

Expand All @@ -96,10 +103,6 @@ func (j *JoinClient) GetCA() (v1beta1.CaResponse, error) {
if resp.StatusCode != http.StatusOK {
return caData, fmt.Errorf("unexpected response status: %s", resp.Status)
}
logrus.Info("got valid CA response")
if resp.Body == nil {
return caData, fmt.Errorf("response body is nil")
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return caData, err
Expand All @@ -112,7 +115,7 @@ func (j *JoinClient) GetCA() (v1beta1.CaResponse, error) {
}

// JoinEtcd calls the etcd join API
func (j *JoinClient) JoinEtcd(peerAddress string) (v1beta1.EtcdResponse, error) {
func (j *JoinClient) JoinEtcd(ctx context.Context, peerAddress string) (v1beta1.EtcdResponse, error) {
var etcdResponse v1beta1.EtcdResponse
etcdRequest := v1beta1.EtcdRequest{
PeerAddress: peerAddress,
Expand All @@ -128,9 +131,9 @@ func (j *JoinClient) JoinEtcd(peerAddress string) (v1beta1.EtcdResponse, error)
return etcdResponse, err
}

req, err := http.NewRequest(http.MethodPost, j.joinAddress+"/v1beta1/etcd/members", buf)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, j.joinAddress+"/v1beta1/etcd/members", buf)
if err != nil {
return etcdResponse, err
return etcdResponse, fmt.Errorf("failed to create join request: %w", err)
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", j.bearerToken))
resp, err := j.httpClient.Do(req)
Expand All @@ -153,7 +156,3 @@ func (j *JoinClient) JoinEtcd(peerAddress string) (v1beta1.EtcdResponse, error)

return etcdResponse, nil
}

func (j *JoinClient) JoinTokenType() string {
return j.joinTokenType
}
100 changes: 100 additions & 0 deletions pkg/token/joinclient_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
Copyright 2024 k0s authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package token_test

import (
"bytes"
"context"
"net"
"net/http"
"net/url"
"testing"

"github.com/k0sproject/k0s/pkg/token"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestJoinClient_Cancellation(t *testing.T) {
t.Parallel()

for _, test := range []struct {
name string
funcUnderTest func(context.Context, *token.JoinClient) error
}{
{"GetCA", func(ctx context.Context, c *token.JoinClient) error {
_, err := c.GetCA(ctx)
return err
}},
{"JoinEtcd", func(ctx context.Context, c *token.JoinClient) error {
_, err := c.JoinEtcd(ctx, "")
return err
}},
} {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()

clientContext, cancelClientContext := context.WithCancelCause(context.Background())
joinURL := startFakeJoinServer(t, func(_ http.ResponseWriter, req *http.Request) {
cancelClientContext(assert.AnError) // cancel the client's context
<-req.Context().Done() // block forever
})

kubeconfig, err := token.GenerateKubeconfig(joinURL.String(), nil, "", "")
require.NoError(t, err)
tok, err := token.JoinEncode(bytes.NewReader(kubeconfig))
require.NoError(t, err)

underTest, err := token.JoinClientFromToken(tok)
require.NoError(t, err)

err = test.funcUnderTest(clientContext, underTest)
assert.ErrorIs(t, err, context.Canceled, "Expected the call to be cancelled")
assert.Same(t, context.Cause(clientContext), assert.AnError, "Didn't receive an HTTP request")
})
}
}

func startFakeJoinServer(t *testing.T, handler func(http.ResponseWriter, *http.Request)) *url.URL {
requestCtx, cancelRequests := context.WithCancel(context.Background())

listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
require.NoError(t, err)
}

server := &http.Server{
Addr: listener.Addr().String(),
Handler: http.HandlerFunc(handler),
BaseContext: func(net.Listener) context.Context { return requestCtx },
}

serverError := make(chan error)
go func() { defer close(serverError); serverError <- server.Serve(listener) }()

t.Cleanup(func() {
cancelRequests()
if !assert.NoError(t, server.Shutdown(context.Background()), "Couldn't shutdown HTTP server") {
return
}
assert.ErrorIs(t, <-serverError, http.ErrServerClosed, "HTTP server terminated unexpectedly")
})

return &url.URL{Scheme: "http", Host: server.Addr}
}

0 comments on commit 7f91d37

Please sign in to comment.