Skip to content

Commit

Permalink
Modified vault unseal
Browse files Browse the repository at this point in the history
  • Loading branch information
Shifna12Zarnaz committed Sep 19, 2023
1 parent 5f30d22 commit 6bd617e
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 172 deletions.
2 changes: 1 addition & 1 deletion charts/vault-cred/Chart.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type: application
# This is the chart version. This version number should be incremented each time you make changes
# to the chart and its templates, including the app version.
# Versions are expected to follow Semantic Versioning (https://semver.org/)
version: 0.1.2
version: 0.1.3

# This is the version number of the application being deployed. This version number should be
# incremented each time you make changes to the application. Versions are not expected to
Expand Down
4 changes: 4 additions & 0 deletions charts/vault-cred/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ spec:
value: "{{ .Values.env.logLevel }}"
- name: VAULT_ADDR
value: "{{ .Values.vault.vaultAddress }}"
- name: VAULT_ADDR2
value: "{{ .Values.vault.vaultAddress2 }}"
- name: VAULT_ADDR3
value: "{{ .Values.vault.vaultAddress3 }}"
- name: HA_ENABLED
value: "{{ .Values.vault.haenabled }}"
- name: VAULT_SECRET_NAME
Expand Down
4 changes: 3 additions & 1 deletion charts/vault-cred/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ env:

vault:
haenabled: true
vaultAddress: http://vault-hash:8200
vaultAddress: http://vault-hash-1:8200
vaultAddress2: http://vault-hash-2:8200
vaultAddress3: http://vault-hash-3:8200
secretName: vault-server
secretTokenKeyName: roottoken
secretUnSealKeyPrefix: unsealkey
Expand Down
24 changes: 4 additions & 20 deletions internal/client/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (vc *VaultClient) DeleteCredential(ctx context.Context, mountPath, secretPa
}

