Skip to content

Commit

Permalink
fix temp table deleting/reusing error
Browse files Browse the repository at this point in the history
  • Loading branch information
subbyte committed Jul 26, 2024
1 parent 1581e27 commit e5672c2
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 21 deletions.
20 changes: 17 additions & 3 deletions packages/kestrel_core/src/kestrel/interface/codegen/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,27 @@ def _execute_create(self):
with self.pd_sql.run_transaction():
self.table.create(bind=self.pd_sql.con)

def create(self) -> None:
# SQLite actually use the same temp DB everytime
# even after connection closed and opened again
# so we need to drop the previous temp table
#
# override the superclass code here
# since we need to drop a table with a different schema
# which is not supported by the superclass method
if self.exists():
with self.pd_sql.run_transaction():
self.pd_sql.get_table(self.name).drop(bind=self.pd_sql.con)
self.pd_sql.meta.clear()
self._execute_create()


@typechecked
def ingest_dataframe_to_temp_table(conn: Connection, df: DataFrame, table_name: str):
with pandasSQL_builder(conn) as pandas_engine:
table = _TemporaryTable(
table_name, pandas_engine, frame=df, if_exists="replace", index=False
)
# no need to put if_exists="replace"
# since our customized .create() only has this logic
table = _TemporaryTable(table_name, pandas_engine, frame=df, index=False)
table.create()
df.to_sql(table_name, con=conn, if_exists="append", index=False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,12 @@ def __init__(
_logger.debug("SQLAlchemyInterface: loading config")
super().__init__(serialized_cache_catalog, session_id)
self.engines: dict = {} # Map of conn name -> engine
self.conns: dict = {} # Map of conn name -> connection
self.config = load_config()
for info in self.config.datasources.values():
name = info.connection
conn_info = self.config.connections[name]
if name not in self.engines:
self.engines[name] = sqlalchemy.create_engine(conn_info.url)
if name not in self.conns:
engine = self.engines[name]
self.conns[name] = engine.connect()
_logger.debug("SQLAlchemyInterface: configured %s", name)

@staticmethod
Expand All @@ -75,8 +71,7 @@ def store(
raise NotImplementedError("SQLAlchemyInterface.store") # TEMP

def __del__(self):
for conn in self.conns.values():
conn.close()
pass

def evaluate_graph(
self,
Expand All @@ -88,14 +83,18 @@ def evaluate_graph(
if not instructions_to_evaluate:
instructions_to_evaluate = graph.get_sink_nodes()
for instruction in instructions_to_evaluate:
translator = self._evaluate_instruction_in_graph(graph, cache, instruction)
conn = self.engines[graph.store].connect()

translator = self._evaluate_instruction_in_graph(
conn, graph, cache, instruction
)
# TODO: may catch error in case evaluation starts from incomplete SQL
sql = translator.result()
_logger.debug("SQL query generated: %s", sql)

# Get the "from" table for this query
conn = self.conns[graph.store]
df = read_sql(sql, conn)
conn.close()

# value translation
if translator.data_mapping:
Expand Down Expand Up @@ -145,14 +144,19 @@ def explain_graph(
instruction
)
# render the graph in SQL
translator = self._evaluate_instruction_in_graph(graph, cache, instruction)
conn = self.engines[graph.store].connect()
translator = self._evaluate_instruction_in_graph(
conn, graph, cache, instruction
)
query = NativeQuery("SQL", str(translator.result_w_literal_binds()))
conn.close()
# return the graph and SQL
mapping[instruction.id] = GraphletExplanation(dep_graph.to_dict(), query)
return mapping

def _evaluate_instruction_in_graph(
self,
conn: sqlalchemy.engine.Connection,
graph: IRGraphEvaluable,
cache: MutableMapping[UUID, Any],
instruction: Instruction,
Expand Down Expand Up @@ -180,7 +184,7 @@ def _evaluate_instruction_in_graph(

# write to temp table
ingest_dataframe_to_temp_table(
self.conns[graph.store],
conn,
cache[instruction.id],
table_name,
)
Expand All @@ -206,11 +210,9 @@ def _evaluate_instruction_in_graph(
if isinstance(instruction, DataSource):
ds_config = self.config.datasources[instruction.datasource]
columns = list(
self.conns[ds_config.connection]
.execute(
conn.execute(
sqlalchemy.text(f"SELECT * FROM {ds_config.table} LIMIT 1")
)
.keys()
).keys()
)
translator = SQLAlchemyTranslator(
NativeTable(
Expand Down Expand Up @@ -238,7 +240,7 @@ def _evaluate_instruction_in_graph(
instruction.predecessor = trunk

translator = self._evaluate_instruction_in_graph(
graph, cache, trunk, graph_genuine_copy, subquery_memory
conn, graph, cache, trunk, graph_genuine_copy, subquery_memory
)

if isinstance(instruction, SolePredecessorTransformingInstruction):
Expand All @@ -256,6 +258,7 @@ def _evaluate_instruction_in_graph(
if r2n:
instruction.resolve_references(
lambda x: self._evaluate_instruction_in_graph(
conn,
graph,
cache,
r2n[x],
Expand Down
13 changes: 10 additions & 3 deletions packages/kestrel_interface_sqlalchemy/tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,20 @@ def test_write_to_temp_table(setup_sqlite_ecs_process_creation):
datalake = session.interface_manager["sqlalchemy"]
idx = uuid4().hex
df = DataFrame({'foo': [1, 2, 3]})
conn_name = list(datalake.conns.keys())[0]
conn = datalake.conns[conn_name]
conn = datalake.engines["datalake"].connect()
ingest_dataframe_to_temp_table(conn, df, idx)
assert read_sql(f'SELECT * FROM "{idx}"', conn).equals(df)
conn.close()
conn = datalake.engines[conn_name].connect()
conn = datalake.engines["datalake"].connect()
assert read_sql(f'SELECT * FROM "{idx}"', conn).empty
# ingest again actually write to the same temp table
# which exist in temp.
# the kestrel.interface.codegen.sql needs to handle this
ingest_dataframe_to_temp_table(conn, df, idx)
conn.close()
conn = datalake.engines["datalake"].connect()
assert read_sql(f'SELECT * FROM "{idx}"', conn).empty
conn.close()



Expand Down

0 comments on commit e5672c2

Please sign in to comment.