diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 84835ab8..7c8a6926 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -661,7 +661,7 @@ class BloomModel(BloomPreTrainedModel): return combined_attention_mask - def set_input_embeddings(self, new_embeddings: torch.Tensor): + def set_inputs_embeds(self, new_embeddings: torch.Tensor): self.word_embeddings = new_embeddings def forward( diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 9fc9bca6..2da36ecc 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -1084,7 +1084,7 @@ class IdeficsModel(IdeficsPreTrainedModel): # def get_input_embeddings(self): # return self.embed_tokens - # def set_input_embeddings(self, value): + # def set_inputs_embeds(self, value): # self.embed_tokens = value # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 581cbde8..695c4af3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -207,8 +207,6 @@ class FlashCausalLMBatch(Batch): # Maximum number of blocks max_blocks: int - inputs_embeds: Optional[torch.Tensor] = None - def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, @@ -1346,9 +1344,6 @@ class FlashCausalLM(Model): def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch - def get_input_embeddings(self, batch): - batch.inputs_embeds = None - def init_kv_cache( self, num_blocks: int, @@ -1901,7 +1896,9 @@ class FlashCausalLM(Model): if prefill: batch.prepare_for_prefill() - self.get_input_embeddings(batch) + if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds): + self.set_inputs_embeds(batch) + prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index 9078160b..7e041c1a 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -199,7 +199,8 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): class MllamaCausalLM(VlmCausalLM): - def get_input_embeddings(self, batch): + def set_inputs_embeds(self, batch): + # Set the input embeddings to None, as we are using the input_ids for the model batch.inputs_embeds = None def forward( diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 953081a9..872418e8 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -205,7 +205,8 @@ def preprocess_image(config, img): elif model_type == "paligemma": img = img.convert("RGB") - if model_type in {"llava_next", "gemma3", "llama4"}: + if model_type not in {"llava_next", "gemma3", "llama4"}: + # TODO: check if this is needed img = [img] return img @@ -307,6 +308,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): image_grid_thw: Optional[torch.Tensor] cache_entries_to_free: List[Tuple[int, int]] has_image_inputs: bool = False + inputs_embeds: Optional[torch.Tensor] = None @classmethod @tracer.start_as_current_span("concatenate") @@ -334,6 +336,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch): batch.pixel_attention_mask = None batch.image_sizes = None batch.image_grid_thw = None + batch.inputs_embeds = None + # To be filled in prepare_for_prefill batch.has_image_inputs = False batch.cache_entries_to_free = [] @@ -349,7 +353,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): image_positions = [] encoder_cache = [] - for i, request_id in enumerate(request_ids): + for request_id in request_ids: idx = self.requests_idx_mapping[request_id] image_inputs.append(self.image_inputs[idx]) image_positions.append(self.image_positions[idx]) @@ -360,6 +364,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): batch.pixel_attention_mask = None batch.image_sizes = None batch.image_grid_thw = None + batch.inputs_embeds = None batch.image_inputs = image_inputs batch.image_positions = image_positions @@ -383,10 +388,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch): max_length = 0 vocab = tokenizer.get_vocab() - config.image_token_index = ( - config.image_token_index - if hasattr(config, "image_token_index") - else config.image_token_id + + config.image_token_index = getattr( + config, "image_token_index", config.image_token_id ) batch_tokenized_inputs: List[List[int]] = [] @@ -551,6 +555,12 @@ class VlmCausalLMBatch(FlashCausalLMBatch): self.pixel_values = [] + assert ( + len(self.cache_lengths) + == len(self.input_lengths) + == len(self.prefilling_mask) + ), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask" + for i, ( cache_length, input_length, @@ -915,7 +925,7 @@ class VlmCausalLM(FlashCausalLM): batch.pixel_values = None - def get_input_embeddings(self, batch): + def set_inputs_embeds(self, batch): if batch.has_image_inputs: self.encode_images(batch) vision_embeds = batch.gather_vision_embeds()