Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,87 @@ def test_vdpo_trainer(self, model_id):
continue
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")

def test_dpo_trainer_gemma3_vision_model_detection(self):
"""CPU-only check that Gemma 3 routes via vision path and preserves pixel tensors."""

# Minimal Gemma 3-like model
class _MinimalConfig:
def __init__(self):
self.model_type = "gemma3"
self.is_encoder_decoder = False
self._name_or_path = "dummy"

class _MinimalModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.config = _MinimalConfig()
self.warnings_issued = {}

model = _MinimalModel()

# Mock tokenizer with required token IDs to avoid external dependencies
mock_tokenizer = MagicMock()
mock_tokenizer.pad_token_id = 0
mock_tokenizer.eos_token_id = 1
mock_tokenizer.bos_token_id = 2

def _tok_call(text, add_special_tokens=False):
return {"input_ids": [11, 22]}

mock_tokenizer.side_effect = _tok_call

# Mock processor that returns pixel_values to simulate vision processing
processor = MagicMock()
processor.tokenizer = mock_tokenizer
# Ensure DPOTrainer reads an integer pad_token_id
processor.pad_token_id = 0

def _proc_call(images=None, text=None, add_special_tokens=False):
return {
"input_ids": [[101, 102]],
"pixel_values": [np.zeros((3, 8, 8), dtype=np.float32)],
}

processor.side_effect = _proc_call

# Tiny dataset with one 16x16 image
img = Image.fromarray(np.zeros((16, 16, 3), dtype=np.uint8))
ds = Dataset.from_list(
[
{
"prompt": "Describe the image.",
"chosen": "Black square.",
"rejected": "White circle.",
"images": [img],
}
]
)
# Test-optimized config to avoid multiprocessing and reference model creation
args = DPOConfig(
output_dir=self.tmp_dir,
dataset_num_proc=1,
precompute_ref_log_probs=True,
gradient_checkpointing=False,
report_to="none",
)

trainer = DPOTrainer(model=model, args=args, processing_class=processor, train_dataset=ds)

# 1) Model detected as vision-text
self.assertTrue(trainer.is_vision_model, "Expected Gemma 3 to be detected as vision-text model")

# 2) Dataset processed via process_row (pixel_values present)
row = trainer.train_dataset[0]
self.assertIn("pixel_values", row, "process_row did not add pixel_values")

# 3) Signature columns include vision fields
trainer._set_signature_columns_if_needed()
self.assertIn("pixel_values", trainer._signature_columns)

# 4) Collator preserves pixel tensors
batch = trainer.data_collator([row])
self.assertIn("pixel_values", batch, "pixel_values missing in collated batch")


if __name__ == "__main__":
unittest.main()
6 changes: 4 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
is_mlflow_available,
is_wandb_available,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_liger_kernel_available, is_peft_available
Expand Down Expand Up @@ -330,7 +330,7 @@ def __init__(
)

self.is_encoder_decoder = model.config.is_encoder_decoder
self.is_vision_model = model.config.model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.keys()
self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
self.model_adapter_name = args.model_adapter_name
self.ref_adapter_name = args.ref_adapter_name
Expand Down Expand Up @@ -788,6 +788,8 @@ def _set_signature_columns_if_needed(self):
"prompt_input_ids",
"chosen_input_ids",
"rejected_input_ids",
"pixel_values",
"pixel_attention_mask",
"image_sizes",
"ref_chosen_logps",
"ref_rejected_logps",
Expand Down