From 65646f38f81d1b495b591195d9731daeca3a7372 Mon Sep 17 00:00:00 2001 From: DrJessop Date: Wed, 18 Dec 2024 22:47:28 +0000 Subject: [PATCH] Clean up global search and replace --- equinox/_ad.py | 1 - equinox/_enum.py | 1 - equinox/_errors.py | 1 - equinox/_jit.py | 1 - equinox/_misc.py | 1 - equinox/_vmap_pmap.py | 1 - equinox/debug/_backward_nan.py | 1 - equinox/debug/_dce.py | 1 - equinox/internal/_loop/checkpointed.py | 1 - tests/test_tree.py | 1 - 10 files changed, 10 deletions(-) diff --git a/equinox/_ad.py b/equinox/_ad.py index bbb1e7c3..665f7792 100644 --- a/equinox/_ad.py +++ b/equinox/_ad.py @@ -18,7 +18,6 @@ import jax import jax._src.traceback_util as traceback_util import jax.core -import jax.extend import jax.interpreters.ad as ad import jax.numpy as jnp import jax.tree_util as jtu diff --git a/equinox/_enum.py b/equinox/_enum.py index 7d107f08..1716a60c 100644 --- a/equinox/_enum.py +++ b/equinox/_enum.py @@ -3,7 +3,6 @@ import jax._src.traceback_util as traceback_util import jax.core -import jax.extend import jax.numpy as jnp import numpy as np from jaxtyping import Array, ArrayLike, Bool, Int diff --git a/equinox/_errors.py b/equinox/_errors.py index 0ce20bed..3297fb3d 100644 --- a/equinox/_errors.py +++ b/equinox/_errors.py @@ -9,7 +9,6 @@ import jax import jax._src.traceback_util as traceback_util import jax.core -import jax.extend import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu diff --git a/equinox/_jit.py b/equinox/_jit.py index 4a2d5445..1d6b190a 100644 --- a/equinox/_jit.py +++ b/equinox/_jit.py @@ -11,7 +11,6 @@ import jax._src.dispatch import jax._src.traceback_util as traceback_util import jax.core -import jax.extend import jax.errors import jax.numpy as jnp from jaxtyping import PyTree diff --git a/equinox/_misc.py b/equinox/_misc.py index 20127f31..14183278 100644 --- a/equinox/_misc.py +++ b/equinox/_misc.py @@ -1,6 +1,5 @@ import jax import jax.core -import jax.extend import jax.numpy as jnp from jaxtyping import Array diff --git a/equinox/_vmap_pmap.py b/equinox/_vmap_pmap.py index 588e2aa4..ef917500 100644 --- a/equinox/_vmap_pmap.py +++ b/equinox/_vmap_pmap.py @@ -8,7 +8,6 @@ import jax import jax._src.traceback_util as traceback_util import jax.core -import jax.extend import jax.numpy as jnp import jax.tree_util as jtu import numpy as np diff --git a/equinox/debug/_backward_nan.py b/equinox/debug/_backward_nan.py index b8ba0fc0..b1c78b16 100644 --- a/equinox/debug/_backward_nan.py +++ b/equinox/debug/_backward_nan.py @@ -3,7 +3,6 @@ import jax import jax._src.traceback_util as traceback_util import jax.core -import jax.extend import jax.numpy as jnp import jax.tree_util as jtu diff --git a/equinox/debug/_dce.py b/equinox/debug/_dce.py index 10578f93..8b70a74a 100644 --- a/equinox/debug/_dce.py +++ b/equinox/debug/_dce.py @@ -2,7 +2,6 @@ import jax import jax.core -import jax.extend import jax.numpy as jnp import jax.tree_util as jtu from jaxtyping import PyTree diff --git a/equinox/internal/_loop/checkpointed.py b/equinox/internal/_loop/checkpointed.py index 3973121f..fd5ed7c8 100644 --- a/equinox/internal/_loop/checkpointed.py +++ b/equinox/internal/_loop/checkpointed.py @@ -54,7 +54,6 @@ import jax import jax.core -import jax.extend import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu diff --git a/tests/test_tree.py b/tests/test_tree.py index 85c85342..b4cae5b1 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -3,7 +3,6 @@ import equinox as eqx import jax import jax.core -import jax.extend import jax.nn as jnn import jax.numpy as jnp import jax.random as jrandom