Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020-present the HuggingFace Inc. team.
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -3948,6 +3948,19 @@ def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch.

return contextlib.nullcontext, inputs

def _reduce_loss(self, loss, num_items_in_batch=None):
"""
Properly reduce the loss depending on GPU setup and token averaging.
If running on multiple GPUs, aggregate losses correctly:
- If num_items_in_batch is provided, scale by total items (e.g., tokens).
- Otherwise, average across devices.
"""
if self.args.n_gpu > 1:
if num_items_in_batch is not None:
return loss.sum() / num_items_in_batch
return loss.mean()
return loss

def compute_loss_context_manager(self):
"""
A helper wrapper to group together context managers.
Expand Down Expand Up @@ -4044,8 +4057,9 @@ def training_step(
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()

if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
# Properly reduce the loss across devices and batch/token averaging
loss = self._reduce_loss(loss, num_items_in_batch)


if self.use_apex:
from apex import amp
Expand Down
16 changes: 16 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
def test_loss_aggregation_multi_gpu(tmp_path):
import torch
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM

# tiny model for testing
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
args = TrainingArguments(output_dir=tmp_path, per_device_train_batch_size=2)

trainer = Trainer(model=model, args=args)
trainer.args.n_gpu = 2 # simulate 2 GPUs

# simulate per-GPU losses
dummy_losses = torch.tensor([2.0, 2.0])

# aggregation should sum (not mean)
assert torch.isclose(dummy_losses.sum(), torch.tensor(4.0))
109 changes: 109 additions & 0 deletions tests/trainer/test_loss_reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
import pytest
from torch import nn
from transformers import TrainingArguments, Trainer


class DummyModel(nn.Module):
"""Simple dummy model for testing."""
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)

def forward(self, **kwargs):
return torch.tensor(1.0)


class DummyTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def call_reduce_loss(self, loss, num_items_in_batch=None):
return self._reduce_loss(loss, num_items_in_batch=num_items_in_batch)


@pytest.fixture
def dummy_trainer(tmp_path):
args = TrainingArguments(output_dir=tmp_path, per_device_train_batch_size=2)
model = DummyModel()
return DummyTrainer(model=model, args=args)


def test_reduce_loss_single_gpu_no_reduction(dummy_trainer):
"""Test that on single GPU, loss is returned unchanged."""
dummy_trainer.args._n_gpu = 1
loss = torch.tensor([2.0, 4.0])
reduced = dummy_trainer.call_reduce_loss(loss)
assert torch.equal(reduced, loss)


def test_reduce_loss_multi_gpu_mean(dummy_trainer):
"""Test that on multi-GPU without token count, loss.mean() is used."""
dummy_trainer.args._n_gpu = 2
loss = torch.tensor([2.0, 4.0])
reduced = dummy_trainer.call_reduce_loss(loss)
expected = loss.mean() # Should be 3.0
assert torch.isclose(reduced, expected)


def test_reduce_loss_multi_gpu_sum_with_tokens(dummy_trainer):
"""Test that on multi-GPU with token count, loss.sum() / num_items_in_batch is used."""
dummy_trainer.args._n_gpu = 2
# Simulate two partial losses on two GPUs
loss = torch.tensor([2.0, 4.0])
num_items_in_batch = 2
reduced = dummy_trainer.call_reduce_loss(loss, num_items_in_batch=num_items_in_batch)
# Should do sum / num_items_in_batch = (2+4)/2 = 3
expected = loss.sum() / num_items_in_batch
assert torch.isclose(reduced, expected)


def test_reduce_loss_multi_gpu_different_batch_sizes(dummy_trainer):
"""Test loss reduction with different batch sizes."""
dummy_trainer.args._n_gpu = 4
loss = torch.tensor([1.0, 2.0, 3.0, 4.0])
num_items_in_batch = 10 # Total tokens across all devices
reduced = dummy_trainer.call_reduce_loss(loss, num_items_in_batch=num_items_in_batch)
# Should do sum / num_items_in_batch = (1+2+3+4)/10 = 1.0
expected = loss.sum() / num_items_in_batch
assert torch.isclose(reduced, expected)


def test_reduce_loss_single_element_tensor(dummy_trainer):
"""Test with single element tensor (common case)."""
dummy_trainer.args._n_gpu = 2
loss = torch.tensor(5.0)
reduced = dummy_trainer.call_reduce_loss(loss)
# Single element mean is the element itself
assert torch.isclose(reduced, loss)


def test_reduce_loss_zero_loss(dummy_trainer):
"""Test with zero loss values."""
dummy_trainer.args._n_gpu = 2
loss = torch.tensor([0.0, 0.0])
num_items_in_batch = 5
reduced = dummy_trainer.call_reduce_loss(loss, num_items_in_batch=num_items_in_batch)
assert torch.isclose(reduced, torch.tensor(0.0))


def test_reduce_loss_preserves_gradients(dummy_trainer):
"""Test that gradient information is preserved."""
dummy_trainer.args._n_gpu = 2
loss = torch.tensor([2.0, 4.0], requires_grad=True)
reduced = dummy_trainer.call_reduce_loss(loss)
assert reduced.requires_grad

# Test backward pass works
reduced.backward()
assert loss.grad is not None


def test_reduce_loss_with_tensor_token_count(dummy_trainer):
"""Test with tensor token count (as would come from actual training)."""
dummy_trainer.args._n_gpu = 2
loss = torch.tensor([3.0, 6.0])
num_items_in_batch = torch.tensor(3) # Tensor instead of int
reduced = dummy_trainer.call_reduce_loss(loss, num_items_in_batch=num_items_in_batch)
expected = loss.sum() / num_items_in_batch
assert torch.isclose(reduced, expected)