Skip to content

Commit

Permalink
Move general configuration from Coordinator to DeviceSettings.
Browse files Browse the repository at this point in the history
This change moves some device configuration settings out of `Coordinator` to a
new class `DeviceSettings`, which for now contains logic for setting visual
debug ("show touch" and "show pointer location"), status/navigation bars, and
the device orientation. The class is also thoroughly tested.

PiperOrigin-RevId: 688949117
  • Loading branch information
kenjitoyama authored and copybara-github committed Oct 23, 2024
1 parent 83daeb2 commit 6e73777
Show file tree
Hide file tree
Showing 7 changed files with 436 additions and 163 deletions.
22 changes: 16 additions & 6 deletions android_env/components/config_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,9 @@ class AdbControllerConfig:


@dataclasses.dataclass
class CoordinatorConfig:
"""Config class for Coordinator."""
class DeviceSettingsConfig:
"""Config class for DeviceSettings."""

# Number of virtual "fingers" of the agent.
num_fingers: int = 1
# Whether to enable keyboard key events.
enable_key_events: bool = False
# Whether to show circles on the screen indicating touch position.
show_touches: bool = True
# Whether to show blue lines on the screen indicating touch position.
Expand All @@ -51,10 +47,24 @@ class CoordinatorConfig:
show_status_bar: bool = False
# Whether or not to show the navigation (bottom) bar.
show_navigation_bar: bool = False


@dataclasses.dataclass
class CoordinatorConfig:
"""Config class for Coordinator."""

# Number of virtual "fingers" of the agent.
num_fingers: int = 1
# Whether to enable keyboard key events.
enable_key_events: bool = False
# Time between periodic restarts in minutes. If > 0, will trigger
# a simulator restart at the beginning of the next episode once the time has
# been reached.
periodic_restart_time_min: float = 0.0
# General Android settings.
device_settings: DeviceSettingsConfig = dataclasses.field(
default_factory=DeviceSettingsConfig
)


@dataclasses.dataclass
Expand Down
105 changes: 13 additions & 92 deletions android_env/components/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from android_env.components import action_type as action_type_lib
from android_env.components import adb_call_parser
from android_env.components import config_classes
from android_env.components import device_settings as device_settings_lib
from android_env.components import errors
from android_env.components import pixel_fns
from android_env.components import specs
Expand All @@ -42,6 +43,7 @@ def __init__(
self,
simulator: base_simulator.BaseSimulator,
task_manager: task_manager_lib.TaskManager,
device_settings: device_settings_lib.DeviceSettings,
config: config_classes.CoordinatorConfig | None = None,
):
"""Handles communication between AndroidEnv and its components.
Expand All @@ -54,12 +56,8 @@ def __init__(
self._simulator = simulator
self._task_manager = task_manager
self._config = config or config_classes.CoordinatorConfig()
self._device_settings = device_settings
self._adb_call_parser: adb_call_parser.AdbCallParser = None
self._orientation = np.zeros(4, dtype=np.uint8)

# The size of the device screen in pixels.
self._screen_width = 0
self._screen_height = 0

# Initialize stats.
self._stats = {
Expand Down Expand Up @@ -91,41 +89,9 @@ def action_spec(self) -> dict[str, dm_env.specs.Array]:

def observation_spec(self) -> dict[str, dm_env.specs.Array]:
return specs.base_observation_spec(
height=self._screen_height, width=self._screen_width
)

def _update_screen_size(self) -> None:
"""Sets the screen size from a screenshot ignoring the color channel."""
screenshot = self._simulator.get_screenshot()
self._screen_height = screenshot.shape[0]
self._screen_width = screenshot.shape[1]

def _update_device_orientation(self) -> None:
"""Updates the current device orientation."""

# Skip fetching the orientation if we already have it.
if not np.all(self._orientation == np.zeros(4)):
logging.info('self._orientation already set, not setting it again')
return

orientation_response = self._adb_call_parser.parse(
adb_pb2.AdbRequest(
get_orientation=adb_pb2.AdbRequest.GetOrientationRequest()
)
height=self._device_settings.screen_height(),
width=self._device_settings.screen_width(),
)
if orientation_response.status != adb_pb2.AdbResponse.Status.OK:
logging.error('Got bad orientation: %r', orientation_response)
return

orientation = orientation_response.get_orientation.orientation
if orientation not in {0, 1, 2, 3}:
logging.error('Got bad orientation: %r', orientation_response)
return

# Transform into one-hot format.
orientation_onehot = np.zeros([4], dtype=np.uint8)
orientation_onehot[orientation] = 1
self._orientation = orientation_onehot

def _should_periodic_relaunch(self) -> bool:
"""Checks if it is time to restart the simulator.
Expand Down Expand Up @@ -178,9 +144,9 @@ def _launch_simulator(self, max_retries: int = 3):
# From here on, the simulator is assumed to be up and running.
self._adb_call_parser = self._create_adb_call_parser()
try:
self._update_settings()
self._device_settings.update(self._config.device_settings)
except errors.AdbControllerError as e:
logging.exception('_update_settings() failed.')
logging.exception('device_settings.update() failed.')
self._stats['relaunch_count_update_settings'] += 1
self._latest_error = e
num_tries += 1
Expand All @@ -205,51 +171,6 @@ def _launch_simulator(self, max_retries: int = 3):
self._stats['relaunch_count'] += 1
break

