Skip to content

Commit

Permalink
Merge pull request #7 from martinohanlon/dev
Browse files Browse the repository at this point in the history
v0.1.0
  • Loading branch information
Martin O'Hanlon authored Aug 12, 2018
2 parents aefd7d2 + 337da39 commit 61d8413
Show file tree
Hide file tree
Showing 8 changed files with 465 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
*.py[cdo]
pythonhosted/
.pytest_cache/

# Editor detritus
*.vim
Expand Down
11 changes: 10 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
quickdraw
=========

|pypibadge| |docsbadge|

`Quick Draw`_ is a drawing game which is training a neural network to recognise doodles.

|quickdraw|
Expand Down Expand Up @@ -128,7 +130,7 @@ The drawings have been moderated but there is no guarantee it'll actually be a p
Status
------

**Alpha** - under active dev, the API may change, problems might occur.
**Beta** - stable, under active dev, the API may change.


.. |quickdraw| image:: https://raw.githubusercontent.com/martinohanlon/quickdraw_python/master/docs/images/quickdraw.png
Expand All @@ -143,6 +145,13 @@ Status
:scale: 100 %
:alt: quickdraw_preview

.. |pypibadge| image:: https://badge.fury.io/py/quickdraw.svg
:target: https://badge.fury.io/py/quickdraw
:alt: Latest Version

.. |docsbadge| image:: https://readthedocs.org/projects/quickdraw/badge/
:target: https://readthedocs.org/projects/quickdraw/
:alt: Docs

.. _Martin O'Hanlon: https://github.com/martinohanlon
.. _stuffaboutco.de: http://stuffaboutco.de
Expand Down
8 changes: 8 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ Change log

.. currentmodule:: quickdraw

0.1.0
-----

+ Beta
+ Bug fixes
+ Additional properties methods and stuff
+ Tests

0.0.1 > 0.0.4
-------------

Expand Down
Empty file added docs/getstarted.rst
Empty file.
194 changes: 170 additions & 24 deletions quickdraw/data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import unicode_literals

import struct
from random import choice
from os import path, makedirs
Expand Down Expand Up @@ -28,17 +30,22 @@ class QuickDrawData():
anvil = qd.get_drawing("anvil")
anvil.image.save("my_anvil.gif")
:param bool recognized:
If ``True`` only recognized drawings will be loaded, if ``False``
only unrecognized drawings will be loaded, if ``None`` (the default)
both recognized and unrecognized drawings will be loaded.
:param int max_drawings:
The maximum number of drawings to be loaded into memory,
defaults to 1000.
:param bool refresh_data:
If `True` forces data to be downloaded even if it has been
downloaded before, defaults to `False`.
If ``True`` forces data to be downloaded even if it has been
downloaded before, defaults to ``False``.
:param bool jit_loading:
If `True` (the default) only downloads and loads data into
memory when it is required (jit = just in time). If `False`
If ``True`` (the default) only downloads and loads data into
memory when it is required (jit = just in time). If ``False``
all drawings will be downloaded and loaded into memory.
:param bool print_messages:
Expand All @@ -49,7 +56,16 @@ class QuickDrawData():
Specify a cache directory to use when downloading data files,
defaults to `./.quickdrawcache`.
"""
def __init__(self, max_drawings=1000, refresh_data=False, jit_loading=True, print_messages=True, cache_dir=CACHE_DIR):
def __init__(
self,
recognized=None,
max_drawings=1000,
refresh_data=False,
jit_loading=True,
print_messages=True,
cache_dir=CACHE_DIR):

self._recognized = recognized
self._print_messages = print_messages
self._refresh_data = refresh_data
self._max_drawings = max_drawings
Expand All @@ -74,7 +90,7 @@ def get_drawing(self, name, index=None):
:param int index:
The index of the drawing to get.
If `None` (the default) a random drawing will be returned.
If ``None`` (the default) a random drawing will be returned.
"""
return self.get_drawing_group(name).get_drawing(index)

