Skip to content

Commit

Permalink
Resolve unsoundness caught by pytype --strict-none-binding.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707220791
  • Loading branch information
DeepMind authored and copybara-github committed Dec 17, 2024
1 parent f612730 commit 0f30fee
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions android_env/components/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import json
import re
import threading
from typing import Any
from typing import Any, Optional

from absl import logging
from android_env.components import adb_call_parser as adb_call_parser_lib
Expand All @@ -40,6 +40,11 @@
class TaskManager:
"""Handles all events and information related to the task."""

_setup_step_interpreter: setup_step_interpreter.SetupStepInterpreter
_dumpsys_thread: dumpsys_thread.DumpsysThread
_task_start_time: datetime.datetime
_logcat_thread: logcat_thread.LogcatThread

def __init__(
self,
task: task_pb2.Task,
Expand All @@ -55,9 +60,6 @@ def __init__(
self._task = task
self._config = config or config_classes.TaskManagerConfig()
self._lock = threading.Lock()
self._logcat_thread = None
self._dumpsys_thread = None
self._setup_step_interpreter = None

# Initialize stats.
self._stats = {
Expand All @@ -71,7 +73,6 @@ def __init__(
}

# Initialize internal state
self._task_start_time = None
self._bad_state_counter = 0
self._is_bad_episode = False

Expand All @@ -84,6 +85,11 @@ def __init__(

logging.info('Task config: %s', self._task)

@property
def _logcate_thread_ok(self) -> logcat_thread.LogcatThread:
assert self._logcat_thread is not None
return self._logcat_thread

def stats(self) -> dict[str, Any]:
"""Returns a dictionary of stats.
Expand All @@ -109,16 +115,16 @@ def start(
"""Starts task processing."""

self._start_logcat_thread(log_stream=log_stream)
self._logcat_thread.resume()
self._logcate_thread_ok.resume()
self._start_dumpsys_thread(adb_call_parser_factory())
self._start_setup_step_interpreter(adb_call_parser_factory())

def reset_task(self) -> None:
"""Resets a task for a new run."""

self._logcat_thread.pause()
self._logcate_thread_ok.pause()
self._setup_step_interpreter.interpret(self._task.reset_steps)
self._logcat_thread.resume()
self._logcate_thread_ok.resume()

# Reset some other variables.
if not self._is_bad_episode:
Expand All @@ -139,7 +145,7 @@ def rl_reset(self, observation: dict[str, Any]) -> dm_env.TimeStep:

self._stats['episode_steps'] = 0

self._logcat_thread.line_ready().wait()
self._logcate_thread_ok.line_ready().wait()
with self._lock:
extras = self._get_current_extras()

Expand All @@ -156,7 +162,7 @@ def rl_step(self, observation: dict[str, Any]) -> dm_env.TimeStep:

self._stats['episode_steps'] += 1

self._logcat_thread.line_ready().wait()
self._logcate_thread_ok.line_ready().wait()
with self._lock:
reward = self._get_current_reward()
extras = self._get_current_extras()
Expand Down

0 comments on commit 0f30fee

Please sign in to comment.