Skip to content

current_gradient_accumulation_steps is undefined when eval_on_start==True #4010

@konstantinjdobler

Description

@konstantinjdobler

Reproduction

Similar issue as already raised in #3983

However, that issue concerned errors during the training loop, which are indeed fixed by using the correct transformers version. But when eval_on_start is True, we go into eval before current_gradient_accumulation_steps is first set in the inner_training_loop of the Trainer.

In transformers.Trainer, this is fine because current_gradient_accumulation_steps is not referenced during eval. But in TRL we use current_gradient_accumulation_steps in _compute_loss, which is also used during eval => error.

Fix: I guess we need to set an initial value for current_gradient_accumulation_steps, at least if eval_on_start is True (we should check whether the current logic for eval is even correct, maybe we actually need to set current_gradient_accumulation_steps ourselves in eval if this can have different values than during train).

Traceback:

Traceback (most recent call last):
  File "<...>/grpo.py", line 328, in main
    trainer.train()
  File "<...>/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2328, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "<...>/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2581, in _inner_training_loop
    self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
  File "/mnt/task_runtime/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3176, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<...>/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 4469, in evaluate
    output = eval_loop(
             ^^^^^^^^^^
  File "<...>/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 4665, in evaluation_loop
    losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<...>/.venv/lib/python3.12/site-packages/trl/trainer/grpo_trainer.py", line 1667, in prediction_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File <...>/.venv/lib/python3.12/site-packages/trl/extras/profiling.py", line 98, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<...>/.venv/lib/python3.12/site-packages/trl/trainer/grpo_trainer.py", line 1539, in compute_loss
    return self._compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<...>/.venv/lib/python3.12/site-packages/trl/trainer/grpo_trainer.py", line 1613, in _compute_loss
    loss = loss / self.current_gradient_accumulation_steps
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'GRPOTrainer' object has no attribute 'current_gradient_accumulation_steps'

System Info

I am on transformers==4.56.0 and trl==0.22.2.

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 GRPORelated to GRPO🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions