We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1a9f4d6 commit 62fc630Copy full SHA for 62fc630
axlearn/common/ops/_optimization_barrier.py
@@ -5,8 +5,6 @@
5
from typing import Any
6
7
import jax
8
-from jax._src import ad_checkpoint # pylint: disable=protected-access
9
-
10
11
@jax.custom_jvp
12
@jax.custom_batching.custom_vmap # Must be wrapped in this before custom_jvp.
axlearn/common/utils_test.py
@@ -14,7 +14,6 @@
14
15
# pylint: disable=no-self-use
16
17
-import jaxlib
18
import numpy as np
19
import pytest
20
import tensorflow as tf
0 commit comments