mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix qwen test
This commit is contained in:
parent
996473164a
commit
6a5955a78c
@ -603,6 +603,18 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
self.pixel_attention_mask = None
|
self.pixel_attention_mask = None
|
||||||
self.image_sizes = None
|
self.image_sizes = None
|
||||||
self.image_grid_thw = None
|
self.image_grid_thw = None
|
||||||
|
else:
|
||||||
|
image_grid_thw_list = [
|
||||||
|
x[2]["image_grid_thw"]
|
||||||
|
for x in self.pixel_values
|
||||||
|
if "image_grid_thw" in x[2]
|
||||||
|
]
|
||||||
|
if image_grid_thw_list:
|
||||||
|
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to(
|
||||||
|
self.input_ids.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.image_grid_thw = None
|
||||||
|
|
||||||
def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
|
def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
|
||||||
self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
|
self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
|
||||||
@ -886,7 +898,6 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def encode_images(self, batch):
|
def encode_images(self, batch):
|
||||||
image_grid_thw = None
|
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
device = batch.input_ids.device
|
device = batch.input_ids.device
|
||||||
for request_id, image_id, image_input in batch.pixel_values:
|
for request_id, image_id, image_input in batch.pixel_values:
|
||||||
@ -924,7 +935,6 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = image_grid_thw
|
|
||||||
|
|
||||||
def set_inputs_embeds(self, batch):
|
def set_inputs_embeds(self, batch):
|
||||||
if batch.has_image_inputs:
|
if batch.has_image_inputs:
|
||||||
|
Loading…
Reference in New Issue
Block a user