-
Notifications
You must be signed in to change notification settings - Fork 369
Update JAX API usage to latest version #1317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
GitOrigin-RevId: d143490
Make sure sink's contribution is added once. Also added tests. GitOrigin-RevId: 8de870c
GitOrigin-RevId: 56cf7e8
* pin nccl version * empty commit * add actual pacakge * trigger new build to address flaky test * Update pyproject.toml GitOrigin-RevId: b8653ad
…ункции flatten/unflatten в utils
Hi, the workflows need approval to run (GitHub Actions are pending). Can someone with write access approve and run them? @ruomingp pls |
GitOrigin-RevId: 65af801
…рать лишний пробел в _enable_numeric_checks
…мый код в dataclass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of these changes have the potential to break things and don't seem to be necessary, as @changlan mentioned. Could you explain for every change, why it is necessary? Also, please do not mark comments as resolved yourself. To streamline reviewing, we only have PR reviewers mark comments as resolved.
Also please resolve any merge conflicts. |
We need to update the JAX API usage across the codebase to use the latest stable versions.
Changes Required
jax.tree_util
withjax.tree
:register_pytree_with_keys
instead ofregister_pytree_node
tree_map
with new API versionFiles to Modify
Key files that need updates:
axlearn/common/struct.py
axlearn/common/utils.py
axlearn/common/metrics.py
axlearn/common/learner.py
Implementation Details
Success Criteria