mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
add encoder cache free
This commit is contained in:
parent
44ed5efbcc
commit
526a8785ed
@ -1899,7 +1899,9 @@ class FlashCausalLM(Model):
|
||||
batch.prepare_for_prefill()
|
||||
|
||||
self.get_input_embeddings(batch)
|
||||
from pdb import set_trace
|
||||
|
||||
set_trace()
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
# Update adapter indices for speculative tokens (if present)
|
||||
|
@ -304,6 +304,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
support_chunking=True,
|
||||
)
|
||||
|
||||
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
|
||||
|
@ -499,6 +499,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
if image_id not in self.encoder_cache[i]:
|
||||
self.scheduled_image_input.append((i, image_position))
|
||||
scheduled_image_pixel_values.append(pixel_values)
|
||||
self.image_inputs[i][j] = None
|
||||
|
||||
if self.has_image and len(scheduled_image_pixel_values):
|
||||
self.pixel_values = torch.cat(
|
||||
@ -590,6 +591,37 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
return torch.cat(mm_embeds, dim=0).to(device)
|
||||
|
||||
def free_encoder_cache(self):
|
||||
for i, (
|
||||
r,
|
||||
cache_length,
|
||||
input_length,
|
||||
request_prefilling,
|
||||
) in enumerate(
|
||||
zip(
|
||||
self.requests,
|
||||
self.cache_lengths,
|
||||
self.input_lengths,
|
||||
self.prefilling_mask,
|
||||
)
|
||||
):
|
||||
if not request_prefilling or self.image_positions[i] is None:
|
||||
continue
|
||||
|
||||
for j, image_position in enumerate(self.image_positions[i]):
|
||||
image_id = image_position.id
|
||||
|
||||
start_pos = image_position.offset
|
||||
length = image_position.length
|
||||
|
||||
cache_length = cache_length + input_length
|
||||
if start_pos >= cache_length:
|
||||
# No encoder input required at this step
|
||||
break
|
||||
|
||||
if start_pos + length <= cache_length:
|
||||
self.encoder_cache[i][image_id] = None
|
||||
|
||||
|
||||
class VlmCausalLM(FlashCausalLM):
|
||||
def __init__(
|
||||
@ -831,4 +863,6 @@ class VlmCausalLM(FlashCausalLM):
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
|
||||
batch.free_encoder_cache()
|
||||
return logits, speculative_logits
|
||||
|
Loading…
Reference in New Issue
Block a user