Skip to content

Commit

Permalink
Remove usage of jax2tf.PolyShape in Scenic.
Browse files Browse the repository at this point in the history
jax2tf.PolyShape has been deprecated since January 2024.
Instead, we can use simple strings.
The code changed here was always using strings, but referenced jax2tf.PolyShape in type declarations.

PiperOrigin-RevId: 707421419
  • Loading branch information
gnecula authored and Scenic Authors committed Dec 18, 2024
1 parent 9364f20 commit 97d6ac5
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion scenic/common_lib/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def convert_and_save_model(
Sequence[dict[str, tf.TensorSpec]],
],
polymorphic_shapes: Optional[
Union[str, jax2tf.PolyShape, dict[str, str]]
Union[str, dict[str, str]]
] = None,
with_gradient: bool = False,
enable_xla: bool = True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@
" *,\n",
" input_signatures: Union[Sequence[tf.TensorSpec],\n",
" Sequence[Sequence[tf.TensorSpec]]],\n",
" polymorphic_shapes: Optional[Union[str, jax2tf.PolyShape]] = None,\n",
" polymorphic_shapes: Optional[str] = None,\n",
" with_gradient: bool = False,\n",
" enable_xla: bool = True,\n",
" compile_model: bool = True,\n",
Expand Down

0 comments on commit 97d6ac5

Please sign in to comment.