Expand All @@ -91,6 +107,7 @@ def get_drawing_group(self, name):
if name not in self._drawing_groups.keys():
drawings = QuickDrawDataGroup(
name,
recognized=self._recognized,
max_drawings=self._max_drawings,
refresh_data=self._refresh_data,
print_messages=self._print_messages,
Expand All @@ -99,11 +116,58 @@ def get_drawing_group(self, name):

return self._drawing_groups[name]

def search_drawings(self, name, key_id=None, recognized=None, countrycode=None, timestamp=None):
"""
Search the drawings.
Returns an list of :class:`QuickDrawing` instances representing the
matched drawings.
Note - search criteria are a compound.
Search for all the drawings with the ``countrycode`` "PL" ::
from quickdraw import QuickDrawDataGroup
anvils = QuickDrawDataGroup("anvil")
results = anvils.search_drawings(countrycode="PL")
:param string name:
The name of the drawings (anvil, ant, aircraft, etc)
to search.
:param int key_id:
The ``key_id`` to such for. If ``None`` (the default) the
``key_id`` is not used.
:param bool recognized:
To search for drawings which were ``recognized``. If ``None``
(the default) ``recognized`` is not used.
:param str countrycode:
To search for drawings which with the ``countrycode``. If
``None`` (the default) ``countrycode`` is not used.
:param int countrycode:
To search for drawings which with the ``timestamp``. If ``None``
(the default) ``timestamp`` is not used.
"""
return self.get_drawing_group(name).search_drawings(key_id, recognized, countrycode, timestamp)

def load_all_drawings(self):
"""
Loads (and downloads if required) all drawings into memory.
"""
for drawing_group in QUICK_DRAWING_NAMES:
self.load_drawings(self.drawing_names)

def load_drawings(self, list_of_drawings):
"""
Loads (and downloads if required) all drawings into memory.
:param list list_of_drawings:
A list of the drawings to be loaded (anvil, ant, aircraft, etc).
"""
for drawing_group in list_of_drawings:
self.get_drawing_group(drawing_group)

@property
Expand All @@ -113,6 +177,13 @@ def drawing_names(self):
"""
return QUICK_DRAWING_NAMES

@property
def loaded_drawings(self):
"""
Returns a list of drawing which have been loaded into memory.
"""
return list(self._drawing_groups.keys())


class QuickDrawDataGroup():
"""
Expand All @@ -130,12 +201,17 @@ class QuickDrawDataGroup():
:param string name:
The name of the drawings to be loaded (anvil, ant, aircraft, etc).
:param bool recognized:
If ``True`` only recognized drawings will be loaded, if ``False``
only unrecognized drawings will be loaded, if ``None`` (the default)
both recognized and unrecognized drawings will be loaded.
:param int max_drawings:
The maximum number of drawings to be loaded into memory,
defaults to 1000.
:param bool refresh_data:
If `True` forces data to be downloaded even if it has been
If ``True`` forces data to be downloaded even if it has been
downloaded before, defaults to `False`.
:param bool print_messages:
Expand All @@ -144,9 +220,16 @@ class QuickDrawDataGroup():
:param string cache_dir:
Specify a cache directory to use when downloading data files,
defaults to `./.quickdrawcache`.
defaults to ``./.quickdrawcache``.
"""
def __init__(self, name, max_drawings=1000, refresh_data=False, print_messages=True, cache_dir=CACHE_DIR):
def __init__(
self,
name,
recognized=None,
max_drawings=1000,
refresh_data=False,
print_messages=True,
cache_dir=CACHE_DIR):

if name not in QUICK_DRAWING_NAMES:
raise ValueError("{} is not a valid google quick drawing".format(name))
Expand All @@ -155,6 +238,7 @@ def __init__(self, name, max_drawings=1000, refresh_data=False, print_messages=T
self._print_messages = print_messages
self._max_drawings = max_drawings
self._cache_dir = cache_dir
self._recognized = recognized

self._drawings = []

Expand Down Expand Up @@ -227,21 +311,27 @@ def _load_drawings(self, filename):
y = struct.unpack(fmt, binary_file.read(n_points))
image.append((x, y))

self._drawings.append({
'key_id': key_id,
'countrycode': countrycode,
'recognized': recognized,
'timestamp': timestamp,
'n_strokes': n_strokes,
'image': image
})
append_drawing = True
if self._recognized is not None:
if bool(recognized) != self._recognized:
append_drawing = False

if append_drawing:
self._drawings.append({
'key_id': key_id,
'countrycode': countrycode,
'recognized': recognized,
'timestamp': timestamp,
'n_strokes': n_strokes,
'image': image
})

self._drawing_count += 1

# nothing left to read
except struct.error:
break

self._drawing_count += 1

self._print_message("load complete")

def _print_message(self, message):
Expand Down Expand Up @@ -270,7 +360,7 @@ def drawings(self):
"""
while True:
self._current_drawing += 1
if self._current_drawing == self._drawing_count - 1:
if self._current_drawing > self._drawing_count - 1:
# reached the end to the drawings
self._current_drawing = 0
raise StopIteration()
Expand All @@ -295,16 +385,72 @@ def get_drawing(self, index=None):
:param int index:
The index of the drawing to get.
If `None` (the default) a random drawing will be returned.
If ``None`` (the default) a random drawing will be returned.
"""
if index is None:
return QuickDrawing(self._name, choice(self._drawings))
else:
if index < self.drawing_count - 1:
if index < self.drawing_count:
return QuickDrawing(self._name, self._drawings[index])
else:
raise IndexError("index {} out of range, there are {} drawings".format(index, self.drawing_count))

def search_drawings(self, key_id=None, recognized=None, countrycode=None, timestamp=None):
"""
Searches the drawings in this group.
Returns an list of :class:`QuickDrawing` instances representing the
matched drawings.
Note - search criteria are a compound.
Search for all the drawings with the ``countrycode`` "PL" ::
from quickdraw import QuickDrawDataGroup
anvils = QuickDrawDataGroup("anvil")
results = anvils.search_drawings(countrycode="PL")
:param int key_id:
The ``key_id`` to such for. If ``None`` (the default) the
``key_id`` is not used.
:param bool recognized:
To search for drawings which were ``recognized``. If ``None``
(the default) ``recognized`` is not used.
:param str countrycode:
To search for drawings which with the ``countrycode``. If
``None`` (the default) ``countrycode`` is not used.
:param int countrycode:
To search for drawings which with the ``timestamp``. If ``None``
(the default) ``timestamp`` is not used.
"""
results = []

for drawing in self.drawings:
match = True
if key_id is not None:
if key_id != drawing.key_id:
match = False

if recognized is not None:
if recognized != drawing.recognized:
match = False

if countrycode is not None:
if countrycode != drawing.countrycode:
match = False

if timestamp is not None:
if timestamp != drawing.timestamp:
match = False

if match:
results.append(drawing)

return results

class QuickDrawing():
"""
Expand Down Expand Up @@ -335,7 +481,7 @@ def countrycode(self):
"""
Returns the country code for the drawing.
"""
return self._drawing_data["countrycode"]
return self._drawing_data["countrycode"].decode("utf-8")

@property
def recognized(self):
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

__project__ = 'quickdraw'
__desc__ = 'An API for downloading and reading the google quickdraw data.'
__version__ = '0.0.4'
__version__ = '0.1.0'
__author__ = "Martin O'Hanlon"
__author_email__ = '[email protected]'
__license__ = 'MIT'
Expand Down Expand Up @@ -63,8 +63,8 @@
"""

__classifiers__ = [
"Development Status :: 3 - Alpha",
# "Development Status :: 4 - Beta",
# "Development Status :: 3 - Alpha",
"Development Status :: 4 - Beta",
# "Development Status :: 5 - Production/Stable",
"Intended Audience :: Education",
"Intended Audience :: Developers",
Expand Down
Loading

0 comments on commit 61d8413

Please sign in to comment.