Skip to content

Commit

Permalink
Merge pull request #548 from opencybersecurityalliance/k2-cmd-info
Browse files Browse the repository at this point in the history
add command INFO
  • Loading branch information
subbyte authored Jul 22, 2024
2 parents e098e97 + 989c047 commit 9c2364d
Show file tree
Hide file tree
Showing 18 changed files with 122 additions and 14 deletions.
12 changes: 11 additions & 1 deletion packages/kestrel_core/src/kestrel/cache/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from kestrel.cache.base import AbstractCache
from kestrel.display import GraphletExplanation, NativeQuery
from kestrel.interface.codegen.sql import SqlTranslator
from kestrel.interface.codegen.utils import variable_attributes_to_dataframe
from kestrel.ir.graph import IRGraphEvaluable
from kestrel.ir.instructions import (
Construct,
Explain,
Filter,
Information,
Instruction,
Return,
SolePredecessorTransformingInstruction,
Expand Down Expand Up @@ -110,7 +112,15 @@ def evaluate_graph(
translator = self._evaluate_instruction_in_graph(graph, instruction)
# TODO: may catch error in case evaluation starts from incomplete SQL
_logger.debug(f"SQL query generated: {translator.result_w_literal_binds()}")
mapping[instruction.id] = read_sql(translator.result(), self.connection)
df = read_sql(translator.result(), self.connection)

# handle Information command
if isinstance(instruction, Return):
trunk, _ = graph.get_trunk_n_branches(instruction)
if isinstance(trunk, Information):
df = variable_attributes_to_dataframe(df)

mapping[instruction.id] = df
return mapping

def explain_graph(
Expand Down
4 changes: 4 additions & 0 deletions packages/kestrel_core/src/kestrel/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,7 @@ class InvalidProjectEntityFromEntity(KestrelError):

class InvalidMappingWithMultipleIdentifierFields(KestrelError):
pass


class InvalidAttributes(KestrelError):
pass
7 changes: 7 additions & 0 deletions packages/kestrel_core/src/kestrel/frontend/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
DataSource,
Explain,
Filter,
Information,
Instruction,
Limit,
Offset,
Expand Down Expand Up @@ -698,3 +699,9 @@ def explain(self, args) -> List[Return]:
explain = self.irgraph.add_node(Explain(), variable)
ret = self.irgraph.add_node(Return(), explain)
return [ret]

def info(self, args) -> List[Return]:
variable = self.irgraph.get_variable(args[0].value)
info = self.irgraph.add_node(Information(), variable)
ret = self.irgraph.add_node(Return(), info)
return [ret]
14 changes: 13 additions & 1 deletion packages/kestrel_core/src/kestrel/interface/codegen/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from typing import Callable

from kestrel.exceptions import (
InvalidAttributes,
InvalidOperatorInMultiColumnComparison,
MismatchedFieldValueInMultiColumnComparison,
)
from kestrel.interface.codegen.utils import variable_attributes_to_dataframe
from kestrel.ir.filter import (
AbsoluteTrue,
BoolExp,
Expand All @@ -24,6 +26,7 @@
from kestrel.ir.instructions import (
Construct,
Filter,
Information,
Limit,
ProjectAttrs,
ProjectEntity,
Expand Down Expand Up @@ -70,8 +73,17 @@ def _eval_Limit(instruction: Limit, dataframe: DataFrame) -> DataFrame:
return dataframe.head(instruction.num)


@typechecked
def _eval_Information(instruction: Information, dataframe: DataFrame) -> DataFrame:
return variable_attributes_to_dataframe(dataframe)


@typechecked
def _eval_ProjectAttrs(instruction: ProjectAttrs, dataframe: DataFrame) -> DataFrame:
cols = set(list(dataframe))
invalid_attrs = set(instruction.attrs) - cols
if invalid_attrs:
raise InvalidAttributes(list(invalid_attrs))
return dataframe[list(instruction.attrs)]


Expand All @@ -84,7 +96,7 @@ def _eval_ProjectEntity(instruction: ProjectEntity, dataframe: DataFrame) -> Dat
df = dataframe[
[col for col in dataframe if col.startswith(instruction.ocsf_field)]
]
df.rename(columns=lambda x: x[len(instruction.ocsf_field) + 1 :], inplace=True)
df = df.rename(columns=lambda x: x[len(instruction.ocsf_field) + 1 :])
df = df.drop_duplicates()
return df

Expand Down
16 changes: 14 additions & 2 deletions packages/kestrel_core/src/kestrel/interface/codegen/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import sqlalchemy
from kestrel.exceptions import (
InvalidAttributes,
InvalidMappingWithMultipleIdentifierFields,
InvalidProjectEntityFromEntity,
SourceSchemaNotFound,
Expand All @@ -22,6 +23,7 @@
)
from kestrel.ir.instructions import (
Filter,
Information,
Instruction,
Limit,
Offset,
Expand Down Expand Up @@ -244,8 +246,15 @@ def add_Filter(self, filt: Filter) -> None:
self.query = self.query.where(selection)

def add_ProjectAttrs(self, proj: ProjectAttrs) -> None:
cols = [column(col) for col in proj.attrs]
self.query = self.query.with_only_columns(*cols)
if not self.source_schema:
raise SourceSchemaNotFound(self.result_w_literal_binds())
else:
if self.source_schema != ["*"]:
invalid_attrs = set(proj.attrs) - set(self.source_schema)
if invalid_attrs:
raise InvalidAttributes(list(invalid_attrs))
cols = [column(col) for col in proj.attrs]
self.query = self.query.with_only_columns(*cols)

def add_ProjectEntity(self, proj: ProjectEntity) -> None:
if self.projection_base_field and self.projection_base_field != "event":
Expand Down Expand Up @@ -290,6 +299,9 @@ def add_ProjectEntity(self, proj: ProjectEntity) -> None:
def add_Limit(self, lim: Limit) -> None:
self.query = self.query.limit(lim.num)

def add_Information(self, ins: Information) -> None:
self.query = self.query.limit(1)

def add_Offset(self, offset: Offset) -> None:
self.query = self.query.offset(offset.num)

Expand Down
11 changes: 11 additions & 0 deletions packages/kestrel_core/src/kestrel/interface/codegen/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from itertools import groupby
from typing import List

from pandas import DataFrame


def variable_attributes_to_dataframe(attrs: List[str]) -> DataFrame:
categories = []
for k, g in groupby(sorted(attrs), lambda s: s.split(".")[0] if "." in s else ""):
categories.append(", ".join(g))
return DataFrame(data={"attributes": categories})
5 changes: 5 additions & 0 deletions packages/kestrel_core/src/kestrel/ir/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ class Explain(SolePredecessorTransformingInstruction):
pass


@dataclass(eq=False)
class Information(SolePredecessorTransformingInstruction):
pass


@dataclass(eq=False)
class Limit(SolePredecessorTransformingInstruction):
num: int
Expand Down
18 changes: 18 additions & 0 deletions packages/kestrel_core/tests/test_cache_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,21 @@ def test_explain_find_event_to_entity(process_creation_events):
WHERE device.os = 'Linux')
SELECT DISTINCT *
FROM procs"""


def test_eval_information():
stmt = """
events = NEW event [ {"process.name": "cmd.exe", "process.pid": 123, "user.name": "user", "event_type": "process"} ]
INFO events
"""
graph = IRGraph()
rets = parse_kestrel_and_update_irgraph(stmt, graph, {})
graph = IRGraphEvaluable(graph)
c = SqlCache()
mapping = c.evaluate_graph(graph, c)

# check the return is correct
assert len(rets) == 1
df = mapping[rets[0].id]
attrs = df["attributes"].to_list()
assert attrs == ['event_type', 'process.name, process.pid', 'user.name']
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
evaluate_source_instruction, evaluate_transforming_instruction)
from kestrel.ir.graph import IRGraph
from kestrel.ir.instructions import Construct, Limit, ProjectAttrs, Variable
from kestrel.interface.codegen.utils import variable_attributes_to_dataframe


def test_evaluate_Construct():
Expand Down Expand Up @@ -97,3 +98,11 @@ def test_evaluate_Construct_Filter_ProjectAttrs():
ft = next(graph.predecessors(graph.get_variable("p4")))
dfx = evaluate_transforming_instruction(ft, df0)
assert dfx.to_dict("records") == [ {"name": "cmd.exe", "pid": 123} ]


def test_information():
data = [ {"process.name": "cmd.exe", "process.pid": 123, "user.name": "user", "event_type": "process"} ]
df = DataFrame(data)
idf = variable_attributes_to_dataframe(df)
attrs = idf["attributes"].to_list()
assert attrs == ['event_type', 'process.name, process.pid', 'user.name']
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _remove_nl(s):
]
)
def test_sql_translator(iseq, sql):
trans = SqlTranslator(sqlalchemy.dialects.sqlite.dialect(), "my_table", None, None, None, _time2string, "timestamp")
trans = SqlTranslator(sqlalchemy.dialects.sqlite.dialect(), "my_table", ["foo", "bar", "baz"], None, None, _time2string, "timestamp")
for i in iseq:
trans.add_instruction(i)
result = trans.result()
Expand Down
1 change: 0 additions & 1 deletion packages/kestrel_core/tests/test_ir_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def test_add_get_datasource():
g.add_datasource("stixshifter://abc")

s = g.add_datasource(DataSource("stixshifter://abc"))
print(g.to_json())
assert len(g) == 1

s2 = DataSource("stixshifter://abcd")
Expand Down
1 change: 0 additions & 1 deletion packages/kestrel_core/tests/test_ir_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def test_construct():
c = Construct(data)
assert c.data.equals(DataFrame(data))
assert c.interface == CACHE_INTERFACE_IDENTIFIER
print(c.to_dict())


def test_instruction_from_dict():
Expand Down
2 changes: 0 additions & 2 deletions packages/kestrel_core/tests/test_mapping_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def test_reverse_mapping_executable():

def test_reverse_mapping_event_id():
rmap = reverse_mapping(WINLOGBEAT_MAPPING)
print(json.dumps(rmap, indent=4))
assert rmap["winlog.event_id"][0]["ocsf_value"]["1"] == [100701]
assert rmap["winlog.event_id"][0]["ocsf_value"]["5"] == [100702]
assert rmap["winlog.event_id"][0]["ocsf_value"]["4688"] == [100701]
Expand Down Expand Up @@ -260,7 +259,6 @@ def test_translate_dataframe_events():
}
)
df = translate_dataframe(df, WINLOGBEAT_MAPPING)
print(df)
assert df["type_uid"].iloc[0] == 100701
assert df["type_uid"].iloc[1] == "1234" # Passthrough?

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from kestrel.exceptions import SourceNotFound
from kestrel.interface import AbstractInterface
from kestrel.interface.codegen.sql import ingest_dataframe_to_temp_table
from kestrel.interface.codegen.utils import variable_attributes_to_dataframe
from kestrel.ir.graph import IRGraphEvaluable
from kestrel.ir.instructions import (
DataSource,
Explain,
Filter,
Information,
Instruction,
Return,
SolePredecessorTransformingInstruction,
Expand Down Expand Up @@ -40,10 +42,10 @@ def __init__(
):
_logger.debug("SQLAlchemyInterface: loading config")
super().__init__(serialized_cache_catalog, session_id)
self.config = load_config()
self.schemas: dict = {} # Schema per table (index)
self.engines: dict = {} # Map of conn name -> engine
self.conns: dict = {} # Map of conn name -> connection
self.config = load_config()
for info in self.config.datasources.values():
name = info.connection
conn_info = self.config.connections[name]
Expand Down Expand Up @@ -102,7 +104,16 @@ def evaluate_graph(
# pass through
_logger.debug("No result/value translation")
dmm = None
mapping[instruction.id] = translate_dataframe(df, dmm) if dmm else df

df = translate_dataframe(df, dmm) if dmm else df

# handle Information command
if isinstance(instruction, Return):
trunk, _ = graph.get_trunk_n_branches(instruction)
if isinstance(trunk, Information):
df = variable_attributes_to_dataframe(df)

mapping[instruction.id] = df
return mapping

def explain_graph(
Expand Down
12 changes: 12 additions & 0 deletions packages/kestrel_interface_sqlalchemy/tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,15 @@ def test_find_entity_to_entity_2(setup_sqlite_ecs_process_creation):
parents = session.execute(huntflow)[0]
assert parents.shape[0] == 2
assert list(parents) == ['endpoint.uid', 'file.endpoint.uid', 'user.endpoint.uid', 'endpoint.name', 'file.endpoint.name', 'user.endpoint.name', 'endpoint.os', 'file.endpoint.os', 'user.endpoint.os', 'cmd_line', 'name', 'pid', 'uid']


def test_information(setup_sqlite_ecs_process_creation):
with Session() as session:
huntflow = """
evs = GET event FROM sqlalchemy://events WHERE os.name = 'Linux'
INFO evs
"""
df = session.execute(huntflow)[0]
attrs = df["attributes"].to_list()
assert attrs == ['actor.process.cmd_line, actor.process.endpoint.name, actor.process.endpoint.os, actor.process.endpoint.uid, actor.process.file.endpoint.name, actor.process.file.endpoint.os, actor.process.file.endpoint.uid, actor.process.file.name, actor.process.file.parent_folder, actor.process.file.path, actor.process.name, actor.process.parent_process.cmd_line, actor.process.parent_process.endpoint.name, actor.process.parent_process.endpoint.os, actor.process.parent_process.endpoint.uid, actor.process.parent_process.file.endpoint.name, actor.process.parent_process.file.endpoint.os, actor.process.parent_process.file.endpoint.uid, actor.process.parent_process.name, actor.process.parent_process.pid, actor.process.parent_process.uid, actor.process.parent_process.user.endpoint.name, actor.process.parent_process.user.endpoint.os, actor.process.parent_process.user.endpoint.uid, actor.process.pid, actor.process.uid, actor.process.user.endpoint.name, actor.process.user.endpoint.os, actor.process.user.endpoint.uid, actor.user.name, actor.user.uid', 'endpoint.name, endpoint.os, endpoint.uid', 'file.endpoint.name, file.endpoint.os, file.endpoint.uid', 'process.cmd_line, process.endpoint.name, process.endpoint.os, process.endpoint.uid, process.file.endpoint.name, process.file.endpoint.os, process.file.endpoint.uid, process.file.name, process.file.parent_folder, process.file.path, process.name, process.parent_process.cmd_line, process.parent_process.endpoint.name, process.parent_process.endpoint.os, process.parent_process.endpoint.uid, process.parent_process.file.endpoint.name, process.parent_process.file.endpoint.os, process.parent_process.file.endpoint.uid, process.parent_process.name, process.parent_process.pid, process.parent_process.uid, process.parent_process.user.endpoint.name, process.parent_process.user.endpoint.os, process.parent_process.user.endpoint.uid, process.pid, process.uid, process.user.endpoint.name, process.user.endpoint.os, process.user.endpoint.uid', 'reg_key.endpoint.name, reg_key.endpoint.os, reg_key.endpoint.uid', 'reg_value.endpoint.name, reg_value.endpoint.os, reg_value.endpoint.uid', 'user.name, user.uid']

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ENGINE = sqlalchemy.create_engine("sqlite:///test.db")
DIALECT = ENGINE.dialect
TABLE = "my_table"
TABLE_SCHEMA = ["CommandLine", "Image", "ProcessId", "ParentProcessId"]
TABLE_SCHEMA = ["CommandLine", "Image", "ProcessId", "ParentProcessId", "foo", "bar", "baz"]


TIMEFMT = '%Y-%m-%dT%H:%M:%S.%fZ'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def to_html_blocks(d: Display) -> Iterable[str]:
elif isinstance(d, GraphExplanation):
for graphlet in d.graphlets:
graph = IRGraph(graphlet.graph)
plt.figure(figsize=(4, 2))
plt.figure(figsize=(10, 8))
nx.draw(
graph,
with_labels=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def do_execute(

except Exception as e:
_logger.error("Exception occurred", exc_info=True)
error = f"{e.__class__.__name__}: {e}"
self.send_response(
self.iopub_socket, "stream", {"name": "stderr", "text": str(e)}
self.iopub_socket, "stream", {"name": "stderr", "text": error}
)

return {
Expand Down

0 comments on commit 9c2364d

Please sign in to comment.