From c4719f2e2f2c4b3afc0bb54585b3f9bea776244b Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 23 Mar 2020 23:26:46 +0200 Subject: [PATCH] Add type annotations and fix docstrings --- trains/task.py | 101 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 31 deletions(-) diff --git a/trains/task.py b/trains/task.py index bb7a1306..9171cf85 100644 --- a/trains/task.py +++ b/trains/task.py @@ -12,7 +12,7 @@ except ImportError: from collections import Callable, Sequence -from typing import Optional +from typing import Optional, Union, Mapping, Sequence as TSequence, Any, Dict, List import psutil import six @@ -41,7 +41,7 @@ from .debugging.log import LoggerRoot from .errors import UsageError from .logger import Logger -from .model import InputModel, OutputModel, ARCHIVED_TAG +from .model import Model, InputModel, OutputModel, ARCHIVED_TAG from .task_parameters import TaskParameters from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ argparser_update_currenttask @@ -99,7 +99,8 @@ def _options(cls): def __init__(self, private=None, **kwargs): """ - Do not construct Task manually! + .. warning:: + Do not construct Task manually! **Please use Task.init() or Task.get_task(id=, project=, name=)** """ if private is not Task.__create_protection: @@ -141,6 +142,7 @@ def init( auto_connect_frameworks=True, auto_resource_monitoring=True, ): + # type: (Optional[str], Optional[str], TaskTypes, bool, Optional[str], bool, Union[bool, Mapping[str, bool]], bool) -> Task """ Return the Task object for the main execution task (task context). @@ -369,6 +371,7 @@ def create( task_name=None, task_type=TaskTypes.training, ): + # type: (Optional[str], Optional[str], TaskTypes) -> Task """ Create a new Task object, regardless of the main execution task (Task.init). @@ -403,6 +406,7 @@ def create( @classmethod def get_task(cls, task_id=None, project_name=None, task_name=None): + # type: (Optional[str], Optional[str], Optional[str]) -> Task """ Returns Task object based on either, task_id (system uuid) or task name @@ -415,6 +419,7 @@ def get_task(cls, task_id=None, project_name=None, task_name=None): @classmethod def get_tasks(cls, task_ids=None, project_name=None, task_name=None): + # type: (Optional[TSequence[str]], Optional[str], Optional[str]) -> Task """ Returns a list of Task objects, matching requested task name (or partially matching) @@ -429,10 +434,12 @@ def get_tasks(cls, task_ids=None, project_name=None, task_name=None): @property def output_uri(self): + # type: () -> str return self.storage_uri @output_uri.setter def output_uri(self, value): + # type: (str) -> None # check if we have the correct packages / configuration if value and value != self.storage_uri: from .storage.helper import StorageHelper @@ -445,9 +452,11 @@ def output_uri(self, value): @property def artifacts(self): + # type: () -> Dict[str, Artifact] """ - read-only dictionary of Task artifacts (name, artifact) - :return: dict + Read-only dictionary of Task artifacts (name, artifact) + + :return dict: dictionary of artifacts """ if not Session.check_min_api_version('2.3'): return ReadOnlyDict() @@ -470,6 +479,7 @@ def models(self): @classmethod def clone(cls, source_task=None, name=None, comment=None, parent=None, project=None): + # type: (Optional[Task], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]) -> Task """ Clone a task object, create a copy a task. @@ -502,6 +512,7 @@ def clone(cls, source_task=None, name=None, comment=None, parent=None, project=N @classmethod def enqueue(cls, task, queue_name=None, queue_id=None): + # type: (Task, Optional[str], Optional[str]) -> Any """ Enqueue (send) a task for execution, by adding it to an execution queue @@ -534,6 +545,7 @@ def enqueue(cls, task, queue_name=None, queue_id=None): @classmethod def dequeue(cls, task): + # type: (Union[Task, str]) -> Any """ Dequeue (remove) task from execution queue. @@ -553,10 +565,11 @@ def dequeue(cls, task): return resp def add_tags(self, tags): + # type: (Union[Sequence[str], str]) -> None """ Add tags to this task. Old tags are not deleted - In remote, this is a no-op. + When running remotely, this method has no effect. :param tags: An iterable or space separated string of new tags (string) to add. :type tags: str or iterable of str @@ -570,6 +583,7 @@ def add_tags(self, tags): self._edit(tags=list(set(self.data.tags))) def connect(self, mutable): + # type: (Any) -> Any """ Connect an object to a task (see introduction to Task connect design) @@ -597,6 +611,7 @@ def connect(self, mutable): raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__) def connect_configuration(self, configuration): + # type: (Union[Mapping, Path, str]) -> Union[Mapping, Path, str] """ Connect a configuration dict / file (pathlib.Path / str) with the Task Connecting configuration file should be called before reading the configuration file. @@ -626,14 +641,14 @@ def connect_configuration(self, configuration): # parameter dictionary if isinstance(configuration, dict): def _update_config_dict(task, config_dict): - task.set_model_config(config_dict=config_dict) + task._set_model_config(config_dict=config_dict) if not running_remotely() or not self.is_main_task(): - self.set_model_config(config_dict=configuration) + self._set_model_config(config_dict=configuration) configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration) else: configuration.clear() - configuration.update(self.get_model_config_dict()) + configuration.update(self._get_model_config_dict()) configuration = ProxyDictPreWrite(False, False, **configuration) return configuration @@ -649,10 +664,10 @@ def _update_config_dict(task, config_dict): except Exception: raise ValueError("Could not connect configuration file {}, file could not be read".format( configuration_path.as_posix())) - self.set_model_config(config_text=configuration_text) + self._set_model_config(config_text=configuration_text) return configuration else: - configuration_text = self.get_model_config_text() + configuration_text = self._get_model_config_text() configuration_path = Path(configuration) fd, local_filename = mkstemp(prefix='trains_task_config_', suffix=configuration_path.suffixes[-1] if @@ -662,6 +677,7 @@ def _update_config_dict(task, config_dict): return Path(local_filename) if isinstance(configuration, Path) else local_filename def connect_label_enumeration(self, enumeration): + # type: (Dict[str, int]) -> Dict[str, int] """ Connect a label enumeration dictionary with the Task @@ -686,7 +702,7 @@ def connect_label_enumeration(self, enumeration): def get_logger(self): # type: () -> Logger """ - get a logger object for reporting, for this task context. + Get a logger object for reporting, for this task context. All reports (metrics, text etc.) related to this task are accessible in the web UI :return: Logger object @@ -695,7 +711,7 @@ def get_logger(self): def mark_started(self): """ - Manually Mark the task as started (will happen automatically) + Manually Mark the task as started (happens automatically) """ # UI won't let us see metrics if we're not started self.started() @@ -703,7 +719,7 @@ def mark_started(self): def mark_stopped(self): """ - Manually Mark the task as stopped (also used in self._at_exit) + Manually Mark the task as stopped (also used in :func:`_at_exit`) """ # flush any outstanding logs self.flush(wait_for_uploads=True) @@ -711,8 +727,9 @@ def mark_stopped(self): self.stopped() def flush(self, wait_for_uploads=False): + # type: (bool) -> bool """ - flush any outstanding reports or console logs + Flush any outstanding reports or console logs :param wait_for_uploads: if True the flush will exit only after all outstanding uploads are completed """ @@ -736,6 +753,7 @@ def flush(self, wait_for_uploads=False): return True def reset(self, set_started_on_success=False, force=False): + # type: (bool, bool) -> None """ Reset the task. Task will be reloaded following a successful reset. @@ -759,8 +777,9 @@ def close(self): self.__register_at_exit(None) def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True): + # type: (str, "pandas.DataFrame", Dict, Union[bool, TSequence[str]]) -> None """ - Add artifact for the current Task, used mostly for Data Audition. + Add artifact for the current Task, used mostly for Data Auditing. Currently supported artifacts object types: pandas.DataFrame :param str name: name of the artifacts. Notice! it will override previous artifacts if name already exists. @@ -776,6 +795,7 @@ def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=Tr self._artifacts_manager.register_artifact(name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns) def unregister_artifact(self, name): + # type: (str) -> None """ Remove artifact from the watch list. Notice this will not remove the artifacts from the Task. It will only stop monitoring the artifact, @@ -784,6 +804,7 @@ def unregister_artifact(self, name): self._artifacts_manager.unregister_artifact(name=name) def get_registered_artifacts(self): + # type: () -> Dict[str, Artifact] """ dictionary of Task registered artifacts (name, artifact object) Notice these objects can be modified, changes will be uploaded automatically @@ -793,9 +814,11 @@ def get_registered_artifacts(self): return self._artifacts_manager.registered_artifacts def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upload=False): + # type: (str, Union[str, Mapping, "pandas.DataFrame", "numpy.ndarray", "PIL.Image.Image"], Optional[Mapping], bool) -> bool """ Add static artifact to Task. Artifact file/object will be uploaded in the background - Raise ValueError if artifact_object is not supported + + :raises ValueError: if artifact_object is not supported :param str name: Artifact name. Notice! it will override previous artifact if name already exists :param object artifact_object: Artifact object to upload. Currently supports: @@ -814,8 +837,9 @@ def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upl metadata=metadata, delete_after_upload=delete_after_upload) def get_models(self): + # type: () -> Dict[str, List[Model]] """ - Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task. + Return a dictionary with {'input': [], 'output': []} loaded/stored models of the current Task Input models are files loaded in the task, either manually or automatically logged Output models are files stored in the task, either manually or automatically logged Automatically logged frameworks are for example: TensorFlow, Keras, PyTorch, ScikitLearn(joblib) etc. @@ -828,26 +852,29 @@ def get_models(self): return task_models def is_current_task(self): + # type: () -> bool """ Check if this task is the main task (returned by Task.init()) - NOTE: This call is deprecated. Please use Task.is_main_task() + .. deprecated:: 0.1.0 + Use :func:`is_main_task()` instead If Task.init() was never called, this method will *not* create it, making this test cheaper than Task.init() == task - :return: True if this task is the current task + :return: True if this task is the main task """ return self.is_main_task() def is_main_task(self): + # type: () -> bool """ - Check if this task is the main task (returned by Task.init()) + Check if this task is the main task (created/returned by Task.init()) If Task.init() was never called, this method will *not* create it, making this test cheaper than Task.init() == task - :return: True if this task is the current task + :return: True if this task is the main task """ return self is self.__main_task @@ -876,6 +903,7 @@ def get_model_config_dict(self): return self._get_model_config_dict() def set_model_label_enumeration(self, enumeration=None): + # type: (Optional[Mapping[str, int]]) -> () """ Set Task output label enumeration (before creating an output model) When an output model is created it will inherit these properties @@ -886,6 +914,7 @@ def set_model_label_enumeration(self, enumeration=None): super(Task, self).set_model_label_enumeration(enumeration=enumeration) def get_last_iteration(self): + # type: () -> int """ Return the maximum reported iteration (i.e. the maximum iteration the task reported a metric for) Notice, this is not a cached call, it will ask the backend for the answer (no local caching) @@ -896,6 +925,7 @@ def get_last_iteration(self): return max(self.data.last_iteration, self._reporter.max_iteration if self._reporter else 0) def set_last_iteration(self, last_iteration): + # type: (int) -> None """ Forcefully set the last reported iteration (i.e. the maximum iteration the task reported a metric for) @@ -907,6 +937,7 @@ def set_last_iteration(self, last_iteration): self._edit(last_iteration=self.data.last_iteration) def set_initial_iteration(self, offset=0): + # type: (int) -> int """ Set initial iteration, instead of zero. Useful when continuing training from previous checkpoints @@ -916,17 +947,19 @@ def set_initial_iteration(self, offset=0): return super(Task, self).set_initial_iteration(offset=offset) def get_initial_iteration(self): + # type: () -> int """ - Return the initial iteration offset, default is 0. - Useful when continuing training from previous checkpoints. + Return the initial iteration offset, default is 0 + Useful when continuing training from previous checkpoints :return int: initial iteration offset """ return super(Task, self).get_initial_iteration() def get_last_scalar_metrics(self): + # type: () -> Dict[str, Dict[str, Dict[str, float]]] """ - Extract the last scalar metrics, ordered by title & series in a nested dictionary + Extract the last scalar metrics, ordered by title and series in a nested dictionary :return: dict. Example: {'title': {'series': {'last': 0.5, 'min': 0.1, 'max': 0.9}}} """ @@ -940,34 +973,40 @@ def get_last_scalar_metrics(self): return scalar_metrics def get_parameters_as_dict(self): + # type: () -> Dict """ Get task parameters as a raw nested dict - Note that values are not parsed and returned as is (i.e. string) + + .. note:: + values are not parsed and returned as is (i.e. string) """ return naive_nested_from_flat_dictionary(self.get_parameters()) def set_parameters_as_dict(self, dictionary): + # type: (Dict) -> None """ - Set task parameters from a (possibly nested) dict. + Set task parameters from a (possibly nested) dict While parameters are set just as they would be in connect(dict), this does not link the dict to the task, - but rather does a one-time update. + but rather performs a one-time update. """ self._arguments.copy_from_dict(flatten_dictionary(dictionary)) @classmethod def set_credentials(cls, api_host=None, web_host=None, files_host=None, key=None, secret=None, host=None): + # type: (Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]) -> () """ - Set new default TRAINS-server host and credentials + Set new default trains-server host values and credentials These configurations will be overridden by either OS environment variables or trains.conf configuration file - Notice! credentials needs to be set *prior* to Task initialization + .. note:: + credentials need to be set *prior* to Task initialization :param str api_host: Trains API server url, example: host='http://localhost:8008' :param str web_host: Trains WEB server url, example: host='http://localhost:8080' :param str files_host: Trains Files server url, example: host='http://localhost:8081' :param str key: user key/secret pair, example: key='thisisakey123' :param str secret: user key/secret pair, example: secret='thisisseceret123' - :param str host: host url, example: host='http://localhost:8008' (deprecated) + :param str host: host url (overrides api_host), example: host='http://localhost:8008' """ if api_host: Session.default_host = api_host