Skip to content

Commit

Permalink
edgecolor and legend aesthetics
Browse files Browse the repository at this point in the history
  • Loading branch information
jordanplanders committed Jun 21, 2024
1 parent 245728d commit 36156e5
Showing 1 changed file with 32 additions and 19 deletions.
51 changes: 32 additions & 19 deletions pyleoclim/utils/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def make_df(geo_ms, hue=None, marker=None, size=None, cols=None, d=None):
def scatter_map(geos, hue='archiveType', size=None, marker='archiveType', edgecolor='k',
proj_default=True, projection='auto', crit_dist=5000,
background=True, borders=False, coastline=True, rivers=False, lakes=False, ocean=True, land=True,
figsize=None, scatter_kwargs=None, gridspec_kwargs=None, extent='global',
figsize=None, scatter_kwargs=None, gridspec_kwargs=None, extent='global', edgecolor_var=None,
lgd_kwargs=None, legend=True, colorbar=True, cmap=None, color_scale_type=None,
fig=None, gs_slot=None, **kwargs):
'''
Expand Down Expand Up @@ -774,23 +774,24 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
if isinstance(scatter_kwargs, dict):
linewidth = scatter_kwargs.pop('linewidth', 1)

if isinstance(lgd_kwargs, dict):
handle_size = lgd_kwargs.pop('handle_size', 11)
# if isinstance(lgd_kwargs, dict):
# handle_size = lgd_kwargs.pop('handle_size', 11)

if 'neighbor' in df.columns:
edgecolor_var = 'neighbor'
if 'neighbor_status' in _df.columns:
edgecolor_var = 'neighbor_status'

if isinstance(edgecolor, (list, np.ndarray)):
if len(edgecolor) == len(_df):
_df['edgecolor'] = edgecolor
elif len(edgecolor) == 1:
_df['edgecolor'] = edgecolor[0]

# try making a column populated by the edgecolor
elif isinstance(edgecolor, str):
_df['edgecolor'] = edgecolor
elif isinstance(edgecolor, dict):
if edgecolor_var in _df.columns:
_df['edgecolor'] = _df[edgecolor_var].map(edgecolor)
_df['edgecolor'] = _df[edgecolor_var].apply(lambda x: edgecolor[x])

_df = _df.apply(lambda x: tidy_labels(x) if x.dtype == "str" else x)
hue_var = hue_var if hue_var in _df.columns else None
Expand Down Expand Up @@ -897,7 +898,6 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
palette = {key: value[0] for key, value in plot_defaults.items()}
elif isinstance(hue_var,str): #hue_var) == str:
hue_data = _df[_df[hue_var] != missing_val]

# If scalar mappable was passed, try to extract components.
if ax_sm is not None:
try:
Expand Down Expand Up @@ -950,18 +950,19 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
scatter_kwargs['zorder'] = 13
if isinstance(edgecolor, np.ndarray):
_df['edgecolor'] = edgecolor
_df['neighbor'] = _df['edgecolor'].map({'k': 'target', 'w': 'neighbor'})
_df['neighbor_status'] = _df['edgecolor'].map({'k': 'target', 'w': 'neighbor'})

# handle missing values
hue_data = _df[_df[hue_var] == missing_val]
if len(hue_data) > 0:
if 'neighbor' in hue_data.columns:
if 'neighbor_status' in hue_data.columns:
sns.scatterplot(data=hue_data, x=x, y=y, size=size_var,
transform=transform,
edgecolor=hue_data.edgecolor.values, linewidth=linewidth,
style=marker_var, hue=hue_var, palette=[missing_d['hue'] for ik in range(len(hue_data))],
ax=ax, legend=False,
**scatter_kwargs)
scatter_kwargs['zorder'] = 13

sns.scatterplot(data=hue_data, x=x, y=y, size=size_var,
transform=transform,
Expand Down Expand Up @@ -999,16 +1000,18 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
# sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, transform=transform,edgecolor=edgecolor,
# style=marker_var, palette=palette, ax=ax, **scatter_kwargs)

if 'neighbor' in hue_data.columns:
if 'neighbor_status' in hue_data.columns:
scatter_kwargs['zorder'] = 14
sns.scatterplot(data=hue_data, x=x, y=y, size=size_var,
transform=transform,
edgecolor=hue_data.edgecolor.values,linewidth=linewidth,
style=marker_var, hue=hue_var, palette=palette, ax=ax, legend=False, **scatter_kwargs)
if not isinstance(edgecolor, str):
scatter_kwargs['zorder'] = 13
# if not isinstance(edgecolor, str):
edgecolor = None
linewidth = 0
else:
linewidth = linewidth
# else:
# linewidth = linewidth

sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, transform=transform,
edgecolor=edgecolor,linewidth=linewidth,
Expand Down Expand Up @@ -1092,11 +1095,13 @@ def replace_last(source_string, replace_what, replace_with):
# else:
# d_leg.pop(hue_var, None)
# print(d_leg, d_leg.keys())
tmpHandles = []
for key in [hue_var, marker_var]:
if key in d_leg.keys():
for handle in d_leg[key]['handles']:
handle.set_markersize(handle_size)
# tmpHandles = []

# for key in [hue_var, marker_var]:
# if key in d_leg.keys():
# for handle in d_leg[key]['handles']:
# mpl.artist.setp(handle, markersize=handle_size)
# handle._legmarker.set_markersize(handle_size)
# tmpHandles.append(handle)

# d_leg[hue_var]['handles'] = tmpHandles
Expand Down Expand Up @@ -1141,7 +1146,15 @@ def replace_last(source_string, replace_what, replace_with):
if 'bbox_to_anchor' not in lgd_kwargs.keys():
lgd_kwargs['bbox_to_anchor'] = (-.1, 1) # (1, 1)
if 'labelspacing' not in lgd_kwargs.keys():
lgd_kwargs['labelspacing'] = .275
if len(labels) > 15:
lgd_kwargs['labelspacing'] = .275
else:
lgd_kwargs['labelspacing'] = .5
if 'markerscale' not in lgd_kwargs.keys():
if len(labels)>15:
lgd_kwargs['markerscale'] = .8
else:
lgd_kwargs['markerscale'] = 1

built_legend = ax_leg.legend(handles, labels, **lgd_kwargs)
if headers is True:
Expand Down

0 comments on commit 36156e5

Please sign in to comment.