diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 9913085fde6..2fbca29aedd 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -12,12 +12,13 @@ from keras.src.backend.common.stateless_scope import StatelessScope from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.jax import distribution_lib +from flax import nnx SUPPORTS_SPARSE_TENSORS = True IS_THREAD_SAFE = True -class Variable(KerasVariable): +class Variable(nnx.Param, KerasVariable): def _initialize(self, value): value = jnp.array(value, dtype=self._dtype) # Note that variable.shape is needed by distribution_lib @@ -33,6 +34,14 @@ def _initialize(self, value): self._layout = None self._direct_assign(value) + @property + def _value(self): + return self.raw_value + + @_value.setter + def _value(self, value): + self.raw_value = value + def _direct_assign(self, value): if getattr(self, "_layout", None) is not None: value = distribution_lib.distribute_variable(value, self._layout) diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index fbcc4fe5b5c..021895b81e2 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -1,2 +1,5 @@ -class JaxLayer: - pass +from flax import nnx + +class JaxLayer(nnx.Object): + def __init_subclass__(cls): + super().__init_subclass__() diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index d065fdd2fdf..8f04edbccc2 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -212,6 +212,9 @@ def call(self, inputs): ``` """ + def __init_subclass__(cls): + super(BackendLayer, cls).__init_subclass__() + def __new__(cls, *args, **kwargs): obj = super().__new__(cls, *args, **kwargs) @@ -538,7 +541,9 @@ def add_weight( initializer = "zeros" initializer = initializers.get(initializer) with backend.name_scope(self.name, caller=self): + value = initializer(shape, dtype) variable = backend.Variable( + value, initializer=initializer, shape=shape, dtype=dtype, diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py index a289bc5f321..9852c85ceb8 100644 --- a/keras/src/ops/operation.py +++ b/keras/src/ops/operation.py @@ -10,10 +10,13 @@ from keras.src.utils import python_utils from keras.src.utils import traceback_utils from keras.src.utils.naming import auto_name +from flax import nnx @keras_export("keras.Operation") class Operation: + def __init_subclass__(cls): + super().__init_subclass__() def __init__(self, dtype=None, name=None): if name is None: name = auto_name(self.__class__.__name__) @@ -97,6 +100,7 @@ def __new__(cls, *args, **kwargs): to manually implement `get_config()`. """ instance = super(Operation, cls).__new__(cls) + vars(instance)['_object__state'] = nnx.object.ObjectState() # Generate a config to be returned by default by `get_config()`. arg_names = inspect.getfullargspec(cls.__init__).args