mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-21 23:30:17 +00:00
fixes
This commit is contained in:
parent
526a8785ed
commit
b86919a87a
@ -290,6 +290,24 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches):
|
def concatenate(cls, batches):
|
||||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||||
|
|
||||||
|
batch.image_inputs = []
|
||||||
|
batch.image_positions = []
|
||||||
|
batch.encoder_cache = []
|
||||||
|
for b in batches:
|
||||||
|
if b.image_inputs is not None:
|
||||||
|
batch.image_inputs.extend(b.image_inputs)
|
||||||
|
else:
|
||||||
|
batch.image_inputs.append(None)
|
||||||
|
if b.image_positions is not None:
|
||||||
|
batch.image_positions.extend(b.image_positions)
|
||||||
|
else:
|
||||||
|
batch.image_positions.append(None)
|
||||||
|
if b.encoder_cache is not None:
|
||||||
|
batch.encoder_cache.extend(b.encoder_cache)
|
||||||
|
else:
|
||||||
|
batch.encoder_cache.append(None)
|
||||||
|
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
@ -298,11 +316,28 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]):
|
def filter(self, request_ids: List[int]):
|
||||||
|
if len(request_ids) == 0:
|
||||||
|
raise ValueError("Batch must have at least one request")
|
||||||
|
|
||||||
|
image_inputs = []
|
||||||
|
image_positions = []
|
||||||
|
encoder_cache = []
|
||||||
|
|
||||||
|
for i, request_id in enumerate(request_ids):
|
||||||
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
image_inputs.append(self.image_inputs[idx])
|
||||||
|
image_positions.append(self.image_positions[idx])
|
||||||
|
encoder_cache.append(self.encoder_cache[idx])
|
||||||
|
|
||||||
batch = super().filter(request_ids)
|
batch = super().filter(request_ids)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
|
||||||
|
batch.image_inputs = image_inputs
|
||||||
|
batch.image_positions = image_positions
|
||||||
|
batch.encoder_cache = encoder_cache
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -352,7 +387,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
if len(image_inputs) > 0:
|
if len(image_inputs) > 0:
|
||||||
batch_image_inputs[i] = image_inputs
|
batch_image_inputs[i] = image_inputs
|
||||||
|
from pdb import set_trace
|
||||||
|
|
||||||
|
set_trace()
|
||||||
batch_image_positions = []
|
batch_image_positions = []
|
||||||
batch_tokenized_inputs = []
|
batch_tokenized_inputs = []
|
||||||
max_length = 0
|
max_length = 0
|
||||||
@ -459,6 +496,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
super().prepare_for_prefill()
|
super().prepare_for_prefill()
|
||||||
|
|
||||||
self.has_image = False
|
self.has_image = False
|
||||||
|
self.encoder_cache_to_free = []
|
||||||
self.scheduled_image_input = []
|
self.scheduled_image_input = []
|
||||||
scheduled_image_pixel_values = []
|
scheduled_image_pixel_values = []
|
||||||
|
|
||||||
@ -574,6 +612,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
start_idx = max(cache_length - start_pos, 0)
|
start_idx = max(cache_length - start_pos, 0)
|
||||||
end_idx = min(cache_length - start_pos + input_length, length)
|
end_idx = min(cache_length - start_pos + input_length, length)
|
||||||
|
|
||||||
|
if end_idx == length:
|
||||||
|
self.encoder_cache_to_free.append((i, image_id))
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
image_id in self.encoder_cache[i]
|
image_id in self.encoder_cache[i]
|
||||||
), f"image_id {image_id} not in encoder_cache {self.encoder_cache[i]}"
|
), f"image_id {image_id} not in encoder_cache {self.encoder_cache[i]}"
|
||||||
@ -592,35 +633,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
return torch.cat(mm_embeds, dim=0).to(device)
|
return torch.cat(mm_embeds, dim=0).to(device)
|
||||||
|
|
||||||
def free_encoder_cache(self):
|
def free_encoder_cache(self):
|
||||||
for i, (
|
for i, image_id in self.encoder_cache_to_free:
|
||||||
r,
|
self.encoder_cache[i][image_id] = None
|
||||||
cache_length,
|
|
||||||
input_length,
|
|
||||||
request_prefilling,
|
|
||||||
) in enumerate(
|
|
||||||
zip(
|
|
||||||
self.requests,
|
|
||||||
self.cache_lengths,
|
|
||||||
self.input_lengths,
|
|
||||||
self.prefilling_mask,
|
|
||||||
)
|
|
||||||
):
|
|
||||||
if not request_prefilling or self.image_positions[i] is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for j, image_position in enumerate(self.image_positions[i]):
|
self.encoder_cache_to_free = []
|
||||||
image_id = image_position.id
|
|
||||||
|
|
||||||
start_pos = image_position.offset
|
|
||||||
length = image_position.length
|
|
||||||
|
|
||||||
cache_length = cache_length + input_length
|
|
||||||
if start_pos >= cache_length:
|
|
||||||
# No encoder input required at this step
|
|
||||||
break
|
|
||||||
|
|
||||||
if start_pos + length <= cache_length:
|
|
||||||
self.encoder_cache[i][image_id] = None
|
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLM(FlashCausalLM):
|
class VlmCausalLM(FlashCausalLM):
|
||||||
@ -814,6 +830,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
if batch.image_grid_thw is not None:
|
if batch.image_grid_thw is not None:
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
batch.free_encoder_cache()
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
Loading…
Reference in New Issue
Block a user