Skip to content

Commit

Permalink
Fix user and group determination (#19)
Browse files Browse the repository at this point in the history
Closes #18
  • Loading branch information
jmsmkn authored Jan 10, 2024
1 parent a9852c5 commit cebf9fe
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 64 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sagemaker-shim"
version = "0.2.3"
version = "0.2.4"
description = "Adapts algorithms that implement the Grand Challenge inference API for running in SageMaker"
authors = ["James Meakin <[email protected]>"]
license = "Apache-2.0"
Expand Down
135 changes: 88 additions & 47 deletions sagemaker_shim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,70 @@ class InferenceResult(BaseModel):
sagemaker_shim_version: str = version("sagemaker-shim")


class UserGroup(NamedTuple):
class UserInfo(NamedTuple):
uid: int | None
gid: int | None
home: str | None
groups: list[int]


def _get_user_info(id_or_name: str) -> UserInfo:
if id_or_name == "":
return UserInfo(uid=None, gid=None, home=None, groups=[])

try:
user = pwd.getpwnam(id_or_name)
except (KeyError, AttributeError):
try:
uid = int(id_or_name)
except ValueError as error:
raise RuntimeError(f"User '{id_or_name}' not found") from error

try:
user = pwd.getpwuid(uid)
except (KeyError, AttributeError):
return UserInfo(uid=uid, gid=None, home=None, groups=[])

return UserInfo(
uid=user.pw_uid,
gid=user.pw_gid,
home=user.pw_dir,
groups=_get_users_groups(user=user),
)


def _get_users_groups(*, user: pwd.struct_passwd) -> list[int]:
users_groups = [
g.gr_gid for g in grp.getgrall() if user.pw_name in g.gr_mem
]
return _put_gid_first(gid=user.pw_gid, groups=users_groups)


def _put_gid_first(*, gid: int | None, groups: list[int]) -> list[int]:
if gid is None:
return groups
else:
user_groups = set(groups)

try:
user_groups.remove(gid)
except KeyError:
pass

return [gid, *sorted(user_groups)]


def _get_group_id(id_or_name: str) -> int | None:
if id_or_name == "":
return None

try:
return grp.getgrnam(id_or_name).gr_gid
except (KeyError, AttributeError):
try:
return int(id_or_name)
except ValueError as error:
raise RuntimeError(f"Group '{id_or_name}' not found") from error


class InferenceTask(BaseModel):
Expand Down Expand Up @@ -234,6 +294,18 @@ def output_path(self) -> Path:
logger.debug(f"{output_path=}")
return output_path

@property
def extra_groups(self) -> list[int] | None:
if (
os.environ.get(
"GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "True"
).lower()
== "true"
):
return self.proc_user.groups
else:
return None

@cached_property
def _s3_client(self) -> S3Client:
return get_s3_client()
Expand Down Expand Up @@ -288,61 +360,29 @@ def proc_env(self) -> dict[str, str]:

return env

@staticmethod
def _get_user_info(id_or_name: str) -> pwd.struct_passwd | None:
if id_or_name == "":
return None

try:
return pwd.getpwnam(id_or_name)
except (KeyError, AttributeError):
try:
return pwd.getpwuid(int(id_or_name))
except (KeyError, ValueError, AttributeError) as error:
raise RuntimeError(f"User {id_or_name} not found") from error

@staticmethod
def _get_group_info(id_or_name: str) -> grp.struct_group | None:
if id_or_name == "":
return None

try:
return grp.getgrnam(id_or_name)
except (KeyError, AttributeError):
try:
return grp.getgrgid(int(id_or_name))
except (KeyError, ValueError, AttributeError) as error:
raise RuntimeError(f"Group {id_or_name} not found") from error

@cached_property
def proc_user(self) -> UserGroup:
def proc_user(self) -> UserInfo:
if self.user == "":
return UserInfo(uid=None, gid=None, home=None, groups=[])

match = re.fullmatch(
r"^(?P<user>[0-9a-zA-Z]*):?(?P<group>[0-9a-zA-Z]*)$", self.user
)

if match:
user = self._get_user_info(id_or_name=match.group("user"))
group = self._get_group_info(id_or_name=match.group("group"))
info = _get_user_info(id_or_name=match.group("user"))
group_id = _get_group_id(id_or_name=match.group("group"))

if user is None:
uid = None
home = None
else:
uid = user.pw_uid
home = user.pw_dir

if group is None:
if user is None:
gid = None
else:
# Switch to the users primary group
gid = user.pw_gid
else:
gid = group.gr_gid
gid = info.gid if group_id is None else group_id

return UserGroup(uid=uid, gid=gid, home=home)
return UserInfo(
uid=info.uid,
gid=gid,
home=info.home,
groups=_put_gid_first(gid=gid, groups=info.groups),
)
else:
return UserGroup(uid=None, gid=None, home=None)
raise RuntimeError(f"Invalid user '{self.user}'")

async def invoke(self) -> InferenceResult:
"""Run the inference on a single case"""
Expand Down Expand Up @@ -461,6 +501,7 @@ async def execute(self) -> int:
*self.proc_args,
user=self.proc_user.uid,
group=self.proc_user.gid,
extra_groups=self.extra_groups,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=self.proc_env,
Expand Down
1 change: 1 addition & 0 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_invocations_endpoint(client, tmp_path, monkeypatch, capsys, minio):
"GRAND_CHALLENGE_COMPONENT_OUTPUT_PATH",
str(output_path),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")

debug_log = deepcopy(LOGGING_CONFIG)
debug_log["root"]["level"] = "DEBUG"
Expand Down
4 changes: 4 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_inference_from_task_list(
"GRAND_CHALLENGE_COMPONENT_CMD_B64J",
encode_b64j(val=cmd),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")

runner = CliRunner()
runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])
Expand Down Expand Up @@ -142,6 +143,7 @@ def test_inference_from_s3_uri(minio, monkeypatch, cmd, expected_return_code):
"GRAND_CHALLENGE_COMPONENT_CMD_B64J",
encode_b64j(val=cmd),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")

definition_key = f"{uuid4()}/invocations.json"

Expand Down Expand Up @@ -183,6 +185,7 @@ def test_logging_setup(minio, monkeypatch):
"GRAND_CHALLENGE_COMPONENT_CMD_B64J",
encode_b64j(val=["echo", "hello"]),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")

runner = CliRunner()
result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])
Expand Down Expand Up @@ -210,6 +213,7 @@ def test_logging_stderr_setup(minio, monkeypatch):
val=["bash", "-c", "echo 'hello' >> /dev/stderr && exit 1"]
),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")

runner = CliRunner()
result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])
Expand Down
1 change: 1 addition & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ async def test_inference_result_upload(
"GRAND_CHALLENGE_COMPONENT_CMD_B64J",
encode_b64j(val=cmd),
)
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")

direct_invocation = await task.invoke()

Expand Down
Loading

0 comments on commit cebf9fe

Please sign in to comment.