Skip to content

Commit 3028fd2

Browse files
carmoccaawaelchlipre-commit-ci[bot]akihironitta
authored
Fix TPU test CI (#14926)
* Fix TPU test CI * +x first * Lite first to uncovert errors faster * Fixes * One more * Simplify XLALauncher wrapping to avoid pickle error * debug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Debug commit successful. Trying local definitions * Require tpu for mock test * ValueError: The number of devices must be either 1 or 8, got 4 instead * Fix mock test * Simplify call, rely on defaults * Skip OSError for now. Maybe upgrading will help * Simplify launch tests, move some to lite * Stricter typing * RuntimeError: Accessing the XLA device before processes have spawned is not allowed. * Revert "RuntimeError: Accessing the XLA device before processes have spawned is not allowed." This reverts commit f65107e. * Alternative boring solution to the reverted commit * Fix failing test on CUDA machine * Workarounds * Try latest mkl * Revert "Try latest mkl" This reverts commit d06813a. * Wrong exception * xfail * Mypy * Comment change * Spawn launch refactor * Accept that we cannot lazy init now * Fix mypy and launch test failures * The base dockerfile already includes mkl-2022.1.0 - what if we use it? * try a different mkl version * Revert mkl version changes Co-authored-by: awaelchli <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <[email protected]>
1 parent e290c20 commit 3028fd2

File tree

7 files changed

+60
-56
lines changed

7 files changed

+60
-56
lines changed

.circleci/config.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ parameters:
1414
GHA_Event:
1515
type: string
1616
default: ""
17+
GHA_Meta:
18+
type: string
19+
default: ""
1720

1821
references:
1922

@@ -49,9 +52,10 @@ references:
4952
update_jsonnet: &update_jsonnet
5053
run:
5154
name: Update jsonnet
55+
environment:
56+
PR_NUMBER: << pipeline.parameters.GHA_Meta >>
5257
command: |
5358
export SHA=$(git rev-parse --short HEAD)
54-
export PR_NUMBER=$(git ls-remote origin "pull/*/head" | grep -F -f $SHA | awk -F'/' '{print $3}')
5559
python -c "fname = 'dockers/tpu-tests/tpu_test_cases.jsonnet' ; data = open(fname).read().replace('{PYTORCH_VERSION}', '$XLA_VER')
5660
data = data.replace('{PYTHON_VERSION}', '$PYTHON_VER').replace('{PR_NUMBER}', '$PR_NUMBER').replace('{SHA}', '$SHA') ; open(fname, 'w').write(data)"
5761
cat dockers/tpu-tests/tpu_test_cases.jsonnet

.github/workflows/ci-circleci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,5 @@ jobs:
3030
- uses: CircleCI-Public/[email protected]
3131
env:
3232
CCI_TOKEN: ${{ secrets.CCI_TOKEN }}
33+
with:
34+
GHA_Meta: ${{ github.event.pull_request.number }}

dockers/tpu-tests/tpu_test_cases.jsonnet

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,40 @@ local tputests = base.BaseTest {
2020

2121
command: utils.scriptCommand(
2222
|||
23+
set +x # turn off tracing, spammy
24+
set -e # exit on error
25+
2326
source ~/.bashrc
24-
set -e
2527
conda activate lightning
26-
mkdir -p /home/runner/work/lightning && cd /home/runner/work/lightning
27-
git clone https://github.com/Lightning-AI/lightning.git
28-
cd lightning
29-
echo $PWD
30-
git ls-remote --refs origin
31-
git fetch origin "refs/pull/{PR_NUMBER}/head"
32-
git checkout {SHA}
33-
export PACKAGE_NAME=pytorch
34-
export FREEZE_REQUIREMENTS=1
35-
pip install -e .[test]
28+
29+
echo "--- Fetch the SHA's changes ---"
30+
git clone --single-branch --depth 1 https://github.com/Lightning-AI/lightning.git /home/runner/work/lightning
31+
cd home/runner/work/lightning
32+
git fetch origin --depth 1 pull/{PR_NUMBER}/head:test/{PR_NUMBER}
33+
git -c advice.detachedHead=false checkout {SHA}
34+
35+
echo "--- Install PL ---"
36+
PACKAGE_NAME=pytorch FREEZE_REQUIREMENTS=1 pip install -e .[test]
37+
pip list
38+
3639
echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS
3740
export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}"
38-
export PL_RUN_TPU_TESTS=1
39-
cd tests/tests_pytorch
40-
coverage run --source=pytorch_lightning -m pytest -vv --durations=0 ./
41-
echo "\n||| Running standalone tests |||\n"
42-
export PL_STANDALONE_TESTS_SOURCE=pytorch_lightning
43-
export PL_STANDALONE_TESTS_BATCH_SIZE=1
44-
bash run_standalone_tests.sh
45-
echo "\n||| END PYTEST LOGS |||\n"
41+
42+
echo "--- Running Lite tests ---"
43+
cd tests/tests_lite
44+
PL_RUN_TPU_TESTS=1 coverage run --source=lightning_lite -m pytest -vv --durations=0 ./
45+
46+
echo "--- Running standalone Lite tests ---"
47+
PL_STANDALONE_TESTS_SOURCE=lightning_lite PL_STANDALONE_TESTS_BATCH_SIZE=1 bash run_standalone_tests.sh
48+
49+
echo "--- Running PL tests ---"
50+
cd ../tests_pytorch
51+
PL_RUN_TPU_TESTS=1 coverage run --source=pytorch_lightning -m pytest -vv --durations=0 ./
52+
53+
echo "--- Running standalone PL tests ---"
54+
PL_STANDALONE_TESTS_SOURCE=pytorch_lightning PL_STANDALONE_TESTS_BATCH_SIZE=1 bash run_standalone_tests.sh
55+
56+
echo "--- Generating coverage ---"
4657
coverage xml
4758
cat coverage.xml | tr -d '\t'
4859
|||

src/lightning_lite/strategies/launchers/xla.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import time
15-
from functools import wraps
1615
from multiprocessing.queues import SimpleQueue
17-
from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING
16+
from typing import Any, Callable, Optional, TYPE_CHECKING
1817

19-
from torch.multiprocessing import get_context, ProcessContext
18+
from torch.multiprocessing import get_context
2019

2120
from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
2221
from lightning_lite.utilities import _TPU_AVAILABLE
@@ -67,7 +66,7 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
6766
"""
6867
context = get_context(self._start_method)
6968
return_queue = context.SimpleQueue()
70-
_save_spawn(
69+
xmp.spawn(
7170
self._wrapping_function,
7271
args=(function, args, kwargs, return_queue),
7372
nprocs=self._strategy.num_processes,
@@ -90,30 +89,16 @@ def _wrapping_function(
9089
if process_idx == 0:
9190
return_queue.put(move_data_to_device(results, "cpu"))
9291

92+
_rank_teardown(process_idx)
9393

94-
def _save_spawn(
95-
fn: Callable,
96-
args: Tuple = (),
97-
nprocs: Optional[int] = None,
98-
join: bool = True,
99-
daemon: bool = False,
100-
start_method: str = "spawn",
101-
) -> Optional[ProcessContext]:
102-
"""Wraps the :func:`torch_xla.distributed.xla_multiprocessing.spawn` with added teardown logic for the worker
103-
processes."""
104-
105-
@wraps(fn)
106-
def wrapped(rank: int, *_args: Any) -> None:
107-
fn(rank, *_args)
108-
109-
import torch_xla.core.xla_model as xm
110-
111-
# Make all processes wait for each other before joining
112-
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
113-
xm.rendezvous("end-process")
114-
# Ensure that the rank 0 process is the one exiting last
115-
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
116-
if rank == 0:
117-
time.sleep(1)
118-
119-
return xmp.spawn(wrapped, args=args, nprocs=nprocs, join=join, daemon=daemon, start_method=start_method)
94+
95+
def _rank_teardown(rank: int) -> None:
96+
import torch_xla.core.xla_model as xm
97+
98+
# Make all processes wait for each other before joining
99+
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
100+
xm.rendezvous("end-process")
101+
# Ensure that the rank 0 process is the one exiting last
102+
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
103+
if rank == 0:
104+
time.sleep(1)

src/pytorch_lightning/strategies/launchers/xla.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch.multiprocessing as mp
1919

2020
import pytorch_lightning as pl
21-
from lightning_lite.strategies.launchers.xla import _save_spawn
21+
from lightning_lite.strategies.launchers.xla import _rank_teardown
2222
from lightning_lite.utilities import move_data_to_device
2323
from pytorch_lightning.strategies.launchers.multiprocessing import (
2424
_FakeQueue,
@@ -74,7 +74,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
7474
"""
7575
context = mp.get_context(self._start_method)
7676
return_queue = context.SimpleQueue()
77-
_save_spawn(
77+
xmp.spawn(
7878
self._wrapping_function,
7979
args=(trainer, function, args, kwargs, return_queue),
8080
nprocs=self._strategy.num_processes,
@@ -106,6 +106,8 @@ def _wrapping_function(
106106
if process_idx == 0:
107107
return_queue.put(move_data_to_device(results, "cpu"))
108108

109+
_rank_teardown(process_idx)
110+
109111
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
110112
rank_zero_debug("Collecting results from rank 0 process.")
111113
checkpoint_callback = trainer.checkpoint_callback

tests/tests_lite/strategies/launchers/test_xla.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from unittest import mock
2-
from unittest.mock import ANY, Mock
2+
from unittest.mock import Mock
33

44
from tests_lite.helpers.runif import RunIf
55

@@ -29,11 +29,9 @@ def test_xla_launcher_xmp_spawn(get_context_mock, xmp_mock):
2929
queue = get_context_mock.return_value.SimpleQueue.return_value
3030
get_context_mock.assert_called_with("fork")
3131
xmp_mock.spawn.assert_called_with(
32-
ANY,
32+
launcher._wrapping_function,
3333
args=(function, ("positional-arg",), {"keyword_arg": 0}, queue),
3434
nprocs=strategy.num_processes,
35-
join=True,
36-
daemon=False,
3735
start_method="fork",
3836
)
3937
queue.get.assert_called_once_with()

tests/tests_pytorch/accelerators/test_tpu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121
import torch
2222
from torch import nn
23+
from torch.multiprocessing import ProcessExitedException
2324
from torch.utils.data import DataLoader
2425

2526
from pytorch_lightning import Trainer
@@ -69,6 +70,7 @@ def test_resume_training_on_cpu(tmpdir):
6970

7071
@RunIf(tpu=True)
7172
@mock.patch.dict(os.environ, {}, clear=True)
73+
@pytest.mark.xfail(raises=ProcessExitedException, reason="https://github.com/pytorch/xla/issues/1666")
7274
def test_if_test_works_after_train(tmpdir):
7375
"""Ensure that .test() works after .fit()"""
7476
model = BoringModel()

0 commit comments

Comments
 (0)