mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix mllama crash if bs>0 and filter
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
b1ae4ad260
commit
bba260912c
@ -1075,8 +1075,19 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids = [0] * extra_pad + input_ids
|
input_ids = [0] * extra_pad + input_ids
|
||||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
else:
|
else:
|
||||||
self.input_ids = F.pad(self.input_ids, (extra_pad, 0), value=0)
|
input_ids = self.input_ids.new_zeros(max_padded_input_len * len(self))
|
||||||
input_ids_padded_length.extend([extra_pad] * len(self))
|
src_pos = 0
|
||||||
|
for i in range(len(self)):
|
||||||
|
end_pos = (i + 1) * max_padded_input_len
|
||||||
|
start_pos = end_pos - self.input_lengths[i]
|
||||||
|
input_ids[start_pos:end_pos] = self.input_ids[
|
||||||
|
src_pos : src_pos + self.input_lengths[i]
|
||||||
|
]
|
||||||
|
input_ids_padded_length.append(
|
||||||
|
max_padded_input_len - self.input_lengths[i]
|
||||||
|
)
|
||||||
|
src_pos += self.input_lengths[i]
|
||||||
|
self.input_ids = input_ids
|
||||||
|
|
||||||
self.input_ids = F.pad(
|
self.input_ids = F.pad(
|
||||||
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0
|
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=0
|
||||||
|
@ -80,7 +80,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
|||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]):
|
def filter(self, request_ids: List[int]):
|
||||||
assert self.image_indices is not None
|
assert self.image_indices is not None
|
||||||
batch = super().filter(request_ids)
|
batch = super(FlashVlmCausalLMBatch, self).filter(request_ids)
|
||||||
assert self.image_indices is not None
|
assert self.image_indices is not None
|
||||||
indices = []
|
indices = []
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request_id in enumerate(request_ids):
|
||||||
@ -106,6 +106,7 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
batch.cross_attention_states = None
|
batch.cross_attention_states = None
|
||||||
|
batch.pixel_values = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user