Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move general configuration from Coordinator to DeviceSettings. #235

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions android_env/components/config_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,9 @@ class AdbControllerConfig:


@dataclasses.dataclass
class CoordinatorConfig:
"""Config class for Coordinator."""
class DeviceSettingsConfig:
"""Config class for DeviceSettings."""

# Number of virtual "fingers" of the agent.
num_fingers: int = 1
# Whether to enable keyboard key events.
enable_key_events: bool = False
# Whether to show circles on the screen indicating touch position.
show_touches: bool = True
# Whether to show blue lines on the screen indicating touch position.
Expand All @@ -51,10 +47,24 @@ class CoordinatorConfig:
show_status_bar: bool = False
# Whether or not to show the navigation (bottom) bar.
show_navigation_bar: bool = False


@dataclasses.dataclass
class CoordinatorConfig:
"""Config class for Coordinator."""

# Number of virtual "fingers" of the agent.
num_fingers: int = 1
# Whether to enable keyboard key events.
enable_key_events: bool = False
# Time between periodic restarts in minutes. If > 0, will trigger
# a simulator restart at the beginning of the next episode once the time has
# been reached.
periodic_restart_time_min: float = 0.0
# General Android settings.
device_settings: DeviceSettingsConfig = dataclasses.field(
default_factory=DeviceSettingsConfig
)


@dataclasses.dataclass
Expand Down
105 changes: 13 additions & 92 deletions android_env/components/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from android_env.components import action_type as action_type_lib
from android_env.components import adb_call_parser
from android_env.components import config_classes
from android_env.components import device_settings as device_settings_lib
from android_env.components import errors
from android_env.components import pixel_fns
from android_env.components import specs
Expand All @@ -42,6 +43,7 @@ def __init__(
self,
simulator: base_simulator.BaseSimulator,
task_manager: task_manager_lib.TaskManager,
device_settings: device_settings_lib.DeviceSettings,
config: config_classes.CoordinatorConfig | None = None,
):
"""Handles communication between AndroidEnv and its components.
Expand All @@ -54,12 +56,8 @@ def __init__(
self._simulator = simulator
self._task_manager = task_manager
self._config = config or config_classes.CoordinatorConfig()
self._device_settings = device_settings
self._adb_call_parser: adb_call_parser.AdbCallParser = None
self._orientation = np.zeros(4, dtype=np.uint8)

# The size of the device screen in pixels.
self._screen_width = 0
self._screen_height = 0

# Initialize stats.
self._stats = {
Expand Down Expand Up @@ -91,41 +89,9 @@ def action_spec(self) -> dict[str, dm_env.specs.Array]:

def observation_spec(self) -> dict[str, dm_env.specs.Array]:
return specs.base_observation_spec(
height=self._screen_height, width=self._screen_width
)

def _update_screen_size(self) -> None:
"""Sets the screen size from a screenshot ignoring the color channel."""
screenshot = self._simulator.get_screenshot()
self._screen_height = screenshot.shape[0]
self._screen_width = screenshot.shape[1]

def _update_device_orientation(self) -> None:
"""Updates the current device orientation."""

# Skip fetching the orientation if we already have it.
if not np.all(self._orientation == np.zeros(4)):
logging.info('self._orientation already set, not setting it again')
return

orientation_response = self._adb_call_parser.parse(
adb_pb2.AdbRequest(
get_orientation=adb_pb2.AdbRequest.GetOrientationRequest()
)
height=self._device_settings.screen_height(),
width=self._device_settings.screen_width(),
)
if orientation_response.status != adb_pb2.AdbResponse.Status.OK:
logging.error('Got bad orientation: %r', orientation_response)
return

orientation = orientation_response.get_orientation.orientation
if orientation not in {0, 1, 2, 3}:
logging.error('Got bad orientation: %r', orientation_response)
return

# Transform into one-hot format.
orientation_onehot = np.zeros([4], dtype=np.uint8)
orientation_onehot[orientation] = 1
self._orientation = orientation_onehot

def _should_periodic_relaunch(self) -> bool:
"""Checks if it is time to restart the simulator.
Expand Down Expand Up @@ -178,9 +144,9 @@ def _launch_simulator(self, max_retries: int = 3):
# From here on, the simulator is assumed to be up and running.
self._adb_call_parser = self._create_adb_call_parser()
try:
self._update_settings()
self._device_settings.update(self._config.device_settings)
except errors.AdbControllerError as e:
logging.exception('_update_settings() failed.')
logging.exception('device_settings.update() failed.')
self._stats['relaunch_count_update_settings'] += 1
self._latest_error = e
num_tries += 1
Expand All @@ -205,51 +171,6 @@ def _launch_simulator(self, max_retries: int = 3):
self._stats['relaunch_count'] += 1
break

def _update_settings(self) -> None:
"""Updates some internal state and preferences given in the constructor."""

