Skip to content

trainer.strategy.cluster_environment is not XLAEnvironment when on XLA device/ TPU #16802

@Liyang90

Description

@Liyang90

Bug description

    trainer = Trainer(
        accelerator="tpu",
        devices=8,
    )
    print("type(trainer.strategy.cluster_environment)", type(trainer.strategy.cluster_environment))

Expecting XLAEnvironment but getting "<class 'lightning.fabric.plugins.environments.lightning.LightningEnvironment'>"

This leads to wrong strategy.world_size and strategy.global_rank on TPU pod, because torch_xla APIs are not used. As a result DatasetSampler would raise error like "ValueError: Invalid rank 20, rank should be in the interval [0, 7]" even though the world size is 32 on TPU V3-32.

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): Trainer
#- PyTorch Lightning Version (e.g., 1.5.0): master
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration: TPU pod V3-32
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @JackCaoG @steventk-g @Liyang90

Metadata

Metadata

Assignees

Labels

accelerator: tpuTensor Processing UnitbugSomething isn't working

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions