mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
review comments
This commit is contained in:
parent
15926210d3
commit
d58ec388bf
@ -661,7 +661,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
||||
def set_inputs_embeds(self, new_embeddings: torch.Tensor):
|
||||
self.word_embeddings = new_embeddings
|
||||
|
||||
def forward(
|
||||
|
@ -1084,7 +1084,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
# def get_input_embeddings(self):
|
||||
# return self.embed_tokens
|
||||
|
||||
# def set_input_embeddings(self, value):
|
||||
# def set_inputs_embeds(self, value):
|
||||
# self.embed_tokens = value
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||
|
@ -207,8 +207,6 @@ class FlashCausalLMBatch(Batch):
|
||||
# Maximum number of blocks
|
||||
max_blocks: int
|
||||
|
||||
inputs_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
@ -1346,9 +1344,6 @@ class FlashCausalLM(Model):
|
||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||
return FlashCausalLMBatch
|
||||
|
||||
def get_input_embeddings(self, batch):
|
||||
batch.inputs_embeds = None
|
||||
|
||||
def init_kv_cache(
|
||||
self,
|
||||
num_blocks: int,
|
||||
@ -1901,7 +1896,9 @@ class FlashCausalLM(Model):
|
||||
if prefill:
|
||||
batch.prepare_for_prefill()
|
||||
|
||||
self.get_input_embeddings(batch)
|
||||
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
|
||||
self.set_inputs_embeds(batch)
|
||||
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
# Update adapter indices for speculative tokens (if present)
|
||||
|
@ -199,7 +199,8 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||
|
||||
|
||||
class MllamaCausalLM(VlmCausalLM):
|
||||
def get_input_embeddings(self, batch):
|
||||
def set_inputs_embeds(self, batch):
|
||||
# Set the input embeddings to None, as we are using the input_ids for the model
|
||||
batch.inputs_embeds = None
|
||||
|
||||
def forward(
|
||||
|
@ -205,7 +205,8 @@ def preprocess_image(config, img):
|
||||
elif model_type == "paligemma":
|
||||
img = img.convert("RGB")
|
||||
|
||||
if model_type in {"llava_next", "gemma3", "llama4"}:
|
||||
if model_type not in {"llava_next", "gemma3", "llama4"}:
|
||||
# TODO: check if this is needed
|
||||
img = [img]
|
||||
|
||||
return img
|
||||
@ -307,6 +308,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
image_grid_thw: Optional[torch.Tensor]
|
||||
cache_entries_to_free: List[Tuple[int, int]]
|
||||
has_image_inputs: bool = False
|
||||
inputs_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
@ -334,6 +336,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
batch.inputs_embeds = None
|
||||
|
||||
# To be filled in prepare_for_prefill
|
||||
batch.has_image_inputs = False
|
||||
batch.cache_entries_to_free = []
|
||||
@ -349,7 +353,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
image_positions = []
|
||||
encoder_cache = []
|
||||
|
||||
for i, request_id in enumerate(request_ids):
|
||||
for request_id in request_ids:
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
image_inputs.append(self.image_inputs[idx])
|
||||
image_positions.append(self.image_positions[idx])
|
||||
@ -360,6 +364,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
batch.inputs_embeds = None
|
||||
|
||||
batch.image_inputs = image_inputs
|
||||
batch.image_positions = image_positions
|
||||
@ -383,10 +388,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
max_length = 0
|
||||
vocab = tokenizer.get_vocab()
|
||||
config.image_token_index = (
|
||||
config.image_token_index
|
||||
if hasattr(config, "image_token_index")
|
||||
else config.image_token_id
|
||||
|
||||
config.image_token_index = getattr(
|
||||
config, "image_token_index", config.image_token_id
|
||||
)
|
||||
|
||||
batch_tokenized_inputs: List[List[int]] = []
|
||||
@ -551,6 +555,12 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
self.pixel_values = []
|
||||
|
||||
assert (
|
||||
len(self.cache_lengths)
|
||||
== len(self.input_lengths)
|
||||
== len(self.prefilling_mask)
|
||||
), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask"
|
||||
|
||||
for i, (
|
||||
cache_length,
|
||||
input_length,
|
||||
@ -915,7 +925,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
|
||||
batch.pixel_values = None
|
||||
|
||||
def get_input_embeddings(self, batch):
|
||||
def set_inputs_embeds(self, batch):
|
||||
if batch.has_image_inputs:
|
||||
self.encode_images(batch)
|
||||
vision_embeds = batch.gather_vision_embeds()
|
||||
|
Loading…
Reference in New Issue
Block a user