Skip to content

Commit

Permalink
Remove Coordinator.{load,save}_state().
Browse files Browse the repository at this point in the history
Now that `AndroidEnv` has direct access to the simulator, we don't need the
`Coordinator` to forward these requests anymore. This simplifies the design and
the communication between the major components.

PiperOrigin-RevId: 683591042
  • Loading branch information
kenjitoyama authored and copybara-github committed Oct 8, 2024
1 parent 65ee54c commit 1a9ab4b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 70 deletions.
38 changes: 0 additions & 38 deletions android_env/components/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import copy
import socket
import tempfile
import threading
import time
from typing import Any
Expand All @@ -32,7 +31,6 @@
from android_env.components import task_manager as task_manager_lib
from android_env.components.simulators import base_simulator
from android_env.proto import adb_pb2
from android_env.proto import state_pb2
import dm_env
import numpy as np

Expand Down Expand Up @@ -455,42 +453,6 @@ def stats(self) -> dict[str, Any]:

return copy.deepcopy(self._stats)

def load_state(
self, request: state_pb2.LoadStateRequest
) -> state_pb2.LoadStateResponse:
"""Loads a state.
Args:
request: A `LoadStateRequest` containing any parameters necessary to
specify how/what state to load.
Returns:
A `LoadStateResponse` containing the status, error message (if
applicable), and any other relevant information.
"""
self._task_manager.stop()
response = self._simulator.load_state(request)
self._task_manager.start(
adb_call_parser_factory=self._create_adb_call_parser,
log_stream=self._simulator.create_log_stream(),
)
return response

def save_state(
self, request: state_pb2.SaveStateRequest
) -> state_pb2.SaveStateResponse:
"""Saves a state.
Args:
request: A `SaveStateRequest` containing any parameters necessary to
specify how/what state to save.
Returns:
A `SaveStateResponse` containing the status, error message (if
applicable), and any other relevant information.
"""
return self._simulator.save_state(request)

def close(self):
"""Cleans up the state of this Coordinator."""
if self._interaction_thread is not None:
Expand Down
24 changes: 0 additions & 24 deletions android_env/components/coordinator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,30 +467,6 @@ def test_bar_visibility(self, show_navigation_bar, show_status_bar,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='policy_control', value=expected_value))))

def test_load_state(self):
expected_response = state_pb2.LoadStateResponse(
status=state_pb2.LoadStateResponse.Status.OK
)
request = state_pb2.LoadStateRequest(args={'foo': 'bar'})
self._simulator.load_state.return_value = expected_response
stop_call_count = self._task_manager.stop.call_count
start_call_count = self._task_manager.start.call_count
response = self._coordinator.load_state(request)
self.assertEqual(response, expected_response)
self._simulator.load_state.assert_called_once_with(request)
self.assertEqual(self._task_manager.stop.call_count, stop_call_count + 1)
self.assertEqual(self._task_manager.start.call_count, start_call_count + 1)

def test_save_state(self):
expected_response = state_pb2.SaveStateResponse(
status=state_pb2.SaveStateResponse.Status.OK
)
request = state_pb2.SaveStateRequest(args={'foo': 'bar'})
self._simulator.save_state.return_value = expected_response
response = self._coordinator.save_state(request)
self.assertEqual(response, expected_response)
self._simulator.save_state.assert_called_once_with(request)


if __name__ == '__main__':
absltest.main()
15 changes: 13 additions & 2 deletions android_env/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from absl import logging
from android_env import env_interface
from android_env.components import adb_call_parser
from android_env.components import coordinator as coordinator_lib
from android_env.components import task_manager as task_manager_lib
from android_env.components.simulators import base_simulator
Expand Down Expand Up @@ -161,7 +162,16 @@ def load_state(
A `LoadStateResponse` containing the status, error message (if
applicable), and any other relevant information.
"""
return self._coordinator.load_state(request)

self._task_manager.stop()
response = self._simulator.load_state(request)
self._task_manager.start(
adb_call_parser_factory=lambda: adb_call_parser.AdbCallParser(
self._simulator.create_adb_controller()
),
log_stream=self._simulator.create_log_stream(),
)
return response

def save_state(
self, request: state_pb2.SaveStateRequest
Expand All @@ -176,4 +186,5 @@ def save_state(
A `SaveStateResponse` containing the status, error message (if
applicable), and any other relevant information.
"""
return self._coordinator.save_state(request)

return self._simulator.save_state(request)
15 changes: 9 additions & 6 deletions android_env/environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from android_env.components import config_classes
from android_env.components import coordinator as coordinator_lib
from android_env.components import task_manager as task_manager_lib
from android_env.components.simulators import base_simulator
from android_env.components.simulators.fake import fake_simulator
from android_env.proto import adb_pb2
from android_env.proto import state_pb2
Expand Down Expand Up @@ -223,7 +224,7 @@ def test_adb_call(self):
coordinator.execute_adb_call.assert_called_once_with(call)

def test_load_state(self):
simulator = _create_fake_simulator()
simulator = mock.create_autospec(base_simulator.BaseSimulator)
coordinator = _create_mock_coordinator()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
Expand All @@ -233,13 +234,15 @@ def test_load_state(self):
status=state_pb2.LoadStateResponse.Status.OK
)
request = state_pb2.LoadStateRequest(args={'foo': 'bar'})
coordinator.load_state.return_value = expected_response
simulator.load_state.return_value = expected_response
response = env.load_state(request)
self.assertEqual(response, expected_response)
coordinator.load_state.assert_called_once_with(request)
simulator.load_state.assert_called_once_with(request)
task_manager.stop.assert_called_once()
task_manager.start.assert_called_once()

def test_save_state(self):
simulator = _create_fake_simulator()
simulator = mock.create_autospec(base_simulator.BaseSimulator)
coordinator = _create_mock_coordinator()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
Expand All @@ -249,10 +252,10 @@ def test_save_state(self):
status=state_pb2.SaveStateResponse.Status.OK
)
request = state_pb2.SaveStateRequest(args={'foo': 'bar'})
coordinator.save_state.return_value = expected_response
simulator.save_state.return_value = expected_response
response = env.save_state(request)
self.assertEqual(response, expected_response)
coordinator.save_state.assert_called_once_with(request)
simulator.save_state.assert_called_once_with(request)

def test_double_close(self):
simulator = _create_fake_simulator()
Expand Down

0 comments on commit 1a9ab4b

Please sign in to comment.