Skip to content

Commit df67833

Browse files
author
Sherin Thomas
authored
[App] Multiprocessing-safe work pickling (#15836)
1 parent ca5ca0e commit df67833

File tree

4 files changed

+171
-0
lines changed

4 files changed

+171
-0
lines changed

src/lightning_app/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818

1919
- Added a CloudMultiProcessBackend which enables running a child App from within the Flow in the cloud ([#15800](https://github.com/Lightning-AI/lightning/pull/15800))
2020

21+
- Utility for pickling work object safely even from a child process ([#15836](https://github.com/Lightning-AI/lightning/pull/15836))
22+
2123
- Added `AutoScaler` component ([#15769](https://github.com/Lightning-AI/lightning/pull/15769))
2224

2325
- Added the property `ready` of the LightningFlow to inform when the `Open App` should be visible ([#15921](https://github.com/Lightning-AI/lightning/pull/15921))
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import contextlib
2+
import pickle
3+
import sys
4+
import types
5+
import typing
6+
from copy import deepcopy
7+
from pathlib import Path
8+
9+
from lightning_app.core.work import LightningWork
10+
from lightning_app.utilities.app_helpers import _LightningAppRef
11+
12+
NON_PICKLABLE_WORK_ATTRIBUTES = ["_request_queue", "_response_queue", "_backend", "_setattr_replacement"]
13+
14+
15+
@contextlib.contextmanager
16+
def _trimmed_work(work: LightningWork, to_trim: typing.List[str]) -> typing.Iterator[None]:
17+
"""Context manager to trim the work object to remove attributes that are not picklable."""
18+
holder = {}
19+
for arg in to_trim:
20+
holder[arg] = getattr(work, arg)
21+
setattr(work, arg, None)
22+
yield
23+
for arg in to_trim:
24+
setattr(work, arg, holder[arg])
25+
26+
27+
def get_picklable_work(work: LightningWork) -> LightningWork:
28+
"""Pickling a LightningWork instance fails if done from the work process
29+
itself. This function is safe to call from the work process within both MultiprocessRuntime
30+
and Cloud.
31+
Note: This function modifies the module information of the work object. Specifically, it injects
32+
the relative module path into the __module__ attribute of the work object. If the object is not
33+
importable from the CWD, then the pickle load will fail.
34+
35+
Example:
36+
for a directory structure like below and the work class is defined in the app.py where
37+
the app.py is the entrypoint for the app, it will inject `foo.bar.app` into the
38+
__module__ attribute
39+
40+
└── foo
41+
├── __init__.py
42+
└── bar
43+
└── app.py
44+
"""
45+
46+
# If the work object not taken from the app ref, there is a thread lock reference
47+
# somewhere thats preventing it from being pickled. Investigate it later. We
48+
# shouldn't be fetching the work object from the app ref. TODO @sherin
49+
app_ref = _LightningAppRef.get_current()
50+
if app_ref is None:
51+
raise RuntimeError("Cannot pickle LightningWork outside of a LightningApp")
52+
for w in app_ref.works:
53+
if work.name == w.name:
54+
# deep-copying the work object to avoid modifying the original work object
55+
with _trimmed_work(w, to_trim=NON_PICKLABLE_WORK_ATTRIBUTES):
56+
copied_work = deepcopy(w)
57+
break
58+
else:
59+
raise ValueError(f"Work with name {work.name} not found in the app references")
60+
61+
# if work is defined in the __main__ or __mp__main__ (the entrypoint file for `lightning run app` command),
62+
# pickling/unpickling will fail, hence we need patch the module information
63+
if "_main__" in copied_work.__class__.__module__:
64+
work_class_module = sys.modules[copied_work.__class__.__module__]
65+
work_class_file = work_class_module.__file__
66+
if not work_class_file:
67+
raise ValueError(
68+
f"Cannot pickle work class {copied_work.__class__.__name__} because we "
69+
f"couldn't identify the module file"
70+
)
71+
relative_path = Path(work_class_module.__file__).relative_to(Path.cwd()) # type: ignore
72+
expected_module_name = relative_path.as_posix().replace(".py", "").replace("/", ".")
73+
# TODO @sherin: also check if the module is importable from the CWD
74+
fake_module = types.ModuleType(expected_module_name)
75+
fake_module.__dict__.update(work_class_module.__dict__)
76+
fake_module.__dict__["__name__"] = expected_module_name
77+
sys.modules[expected_module_name] = fake_module
78+
for k, v in fake_module.__dict__.items():
79+
if not k.startswith("__") and hasattr(v, "__module__"):
80+
if "_main__" in v.__module__:
81+
v.__module__ = expected_module_name
82+
return copied_work
83+
84+
85+
def dump(work: LightningWork, f: typing.BinaryIO) -> None:
86+
picklable_work = get_picklable_work(work)
87+
pickle.dump(picklable_work, f)
88+
89+
90+
def load(f: typing.BinaryIO) -> typing.Any:
91+
# inject current working directory to sys.path
92+
sys.path.insert(1, str(Path.cwd()))
93+
work = pickle.load(f)
94+
sys.path.pop(1)
95+
return work
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import subprocess
2+
from pathlib import Path
3+
4+
5+
def test_safe_pickle_app():
6+
test_dir = Path(__file__).parent / "testdata"
7+
proc = subprocess.Popen(
8+
["lightning", "run", "app", "safe_pickle_app.py", "--open-ui", "false"], stdout=subprocess.PIPE, cwd=test_dir
9+
)
10+
stdout, _ = proc.communicate()
11+
assert "Exiting the pickling app successfully" in stdout.decode("UTF-8")
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
This app tests three things
3+
1. Can a work pickle `self`
4+
2. Can the pickled work be unpickled in another work
5+
3. Can the pickled work be unpickled from a script
6+
"""
7+
8+
import subprocess
9+
from pathlib import Path
10+
11+
from lightning_app import LightningApp, LightningFlow, LightningWork
12+
from lightning_app.utilities import safe_pickle
13+
14+
15+
class SelfPicklingWork(LightningWork):
16+
def run(self):
17+
with open("work.pkl", "wb") as f:
18+
safe_pickle.dump(self, f)
19+
20+
def get_test_string(self):
21+
return f"Hello from {self.__class__.__name__}!"
22+
23+
24+
class WorkThatLoadsPickledWork(LightningWork):
25+
def run(self):
26+
with open("work.pkl", "rb") as f:
27+
work = safe_pickle.load(f)
28+
assert work.get_test_string() == "Hello from SelfPicklingWork!"
29+
30+
31+
script_load_pickled_work = """
32+
import pickle
33+
work = pickle.load(open("work.pkl", "rb"))
34+
print(work.get_test_string())
35+
"""
36+
37+
38+
class RootFlow(LightningFlow):
39+
def __init__(self):
40+
super().__init__()
41+
self.self_pickling_work = SelfPicklingWork()
42+
self.work_that_loads_pickled_work = WorkThatLoadsPickledWork()
43+
44+
def run(self):
45+
self.self_pickling_work.run()
46+
self.work_that_loads_pickled_work.run()
47+
48+
with open("script_that_loads_pickled_work.py", "w") as f:
49+
f.write(script_load_pickled_work)
50+
51+
# read the output from subprocess
52+
proc = subprocess.Popen(["python", "script_that_loads_pickled_work.py"], stdout=subprocess.PIPE)
53+
assert "Hello from SelfPicklingWork" in proc.stdout.read().decode("UTF-8")
54+
55+
# deleting the script
56+
Path("script_that_loads_pickled_work.py").unlink()
57+
# deleting the pkl file
58+
Path("work.pkl").unlink()
59+
60+
self._exit("Exiting the pickling app successfully!!")
61+
62+
63+
app = LightningApp(RootFlow())

0 commit comments

Comments
 (0)