Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions axlearn/common/compiler_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,9 @@ def infer_xsc_compiler_options(
xla_tpu_sdc_checker_alternate_megacore_cores=True,
# XLA ICI SDC Checker flags:
# N.B. ICI checker only runs once after first program compilation.
# Enable the interconnect checker on first program call.
xla_tpu_ici_sdc_test_run_on_program_start=True,
# Disable the interconnect checker by default as it is not meant for production run.
# In a job with 32k chips, disabling it reduced compilation time from 18mins to 15s.
xla_tpu_ici_sdc_test_run_on_program_start=False,
# Max distance between send/recv neighbours.
xla_tpu_ici_sdc_test_max_distance=1,
# Number of repeated send/recv before checking for equivalence.
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/compiler_options_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_xsc_compiler_options(self):
xla_tpu_sdc_check_halt_on_detection=False,
xla_tpu_sdc_replicate_llo=True,
xla_tpu_sdc_checker_alternate_megacore_cores=True,
xla_tpu_ici_sdc_test_run_on_program_start=True,
xla_tpu_ici_sdc_test_run_on_program_start=False,
xla_tpu_ici_sdc_test_max_distance=1,
xla_tpu_ici_sdc_test_pipeline_depth=4,
xla_tpu_ici_sdc_test_buffer_size_chunks=32,
Expand Down
4 changes: 2 additions & 2 deletions axlearn/common/flash_attention/tpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def test_logit_sink(
# Compare outputs
out = fn(input_batch)
ref_out = ref_fn(input_batch)
self.assertNestedAllClose(out, ref_out, atol=2e-2)
self.assertNestedAllClose(out, ref_out, atol=1e-6 if q_dtype == jnp.float32 else 2e-2)

# Compare gradients
def grad_fn(float_inputs, aux_inputs, f):
Expand All @@ -310,7 +310,7 @@ def grad_fn(float_inputs, aux_inputs, f):
aux_inputs = dict(bias=bias, prng_key=prng_key)
grad_out = jax.grad(grad_fn, argnums=0)(float_inputs, aux_inputs, fn)
ref_grad_out = jax.grad(grad_fn, argnums=0)(float_inputs, aux_inputs, ref_fn)
self.assertNestedAllClose(grad_out, ref_grad_out, atol=1e-5)
self.assertNestedAllClose(grad_out, ref_grad_out, atol=1e-6)

def test_logit_sink_shape_validation(self):
"""Test that logit sink shape validation works correctly."""
Expand Down
13 changes: 6 additions & 7 deletions axlearn/common/flash_attention/tpu_splash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
participate in the max and sum computations but do not contribute to the output. When enabled,
the `logit_sink` parameter provides per-head scalar values that are incorporated into the
softmax normalization as follows: the running maximum is initialized with the sink value, and
at each step, the sink's contribution is added to the normalization sum (denominator) as
exp(logit_sink - running_max). The sink does not contribute to the numerator of the
attention-weighted sum, as it has no corresponding value. In the backward pass, gradients for
during the final normalization the sink's contribution is added once to the normalization sum
(denominator) as exp(logit_sink - running_max). The sink does not contribute to the numerator of
the attention-weighted sum, as it has no corresponding value. In the backward pass, gradients for
the sink logits are computed as the negative sum of their attention weights multiplied by the
output gradients, reflecting their role in the normalization term without direct output
contribution.
Expand Down Expand Up @@ -219,10 +219,6 @@ def body(kv_compute_index, _):
assert s_curr.shape == (bq, bkv_compute)

l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,))
# Add sink contribution to normalization sum.
if logit_sink_ref is not None:
sink_value = logit_sink_ref[h].astype(qk.dtype)
l_curr = l_curr + jnp.exp(sink_value - m_next[:, 0:1])
assert l_curr.shape == (bq, NUM_LANES)

alpha = jnp.exp(m_prev - m_next)
Expand Down Expand Up @@ -262,6 +258,9 @@ def run():
@pl.when(j == grid_width - 1)
def end():
l = l_scratch_ref[...]
if logit_sink_ref is not None:
sink_value = logit_sink_ref[h].astype(jnp.float32)
l = l + jnp.exp(sink_value - m_scratch_ref[...])
l_inv = pltpu.repeat(1.0 / l, head_dim_repeats, axis=1)
o_ref[...] = (o_scratch_ref[...] * l_inv).astype(o_ref.dtype)
if logsumexp_ref is not None:
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ def _metric_accumulator_unflatten(
MetricAccumulator,
_metric_accumulator_flatten,
_metric_accumulator_unflatten,
)
)
5 changes: 3 additions & 2 deletions axlearn/common/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,18 @@ def flatten_func(x):
return data, meta

