From 337da3986d7f3d2cadb88f83255f66a7fc71107f Mon Sep 17 00:00:00 2001 From: Martin O'Hanlon Date: Sun, 12 Aug 2018 20:39:25 +0100 Subject: [PATCH] v0.1.0 --- .gitignore | 1 + README.rst | 11 +- docs/changelog.rst | 8 ++ docs/getstarted.rst | 0 quickdraw/data.py | 194 +++++++++++++++++++++++++++---- setup.py | 6 +- tests/test_quickdrawdata.py | 138 ++++++++++++++++++++++ tests/test_quickdrawdatagroup.py | 135 +++++++++++++++++++++ 8 files changed, 465 insertions(+), 28 deletions(-) create mode 100644 docs/getstarted.rst create mode 100644 tests/test_quickdrawdata.py create mode 100644 tests/test_quickdrawdatagroup.py diff --git a/.gitignore b/.gitignore index 4a99e44..e35f42e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.py[cdo] pythonhosted/ +.pytest_cache/ # Editor detritus *.vim diff --git a/README.rst b/README.rst index 01768f3..d3e16e9 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,8 @@ quickdraw ========= +|pypibadge| |docsbadge| + `Quick Draw`_ is a drawing game which is training a neural network to recognise doodles. |quickdraw| @@ -128,7 +130,7 @@ The drawings have been moderated but there is no guarantee it'll actually be a p Status ------ -**Alpha** - under active dev, the API may change, problems might occur. +**Beta** - stable, under active dev, the API may change. .. |quickdraw| image:: https://raw.githubusercontent.com/martinohanlon/quickdraw_python/master/docs/images/quickdraw.png @@ -143,6 +145,13 @@ Status :scale: 100 % :alt: quickdraw_preview +.. |pypibadge| image:: https://badge.fury.io/py/quickdraw.svg + :target: https://badge.fury.io/py/quickdraw + :alt: Latest Version + +.. |docsbadge| image:: https://readthedocs.org/projects/quickdraw/badge/ + :target: https://readthedocs.org/projects/quickdraw/ + :alt: Docs .. _Martin O'Hanlon: https://github.com/martinohanlon .. _stuffaboutco.de: http://stuffaboutco.de diff --git a/docs/changelog.rst b/docs/changelog.rst index 587d205..fa9a080 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,6 +3,14 @@ Change log .. currentmodule:: quickdraw +0.1.0 +----- + ++ Beta ++ Bug fixes ++ Additional properties methods and stuff ++ Tests + 0.0.1 > 0.0.4 ------------- diff --git a/docs/getstarted.rst b/docs/getstarted.rst new file mode 100644 index 0000000..e69de29 diff --git a/quickdraw/data.py b/quickdraw/data.py index 64819f8..7590805 100644 --- a/quickdraw/data.py +++ b/quickdraw/data.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import struct from random import choice from os import path, makedirs @@ -28,17 +30,22 @@ class QuickDrawData(): anvil = qd.get_drawing("anvil") anvil.image.save("my_anvil.gif") + :param bool recognized: + If ``True`` only recognized drawings will be loaded, if ``False`` + only unrecognized drawings will be loaded, if ``None`` (the default) + both recognized and unrecognized drawings will be loaded. + :param int max_drawings: The maximum number of drawings to be loaded into memory, defaults to 1000. :param bool refresh_data: - If `True` forces data to be downloaded even if it has been - downloaded before, defaults to `False`. + If ``True`` forces data to be downloaded even if it has been + downloaded before, defaults to ``False``. :param bool jit_loading: - If `True` (the default) only downloads and loads data into - memory when it is required (jit = just in time). If `False` + If ``True`` (the default) only downloads and loads data into + memory when it is required (jit = just in time). If ``False`` all drawings will be downloaded and loaded into memory. :param bool print_messages: @@ -49,7 +56,16 @@ class QuickDrawData(): Specify a cache directory to use when downloading data files, defaults to `./.quickdrawcache`. """ - def __init__(self, max_drawings=1000, refresh_data=False, jit_loading=True, print_messages=True, cache_dir=CACHE_DIR): + def __init__( + self, + recognized=None, + max_drawings=1000, + refresh_data=False, + jit_loading=True, + print_messages=True, + cache_dir=CACHE_DIR): + + self._recognized = recognized self._print_messages = print_messages self._refresh_data = refresh_data self._max_drawings = max_drawings @@ -74,7 +90,7 @@ def get_drawing(self, name, index=None): :param int index: The index of the drawing to get. - If `None` (the default) a random drawing will be returned. + If ``None`` (the default) a random drawing will be returned. """ return self.get_drawing_group(name).get_drawing(index) @@ -91,6 +107,7 @@ def get_drawing_group(self, name): if name not in self._drawing_groups.keys(): drawings = QuickDrawDataGroup( name, + recognized=self._recognized, max_drawings=self._max_drawings, refresh_data=self._refresh_data, print_messages=self._print_messages, @@ -99,11 +116,58 @@ def get_drawing_group(self, name): return self._drawing_groups[name] + def search_drawings(self, name, key_id=None, recognized=None, countrycode=None, timestamp=None): + """ + Search the drawings. + + Returns an list of :class:`QuickDrawing` instances representing the + matched drawings. + + Note - search criteria are a compound. + + Search for all the drawings with the ``countrycode`` "PL" :: + + from quickdraw import QuickDrawDataGroup + + anvils = QuickDrawDataGroup("anvil") + results = anvils.search_drawings(countrycode="PL") + + :param string name: + The name of the drawings (anvil, ant, aircraft, etc) + to search. + + :param int key_id: + The ``key_id`` to such for. If ``None`` (the default) the + ``key_id`` is not used. + + :param bool recognized: + To search for drawings which were ``recognized``. If ``None`` + (the default) ``recognized`` is not used. + + :param str countrycode: + To search for drawings which with the ``countrycode``. If + ``None`` (the default) ``countrycode`` is not used. + + :param int countrycode: + To search for drawings which with the ``timestamp``. If ``None`` + (the default) ``timestamp`` is not used. + """ + return self.get_drawing_group(name).search_drawings(key_id, recognized, countrycode, timestamp) + def load_all_drawings(self): """ Loads (and downloads if required) all drawings into memory. """ - for drawing_group in QUICK_DRAWING_NAMES: + self.load_drawings(self.drawing_names) + + def load_drawings(self, list_of_drawings): + """ + Loads (and downloads if required) all drawings into memory. + + :param list list_of_drawings: + A list of the drawings to be loaded (anvil, ant, aircraft, etc). + """ + for drawing_group in list_of_drawings: self.get_drawing_group(drawing_group) @property @@ -113,6 +177,13 @@ def drawing_names(self): """ return QUICK_DRAWING_NAMES + @property + def loaded_drawings(self): + """ + Returns a list of drawing which have been loaded into memory. + """ + return list(self._drawing_groups.keys()) + class QuickDrawDataGroup(): """ @@ -130,12 +201,17 @@ class QuickDrawDataGroup(): :param string name: The name of the drawings to be loaded (anvil, ant, aircraft, etc). + :param bool recognized: + If ``True`` only recognized drawings will be loaded, if ``False`` + only unrecognized drawings will be loaded, if ``None`` (the default) + both recognized and unrecognized drawings will be loaded. + :param int max_drawings: The maximum number of drawings to be loaded into memory, defaults to 1000. :param bool refresh_data: - If `True` forces data to be downloaded even if it has been + If ``True`` forces data to be downloaded even if it has been downloaded before, defaults to `False`. :param bool print_messages: @@ -144,9 +220,16 @@ class QuickDrawDataGroup(): :param string cache_dir: Specify a cache directory to use when downloading data files, - defaults to `./.quickdrawcache`. + defaults to ``./.quickdrawcache``. """ - def __init__(self, name, max_drawings=1000, refresh_data=False, print_messages=True, cache_dir=CACHE_DIR): + def __init__( + self, + name, + recognized=None, + max_drawings=1000, + refresh_data=False, + print_messages=True, + cache_dir=CACHE_DIR): if name not in QUICK_DRAWING_NAMES: raise ValueError("{} is not a valid google quick drawing".format(name)) @@ -155,6 +238,7 @@ def __init__(self, name, max_drawings=1000, refresh_data=False, print_messages=T self._print_messages = print_messages self._max_drawings = max_drawings self._cache_dir = cache_dir + self._recognized = recognized self._drawings = [] @@ -227,21 +311,27 @@ def _load_drawings(self, filename): y = struct.unpack(fmt, binary_file.read(n_points)) image.append((x, y)) - self._drawings.append({ - 'key_id': key_id, - 'countrycode': countrycode, - 'recognized': recognized, - 'timestamp': timestamp, - 'n_strokes': n_strokes, - 'image': image - }) + append_drawing = True + if self._recognized is not None: + if bool(recognized) != self._recognized: + append_drawing = False + + if append_drawing: + self._drawings.append({ + 'key_id': key_id, + 'countrycode': countrycode, + 'recognized': recognized, + 'timestamp': timestamp, + 'n_strokes': n_strokes, + 'image': image + }) + + self._drawing_count += 1 # nothing left to read except struct.error: break - self._drawing_count += 1 - self._print_message("load complete") def _print_message(self, message): @@ -270,7 +360,7 @@ def drawings(self): """ while True: self._current_drawing += 1 - if self._current_drawing == self._drawing_count - 1: + if self._current_drawing > self._drawing_count - 1: # reached the end to the drawings self._current_drawing = 0 raise StopIteration() @@ -295,16 +385,72 @@ def get_drawing(self, index=None): :param int index: The index of the drawing to get. - If `None` (the default) a random drawing will be returned. + If ``None`` (the default) a random drawing will be returned. """ if index is None: return QuickDrawing(self._name, choice(self._drawings)) else: - if index < self.drawing_count - 1: + if index < self.drawing_count: return QuickDrawing(self._name, self._drawings[index]) else: raise IndexError("index {} out of range, there are {} drawings".format(index, self.drawing_count)) + def search_drawings(self, key_id=None, recognized=None, countrycode=None, timestamp=None): + """ + Searches the drawings in this group. + + Returns an list of :class:`QuickDrawing` instances representing the + matched drawings. + + Note - search criteria are a compound. + + Search for all the drawings with the ``countrycode`` "PL" :: + + from quickdraw import QuickDrawDataGroup + + anvils = QuickDrawDataGroup("anvil") + results = anvils.search_drawings(countrycode="PL") + + :param int key_id: + The ``key_id`` to such for. If ``None`` (the default) the + ``key_id`` is not used. + + :param bool recognized: + To search for drawings which were ``recognized``. If ``None`` + (the default) ``recognized`` is not used. + + :param str countrycode: + To search for drawings which with the ``countrycode``. If + ``None`` (the default) ``countrycode`` is not used. + + :param int countrycode: + To search for drawings which with the ``timestamp``. If ``None`` + (the default) ``timestamp`` is not used. + """ + results = [] + + for drawing in self.drawings: + match = True + if key_id is not None: + if key_id != drawing.key_id: + match = False + + if recognized is not None: + if recognized != drawing.recognized: + match = False + + if countrycode is not None: + if countrycode != drawing.countrycode: + match = False + + if timestamp is not None: + if timestamp != drawing.timestamp: + match = False + + if match: + results.append(drawing) + + return results class QuickDrawing(): """ @@ -335,7 +481,7 @@ def countrycode(self): """ Returns the country code for the drawing. """ - return self._drawing_data["countrycode"] + return self._drawing_data["countrycode"].decode("utf-8") @property def recognized(self): diff --git a/setup.py b/setup.py index 12abd76..2df1766 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ __project__ = 'quickdraw' __desc__ = 'An API for downloading and reading the google quickdraw data.' -__version__ = '0.0.4' +__version__ = '0.1.0' __author__ = "Martin O'Hanlon" __author_email__ = 'martin@ohanlonweb.com' __license__ = 'MIT' @@ -63,8 +63,8 @@ """ __classifiers__ = [ - "Development Status :: 3 - Alpha", -# "Development Status :: 4 - Beta", +# "Development Status :: 3 - Alpha", + "Development Status :: 4 - Beta", # "Development Status :: 5 - Production/Stable", "Intended Audience :: Education", "Intended Audience :: Developers", diff --git a/tests/test_quickdrawdata.py b/tests/test_quickdrawdata.py new file mode 100644 index 0000000..8d90590 --- /dev/null +++ b/tests/test_quickdrawdata.py @@ -0,0 +1,138 @@ +from quickdraw import QuickDrawData, QuickDrawDataGroup +from PIL.Image import Image + +def test_get_specific_drawing(): + qd = QuickDrawData() + + # get the first anvil drawing and test the values + d = qd.get_drawing("anvil", 0) + assert d.name == "anvil" + assert d.key_id == 5355190515400704 + assert d.recognized == True + assert d.countrycode == "PL" + assert d.timestamp == 1488368345 + + # 1 stroke, 2 x,y coords, 33 points + assert len(d.image_data) == 1 + assert len(d.image_data[0]) == 2 + assert len(d.image_data[0][0]) == 33 + assert len(d.image_data[0][1]) == 33 + + assert d.no_of_strokes == 1 + assert len(d.strokes) == 1 + assert len(d.strokes[0]) == 33 + assert len(d.strokes[0][0]) == 2 + + assert isinstance(d.image, Image) + assert isinstance(d.get_image(stroke_color=(10,10,10), stroke_width=4, bg_color=(200,200,200)), Image) + +def test_get_random_drawing(): + qd = QuickDrawData() + + d = qd.get_drawing("anvil", 0) + assert d.name == "anvil" + assert isinstance(d.key_id, int) + assert isinstance(d.recognized, bool) + assert isinstance(d.timestamp, int) + assert isinstance(d.countrycode, str) + + assert isinstance(d.image_data, list) + assert len(d.image_data) == d.no_of_strokes + + assert isinstance(d.strokes, list) + assert len(d.strokes) == d.no_of_strokes + for stroke in d.strokes: + for point in stroke: + assert len(point) == 2 + + assert isinstance(d.image, Image) + assert isinstance(d.get_image(stroke_color=(10,10,10), stroke_width=4, bg_color=(200,200,200)), Image) + +def test_drawing_names(): + qd = QuickDrawData() + assert len(qd.drawing_names) == 345 + +def test_load_drawings(): + qd = QuickDrawData() + qd.load_drawings(["anvil", "ant"]) + assert qd.loaded_drawings == ["anvil", "ant"] + + qd.get_drawing("angel") + assert qd.loaded_drawings == ["anvil", "ant", "angel"] + +def test_get_drawing_group(): + qd = QuickDrawData() + assert isinstance(qd.get_drawing_group("anvil"), QuickDrawDataGroup) + +def test_recognized_data(): + qdg = QuickDrawData(recognized=True).get_drawing_group("anvil") + assert qdg.drawing_count == 1000 + + rec = 0 + unrec = 0 + + for drawing in qdg.drawings: + if drawing.recognized: + rec += 1 + else: + unrec += 1 + + assert rec == qdg.drawing_count + assert unrec == 0 + +def test_unrecognized_data(): + qdg = QuickDrawData(recognized=False).get_drawing_group("anvil") + assert qdg.drawing_count == 1000 + + rec = 0 + unrec = 0 + + for drawing in qdg.drawings: + if drawing.recognized: + rec += 1 + else: + unrec += 1 + + assert rec == 0 + assert unrec == qdg.drawing_count + +def test_search_drawings(): + qd = QuickDrawData() + # test a search with no criteria returns 1000 results + r = qd.search_drawings("anvil") + assert len(r) == 1000 + + # test a recognized search + r = qd.search_drawings("anvil", recognized=True) + for d in r: + assert d.recognized + + r = qd.search_drawings("anvil", recognized=False) + for d in r: + assert not d.recognized + + # test a country search + r = qd.search_drawings("anvil", countrycode="US") + for d in r: + assert d.countrycode == "US" + + # pull first drawing + key_id = r[0].key_id + timestamp = r[0].timestamp + + # test key_id search + r = qd.search_drawings("anvil", key_id=key_id) + for d in r: + assert d.key_id == key_id + + # test timestamp search + r = qd.search_drawings("anvil", timestamp=timestamp) + for d in r: + assert d.timestamp == timestamp + + # test a compound search of recognized and country code + r = qd.search_drawings("anvil", recognized=True, countrycode="US") + for d in r: + assert d.recognized + assert d.countrycode == "US" + diff --git a/tests/test_quickdrawdatagroup.py b/tests/test_quickdrawdatagroup.py new file mode 100644 index 0000000..a565cb9 --- /dev/null +++ b/tests/test_quickdrawdatagroup.py @@ -0,0 +1,135 @@ +from quickdraw import QuickDrawDataGroup +from PIL.Image import Image + +def test_get_data_group(): + qdg = QuickDrawDataGroup("anvil") + assert qdg.drawing_count == 1000 + + qdg = QuickDrawDataGroup("anvil", max_drawings=2000) + assert qdg.drawing_count == 2000 + +def test_get_specific_drawing(): + qdg = QuickDrawDataGroup("anvil") + + # get the first anvil drawing and test the values + d = qdg.get_drawing(0) + assert d.name == "anvil" + assert d.key_id == 5355190515400704 + assert d.recognized == True + assert d.countrycode == "PL" + assert d.timestamp == 1488368345 + + # 1 stroke, 2 x,y coords, 33 points + assert len(d.image_data) == 1 + assert len(d.image_data[0]) == 2 + assert len(d.image_data[0][0]) == 33 + assert len(d.image_data[0][1]) == 33 + + assert d.no_of_strokes == 1 + assert len(d.strokes) == 1 + assert len(d.strokes[0]) == 33 + assert len(d.strokes[0][0]) == 2 + + assert isinstance(d.image, Image) + assert isinstance(d.get_image(stroke_color=(10,10,10), stroke_width=4, bg_color=(200,200,200)), Image) + +def test_get_random_drawing(): + qdg = QuickDrawDataGroup("anvil") + + d = qdg.get_drawing(0) + assert d.name == "anvil" + assert isinstance(d.key_id, int) + assert isinstance(d.recognized, bool) + assert isinstance(d.timestamp, int) + assert isinstance(d.countrycode, str) + + assert isinstance(d.image_data, list) + assert len(d.image_data) == d.no_of_strokes + + assert isinstance(d.strokes, list) + assert len(d.strokes) == d.no_of_strokes + for stroke in d.strokes: + for point in stroke: + assert len(point) == 2 + + assert isinstance(d.image, Image) + assert isinstance(d.get_image(stroke_color=(10,10,10), stroke_width=4, bg_color=(200,200,200)), Image) + +def test_drawings(): + qdg = QuickDrawDataGroup("anvil") + count = 0 + for drawing in qdg.drawings: + count += 1 + assert count == 1000 + +def test_recognized_data(): + qdg = QuickDrawDataGroup("anvil", recognized=True) + assert qdg.drawing_count == 1000 + + rec = 0 + unrec = 0 + + for drawing in qdg.drawings: + if drawing.recognized: + rec += 1 + else: + unrec += 1 + + assert rec == qdg.drawing_count + assert unrec == 0 + +def test_unrecognized_data(): + qdg = QuickDrawDataGroup("anvil", recognized=False) + assert qdg.drawing_count == 1000 + + rec = 0 + unrec = 0 + + for drawing in qdg.drawings: + if drawing.recognized: + rec += 1 + else: + unrec += 1 + + assert rec == 0 + assert unrec == qdg.drawing_count + +def test_search_drawings(): + qdg = QuickDrawDataGroup("anvil") + # test a search with no criteria returns 1000 results + r = qdg.search_drawings() + assert len(r) == 1000 + + # test a recognized search + r = qdg.search_drawings(recognized=True) + for d in r: + assert d.recognized + + r = qdg.search_drawings(recognized=False) + for d in r: + assert not d.recognized + + # test a country search + r = qdg.search_drawings(countrycode="US") + for d in r: + assert d.countrycode == "US" + + # pull first drawing + key_id = r[0].key_id + timestamp = r[0].timestamp + + # test key_id search + r = qdg.search_drawings(key_id=key_id) + for d in r: + assert d.key_id == key_id + + # test timestamp search + r = qdg.search_drawings(timestamp=timestamp) + for d in r: + assert d.timestamp == timestamp + + # test a compound search of recognized and country code + r = qdg.search_drawings(recognized=True, countrycode="US") + for d in r: + assert d.recognized + assert d.countrycode == "US" \ No newline at end of file