diff --git a/src/gluonts/time_feature/_base.py b/src/gluonts/time_feature/_base.py index 0d88971002..3aa53a55ef 100644 --- a/src/gluonts/time_feature/_base.py +++ b/src/gluonts/time_feature/_base.py @@ -11,6 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. +from packaging.version import Version from typing import Any, Callable, Dict, List import numpy as np @@ -196,7 +197,10 @@ def norm_freq_str(freq_str: str) -> str: # Note: Secondly ("S") frequency exists, where we don't want to remove the # "S"! if len(base_freq) >= 2 and base_freq.endswith("S"): - return base_freq[:-1] + base_freq = base_freq[:-1] + # In pandas >= 2.2, period end frequencies have been renamed, e.g. "M" -> "ME" + if Version(pd.__version__) >= Version("2.2.0"): + base_freq += "E" return base_freq @@ -252,17 +256,13 @@ def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: Unsupported frequency {freq_str} The following frequencies are supported: - - Y - yearly - alias: A - Q - quarterly - M - monthly - W - weekly - D - daily - B - business days - H - hourly - T - minutely - alias: min - S - secondly + """ + + for offset_cls in features_by_offsets: + offset = offset_cls() + supported_freq_msg += ( + f"\t{offset.freqstr.split('-')[0]} - {offset_cls.__name__}" + ) + raise RuntimeError(supported_freq_msg) diff --git a/src/gluonts/time_feature/seasonality.py b/src/gluonts/time_feature/seasonality.py index 62026fc691..9cb2581a24 100644 --- a/src/gluonts/time_feature/seasonality.py +++ b/src/gluonts/time_feature/seasonality.py @@ -33,6 +33,7 @@ "ME": 12, "B": 5, "Q": 4, + "QE": 4, } diff --git a/test/time_feature/__init__.py b/test/time_feature/__init__.py new file mode 100644 index 0000000000..f342912f9b --- /dev/null +++ b/test/time_feature/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. diff --git a/test/time_feature/common.py b/test/time_feature/common.py new file mode 100644 index 0000000000..89e19a23f8 --- /dev/null +++ b/test/time_feature/common.py @@ -0,0 +1,28 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import pandas as pd +from packaging.version import Version + +if Version(pd.__version__) <= Version("2.2.0"): + S = "S" + H = "H" + M = "M" + Q = "Q" + Y = "A" +else: + S = "s" + H = "h" + M = "ME" + Q = "QE" + Y = "YE" diff --git a/test/time_feature/test_agg_lags.py b/test/time_feature/test_agg_lags.py index dd3b2f2d9b..6e299e498d 100644 --- a/test/time_feature/test_agg_lags.py +++ b/test/time_feature/test_agg_lags.py @@ -16,7 +16,6 @@ import pytest from gluonts.dataset.common import ListDataset - from gluonts.dataset.field_names import FieldName from gluonts.transform import AddAggregateLags diff --git a/test/time_feature/test_base.py b/test/time_feature/test_base.py index 8e249eba86..e448156b2b 100644 --- a/test/time_feature/test_base.py +++ b/test/time_feature/test_base.py @@ -11,21 +11,23 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. +import pytest from pandas.tseries.frequencies import to_offset from gluonts.time_feature import norm_freq_str +from .common import M, Q, S, Y -def test_norm_freq_str(): - assert norm_freq_str(to_offset("Y").name) in ["A", "YE"] - assert norm_freq_str(to_offset("YS").name) in ["A", "Y"] - assert norm_freq_str(to_offset("A").name) in ["A", "YE"] - assert norm_freq_str(to_offset("AS").name) in ["A", "Y"] - assert norm_freq_str(to_offset("Q").name) in ["Q", "QE"] - assert norm_freq_str(to_offset("QS").name) == "Q" - - assert norm_freq_str(to_offset("M").name) in ["M", "ME"] - assert norm_freq_str(to_offset("MS").name) in ["M", "ME"] - - assert norm_freq_str(to_offset("S").name) in ["S", "s"] +@pytest.mark.parametrize( + " aliases, normalized_freq_str", + [ + (["Y", "YS", "A", "AS"], Y), + (["Q", "QS"], Q), + (["M", "MS"], M), + (["S"], S), + ], +) +def test_norm_freq_str(aliases, normalized_freq_str): + for alias in aliases: + assert norm_freq_str(to_offset(alias).name) == normalized_freq_str diff --git a/test/time_feature/test_features.py b/test/time_feature/test_features.py index 96c590fbf2..1c59db9909 100644 --- a/test/time_feature/test_features.py +++ b/test/time_feature/test_features.py @@ -16,7 +16,6 @@ import pytest from gluonts import zebras as zb - from gluonts.time_feature import ( Constant, TimeFeature, diff --git a/test/time_feature/test_lag.py b/test/time_feature/test_lag.py index 951a5f9cb4..2ce9651e0c 100644 --- a/test/time_feature/test_lag.py +++ b/test/time_feature/test_lag.py @@ -15,12 +15,16 @@ Test the lags computed for different frequencies. """ +import pytest + import gluonts.time_feature.lag as date_feature_set +from .common import H, M, Q, Y + # These are the expected lags for common frequencies and corner cases. # By default all frequencies have the following lags: [1, 2, 3, 4, 5, 6, 7]. # Remaining lags correspond to the same `season` (+/- `delta`) in previous `k` cycles. -expected_lags = { +EXPECTED_LAGS = { # (apart from the default lags) centered around each of the last 3 hours (delta = 2) "4S": [ 1, @@ -179,7 +183,7 @@ ] + [329, 330, 331, 494, 495, 496, 659, 660, 661, 707, 708, 709], # centered around each of the last 3 hours (delta = 2) + last 7 days (delta = 1) + last 6 weeks (delta = 1) - "H": [1, 2, 3, 4, 5, 6, 7] + H: [1, 2, 3, 4, 5, 6, 7] + [ 23, 24, @@ -206,7 +210,7 @@ + [335, 336, 337, 503, 504, 505, 671, 672, 673, 719, 720, 721], # centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) + # last 8th and 12th weeks (delta = 0) - "6H": [ + ("6" + H): [ 1, 2, 3, @@ -237,21 +241,21 @@ + [224, 336], # centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) + # last 8th and 12th weeks (delta = 0) + last year (delta = 1) - "12H": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + ("12" + H): [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + [27, 28, 29, 41, 42, 43, 55, 56, 57] + [59, 60, 61] + [112, 168] + [727, 728, 729], # centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) + # last 8th and 12th weeks (delta = 0) + last 3 years (delta = 1) - "23H": [1, 2, 3, 4, 5, 6, 7, 8] + ("23" + H): [1, 2, 3, 4, 5, 6, 7, 8] + [13, 14, 15, 20, 21, 22, 28, 29] + [30, 31, 32] + [58, 87] + [378, 379, 380, 758, 759, 760, 1138, 1139, 1140], # centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) + # last 8th and 12th weeks (delta = 0) + last 3 years (delta = 1) - "25H": [1, 2, 3, 4, 5, 6, 7] + ("25" + H): [1, 2, 3, 4, 5, 6, 7] + [12, 13, 14, 19, 20, 21, 25, 26, 27] + [28, 29] + [53, 80] @@ -285,64 +289,31 @@ # centered around each of the last 3 years (delta = 1) "5W": [1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 19, 20, 21, 30, 31, 32], # centered around each of the last 3 years (delta = 1) - "M": [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 23, 24, 25, 35, 36, 37], + M: [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 23, 24, 25, 35, 36, 37], # default - "6M": [1, 2, 3, 4, 5, 6, 7], + "6" + M: [1, 2, 3, 4, 5, 6, 7], # default - "12M": [1, 2, 3, 4, 5, 6, 7], + "12" + M: [1, 2, 3, 4, 5, 6, 7], + Q: [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13], + "QS": [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13], + Y: [1, 2, 3, 4, 5, 6, 7], + "YS": [1, 2, 3, 4, 5, 6, 7], } # For the default multiple (1) -for freq in ["min", "H", "D", "W", "M"]: - expected_lags["1" + freq] = expected_lags[freq] +for freq in ["min", H, "D", "W", M]: + EXPECTED_LAGS["1" + freq] = EXPECTED_LAGS[freq] # For frequencies that do not have unique form -expected_lags["60min"] = expected_lags["1H"] -expected_lags["24H"] = expected_lags["1D"] -expected_lags["7D"] = expected_lags["1W"] - - -def test_lags(): - freq_strs = [ - "4S", - "min", - "1min", - "15min", - "30min", - "59min", - "60min", - "61min", - "H", - "1H", - "6H", - "12H", - "23H", - "24H", - "25H", - "D", - "1D", - "2D", - "6D", - "7D", - "8D", - "W", - "1W", - "3W", - "4W", - "5W", - "M", - "6M", - "12M", - ] +EXPECTED_LAGS["60min"] = EXPECTED_LAGS["1" + H] +EXPECTED_LAGS["24" + H] = EXPECTED_LAGS["1D"] +EXPECTED_LAGS["7D"] = EXPECTED_LAGS["1W"] - for freq_str in freq_strs: - lags = date_feature_set.get_lags_for_frequency(freq_str) - assert ( - lags == expected_lags[freq_str] - ), "lags do not match for the frequency '{}':\nexpected: {},\nprovided: {}".format( - freq_str, expected_lags[freq_str], lags - ) +@pytest.mark.parametrize("freq_str, expected_lags", EXPECTED_LAGS.items()) +def test_lags(freq_str, expected_lags): + lags = date_feature_set.get_lags_for_frequency(freq_str) + assert lags == expected_lags if __name__ == "__main__": diff --git a/test/time_feature/test_seasonality.py b/test/time_feature/test_seasonality.py index 0323e52ebe..3817416c2c 100644 --- a/test/time_feature/test_seasonality.py +++ b/test/time_feature/test_seasonality.py @@ -15,25 +15,42 @@ from gluonts.time_feature import get_seasonality +from .common import H, M, Q, Y -@pytest.mark.parametrize( - "freq, expected_seasonality", - [ - ("30min", 48), - ("1H", 24), - ("H", 24), - ("2H", 12), - ("3H", 8), - ("4H", 6), - ("15H", 1), - ("5B", 1), - ("1B", 5), - ("2W", 1), - ("3M", 4), - ("1D", 1), - ("7D", 1), - ("8D", 1), - ], -) +TEST_CASES = [ + ("30min", 48), + ("5B", 1), + ("1B", 5), + ("2W", 1), + ("1D", 1), + ("7D", 1), + ("8D", 1), + # Monthly + ("MS", 12), + ("3MS", 4), + (M, 12), + ("3" + M, 4), + # Quarterly + ("QS", 4), + ("2QS", 2), + (Q, 4), + ("2" + Q, 2), + ("3" + Q, 1), + # Hourly + ("1" + H, 24), + (H, 24), + ("2" + H, 12), + ("3" + H, 8), + ("4" + H, 6), + ("15" + H, 1), + # Yearly + (Y, 1), + ("2" + Y, 1), + ("YS", 1), + ("2YS", 1), +] + + +@pytest.mark.parametrize("freq, expected_seasonality", TEST_CASES) def test_get_seasonality(freq, expected_seasonality): assert get_seasonality(freq) == expected_seasonality