|
41 | 41 | from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
42 | 42 | from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
43 | 43 | from ...processing_utils import Unpack
|
44 |
| -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging |
| 44 | +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging |
45 | 45 | from ...utils.deprecation import deprecate_kwarg
|
46 | 46 | from ...utils.generic import OutputRecorder, check_model_inputs
|
47 | 47 | from .configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
|
@@ -1064,15 +1064,6 @@ def forward(
|
1064 | 1064 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1065 | 1065 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1066 | 1066 | """
|
1067 |
| - if self.training and self.config._attn_implementation != "eager": |
1068 |
| - msg = ( |
1069 |
| - "It is strongly recommended to train T5Gemma models with the `eager` attention implementation " |
1070 |
| - f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`." |
1071 |
| - ) |
1072 |
| - if is_torchdynamo_compiling(): |
1073 |
| - raise ValueError(msg) |
1074 |
| - else: |
1075 |
| - logger.warning_once(msg) |
1076 | 1067 |
|
1077 | 1068 | if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
1078 | 1069 | # get decoder inputs from shifting lm labels to the right
|
|
0 commit comments