Skip to content

Commit

Permalink
Merge pull request #553 from opencybersecurityalliance/k2-fix-nullabl…
Browse files Browse the repository at this point in the history
…e-int

fix nullable int
  • Loading branch information
subbyte authored Jul 25, 2024
2 parents 7ebec8d + 0af883b commit 8d492cf
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 26 deletions.
11 changes: 8 additions & 3 deletions packages/kestrel_core/src/kestrel/cache/sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
from copy import copy
from tempfile import mkstemp
from typing import Any, Iterable, Mapping, MutableMapping, Optional
from uuid import UUID

Expand Down Expand Up @@ -46,12 +48,14 @@ class SqlCache(AbstractCache):
def __init__(
self,
initial_cache: Optional[Mapping[UUID, DataFrame]] = None,
session_id: Optional[UUID] = None,
debug: bool = False,
):
super().__init__()

basename = session_id or "cache"
self.db_path = f"{basename}.db"
if debug:
self.db_path = "local.db"
else:
_, self.db_path = mkstemp(suffix=".db")

# for an absolute file path, the three slashes are followed by the absolute path
# for a relative path, it's also three slashes?
Expand All @@ -68,6 +72,7 @@ def __init__(

def __del__(self):
self.connection.close()
os.remove(self.db_path)

def __getitem__(self, instruction_id: UUID) -> DataFrame:
return read_sql(self.cache_catalog[instruction_id], self.connection)
Expand Down
6 changes: 2 additions & 4 deletions packages/kestrel_core/src/kestrel/config/kestrel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
language:
default_sort_order: "desc"

# debug options
debug:
env_var: "KESTREL_DEBUG" # debug mode if the environment variable exists
cache_directory_path: "~/kestrel-debug-session" # put in user's home directory by default
# debug mode
debug: false

# default identifier attribute(s) of an entity across all datasource interfaces
# always provide a list as identifiers even it is a single identifier
Expand Down
13 changes: 9 additions & 4 deletions packages/kestrel_core/src/kestrel/mapping/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from functools import reduce
from typing import Any, Iterable, List, Optional, Tuple, Union

import numpy as np
import numpy
import yaml
from kestrel.exceptions import IncompleteDataMapping
from kestrel.ir.filter import ReferenceValue
from kestrel.mapping.transformers import run_transformer, run_transformer_on_series
from kestrel.utils import list_folder_files
from pandas import DataFrame
from pandas import DataFrame, Int64Dtype
from typeguard import typechecked

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -438,8 +438,13 @@ def translate_dataframe(df: DataFrame, to_native_nested_map: dict) -> DataFrame:
# Not actually a named function; it's a literal value map
df[col] = df[col].replace(transformer_name)
else:
df[col] = run_transformer_on_series(
s = run_transformer_on_series(
transformer_name, df[col].dropna()
)
df = df.replace({np.nan: None})
df[col] = s
# if the series is integers, use Int64 (Nullable int) to allow NaN/NA
# if not, pandas will use float64 by default, which gives .0
if s.dtype == numpy.int64:
df[col] = df[col].astype(Int64Dtype())
df = df.replace({numpy.nan: None})
return df
5 changes: 4 additions & 1 deletion packages/kestrel_core/src/kestrel/mapping/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def to_int(value) -> int:
return int(value)
except ValueError:
# Maybe it's a hexadecimal string?
return int(value, 16)
try:
return int(value, 16)
except:
return -1


@transformer
Expand Down
6 changes: 5 additions & 1 deletion packages/kestrel_core/src/kestrel/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from contextlib import AbstractContextManager
from os import environ
from typing import Iterable
from uuid import uuid4

Expand Down Expand Up @@ -28,8 +29,11 @@ def __init__(self):
self.irgraph = IRGraph()
self.config = load_kestrel_config()

if "KESTREL_DEBUG" in environ:
self.config["debug"] = True

# load all interfaces; cache is a special interface
cache = SqlCache()
cache = SqlCache(debug=self.config["debug"])

# Python analytics are "built-in"
pyanalytics = PythonAnalyticsInterface()
Expand Down
8 changes: 4 additions & 4 deletions packages/kestrel_core/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_env_vars_in_config_overwrite():
credentials:
username: ${TEST_USER}
password: ${TEST_PASSWORD}
debug:
testattr:
cache_directory_prefix: $KESTREL_CACHE_DIRECTORY_PREFIX
"""
os.environ["TEST_USER"] = "test-user"
Expand All @@ -39,14 +39,14 @@ def test_env_vars_in_config_overwrite():
config = load_kestrel_config()
assert config["credentials"]["username"] == "test-user"
assert config["credentials"]["password"] == "test-password"
assert config["debug"]["cache_directory_prefix"] == "Kestrel2.0-"
assert config["testattr"]["cache_directory_prefix"] == "Kestrel2.0-"

def test_empty_env_var_in_config():
test_config = """---
credentials:
username: ${TEST_USER}
password: ${TEST_PASSWORD}
debug:
testattr:
cache_directory_prefix: $I_DONT_EXIST
"""
os.environ["TEST_USER"] = "test-user"
Expand All @@ -58,7 +58,7 @@ def test_empty_env_var_in_config():
config = load_kestrel_config()
assert config["credentials"]["username"] == "test-user"
assert config["credentials"]["password"] == "test-password"
assert config["debug"]["cache_directory_prefix"] == "$I_DONT_EXIST"
assert config["testattr"]["cache_directory_prefix"] == "$I_DONT_EXIST"

def test_yaml_load_in_config(tmp_path):
test_config = """---
Expand Down
10 changes: 2 additions & 8 deletions packages/kestrel_core/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ class Gateway(SqlCache):
def schemes():
return ["gateway"]

extra_db = []
with Session() as session:
stmt1 = """
procs = NEW process [ {"name": "cmd.exe", "pid": 123}
Expand All @@ -235,8 +234,7 @@ def schemes():
session.interface_manager[CACHE_INTERFACE_IDENTIFIER].__class__ = DataLake
session.irgraph.get_nodes_by_type_and_attributes(Construct, {"interface": CACHE_INTERFACE_IDENTIFIER})[0].interface = "datalake"

new_cache = SqlCache(session_id = uuid4())
extra_db.append(new_cache.db_path)
new_cache = SqlCache()
session.interface_manager.interfaces.append(new_cache)
stmt2 = """
nt = NEW network [ {"pid": 123, "source": "192.168.1.1", "destination": "1.1.1.1"}
Expand All @@ -248,8 +246,7 @@ def schemes():
session.interface_manager[CACHE_INTERFACE_IDENTIFIER].__class__ = Gateway
session.irgraph.get_nodes_by_type_and_attributes(Construct, {"interface": CACHE_INTERFACE_IDENTIFIER})[0].interface = "gateway"

new_cache = SqlCache(session_id = uuid4())
extra_db.append(new_cache.db_path)
new_cache = SqlCache()
session.interface_manager.interfaces.append(new_cache)
stmt3 = """
domain = NEW domain [ {"ip": "1.1.1.1", "domain": "cloudflare.com"}
Expand Down Expand Up @@ -307,9 +304,6 @@ def schemes():
df_ref = DataFrame([{"ip": "1.1.1.2", "domain": "xyz.cloudflare.com"}])
assert df_ref.equals(df_res)

for db_file in extra_db:
os.remove(db_file)


def test_apply_on_construct():
hf = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def gen_label_mapping(g: IRGraph) -> Mapping[Instruction, str]:

def to_html_blocks(d: Display) -> Iterable[str]:
if isinstance(d, DataFrame):
escaped_df = d.replace("\$", "\\\$", inplace=False, regex=True)
escaped_df = d.map(lambda x: x.replace("$", "\\$") if isinstance(x, str) else x)
# escaped_df = d.replace("\$", "\\\$", inplace=False, regex=True)
yield escaped_df.to_html(index=False, na_rep="")
elif isinstance(d, GraphExplanation):
for graphlet in d.graphlets:
Expand Down

0 comments on commit 8d492cf

Please sign in to comment.