diff --git a/pyleoclim/core/__init__.py b/pyleoclim/core/__init__.py index 2c6dd0ed..b697d96b 100644 --- a/pyleoclim/core/__init__.py +++ b/pyleoclim/core/__init__.py @@ -12,8 +12,7 @@ from .surrogateseries import SurrogateSeries from .ensembleseries import EnsembleSeries from .scalograms import Scalogram, MultipleScalogram -from .coherence import Coherence -from .globalcoherence import GlobalCoherence +from .coherences import Coherence, GlobalCoherence from .corr import Corr from .correns import CorrEns from .multivardecomp import MultivariateDecomp diff --git a/pyleoclim/core/coherence.py b/pyleoclim/core/coherence.py index b851f51a..1c6838db 100644 --- a/pyleoclim/core/coherence.py +++ b/pyleoclim/core/coherence.py @@ -13,7 +13,7 @@ from copy import deepcopy from matplotlib.ticker import ScalarFormatter, FormatStrFormatter -from matplotlib import cm +#from matplotlib import cm from matplotlib import gridspec from tqdm import tqdm diff --git a/pyleoclim/core/coherences.py b/pyleoclim/core/coherences.py new file mode 100644 index 00000000..7f4e74e6 --- /dev/null +++ b/pyleoclim/core/coherences.py @@ -0,0 +1,1169 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +The Coherence class stores the result of Series.wavelet_coherence(), whether WWZ or CWT. +It includes wavelet transform coherency and cross-wavelet transform. +""" +from ..utils import plotting +from ..utils import wavelet as waveutils +from ..core.scalograms import Scalogram, MultipleScalogram + +import matplotlib.pyplot as plt +import numpy as np +from copy import deepcopy + +from matplotlib.ticker import ScalarFormatter, FormatStrFormatter +#from matplotlib import cm +from matplotlib import gridspec + +from tqdm import tqdm +from scipy.stats.mstats import mquantiles +import warnings + +class Coherence: + '''Coherence object, meant to receive the WTC and XWT part of Series.wavelet_coherence() + + See also + -------- + + pyleoclim.core.series.Series.wavelet_coherence : Wavelet coherence method + + ''' + def __init__(self, frequency, scale, time, wtc, xwt, phase, coi=None, + wave_method=None, wave_args=None, + timeseries1=None, timeseries2=None, signif_qs=None, signif_method=None, qs =None, + freq_method=None, freq_kwargs=None, Neff_threshold=3, scale_unit=None, time_label=None): + self.frequency = np.array(frequency) + self.time = np.array(time) + self.scale = np.array(scale) + self.wtc = np.array(wtc) + self.xwt = np.array(xwt) + if coi is not None: + self.coi = np.array(coi) + else: + self.coi = waveutils.make_coi(self.time, Neff_threshold=Neff_threshold) + self.phase = np.array(phase) + self.timeseries1 = timeseries1 + self.timeseries2 = timeseries2 + self.signif_qs = signif_qs + self.signif_method = signif_method + self.freq_method = freq_method + self.freq_kwargs = freq_kwargs + self.wave_method = wave_method + if wave_args is not None: + if 'freq' in wave_args.keys(): + wave_args['freq'] = np.array(wave_args['freq']) + if 'tau' in wave_args.keys(): + wave_args['tau'] = np.array(wave_args['tau']) + self.wave_args = wave_args + self.qs = qs + + if scale_unit is not None: + self.scale_unit = scale_unit + elif timeseries1 is not None: + self.scale_unit = plotting.infer_period_unit_from_time_unit(timeseries1.time_unit) + elif timeseries2 is not None: + self.scale_unit = plotting.infer_period_unit_from_time_unit(timeseries2.time_unit) + else: + self.scale_unit = None + + if time_label is not None: + self.time_label = time_label + elif timeseries1 is not None: + if timeseries1.time_unit is not None: + self.time_label = f'{timeseries1.time_name} [{timeseries1.time_unit}]' + else: + self.time_label = f'{timeseries1.time_name}' + elif timeseries2 is not None: + if timeseries2.time_unit is not None: + self.time_label = f'{timeseries2.time_name} [{timeseries2.time_unit}]' + else: + self.time_label = f'{timeseries2.time_name}' + else: + self.time_label = None + + def copy(self): + '''Copy object + ''' + return deepcopy(self) + + def plot(self, var='wtc', xlabel=None, ylabel=None, title='auto', figsize=[10, 8], + ylim=None, xlim=None, in_scale=True, yticks=None, contourf_style={}, + phase_style={}, cbar_style={}, savefig_settings={}, ax=None, + signif_clr='white', signif_linestyles='-', signif_linewidths=1, + signif_thresh = 0.95, under_clr='ivory', over_clr='black', bad_clr='dimgray'): + '''Plot the cross-wavelet results + + Parameters + ---------- + + var : str {'wtc', 'xwt'} + + variable to be plotted as color field. Default: 'wtc', the wavelet transform coherency. + 'xwt' plots the cross-wavelet transform instead. + + xlabel : str, optional + + x-axis label. The default is None. + + ylabel : str, optional + + y-axis label. The default is None. + + title : str, optional + + Title of the plot. The default is 'auto', where it is made from object metadata. + To mute, pass title = None. + + figsize : list, optional + + Figure size. The default is [10, 8]. + + ylim : list, optional + + y-axis limits. The default is None. + + xlim : list, optional + + x-axis limits. The default is None. + + in_scale : bool, optional + + Plots scales instead of frequencies The default is True. + + yticks : list, optional + + y-ticks label. The default is None. + + contourf_style : dict, optional + + Arguments for the contour plot. The default is {}. + + phase_style : dict, optional + + Arguments for the phase arrows. The default is {}. It includes: + - 'pt': the default threshold above which phase arrows will be plotted + - 'skip_x': the number of points to skip between phase arrows along the x-axis + - 'skip_y': the number of points to skip between phase arrows along the y-axis + - 'scale': number of data units per arrow length unit (see matplotlib.pyplot.quiver) + - 'width': shaft width in arrow units (see matplotlib.pyplot.quiver) + - 'color': arrow color (see matplotlib.pyplot.quiver) + + cbar_style : dict, optional + + Arguments for the color bar. The default is {}. + + savefig_settings : dict, optional + + The default is {}. + the dictionary of arguments for plt.savefig(); some notes below: + - "path" must be specified; it can be any existed or non-existed path, + with or without a suffix; if the suffix is not given in "path", it will follow "format" + - "format" can be one of {"pdf", "eps", "png", "ps"} + + ax : ax, optional + + Matplotlib axis on which to return the figure. The default is None. + + signif_thresh: float in [0, 1] + + Significance threshold. Default is 0.95. If this quantile is not + found in the qs field of the Coherence object, the closest quantile + will be picked. + + signif_clr : str, optional + + Color of the significance line. The default is 'white'. + + signif_linestyles : str, optional + + Style of the significance line. The default is '-'. + + signif_linewidths : float, optional + + Width of the significance line. The default is 1. + + under_clr : str, optional + + Color for under 0. The default is 'ivory'. + + over_clr : str, optional + + Color for over 1. The default is 'black'. + + bad_clr : str, optional + + Color for missing values. The default is 'dimgray'. + + Returns + ------- + fig, ax + + See also + -------- + + pyleoclim.core.coherence.Coherence.dashboard : plots a a dashboard showing the coherence and the cross-wavelet transform. + + pyleoclim.core.series.Series.wavelet_coherence : computes the coherence from two timeseries. + + matplotlib.pyplot.quiver : quiver plot + + Examples + -------- + + Calculate the wavelet coherence of NINO3 and All India Rainfall and plot it: + .. jupyter-execute:: + + + import pyleoclim as pyleo + ts_air = pyleo.utils.load_dataset('AIR') + ts_nino = pyleo.utils.load_dataset('NINO3') + + coh = ts_air.wavelet_coherence(ts_nino) + coh.plot() + + Establish significance against an AR(1) benchmark: + + .. jupyter-execute:: + + coh_sig = coh.signif_test(number=20, qs=[.9,.95,.99]) + coh_sig.plot() + + Note that specifiying 3 significance thresholds does not take any more time as the quantiles are + simply estimated from the same ensemble. By default, the plot function looks + for the closest quantile to 0.95, but this is easy to adjust, e.g. for the 99th percentile: + + .. jupyter-execute:: + + coh_sig.plot(signif_thresh = 0.99) + + By default, the function plots the wavelet transform coherency (WTC), which quantifies where + two timeseries exhibit similar behavior in time-frequency space, regardless of whether this + corresponds to regions of high common power. To visualize the latter, you want to plot the + cross-wavelet transform (XWT) instead, like so: + + .. jupyter-execute:: + + coh_sig.plot(var='xwt') + + ''' + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + + # handling NaNs + mask_freq = [] + for i in range(np.size(self.frequency)): + if all(np.isnan(self.wtc[:, i])): + mask_freq.append(False) + else: + mask_freq.append(True) + + if in_scale: + y_axis = self.scale[mask_freq] + if ylabel is None: + ylabel = f'Scale [{self.scale_unit}]' if self.scale_unit is not None else 'Scale' + + if yticks is None: + yticks_default = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 1e4, 2e4, 5e4, 1e5, 2e5, 5e5, 1e6]) + mask = (yticks_default >= np.min(y_axis)) & (yticks_default <= np.max(y_axis)) + yticks = yticks_default[mask] + else: + y_axis = self.frequency[mask_freq] + if ylabel is None: + ylabel = f'Frequency [1/{self.scale_unit}]' if self.scale_unit is not None else 'Frequency' + + if signif_thresh > 1 or signif_thresh < 0: + raise ValueError("The significance threshold must be in [0, 1] ") + + # plot color field for WTC or XWT + contourf_args = { + 'cmap': 'magma', + 'origin': 'lower', + } + contourf_args.update(contourf_style) + + cmap = plt.get_cmap(contourf_args['cmap']) + cmap.set_under(under_clr) + cmap.set_over(over_clr) + cmap.set_bad(bad_clr) + contourf_args['cmap'] = cmap + + if var == 'wtc': + lev = np.linspace(0, 1, 11) + cont = ax.contourf(self.time, y_axis, self.wtc[:, mask_freq].T, + levels = lev, **contourf_args) + elif var == 'xwt': + cont = ax.contourf(self.time, y_axis, self.xwt[:, mask_freq].T, + levels = 11, **contourf_args) # just pass number of contours + else: + raise ValueError("Unknown variable; please choose either 'wtc' or 'xwt'") + + # plot significance levels + if self.signif_qs is not None: + signif_method_label = { + 'ar1': 'AR(1)', + } + if signif_thresh not in self.qs: + isig = np.abs(np.array(self.qs) - signif_thresh).argmin() + print("Significance threshold {:3.2f} not found in qs. Picking the closest, which is {:3.2f}".format(signif_thresh,self.qs[isig])) + else: + isig = self.qs.index(signif_thresh) + + if var == 'wtc': + signif_coh = self.signif_qs[0].scalogram_list[isig] # extract WTC significance threshold + signif_boundary = self.wtc[:, mask_freq].T / signif_coh.amplitude[:, mask_freq].T + elif var == 'xwt': + signif_coh = self.signif_qs[1].scalogram_list[isig] # extract XWT significance threshold + signif_boundary = self.xwt[:, mask_freq].T / signif_coh.amplitude[:, mask_freq].T + + ax.contour(self.time, y_axis, signif_boundary, [-99, 1], + colors=signif_clr, + linestyles=signif_linestyles, + linewidths=signif_linewidths) + if title is not None: + ax.set_title("Lines:" + str(round(self.qs[isig]*100))+"% threshold") + + # plot colorbar + cbar_args = { + 'label': var.upper(), + 'drawedges': False, + 'orientation': 'vertical', + 'fraction': 0.15, + 'pad': 0.05, + 'ticks': cont.levels + } + cbar_args.update(cbar_style) + + # assign colorbar to axis (instead of fig) : https://matplotlib.org/stable/gallery/subplots_axes_and_figures/colorbar_placement.html + cb = plt.colorbar(cont, ax = ax, **cbar_args) + + # plot cone of influence + ax.set_yscale('log') + ax.plot(self.time, self.coi, 'k--') + + if ylim is None: + ylim = [np.min(y_axis), np.min([np.max(y_axis), np.max(self.coi)])] + + ax.fill_between(self.time, self.coi, np.max(self.coi), color='white', alpha=0.5) + + if yticks is not None: + ax.set_yticks(yticks) + ax.yaxis.set_major_formatter(ScalarFormatter()) + ax.yaxis.set_major_formatter(FormatStrFormatter('%g')) + + if xlabel is None: + xlabel = self.time_label + + if xlabel is not None: + ax.set_xlabel(xlabel) + + if ylabel is not None: + ax.set_ylabel(ylabel) + + # plot phase + skip_x = np.max([int(np.size(self.time)//20), 1]) + skip_y = np.max([int(np.size(y_axis)//20), 1]) + + phase_args = {'pt': 0.5, 'skip_x': skip_x, 'skip_y': skip_y, + 'scale': 30, 'width': 0.004} + phase_args.update(phase_style) + + pt = phase_args['pt'] + skip_x = phase_args['skip_x'] + skip_y = phase_args['skip_y'] + scale = phase_args['scale'] + width = phase_args['width'] + + if 'color' in phase_style: + color = phase_style['color'] + else: + color = 'black' + + phase = np.copy(self.phase)[:, mask_freq] + + if self.signif_qs is None: + if var == 'wtc': + phase[self.wtc[:, mask_freq] < pt] = np.nan + else: + field = self.xwt[:, mask_freq] + phase[field < pt*field.max()] = np.nan + else: + phase[signif_boundary.T < 1] = np.nan + + X, Y = np.meshgrid(self.time, y_axis) + U, V = np.cos(phase).T, np.sin(phase).T + + ax.quiver(X[::skip_y, ::skip_x], Y[::skip_y, ::skip_x], + U[::skip_y, ::skip_x], V[::skip_y, ::skip_x], + scale=scale, width=width, zorder=99, color=color) + + ax.set_ylim(ylim) + + if xlim is not None: + ax.set_xlim(xlim) + + lbl1 = self.timeseries1.label + lbl2 = self.timeseries2.label + + if 'fig' in locals(): + if 'path' in savefig_settings: + plotting.savefig(fig, settings=savefig_settings) + if title is not None and title != 'auto': + fig.suptitle(title) + elif title == 'auto' and lbl1 is not None and lbl1 is not None: + title = 'Wavelet coherency ('+self.wave_method.upper() +') between '+ lbl1 + ' and ' + lbl2 + fig.suptitle(title) + return fig, ax + else: + return ax + + + def dashboard(self, title=None, figsize=[9,12], overlap = True, phase_style = {}, + line_colors = ['tab:blue','tab:orange'], savefig_settings={}, + ts_plot_kwargs = None, wavelet_plot_kwargs= None): + ''' Cross-wavelet dashboard, including the two series, their WTC and XWT. + + Note: this design balances many considerations, and is not easily customizable. + + Parameters + ---------- + title : str, optional + + Title of the plot. The default is None. + + figsize : list, optional + + Figure size. The default is [9, 12], as this is an information-rich figure. + + overlap : boolean, optional + whether to restrict the plot to the period of overlap between the series. Defaults to True + + phase_style : dict, optional + + Arguments for the phase arrows. The default is {}. It includes: + - 'pt': the default threshold above which phase arrows will be plotted + - 'skip_x': the number of points to skip between phase arrows along the x-axis + - 'skip_y': the number of points to skip between phase arrows along the y-axis + - 'scale': number of data units per arrow length unit (see matplotlib.pyplot.quiver) + - 'width': shaft width in arrow units (see matplotlib.pyplot.quiver) + - 'color': arrow color (see matplotlib.pyplot.quiver) + + line_colors : list, optional + + Colors for the 2 traces For nomenclature, see https://matplotlib.org/stable/gallery/color/named_colors.html + + savefig_settings : dict, optional + + The default is {}. + the dictionary of arguments for plt.savefig(); some notes below: + - "path" must be specified; it can be any existed or non-existed path, + with or without a suffix; if the suffix is not given in "path", it will follow "format" + - "format" can be one of {"pdf", "eps", "png", "ps"} + + ts_plot_kwargs : dict + + arguments to be passed to the timeseries subplot, see pyleoclim.core.series.Series.plot for details + + wavelet_plot_kwargs : dict + + arguments to be passed to the contour subplots (XWT and WTC), [see pyleoclim.core.coherence.Coherence.plot for details] + + + Returns + ------- + fig, ax + + See also + -------- + + pyleoclim.core.coherence.Coherence.plot : creates a coherence plot + + pyleoclim.core.series.Series.wavelet_coherence : computes the coherence between two timeseries. + + pyleoclim.core.series.Series.plot: plots a timeseries + + matplotlib.pyplot.quiver: makes a quiver plot + + Examples + -------- + + Calculate the coherence of NINO3 and All India Rainfall and plot it as a dashboard: + + .. jupyter-execute:: + + import pyleoclim as pyleo + ts_air = pyleo.utils.load_dataset('AIR') + ts_nino = pyleo.utils.load_dataset('NINO3') + + coh = ts_air.wavelet_coherence(ts_nino) + coh_sig = coh.signif_test(number=10) + + coh_sig.dashboard() + + You may customize colors like so: + + .. jupyter-execute:: + + coh_sig.dashboard(line_colors=['teal','gold']) + + To export the figure, use `savefig_settings`: + + .. jupyter-execute:: + + coh_sig.dashboard(savefig_settings={'path':'./coh_dash.png','dpi':300}) + + ''' + # prepare options dictionaries + savefig_settings = {} if savefig_settings is None else savefig_settings.copy() + wavelet_plot_kwargs={} if wavelet_plot_kwargs is None else wavelet_plot_kwargs.copy() + ts_plot_kwargs={} if ts_plot_kwargs is None else ts_plot_kwargs.copy() + + + # create figure + fig = plt.figure(figsize=figsize) + gs = gridspec.GridSpec(8, 1) + gs.update(wspace=0, hspace=0.5) # add some breathing room + ax = {} + + # assess period of overlap + xlims = np.min(self.time), np.max(self.time) + + # 1) plot timeseries + #plt.rc('ytick', labelsize=8) + ax['ts1'] = plt.subplot(gs[0:2, 0]) + self.timeseries1.plot(ax=ax['ts1'], color=line_colors[0], **ts_plot_kwargs, legend=False) + ax['ts1'].yaxis.label.set_color(line_colors[0]) + ax['ts1'].tick_params(axis='y', colors=line_colors[0],labelsize=8) + ax['ts1'].spines['left'].set_color(line_colors[0]) + ax['ts1'].spines['bottom'].set_visible(False) + ax['ts1'].grid(False) + ax['ts1'].set_xlabel('') + if overlap: + ax['ts1'].set_xlim(xlims) + + ax['ts2'] = ax['ts1'].twinx() + self.timeseries2.plot(ax=ax['ts2'], color=line_colors[1], **ts_plot_kwargs, legend=False) + ax['ts2'].yaxis.label.set_color(line_colors[1]) + ax['ts2'].tick_params(axis='y', colors=line_colors[1],labelsize=8) + ax['ts2'].spines['right'].set_color(line_colors[1]) + ax['ts2'].spines['right'].set_visible(True) + ax['ts2'].spines['left'].set_visible(False) + ax['ts2'].grid(False) + if overlap: + ax['ts2'].set_xlim(xlims) + + # 2) plot WTC + ax['wtc'] = plt.subplot(gs[2:5, 0], sharex=ax['ts1']) + if 'cbar_style' not in wavelet_plot_kwargs: + wavelet_plot_kwargs.update({'cbar_style':{'orientation': 'horizontal', + 'pad': 0.15, 'aspect': 60}}) + self.plot(var='wtc',ax=ax['wtc'], title= None, **wavelet_plot_kwargs) + #ax['wtc'].xaxis.set_visible(False) # hide x axis + ax['wtc'].set_xlabel('') + + # 3) plot XWT + ax['xwt'] = plt.subplot(gs[5:8, 0], sharex=ax['ts1']) + if 'phase_style' not in wavelet_plot_kwargs: + wavelet_plot_kwargs.update({'phase_style':{'color': 'lightgray'}}) + self.plot(var='xwt',ax=ax['xwt'], title= None, + contourf_style={'cmap': 'viridis'}, + cbar_style={'orientation': 'horizontal','pad': 0.2, 'aspect': 60}, + phase_style=wavelet_plot_kwargs['phase_style']) + + #gs.tight_layout(fig) # this does nothing + + if 'fig' in locals(): + if 'path' in savefig_settings: + plotting.savefig(fig, settings=savefig_settings) + return fig, ax + else: + return ax + + def signif_test(self, number=200, method='ar1sim', seed=None, qs=[0.95], settings=None, mute_pbar=False): + '''Significance testing for Coherence objects + + The method obtains quantiles `qs` of the distribution of coherence between + `number` pairs of Monte Carlo simulations of a process that resembles the original series. + Currently, only AR(1) surrogates are supported. + + Parameters + ---------- + number : int, optional + + Number of surrogate series to create for significance testing. The default is 200. + + method : {'ar1sim','phaseran','CN'}, optional + + Method through which to generate the surrogate series. The default is 'phaseran'. + + seed : int, optional + + Fixes the seed for NumPy's random number generator. + Useful for reproducibility. The default is None, so fresh, unpredictable + entropy will be pulled from the operating system. + + qs : list, optional + + Significance levels to return. The default is [0.95]. + + settings : dict, optional + + Parameters for surrogate model. The default is None. + + mute_pbar : bool, optional + + Mute the progress bar. The default is False. + + Returns + ------- + new : pyleoclim.core.coherence.Coherence + + original Coherence object augmented with significance levels signif_qs, + a list with the following `MultipleScalogram` objects: + * 0: MultipleScalogram for the wavelet transform coherency (WTC) + * 1: MultipleScalogram for the cross-wavelet transform (XWT) + + Each object contains as many Scalogram objects as qs contains values + + See also + -------- + + pyleoclim.core.series.Series.wavelet_coherence : Wavelet coherence + + pyleoclim.core.scalograms.Scalogram : Scalogram object + + pyleoclim.core.scalograms.MultipleScalogram : Multiple Scalogram object + + pyleoclim.core.coherence.Coherence.plot : plotting method for Coherence objects + + Examples + -------- + + Calculate the coherence of NINO3 and All India Rainfall and assess significance: + + .. jupyter-execute:: + + import pyleoclim as pyleo + ts_air = pyleo.utils.load_dataset('AIR') + ts_nino = pyleo.utils.load_dataset('NINO3') + + coh = ts_air.wavelet_coherence(ts_nino) + coh_sig = coh.signif_test(number=20) + coh_sig.plot() + + By default, significance is assessed against a 95% benchmark derived from + an AR(1) process fit to the data, using 200 Monte Carlo simulations. + To customize, one can increase the number of simulations + (more reliable, but slower), and the quantile levels. + + .. jupyter-execute:: + + coh_sig2 = coh.signif_test(number=100, qs=[.9,.95,.99]) + coh_sig2.plot() + + The plot() function will represent the 95% level as contours by default. + If you need to show 99%, say, use the `signif_thresh` argument: + + .. jupyter-execute:: + + coh_sig2.plot(signif_thresh=0.99) + + Note that if the 99% quantile is not present, the plot method will look + for the closest match, but lines are always labeled appropriately. + For reproducibility purposes, it may be good to specify the (pseudo)random number + generator's seed, like so: + + .. jupyter-execute:: + + coh_sig27 = coh.signif_test(number=20, seed=27) + + This will generate exactly the same set of draws from the + (pseudo)random number at every execution, which may be important for marginal features + in small ensembles. In general, however, we recommend increasing the + number of draws to check that features are robust. + + One can also specifiy a different method to obtain surrogates, e.g. phase randomization: + + .. jupyter-execute:: + + coh.signif_test(method='phaseran').plot() + ''' + from ..core.surrogateseries import SurrogateSeries + + if number == 0: + return self + + new = self.copy() + + surr1 = SurrogateSeries(method=method,number=number, seed=seed) + surr1.from_series(self.timeseries1) + surr2 = SurrogateSeries(method=method,number=number, seed=seed) + surr2.from_series(self.timeseries2) + + # adjust time axis + + wtcs, xwts = [], [] + + for i in tqdm(range(number), desc='Performing wavelet coherence on surrogate pairs', total=number, disable=mute_pbar): + coh_tmp = surr1.series_list[i].wavelet_coherence(surr2.series_list[i], + method = self.wave_method, + settings = self.wave_args) + wtcs.append(coh_tmp.wtc) + xwts.append(coh_tmp.xwt) + + wtcs = np.array(wtcs) + xwts = np.array(xwts) + + + ne, nf, nt = np.shape(wtcs) + + # reshape because mquantiles only accepts inputs of at most 2D + wtcs_r = np.reshape(wtcs, (ne, nf*nt)) + xwts_r = np.reshape(xwts, (ne, nf*nt)) + + # define nd-arrays + nq = len(qs) + wtc_qs = np.ndarray(shape=(nq, nf, nt)) + xwt_qs = np.empty_like(wtc_qs) + + # extract quantiles and reshape + wtc_qs = mquantiles(wtcs_r, qs, axis=0) + wtc_qs = np.reshape(wtc_qs, (nq, nf, nt)) + xwt_qs = mquantiles(xwts_r, qs, axis=0) + xwt_qs = np.reshape(xwt_qs, (nq, nf, nt)) + + # put in Scalogram objects for export + wtc_list, xwt_list = [],[] + + for i in range(nq): + wtc_tmp = Scalogram( + frequency=self.frequency, time=self.time, amplitude=wtc_qs[i,:,:], + coi=self.coi, scale = self.scale, + freq_method=self.freq_method, freq_kwargs=self.freq_kwargs, label=f'{qs[i]*100:g}%', + ) + wtc_list.append(wtc_tmp) + xwt_tmp = Scalogram( + frequency=self.frequency, time=self.time, amplitude=xwt_qs[i,:,:], + coi=self.coi, scale = self.scale, + freq_method=self.freq_method, freq_kwargs=self.freq_kwargs, label=f'{qs[i]*100:g}%', + ) + + xwt_list.append(xwt_tmp) + + new.signif_qs = [] + new.signif_qs.append(MultipleScalogram(scalogram_list=wtc_list)) # Export WTC quantiles + new.signif_qs.append(MultipleScalogram(scalogram_list=xwt_list)) # Export XWT quantiles + new.signif_method = method + new.qs = qs + + return new + + def phase_stats(self, scales, number=1000, level=0.05): + ''' Estimate phase angle statistics of a Coherence object + + As per [1], the strength (consistency) of a phase relationship is assessed using: + + * sigma, the circular standard deviation + + * kappa, an estimate of the Von Mises distribution's concentration parameter. + It is a reciprocal measure of dispersion, so 1/kappa is analogous to the variance) [3]. + + Because of inherent persistence of geophysical signals and of the + reproducing kernel of the continuous wavelet transform [3], phase statistics are + assessed relative to an AR(1) model fit to the angle deviations observed at the requested scale(s). + + Specifically, if `number` is specified, the method simulates `number` + Monte Carlo realizations of an AR(1) process fit to fluctuations around + the mean angle. This ensemble is used to obtain the confidence limits: + `sigma_lo` (`level` quantile) and `kappa_hi` (1-`level` quantile). + These correspond to 1-tailed tests of the strength of the relationship. + + Parameters + ---------- + scales : float + + scale at which to evaluate the phase angle + + number : int, optional + + number of AR(1) series to create for significance testing. The default is 1000. + + level : float, optional + + significance level against which to gauge sigma and kappa. default: 0.05 + + + Returns + ------- + result : dict + + contains angle_mean (the mean angle for those scales), sigma (the + circular standard deviation), kappa, sigma_lo (alpha-level quantile + for sigma) and kappa_hi, the (1-alpha)-level quantile for kappa. + + See also + -------- + + pyleoclim.core.series.Series.wavelet_coherence : Wavelet coherence + + pyleoclim.core.scalograms.Scalogram : Scalogram object + + pyleoclim.core.scalograms.MultipleScalogram : Multiple Scalogram object + + pyleoclim.core.coherence.Coherence.plot : plotting method for Coherence objects + + pyleoclime.utils.wavelet.angle_sig : significance of phase angle statistics + + pyleoclim.utils.wavelet.angle_stats: phase angle statistics + + + References + ---------- + + [1] Grinsted, A., J. C. Moore, and S. Jevrejeva (2004), Application of the cross + wavelet transform and wavelet coherence to geophysical time series, + Nonlinear Processes in Geophysics, 11, 561–566. + + [2] Huber, R., Dutra, L. V., & da Costa Freitas, C. (2001). + SAR interferogram phase filtering based on the Von Mises distribution. + In IGARSS 2001. Scanning the Present and Resolving the Future. + Proceedings. IEEE 2001 International Geoscience and Remote Sensing Symposium + (Cat. No. 01CH37217) (Vol. 6, pp. 2816-2818). IEEE. + + [3] Farge, M. and Schneider, K. (2006): Wavelets: application to turbulence + Encyclopedia of Mathematical Physics (Eds. J.-P. Françoise, G. Naber and T.S. Tsun) + pp 408-420. + + Examples + -------- + + Calculate the phase angle between NINO3 and All India Rainfall at 5y scales: + + .. jupyter-execute:: + + import pyleoclim as pyleo + ts_air = pyleo.utils.load_dataset('AIR') + ts_nino = pyleo.utils.load_dataset('NINO3') + coh = ts_air.wavelet_coherence(ts_nino) + coh.phase_stats(scales=5) + + One may also obtain phase angle statistics over an interval, like the 2-8y ENSO band: + + .. jupyter-execute:: + + phase = coh.phase_stats(scales=[2,8]) + print("The mean angle is {:4.2f}°".format(phase.mean_angle/np.pi*180)) + print(phase) + + From this example, one diagnoses a strong anti-phased relationship in the ENSO band, + with high von Mises concentration (kappa ~ 3.35 >> kappa_hi) and low circular + dispersion (sigma ~ 0.6 << sigma_lo). This would be strong evidence of a consistent + anti-phasing between NINO3 and AIR at those scales. + + ''' + scales = np.array(scales) + + if scales.max() > self.scale.max(): + warnings.warn("Requested scale exceeds largest scale in object. Truncating to "+str(self.scale.max())) + + if scales.size == 1: + scale_idx = np.argmin(np.abs(self.scale - scales)) + res = waveutils.angle_sig(self.phase[:,scale_idx],nMC=number,level=level) + elif scales.size == 2: + idx_lo = np.argmin(np.abs(self.scale - scales.min())) + idx_hi = np.argmin(np.abs(self.scale - scales.max())) + if (idx_hi >= idx_lo): + raise ValueError("Insufficiently spaced scales. Please pick a single one, or a wider interval") + else: # average phase over those scales + nt, ns = self.phase.shape + phase = np.empty((nt)) + for i in range(nt): + phase[i], _, _ = waveutils.angle_stats(self.phase[i,idx_hi:idx_lo]) + res = waveutils.angle_sig(phase,nMC=number,level=level) # assess significance + + return res + +class GlobalCoherence: + '''Class to store the results of cross spectral analysis + + Attributes + ---------- + + global_coh: numpy array + coherence values + + scale: numpy array + scale values + + frequency: numpy array + frequency values + + coi: numpy array + cone of influence values + + coh: Coherence + Original coherence object + + + See Also + -------- + + pyleoclim.core.series.Series.global_coherence : method to compute the spectral coherence''' + + def __init__(self, global_coh, coh, signif_qs=None,signif_method=None,qs=None, label='Coherence'): + self.global_coh = global_coh + self.label = label + self.coh = coh + self.signif_qs = signif_qs + self.signif_method = signif_method + self.qs = qs + + def copy(self): + '''Copy object + ''' + return deepcopy(self) + + def signif_test(self,method='ar1sim',number=200,qs=[.95]): + '''Perform a significance test on the coherence values + + Parameters + ---------- + method: str; {'ar1sim','CN','phaseran'} + method to use for the surrogate test. Default is 'ar1sim'. + + number: int + number of surrogates to generate. Default is 200 + + qs: list + list of quantiles to compute. Default is [.95] + + Returns + ------- + global_coh: pyleoclim.core.globalcoherence.GlobalCoherence + Global coherence with significance field filled in + + Examples + -------- + + .. jupyter-execute:: + + soi = pyleo.utils.load_dataset('SOI') + nino3 = pyleo.utils.load_dataset('NINO3') + + gcoh = soi.global_coherence(nino3) + gcoh_sig = gcoh.signif_test(number=10) + gcoh_sig.plot() + ''' + + from ..core.surrogateseries import SurrogateSeries + + new = self.copy() + + ts1 = self.coh.timeseries1 + ts2 = self.coh.timeseries2 + + surr1 = SurrogateSeries(method=method,number=number) + surr2 = SurrogateSeries(method=method,number=number) + + surr1.from_series(ts1) + surr2.from_series(ts2) + + coh_array = np.empty((number,len(self.global_coh))) + + wavelet_kwargs = { + 'freq':self.coh.frequency, # pass the frequency axis directly + 'settings':self.coh.wave_args, + 'method':self.coh.wave_method, + } + + for i in range(number): + surr_series1 = surr1.series_list[i] + surr_series2 = surr2.series_list[i] + surr_coh = surr_series1.global_coherence(surr_series2,wavelet_kwargs=wavelet_kwargs) + coh_array[i,:] = surr_coh.global_coh + + quantiles = mquantiles(coh_array,qs,axis=0) + new.signif_qs = quantiles.data + new.signif_method = method + new.qs = qs + + return new + + def plot(self,figsize=(8,8),xlim=None,xlabel=None,label=None,psd_y_label='PSD',coh_y_label='Coherence',coh_line_color='grey',ax=None,coh_ylim=(.4,1),fill_alpha=.3,fill_color='grey',coh_plot_kwargs=None, + savefig_settings=None,spectral_kwargs=None,legend=True,legend_kwargs=None,spec1_plot_kwargs=None,spec2_plot_kwargs=None): + '''Plot the coherence as a function of scale or frequency, alongside the spectrum of the two timeseries (using the same method used for the coherence). + + Parameters + ---------- + figsize: tuple + size of the figure. Default is (8,8). Only used if ax is None + + xlim: tuple + x limits for the plot. Default is None + + label: str + label of the plot + + xlabel: str + x label of the plot + + psd_y_label: str + y label of the power spectral density plot (left hand side) + + coh_y_label: str + y label of the coherence plot (right hand side) + + coh_line_color: str + color of the coherence line + + coh_ylim: tuple + y limits for the coherence plot. Default is (.4,1) + + fill_alpha: float + alpha value for the fill_between plot. Default is .3 + + fill_color : str + color of the fill_between plot + + coh_plot_kwargs: dict + additional arguments to pass to the pyleoclim.utils.plotting.plot_xy + + savefig_settings: dict + settings to pass to the pyleoclim.utils.plotting.savefig function + + spectral_kwargs: dict + additional arguments to pass to the pyleo.Series.spectral method + + spec1_plot_kwargs: dict + additional arguments to pass to the pyleo.Series.spectral method + + spec2_plot_kwargs: dict + additional arguments to pass to the pyleo.Series.spectral method + + legend: bool + whether to include a legend or not + + legend_kwargs: dict + additional arguments to pass to ax.legend + + ax: matplotlib axis + axis to plot on + + Returns + ------- + ax: matplotlib axis + axis with the plot + + Examples + -------- + + .. jupyter-execute:: + + soi = pyleo.utils.load_dataset('SOI') + nino3 = pyleo.utils.load_dataset('NINO3') + + gcoh = soi.global_coherence(nino3) + gcoh.plot()''' + + coh_plot_kwargs = {} if coh_plot_kwargs is None else coh_plot_kwargs.copy() + savefig_settings = {} if savefig_settings is None else savefig_settings.copy() + spectral_kwargs = {} if spectral_kwargs is None else spectral_kwargs.copy() + legend_kwargs = {} if legend_kwargs is None else legend_kwargs.copy() + spec1_plot_kwargs = {} if spec1_plot_kwargs is None else spec1_plot_kwargs.copy() + spec2_plot_kwargs = {} if spec2_plot_kwargs is None else spec2_plot_kwargs.copy() + + if ax is None: + fig,ax = plt.subplots(figsize=figsize) + else: + pass + + coh_dict = self.coh.__dict__ + + if 'method' not in spectral_kwargs: + spectral_kwargs.update({'method': coh_dict['wave_method']}) + if 'freq' not in spectral_kwargs: + spectral_kwargs.update({'freq': coh_dict['freq_method']}) + if 'freq_kwargs' not in spectral_kwargs: + spectral_kwargs.update({'freq_kwargs': coh_dict['freq_kwargs']}) + if spectral_kwargs['method'] == coh_dict['wave_method']: + for key,value in coh_dict['wave_args'].items(): + if key not in spectral_kwargs: + spectral_kwargs.update({key: value}) + + ts1 = coh_dict['timeseries1'] + ts2 = coh_dict['timeseries2'] + + spec1 = ts1.spectral(label=ts1.label, **spectral_kwargs) + spec2 = ts2.spectral(label=ts2.label, **spectral_kwargs) + + spec1.plot(ax=ax,**spec1_plot_kwargs) + spec2.plot(ax=ax,**spec2_plot_kwargs) + + if xlim is not None: + ax.set_xlim(xlim) + if xlabel is not None: + ax.set_xlabel(xlabel) + if psd_y_label is not None: + ax.set_ylabel(psd_y_label) + + ax2 = ax.twinx() + + if coh_line_color is not None: + coh_plot_kwargs.update({'color':coh_line_color}) + if coh_y_label is not None: + ax2.set_ylabel(coh_y_label) + if coh_ylim is not None: + ax2.set_ylim(coh_ylim) + if label is None: + label = self.label + coh_plot_kwargs.update({'label': label}) + + scale = coh_dict['scale'] + + ax2.plot(scale,self.global_coh,**coh_plot_kwargs) + ax2.fill_between(scale, 0, self.global_coh, color=fill_color, alpha=fill_alpha) + ax2.grid(False) + + # plot significance levels if present + if self.signif_qs is not None: + signif_method_label = { + 'ar1sim': 'AR(1) simulations (MoM)', + 'phaseran': 'Phase Randomization', + 'CN': 'Colored Noise' + } + + for i, q in enumerate(self.signif_qs): + ax.plot( + scale, q, + label=f'{signif_method_label[self.signif_method]}, {self.qs[i]} threshold', + color='red', + linestyle='dashed', + linewidth=.8, + ) + + #formatting + if legend: + if len(legend_kwargs) == 0: + ax.legend().set_visible(False) + lines, labels = ax.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax2.legend(lines + lines2, labels + labels2) + else: + lines, labels = ax.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + if 'handles' not in legend_kwargs: + legend_kwargs.update({'handles': lines+lines2}) + if 'labels' not in legend_kwargs: + legend_kwargs.update({'labels': labels+labels2}) + ax.legend(**legend_kwargs) + ax2.legend().set_visible(False) + else: + ax.legend().set_visible(False) + ax2.legend().set_visible(False) + + if 'fig' in locals(): + if 'path' in savefig_settings: + plotting.savefig(fig, settings=savefig_settings) + return fig, ax + else: + return ax diff --git a/pyleoclim/core/geoseries.py b/pyleoclim/core/geoseries.py index c62aded9..bd9a0da0 100644 --- a/pyleoclim/core/geoseries.py +++ b/pyleoclim/core/geoseries.py @@ -508,9 +508,12 @@ def map_neighbors(self, mgs, radius=3000, projection='Orthographic', proj_defaul neighbor_coloring = ['w' for ik in range(len(neighborhood))] neighbor_coloring[-1] = 'k' neighborhood['original'] =neighbor_coloring + neighborhood['neighbors'] =neighbor_coloring + # plot neighbors - fig, ax_d = mapping.scatter_map(neighborhood, fig=fig, gs_slot=gridspec_slot, hue=hue, size=size, marker=marker, projection=projection, + fig, ax_d = mapping.scatter_map(neighborhood, fig=fig, gs_slot=gridspec_slot, hue=hue, size=size, + marker=marker, projection=projection, proj_default=proj_default, background=background, borders=borders, rivers=rivers, lakes=lakes, ocean=ocean, land=land, diff --git a/pyleoclim/core/multiplegeoseries.py b/pyleoclim/core/multiplegeoseries.py index b7461a42..c1db9ae3 100644 --- a/pyleoclim/core/multiplegeoseries.py +++ b/pyleoclim/core/multiplegeoseries.py @@ -7,14 +7,13 @@ from ..core.multipleseries import MultipleSeries from ..utils import mapping as mp from ..utils import plotting -import warnings -import copy + import matplotlib.pyplot as plt import matplotlib as mpl -from matplotlib import cm -from itertools import cycle -import matplotlib.lines as mlines +#from matplotlib import cm +#from itertools import cycle +#import matplotlib.lines as mlines import numpy as np #import warnings diff --git a/pyleoclim/core/multipleseries.py b/pyleoclim/core/multipleseries.py index 2cfab127..0435f064 100644 --- a/pyleoclim/core/multipleseries.py +++ b/pyleoclim/core/multipleseries.py @@ -1323,7 +1323,7 @@ def spectral(self, method='lomb_scargle', freq=None, settings=None, mute_pbar=Fa return psds - def wavelet(self, method='cwt', settings={}, freq_method='log', freq_kwargs=None, verbose=False, mute_pbar=False): + def wavelet(self, method='cwt', settings={}, freq=None, freq_kwargs=None, verbose=False, mute_pbar=False): '''Wavelet analysis Parameters @@ -1341,7 +1341,12 @@ def wavelet(self, method='cwt', settings={}, freq_method='log', freq_kwargs=None Settings for the particular method. The default is {}. - freq_method : str; {'log', 'scale', 'nfft', 'lomb_scargle', 'welch'} + freq : str or array, optional + Information to produce the frequency vector (highly consequential for the WWZ method) + This can be 'log','scale', 'nfft', 'lomb_scargle', 'welch' or a NumPy array. + If a string, will use `make_freq_vector()` with the specified frequency-generating method. + If an array, this will be passed directly to the spectral method. + If None (default), will use the 'log' method freq_kwargs : dict @@ -1392,7 +1397,6 @@ def wavelet(self, method='cwt', settings={}, freq_method='log', freq_kwargs=None .. jupyter-execute:: - import pyleoclim as pyleo soi = pyleo.utils.load_dataset('SOI') nino = pyleo.utils.load_dataset('NINO3') ms = (soi & nino) @@ -1403,7 +1407,7 @@ def wavelet(self, method='cwt', settings={}, freq_method='log', freq_kwargs=None scal_list = [] for s in tqdm(self.series_list, desc='Performing wavelet analysis on individual series', position=0, leave=True, disable=mute_pbar): - scal_tmp = s.wavelet(method=method, settings=settings, freq_method=freq_method, freq_kwargs=freq_kwargs, verbose=verbose) + scal_tmp = s.wavelet(method=method, settings=settings, freq=freq, freq_kwargs=freq_kwargs, verbose=verbose) scal_list.append(scal_tmp) scals = MultipleScalogram(scalogram_list=scal_list) diff --git a/pyleoclim/core/multivardecomp.py b/pyleoclim/core/multivardecomp.py index 861b45b3..0dc7bc39 100644 --- a/pyleoclim/core/multivardecomp.py +++ b/pyleoclim/core/multivardecomp.py @@ -1,10 +1,7 @@ import numpy as np #import pandas as pd from matplotlib import pyplot as plt, gridspec -#from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from matplotlib.ticker import MaxNLocator -#import cartopy.crs as ccrs -#import cartopy.feature as cfeature from ..core import series from ..utils import plotting, mapping, tsbase diff --git a/pyleoclim/core/resolutions.py b/pyleoclim/core/resolutions.py index d29b4c50..e6a05423 100644 --- a/pyleoclim/core/resolutions.py +++ b/pyleoclim/core/resolutions.py @@ -4,10 +4,8 @@ Resolution objects are designed to contain, display, and analyze information on the resolution of the time axis of a Series object. """ -from ..utils import tsutils, plotting, tsmodel, tsbase - -import warnings - +from ..utils import plotting +#import warnings import numpy as np import seaborn as sns import pandas as pd diff --git a/pyleoclim/core/scalograms.py b/pyleoclim/core/scalograms.py index a6ad21f2..0af36346 100644 --- a/pyleoclim/core/scalograms.py +++ b/pyleoclim/core/scalograms.py @@ -1,7 +1,7 @@ # It is unclear why the documentation for these two modules does not build automatically using automodule. It therefore had to be built using autoclass -from ..utils import plotting, lipdutils, tsutils +from ..utils import plotting, tsutils from ..utils import wavelet as waveutils import matplotlib.pyplot as plt @@ -9,12 +9,10 @@ from tabulate import tabulate from copy import deepcopy -from matplotlib.ticker import ScalarFormatter, FormatStrFormatter #, MaxNLocator -from mpl_toolkits.axes_grid1.inset_locator import inset_axes +from matplotlib.ticker import ScalarFormatter, FormatStrFormatter from scipy.stats.mstats import mquantiles -#from ..core import MultipleScalogram class Scalogram: ''' diff --git a/pyleoclim/core/series.py b/pyleoclim/core/series.py index 29c1c33b..3d1e67ca 100644 --- a/pyleoclim/core/series.py +++ b/pyleoclim/core/series.py @@ -10,22 +10,22 @@ import datetime as dt import re -from ..utils import tsutils, plotting, tsmodel, tsbase, lipdutils, jsonutils +from ..utils import tsutils, plotting, tsbase, jsonutils from ..utils import wavelet as waveutils from ..utils import spectral as specutils from ..utils import correlation as corrutils from ..utils import causality as causalutils from ..utils import decomposition from ..utils import filter as filterutils +from ..utils import lipdutils from ..core.psds import PSD from ..core.ssares import SsaRes from ..core.multipleseries import MultipleSeries from ..core.scalograms import Scalogram -from ..core.coherence import Coherence +from ..core.coherences import Coherence, GlobalCoherence from ..core.corr import Corr from ..core.resolutions import Resolution -from .globalcoherence import GlobalCoherence import seaborn as sns import matplotlib.pyplot as plt @@ -1847,10 +1847,9 @@ def summary_plot(self, psd, scalogram, figsize=[8, 10], title=None, .. jupyter-execute:: - import pyleoclim as pyleo series = pyleo.utils.load_dataset('SOI') psd = series.spectral(freq = 'welch') - scalogram = series.wavelet(freq_method = 'welch') + scalogram = series.wavelet(freq = 'welch') fig, ax = series.summary_plot(psd = psd,scalogram = scalogram) @@ -1859,7 +1858,6 @@ def summary_plot(self, psd, scalogram, figsize=[8, 10], title=None, .. jupyter-execute:: - import pyleoclim as pyleo series = pyleo.utils.load_dataset('SOI') psd = series.spectral(freq = 'welch').signif_test(number=20) scalogram = series.wavelet(freq_method = 'welch') @@ -3066,7 +3064,7 @@ def spectral(self, method='lomb_scargle', freq=None, freq_kwargs=None, settings= return psd - def wavelet(self, method='cwt', settings=None, freq_method='log', freq_kwargs=None, verbose=False): + def wavelet(self, method='cwt', settings=None, freq=None, freq_kwargs=None, verbose=False): ''' Perform wavelet analysis on a timeseries Parameters @@ -3078,8 +3076,12 @@ def wavelet(self, method='cwt', settings=None, freq_method='log', freq_kwargs=No is appropriate for unevenly-spaced series. Default is cwt, returning an error if the Series is unevenly-spaced. - freq_method : str, optional - Can be one of 'log', 'scale', 'nfft', 'lomb_scargle', 'welch'. + freq : str or array, optional + Information to produce the frequency vector (highly consequential for the WWZ method) + This can be 'log','scale', 'nfft', 'lomb_scargle', 'welch' or a NumPy array. + If a string, will use `make_freq_vector()` with the specified frequency-generating method. + If an array, this will be passed directly to the spectral method. + If None (default), will use the 'log' method freq_kwargs : dict Arguments for the frequency vector @@ -3177,12 +3179,25 @@ def wavelet(self, method='cwt', settings=None, freq_method='log', freq_kwargs=No settings = {} if settings is None else settings.copy() freq_kwargs = {} if freq_kwargs is None else freq_kwargs.copy() - freq = specutils.make_freq_vector(self.time, method=freq_method, **freq_kwargs) - + if 'freq' in settings.keys(): + freq_vec = settings['freq'] + freq_method = "user_specified" + else: + if freq is None: # assign the frequency method automatically based on context + freq_vec = specutils.make_freq_vector(self.time, method='log', **freq_kwargs) + freq_method = "log" + elif isinstance(freq, str): # apply the specified method + freq_vec = specutils.make_freq_vector(self.time, method=freq, **freq_kwargs) + freq_method = freq + elif isinstance(freq,np.ndarray): # use the specified vector if dimensions check out + freq_vec = np.squeeze(freq) + freq_method = "user_specified" + if freq.ndim != 1: + raise ValueError("freq should be a 1-dimensional array") args = {} - args['wwz'] = {'freq': freq} - args['cwt'] = {'freq': freq} + args['wwz'] = {'freq': freq_vec} + args['cwt'] = {'freq': freq_vec} if method == 'wwz': if 'ntau' in settings.keys(): @@ -3223,7 +3238,7 @@ def wavelet(self, method='cwt', settings=None, freq_method='log', freq_kwargs=No return scal def wavelet_coherence(self, target_series, method='cwt', settings=None, - freq_method='log', freq_kwargs=None, verbose=False, + freq=None, freq_kwargs=None, verbose=False, common_time_kwargs=None): ''' Performs wavelet coherence analysis with the target timeseries @@ -3237,8 +3252,12 @@ def wavelet_coherence(self, target_series, method='cwt', settings=None, if the series share the same evenly-spaced time axis. 'wwz' is designed for unevenly-spaced data, but is far slower. - freq_method : str - {'log','scale', 'nfft', 'lomb_scargle', 'welch'} + freq : str or array, optional + Information to produce the frequency vector (highly consequential for the WWZ method) + This can be 'log','scale', 'nfft', 'lomb_scargle', 'welch' or a NumPy array. + If a string, will use `make_freq_vector()` with the specified frequency-generating method. + If an array, this will be passed directly to the spectral method. + If None (default), will use the 'log' method freq_kwargs : dict Arguments for frequency vector @@ -3329,7 +3348,7 @@ def wavelet_coherence(self, target_series, method='cwt', settings=None, # by default, the plot function will look for the closest quantile to 0.95, but it is easy to adjust: cwt_sig.plot(signif_thresh = 0.9) - Another plotting option, `dashboard`, allows to visualize both + Another plotting option, `dashboard()`, allows to visualize both timeseries as well as the wavelet transform coherency (WTC), which quantifies where two timeseries exhibit similar behavior in time-frequency space, and the cross-wavelet transform (XWT), which indicates regions of high common power. @@ -3355,11 +3374,26 @@ def wavelet_coherence(self, target_series, method='cwt', settings=None, # Process options settings = {} if settings is None else settings.copy() freq_kwargs = {} if freq_kwargs is None else freq_kwargs.copy() - freq = specutils.make_freq_vector(self.time, method=freq_method, **freq_kwargs) + + if 'freq' in settings.keys(): + freq_vec = settings['freq'] + freq_method = "user_specified" + else: + if freq is None: # assign the frequency method automatically based on context + freq_vec = specutils.make_freq_vector(self.time, method='log', **freq_kwargs) + freq_method = "log" + elif isinstance(freq, str): # apply the specified method + freq_vec = specutils.make_freq_vector(self.time, method=freq, **freq_kwargs) + freq_method = freq + elif isinstance(freq,np.ndarray): # use the specified vector if dimensions check out + freq_vec = np.squeeze(freq) + freq_method = "user_specified" + if freq.ndim != 1: + raise ValueError("freq should be a 1-dimensional array") + args = {} - args['wwz'] = {'freq': freq, 'verbose': verbose} - args['cwt'] = {'freq': freq} - + args['wwz'] = {'freq': freq_vec, 'verbose': verbose} + args['cwt'] = {'freq': freq_vec} # put on same time axes if necessary if method == 'cwt' and not np.array_equal(self.time, target_series.time): diff --git a/pyleoclim/core/ssares.py b/pyleoclim/core/ssares.py index b9d9cf82..37adf652 100644 --- a/pyleoclim/core/ssares.py +++ b/pyleoclim/core/ssares.py @@ -10,8 +10,6 @@ import seaborn as sns from matplotlib import pyplot as plt, gridspec from matplotlib.ticker import MaxNLocator - -from ..core import series from ..utils import plotting diff --git a/pyleoclim/tests/test_core_Coherence.py b/pyleoclim/tests/test_core_Coherences.py similarity index 72% rename from pyleoclim/tests/test_core_Coherence.py rename to pyleoclim/tests/test_core_Coherences.py index 07f50419..1639b3db 100644 --- a/pyleoclim/tests/test_core_Coherence.py +++ b/pyleoclim/tests/test_core_Coherences.py @@ -85,4 +85,40 @@ def test_phasestats_t0(self, gen_ts): ts1 = gen_ts ts2 = gen_ts coh = ts2.wavelet_coherence(ts1) - _ = coh.phase_stats(scales=[2,8]) \ No newline at end of file + _ = coh.phase_stats(scales=[2,8]) + +class TestUiGlobalCoherencePlot: + ''' Tests for GlobalCoherence.plot() + ''' + + def test_plot_t0(self, gen_ts): + ''' Test GlobalCoherence.plot with various parameters + ''' + ts1 = gen_ts + ts2 = gen_ts + coh = ts1.global_coherence(ts2) + fig,ax = coh.plot() + pyleo.closefig(fig) + + def test_plot_t1(self, gen_ts): + ''' Test GlobalCoherence.plot with signif tests + ''' + ts1 = gen_ts + ts2 = gen_ts + coh = ts1.global_coherence(ts2).signif_test(number=1) + fig,ax = coh.plot() + pyleo.closefig(fig) + +class TestUiGlobalCoherenceSignifTest: + ''' Tests for GlobalCoherence.signif_test() + ''' + + @pytest.mark.parametrize('method',['ar1sim','phaseran','CN']) + @pytest.mark.parametrize('number',[1,10]) + @pytest.mark.parametrize('qs',[[.95],[.05,.95]]) + def test_signiftest_t0(self,method,number, qs,gen_ts): + ''' Test GlobalCoherence.signif_test + ''' + ts1 = gen_ts + ts2 = gen_ts + _ = ts1.global_coherence(ts2).signif_test(method=method,number=number,qs=qs) \ No newline at end of file diff --git a/pyleoclim/tests/test_core_GeoSeries.py b/pyleoclim/tests/test_core_GeoSeries.py index b6521143..87594d4e 100644 --- a/pyleoclim/tests/test_core_GeoSeries.py +++ b/pyleoclim/tests/test_core_GeoSeries.py @@ -86,7 +86,7 @@ def test_init_dropna(self, evenly_spaced_series): print(ts2.value) assert ~np.isnan(ts2.value[0]) - +@pytest.mark.xfail # will fail until pandas is fixed class TestUIGeoSeriesResample(): ''' test GeoSeries.Resample() ''' diff --git a/pyleoclim/tests/test_core_GlobalCoherence.py b/pyleoclim/tests/test_core_GlobalCoherence.py deleted file mode 100644 index ae4e366d..00000000 --- a/pyleoclim/tests/test_core_GlobalCoherence.py +++ /dev/null @@ -1,53 +0,0 @@ -''' Tests for pyleoclim.core.globalcoherence.GlobalCoherence - -Naming rules: -1. class: Test{filename}{Class}{method} with appropriate camel case -2. function: test_{method}_t{test_id} - -Notes on how to test: -0. Make sure [pytest](https://docs.pytest.org) has been installed: `pip install pytest` -1. execute `pytest {directory_path}` in terminal to perform all tests in all testing files inside the specified directory -2. execute `pytest {file_path}` in terminal to perform all tests in the specified file -3. execute `pytest {file_path}::{TestClass}::{test_method}` in terminal to perform a specific test class/method inside the specified file -4. after `pip install pytest-xdist`, one may execute "pytest -n 4" to test in parallel with number of workers specified by `-n` -5. for more details, see https://docs.pytest.org/en/stable/usage.html -''' - -import pytest -import pyleoclim as pyleo - -class TestUiGlobalCoherencePlot: - ''' Tests for GlobalCoherence.plot() - ''' - - def test_plot_t0(self, gen_ts): - ''' Test GlobalCoherence.plot with various parameters - ''' - ts1 = gen_ts - ts2 = gen_ts - coh = ts1.global_coherence(ts2) - fig,ax = coh.plot() - pyleo.closefig(fig) - - def test_plot_t1(self, gen_ts): - ''' Test GlobalCoherence.plot with signif tests - ''' - ts1 = gen_ts - ts2 = gen_ts - coh = ts1.global_coherence(ts2).signif_test(number=1) - fig,ax = coh.plot() - pyleo.closefig(fig) - -class TestUiGlobalCoherenceSignifTest: - ''' Tests for GlobalCoherence.signif_test() - ''' - - @pytest.mark.parametrize('method',['ar1sim','phaseran','CN']) - @pytest.mark.parametrize('number',[1,10]) - @pytest.mark.parametrize('qs',[[.95],[.05,.95]]) - def test_signiftest_t0(self,method,number, qs,gen_ts): - ''' Test GlobalCoherence.signif_test - ''' - ts1 = gen_ts - ts2 = gen_ts - _ = ts1.global_coherence(ts2).signif_test(method=method,number=number,qs=qs) \ No newline at end of file diff --git a/pyleoclim/tests/test_core_Series.py b/pyleoclim/tests/test_core_Series.py index d94641cf..b35467a4 100644 --- a/pyleoclim/tests/test_core_Series.py +++ b/pyleoclim/tests/test_core_Series.py @@ -974,8 +974,8 @@ def test_xwave_t3(self): v_unevenly = np.delete(ts1.value, deleted_idx) t1_unevenly = np.delete(ts2.time, deleted_idx1) v1_unevenly = np.delete(ts2.value, deleted_idx1) - ts3 = pyleo.Series(time=t_unevenly, value=v_unevenly) - ts4 = pyleo.Series(time=t1_unevenly, value=v1_unevenly) + ts3 = pyleo.Series(time=t_unevenly, value=v_unevenly,auto_time_params=True) + ts4 = pyleo.Series(time=t1_unevenly, value=v1_unevenly,auto_time_params=True) _ = ts3.wavelet_coherence(ts4,method='wwz') def test_xwave_t4(self): @@ -1005,6 +1005,22 @@ def test_xwave_t6(self): ts2 = gen_ts(model='colored_noise') tau = ts1.time[::10] _ = ts1.wavelet_coherence(ts2,method='wwz',settings={'tau':tau}) + + @pytest.mark.parametrize('freq', [None,np.linspace(1/100,1/2,num=20),'log', 'nfft', 'welch']) + def test_xwave_t7(self, freq): + ''' Test Series.wavelet_coherence() with freq method argument + ''' + ts1 = gen_ts(model='colored_noise') + ts2 = gen_ts(model='colored_noise') + + coh = ts1.wavelet_coherence(ts2,freq=freq) + + if freq is None: + assert coh.freq_method == 'log' + elif isinstance(freq,np.ndarray): + assert coh.freq_method == 'user_specified' + elif isinstance(freq, str): + assert coh.freq_method == freq class TestUISeriesGlobalCoherence(): '''Test global coherence @@ -1016,13 +1032,31 @@ def test_globalcoherence_t0(self): ts2 = gen_ts(model='colored_noise') _ = ts1.global_coherence(ts2) - def test_globalcoherence_t0(self): + def test_globalcoherence_t1(self): ''' Test Series.global_coherence() with passed coh ''' ts1 = gen_ts(model='colored_noise') ts2 = gen_ts(model='colored_noise') coh = ts1.wavelet_coherence(ts2) _ = ts1.global_coherence(coh=coh) + + @pytest.mark.parametrize('freq', [None,np.linspace(1/100,1/2,num=20),'log', 'nfft', 'welch']) + def test_globalcoherence_t2(self,freq): + ''' Test Series.global_coherence() with wavelet kwargs + ''' + ts1 = gen_ts(model='colored_noise') + ts2 = gen_ts(model='colored_noise') + + kwargs = {} + kwargs['freq'] = freq + gcoh = ts1.global_coherence(target_series=ts2,wavelet_kwargs=kwargs) + + if freq is None: + assert gcoh.coh.freq_method == 'log' + elif isinstance(freq,np.ndarray): + assert gcoh.coh.freq_method == 'user_specified' + elif isinstance(freq, str): + assert gcoh.coh.freq_method == freq class TestUISeriesWavelet(): ''' Test the wavelet functionalities @@ -1042,28 +1076,41 @@ def test_wave_t1(self,wave_method): n = 100 ts = gen_ts(model='colored_noise',nt=n) freq = np.linspace(1/n, 1/2, 20) - _ = ts.wavelet(method=wave_method, settings={'freq': freq}) - + scal = ts.wavelet(method=wave_method, settings={'freq': freq}) + assert scal.freq_method == "user_specified" + def test_wave_t2(self): ''' Test Series.wavelet() ntau option and plot functionality ''' - ts = gen_ts(model='colored_noise',nt=200) + ts = gen_ts(model='colored_noise',nt=100) _ = ts.wavelet(method='wwz',settings={'ntau':10}) @pytest.mark.parametrize('mother',['MORLET', 'PAUL', 'DOG']) def test_wave_t3(self,mother): ''' Test Series.wavelet() with different mother wavelets ''' - ts = gen_ts(model='colored_noise',nt=200) - _ = ts.wavelet(method='cwt',settings={'mother':mother}) + ts = gen_ts(model='colored_noise',nt=100) + _ = ts.wavelet(settings={'mother':mother}) @pytest.mark.parametrize('freq_meth', ['log', 'scale', 'nfft', 'welch']) def test_wave_t4(self,freq_meth): - ''' Test Series.wavelet() with different mother wavelets + ''' Test Series.wavelet() with different frequency methods ''' - ts = gen_ts(model='colored_noise',nt=200) - _ = ts.wavelet(method='cwt',freq_method=freq_meth) - + ts = gen_ts(model='colored_noise',nt=100) + scal = ts.wavelet(freq=freq_meth) + assert scal.freq_method == freq_meth + + @pytest.mark.parametrize('freq', [None,np.linspace(1/100,1/2,num=20)]) + def test_wave_t5(self,freq): + ''' Test Series.wavelet() with different frequency vectors + ''' + ts = gen_ts(model='colored_noise',nt=100) + scal = ts.wavelet(freq=freq) + if freq is None: + assert scal.freq_method == 'log' + else: + assert scal.freq_method == 'user_specified' + class TestUISeriesSsa(): ''' Test the SSA functionalities ''' diff --git a/pyleoclim/utils/mapping.py b/pyleoclim/utils/mapping.py index 96fab138..35d82b98 100644 --- a/pyleoclim/utils/mapping.py +++ b/pyleoclim/utils/mapping.py @@ -718,7 +718,7 @@ def scatter_map(geos, hue='archiveType', size=None, marker='archiveType', edgeco # geos_df = pd.DataFrame(value_d) # return geos_df - def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_var=None, edgecolor='w', + def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_var=None, edgecolor_var=None, edgecolor='w', ax=None, ax_d=None, proj=None, scatter_kwargs=None, legend=True, lgd_kwargs=None, colorbar=None, fig=None, color_scale_type=None, # gs_slot=None, cmap=None, **kwargs): @@ -773,16 +773,28 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va } missing_val = missing_d['label'] - # if 'edgecolor' in scatter_kwargs: - # edgecolor = scatter_kwargs['edgecolor'] - # scatter_kwargs['edgecolors'] = edgecolor - # if 'edgecolors' not in scatter_kwargs: - # scatter_kwargs['edgecolors'] = edgecolor - # - if 'edgecolor' in scatter_kwargs: - edgecolor = scatter_kwargs['edgecolor']# = edgecolor + if isinstance(scatter_kwargs, dict): + edgecolor = scatter_kwargs.pop('edgecolor', edgecolor) + + if 'neighbor' in df.columns: + edgecolor_var = 'neighbor' + # if ~isinstance(edgecolor, np.ndarray): + # if isinstance(edgecolor, str): + # edgecolor = [edgecolor] + # edgecolor = np.array(edgecolor) + + if isinstance(edgecolor, (list, np.ndarray)): + if len(edgecolor) == len(_df): + _df['edgecolor'] = edgecolor + elif len(edgecolor) == 1: + _df['edgecolor'] = edgecolor[0] + # try making a column populated by the edgecolor + elif isinstance(edgecolor, str): _df['edgecolor'] = edgecolor + elif isinstance(edgecolor, dict): + if edgecolor_var in _df.columns: + _df['edgecolor'] = _df[edgecolor_var].map(edgecolor) hue_var = hue_var if hue_var in _df.columns else None hue_var_type_numeric = False @@ -935,23 +947,42 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va _df['edgecolor'] = edgecolor _df['neighbor'] = _df['edgecolor'].map({'k': 'target', 'w': 'neighbor'}) + # handle missing values hue_data = _df[_df[hue_var] == missing_val] - if len(hue_data) > 0: - sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, - style=marker_var, transform=transform,edgecolor='w', - # change to transform=scatter_kwargs['transform'] + if 'neighbor' in hue_data.columns: + sns.scatterplot(data=hue_data, x=x, y=y, size=size_var, + transform=transform, + edgecolor=hue_data.edgecolor.values, linewidth=2, + style=marker_var, hue=hue_var, palette=[missing_d['hue'] for ik in range(len(hue_data))], + ax=ax, legend=False, + **scatter_kwargs) + + sns.scatterplot(data=hue_data, x=x, y=y, size=size_var, + transform=transform, + edgecolor=None, linewidth=0, + style=marker_var, hue=hue_var, palette=[missing_d['hue'] for ik in range(len(hue_data))], - ax=ax, **scatter_kwargs) + ax=ax, + **scatter_kwargs) + # sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, + # style=marker_var, transform=transform,edgecolor=edgecolor, + # # change to transform=scatter_kwargs['transform'] + # palette=[missing_d['hue'] for ik in range(len(hue_data))], + # ax=ax, **scatter_kwargs) missing_handles, missing_labels = ax.get_legend_handles_labels() - if 'neighbor' in hue_data.columns: - if len(hue_data[hue_data['neighbor'] != 'neighbor']) > 1: - _edgecolor = hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor'].values[0] - else: - _edgecolor = hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor'] - sns.scatterplot(data=hue_data[hue_data['neighbor'] != 'neighbor'], x=x, y=y, size=size_var, - transform=transform, edgecolor=_edgecolor, - style=marker_var, hue=hue_var, palette=palette, ax=ax, **scatter_kwargs) + + # # if the missing values are being handled for map_neighbors + # if 'neighbor' in hue_data.columns: + # # if the missing value is actually the target + # if len(hue_data[hue_data['neighbor'] != 'neighbor']) > 1: + # _edgecolor = hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor'].values[0] + # else: + # _edgecolor = hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor'] + # sns.scatterplot(data=hue_data[hue_data['neighbor'] != 'neighbor'], x=x, y=y, size=size_var, + # transform=transform, edgecolor=_edgecolor, + # style=marker_var, hue=hue_var, palette=palette, ax=ax, **scatter_kwargs) + # available values else: missing_handles, missing_labels = [], [] @@ -960,12 +991,23 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va if hue_norm is not None: scatter_kwargs['hue_norm'] = hue_norm - sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, transform=transform,edgecolor='w', - style=marker_var, palette=palette, ax=ax, **scatter_kwargs) + # sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, transform=transform,edgecolor=edgecolor, + # style=marker_var, palette=palette, ax=ax, **scatter_kwargs) + if 'neighbor' in hue_data.columns: - sns.scatterplot(data=hue_data[hue_data['neighbor'] != 'neighbor'], x=x, y=y, size=size_var, - transform=transform, edgecolor=hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor'].values[0], - style=marker_var, hue=hue_var, palette=palette, ax=ax, **scatter_kwargs) + sns.scatterplot(data=hue_data, x=x, y=y, size=size_var, + transform=transform, + edgecolor=hue_data.edgecolor.values,linewidth=2, + style=marker_var, hue=hue_var, palette=palette, ax=ax, legend=False, **scatter_kwargs) + if not isinstance(edgecolor, str): + edgecolor = None + linewidth = 0 + else: + linewidth = 1 + + sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, transform=transform, + edgecolor=edgecolor,linewidth=linewidth, + style=marker_var, palette=palette, ax=ax, **scatter_kwargs) else: scatter_kwargs['zorder'] = 13 diff --git a/pyleoclim/utils/tsmodel.py b/pyleoclim/utils/tsmodel.py index 529cc1cc..0f3b48c1 100644 --- a/pyleoclim/utils/tsmodel.py +++ b/pyleoclim/utils/tsmodel.py @@ -726,7 +726,7 @@ def uar1_sim(t, tau, sigma_2=1): def inverse_cumsum(arr): return np.diff(np.concatenate(([0], arr))) -def random_time_axis(n, delta_t_dist = "exponential", param = [1.0]): +def random_time_axis(n, delta_t_dist = "exponential", param = [1.0], seed = None): ''' Generate a random time axis according to a specific probability model @@ -756,12 +756,14 @@ def random_time_axis(n, delta_t_dist = "exponential", param = [1.0]): ''' + if seed is not None: + np.random.seed(seed) + # check for a valid distribution valid_distributions = ["exponential", "poisson", "pareto", "random_choice"] if delta_t_dist not in valid_distributions: raise ValueError("delta_t_dist must be one of: 'exponential', 'poisson', 'pareto', 'random_choice'.") - param = np.array(param) # coerce array type if delta_t_dist == "exponential": diff --git a/setup.py b/setup.py index cf33da89..3eec46f0 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages -version = '1.0.0' +version = '1.0.0b0' # Read the readme file contents into variable def read(fname):