Skip to content

Conversation

moonrunnerkc
Copy link

This PR fixes incorrect loss normalization in the Trainer when running on multiple GPUs.
The previous implementation always averaged losses, which under-reported values in token-level training.
The new implementation provides a clean, token-aware reduction method that works consistently across single and multi-GPU setups.

Fixes #37474

Motivation and Context:
When using multiple GPUs, Trainer.training_step reported losses that were too small because the reduction was always done by mean().

This PR introduces _reduce_loss, a dedicated helper method that properly handles:
Single GPU: returns loss unchanged
Multi-GPU without token counts: averages across devices
Multi-GPU with token counts: sums and divides by the actual number of tokens
This ensures loss reporting and optimization are accurate, matching expected values like log(vocab_size) during early training.

What was changed:
Added _reduce_loss method inside the Trainer class.
Updated training_step to use _reduce_loss instead of hard-coded loss.mean().
Added a new test suite tests/trainer/test_loss_reduction.py covering single/multi-GPU scenarios, token-aware averaging, gradient preservation, and edge cases.
Added a minimal regression test in tests/test_trainer.py.

Tests:
✅ New tests added (8 total cases) and all pass locally.
✅ All existing tests continue to pass (excluding documented skips for distributed tests).
✅ No regressions introduced.
✅ Code imports and runs without errors.

Notes:
The implementation is backward compatible with existing code.
The design is clean, maintainable, and aligned with existing codebase patterns.

Maintainers may wish to further integrate this with annotations or future loss utilities, but this fix addresses the immediate normalization bug.

@Rocketknight1
Copy link
Member

cc @SunMarc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Trainer.training_step incorrectly normalizes mean token loss when n_gpu > 1
2 participants