-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ability to load config from drive
Internal-tag: [#56628] Signed-off-by: Robert Winkler <[email protected]>
- Loading branch information
Showing
3 changed files
with
235 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Copyright (c) 2024 Antmicro <www.antmicro.com> | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import logging | ||
from pathlib import Path | ||
|
||
import pytest | ||
import yaml | ||
|
||
from topwrap.config import Config, ConfigManager, RepositoryEntry | ||
|
||
|
||
class TestConfigManager: | ||
@pytest.fixture | ||
def config_dict(self): | ||
return Config.Schema().dump( | ||
Config( | ||
force_interface_compliance=True, | ||
repositories=[ | ||
RepositoryEntry("My topwrap repo", "~/custom/repo/path"), | ||
], | ||
) | ||
) | ||
|
||
@pytest.fixture | ||
def custom_config_dicts(self): | ||
return [ | ||
( | ||
"custom/path/cfg.yml", | ||
Config.Schema().dump(Config(repositories=[RepositoryEntry("repo1", "path1")])), | ||
), | ||
( | ||
"/global/path/mycfg.yaml", | ||
Config.Schema().dump(Config(repositories=[RepositoryEntry("repo2", "path2")])), | ||
), | ||
] | ||
|
||
@pytest.fixture | ||
def incorrect_config_dicts(self): | ||
return [ | ||
{ | ||
"force_interface_compliance": True, | ||
"repositories": [ | ||
{ | ||
"name": "My topwrap repo", | ||
"path": "~/custom/repo/path", | ||
"info": "Info should not be here", | ||
} | ||
], | ||
}, | ||
{ | ||
"force_interface_compliance": True, | ||
"meta": "A missing 'repositories' entry is correct, an additional custom entry is not", | ||
}, | ||
] | ||
|
||
@staticmethod | ||
def contains_warnings_in_log(caplog): | ||
for name, level, msg in caplog.record_tuples: | ||
if name == "topwrap.config" and level == logging.WARNING: | ||
return True | ||
return False | ||
|
||
def test_adding_repo_duplicates(self, fs, config_dict, caplog): | ||
(repo_dict,) = config_dict["repositories"] | ||
|
||
manager = ConfigManager() | ||
for path in manager.search_paths: | ||
config_str = yaml.dump(config_dict) | ||
fs.create_file(path, contents=config_str) | ||
|
||
config = manager.load() | ||
assert len(config.repositories) == 1 | ||
assert not self.contains_warnings_in_log(caplog) | ||
|
||
def test_loading_order(self, fs, config_dict, caplog): | ||
(repo_dict,) = config_dict["repositories"] | ||
|
||
manager = ConfigManager() | ||
for i, path in enumerate(manager.search_paths): | ||
repo_dict["name"] = str(i) | ||
repo_dict["path"] = str(path) | ||
config_str = yaml.dump(config_dict) | ||
fs.create_file(path, contents=config_str) | ||
|
||
config = manager.load() | ||
assert config.repositories == [ | ||
RepositoryEntry(name=str(i), path=str(manager.search_paths[i])) | ||
for i in reversed(range(len(manager.search_paths))) | ||
] | ||
assert not self.contains_warnings_in_log(caplog) | ||
|
||
def test_custom_search_patchs(self, fs, custom_config_dicts, caplog): | ||
for path, config_dict in custom_config_dicts: | ||
config_str = yaml.dump(config_dict) | ||
fs.create_file(path, contents=config_str) | ||
|
||
paths, config_dicts = zip(*custom_config_dicts) | ||
config = ConfigManager(paths).load() | ||
assert len(config.repositories) == len(config_dicts) | ||
assert not self.contains_warnings_in_log(caplog) | ||
|
||
def test_config_override(self, fs, config_dict, caplog): | ||
config_path = Path(ConfigManager.DEFAULT_SEARCH_PATHS[0]).expanduser() | ||
config_str = yaml.dump(config_dict) | ||
fs.create_file(config_path, contents=config_str) | ||
|
||
manager = ConfigManager() | ||
|
||
(repo_dict,) = config_dict["repositories"] | ||
|
||
config = manager.load() | ||
assert config.force_interface_compliance is True | ||
assert config.repositories == [RepositoryEntry(repo_dict["name"], repo_dict["path"])] | ||
|
||
override_config = Config( | ||
force_interface_compliance=False, | ||
repositories=None, | ||
) | ||
|
||
config2 = manager.load(override_config) | ||
assert config2.force_interface_compliance is False | ||
assert config2.repositories == [RepositoryEntry(repo_dict["name"], repo_dict["path"])] | ||
assert not self.contains_warnings_in_log(caplog) | ||
|
||
def test_loading_incorrect_configs(self, fs, incorrect_config_dicts, caplog): | ||
config_path = Path(ConfigManager.DEFAULT_SEARCH_PATHS[0]).expanduser() | ||
for incorrect_config in incorrect_config_dicts: | ||
manager = ConfigManager() | ||
config_str = yaml.dump(incorrect_config) | ||
fs.create_file(config_path, contents=config_str) | ||
manager.load() | ||
assert self.contains_warnings_in_log(caplog) | ||
config_path.unlink() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,109 @@ | ||
# Copyright (c) 2021-2024 Antmicro <www.antmicro.com> | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import logging | ||
from dataclasses import field | ||
from os import PathLike | ||
from pathlib import Path | ||
from typing import List, Optional | ||
|
||
import marshmallow | ||
import marshmallow_dataclass | ||
import yaml | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class InvalidConfigError(Exception): | ||
"""Raised when the provided configuration is incorrect""" | ||
|
||
|
||
@marshmallow_dataclass.dataclass | ||
class RepositoryEntry: | ||
"""Contains information about topwrap repository""" | ||
|
||
name: str | ||
path: str | ||
|
||
|
||
@marshmallow_dataclass.dataclass | ||
class Config: | ||
"""Configuration class used to store global choices | ||
for behavior of Topwrap. | ||
"""Global topwrap configuration""" | ||
|
||
force_interface_compliance: Optional[bool] = field( | ||
default=False, metadata={"load_default": None} | ||
) | ||
repositories: Optional[List[RepositoryEntry]] = field( | ||
default_factory=list, metadata={"load_default": None} | ||
) | ||
|
||
def update(self, config: "Config"): | ||
if config.force_interface_compliance is not None: | ||
self.force_interface_compliance = config.force_interface_compliance | ||
|
||
if config.repositories is not None: | ||
if self.repositories is None: | ||
self.repositories = config.repositories | ||
else: | ||
for repo in config.repositories: | ||
if repo not in self.repositories: | ||
self.repositories.append(repo) | ||
|
||
|
||
class ConfigManager: | ||
"""Manager used to load topwrap's configuration from files. | ||
The configuration files are loaded in a specific order, which also | ||
determines the priority of settings that are defined differently | ||
in the files. The list of default search paths is defined in | ||
the `DEFAULT_SEARCH_PATH` class variable. Configuration files that | ||
are specified earlier in the list have higher priority and can | ||
overwrite the settings from the files that follow. The default list of | ||
search paths can be changed by passing a different list to | ||
the ConfigManager constructor. | ||
""" | ||
|
||
def __init__(self, force_interface_compliance=False): | ||
self.force_interface_compliance = force_interface_compliance | ||
DEFAULT_SEARCH_PATHS = [ | ||
"topwrap.yaml", | ||
"~/.config/topwrap/topwrap.yaml", | ||
"~/.config/topwrap/config.yaml", | ||
] | ||
|
||
def __init__(self, search_paths: Optional[List[PathLike]] = None): | ||
if search_paths is None: | ||
search_paths = self.DEFAULT_SEARCH_PATHS | ||
|
||
self.search_paths = [] | ||
for path in search_paths: | ||
self.search_paths += [Path(path).expanduser()] | ||
|
||
def load(self, overrides: Optional[Config] = None, default: Optional[Config] = None): | ||
config = Config() if default is None else default | ||
|
||
for path in reversed(self.search_paths): | ||
if not path.is_file(): | ||
continue | ||
|
||
with open(path) as f: | ||
try: | ||
yaml_dict = yaml.safe_load(f) | ||
except yaml.YAMLError: | ||
logger.warning(f"{path} configuration file is not a valid YAML") | ||
continue | ||
|
||
try: | ||
new_config = Config.Schema().load(yaml_dict) | ||
config.update(new_config) | ||
except marshmallow.ValidationError as e: | ||
logger.warning(f"{path} configuration file is not valid ({e.messages})") | ||
continue | ||
|
||
if overrides is not None: | ||
config.update(overrides) | ||
|
||
logger.debug(f"Final configuration used by topwrap: {config}") | ||
|
||
return config | ||
|
||
|
||
config = Config() | ||
config = ConfigManager().load() |