Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ def test_pad_to_multiple_of(self):
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0], [0, 1, 0, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]]))

def test_pad_to_multiple_of_and_padding_free(self):
"""Test padding to multiple of specified value."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, pad_to_multiple_of=4)
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]

result = collator(examples)

torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, -100, -100, -100]]))

def test_custom_position_ids(self):
"""Test handling of custom position IDs in examples."""
self.collator = DataCollatorForLanguageModeling(pad_token_id=0)
Expand Down
5 changes: 3 additions & 2 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,9 @@ def pack_dataset(
>>> dataset = Dataset.from_dict(examples)
>>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="bfd")
>>> packed_dataset[:]
{'input_ids': [[1, 2, 3, 9], [6, 7, 8, 4, 5]],
'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]]}
{'input_ids': [[1, 2, 3, 9], [6, 7, 8], [4, 5]],
'attention_mask': [[1, 1, 0, 1], [1, 0, 0], [1, 0]],
'seq_lengths': [[3, 1], [3], [2]]}
```
"""
if map_kwargs is None:
Expand Down
92 changes: 47 additions & 45 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
that are no in the completion.
padding_free (`bool`, *optional*, defaults to `False`):
If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be
generated accordingly. The attention mask will be set to 1 for all tokens.
generated accordingly.
return_position_ids (`bool`, *optional*, defaults to `True`):
Whether to return position IDs. If `True`, position IDs are generated and returned. If `False`, attention
masks are generated and returned instead. Note that when using FlashAttention, this should be set to
`True`.
pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
If set, the sequences will be padded to a multiple of this value.
return_tensors (`str`, *optional*, defaults to `"pt"`):
Expand All @@ -139,8 +143,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
>>> collator(examples)
{'input_ids': tensor([[ 1, 2, 3],
[ 4, 5, 0]]),
'attention_mask': tensor([[ 1, 1, 1],
[ 1, 1, 0]]),
'position_ids': tensor([[0, 1, 2],
[0, 1, 0]]),
'labels': tensor([[ 1, 2, 3],
Expand All @@ -154,8 +156,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
>>> collator(examples)
{'input_ids': tensor([[ 1, 2, 3],
[ 4, 5, 0]]),
'attention_mask': tensor([[ 1, 1, 1],
[ 1, 1, 0]]),
'position_ids': tensor([[0, 1, 2],
[0, 1, 0]]),
'labels': tensor([[-100, 2, 3],
Expand All @@ -182,21 +182,23 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
# Convert to tensor
input_ids = [torch.tensor(example["input_ids"]) for example in examples]

# Check if we have meaningful seq_lengths from packing (restarting sequences)
has_packed_position_ids = self.return_position_ids and "seq_lengths" in examples[0] and self.padding_free

# For packing with position_ids, we should NOT create attention_mask as it causes
# FlashAttention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask
if not has_packed_position_ids:
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]

# In practice, self.return_position_ids is True when using FlashAttention. When using FlashAttention, we should
# NOT create attention_mask as it causes FlashAttention to ignore position_ids and compute wrong cu_seq_lens
# from the all-1s mask.
if self.return_position_ids:
if "seq_lengths" in examples[0]:
position_ids = self.get_position_ids_from_packed_seq_lengths(
[example["seq_lengths"] for example in examples]
)
else:
position_ids = [torch.arange(len(ids)) for ids in input_ids]
else:
if "seq_lengths" in examples[0]:
logger.warning(
"The input examples contain `seq_lengths` but `return_position_ids` is set to `False`. "
"`seq_lengths` will be ignored."
)
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
if "labels" in examples[0]:
labels = [torch.tensor(example["labels"]) for example in examples]
else:
Expand All @@ -206,48 +208,48 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
if "assistant_masks" in examples[0]:
assistant_masks = [torch.tensor(example["assistant_masks"]) for example in examples]

# Pad
# If padding_free, flatten everything into a single sequence
output = {}
if self.padding_free:
output["input_ids"] = torch.cat(input_ids, dim=0).unsqueeze(0)
if not has_packed_position_ids:
output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0)
input_ids = [torch.cat(input_ids, dim=0)]
if self.return_position_ids:
output["position_ids"] = torch.cat(position_ids, dim=0).unsqueeze(0)
output["labels"] = torch.cat(labels, dim=0).unsqueeze(0)
position_ids = [torch.cat(position_ids, dim=0)]
else:
attention_mask = [torch.cat(attention_mask, dim=0)]
labels = [torch.cat(labels, dim=0)]
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = torch.cat(completion_mask, dim=0).unsqueeze(0)
output["labels"][completion_mask == 0] = -100
completion_mask = [torch.cat(completion_mask, dim=0)]
if "assistant_masks" in examples[0]:
assistant_masks = torch.cat(assistant_masks, dim=0).unsqueeze(0)
output["labels"][assistant_masks == 0] = -100
else:
output["input_ids"] = pad(
input_ids,
padding_value=self.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
assistant_masks = [torch.cat(assistant_masks, dim=0)]

# Pad
output["input_ids"] = pad(
input_ids,
padding_value=self.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
)
if self.return_position_ids:
output["position_ids"] = pad(
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
else:
output["attention_mask"] = pad(
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.return_position_ids:
output["position_ids"] = pad(
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"] = pad(
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
output["labels"] = pad(
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = pad(
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = pad(
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
if "assistant_masks" in examples[0]:
assistant_masks = pad(
assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][assistant_masks == 0] = -100
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
if "assistant_masks" in examples[0]:
assistant_masks = pad(
assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][assistant_masks == 0] = -100
return output

@staticmethod
Expand Down