Skip to content

Commit

Permalink
Merge pull request #558 from opencybersecurityalliance/k2-fix-dup-sol…
Browse files Browse the repository at this point in the history
…ve-ref

fix duplicated ref resolution bug
  • Loading branch information
subbyte authored Jul 26, 2024
2 parents 1f31327 + 905d5d9 commit 3c7789d
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 41 deletions.
14 changes: 1 addition & 13 deletions packages/kestrel_core/src/kestrel/cache/inmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __setitem__(
def get_virtual_copy(self) -> AbstractCache:
v = copy(self)
v.cache_catalog = copy(self.cache_catalog)
v.__class__ = InMemoryCacheVirtual
v.cache = copy(self.cache)
return v

def evaluate_graph(
Expand Down Expand Up @@ -113,15 +113,3 @@ def _evaluate_instruction_in_graph(
else:
raise NotImplementedError(f"Unknown instruction type: {instruction}")
return df


@typechecked
class InMemoryCacheVirtual(InMemoryCache):
def __getitem__(self, instruction_id: UUID) -> Any:
return self.cache_catalog[instruction_id]

def __delitem__(self, instruction_id: UUID):
del self.cache_catalog[instruction_id]

def __setitem__(self, instruction_id: UUID, data: Any):
self.cache_catalog[instruction_id] = "virtual" + instruction_id.hex
14 changes: 11 additions & 3 deletions packages/kestrel_core/src/kestrel/cache/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sqlalchemy
from dateutil.parser import parse as dt_parser
from kestrel.cache.base import AbstractCache
from kestrel.config.internal import VIRTUAL_CACHE_VAR_DATA
from kestrel.display import GraphletExplanation, NativeQuery
from kestrel.interface.codegen.sql import SqlTranslator
from kestrel.interface.codegen.utils import variable_attributes_to_dataframe
Expand Down Expand Up @@ -252,13 +253,20 @@ def _evaluate_instruction_in_graph(

@typechecked
class SqlCacheVirtual(SqlCache):
def __getitem__(self, instruction_id: UUID) -> Any:
return self.cache_catalog[instruction_id]
def __getitem__(self, instruction_id: UUID) -> DataFrame:
if instruction_id in self.cache_catalog:
try:
df = read_sql(self.cache_catalog[instruction_id], self.connection)
except:
df = VIRTUAL_CACHE_VAR_DATA
else:
raise KeyError(instruction_id)
return df

def __delitem__(self, instruction_id: UUID):
del self.cache_catalog[instruction_id]

def __setitem__(self, instruction_id: UUID, data: Any):
def __setitem__(self, instruction_id: UUID, data: DataFrame):
self.cache_catalog[instruction_id] = instruction_id.hex + "v"
self.cache_catalog_schemas[instruction_id] = ["*"]

Expand Down
5 changes: 4 additions & 1 deletion packages/kestrel_core/src/kestrel/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from kestrel.exceptions import KestrelError
from kestrel.session import Session
from pandas import DataFrame


def add_logging_handler(handler, if_debug):
Expand Down Expand Up @@ -47,7 +48,9 @@ def kestrel():
with open(args.huntflow, "r") as fp:
huntflow = fp.read()
outputs = session.execute(huntflow)
results = "\n\n".join([o.to_string() for o in outputs])
results = "\n\n".join(
[o.to_string() if isinstance(o, DataFrame) else str(o) for o in outputs]
)
print(results)


Expand Down
4 changes: 4 additions & 0 deletions packages/kestrel_core/src/kestrel/config/internal.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from pandas import DataFrame

CACHE_INTERFACE_IDENTIFIER = "cache"
CACHE_STORAGE_IDENTIFIER = "local"

VIRTUAL_CACHE_VAR_DATA = DataFrame({"*": ["*"]})
8 changes: 6 additions & 2 deletions packages/kestrel_core/src/kestrel/ir/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ class TimeRange(DataClassJSONMixin):
@typechecked
def get_references_from_exp(exp: FExpression) -> Iterable[ReferenceValue]:
if isinstance(exp, RefComparison):
yield exp.value
if isinstance(exp.value, ReferenceValue):
# if not already resolved
yield exp.value
elif isinstance(exp, BoolExp):
yield from get_references_from_exp(exp.lhs)
yield from get_references_from_exp(exp.rhs)
Expand All @@ -207,7 +209,9 @@ def resolve_reference_with_function(
exp: FExpression, f: Callable[[ReferenceValue], Any]
):
if isinstance(exp, RefComparison):
exp.value = f(exp.value)
if isinstance(exp.value, ReferenceValue):
# if not already resolved
exp.value = f(exp.value)
elif isinstance(exp, BoolExp):
resolve_reference_with_function(exp.lhs, f)
resolve_reference_with_function(exp.rhs, f)
Expand Down
4 changes: 2 additions & 2 deletions packages/kestrel_core/src/kestrel/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from kestrel.analytics import PythonAnalyticsInterface
from kestrel.cache import SqlCache
from kestrel.config import load_kestrel_config
from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER
from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER, VIRTUAL_CACHE_VAR_DATA
from kestrel.display import Display, GraphExplanation
from kestrel.exceptions import InstructionNotFound
from kestrel.frontend.completor import do_complete
Expand Down Expand Up @@ -134,7 +134,7 @@ def evaluate_instruction(self, ins: Instruction) -> Display:
).items():
if is_explain:
display.graphlets.append(_display)
_cache[iid] = True # virtual cache; value type does not matter
_cache[iid] = VIRTUAL_CACHE_VAR_DATA
else:
display = _display
_cache[iid] = display
Expand Down
11 changes: 3 additions & 8 deletions packages/kestrel_core/tests/test_cache_inmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pandas import DataFrame, read_csv

from kestrel.cache import InMemoryCache
from kestrel.cache.inmemory import InMemoryCacheVirtual
from kestrel.config import load_kestrel_config
from kestrel.frontend.parser import parse_kestrel_and_update_irgraph
from kestrel.ir.graph import IRGraph, IRGraphEvaluable
Expand Down Expand Up @@ -126,15 +125,11 @@ def test_get_virtual_copy():
mapping = c.evaluate_graph(graph, c)
v = c.get_virtual_copy()
new_entry = uuid4()
v[new_entry] = True
v[new_entry] = DataFrame()

# v[new_entry] calls the right method
assert isinstance(v, InMemoryCacheVirtual)
assert v[new_entry].startswith("virtual")

# v[new_entry] does not hit v.cache
# v[new_entry] does not hit c.cache
assert len(c.cache) == 2
assert len(v.cache) == 2
assert len(v.cache) == 3

# the two cache_catalog are different
assert new_entry not in c
Expand Down
5 changes: 3 additions & 2 deletions packages/kestrel_core/tests/test_cache_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sqlalchemy
from pandas import DataFrame, read_csv

from kestrel.config.internal import VIRTUAL_CACHE_VAR_DATA
from kestrel.cache import SqlCache
from kestrel.cache.sql import SqlCacheVirtual
from kestrel.config import load_kestrel_config
Expand Down Expand Up @@ -201,11 +202,11 @@ def test_get_virtual_copy():
mapping = c.evaluate_graph(graph, c)
v = c.get_virtual_copy()
new_entry = uuid4()
v[new_entry] = True
v[new_entry] = DataFrame()

# v[new_entry] calls the right method
assert isinstance(v, SqlCacheVirtual)
assert v[new_entry].endswith("v")
assert v[new_entry].equals(VIRTUAL_CACHE_VAR_DATA)

# the two cache_catalog are different
assert new_entry not in c
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ def evaluate_graph(

# handle Information command
if isinstance(instruction, Return):
trunk, _ = graph.get_trunk_n_branches(instruction)
trunk = (
instruction.predecessor
if hasattr(instruction, "predecessor")
else graph.get_trunk_n_branches(instruction)[0]
)
if isinstance(trunk, Information):
df = variable_attributes_to_dataframe(df)

Expand All @@ -132,14 +136,19 @@ def explain_graph(
instructions_to_explain: Optional[Iterable[Instruction]] = None,
) -> Mapping[UUID, GraphletExplanation]:
mapping = {}
graph_genuine_copy = graph.deepcopy()
if not instructions_to_explain:
instructions_to_explain = graph.get_sink_nodes()
for instruction in instructions_to_explain:
dep_graph = graph.duplicate_dependent_subgraph_of_node(instruction)
graph_dict = dep_graph.to_dict()
# duplicate graph here before ref resolution
dep_graph = graph_genuine_copy.duplicate_dependent_subgraph_of_node(
instruction
)
# render the graph in SQL
translator = self._evaluate_instruction_in_graph(graph, cache, instruction)
query = NativeQuery("SQL", str(translator.result_w_literal_binds()))
mapping[instruction.id] = GraphletExplanation(graph_dict, query)
# return the graph and SQL
mapping[instruction.id] = GraphletExplanation(dep_graph.to_dict(), query)
return mapping

def _evaluate_instruction_in_graph(
Expand Down Expand Up @@ -220,7 +229,14 @@ def _evaluate_instruction_in_graph(
if instruction.id in subquery_memory:
translator = subquery_memory[instruction.id]
else:
trunk, r2n = graph.get_trunk_n_branches(instruction)
# record the predecessor so we do not resolve reference for Filter again
# which is not possible (ReferenceValue already gone---replaced with subquery)
if hasattr(instruction, "predecessor"):
trunk, r2n = instruction.predecessor, {}
else:
trunk, r2n = graph.get_trunk_n_branches(instruction)
instruction.predecessor = trunk

translator = self._evaluate_instruction_in_graph(
graph, cache, trunk, graph_genuine_copy, subquery_memory
)
Expand All @@ -237,11 +253,16 @@ def _evaluate_instruction_in_graph(
translator.add_instruction(instruction)

elif isinstance(instruction, Filter):
instruction.resolve_references(
lambda x: self._evaluate_instruction_in_graph(
graph, cache, r2n[x], graph_genuine_copy, subquery_memory
).query
)
if r2n:
instruction.resolve_references(
lambda x: self._evaluate_instruction_in_graph(
graph,
cache,
r2n[x],
graph_genuine_copy,
subquery_memory,
).query
)
translator.add_instruction(instruction)

else:
Expand Down

0 comments on commit 3c7789d

Please sign in to comment.