Skip to content

Commit

Permalink
feat(python): Support Python Enum values in lit (#16858)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Jun 11, 2024
1 parent 1a2707d commit 58e438c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
9 changes: 8 additions & 1 deletion py-polars/polars/functions/lit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import enum
from datetime import date, datetime, time, timedelta, timezone
from typing import TYPE_CHECKING, Any

Expand All @@ -12,7 +13,7 @@
timedelta_to_int,
)
from polars._utils.wrap import wrap_expr
from polars.datatypes import Date, Datetime, Duration, Time
from polars.datatypes import Date, Datetime, Duration, Enum, Time
from polars.dependencies import _check_for_numpy
from polars.dependencies import numpy as np

Expand Down Expand Up @@ -126,6 +127,12 @@ def lit(
elif isinstance(value, (list, tuple)):
return lit(pl.Series("literal", [value], dtype=dtype))

elif isinstance(value, enum.Enum):
lit_value = value.value
if dtype is None and isinstance(value, str):
dtype = Enum(value.__class__.__members__.values())
return lit(lit_value, dtype=dtype)

if dtype:
return wrap_expr(plr.lit(value, allow_object)).cast(dtype)

Expand Down
37 changes: 37 additions & 0 deletions py-polars/tests/unit/functions/test_lit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
from datetime import datetime, timedelta
from typing import Any

Expand Down Expand Up @@ -99,6 +100,42 @@ def test_lit_unsupported_type() -> None:
pl.lit(pl.LazyFrame({"a": [1, 2, 3]}))


def test_lit_enum_input_16668() -> None:
# https://github.com/pola-rs/polars/issues/16668

class State(str, enum.Enum):
VIC = "victoria"
NSW = "new south wales"

value = State.VIC

result = pl.lit(value)
assert pl.select(result).dtypes[0] == pl.Enum(["victoria", "new south wales"])
assert pl.select(result).item() == "victoria"

result = pl.lit(value, dtype=pl.String)
assert pl.select(result).dtypes[0] == pl.String
assert pl.select(result).item() == "victoria"


def test_lit_enum_input_non_string() -> None:
# https://github.com/pola-rs/polars/issues/16668

class State(int, enum.Enum):
ONE = 1
TWO = 2

value = State.ONE

result = pl.lit(value)
assert pl.select(result).dtypes[0] == pl.Int32
assert pl.select(result).item() == 1

result = pl.lit(value, dtype=pl.Int8)
assert pl.select(result).dtypes[0] == pl.Int8
assert pl.select(result).item() == 1


@given(value=datetimes("ns"))
def test_datetime_ns(value: datetime) -> None:
result = pl.select(pl.lit(value, dtype=pl.Datetime("ns")))["literal"][0]
Expand Down

0 comments on commit 58e438c

Please sign in to comment.