mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
review comments
This commit is contained in:
parent
15926210d3
commit
d58ec388bf
@ -661,7 +661,7 @@ class BloomModel(BloomPreTrainedModel):
|
|||||||
|
|
||||||
return combined_attention_mask
|
return combined_attention_mask
|
||||||
|
|
||||||
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
def set_inputs_embeds(self, new_embeddings: torch.Tensor):
|
||||||
self.word_embeddings = new_embeddings
|
self.word_embeddings = new_embeddings
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -1084,7 +1084,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
|||||||
# def get_input_embeddings(self):
|
# def get_input_embeddings(self):
|
||||||
# return self.embed_tokens
|
# return self.embed_tokens
|
||||||
|
|
||||||
# def set_input_embeddings(self, value):
|
# def set_inputs_embeds(self, value):
|
||||||
# self.embed_tokens = value
|
# self.embed_tokens = value
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||||
|
@ -207,8 +207,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Maximum number of blocks
|
# Maximum number of blocks
|
||||||
max_blocks: int
|
max_blocks: int
|
||||||
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
return generate_pb2.CachedBatch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
@ -1346,9 +1344,6 @@ class FlashCausalLM(Model):
|
|||||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
|
|
||||||
def get_input_embeddings(self, batch):
|
|
||||||
batch.inputs_embeds = None
|
|
||||||
|
|
||||||
def init_kv_cache(
|
def init_kv_cache(
|
||||||
self,
|
self,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
@ -1901,7 +1896,9 @@ class FlashCausalLM(Model):
|
|||||||
if prefill:
|
if prefill:
|
||||||
batch.prepare_for_prefill()
|
batch.prepare_for_prefill()
|
||||||
|
|
||||||
self.get_input_embeddings(batch)
|
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
|
||||||
|
self.set_inputs_embeds(batch)
|
||||||
|
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
# Update adapter indices for speculative tokens (if present)
|
# Update adapter indices for speculative tokens (if present)
|
||||||
|
@ -199,7 +199,8 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
|
|||||||
|
|
||||||
|
|
||||||
class MllamaCausalLM(VlmCausalLM):
|
class MllamaCausalLM(VlmCausalLM):
|
||||||
def get_input_embeddings(self, batch):
|
def set_inputs_embeds(self, batch):
|
||||||
|
# Set the input embeddings to None, as we are using the input_ids for the model
|
||||||
batch.inputs_embeds = None
|
batch.inputs_embeds = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -205,7 +205,8 @@ def preprocess_image(config, img):
|
|||||||
elif model_type == "paligemma":
|
elif model_type == "paligemma":
|
||||||
img = img.convert("RGB")
|
img = img.convert("RGB")
|
||||||
|
|
||||||
if model_type in {"llava_next", "gemma3", "llama4"}:
|
if model_type not in {"llava_next", "gemma3", "llama4"}:
|
||||||
|
# TODO: check if this is needed
|
||||||
img = [img]
|
img = [img]
|
||||||
|
|
||||||
return img
|
return img
|
||||||
@ -307,6 +308,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
image_grid_thw: Optional[torch.Tensor]
|
image_grid_thw: Optional[torch.Tensor]
|
||||||
cache_entries_to_free: List[Tuple[int, int]]
|
cache_entries_to_free: List[Tuple[int, int]]
|
||||||
has_image_inputs: bool = False
|
has_image_inputs: bool = False
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
@ -334,6 +336,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
batch.inputs_embeds = None
|
||||||
|
|
||||||
# To be filled in prepare_for_prefill
|
# To be filled in prepare_for_prefill
|
||||||
batch.has_image_inputs = False
|
batch.has_image_inputs = False
|
||||||
batch.cache_entries_to_free = []
|
batch.cache_entries_to_free = []
|
||||||
@ -349,7 +353,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
image_positions = []
|
image_positions = []
|
||||||
encoder_cache = []
|
encoder_cache = []
|
||||||
|
|
||||||
for i, request_id in enumerate(request_ids):
|
for request_id in request_ids:
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
image_inputs.append(self.image_inputs[idx])
|
image_inputs.append(self.image_inputs[idx])
|
||||||
image_positions.append(self.image_positions[idx])
|
image_positions.append(self.image_positions[idx])
|
||||||
@ -360,6 +364,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
batch.inputs_embeds = None
|
||||||
|
|
||||||
batch.image_inputs = image_inputs
|
batch.image_inputs = image_inputs
|
||||||
batch.image_positions = image_positions
|
batch.image_positions = image_positions
|
||||||
@ -383,10 +388,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
max_length = 0
|
max_length = 0
|
||||||
vocab = tokenizer.get_vocab()
|
vocab = tokenizer.get_vocab()
|
||||||
config.image_token_index = (
|
|
||||||
config.image_token_index
|
config.image_token_index = getattr(
|
||||||
if hasattr(config, "image_token_index")
|
config, "image_token_index", config.image_token_id
|
||||||
else config.image_token_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_tokenized_inputs: List[List[int]] = []
|
batch_tokenized_inputs: List[List[int]] = []
|
||||||
@ -551,6 +555,12 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
self.pixel_values = []
|
self.pixel_values = []
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(self.cache_lengths)
|
||||||
|
== len(self.input_lengths)
|
||||||
|
== len(self.prefilling_mask)
|
||||||
|
), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask"
|
||||||
|
|
||||||
for i, (
|
for i, (
|
||||||
cache_length,
|
cache_length,
|
||||||
input_length,
|
input_length,
|
||||||
@ -915,7 +925,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
|
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
|
||||||
def get_input_embeddings(self, batch):
|
def set_inputs_embeds(self, batch):
|
||||||
if batch.has_image_inputs:
|
if batch.has_image_inputs:
|
||||||
self.encode_images(batch)
|
self.encode_images(batch)
|
||||||
vision_embeds = batch.gather_vision_embeds()
|
vision_embeds = batch.gather_vision_embeds()
|
||||||
|
Loading…
Reference in New Issue
Block a user