func (vc *VaultClient) JoinRaftCluster(podip string, leaderaddress string) error {
// Construct the Vault API address

address := fmt.Sprintf("http://%s:8200", podip)

// Set the Vault client address
Expand All @@ -190,29 +190,13 @@ func (vc *VaultClient) JoinRaftCluster(podip string, leaderaddress string) error

vc.log.Debugf("Address: %s", address)

// Extract the leader address from the response

// Retrieve leader information
// leaderInfo, err := vc.c.Sys().Leader()
// if err != nil {
// return fmt.Errorf("failed to retrieve leader information: %v", err)
// }

//vc.log.Debugf("Leader address: %s", leaderInfo.LeaderAddress)

// if leaderInfo.LeaderAddress == "" {
// // Handle the case where leader address is empty
// vc.log.Debug("Leader address is empty")
// return fmt.Errorf("leader address is empty")
// }

req := &api.RaftJoinRequest{
Retry: true,
LeaderAPIAddr: leaderaddress,
//LeaderAPIAddr: leaderInfo.LeaderAddress,

}

//vc.log.Debugf("Leader API address: %s", leaderInfo.LeaderAddress)


_, err := vc.c.Sys().RaftJoin(req)
if err != nil {
Expand All @@ -230,7 +214,7 @@ func (vc *VaultClient) LeaderAPIAddr(podip string) (string, error) {
return "", err
}

vc.log.Debugf("Address: %s", address)


leaderInfo, err := vc.c.Sys().Leader()
if err != nil {
Expand Down
76 changes: 1 addition & 75 deletions internal/client/vault_seal.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,87 +128,14 @@ func (vc *VaultClient) getVaultSecretValues() (string, []string, error) {
return rootToken, unsealKeys, nil
}

func (vc *VaultClient) UnsealVaultInstance(podip string, unsealKey []string) error {
// Create a Vault API client
vc.log.Debug("Checking Unseal status for vault Instance")
address := fmt.Sprintf("http://%s:8200", podip)
err := vc.c.SetAddress(address)
if err != nil {
vc.log.Errorf("Error while setting address")
}
vc.log.Debug("Address", address)

for _, key := range unsealKey {
unsealResponse, err := vc.c.Sys().Unseal(key)
if err != nil {
return errors.WithMessage(err, "error while unsealing")
}
if unsealResponse.Sealed {
vc.log.Debug("Vault is still sealed after unsealing attempt")
}
}

// Check if Vault is sealed and unseal if necessary

// Vault is sealed; unseal it
// unsealResponse, err := vc.c.Sys().Unseal(unsealKey)
// if err != nil {
// return err
// }

// if unsealResponse.Sealed {
// vc.log.Debug("Vault is still sealed after unsealing attempt")
// }

return nil
}

func (vc *VaultClient) GetVaultSecretValuesforMultiInstance() (string, []string, error) {
k8s, err := NewK8SClient(vc.log)
if err != nil {
return "", nil, errors.WithMessage(err, "error initializing k8s client")
}

vaultSec, err := k8s.GetSecret(context.Background(), vc.conf.VaultSecretName, vc.conf.VaultSecretNameSpace)
if err != nil {
if strings.Contains(err.Error(), "secret not found") {
vc.log.Debugf("secret %d not found", vc.conf.VaultSecretName)
return "", nil, nil
}

return "", nil, errors.WithMessage(err, "error fetching vault secret")
}

vc.log.Debugf("found %d vault secret values", len(vaultSec.Data))
unsealKeys := []string{}
var rootToken string
for key, val := range vaultSec.Data {
if strings.HasPrefix(key, vc.conf.VaultSecretUnSealKeyPrefix) {
// decodedValue, err := base64.StdEncoding.DecodeString(val)
if err != nil {
return "", nil, errors.WithMessage(err, "error decoding value")
}

unsealKeys = append(unsealKeys, val)
vc.log.Debug("Unseal Keys", unsealKeys)
continue
}
if strings.EqualFold(key, vc.conf.VaultSecretTokenKeyName) {
// decodedValue, err := base64.StdEncoding.DecodeString(val)
if err != nil {
return "", nil, errors.WithMessage(err, "error decoding root token")
}
rootToken = val
vc.log.Debug("Root Token Key", rootToken)
}
}
return rootToken, unsealKeys, nil
}

func (vc *VaultClient) IsVaultSealedForAllInstances(svc string) (bool, error) {
address := fmt.Sprintf("http://%s:8200", svc)
err := vc.c.SetAddress(address)
vc.log.Debug("Address for checking vault status", address)

if err != nil {
vc.log.Errorf("Error while setting address")
}
Expand All @@ -225,7 +152,6 @@ func (vc *VaultClient) GetPodIP(podName, namespace string) (string, error) {
return "", errors.WithMessage(err, "error initializing k8s client")
}

// Get the pod's IP address
pod, err := k8s.client.CoreV1().Pods(namespace).Get(context.TODO(), podName, metav1.GetOptions{})
if err != nil {
return "", err
Expand Down
117 changes: 65 additions & 52 deletions internal/job/vault_seal_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,64 +32,45 @@ func (v *VaultSealWatcher) CronSpec() string {

func (v *VaultSealWatcher) Run() {
v.log.Debug("started vault seal watcher job")
addresses := []string{
v.conf.Address,
v.conf.Address2,
v.conf.Adddress3,
}
k8sclient, err := client.NewK8SClient(v.log)
if err != nil {
v.log.Errorf("Error while connecting with k8s %s", err)
return
}
podname, err := k8sclient.GetVaultPodInstances(context.Background())
if err != nil {
v.log.Errorf("Error while retrieving vault instances %s", err)
return
}

var vc *client.VaultClient

var vaultClients []*client.VaultClient
for _, address := range addresses {
conf := v.conf
conf.Address = address
vc, err := client.NewVaultClient(v.log, conf)
if v.conf.HAEnabled {
v.log.Infof(" Vault HA ENABLED", v.conf.HAEnabled)

addresses := []string{
v.conf.Address,
v.conf.Address2,
v.conf.Adddress3,
}
k8sclient, err := client.NewK8SClient(v.log)
if err != nil {
v.log.Errorf("%s", err)
v.log.Errorf("Error while connecting with k8s %s", err)
return
}
podname, err := k8sclient.GetVaultPodInstances(context.Background())
if err != nil {
v.log.Errorf("Error while retrieving vault instances %s", err)
return
}

vaultClients = append(vaultClients, vc)
}

if v.conf.HAEnabled {

v.log.Infof("HA ENABLED", v.conf.HAEnabled)

for _, svc := range podname {
var vc *client.VaultClient

switch svc {
case "vault-hash-0":
vc = vaultClients[0]
var vaultClients []*client.VaultClient
for _, address := range addresses {
conf := v.conf
conf.Address = address
vc, err := client.NewVaultClient(v.log, conf)

podip, err := vc.GetPodIP(svc, v.conf.VaultSecretNameSpace)
if err != nil {
v.log.Errorf("failed to retrieve pod ip, %s", err)
return
}

v.conf.LeaderPodIp = podip
case "vault-hash-1":
vc = vaultClients[1]
if err != nil {
v.log.Errorf("%s", err)
return
}

case "vault-hash-2":
vc = vaultClients[2]
vaultClients = append(vaultClients, vc)
}

default:
for i, svc := range podname {

}
vc = vaultClients[i]

podip, err := vc.GetPodIP(svc, v.conf.VaultSecretNameSpace)
if err != nil {
Expand All @@ -104,13 +85,13 @@ func (v *VaultSealWatcher) Run() {
}
v.log.Info("Seal Status for %v", podip, res)
if res {
v.log.Info("vault is sealed, trying to unseal")
if svc == "vault-hash-0" {

if i == 0 {

v.log.Info("Unsealing for first instance")
podip, err := vc.GetPodIP(svc, v.conf.VaultSecretNameSpace)
v.conf.LeaderPodIp = podip
v.log.Info("Leader Ip", v.conf.LeaderPodIp)

if err != nil {
v.log.Errorf("failed to retrieve pod ip, %s", err)
return
Expand All @@ -122,13 +103,13 @@ func (v *VaultSealWatcher) Run() {
}

} else {
v.log.Info("Leader Pod Ip", v.conf.LeaderPodIp)

leaderaddr, err := vc.LeaderAPIAddr(v.conf.LeaderPodIp)
if err != nil {
v.log.Errorf("failed to retrieve leader address, %s", err)
return
}
v.log.Info("Leader Address", leaderaddr)
v.log.Debug("Leader Address", leaderaddr)
podip, err := vc.GetPodIP(svc, v.conf.VaultSecretNameSpace)
v.log.Infof("Unsealing for %v instance", podip)
if err != nil {
Expand All @@ -154,5 +135,37 @@ func (v *VaultSealWatcher) Run() {

}

} else {
vc, err := client.NewVaultClient(v.log, v.conf)
if err != nil {
v.log.Errorf("%s", err)
return
}

res, err := vc.IsVaultSealed()
if err != nil {
v.log.Errorf("failed to get vault seal status, %s", err)
return
}

if res {
v.log.Info("vault is sealed, trying to unseal")
err := vc.Unseal()
if err != nil {
v.log.Errorf("failed to unseal vault, %s", err)
return
}
v.log.Info("vault unsealed executed")

res, err := vc.IsVaultSealed()
if err != nil {
v.log.Errorf("failed to get vault seal status, %s", err)
return
}
v.log.Infof("vault sealed status: %v", res)
return
} else {
v.log.Debug("vault is in unsealed status")
}
}
}
46 changes: 23 additions & 23 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,28 @@ func initScheduler(log logging.Logger, cfg config.Configuration) (s *job.Schedul
}
}

// if cfg.VaultPolicyWatchInterval != "" {
// pj, err := job.NewVaultPolicyWatcher(log, cfg.VaultPolicyWatchInterval)
// if err != nil {
// log.Fatal("failed to init policy watcher job", err)
// }

// err = s.AddJob("vault-policy-watcher", pj)
// if err != nil {
// log.Fatal("failed to add policy watcher job", err)
// }
// }

// if cfg.VaultCredSyncInterval != "" {
// pj, err := job.NewVaultCredSync(log, cfg.VaultCredSyncInterval)
// if err != nil {
// log.Fatal("failed to init cred sync job", err)
// }

// err = s.AddJob("vault-cred-sync", pj)
// if err != nil {
// log.Fatal("failed to add cred sync job", err)
// }
// }
if cfg.VaultPolicyWatchInterval != "" {
pj, err := job.NewVaultPolicyWatcher(log, cfg.VaultPolicyWatchInterval)
if err != nil {
log.Fatal("failed to init policy watcher job", err)
}

err = s.AddJob("vault-policy-watcher", pj)
if err != nil {
log.Fatal("failed to add policy watcher job", err)
}
}

if cfg.VaultCredSyncInterval != "" {
pj, err := job.NewVaultCredSync(log, cfg.VaultCredSyncInterval)
if err != nil {
log.Fatal("failed to init cred sync job", err)
}

err = s.AddJob("vault-cred-sync", pj)
if err != nil {
log.Fatal("failed to add cred sync job", err)
}
}
return
}

0 comments on commit 6bd617e

Please sign in to comment.