|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import logging |
| 15 | +from pathlib import Path |
| 16 | +from unittest import mock |
| 17 | +from unittest.mock import ANY |
| 18 | + |
14 | 19 | import pytest
|
15 | 20 |
|
16 | 21 | import pytorch_lightning as pl
|
17 | 22 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
18 | 23 | from pytorch_lightning.utilities.migration import migrate_checkpoint
|
19 | 24 | from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version
|
| 25 | +from pytorch_lightning.utilities.upgrade_checkpoint import main as upgrade_main |
20 | 26 |
|
21 | 27 |
|
22 | 28 | @pytest.mark.parametrize(
|
@@ -47,3 +53,71 @@ def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint):
|
47 | 53 | updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
|
48 | 54 | assert updated_checkpoint == old_checkpoint == new_checkpoint
|
49 | 55 | assert _get_version(updated_checkpoint) == pl.__version__
|
| 56 | + |
| 57 | + |
| 58 | +def test_upgrade_checkpoint_file_missing(tmp_path, caplog): |
| 59 | + # path to single file (missing) |
| 60 | + file = tmp_path / "checkpoint.ckpt" |
| 61 | + with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(file)]): |
| 62 | + with caplog.at_level(logging.ERROR): |
| 63 | + with pytest.raises(SystemExit): |
| 64 | + upgrade_main() |
| 65 | + assert f"The path {file} does not exist" in caplog.text |
| 66 | + |
| 67 | + caplog.clear() |
| 68 | + |
| 69 | + # path to non-empty directory, but no checkpoints with matching extension |
| 70 | + file.touch() |
| 71 | + with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path), "--extension", ".other"]): |
| 72 | + with caplog.at_level(logging.ERROR): |
| 73 | + with pytest.raises(SystemExit): |
| 74 | + upgrade_main() |
| 75 | + assert "No checkpoint files with extension .other were found" in caplog.text |
| 76 | + |
| 77 | + |
| 78 | +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.save") |
| 79 | +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.load") |
| 80 | +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.migrate_checkpoint") |
| 81 | +def test_upgrade_checkpoint_single_file(migrate_mock, load_mock, save_mock, tmp_path): |
| 82 | + file = tmp_path / "checkpoint.ckpt" |
| 83 | + file.touch() |
| 84 | + with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(file)]): |
| 85 | + upgrade_main() |
| 86 | + |
| 87 | + load_mock.assert_called_once_with(Path(file)) |
| 88 | + migrate_mock.assert_called_once() |
| 89 | + save_mock.assert_called_once_with(ANY, Path(file)) |
| 90 | + |
| 91 | + |
| 92 | +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.save") |
| 93 | +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.torch.load") |
| 94 | +@mock.patch("pytorch_lightning.utilities.upgrade_checkpoint.migrate_checkpoint") |
| 95 | +def test_upgrade_checkpoint_directory(migrate_mock, load_mock, save_mock, tmp_path): |
| 96 | + top_files = [tmp_path / "top0.ckpt", tmp_path / "top1.ckpt"] |
| 97 | + nested_files = [ |
| 98 | + tmp_path / "subdir0" / "nested0.ckpt", |
| 99 | + tmp_path / "subdir0" / "nested1.other", |
| 100 | + tmp_path / "subdir1" / "nested2.ckpt", |
| 101 | + ] |
| 102 | + |
| 103 | + for file in top_files + nested_files: |
| 104 | + file.parent.mkdir(exist_ok=True, parents=True) |
| 105 | + file.touch() |
| 106 | + |
| 107 | + # directory with recursion |
| 108 | + with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path)]): |
| 109 | + upgrade_main() |
| 110 | + |
| 111 | + assert {c[0][0] for c in load_mock.call_args_list} == { |
| 112 | + tmp_path / "top0.ckpt", |
| 113 | + tmp_path / "top1.ckpt", |
| 114 | + tmp_path / "subdir0" / "nested0.ckpt", |
| 115 | + tmp_path / "subdir1" / "nested2.ckpt", |
| 116 | + } |
| 117 | + assert migrate_mock.call_count == 4 |
| 118 | + assert {c[0][1] for c in save_mock.call_args_list} == { |
| 119 | + tmp_path / "top0.ckpt", |
| 120 | + tmp_path / "top1.ckpt", |
| 121 | + tmp_path / "subdir0" / "nested0.ckpt", |
| 122 | + tmp_path / "subdir1" / "nested2.ckpt", |
| 123 | + } |
0 commit comments