From 58e438cc6b797b5de8e68e39297790f8e50ea821 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 11 Jun 2024 08:16:54 +0200 Subject: [PATCH] feat(python): Support Python `Enum` values in `lit` (#16858) --- py-polars/polars/functions/lit.py | 9 +++++- py-polars/tests/unit/functions/test_lit.py | 37 ++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/py-polars/polars/functions/lit.py b/py-polars/polars/functions/lit.py index b636d5f2544a..3d9a39c071cb 100644 --- a/py-polars/polars/functions/lit.py +++ b/py-polars/polars/functions/lit.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import enum from datetime import date, datetime, time, timedelta, timezone from typing import TYPE_CHECKING, Any @@ -12,7 +13,7 @@ timedelta_to_int, ) from polars._utils.wrap import wrap_expr -from polars.datatypes import Date, Datetime, Duration, Time +from polars.datatypes import Date, Datetime, Duration, Enum, Time from polars.dependencies import _check_for_numpy from polars.dependencies import numpy as np @@ -126,6 +127,12 @@ def lit( elif isinstance(value, (list, tuple)): return lit(pl.Series("literal", [value], dtype=dtype)) + elif isinstance(value, enum.Enum): + lit_value = value.value + if dtype is None and isinstance(value, str): + dtype = Enum(value.__class__.__members__.values()) + return lit(lit_value, dtype=dtype) + if dtype: return wrap_expr(plr.lit(value, allow_object)).cast(dtype) diff --git a/py-polars/tests/unit/functions/test_lit.py b/py-polars/tests/unit/functions/test_lit.py index 5c6fbe1d6ab3..79c7391048b6 100644 --- a/py-polars/tests/unit/functions/test_lit.py +++ b/py-polars/tests/unit/functions/test_lit.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum from datetime import datetime, timedelta from typing import Any @@ -99,6 +100,42 @@ def test_lit_unsupported_type() -> None: pl.lit(pl.LazyFrame({"a": [1, 2, 3]})) +def test_lit_enum_input_16668() -> None: + # https://github.com/pola-rs/polars/issues/16668 + + class State(str, enum.Enum): + VIC = "victoria" + NSW = "new south wales" + + value = State.VIC + + result = pl.lit(value) + assert pl.select(result).dtypes[0] == pl.Enum(["victoria", "new south wales"]) + assert pl.select(result).item() == "victoria" + + result = pl.lit(value, dtype=pl.String) + assert pl.select(result).dtypes[0] == pl.String + assert pl.select(result).item() == "victoria" + + +def test_lit_enum_input_non_string() -> None: + # https://github.com/pola-rs/polars/issues/16668 + + class State(int, enum.Enum): + ONE = 1 + TWO = 2 + + value = State.ONE + + result = pl.lit(value) + assert pl.select(result).dtypes[0] == pl.Int32 + assert pl.select(result).item() == 1 + + result = pl.lit(value, dtype=pl.Int8) + assert pl.select(result).dtypes[0] == pl.Int8 + assert pl.select(result).item() == 1 + + @given(value=datetimes("ns")) def test_datetime_ns(value: datetime) -> None: result = pl.select(pl.lit(value, dtype=pl.Datetime("ns")))["literal"][0]