Skip to content

Commit

Permalink
deprecate from jax.config
Browse files Browse the repository at this point in the history
  • Loading branch information
zgbkdlm committed Oct 30, 2024
1 parent 181130a commit 894fbab
Show file tree
Hide file tree
Showing 50 changed files with 55 additions and 100 deletions.
3 changes: 2 additions & 1 deletion chirpgp/quadratures.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax
import jax.numpy as jnp
import numpy as np
import scipy
from chirpgp.models import g
from typing import Callable, NamedTuple, Union, List, Tuple
from functools import partial
Expand Down Expand Up @@ -172,7 +173,7 @@ def gauss_hermite(cls, d: int, order: int = 3):

w_1d = np.zeros(shape=(order,))
for i in range(order):
w_1d[i] = (2 ** (order - 1) * np.math.factorial(order) * np.sqrt(np.pi) /
w_1d[i] = (2 ** (order - 1) * scipy.special.factorial(order) * np.sqrt(np.pi) /
(order ** 2 * (np.polyval(hermite_coeff[order - 1],
hermite_roots[i])) ** 2))

Expand Down
3 changes: 1 addition & 2 deletions demos/cd_ekfs_mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from chirpgp.quadratures import gaussian_expectation, SigmaPoints
from chirpgp.toymodels import gen_chirp, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from chirpgp.tools import rmse
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions demos/cd_ghfs_mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from chirpgp.quadratures import gaussian_expectation, SigmaPoints
from chirpgp.toymodels import gen_chirp, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from chirpgp.tools import rmse
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions demos/classical_methods/anf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from chirpgp.classical_methods import adaptive_notch_filter
from chirpgp.toymodels import gen_chirp_envelope, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from chirpgp.tools import rmse
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions demos/classical_methods/hilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from chirpgp.classical_methods import hilbert_method
from chirpgp.toymodels import gen_chirp, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from chirpgp.tools import rmse
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions demos/classical_methods/lascala_ekfs_mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
from chirpgp.quadratures import gaussian_expectation
from chirpgp.toymodels import gen_chirp, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from chirpgp.tools import rmse
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions demos/classical_methods/lascala_ghfs_mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from chirpgp.quadratures import gaussian_expectation, SigmaPoints
from chirpgp.toymodels import gen_chirp, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from chirpgp.tools import rmse
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions demos/classical_methods/mean_spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from chirpgp.classical_methods import mean_power_spectrum
from chirpgp.toymodels import gen_chirp, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from chirpgp.tools import rmse
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions demos/classical_methods/mle_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from chirpgp.classical_methods import mle_polynomial
from chirpgp.toymodels import gen_chirp, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from chirpgp.tools import rmse
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions demos/ekfs_mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from chirpgp.quadratures import gaussian_expectation
from chirpgp.toymodels import gen_chirp, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from chirpgp.tools import rmse
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions demos/ghfs_harmonics_mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from chirpgp.quadratures import gaussian_expectation, SigmaPoints
from chirpgp.toymodels import gen_harmonic_chirp, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from chirpgp.tools import rmse
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions demos/ghfs_mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
from chirpgp.quadratures import gaussian_expectation, SigmaPoints
from chirpgp.toymodels import gen_chirp, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from chirpgp.tools import rmse
from jax.config import config

# Use float64
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Time interval, number of times, and time instances.
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions paper_plots_tables/plot_chirp_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from chirpgp.tools import simulate_sde
from chirpgp.models import model_chirp, disc_chirp_lcd
from chirpgp.models import g
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

plt.rcParams.update({
'text.usetex': True,
Expand Down
4 changes: 2 additions & 2 deletions paper_plots_tables/plot_cov_harmonic_sde.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""
Plot the covariance function of harmonic SDE. This generated Figure 2 in the paper.
"""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from chirpgp.cov_funcs import vmap_cov_harmonic_sde
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.01
Expand Down
3 changes: 1 addition & 2 deletions paper_plots_tables/plot_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from chirpgp.models import g
from chirpgp.quadratures import gaussian_expectation
from chirpgp.toymodels import meow_freq, gen_chirp, constant_mag, damped_exp_mag, random_ou_mag
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

mc = 0

Expand Down
3 changes: 1 addition & 2 deletions paper_plots_tables/plot_estimation_harmonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from chirpgp.models import g
from chirpgp.quadratures import gaussian_expectation
from chirpgp.toymodels import meow_freq, gen_harmonic_chirp, constant_mag, damped_exp_mag, random_ou_mag
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# mc = 11
mc = 17
Expand Down
3 changes: 1 addition & 2 deletions paper_plots_tables/print_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from chirpgp.models import g, g_inv, build_chirp_model
from chirpgp.filters_smoothers import ekf, eks
from chirpgp.toymodels import gen_chirp, meow_freq, constant_mag
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions real_applications/bats/eptesicus_nilssonii_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
from chirpgp.models import g, build_harmonic_chirp_model
from chirpgp.filters_smoothers import sgp_filter, sgp_smoother
from chirpgp.quadratures import gaussian_expectation, SigmaPoints
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Load data
fs, sound = scipy.io.wavfile.read('./Eptesicus_nilssonii_1_o.wav')
Expand Down
3 changes: 1 addition & 2 deletions real_applications/bats/myotis_myotis_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@
from chirpgp.models import g, build_harmonic_chirp_model
from chirpgp.filters_smoothers import sgp_filter, sgp_smoother
from chirpgp.quadratures import gaussian_expectation, SigmaPoints
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Load data
fs, sound = scipy.io.wavfile.read('Myotis_myotis_2_o.wav')
Expand Down
3 changes: 1 addition & 2 deletions real_applications/ligo/gw_freq_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from chirpgp.models import g, g_inv, build_chirp_model
from chirpgp.filters_smoothers import sgp_filter, sgp_smoother
from chirpgp.quadratures import gaussian_expectation, SigmaPoints
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Load gravitational wave strain data. Please download them by yourself, see README.md
ts, ys = jnp.asarray(np.genfromtxt('./data/fig1-observed-H.txt').T)
Expand Down
3 changes: 1 addition & 2 deletions test/test_classical_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import numpy.testing as npt
from chirpgp.classical_methods import hilbert_method, mean_power_spectrum, adaptive_notch_filter, mle_polynomial
from chirpgp.toymodels import gen_chirp, gen_chirp_envelope, affine_freq, polynomial_freq, constant_mag
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)


class TestClassicalMethods:
Expand Down
4 changes: 2 additions & 2 deletions test/test_cov_funcs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import jax
import math
import pytest
import jax.numpy as jnp
import numpy.testing as npt
from chirpgp.cov_funcs import vmap_cov_harmonic_sde, vmap_marginal_cov_harmonic_sde
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)


class TestCovFuncs:
Expand Down
3 changes: 1 addition & 2 deletions test/test_crlb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from chirpgp.filters_smoothers import kf, rts
from chirpgp.models import posterior_cramer_rao
from chirpgp.tools import lti_sde_to_disc
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

np.random.seed(666)

Expand Down
4 changes: 1 addition & 3 deletions test/test_ekfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from chirpgp.filters_smoothers import ekf, eks, cd_ekf, cd_eks
from chirpgp.tools import simulate_sde
import tme.base_jax as tme
# import matplotlib.pyplot as plt
import numpy.testing as npt
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

dim_x = 3
kappa = 10.
Expand Down
3 changes: 1 addition & 2 deletions test/test_filters_smoothers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from chirpgp.filters_smoothers import kf, rts, ekf, eks, cd_ekf, cd_eks, \
sgp_filter, sgp_smoother, cd_sgp_filter, cd_sgp_smoother
from chirpgp.quadratures import SigmaPoints
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

np.random.seed(666)

Expand Down
3 changes: 1 addition & 2 deletions test/test_gauss_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import jax.numpy as jnp
import numpy.testing as npt
from chirpgp.gauss_newton import gauss_newton, levenberg_marquardt
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)


class TestOptimisers:
Expand Down
4 changes: 2 additions & 2 deletions test/test_m32.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import pytest
import math
import jax
import jax.numpy as jnp
import numpy.testing as npt
from chirpgp.tools import lti_sde_to_disc
from chirpgp.models import _m32_solution
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)


class TestM32:
Expand Down
3 changes: 1 addition & 2 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from chirpgp.models import g, g_inv, model_chirp, disc_chirp_lcd, disc_chirp_lcd_cond_v, disc_chirp_tme, \
disc_chirp_euler_maruyama, disc_m32, disc_model_lascala_lcd, model_harmonic_chirp, disc_harmonic_chirp_lcd
from chirpgp.tools import lti_sde_to_disc
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

np.random.seed(666)

Expand Down
4 changes: 2 additions & 2 deletions test/test_quadratures.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import math
import jax
import jax.numpy as jnp
import numpy.testing as npt
from chirpgp.quadratures import SigmaPoints
from jax import vmap
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

d = 1
gh_order = 5
Expand Down
3 changes: 1 addition & 2 deletions test/test_toymodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from chirpgp.toymodels import affine_freq, polynomial_freq, meow_freq, random_ou_mag, gen_chirp, gen_harmonic_chirp, \
constant_mag
from functools import partial
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)


class TestToyModels:
Expand Down
3 changes: 1 addition & 2 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import numpy.testing as npt
from chirpgp.tools import simulate_lgssm, lti_sde_to_disc, fwd_transformed_pdf, chol_partial_const_diag, rmse
from chirpgp.quadratures import gaussian_expectation
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)


class TestUtils:
Expand Down
3 changes: 1 addition & 2 deletions tetralith/generate_chirp_for_matlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
import numpy as np
from scipy.io import savemat
from chirpgp.toymodels import gen_chirp, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions tetralith/generate_harmonic_chirp_for_matlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
import numpy as np
from scipy.io import savemat
from chirpgp.toymodels import gen_harmonic_chirp, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
3 changes: 1 addition & 2 deletions tetralith/jobs/anf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import chirpgp.tools
from chirpgp.classical_methods import adaptive_notch_filter
from chirpgp.toymodels import gen_chirp_envelope, meow_freq, constant_mag, damped_exp_mag, random_ou_mag
from jax.config import config

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

# Times
dt = 0.001
Expand Down
Loading

0 comments on commit 894fbab

Please sign in to comment.