Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions axlearn/add_one_colocated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pathwaysutils

print("initializing pathwaysutils")
pathwaysutils.initialize()
print("pathwaysutils initialized")

import numpy as np
import jax
from jax.experimental import colocated_python
import os
import shutil
import orbax.checkpoint as ocp

print("jax version on cpu host:", jax.__version__)

print("getting tpu devices")
tpu_devices = jax.devices()
print("tpu devices: ", tpu_devices)
print("getting cpu devices")
cpu_devices = colocated_python.colocated_cpu_devices(tpu_devices)
print("cpu devices: ", cpu_devices)

import cloudpickle

print("JAX_PLATFORMS is 'proxy'. Setting up pathways colocated python checkpointing.")
print(f" Using jax version {jax.__version__} and cloudpickle version {cloudpickle.__version__}")


print("def add_one")


@colocated_python.colocated_python
def add_one(x):
import sys

sys.stderr.write("In colocated python function \n")
sys.stderr.write(f"[Colocated] jax version: {jax.__version__} \n")
sys.stderr.write("[Colocated] add_one")
sys.stderr.write(f"[Colocated] x: {x} on device: {x.device } \n")
return x+1


print("creating input 1")
x = np.array(1)
print("putting on device")
x = jax.device_put(x, cpu_devices[0])

print("adding one to input 1")
out = add_one(x)
print("getting out")
out = jax.device_get(out)
print("out 1: ", out)

print("creating input 2")
x = np.array(5)
print("putting on device")
x = jax.device_put(x, cpu_devices[0])

assert out == 2, f"out: {out}"

print("adding one to input 2")
out = add_one(x)
print("getting out")
out = jax.device_get(out)
print("out 2: ", out)

assert out == 6, f"out: {out}"
32 changes: 30 additions & 2 deletions axlearn/cloud/gcp/pathways_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,23 @@
# The port used by pathways worker server.
# The specific value is not important, as long as clients and servers use the same port.
_PATHWAYS_WORKER_PORT = 29001
_COLOCATED_CONTAINER_PORT = 50051
# Pin to specific pathways image version for stable release.
# There is no guarantee that this image will work with newer Jax releases.
# This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625
# This image extends GRPC timeout for long context models.
_PATHWAYS_IMAGE_TAG = "disable_settings_20250701"

# The docker image used by pathways proxy container.
_PATHWAYS_PROXY_IMAGE = (
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}"
"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/ksadi/unsanitized_proxy_server_maxtext:latest"
)
# The docker image used by pathways resource manager container and worker container.
_PATHWAYS_SERVER_IMAGE = (
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}"
"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/ksadi/unsanitized_server_maxtext:latest"
)
_COLOCATED_PYTHON_IMAGE = (
"gcr.io/cloud-tpu-multipod-dev/ksadi_sidecar_maxtext:latest"
)
# The container name of pathways resourcemanager.
_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME = "pathways-rm"
Expand All @@ -63,6 +68,10 @@
# The k8s replicatedJob name for pathways-worker pods.
_PATHWAYS_WORKER_REPLICATED_JOB_NAME = "pathways-worker"

_COLOCATED_PYTHON_SIDECAR_NAME = "colocated-python-sidecar"



# Add node-selector for cpu workload to avoid sharing nodes with system services.
_PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY = "axlearn/nodepool_type"
_PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE = "workload"
Expand Down Expand Up @@ -382,6 +391,23 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
],
),
]

def _colocated_python_container(self):

return dict(
name=_COLOCATED_PYTHON_SIDECAR_NAME,
image=_COLOCATED_PYTHON_IMAGE,
restartPolicy="Always",
env=[
{
"name": "GRPC_SERVER_ADDRESS",
"value": f"0.0.0.0:{_COLOCATED_CONTAINER_PORT}",
},
],
imagePullPolicy="Always",
ports=[dict(containerPort=_COLOCATED_CONTAINER_PORT)],

)

def _build_pathways_head_pod(self) -> Nested[Any]:
"""Builds a pathways head pod. The pod includes a head container,
Expand Down Expand Up @@ -563,6 +589,8 @@ def _build_pathways_worker_pod(
pod_spec["containers"] = [
self._build_pathways_worker_container(pathways_worker_replicated_job_index)
]
pod_spec["initContainers"]=[self._colocated_python_container()]

worker_pod["spec"] = pod_spec

# Service account for nodes.
Expand Down
2 changes: 1 addition & 1 deletion axlearn/experiments/text/gpt/envy.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,4 +537,4 @@ def make_single_host_config(base_config_name: str) -> SpmdTrainer.Config:
make_single_host_config_func = functools.partial(make_single_host_config, config_name)
config_map[f"{config_name}-single-host"] = make_single_host_config_func

return config_map
return config_map
16 changes: 8 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ requires-python = ">=3.10"
# Minimal requirments for axlearn/common/config.py.
dependencies = [
"attrs>=23.1.0", # We use `type` in `attrs.field`
"numpy==1.26.4", # verified with tensorflow 2.14 RaggedTensor
"numpy==2.1.1", # verified with tensorflow 2.14 RaggedTensor
]

[project.optional-dependencies]
Expand All @@ -23,9 +23,9 @@ core = [
"absl-py==2.1.0",
"chex==0.1.88",
"importlab==0.8.1", # breaks pytype on 0.8
"jax==0.5.3",
"jaxlib==0.5.3",
"ml-dtypes==0.4.1",
"jax==0.6.2",
"jaxlib==0.6.2",
"ml-dtypes==0.5.1",
"msgpack==1.1.0", # for checkpointing.
"nltk==3.7", # for text preprocessing
"optax==0.1.7", # optimizers (0.1.0 has known bugs).
Expand All @@ -34,14 +34,14 @@ core = [
"protobuf>=3.20.3",
"tensorboard-plugin-profile==2.20.4",
# This has both x86 and arm64 wheels. Underneath the hood it uses tensorflow-macos since 2.13.
"tensorflow==2.17.1",
"tensorflow==2.19.1",
"tensorflow-datasets>=4.9.2",
"tensorflow-io>=0.37.1", # for tensorflow-2.16. Note that 0.37.0 results in "pure virtual method called".
"tensorflow_text==2.17.0; platform_machine == 'x86_64'", # implied by seqio, but also used directly for text processing
"tensorflow_text==2.19.0; platform_machine == 'x86_64'", # implied by seqio, but also used directly for text processing
"tensorstore>=0.1.63", # used for supporting GDA checkpoints
"toml", # for config management
"typing-extensions==4.12.2",
"scipy==1.12.0", # to avoid "module 'scipy.linalg' has no attribute 'tril'"
"scipy==1.15.0", # to avoid "module 'scipy.linalg' has no attribute 'tril'"
"seqio==0.0.18", # used for inputs
"aqtp==0.8.2", # Updated from 0.4.0; compatible with Python 3.10
"flax==0.10.2", # for AQT, param converter and adapter.
Expand Down Expand Up @@ -107,7 +107,7 @@ gcp = [
# Note: Specify -f https://storage.googleapis.com/jax-releases/libtpu_releases.html during install.
tpu = [
"axlearn[gcp]",
"jax[tpu]==0.5.3", # must be >=0.4.19 for compat with v5p.
"jax[tpu]==0.6.2", # must be >=0.4.19 for compat with v5p.
"pathwaysutils==0.1.1", # For JAX+Pathways single-controller accelerator coordinator.
]
# Vertex AI tensorboard. TODO(markblee): Merge with `gcp`.
Expand Down
Loading