def flatten_with_keys(x) -> tuple[tuple, tuple]:
data = tuple((jax.tree_util.GetAttrKey(name), getattr(x, name)) for name in data_fields)
data = tuple((jax.tree.GetAttrKey(name), getattr(x, name)) for name in data_fields)
meta = tuple(getattr(x, name) for name in meta_fields)
return data, meta


# Note that meta, data are tuples as produced by `flatten_with_keys`.
def unflatten_func(meta: tuple, data: tuple):
# Support unflattening from chex.dataclass which requires handling lists.
data = tuple(data)
return dataklass(**dict(zip(meta_fields + data_fields, meta + data)))

jax.tree_util.register_pytree_with_keys(
jax.tree.register_pytree_with_keys(
dataklass, flatten_with_keys, unflatten_func, flatten_func
)

Expand Down
30 changes: 19 additions & 11 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,20 @@
# The set of supported floating point dtypes.
_supported_float_dtypes = [jnp.bfloat16, jnp.float32]

def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]:
"""Generate the (key, value) pairs for the immediate children of a pytree `node`."""
flat = jax.tree.default_registry.flatten_one_level(node)
if flat is None:
return []

if isinstance(node, tuple) and hasattr(node, "_fields") and flat[1] == type(node):
return [(jax.tree.GetAttrKey(s), getattr(node, s)) for s in node._fields]

key_children, _ = jax.tree.default_registry.flatten_one_level_with_keys(node)
if key_children:
return key_children

return [(jax.tree.FlattenedIndexKey(i), c) for i, c in enumerate(flat[0])]

@dataclasses.dataclass
class HybridMeshShape:
Expand Down Expand Up @@ -1896,31 +1910,25 @@ def thread_stack_traces() -> Sequence[Sequence[str]]:
def pytree_children(node: Any) -> Sequence[tuple[KeyEntry, Any]]:
"""Generate the (key, value) pairs for the immediate children of a pytree `node`.

The returned children match those returned by
`jax.tree_util.default_registry.flatten_one_level()`.

Reference: jax._src.tree_util.generate_key_paths()

Example:
```
assert pytree_children(dict(a=[1,2])) == [(DictKey('a'), [1,2])]
```
"""
# pylint: disable-next=protected-access
registry_with_keypaths = jax._src.tree_util._registry_with_keypaths

key_handler = registry_with_keypaths.get(type(node))
if key_handler:
key_children, _ = key_handler.flatten_with_keys(node)
return key_children

flat = jax.tree_util.default_registry.flatten_one_level(node)
if flat is None:
return []

if isinstance(node, tuple) and hasattr(node, "_fields") and flat[1] == type(node):
# Handle namedtuple as a special case, based on heuristic.
return [(jax.tree_util.GetAttrKey(s), getattr(node, s)) for s in node._fields]

key_children, _ = jax.tree_util.default_registry.flatten_one_level_with_keys(node)
if key_children:
return key_children

return [(jax.tree_util.FlattenedIndexKey(i), c) for i, c in enumerate(flat[0])]


Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ core = [
"nltk==3.7", # for text preprocessing
"optax==0.1.7", # optimizers (0.1.0 has known bugs).
"portpicker",
"pyarrow>=20.0.0,<21.0.0", # Pin to v20.x to avoid PyExtensionType -> ExtensionType breaking change in v21
"pyarrow<21.0.0", # Pin to v20.x to avoid PyExtensionType -> ExtensionType breaking change in v21
"protobuf>=3.20.3",
"tensorboard-plugin-profile==2.20.4",
# This has both x86 and arm64 wheels. Underneath the hood it uses tensorflow-macos since 2.13.
Expand Down Expand Up @@ -126,7 +126,7 @@ vertexai_tensorboard = [
]
# Dataflow dependencies.
dataflow = [
"pyarrow>=20.0.0,<21.0.0", # Pin to v20.x to avoid PyExtensionType -> ExtensionType breaking change in v21
"pyarrow<21.0.0", # Pin to v20.x to avoid PyExtensionType -> ExtensionType breaking change in v21
"apache-beam==2.55.1",
"apache-beam[gcp]",
"google-apitools", # for beam pipeline
Expand All @@ -137,6 +137,8 @@ gpu = [
"triton==2.1.0",
"jax[cuda12]==0.5.3",
"nvidia-ml-py==12.560.30",
# pin nccl version, otherwise jax[cuda12] will pull latest version
"nvidia-nccl-cu12==2.27.5",
]
# Open API inference.
open_api = [
Expand Down