From 9810af779754e1f9f24697a312cf92b4e0870d2b Mon Sep 17 00:00:00 2001 From: Jeremy Howard Date: Sat, 7 Sep 2024 08:54:23 +1000 Subject: [PATCH] fixes #24 --- fastlite/_modidx.py | 1 + fastlite/core.py | 20 ++++++++++++-- nbs/00_core.ipynb | 64 ++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 79 insertions(+), 6 deletions(-) diff --git a/fastlite/_modidx.py b/fastlite/_modidx.py index 617d0ae..a653404 100644 --- a/fastlite/_modidx.py +++ b/fastlite/_modidx.py @@ -6,6 +6,7 @@ 'git_url': 'https://github.com/AnswerDotAI/fastlite', 'lib_path': 'fastlite'}, 'syms': { 'fastlite.core': { 'fastlite.core.Database.create': ('core.html#database.create', 'fastlite/core.py'), + 'fastlite.core.Database.import_file': ('core.html#database.import_file', 'fastlite/core.py'), 'fastlite.core.Database.q': ('core.html#database.q', 'fastlite/core.py'), 'fastlite.core.Database.t': ('core.html#database.t', 'fastlite/core.py'), 'fastlite.core.Database.v': ('core.html#database.v', 'fastlite/core.py'), diff --git a/fastlite/core.py b/fastlite/core.py index 3eba638..9a32ef2 100644 --- a/fastlite/core.py +++ b/fastlite/core.py @@ -14,6 +14,7 @@ from fastcore.xml import highlight from fastcore.xtras import hl_md, dataclass_src from sqlite_minutils.db import * +from sqlite_minutils.utils import rows_from_file,TypeTracker,Format import types try: from graphviz import Source @@ -168,7 +169,22 @@ def create( res.cls = cls return res -# %% ../nbs/00_core.ipynb 58 +# %% ../nbs/00_core.ipynb 55 +@patch +def import_file(self:Database, table_name, file, format=None, pk=None): + "Import path or handle `file` to new table `table_name`" + if isinstance(file, str): file = file.encode() + if isinstance(file, bytes): file = io.BytesIO(file) + with maybe_open(file) as fp: rows, format_used = rows_from_file(fp, format=format) + tracker = TypeTracker() + rows = tracker.wrap(rows) + tbl = self[table_name] + tbl.insert_all(rows, alter=True) + tbl.transform(types=tracker.types) + if pk: tbl.transform(pk=pk) + return tbl + +# %% ../nbs/00_core.ipynb 61 def _edge(tbl): return "\n".join(f"{fk.table}:{fk.column} -> {fk.other_table}:{fk.other_column};" for fk in tbl.foreign_keys) @@ -186,7 +202,7 @@ def _tnode(tbl): """ return f"{tbl.name} [label=<{res}>];\n" -# %% ../nbs/00_core.ipynb 59 +# %% ../nbs/00_core.ipynb 62 def diagram(tbls, ratio=0.7, size="10", neato=False, render=True): layout = "\nlayout=neato;\noverlap=prism;\noverlap_scaling=0.5;""" if neato else "" edges = "\n".join(map(_edge, tbls)) diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index d6e19fa..6af0874 100644 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -40,6 +40,7 @@ "from fastcore.xml import highlight\n", "from fastcore.xtras import hl_md, dataclass_src\n", "from sqlite_minutils.db import *\n", + "from sqlite_minutils.utils import rows_from_file,TypeTracker,Format\n", "import types\n", "\n", "try: from graphviz import Source\n", @@ -750,7 +751,7 @@ "metadata": {}, "outputs": [], "source": [ - "#| exports\n", + "#| export\n", "@patch\n", "def create(\n", " self: Database,\n", @@ -827,7 +828,7 @@ { "data": { "text/plain": [ - "" + "
" ] }, "execution_count": null, @@ -870,11 +871,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "CREATE TABLE [cat] (\n", + "CREATE TABLE \"cat\" (\n", " [id] INTEGER PRIMARY KEY,\n", " [name] TEXT,\n", " [age] INTEGER,\n", - " [city] TEXT\n", + " [city] TEXT,\n", + " [breed] TEXT\n", ")\n" ] } @@ -892,6 +894,60 @@ "db.t.cat.drop()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "@patch\n", + "def import_file(self:Database, table_name, file, format=None, pk=None):\n", + " \"Import path or handle `file` to new table `table_name`\"\n", + " if isinstance(file, str): file = file.encode()\n", + " if isinstance(file, bytes): file = io.BytesIO(file)\n", + " with maybe_open(file) as fp: rows, format_used = rows_from_file(fp, format=format)\n", + " tracker = TypeTracker()\n", + " rows = tracker.wrap(rows)\n", + " tbl = self[table_name]\n", + " tbl.insert_all(rows, alter=True)\n", + " tbl.transform(types=tracker.types)\n", + " if pk: tbl.transform(pk=pk)\n", + " return tbl" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This uses [`sqlite_utils.utils.rows_from_file`](https://sqlite-utils.datasette.io/en/stable/reference.html#sqlite-utils-utils-rows-from-file) to load the file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'id': 1, 'name': 'Alice', 'age': 30}, {'id': 2, 'name': 'Bob', 'age': 25}, {'id': 3, 'name': 'Charlie', 'age': 35}]\n" + ] + } + ], + "source": [ + "db = Database(\":memory:\")\n", + "csv_data = \"\"\"id,name,age\n", + "1,Alice,30\n", + "2,Bob,25\n", + "3,Charlie,35\"\"\"\n", + "\n", + "table = db.import_file(\"people\", csv_data)\n", + "print(table())\n", + "table.drop()" + ] + }, { "cell_type": "markdown", "metadata": {},