mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-21 15:20:19 +00:00
fixes
This commit is contained in:
parent
526a8785ed
commit
b86919a87a
@ -290,6 +290,24 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches):
|
||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||
|
||||
batch.image_inputs = []
|
||||
batch.image_positions = []
|
||||
batch.encoder_cache = []
|
||||
for b in batches:
|
||||
if b.image_inputs is not None:
|
||||
batch.image_inputs.extend(b.image_inputs)
|
||||
else:
|
||||
batch.image_inputs.append(None)
|
||||
if b.image_positions is not None:
|
||||
batch.image_positions.extend(b.image_positions)
|
||||
else:
|
||||
batch.image_positions.append(None)
|
||||
if b.encoder_cache is not None:
|
||||
batch.encoder_cache.extend(b.encoder_cache)
|
||||
else:
|
||||
batch.encoder_cache.append(None)
|
||||
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
@ -298,11 +316,28 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]):
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
|
||||
image_inputs = []
|
||||
image_positions = []
|
||||
encoder_cache = []
|
||||
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
image_inputs.append(self.image_inputs[idx])
|
||||
image_positions.append(self.image_positions[idx])
|
||||
encoder_cache.append(self.encoder_cache[idx])
|
||||
|
||||
batch = super().filter(request_ids)
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
|
||||
batch.image_inputs = image_inputs
|
||||
batch.image_positions = image_positions
|
||||
batch.encoder_cache = encoder_cache
|
||||
return batch
|
||||
|
||||
@classmethod
|
||||
@ -352,7 +387,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
if len(image_inputs) > 0:
|
||||
batch_image_inputs[i] = image_inputs
|
||||
from pdb import set_trace
|
||||
|
||||
set_trace()
|
||||
batch_image_positions = []
|
||||
batch_tokenized_inputs = []
|
||||
max_length = 0
|
||||
@ -459,6 +496,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
super().prepare_for_prefill()
|
||||
|
||||
self.has_image = False
|
||||
self.encoder_cache_to_free = []
|
||||
self.scheduled_image_input = []
|
||||
scheduled_image_pixel_values = []
|
||||
|
||||
@ -574,6 +612,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
start_idx = max(cache_length - start_pos, 0)
|
||||
end_idx = min(cache_length - start_pos + input_length, length)
|
||||
|
||||
if end_idx == length:
|
||||
self.encoder_cache_to_free.append((i, image_id))
|
||||
|
||||
assert (
|
||||
image_id in self.encoder_cache[i]
|
||||
), f"image_id {image_id} not in encoder_cache {self.encoder_cache[i]}"
|
||||
@ -592,35 +633,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
return torch.cat(mm_embeds, dim=0).to(device)
|
||||
|
||||
def free_encoder_cache(self):
|
||||
for i, (
|
||||
r,
|
||||
cache_length,
|
||||
input_length,
|
||||
request_prefilling,
|
||||
) in enumerate(
|
||||
zip(
|
||||
self.requests,
|
||||
self.cache_lengths,
|
||||
self.input_lengths,
|
||||
self.prefilling_mask,
|
||||
)
|
||||
):
|
||||
if not request_prefilling or self.image_positions[i] is None:
|
||||
continue
|
||||
for i, image_id in self.encoder_cache_to_free:
|
||||
self.encoder_cache[i][image_id] = None
|
||||
|
||||
for j, image_position in enumerate(self.image_positions[i]):
|
||||
image_id = image_position.id
|
||||
|
||||
start_pos = image_position.offset
|
||||
length = image_position.length
|
||||
|
||||
cache_length = cache_length + input_length
|
||||
if start_pos >= cache_length:
|
||||
# No encoder input required at this step
|
||||
break
|
||||
|
||||
if start_pos + length <= cache_length:
|
||||
self.encoder_cache[i][image_id] = None
|
||||
self.encoder_cache_to_free = []
|
||||
|
||||
|
||||
class VlmCausalLM(FlashCausalLM):
|
||||
@ -814,6 +830,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
batch.image_sizes = None
|
||||
if batch.image_grid_thw is not None:
|
||||
batch.image_grid_thw = None
|
||||
batch.free_encoder_cache()
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
|
Loading…
Reference in New Issue
Block a user