self._update_screen_size()
self._adb_call_parser.parse(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='show_touches',
value='1' if self._config.show_touches else '0',
),
)
)
)
self._adb_call_parser.parse(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='pointer_location',
value='1' if self._config.show_pointer_location else '0',
),
)
)
)
if self._config.show_navigation_bar and self._config.show_status_bar:
policy_control_value = 'null*'
elif self._config.show_navigation_bar and not self._config.show_status_bar:
policy_control_value = 'immersive.status=*'
elif not self._config.show_navigation_bar and self._config.show_status_bar:
policy_control_value = 'immersive.navigation=*'
else:
policy_control_value = 'immersive.full=*'
self._adb_call_parser.parse(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='policy_control', value=policy_control_value
),
)
)
)

def _create_adb_call_parser(self):
"""Creates a new AdbCallParser instance."""
return adb_call_parser.AdbCallParser(
Expand All @@ -276,16 +197,16 @@ def rl_reset(self) -> dm_env.TimeStep:
if not action_fns.send_action_to_simulator(
action_fns.lift_all_fingers_action(self._config.num_fingers),
self._simulator,
self._screen_width,
self._screen_height,
self._device_settings.screen_width(),
self._device_settings.screen_height(),
self._config.num_fingers,
):
self._stats['relaunch_count_execute_action'] += 1
self._simulator_healthy = False

# Reset the task.
self._task_manager.reset_task()
self._update_device_orientation()
self._device_settings.get_orientation()

# Get data from the simulator.
simulator_signals = self._gather_simulator_signals()
Expand All @@ -307,8 +228,8 @@ def rl_step(self, agent_action: dict[str, np.ndarray]) -> dm_env.TimeStep:
if not action_fns.send_action_to_simulator(
agent_action,
self._simulator,
self._screen_width,
self._screen_height,
self._device_settings.screen_width(),
self._device_settings.screen_height(),
self._config.num_fingers,
):
self._stats['relaunch_count_execute_action'] += 1
Expand Down Expand Up @@ -341,7 +262,7 @@ def _gather_simulator_signals(self) -> dict[str, np.ndarray]:

return {
'pixels': self._simulator.get_screenshot(),
'orientation': self._orientation,
'orientation': self._device_settings.get_orientation(),
'timedelta': np.array(timestamp_delta, dtype=np.int64),
}

Expand Down
68 changes: 6 additions & 62 deletions android_env/components/coordinator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from android_env.components import adb_call_parser
from android_env.components import config_classes
from android_env.components import coordinator as coordinator_lib
from android_env.components import device_settings as device_settings_lib
from android_env.components import errors
from android_env.components import task_manager
from android_env.components.simulators import base_simulator
Expand Down Expand Up @@ -54,7 +55,9 @@ def setUp(self):
autospec=True,
return_value=self._adb_call_parser))
self._coordinator = coordinator_lib.Coordinator(
simulator=self._simulator, task_manager=self._task_manager
simulator=self._simulator,
task_manager=self._task_manager,
device_settings=device_settings_lib.DeviceSettings(self._simulator),
)

def tearDown(self):
Expand Down Expand Up @@ -92,6 +95,7 @@ def test_lift_all_fingers(self, unused_mock_sleep):
self._coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
device_settings=device_settings_lib.DeviceSettings(self._simulator),
config=config_classes.CoordinatorConfig(num_fingers=3),
)
self._coordinator.rl_reset()
Expand Down Expand Up @@ -183,6 +187,7 @@ def test_execute_multitouch_action(self, unused_mock_sleep):
self._coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
device_settings=device_settings_lib.DeviceSettings(self._simulator),
config=config_classes.CoordinatorConfig(num_fingers=3),
)

Expand Down Expand Up @@ -273,67 +278,6 @@ def test_execute_adb_call(self, unused_mock_sleep):
self.assertEqual(response, expected_response)
self._adb_call_parser.parse.assert_called_with(call)

@parameterized.parameters(
(True, '1'),
(False, '0'),
)
@mock.patch.object(time, 'sleep', autospec=True)
def test_touch_indicator(self, show, expected_value, unused_mock_sleep):
_ = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(show_touches=show),
)
self._adb_call_parser.parse.assert_any_call(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='show_touches', value=expected_value))))

@parameterized.parameters(
(True, '1'),
(False, '0'),
)
@mock.patch.object(time, 'sleep', autospec=True)
def test_pointer_location(self, show, expected_value, unused_mock_sleep):
_ = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(show_pointer_location=show),
)
self._adb_call_parser.parse.assert_any_call(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='pointer_location', value=expected_value))))

@parameterized.parameters(
(True, True, 'null*'),
(True, False, 'immersive.status=*'),
(False, True, 'immersive.navigation=*'),
(False, False, 'immersive.full=*'),
(None, None, 'immersive.full=*'), # Defaults to hiding both.
)
@mock.patch.object(time, 'sleep', autospec=True)
def test_bar_visibility(self, show_navigation_bar, show_status_bar,
expected_value, unused_mock_sleep):
_ = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
config=config_classes.CoordinatorConfig(
show_navigation_bar=show_navigation_bar,
show_status_bar=show_status_bar,
),
)
self._adb_call_parser.parse.assert_any_call(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='policy_control', value=expected_value))))


if __name__ == '__main__':
absltest.main()
Loading