Skip to content

Commit

Permalink
Add JSON encoding support to posture checks
Browse files Browse the repository at this point in the history
  • Loading branch information
bcmmbaga committed Jan 5, 2024
1 parent e9e0041 commit ce3e080
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 0 deletions.
66 changes: 66 additions & 0 deletions management/server/posture/checks.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package posture

import (
"encoding/json"

nbpeer "github.com/netbirdio/netbird/management/server/peer"
)

const (
NBVersionCheckName = "NBVersionCheck"
)

// Check represents an interface for performing a check on a peer.
type Check interface {
Check(peer nbpeer.Peer) error
Name() string
}

type Checks struct {
Expand All @@ -31,6 +38,37 @@ func (*Checks) TableName() string {
return "posture_checks"
}

// MarshalJSON returns the JSON encoding of the Checks object.
// The Checks object is marshaled as a map[string]json.RawMessage,
// where the key is the name of the check and the value is the JSON
// representation of the Check object.
func (pc *Checks) MarshalJSON() ([]byte, error) {
type Alias Checks
return json.Marshal(&struct {
Checks map[string]json.RawMessage
*Alias
}{
Checks: pc.marshalChecks(),
Alias: (*Alias)(pc),
})
}

// UnmarshalJSON unmarshal the JSON data into the Checks object.
func (pc *Checks) UnmarshalJSON(data []byte) error {
type Alias Checks
aux := &struct {
Checks map[string]json.RawMessage
*Alias
}{
Alias: (*Alias)(pc),
}

if err := json.Unmarshal(data, &aux); err != nil {
return err
}
return pc.unmarshalChecks(aux.Checks)
}

// Copy returns a copy of a policy rule.
func (pc *Checks) Copy() *Checks {
checks := &Checks{
Expand All @@ -43,3 +81,31 @@ func (pc *Checks) Copy() *Checks {
copy(checks.Checks, pc.Checks)
return checks
}

func (pc *Checks) marshalChecks() map[string]json.RawMessage {
result := make(map[string]json.RawMessage)
for _, check := range pc.Checks {
data, err := json.Marshal(check)
if err != nil {
return result
}
result[check.Name()] = data
}
return result
}

func (pc *Checks) unmarshalChecks(rawChecks map[string]json.RawMessage) error {
pc.Checks = make([]Check, 0, len(rawChecks))

for name, rawCheck := range rawChecks {
switch name {

Check failure on line 101 in management/server/posture/checks.go

View workflow job for this annotation

GitHub Actions / lint (macos-latest)

singleCaseSwitch: should rewrite switch statement to if statement (gocritic)

Check failure on line 101 in management/server/posture/checks.go

View workflow job for this annotation

GitHub Actions / lint (windows-latest)

singleCaseSwitch: should rewrite switch statement to if statement (gocritic)

Check failure on line 101 in management/server/posture/checks.go

View workflow job for this annotation

GitHub Actions / lint (ubuntu-latest)

singleCaseSwitch: should rewrite switch statement to if statement (gocritic)
case NBVersionCheckName:
check := &NBVersionCheck{}
if err := json.Unmarshal(rawCheck, check); err != nil {
return err
}
pc.Checks = append(pc.Checks, check)
}
}
return nil
}
155 changes: 155 additions & 0 deletions management/server/posture/checks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package posture

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestChecks_MarshalJSON(t *testing.T) {
tests := []struct {
name string
checks *Checks
want []byte
wantErr bool
}{
{
name: "Valid Posture Checks Marshal",
checks: &Checks{
ID: "id1",
Name: "name1",
Description: "desc1",
AccountID: "acc1",
Checks: []Check{
&NBVersionCheck{
Enabled: true,
MinVersion: "1.0.0",
MaxVersion: "1.2.9",
},
},
},
want: []byte(`
{
"ID": "id1",
"Name": "name1",
"Description": "desc1",
"Checks": {
"NBVersionCheck": {
"Enabled": true,
"MinVersion": "1.0.0",
"MaxVersion": "1.2.9"
}
}
}
`),
wantErr: false,
},
{
name: "Empty Posture Checks Marshal",
checks: &Checks{
ID: "",
Name: "",
Description: "",
AccountID: "",
Checks: []Check{
&NBVersionCheck{},
},
},
want: []byte(`
{
"ID": "",
"Name": "",
"Description": "",
"Checks": {
"NBVersionCheck": {
"Enabled": false,
"MinVersion": "",
"MaxVersion": ""
}
}
}
`),
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.checks.MarshalJSON()
if (err != nil) != tt.wantErr {
t.Errorf("Checks.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}

assert.JSONEq(t, string(got), string(tt.want))
assert.Equal(t, tt.checks, tt.checks.Copy(), "original Checks should not be modified")
})
}
}

func TestChecks_UnmarshalJSON(t *testing.T) {
testCases := []struct {
name string
in []byte
expected *Checks
expectedError bool
}{
{
name: "Valid JSON Posture Checks Unmarshal",
in: []byte(`
{
"ID": "id1",
"Name": "name1",
"Description": "desc1",
"Checks": {
"NBVersionCheck": {
"Enabled": true,
"MinVersion": "1.0.0",
"MaxVersion": "1.2.9"
}
}
}
`),
expected: &Checks{
ID: "id1",
Name: "name1",
Description: "desc1",
Checks: []Check{
&NBVersionCheck{
Enabled: true,
MinVersion: "1.0.0",
MaxVersion: "1.2.9",
},
},
},
expectedError: false,
},
{
name: "Invalid JSON Posture Checks Unmarshal",
in: []byte(`{`),
expectedError: true,
},
{
name: "Empty JSON Posture Check Unmarshal",
in: []byte(`{}`),
expected: &Checks{
Checks: make([]Check, 0),
},
expectedError: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
checks := &Checks{}

err := checks.UnmarshalJSON(tc.in)
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expected, checks)
}
})
}
}
4 changes: 4 additions & 0 deletions management/server/posture/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ func (n *NBVersionCheck) Check(peer nbpeer.Peer) error {
n.MaxVersion,
)
}

func (n *NBVersionCheck) Name() string {
return NBVersionCheckName
}

0 comments on commit ce3e080

Please sign in to comment.