Skip to content

Commit

Permalink
Implement numeric comparison modifiers (>, >=, <, <=) (#32)
Browse files Browse the repository at this point in the history
* Add type coercion helper

* Add testcase

* Refactor, add remaining comparators

* Move test up one layer of implementation detail
  • Loading branch information
bradleyjkemp authored Feb 15, 2023
1 parent 9c8e97b commit 47169b1
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 24 deletions.
6 changes: 5 additions & 1 deletion evaluator/evaluate_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,11 @@ func (rule *RuleEvaluator) matcherMatchesValues(matcherValues []string, comparat
valueMatchedEvent := false
// There are multiple possible event fields that each expected value needs to be compared against
for _, actualValue := range actualValues {
if comparator(actualValue, expectedValue) {
comparatorMatched, err := comparator(actualValue, expectedValue)
if err != nil {
// todo
}
if comparatorMatched {
valueMatchedEvent = true
break
}
Expand Down
52 changes: 52 additions & 0 deletions evaluator/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,55 @@ func TestRuleEvaluator_MatchesCaseInsensitive(t *testing.T) {
t.Error("expected first condition to be true and second condition to be false")
}
}

func TestRuleEvaluator_MatchesGreaterThan(t *testing.T) {
rule := ForRule(sigma.Rule{
Detection: sigma.Detection{
Searches: map[string]sigma.Search{
"foo1": {
EventMatchers: []sigma.EventMatcher{
{
{
Field: "foo-field",
Modifiers: []string{"gt"},
Values: []interface{}{
"1",
},
},
},
},
},
"foo0.5": {
EventMatchers: []sigma.EventMatcher{
{
{
Field: "foo-field",
Modifiers: []string{"gt"},
Values: []interface{}{
"0.5",
},
},
},
},
},
},
Conditions: []sigma.Condition{
{
Search: sigma.SearchIdentifier{Name: "foo0.5"},
},
},
},
})

result, err := rule.Matches(context.Background(), map[string]interface{}{
"foo-field": 0.75,
})
switch {
case err != nil:
t.Fatal(err)
case !result.Match:
t.Error("rule should have matched", result.SearchResults)
case !result.SearchResults["foo0.5"] || result.SearchResults["foo1"]:
t.Error("expected foo0.5 to be true but not foo1")
}
}
132 changes: 109 additions & 23 deletions evaluator/modifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,70 +3,156 @@ package evaluator
import (
"encoding/base64"
"fmt"
"gopkg.in/yaml.v3"
"net"
"reflect"
"regexp"
"strings"
)

type valueComparator func(actual interface{}, expected string) bool
type valueComparator func(actual interface{}, expected interface{}) (bool, error)

func baseComparator(actual interface{}, expected string) bool {
func baseComparator(actual interface{}, expected interface{}) (bool, error) {
switch {
case actual == nil && expected == "null":
// special case: "null" should match the case where a field isn't present (and so actual is nil)
return true
return true, nil
default:
// The Sigma spec defines that by default comparisons are case-insensitive
return strings.EqualFold(fmt.Sprintf("%v", actual), expected)
return strings.EqualFold(fmt.Sprint(actual), fmt.Sprint(expected)), nil
}
}

type valueModifier func(next valueComparator) valueComparator

var modifiers = map[string]valueModifier{
"contains": func(_ valueComparator) valueComparator {
return func(actual interface{}, expected string) bool {
return func(actual interface{}, expected interface{}) (bool, error) {
// The Sigma spec defines that by default comparisons are case-insensitive
return strings.Contains(strings.ToLower(fmt.Sprintf("%v", actual)), strings.ToLower(expected))
return strings.Contains(strings.ToLower(fmt.Sprint(actual)), strings.ToLower(fmt.Sprint(expected))), nil
}
},
"endswith": func(_ valueComparator) valueComparator {
return func(actual interface{}, expected string) bool {
return func(actual interface{}, expected interface{}) (bool, error) {
// The Sigma spec defines that by default comparisons are case-insensitive
return strings.HasSuffix(strings.ToLower(fmt.Sprintf("%v", actual)), strings.ToLower(expected))
return strings.HasSuffix(strings.ToLower(fmt.Sprint(actual)), strings.ToLower(fmt.Sprint(expected))), nil
}
},
"startswith": func(_ valueComparator) valueComparator {
return func(actual interface{}, expected string) bool {
return strings.HasPrefix(strings.ToLower(fmt.Sprintf("%v", actual)), strings.ToLower(expected))
return func(actual interface{}, expected interface{}) (bool, error) {
return strings.HasPrefix(strings.ToLower(fmt.Sprint(actual)), strings.ToLower(fmt.Sprint(expected))), nil
}
},
"base64": func(next valueComparator) valueComparator {
return func(actual interface{}, expected string) bool {
return next(actual, base64.StdEncoding.EncodeToString([]byte(expected)))
return func(actual interface{}, expected interface{}) (bool, error) {
return next(actual, base64.StdEncoding.EncodeToString([]byte(fmt.Sprint(expected))))
}
},
"re": func(_ valueComparator) valueComparator {
return func(actual interface{}, expected string) bool {
re, err := regexp.Compile(expected)
return func(actual interface{}, expected interface{}) (bool, error) {
re, err := regexp.Compile(fmt.Sprint(expected))
if err != nil {
// TODO: what to do here?
return false
return false, err
}

return re.MatchString(fmt.Sprintf("%v", actual))
return re.MatchString(fmt.Sprint(actual)), nil
}
},
"cidr": func(_ valueComparator) valueComparator {
return func(actual interface{}, expected string) bool {
_, cidr, err := net.ParseCIDR(expected)
return func(actual interface{}, expected interface{}) (bool, error) {
_, cidr, err := net.ParseCIDR(fmt.Sprint(expected))
if err != nil {
// TODO: what to do here?
return false
return false, err
}

ip := net.ParseIP(fmt.Sprintf("%v", actual))
return cidr.Contains(ip)
ip := net.ParseIP(fmt.Sprint(actual))
return cidr.Contains(ip), nil
}
},
"gt": func(_ valueComparator) valueComparator {
return func(actual interface{}, expected interface{}) (bool, error) {
gt, _, _, _, err := compareNumeric(actual, expected)
return gt, err
}
},
"gte": func(_ valueComparator) valueComparator {
return func(actual interface{}, expected interface{}) (bool, error) {
_, gte, _, _, err := compareNumeric(actual, expected)
return gte, err
}
},
"lt": func(_ valueComparator) valueComparator {
return func(actual interface{}, expected interface{}) (bool, error) {
_, _, lt, _, err := compareNumeric(actual, expected)
return lt, err
}
},
"lte": func(_ valueComparator) valueComparator {
return func(actual interface{}, expected interface{}) (bool, error) {
_, _, _, lte, err := compareNumeric(actual, expected)
return lte, err
}
},
}

// coerceNumeric makes both operands into the widest possible number of the same type
func coerceNumeric(left, right interface{}) (interface{}, interface{}, error) {
leftV := reflect.ValueOf(left)
leftType := reflect.ValueOf(left).Type()
rightV := reflect.ValueOf(right)
rightType := reflect.ValueOf(right).Type()

switch {
// Both integers or both floats? Return directly
case leftType.Kind() == reflect.Int && rightType.Kind() == reflect.Int:
fallthrough
case leftType.Kind() == reflect.Float64 && rightType.Kind() == reflect.Float64:
return left, right, nil

// Mixed integer, float? Return two floats
case leftType.Kind() == reflect.Int && rightType.Kind() == reflect.Float64:
fallthrough
case leftType.Kind() == reflect.Float64 && rightType.Kind() == reflect.Int:
floatType := reflect.TypeOf(float64(0))
return leftV.Convert(floatType).Interface(), rightV.Convert(floatType).Interface(), nil

// One or more strings? Parse and recurse.
// We use `yaml.Unmarshal` to parse the string because it's a cheat's way of parsing either an integer or a float
case leftType.Kind() == reflect.String:
var leftParsed interface{}
if err := yaml.Unmarshal([]byte(left.(string)), &leftParsed); err != nil {
return nil, nil, err
}
return coerceNumeric(leftParsed, right)
case rightType.Kind() == reflect.String:
var rightParsed interface{}
if err := yaml.Unmarshal([]byte(right.(string)), &rightParsed); err != nil {
return nil, nil, err
}
return coerceNumeric(left, rightParsed)

default:
return nil, nil, fmt.Errorf("cannot coerce %T and %T to numeric", left, right)
}
}

func compareNumeric(left, right interface{}) (gt, gte, lt, lte bool, err error) {
left, right, err = coerceNumeric(left, right)
if err != nil {
return
}

switch left.(type) {
case int:
left := left.(int)
right := right.(int)
return left > right, left >= right, left < right, left <= right, nil
case float64:
left := left.(float64)
right := right.(float64)
return left > right, left >= right, left < right, left <= right, nil
default:
err = fmt.Errorf("internal, please report! coerceNumeric returned unexpected types %T and %T", left, right)
return
}
}
46 changes: 46 additions & 0 deletions evaluator/modifiers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package evaluator

import (
"fmt"
"testing"
)

func Test_compareNumeric(t *testing.T) {
tests := []struct {
left interface{}
right interface{}
wantGt bool
wantGte bool
wantLt bool
wantLte bool
}{
{1, 2, false, false, true, true},
{1.1, 1.2, false, false, true, true},
{1, 1.2, false, false, true, true},
{1.1, 2, false, false, true, true},
{1, "2", false, false, true, true},
{"1.1", 1.2, false, false, true, true},
{"1.1", 1.1, false, true, false, true},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s_%s", tt.left, tt.right), func(t *testing.T) {
gotGt, gotGte, gotLt, gotLte, err := compareNumeric(tt.left, tt.right)
if err != nil {
t.Errorf("compareNumeric() error = %v", err)
return
}
if gotGt != tt.wantGt {
t.Errorf("compareNumeric() gotGt = %v, want %v", gotGt, tt.wantGt)
}
if gotGte != tt.wantGte {
t.Errorf("compareNumeric() gotGte = %v, want %v", gotGte, tt.wantGte)
}
if gotLt != tt.wantLt {
t.Errorf("compareNumeric() gotLt = %v, want %v", gotLt, tt.wantLt)
}
if gotLte != tt.wantLte {
t.Errorf("compareNumeric() gotLte = %v, want %v", gotLte, tt.wantLte)
}
})
}
}

0 comments on commit 47169b1

Please sign in to comment.