Skip to content

Commit

Permalink
[FIX] Prevent keys from starting with slashes (#172)
Browse files Browse the repository at this point in the history
* [FIX] Prevent keys from starting with slashes

If a key starts with a slash, then it becomes undeletable and prevents database purges from working properly as well.
This prevents that from occuring by stripping slashes from the left of the key name.

* Double newlines for flake8

* flake8 wanted another newline here

* Force `set_bulk_raw` to handle keys with slashes as well

* Add tests for keys starting with a slash

* Fix a typo I made twice

* flake8

* `del self.db[k]` not `self.db.delete(k)` in non-Async

* One space for flake8

* These were also wrong

* These shouldn't be using `get`

* Match format of some of the other tests in TestDatabase

* Perhaps the key is corrupted?

* Have to `get_raw` for `_raw` calls.

* Reassociate _dumps with def dumps

* Only call keyStrip at the root of the .set function hierarchy

* Clarify that keyStrip is an internal method

---------

Co-authored-by: Devon Stewart <[email protected]>
  • Loading branch information
Firepup6500 and blast-hardcheese authored Feb 28, 2024
1 parent 5c0ab7a commit 4a5a7ac
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/replit/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def dumps(val: Any) -> str:
_dumps = dumps


def _sanitize_key(key: str) -> str:
"""Strip slashes from the beginning of keys.
Args:
key (str): The key to strip
Returns:
str: The stripped key
"""
return key.lstrip("/")


class AsyncDatabase:
"""Async interface for Replit Database.
Expand Down Expand Up @@ -195,6 +207,7 @@ async def set_bulk_raw(self, values: Dict[str, str]) -> None:
Args:
values (Dict[str, str]): The key-value pairs to set.
"""
values = {_sanitize_key(k): v for k, v in values.items()}
async with self.client.post(self.db_url, data=values) as response:
response.raise_for_status()

Expand Down Expand Up @@ -629,6 +642,7 @@ def set_bulk_raw(self, values: Dict[str, str]) -> None:
Args:
values (Dict[str, str]): The key-value pairs to set.
"""
values = {_sanitize_key(k): v for k, v in values.items()}
r = self.sess.post(self.db_url, data=values)
r.raise_for_status()

Expand Down
55 changes: 55 additions & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,33 @@ async def test_bulk_raw(self) -> None:
self.assertEqual(await self.db.get_raw("bulk1"), "val1")
self.assertEqual(await self.db.get_raw("bulk2"), "val2")

async def test_slash_keys(self) -> None:
"""Test that slash keys work."""
k = "/key"
# set
await self.db.set(k,"val1")
self.assertEqual(await self.db.get(k), "val1")
await self.db.delete(k)
with self.assertRaises(KeyError):
await self.db.get(k)
# set_raw
await self.db.set_raw(k,"val1")
self.assertEqual(await self.db.get_raw(k), "val1")
await self.db.delete(k)
with self.assertRaises(KeyError):
await self.db.get(k)
# set_bulk
await self.db.set_bulk({k: "val1"})
self.assertEqual(await self.db.get(k), "val1")
await self.db.delete(k)
with self.assertRaises(KeyError):
await self.db.get(k)
# set_bulk_raw
await self.db.set_bulk_raw({k: "val1"})
self.assertEqual(await self.db.get_raw(k), "val1")
await self.db.delete(k)
with self.assertRaises(KeyError):
await self.db.get(k)

class TestDatabase(unittest.TestCase):
"""Tests for replit.database.Database."""
Expand Down Expand Up @@ -259,3 +286,31 @@ def test_bulk_raw(self) -> None:
self.db.set_bulk_raw({"bulk1": "val1", "bulk2": "val2"})
self.assertEqual(self.db.get_raw("bulk1"), "val1")
self.assertEqual(self.db.get_raw("bulk2"), "val2")

def test_slash_keys(self) -> None:
"""Test that slash keys work."""
k = "/key"
# set
self.db.set(k,"val1")
self.assertEqual(self.db[k], "val1")
del self.db[k]
with self.assertRaises(KeyError):
self.db[k]
# set_raw
self.db.set_raw(k,"val1")
self.assertEqual(self.db.get_raw(k), "val1")
del self.db[k]
with self.assertRaises(KeyError):
self.db[k]
# set_bulk
self.db.set_bulk({k: "val1"})
self.assertEqual(self.db.get(k), "val1")
del self.db[k]
with self.assertRaises(KeyError):
self.db[k]
# set_bulk_raw
self.db.set_bulk_raw({k: "val1"})
self.assertEqual(self.db.get_raw(k), "val1")
del self.db[k]
with self.assertRaises(KeyError):
self.db[k]

0 comments on commit 4a5a7ac

Please sign in to comment.