Skip to content

Commit

Permalink
Remove merge_into and just have merged which copies inputs to avo…
Browse files Browse the repository at this point in the history
…id footguns
  • Loading branch information
reivilibre committed Jan 17, 2024
1 parent 29541fd commit 8c71575
Showing 1 changed file with 33 additions and 27 deletions.
60 changes: 33 additions & 27 deletions docker/configure_workers_and_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import sys
from argparse import ArgumentParser
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
Expand Down Expand Up @@ -321,37 +322,42 @@ def flush_buffers() -> None:
sys.stderr.flush()


def merge_into(dest: Any, new: Any) -> None:
def merged(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]:
"""
Merges `new` into `dest` with the following rules:
Merges `a` and `b` together, returning the result.
The merge is performed with the following rules:
- dicts: values with the same key will be merged recursively
- lists: `new` will be appended to `dest`
- primitives: they will be checked for equality and inequality will result
in a ValueError
It is an error for `dest` and `new` to be of different types.
"""
if isinstance(dest, dict) and isinstance(new, dict):
for k, v in new.items():
if k in dest:
merge_into(dest[k], v)
else:
dest[k] = v
elif isinstance(dest, list) and isinstance(new, list):
dest.extend(new)
elif type(dest) != type(new):
raise TypeError(f"Cannot merge {type(dest).__name__} and {type(new).__name__}")
elif dest != new:
raise ValueError(f"Cannot merge primitive values: {dest!r} != {new!r}")

def merged(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]:
"""
Merges `b` into `a` and returns `a`. Here because we can't use `merge_into`
in a lamba conveniently.
It is an error for `a` and `b` to be of different types.
"""
merge_into(a, b)
if isinstance(a, dict) and isinstance(b, dict):
result = {}
for key in set(a.keys()) | set(b.keys):
if key in a and key in b:
result[key] = merged(a[key], b[key])
elif key in a:
result[key] = deepcopy(a[key])
else:
result[key] = deepcopy(b[key])

return result
elif isinstance(a, list) and isinstance(b, list):
return deepcopy(a) + deepcopy(b)
elif type(a) != type(b):
raise TypeError(f"Cannot merge {type(a).__name__} and {type(b).__name__}")
elif a != b:
raise ValueError(f"Cannot merge primitive values: {a!r} != {b!r}")

if type(a) not in {str, int, float, bool, None.__class__}:
raise TypeError(
f"Cannot use `merged` on type {a} as it may not be safe (must either be an immutable primitive or must have special copy/merge logic)"
)
return a


Expand Down Expand Up @@ -454,10 +460,10 @@ def instantiate_worker_template(
Returns: worker configuration dictionary
"""
worker_config_dict = dataclasses.asdict(template)
stream_writers_dict = {
writer: worker_name for writer in template.stream_writers
}
worker_config_dict["shared_extra_conf"] = merged(template.shared_extra_conf(worker_name), stream_writers_dict)
stream_writers_dict = {writer: worker_name for writer in template.stream_writers}
worker_config_dict["shared_extra_conf"] = merged(
template.shared_extra_conf(worker_name), stream_writers_dict
)
worker_config_dict["endpoint_patterns"] = sorted(template.endpoint_patterns)
worker_config_dict["listener_resources"] = sorted(template.listener_resources)
return worker_config_dict
Expand Down Expand Up @@ -786,7 +792,7 @@ def generate_worker_files(
)

# Update the shared config with any options needed to enable this worker.
merge_into(shared_config, worker_config["shared_extra_conf"])
shared_config = merged(shared_config, worker_config["shared_extra_conf"])

if using_unix_sockets:
healthcheck_urls.append(
Expand Down

0 comments on commit 8c71575

Please sign in to comment.