Skip to content

Commit

Permalink
FIX: Groupby plots will now group timeseries subplots by (#624)
Browse files Browse the repository at this point in the history
* ADD: Groupby plotting capability.

* ENH: Changing author name

* FIX: Plots in examples

* FIX: GroupBy will now override TimeSeriesDisplay default time axes

* ADD: None as default value for group_by.

* FIX: Groupby now groups by year in subplots

---------

Co-authored-by: Robert Jackson <[email protected]>
Co-authored-by: AdamTheisen <[email protected]>
Co-authored-by: Robert Jackson <[email protected]>
Co-authored-by: Robert Jackson <[email protected]>
  • Loading branch information
5 people authored Feb 22, 2023
1 parent 0043566 commit 5f7495a
Showing 1 changed file with 31 additions and 13 deletions.
44 changes: 31 additions & 13 deletions act/plotting/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,10 @@ def __init__(self, display, units):
"""
self.display = display
self._groupby = {}
self.mapping = {}
self.xlims = {}
self.units = units
self.isTimeSeriesDisplay = hasattr(self.display, 'time_height_scatter')
num_groups = 0
datastreams = list(display._obj.keys())
for key in datastreams:
Expand Down Expand Up @@ -350,6 +354,7 @@ def plot_group(self, func_name, dsname=None, **kwargs):
if dsname == key:
self.display._obj = {}
for k, ds in self._groupby[key]:
num_years = len(np.unique(ds.time.dt.year))
self.display._obj[key + '_%d' % k] = ds
if i >= np.prod(subplot_shape):
i = 0
Expand All @@ -363,9 +368,27 @@ def plot_group(self, func_name, dsname=None, **kwargs):
kwargs["subplot_index"] = subplot_index
if "time_rng" in args:
kwargs["time_rng"] = (ds.time.values.min(), ds.time.values.max())
func(dsname=key + '_%d' % k,
**kwargs)

if num_years > 1 and self.isTimeSeriesDisplay:
first_year = ds.time.dt.year[0]
for yr, ds1 in ds.groupby('time.year'):
if ds1.time.dt.year[0] % 4 == 0:
days_in_year = 366
else:
days_in_year = 365
year_diff = ds1.time.dt.year - first_year
time_diff = np.array(
[np.timedelta64(x * days_in_year, 'D') for x in year_diff.values])
ds1['time'] = ds1.time - time_diff
self.display._obj[key + '%d_%d' % (k, yr)] = ds1
func(dsname=key + '%d_%d' % (k, yr), label=str(yr), **kwargs)
self.mapping[key + '%d_%d' % (k, yr)] = subplot_index
self.xlims[key + '%d_%d' % (k, yr)] = (ds1.time.values.min(), ds1.time.values.max())
del self.display._obj[key + '_%d' % k]
else:
func(dsname=key + '_%d' % k, **kwargs)
self.mapping[key + '_%d' % k] = subplot_index
if self.isTimeSeriesDisplay:
self.xlims[key + '_%d' % k] = (ds.time.values.min(), ds.time.values.max())
i = i + 1

if wrap_around is False and i < np.prod(subplot_shape):
Expand All @@ -387,16 +410,11 @@ def plot_group(self, func_name, dsname=None, **kwargs):
except AttributeError:
pass

# Set to min and max for each time period if time series display
# Only the TimeSeriesDisplay has the time_height_scatter function
# So, check for that
if hasattr(self.display, 'time_height_scatter'):
key_list = list(self.display._obj.keys())
if i >= len(key_list):
continue
ds = self.display._obj[key_list[i]]
time_min = ds.time.values.min()
time_max = ds.time.values.max()
if self.isTimeSeriesDisplay:
key_list = list(self.display._obj.keys())
for k in key_list:
time_min, time_max = self.xlims[k]
subplot_index = self.mapping[k]
self.display.set_xrng([time_min, time_max], subplot_index)

self.display._obj = old_obj
Expand Down

0 comments on commit 5f7495a

Please sign in to comment.