Skip to content

Commit

Permalink
isort: add jax as a first-party
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Jul 9, 2024
1 parent 5a49cdf commit c658c13
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 5 deletions.
3 changes: 2 additions & 1 deletion arraycontext/impl/jax/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
"""
from functools import partial, reduce

import jax.numpy as jnp
import numpy as np

import jax.numpy as jnp

from arraycontext.container import NotAnArrayContainerError, serialize_container
from arraycontext.container.traversal import (
rec_map_array_container,
Expand Down
3 changes: 0 additions & 3 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,6 @@ def __init__(self,
unstable.
"""
import jax.numpy as jnp

import pytato as pt
super().__init__(compile_trace_callback=compile_trace_callback)
self.array_types = (pt.Array, jnp.ndarray)
Expand Down Expand Up @@ -766,7 +765,6 @@ def zeros_like(self, ary):

def from_numpy(self, array):
import jax

import pytato as pt

def _from_numpy(ary):
Expand All @@ -791,7 +789,6 @@ def freeze(self, array):
return array

import jax.numpy as jnp

import pytato as pt

from arraycontext.container.traversal import rec_keyed_map_array_container
Expand Down
1 change: 0 additions & 1 deletion arraycontext/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def __init__(self, *args, **kwargs):
def is_available(cls) -> bool:
try:
import jax # noqa: F401

import pytato # noqa: F401
return True
except ImportError:
Expand Down

0 comments on commit c658c13

Please sign in to comment.