Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow empty tunable values to represent the defaults #868

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"type": ["string", "number", "boolean"]
}
},
"minProperties": 1,
"not": {
"required": ["tunable_values"]
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
{
"$schema": "https://raw.githubusercontent.com/microsoft/MLOS/main/mlos_bench/mlos_bench/config/schemas/tunables/tunable-values-schema.json",

"foo": "bar",
"int": 1,
"float": 1.1,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"$schema": "https://raw.githubusercontent.com/microsoft/MLOS/main/mlos_bench/mlos_bench/config/schemas/tunables/tunable-values-schema.json"
// empty tunable values represents the defaults
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
// empty tunable values represents the defaults
}
18 changes: 18 additions & 0 deletions mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ def test_tunables_assign_unknown_param(tunable_groups: TunableGroups) -> None:
)


def test_tunables_assign_defaults(tunable_groups: TunableGroups) -> None:
"""Make sure we can assign the default values using an empty dictionary."""
tunable_groups_defaults = tunable_groups.copy()
assert tunable_groups.is_defaults()
# Assign the default values.
# Also reset the _is_updated flag to avoid considering that in the comparison.
tunable_groups.assign({}).reset()
assert tunable_groups_defaults == tunable_groups
assert tunable_groups.is_defaults()
tunable_groups.assign({"vmSize": "Standard_B2s"}).reset()
assert tunable_groups_defaults != tunable_groups
assert not tunable_groups.is_defaults()
tunable_groups.assign({}).reset()
assert tunable_groups["vmSize"] != "Standard_B2s"
assert tunable_groups.is_defaults()
assert tunable_groups_defaults == tunable_groups


def test_tunables_assign_categorical(tunable_categorical: Tunable) -> None:
"""Regular assignment for categorical tunable."""
# Must be one of: {"Standard_B2s", "Standard_B2ms", "Standard_B4ms"}
Expand Down
12 changes: 12 additions & 0 deletions mlos_bench/mlos_bench/tunables/tunable_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
#
"""TunableGroups definition."""
import copy
import logging
from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple, Union

from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.tunables.covariant_group import CovariantTunableGroup
from mlos_bench.tunables.tunable import Tunable, TunableValue

_LOG = logging.getLogger(__name__)


class TunableGroups:
"""A collection of covariant groups of tunable parameters."""
Expand Down Expand Up @@ -346,11 +349,20 @@ def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups":
param_values : Mapping[str, TunableValue]
Dictionary mapping Tunable parameter names to new values.

As a special behavior when the mapping is empty the method will restore
the default values rather than no-op.
This allows an empty dictionary in json configs to be used to reset the
tunables to defaults without having to copy the original values from the
tunable_params definition.

Returns
-------
self : TunableGroups
Self-reference for chaining.
"""
if not param_values:
_LOG.info("Empty tunable values set provided. Resetting all tunables to defaults.")
return self.restore_defaults()
for key, value in param_values.items():
self[key] = value
return self
Loading