def _update_settings(self) -> None:
"""Updates some internal state and preferences given in the constructor."""

self._update_screen_size()
self._adb_call_parser.parse(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='show_touches',
value='1' if self._config.show_touches else '0',
),
)
)
)
self._adb_call_parser.parse(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='pointer_location',
value='1' if self._config.show_pointer_location else '0',
),
)
)
)
if self._config.show_navigation_bar and self._config.show_status_bar:
policy_control_value = 'null*'
elif self._config.show_navigation_bar and not self._config.show_status_bar:
policy_control_value = 'immersive.status=*'
elif not self._config.show_navigation_bar and self._config.show_status_bar:
policy_control_value = 'immersive.navigation=*'
else:
policy_control_value = 'immersive.full=*'
self._adb_call_parser.parse(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='policy_control', value=policy_control_value
),
)
)
)

def _create_adb_call_parser(self):
"""Creates a new AdbCallParser instance."""
return adb_call_parser.AdbCallParser(
Expand All @@ -276,16 +197,16 @@ def rl_reset(self) -> dm_env.TimeStep:
if not action_fns.send_action_to_simulator(
action_fns.lift_all_fingers_action(self._config.num_fingers),
self._simulator,
self._screen_width,
self._screen_height,
self._device_settings.screen_width(),
self._device_settings.screen_height(),
self._config.num_fingers,
):
self._stats['relaunch_count_execute_action'] += 1
self._simulator_healthy = False

# Reset the task.
self._task_manager.reset_task()
self._update_device_orientation()
self._device_settings.get_orientation()

# Get data from the simulator.
simulator_signals = self._gather_simulator_signals()
Expand All @@ -307,8 +228,8 @@ def rl_step(self, agent_action: dict[str, np.ndarray]) -> dm_env.TimeStep:
if not action_fns.send_action_to_simulator(
agent_action,
self._simulator,
self._screen_width,
self._screen_height,
self._device_settings.screen_width(),
self._device_settings.screen_height(),
self._config.num_fingers,
):
self._stats['relaunch_count_execute_action'] += 1
Expand Down Expand Up @@ -341,7 +262,7 @@ def _gather_simulator_signals(self) -> dict[str, np.ndarray]:

return {
'pixels': self._simulator.get_screenshot(),
'orientation': self._orientation,
'orientation': self._device_settings.get_orientation(),
'timedelta': np.array(timestamp_delta, dtype=np.int64),
}

Expand Down
68 changes: 6 additions & 62 deletions android_env/components/coordinator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from android_env.components import adb_call_parser
from android_env.components import config_classes
from android_env.components import coordinator as coordinator_lib
from android_env.components import device_settings as device_settings_lib
from android_env.components import errors
from android_env.components import task_manager
from android_env.components.simulators import base_simulator
Expand Down Expand Up @@ -54,7 +55,9 @@ def setUp(self):
autospec=True,
return_value=self._adb_call_parser))
self._coordinator = coordinator_lib.Coordinator(
simulator=self._simulator, task_manager=self._task_manager
simulator=self._simulator,
task_manager=self._task_manager,
device_settings=device_settings_lib.DeviceSettings(self._simulator),
)

def tearDown(self):
Expand Down Expand Up @@ -92,6 +95,7 @@ def test_lift_all_fingers(self, unused_mock_sleep):
self._coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
device_settings=device_settings_lib.DeviceSettings(self._simulator),
config=config_classes.CoordinatorConfig(num_fingers=3),
)
self._coordinator.rl_reset()
Expand Down Expand Up @@ -183,6 +187,7 @@ def test_execute_multitouch_action(self, unused_mock_sleep):
self._coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
device_settings=device_settings_lib.DeviceSettings(self._simulator),
config=config_classes.CoordinatorConfig(num_fingers=3),
)

Expand Down Expand Up @@ -273,67 +278,6 @@ def test_execute_adb_call(self, unused_mock_sleep):
self.assertEqual(response, expected_response)
self._adb_call_parser.parse.assert_called_with(call)

@parameterized.parameters(
(True, '1'),
(False, '0'),
)
@mock.patch.object(time, 'sleep', autospec=True)
def test_touch_indicator(self, show, expected_value, unused_mock_sleep):
_ = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(show_touches=show),
)
self._adb_call_parser.parse.assert_any_call(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='show_touches', value=expected_value))))

@parameterized.parameters(
(True, '1'),
(False, '0'),
)
@mock.patch.object(time, 'sleep', autospec=True)
def test_pointer_location(self, show, expected_value, unused_mock_sleep):
_ = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(show_pointer_location=show),
)
self._adb_call_parser.parse.assert_any_call(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='pointer_location', value=expected_value))))

@parameterized.parameters(
(True, True, 'null*'),
(True, False, 'immersive.status=*'),
(False, True, 'immersive.navigation=*'),
(False, False, 'immersive.full=*'),
(None, None, 'immersive.full=*'), # Defaults to hiding both.
)
@mock.patch.object(time, 'sleep', autospec=True)
def test_bar_visibility(self, show_navigation_bar, show_status_bar,
expected_value, unused_mock_sleep):
_ = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(
show_navigation_bar=show_navigation_bar,
show_status_bar=show_status_bar,
),
)
self._adb_call_parser.parse.assert_any_call(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='policy_control', value=expected_value))))


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 6e73777

Please sign in to comment.