mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
add encoder cache free
This commit is contained in:
parent
44ed5efbcc
commit
526a8785ed
@ -1899,7 +1899,9 @@ class FlashCausalLM(Model):
|
|||||||
batch.prepare_for_prefill()
|
batch.prepare_for_prefill()
|
||||||
|
|
||||||
self.get_input_embeddings(batch)
|
self.get_input_embeddings(batch)
|
||||||
|
from pdb import set_trace
|
||||||
|
|
||||||
|
set_trace()
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
# Update adapter indices for speculative tokens (if present)
|
# Update adapter indices for speculative tokens (if present)
|
||||||
|
@ -304,6 +304,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
support_chunking=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
|
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
|
||||||
|
@ -499,6 +499,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
if image_id not in self.encoder_cache[i]:
|
if image_id not in self.encoder_cache[i]:
|
||||||
self.scheduled_image_input.append((i, image_position))
|
self.scheduled_image_input.append((i, image_position))
|
||||||
scheduled_image_pixel_values.append(pixel_values)
|
scheduled_image_pixel_values.append(pixel_values)
|
||||||
|
self.image_inputs[i][j] = None
|
||||||
|
|
||||||
if self.has_image and len(scheduled_image_pixel_values):
|
if self.has_image and len(scheduled_image_pixel_values):
|
||||||
self.pixel_values = torch.cat(
|
self.pixel_values = torch.cat(
|
||||||
@ -590,6 +591,37 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
return torch.cat(mm_embeds, dim=0).to(device)
|
return torch.cat(mm_embeds, dim=0).to(device)
|
||||||
|
|
||||||
|
def free_encoder_cache(self):
|
||||||
|
for i, (
|
||||||
|
r,
|
||||||
|
cache_length,
|
||||||
|
input_length,
|
||||||
|
request_prefilling,
|
||||||
|
) in enumerate(
|
||||||
|
zip(
|
||||||
|
self.requests,
|
||||||
|
self.cache_lengths,
|
||||||
|
self.input_lengths,
|
||||||
|
self.prefilling_mask,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if not request_prefilling or self.image_positions[i] is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for j, image_position in enumerate(self.image_positions[i]):
|
||||||
|
image_id = image_position.id
|
||||||
|
|
||||||
|
start_pos = image_position.offset
|
||||||
|
length = image_position.length
|
||||||
|
|
||||||
|
cache_length = cache_length + input_length
|
||||||
|
if start_pos >= cache_length:
|
||||||
|
# No encoder input required at this step
|
||||||
|
break
|
||||||
|
|
||||||
|
if start_pos + length <= cache_length:
|
||||||
|
self.encoder_cache[i][image_id] = None
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLM(FlashCausalLM):
|
class VlmCausalLM(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -831,4 +863,6 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
logits = cuda_graph["logits"][:bs]
|
logits = cuda_graph["logits"][:bs]
|
||||||
|
|
||||||
|
batch.free_encoder_cache()
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
Loading…
Reference in New Issue
Block a user