review comments

This commit is contained in:
Mohit Sharma 2025-04-24 09:49:29 +00:00
parent 15926210d3
commit d58ec388bf
5 changed files with 24 additions and 16 deletions

View File

@ -661,7 +661,7 @@ class BloomModel(BloomPreTrainedModel):
return combined_attention_mask
def set_input_embeddings(self, new_embeddings: torch.Tensor):
def set_inputs_embeds(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings
def forward(

View File

@ -1084,7 +1084,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
# def get_input_embeddings(self):
# return self.embed_tokens
# def set_input_embeddings(self, value):
# def set_inputs_embeds(self, value):
# self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask

View File

@ -207,8 +207,6 @@ class FlashCausalLMBatch(Batch):
# Maximum number of blocks
max_blocks: int
inputs_embeds: Optional[torch.Tensor] = None
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
@ -1346,9 +1344,6 @@ class FlashCausalLM(Model):
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch
def get_input_embeddings(self, batch):
batch.inputs_embeds = None
def init_kv_cache(
self,
num_blocks: int,
@ -1901,7 +1896,9 @@ class FlashCausalLM(Model):
if prefill:
batch.prepare_for_prefill()
self.get_input_embeddings(batch)
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
self.set_inputs_embeds(batch)
prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present)

View File

@ -199,7 +199,8 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
class MllamaCausalLM(VlmCausalLM):
def get_input_embeddings(self, batch):
def set_inputs_embeds(self, batch):
# Set the input embeddings to None, as we are using the input_ids for the model
batch.inputs_embeds = None
def forward(

View File

@ -205,7 +205,8 @@ def preprocess_image(config, img):
elif model_type == "paligemma":
img = img.convert("RGB")
if model_type in {"llava_next", "gemma3", "llama4"}:
if model_type not in {"llava_next", "gemma3", "llama4"}:
# TODO: check if this is needed
img = [img]
return img
@ -307,6 +308,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
image_grid_thw: Optional[torch.Tensor]
cache_entries_to_free: List[Tuple[int, int]]
has_image_inputs: bool = False
inputs_embeds: Optional[torch.Tensor] = None
@classmethod
@tracer.start_as_current_span("concatenate")
@ -334,6 +336,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.inputs_embeds = None
# To be filled in prepare_for_prefill
batch.has_image_inputs = False
batch.cache_entries_to_free = []
@ -349,7 +353,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
image_positions = []
encoder_cache = []
for i, request_id in enumerate(request_ids):
for request_id in request_ids:
idx = self.requests_idx_mapping[request_id]
image_inputs.append(self.image_inputs[idx])
image_positions.append(self.image_positions[idx])
@ -360,6 +364,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.inputs_embeds = None
batch.image_inputs = image_inputs
batch.image_positions = image_positions
@ -383,10 +388,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
max_length = 0
vocab = tokenizer.get_vocab()
config.image_token_index = (
config.image_token_index
if hasattr(config, "image_token_index")
else config.image_token_id
config.image_token_index = getattr(
config, "image_token_index", config.image_token_id
)
batch_tokenized_inputs: List[List[int]] = []
@ -551,6 +555,12 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
self.pixel_values = []
assert (
len(self.cache_lengths)
== len(self.input_lengths)
== len(self.prefilling_mask)
), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask"
for i, (
cache_length,
input_length,
@ -915,7 +925,7 @@ class VlmCausalLM(FlashCausalLM):
batch.pixel_values = None
def get_input_embeddings(self, batch):
def set_inputs_embeds(self, batch):
if batch.has_image_inputs:
self.encode_images(batch)
vision_embeds = batch.gather_vision_embeds()