Skip to content

Commit 9918d13

Browse files
MengAiDevsayakpaul
andauthored
fix(training_utils): wrap device in list for DiffusionPipeline (#12178)
- Modify offload_models function to handle DiffusionPipeline correctly - Ensure compatibility with both single and multiple module inputs Co-authored-by: Sayak Paul <[email protected]>
1 parent e824660 commit 9918d13

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/diffusers/training_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@ def offload_models(
339339
original_devices = [next(m.parameters()).device for m in modules]
340340
else:
341341
assert len(modules) == 1
342-
original_devices = modules[0].device
342+
# For DiffusionPipeline, wrap the device in a list to make it iterable
343+
original_devices = [modules[0].device]
343344
# move to target device
344345
for m in modules:
345346
m.to(device)

0 commit comments

Comments
 (0)