Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add EXPLAIN to APPLY #567

Merged
merged 5 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions packages/kestrel_core/src/kestrel/analytics/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def my_analytic(df: pd.DataFrame, x: int = 0, y: float = 0.5)
from uuid import UUID

from kestrel.analytics.config import get_profile, load_profiles
from kestrel.display import GraphletExplanation
from kestrel.display import AnalyticOperation, GraphletExplanation
from kestrel.exceptions import (
AnalyticsError,
InvalidAnalytics,
Expand Down Expand Up @@ -156,6 +156,10 @@ def run(self, config: dict) -> DataFrame:
_logger.debug("python analytics job result:\n%s", df)
return df

def get_module_and_func_name(self, config: dict) -> str:
module_name, func_name = get_profile(self.analytic, config)
return module_name, func_name


class PythonAnalyticsInterface(AnalyticsInterface):
def __init__(
Expand Down Expand Up @@ -187,9 +191,20 @@ def store(
def explain_graph(
self,
graph: IRGraphEvaluable,
cache: MutableMapping[UUID, Any],
instructions_to_explain: Optional[Iterable[Instruction]] = None,
) -> Mapping[UUID, GraphletExplanation]:
raise NotImplementedError("PythonAnalyticsInterface.explain_graph") # TEMP
mapping = {}
if not instructions_to_explain:
instructions_to_explain = graph.get_sink_nodes()
for instruction in instructions_to_explain:
dep_graph = graph.duplicate_dependent_subgraph_of_node(instruction)
graph_dict = dep_graph.to_dict()
job = self._evaluate_instruction_in_graph(graph, cache, instruction)
module_name, func_name = job.get_module_and_func_name(self.config)
action = AnalyticOperation("Python", module_name + "::" + func_name)
mapping[instruction.id] = GraphletExplanation(graph_dict, action)
return mapping

def evaluate_graph(
self,
Expand Down
6 changes: 3 additions & 3 deletions packages/kestrel_core/src/kestrel/cache/inmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def explain_graph(
instructions_to_explain: Optional[Iterable[Instruction]] = None,
) -> Mapping[UUID, GraphletExplanation]:
mapping = {}
if not instructions_to_evaluate:
instructions_to_evaluate = graph.get_sink_nodes()
for instruction in instructions_to_evaluate:
if not instructions_to_explain:
instructions_to_explain = graph.get_sink_nodes()
for instruction in instructions_to_explain:
dep_graph = graph.duplicate_dependent_subgraph_of_node(instruction)
graph_dict = dep_graph.to_dict()
query = NativeQuery("DataFrame", "")
Expand Down
10 changes: 9 additions & 1 deletion packages/kestrel_core/src/kestrel/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@ class NativeQuery(DataClassJSONMixin):
statement: str


@dataclass
class AnalyticOperation(DataClassJSONMixin):
# which interface
interface: str
# operation description
operation: str


@dataclass
class GraphletExplanation(DataClassJSONMixin):
# serialized IRGraph
graph: Mapping
# data source query
query: NativeQuery
action: Union[NativeQuery, AnalyticOperation]


@dataclass
Expand Down
6 changes: 6 additions & 0 deletions packages/kestrel_core/src/kestrel/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,12 @@ def _add_node(self, node: Instruction, deref: bool = True) -> Instruction:
self.store = node.store
return super()._add_node(node, deref)

def to_dict(self) -> Mapping[str, Iterable[Mapping]]:
d = super().to_dict()
d["interface"] = self.interface
d["store"] = self.store
return d


@typechecked
class IRGraphSimpleQuery(IRGraphEvaluable):
Expand Down
2 changes: 1 addition & 1 deletion packages/kestrel_core/tests/test_cache_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def test_explain_find_event_to_entity(process_creation_events):
assert len(rets) == 1
explanation = mapping[rets[0].id]
construct = graph.get_nodes_by_type(Construct)[0]
stmt = explanation.query.statement.replace('"', '')
stmt = explanation.action.statement.replace('"', '')
assert stmt == f"""WITH es AS
(SELECT DISTINCT *
FROM {construct.id.hex}),
Expand Down
18 changes: 9 additions & 9 deletions packages/kestrel_core/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER
from kestrel.display import GraphExplanation
from kestrel.frontend.parser import parse_kestrel_and_update_irgraph
from kestrel.ir.graph import IRGraph
from kestrel.ir.graph import IRGraph, IRGraphEvaluable
from kestrel.ir.instructions import Construct, SerializableDataFrame


Expand Down Expand Up @@ -200,10 +200,10 @@ def test_explain_in_cache():
assert isinstance(res, GraphExplanation)
assert len(res.graphlets) == 1
ge = res.graphlets[0]
assert ge.graph == session.irgraph.to_dict()
assert ge.graph == IRGraphEvaluable(session.irgraph).to_dict()
construct = session.irgraph.get_nodes_by_type(Construct)[0]
assert ge.query.language == "SQL"
stmt = ge.query.statement.replace('"', '')
assert ge.action.language == "SQL"
stmt = ge.action.statement.replace('"', '')
assert stmt == f"WITH proclist AS \n(SELECT DISTINCT * \nFROM {construct.id.hex}v), \nbrowsers AS \n(SELECT DISTINCT * \nFROM proclist \nWHERE name != 'cmd.exe'), \nchrome AS \n(SELECT DISTINCT * \nFROM browsers \nWHERE pid = 205)\n SELECT DISTINCT * \nFROM chrome"
with pytest.raises(StopIteration):
next(ress)
Expand Down Expand Up @@ -275,28 +275,28 @@ def schemes():

# DISP procs
assert len(disp.graphlets[0].graph["nodes"]) == 5
query = disp.graphlets[0].query.statement.replace('"', '')
query = disp.graphlets[0].action.statement.replace('"', '')
procs = session.irgraph.get_variable("procs")
c1 = next(session.irgraph.predecessors(procs))
assert query == f"WITH procs AS \n(SELECT DISTINCT * \nFROM {c1.id.hex}), \np2 AS \n(SELECT DISTINCT * \nFROM procs \nWHERE name IN ('firefox.exe', 'chrome.exe'))\n SELECT DISTINCT pid \nFROM p2"

# DISP nt
assert len(disp.graphlets[1].graph["nodes"]) == 2
query = disp.graphlets[1].query.statement.replace('"', '')
query = disp.graphlets[1].action.statement.replace('"', '')
nt = session.irgraph.get_variable("nt")
c2 = next(session.irgraph.predecessors(nt))
assert query == f"WITH nt AS \n(SELECT DISTINCT * \nFROM {c2.id.hex})\n SELECT DISTINCT * \nFROM nt"

# DISP domain
assert len(disp.graphlets[2].graph["nodes"]) == 2
query = disp.graphlets[2].query.statement.replace('"', '')
query = disp.graphlets[2].action.statement.replace('"', '')
domain = session.irgraph.get_variable("domain")
c3 = next(session.irgraph.predecessors(domain))
assert query == f"WITH domain AS \n(SELECT DISTINCT * \nFROM {c3.id.hex})\n SELECT DISTINCT * \nFROM domain"

# EXPLAIN d2
assert len(disp.graphlets[3].graph["nodes"]) == 11
query = disp.graphlets[3].query.statement.replace('"', '')
query = disp.graphlets[3].action.statement.replace('"', '')
p2 = session.irgraph.get_variable("p2")
p2pa = next(session.irgraph.successors(p2))
assert query == f"WITH ntx AS \n(SELECT DISTINCT * \nFROM {nt.id.hex}v \nWHERE abc IN (SELECT DISTINCT * \nFROM {p2pa.id.hex}v)), \nd2 AS \n(SELECT DISTINCT * \nFROM {domain.id.hex}v \nWHERE ip IN (SELECT DISTINCT destination \nFROM ntx))\n SELECT DISTINCT * \nFROM d2"
Expand Down Expand Up @@ -368,7 +368,7 @@ def test_explain_find_event_to_entity(process_creation_events):
session.irgraph = process_creation_events
res = session.execute("procs = FIND process RESPONDED es WHERE device.os = 'Linux' EXPLAIN procs")[0]
construct = session.irgraph.get_nodes_by_type(Construct)[0]
stmt = res.graphlets[0].query.statement.replace('"', '')
stmt = res.graphlets[0].action.statement.replace('"', '')
# cache.sql will use "*" as columns for __setitem__ in virtual cache
# so the result is different from test_cache_sqlite::test_explain_find_event_to_entity
assert stmt == f"WITH es AS \n(SELECT DISTINCT * \nFROM {construct.id.hex}v), \nprocs AS \n(SELECT DISTINCT * \nFROM es \nWHERE device.os = \'Linux\')\n SELECT DISTINCT * \nFROM procs"
6 changes: 3 additions & 3 deletions packages/kestrel_interface_sqlalchemy/tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_find_event_to_entity(setup_sqlite_ecs_process_creation):
evs, explain, procs = session.execute(huntflow)
assert evs.shape[0] == 9 # all events

stmt = explain.graphlets[0].query.statement
stmt = explain.graphlets[0].action.statement
test_dir = os.path.dirname(os.path.abspath(__file__))
result_file = os.path.join(test_dir, "result_interface_find_event_to_entity.txt")
with open(result_file) as h:
Expand All @@ -187,7 +187,7 @@ def test_find_entity_to_event(setup_sqlite_ecs_process_creation):
"""
explain, e2 = session.execute(huntflow)

stmt = explain.graphlets[0].query.statement
stmt = explain.graphlets[0].action.statement
test_dir = os.path.dirname(os.path.abspath(__file__))
result_file = os.path.join(test_dir, "result_interface_find_entity_to_event.txt")
with open(result_file) as h:
Expand Down Expand Up @@ -223,7 +223,7 @@ def test_find_entity_to_entity(setup_sqlite_ecs_process_creation):
"""
explain, parents = session.execute(huntflow)

stmt = explain.graphlets[0].query.statement
stmt = explain.graphlets[0].action.statement
test_dir = os.path.dirname(os.path.abspath(__file__))
result_file = os.path.join(test_dir, "result_interface_find_entity_to_entity.txt")
with open(result_file) as h:
Expand Down
51 changes: 30 additions & 21 deletions packages/kestrel_jupyter/src/kestrel_jupyter_kernel/display.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import base64
import tempfile
from io import BytesIO
from math import ceil, sqrt
from typing import Iterable, Mapping

import matplotlib.pyplot as plt
import networkx as nx
import numpy
import sqlparse
from kestrel.display import Display, GraphExplanation
from kestrel.display import AnalyticOperation, Display, GraphExplanation, NativeQuery
from kestrel.ir.graph import IRGraph
from kestrel.ir.instructions import Construct, DataSource, Instruction, Variable
from pandas import DataFrame
Expand Down Expand Up @@ -39,7 +40,10 @@ def to_html_blocks(d: Display) -> Iterable[str]:
elif isinstance(d, GraphExplanation):
for graphlet in d.graphlets:
graph = IRGraph(graphlet.graph)
plt.figure(figsize=(10, 8))
yield f"<h5>INTERFACE: {graphlet.graph['interface']}; STORE: {graphlet.graph['store']}</h5>"

fig_side_length = min(10, ceil(sqrt(len(graph))) + 1)
plt.figure(figsize=(fig_side_length, fig_side_length))
nx.draw(
graph,
with_labels=True,
Expand All @@ -48,24 +52,29 @@ def to_html_blocks(d: Display) -> Iterable[str]:
node_size=260,
node_color="#bfdff5",
)
with tempfile.NamedTemporaryFile(delete_on_close=False) as tf:
tf.close()
plt.savefig(tf.name, format="png")
with open(tf.name, "rb") as tfx:
data = tfx.read()

img = data_uri = base64.b64encode(data).decode("utf-8")
fig_buffer = BytesIO()
plt.savefig(fig_buffer, format="png")
img = data_uri = base64.b64encode(fig_buffer.getvalue()).decode("utf-8")
imgx = f'<img src="data:image/png;base64,{img}">'
yield imgx

query = graphlet.query.statement
if graphlet.query.language == "SQL":
lexer = SqlLexer()
query = sqlparse.format(query, reindent=True, keyword_case="upper")
elif graphlet.query.language == "KQL":
lexer = KustoLexer()
else:
lexer = guess_lexer(query)
query = highlight(query, lexer, HtmlFormatter())
style = "<style>" + HtmlFormatter().get_style_defs() + "</style>"
yield style + query
if isinstance(graphlet.action, NativeQuery):
native_query = graphlet.action
language = native_query.language
query = native_query.statement
if language == "SQL":
lexer = SqlLexer()
query = sqlparse.format(query, reindent=True, keyword_case="upper")
elif language == "KQL":
lexer = KustoLexer()
else:
lexer = guess_lexer(query)
query = highlight(query, lexer, HtmlFormatter())
style = "<style>" + HtmlFormatter().get_style_defs() + "</style>"
yield style + query
elif isinstance(graphlet.action, AnalyticOperation):
analytic_operation = graphlet.action
data = {
"Analytics": [analytic_operation.operation],
}
yield DataFrame(data).to_html(index=False)
Loading