From 0f30feef59ca034fd821b8d2f4211c1774d9c358 Mon Sep 17 00:00:00 2001 From: DeepMind Date: Tue, 17 Dec 2024 13:09:04 -0800 Subject: [PATCH] Resolve unsoundness caught by pytype --strict-none-binding. PiperOrigin-RevId: 707220791 --- android_env/components/task_manager.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/android_env/components/task_manager.py b/android_env/components/task_manager.py index 8171854..05b9808 100644 --- a/android_env/components/task_manager.py +++ b/android_env/components/task_manager.py @@ -22,7 +22,7 @@ import json import re import threading -from typing import Any +from typing import Any, Optional from absl import logging from android_env.components import adb_call_parser as adb_call_parser_lib @@ -40,6 +40,11 @@ class TaskManager: """Handles all events and information related to the task.""" + _setup_step_interpreter: setup_step_interpreter.SetupStepInterpreter + _dumpsys_thread: dumpsys_thread.DumpsysThread + _task_start_time: datetime.datetime + _logcat_thread: logcat_thread.LogcatThread + def __init__( self, task: task_pb2.Task, @@ -55,9 +60,6 @@ def __init__( self._task = task self._config = config or config_classes.TaskManagerConfig() self._lock = threading.Lock() - self._logcat_thread = None - self._dumpsys_thread = None - self._setup_step_interpreter = None # Initialize stats. self._stats = { @@ -71,7 +73,6 @@ def __init__( } # Initialize internal state - self._task_start_time = None self._bad_state_counter = 0 self._is_bad_episode = False @@ -84,6 +85,11 @@ def __init__( logging.info('Task config: %s', self._task) + @property + def _logcate_thread_ok(self) -> logcat_thread.LogcatThread: + assert self._logcat_thread is not None + return self._logcat_thread + def stats(self) -> dict[str, Any]: """Returns a dictionary of stats. @@ -109,16 +115,16 @@ def start( """Starts task processing.""" self._start_logcat_thread(log_stream=log_stream) - self._logcat_thread.resume() + self._logcate_thread_ok.resume() self._start_dumpsys_thread(adb_call_parser_factory()) self._start_setup_step_interpreter(adb_call_parser_factory()) def reset_task(self) -> None: """Resets a task for a new run.""" - self._logcat_thread.pause() + self._logcate_thread_ok.pause() self._setup_step_interpreter.interpret(self._task.reset_steps) - self._logcat_thread.resume() + self._logcate_thread_ok.resume() # Reset some other variables. if not self._is_bad_episode: @@ -139,7 +145,7 @@ def rl_reset(self, observation: dict[str, Any]) -> dm_env.TimeStep: self._stats['episode_steps'] = 0 - self._logcat_thread.line_ready().wait() + self._logcate_thread_ok.line_ready().wait() with self._lock: extras = self._get_current_extras() @@ -156,7 +162,7 @@ def rl_step(self, observation: dict[str, Any]) -> dm_env.TimeStep: self._stats['episode_steps'] += 1 - self._logcat_thread.line_ready().wait() + self._logcate_thread_ok.line_ready().wait() with self._lock: reward = self._get_current_reward() extras = self._get_current_extras()