|
29 | 29 | V1QueueServerType,
|
30 | 30 | V1SourceType,
|
31 | 31 | V1UserRequestedComputeConfig,
|
| 32 | + V1UserRequestedFlowComputeConfig, |
32 | 33 | V1Work,
|
33 | 34 | )
|
34 | 35 |
|
|
37 | 38 | from lightning_app.storage import Drive, Mount
|
38 | 39 | from lightning_app.utilities.cloud import _get_project
|
39 | 40 | from lightning_app.utilities.dependency_caching import get_hash
|
| 41 | +from lightning_app.utilities.packaging.cloud_compute import CloudCompute |
40 | 42 |
|
41 | 43 |
|
42 | 44 | class MyWork(LightningWork):
|
@@ -66,6 +68,47 @@ def run(self):
|
66 | 68 | class TestAppCreationClient:
|
67 | 69 | """Testing the calls made using GridRestClient to create the app."""
|
68 | 70 |
|
| 71 | + @mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock()) |
| 72 | + def test_run_with_custom_flow_compute_config(self, monkeypatch): |
| 73 | + mock_client = mock.MagicMock() |
| 74 | + mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse( |
| 75 | + memberships=[V1Membership(name="test-project", project_id="test-project-id")] |
| 76 | + ) |
| 77 | + mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = ( |
| 78 | + V1ListLightningappInstancesResponse(lightningapps=[]) |
| 79 | + ) |
| 80 | + cloud_backend = mock.MagicMock() |
| 81 | + cloud_backend.client = mock_client |
| 82 | + monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend)) |
| 83 | + monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock()) |
| 84 | + app = mock.MagicMock() |
| 85 | + app.flows = [] |
| 86 | + app.frontend = {} |
| 87 | + app.flow_cloud_compute = CloudCompute(name="t2.medium") |
| 88 | + cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file="entrypoint.py") |
| 89 | + cloud_runtime._check_uploaded_folder = mock.MagicMock() |
| 90 | + |
| 91 | + monkeypatch.setattr(Path, "is_file", lambda *args, **kwargs: False) |
| 92 | + monkeypatch.setattr(cloud, "Path", Path) |
| 93 | + cloud_runtime.dispatch() |
| 94 | + body = Body8( |
| 95 | + app_entrypoint_file=mock.ANY, |
| 96 | + enable_app_server=True, |
| 97 | + flow_servers=[], |
| 98 | + image_spec=None, |
| 99 | + works=[], |
| 100 | + local_source=True, |
| 101 | + dependency_cache_key=mock.ANY, |
| 102 | + user_requested_flow_compute_config=V1UserRequestedFlowComputeConfig( |
| 103 | + name="t2.medium", |
| 104 | + preemptible=False, |
| 105 | + shm_size=0, |
| 106 | + ), |
| 107 | + ) |
| 108 | + cloud_runtime.backend.client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with( |
| 109 | + project_id="test-project-id", app_id=mock.ANY, body=body |
| 110 | + ) |
| 111 | + |
69 | 112 | @mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock())
|
70 | 113 | def test_run_on_byoc_cluster(self, monkeypatch):
|
71 | 114 | mock_client = mock.MagicMock()
|
@@ -100,6 +143,7 @@ def test_run_on_byoc_cluster(self, monkeypatch):
|
100 | 143 | works=[],
|
101 | 144 | local_source=True,
|
102 | 145 | dependency_cache_key=mock.ANY,
|
| 146 | + user_requested_flow_compute_config=mock.ANY, |
103 | 147 | )
|
104 | 148 | cloud_runtime.backend.client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
|
105 | 149 | project_id="default-project-id", app_id=mock.ANY, body=body
|
@@ -142,6 +186,7 @@ def test_requirements_file(self, monkeypatch):
|
142 | 186 | works=[],
|
143 | 187 | local_source=True,
|
144 | 188 | dependency_cache_key=mock.ANY,
|
| 189 | + user_requested_flow_compute_config=mock.ANY, |
145 | 190 | )
|
146 | 191 | cloud_runtime.backend.client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with(
|
147 | 192 | project_id="test-project-id", app_id=mock.ANY, body=body
|
@@ -264,6 +309,7 @@ def test_call_with_work_app(self, lightningapps, monkeypatch, tmpdir):
|
264 | 309 | enable_app_server=True,
|
265 | 310 | flow_servers=[],
|
266 | 311 | dependency_cache_key=get_hash(requirements_file),
|
| 312 | + user_requested_flow_compute_config=mock.ANY, |
267 | 313 | image_spec=Gridv1ImageSpec(
|
268 | 314 | dependency_file_info=V1DependencyFileInfo(
|
269 | 315 | package_manager=V1PackageManager.PIP, path="requirements.txt"
|
@@ -431,6 +477,7 @@ def test_call_with_work_app_and_attached_drives(self, lightningapps, monkeypatch
|
431 | 477 | enable_app_server=True,
|
432 | 478 | flow_servers=[],
|
433 | 479 | dependency_cache_key=get_hash(requirements_file),
|
| 480 | + user_requested_flow_compute_config=mock.ANY, |
434 | 481 | image_spec=Gridv1ImageSpec(
|
435 | 482 | dependency_file_info=V1DependencyFileInfo(
|
436 | 483 | package_manager=V1PackageManager.PIP, path="requirements.txt"
|
@@ -590,6 +637,7 @@ def test_call_with_work_app_and_multiple_attached_drives(self, lightningapps, mo
|
590 | 637 | enable_app_server=True,
|
591 | 638 | flow_servers=[],
|
592 | 639 | dependency_cache_key=get_hash(requirements_file),
|
| 640 | + user_requested_flow_compute_config=mock.ANY, |
593 | 641 | image_spec=Gridv1ImageSpec(
|
594 | 642 | dependency_file_info=V1DependencyFileInfo(
|
595 | 643 | package_manager=V1PackageManager.PIP, path="requirements.txt"
|
@@ -623,6 +671,7 @@ def test_call_with_work_app_and_multiple_attached_drives(self, lightningapps, mo
|
623 | 671 | enable_app_server=True,
|
624 | 672 | flow_servers=[],
|
625 | 673 | dependency_cache_key=get_hash(requirements_file),
|
| 674 | + user_requested_flow_compute_config=mock.ANY, |
626 | 675 | image_spec=Gridv1ImageSpec(
|
627 | 676 | dependency_file_info=V1DependencyFileInfo(
|
628 | 677 | package_manager=V1PackageManager.PIP, path="requirements.txt"
|
@@ -756,6 +805,7 @@ def test_call_with_work_app_and_attached_mount_and_drive(self, lightningapps, mo
|
756 | 805 | package_manager=V1PackageManager.PIP, path="requirements.txt"
|
757 | 806 | )
|
758 | 807 | ),
|
| 808 | + user_requested_flow_compute_config=mock.ANY, |
759 | 809 | works=[
|
760 | 810 | V1Work(
|
761 | 811 | name="test-work",
|
|
0 commit comments