mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
add logic
This commit is contained in:
parent
84ab88d843
commit
92909f3f33
@ -381,6 +381,130 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
batch.image_grid_thw = None
|
||||
return batch
|
||||
|
||||
def prepare_for_prefill(self):
|
||||
super().prepare_for_prefill()
|
||||
|
||||
self.has_image = False
|
||||
self.scheduled_image_input = []
|
||||
scheduled_image_pixel_values = []
|
||||
|
||||
device = self.input_ids.device
|
||||
|
||||
for i, (
|
||||
r,
|
||||
cache_length,
|
||||
input_length,
|
||||
prompt_length,
|
||||
request_prefilling,
|
||||
) in enumerate(
|
||||
zip(
|
||||
self.requests,
|
||||
self.cache_lengths,
|
||||
self.input_lengths,
|
||||
self.prompt_lengths,
|
||||
self.prefilling_mask,
|
||||
)
|
||||
):
|
||||
if not request_prefilling:
|
||||
continue
|
||||
|
||||
for mm_inputs in batch_mm_inputs[i]:
|
||||
for j, mm_input in enumerate(mm_input):
|
||||
image_id = mm_input.id
|
||||
pixel_values = self.all_pixel_values[i][j].pixel_values
|
||||
|
||||
start_pos = mm_input.offset
|
||||
length = mm_input.length
|
||||
num_placeholder_tokens = mm_input.num_placeholder_tokens
|
||||
|
||||
if start_pos >= cache_length + input_length:
|
||||
# No encoder input required at this step
|
||||
break
|
||||
if start_pos + length <= cache_length:
|
||||
# The encode input is already processed
|
||||
continue
|
||||
|
||||
self.has_image = True
|
||||
|
||||
if image_id not in self.encoder_cache[i][image_id]:
|
||||
self.scheduled_image_input.append((i, mm_input))
|
||||
scheduled_image_pixel_values.append(pixel_values)
|
||||
|
||||
if self.has_image and len(scheduled_image_pixel_values):
|
||||
self.pixel_values = torch.cat([scheduled_image_pixel_values], dim=0).to(device)
|
||||
|
||||
|
||||
def update_encoder_cache(self, encoder_outputs):
|
||||
prev = 0
|
||||
for i, input in self.scheduled_image_input:
|
||||
length = input.num_placeholder_tokens
|
||||
output = encoder_outputs[prev:length]
|
||||
batch.encoder_cache[i][image_id] = self.scatter_image_embed(output, input.is_embed)
|
||||
|
||||
prev = length
|
||||
|
||||
def get_mm_embeddings(self):
|
||||
device = self.input_ids.device
|
||||
|
||||
for i, (
|
||||
r,
|
||||
cache_length,
|
||||
input_length,
|
||||
prompt_length,
|
||||
request_prefilling,
|
||||
) in enumerate(
|
||||
zip(
|
||||
self.requests,
|
||||
self.cache_lengths,
|
||||
self.input_lengths,
|
||||
self.prompt_lengths,
|
||||
self.prefilling_mask,
|
||||
)
|
||||
):
|
||||
if not request_prefilling:
|
||||
continue
|
||||
|
||||
for mm_inputs in batch_mm_inputs[i]:
|
||||
for j, mm_input in enumerate(mm_input):
|
||||
image_id = mm_input.id
|
||||
pixel_values = self.all_pixel_values[i][j].pixel_values
|
||||
|
||||
start_pos = mm_input.offset
|
||||
length = mm_input.length
|
||||
num_placeholder_tokens = mm_input.num_placeholder_tokens
|
||||
|
||||
if start_pos >= cache_length + input_length:
|
||||
# No encoder input required at this step
|
||||
break
|
||||
if start_pos + length <= cache_length:
|
||||
# The encode input is already processed
|
||||
continue
|
||||
|
||||
self.has_image = True
|
||||
|
||||
if image_id not in self.encoder_cache[i][image_id]:
|
||||
self.scheduled_image_input.append(mm_input)
|
||||
scheduled_image_pixel_values.append(pixel_values)
|
||||
|
||||
|
||||
start_idx = max(cache_length - start_pos, 0)
|
||||
end_idx = min(
|
||||
cache_length - start_pos + input_length,
|
||||
length)
|
||||
|
||||
encoder_output = self.encoder_cache[i][image_id]
|
||||
|
||||
is_embed = pos_info.is_embed
|
||||
if is_embed is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
|
||||
mm_embeds_item = gather_mm_embeds(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
|
||||
return torch.cat(mm_embeds, dim=0).to(device)
|
||||
|
||||
class VlmCausalLM(FlashCausalLM):
|
||||
def __init__(
|
||||
@ -418,6 +542,14 @@ class VlmCausalLM(FlashCausalLM):
|
||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||
return self.batch_class
|
||||
|
||||
def get_mm_embeddings(self, batch):
|
||||
if batch.pixel_values is not None:
|
||||
encoder_outputs = self.model.get_mm_embeddings(batch.pixel_values)
|
||||
|
||||
batch.update_encoder_cache(encoder_outputs)
|
||||
|
||||
return batch.get_mm_embeddings()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: VlmCausalLMBatch,
|
||||
|
Loading…
Reference in New Issue
Block a user