Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

edgecolor fix #582

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyleoclim/core/geoseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import matplotlib.pyplot as plt
import re
import pandas as pd
import numpy as np

#from copy import deepcopy
from matplotlib import gridspec
Expand Down Expand Up @@ -497,9 +498,12 @@ def map_neighbors(self, mgs, radius=3000, projection='Orthographic', proj_defaul
neighbor_coloring = ['w' for ik in range(len(neighborhood))]
neighbor_coloring[-1] = 'k'
neighborhood['original'] =neighbor_coloring
neighborhood['neighbors'] =neighbor_coloring

# plot neighbors

fig, ax_d = mapping.scatter_map(neighborhood, fig=fig, gs_slot=gridspec_slot, hue=hue, size=size, marker=marker, projection=projection,
fig, ax_d = mapping.scatter_map(neighborhood, fig=fig, gs_slot=gridspec_slot, hue=hue, size=size,
marker=marker, projection=projection,
proj_default=proj_default,
background=background, borders=borders, rivers=rivers, lakes=lakes,
ocean=ocean, land=land,
Expand Down
96 changes: 69 additions & 27 deletions pyleoclim/utils/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def scatter_map(geos, hue='archiveType', size=None, marker='archiveType', edgeco
# geos_df = pd.DataFrame(value_d)
# return geos_df

def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_var=None, edgecolor='w',
def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_var=None, edgecolor_var=None, edgecolor='w',
ax=None, ax_d=None, proj=None, scatter_kwargs=None, legend=True, lgd_kwargs=None, colorbar=None,
fig=None, color_scale_type=None, # gs_slot=None,
cmap=None, **kwargs):
Expand Down Expand Up @@ -773,16 +773,28 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
}
missing_val = missing_d['label']

# if 'edgecolor' in scatter_kwargs:
# edgecolor = scatter_kwargs['edgecolor']
# scatter_kwargs['edgecolors'] = edgecolor
# if 'edgecolors' not in scatter_kwargs:
# scatter_kwargs['edgecolors'] = edgecolor
#
if 'edgecolor' in scatter_kwargs:
edgecolor = scatter_kwargs['edgecolor']# = edgecolor
if isinstance(scatter_kwargs, dict):
edgecolor = scatter_kwargs.pop('edgecolor', edgecolor)

if 'neighbor' in df.columns:
edgecolor_var = 'neighbor'
# if ~isinstance(edgecolor, np.ndarray):
# if isinstance(edgecolor, str):
# edgecolor = [edgecolor]
# edgecolor = np.array(edgecolor)


if isinstance(edgecolor, (list, np.ndarray)):
if len(edgecolor) == len(_df):
_df['edgecolor'] = edgecolor
elif len(edgecolor) == 1:
_df['edgecolor'] = edgecolor[0]
# try making a column populated by the edgecolor
elif isinstance(edgecolor, str):
_df['edgecolor'] = edgecolor
elif isinstance(edgecolor, dict):
if edgecolor_var in _df.columns:
_df['edgecolor'] = _df[edgecolor_var].map(edgecolor)

hue_var = hue_var if hue_var in _df.columns else None
hue_var_type_numeric = False
Expand Down Expand Up @@ -935,23 +947,42 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
_df['edgecolor'] = edgecolor
_df['neighbor'] = _df['edgecolor'].map({'k': 'target', 'w': 'neighbor'})

# handle missing values
hue_data = _df[_df[hue_var] == missing_val]

if len(hue_data) > 0:
sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var,
style=marker_var, transform=transform,edgecolor='w',
# change to transform=scatter_kwargs['transform']
if 'neighbor' in hue_data.columns:
sns.scatterplot(data=hue_data, x=x, y=y, size=size_var,
transform=transform,
edgecolor=hue_data.edgecolor.values, linewidth=2,
style=marker_var, hue=hue_var, palette=[missing_d['hue'] for ik in range(len(hue_data))],
ax=ax, legend=False,
**scatter_kwargs)

sns.scatterplot(data=hue_data, x=x, y=y, size=size_var,
transform=transform,
edgecolor=None, linewidth=0,
style=marker_var, hue=hue_var,
palette=[missing_d['hue'] for ik in range(len(hue_data))],
ax=ax, **scatter_kwargs)
ax=ax,
**scatter_kwargs)
# sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var,
# style=marker_var, transform=transform,edgecolor=edgecolor,
# # change to transform=scatter_kwargs['transform']
# palette=[missing_d['hue'] for ik in range(len(hue_data))],
# ax=ax, **scatter_kwargs)
missing_handles, missing_labels = ax.get_legend_handles_labels()
if 'neighbor' in hue_data.columns:
if len(hue_data[hue_data['neighbor'] != 'neighbor']) > 1:
_edgecolor = hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor'].values[0]
else:
_edgecolor = hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor']
sns.scatterplot(data=hue_data[hue_data['neighbor'] != 'neighbor'], x=x, y=y, size=size_var,
transform=transform, edgecolor=_edgecolor,
style=marker_var, hue=hue_var, palette=palette, ax=ax, **scatter_kwargs)

# # if the missing values are being handled for map_neighbors
# if 'neighbor' in hue_data.columns:
# # if the missing value is actually the target
# if len(hue_data[hue_data['neighbor'] != 'neighbor']) > 1:
# _edgecolor = hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor'].values[0]
# else:
# _edgecolor = hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor']
# sns.scatterplot(data=hue_data[hue_data['neighbor'] != 'neighbor'], x=x, y=y, size=size_var,
# transform=transform, edgecolor=_edgecolor,
# style=marker_var, hue=hue_var, palette=palette, ax=ax, **scatter_kwargs)
# available values
else:
missing_handles, missing_labels = [], []

Expand All @@ -960,12 +991,23 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
if hue_norm is not None:
scatter_kwargs['hue_norm'] = hue_norm

sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, transform=transform,edgecolor='w',
style=marker_var, palette=palette, ax=ax, **scatter_kwargs)
# sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, transform=transform,edgecolor=edgecolor,
# style=marker_var, palette=palette, ax=ax, **scatter_kwargs)

if 'neighbor' in hue_data.columns:
sns.scatterplot(data=hue_data[hue_data['neighbor'] != 'neighbor'], x=x, y=y, size=size_var,
transform=transform, edgecolor=hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor'].values[0],
style=marker_var, hue=hue_var, palette=palette, ax=ax, **scatter_kwargs)
sns.scatterplot(data=hue_data, x=x, y=y, size=size_var,
transform=transform,
edgecolor=hue_data.edgecolor.values,linewidth=2,
style=marker_var, hue=hue_var, palette=palette, ax=ax, legend=False, **scatter_kwargs)
if not isinstance(edgecolor, str):
edgecolor = None
linewidth = 0
else:
linewidth = 1

sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, transform=transform,
edgecolor=edgecolor,linewidth=linewidth,
style=marker_var, palette=palette, ax=ax, **scatter_kwargs)

else:
scatter_kwargs['zorder'] = 13
Expand Down
Loading