Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image
from transformers import PixtralVisionConfig, TensorType
from transformers import BatchFeature, PixtralVisionConfig, TensorType
from transformers.image_utils import ImageInput
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
Expand Down Expand Up @@ -163,10 +163,12 @@ def __call__(
images_processed.append(image_processed)
images_tokens.append(image_tokens)

return {
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
}
return BatchFeature({
"input_ids":
torch.cat(images_tokens)[None].expand(len(text), -1),
"images":
images_processed,
})


class PixtralProcessingInfo(BaseProcessingInfo):
Expand Down
12 changes: 7 additions & 5 deletions vllm/model_executor/models/voxtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
from transformers import TensorType, WhisperConfig
from transformers import BatchFeature, TensorType, WhisperConfig
from transformers.tokenization_utils_base import TextInput

from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
Expand Down Expand Up @@ -156,10 +156,12 @@ def __call__(
audios_tokens.append(torch.tensor(audio_tokens))
audios_processed.append(torch.tensor(audio))

return {
"input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
"audio_arrays": audios_processed,
}
return BatchFeature({
"input_ids":
torch.cat(audios_tokens)[None].expand(len(text), -1),
"audio_arrays":
audios_processed,
})


class VoxtralProcessingInfo(BaseProcessingInfo):
Expand Down
30 changes: 17 additions & 13 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,16 @@ def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
self.version: int = int(_mistral_version_str.split("v")[-1])

tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
from mistral_common.tokens.tokenizers.tekken import (
SpecialTokenPolicy, Tekkenizer)
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from mistral_common.tokens.tokenizers.tekken import Tekkenizer

self.is_tekken = isinstance(tokenizer_, Tekkenizer)
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer)
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
if self.is_tekken:
# Make sure special tokens will not raise
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
elif self.is_spm:
pass
else:
self._special_token_policy = (SpecialTokenPolicy.IGNORE
if self.is_tekken else None)
if not (self.is_tekken or self.is_spm):
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")

self._vocab = tokenizer_.vocab()
Expand Down Expand Up @@ -430,7 +428,8 @@ def _token_to_id(t: str):
return self.tokenizer.unk_id

ids = [_token_to_id(t) for t in tokens]
decoded = self.tokenizer.decode(ids)
decoded = self.tokenizer.decode(ids,
self._special_token_policy)
else:
decoded = "".join(tokens)
else:
Expand All @@ -444,15 +443,17 @@ def _token_to_id(t: str):
if token in special_tokens:
if regular_tokens:
decoded_list.append(
self.tokenizer.decode(regular_tokens))
self.tokenizer.decode(regular_tokens,
self._special_token_policy))
regular_tokens = []
decoded_list.append(token)
else:
regular_tokens.append(token)

if regular_tokens:
decoded_list.append(
self.tokenizer.decode(regular_tokens)) # type: ignore
self.tokenizer.decode(regular_tokens,
self._special_token_policy))

decoded = ''.join(decoded_list)

Expand All @@ -470,7 +471,7 @@ def decode(self,

if isinstance(ids, int):
ids = [ids]
return self.tokenizer.decode(ids)
return self.tokenizer.decode(ids, self._special_token_policy)

def convert_ids_to_tokens(
self,
Expand Down Expand Up @@ -511,6 +512,9 @@ def convert_ids_to_tokens(
# See: https://github.com/vllm-project/vllm/pull/8640
# https://github.com/vllm-project/vllm/pull/9625
# if underlying tokenizeir is sentencepiece, we just add "�"
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
tokens = [
self.tokenizer.id_to_byte_piece(id, self._special_token_policy)
for id in ids
]

return tokens