mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
update pixel_values
This commit is contained in:
parent
52e4186c2a
commit
7237e8e6bf
@ -517,10 +517,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
self.has_image = False
|
||||
self.encoder_cache_to_free = []
|
||||
self.scheduled_image_input = []
|
||||
scheduled_image_pixel_values = []
|
||||
|
||||
device = self.input_ids.device
|
||||
self.pixel_values = []
|
||||
|
||||
for i, (
|
||||
r,
|
||||
@ -540,7 +538,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
for j, image_position in enumerate(self.image_positions[i]):
|
||||
image_id = image_position.id
|
||||
pixel_values = self.image_inputs[i][j]
|
||||
image_inputs = self.image_inputs[i][j]
|
||||
|
||||
start_pos = image_position.offset
|
||||
length = image_position.length
|
||||
@ -555,38 +553,38 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
self.has_image = True
|
||||
|
||||
if image_id not in self.encoder_cache[i]:
|
||||
self.scheduled_image_input.append((i, image_position))
|
||||
scheduled_image_pixel_values.append(pixel_values)
|
||||
self.pixel_values.append((i, image_position, image_inputs))
|
||||
# scheduled_image_pixel_values.append(image_inputs)
|
||||
self.image_inputs[i][j] = None
|
||||
|
||||
if self.has_image and len(scheduled_image_pixel_values):
|
||||
self.pixel_values = [
|
||||
d["pixel_values"].to(device) for d in scheduled_image_pixel_values
|
||||
]
|
||||
# if self.has_image and len(scheduled_image_pixel_values):
|
||||
# self.pixel_values = [
|
||||
# d["pixel_values"].to(device) for d in scheduled_image_pixel_values
|
||||
# ]
|
||||
|
||||
if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
|
||||
self.pixel_attention_mask = [
|
||||
d["pixel_attention_mask"].to(device)
|
||||
for d in scheduled_image_pixel_values
|
||||
]
|
||||
# if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
|
||||
# self.pixel_attention_mask = [
|
||||
# d["pixel_attention_mask"].to(device)
|
||||
# for d in scheduled_image_pixel_values
|
||||
# ]
|
||||
|
||||
if "image_sizes" in scheduled_image_pixel_values[0]:
|
||||
self.image_sizes = [
|
||||
d["image_sizes"].to(device) for d in scheduled_image_pixel_values
|
||||
]
|
||||
# if "image_sizes" in scheduled_image_pixel_values[0]:
|
||||
# self.image_sizes = [
|
||||
# d["image_sizes"].to(device) for d in scheduled_image_pixel_values
|
||||
# ]
|
||||
|
||||
if "image_grid_thw" in scheduled_image_pixel_values[0]:
|
||||
self.image_grid_thw = [
|
||||
d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values
|
||||
]
|
||||
else:
|
||||
# if "image_grid_thw" in scheduled_image_pixel_values[0]:
|
||||
# self.image_grid_thw = [
|
||||
# d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values
|
||||
# ]
|
||||
if not self.has_image:
|
||||
self.pixel_values = None
|
||||
self.pixel_attention_mask = None
|
||||
self.image_sizes = None
|
||||
self.image_grid_thw = None
|
||||
|
||||
def update_encoder_cache(self, encoder_outputs, batch, input):
|
||||
self.encoder_cache[batch][input.id] = scatter_image_embeds(
|
||||
def update_encoder_cache(self, encoder_outputs, request_id, input):
|
||||
self.encoder_cache[request_id][input.id] = scatter_image_embeds(
|
||||
encoder_outputs, input.is_embed
|
||||
)
|
||||
|
||||
@ -717,24 +715,26 @@ class VlmCausalLM(FlashCausalLM):
|
||||
|
||||
def get_mm_embeddings(self, batch):
|
||||
if batch.pixel_values is not None:
|
||||
for i, image_input in batch.scheduled_image_input:
|
||||
from pdb import set_trace
|
||||
device = batch.input_ids.device
|
||||
for request_id, image_position, image_input in batch.pixel_values:
|
||||
pixel_values = image_input["pixel_values"].to(device)
|
||||
|
||||
set_trace()
|
||||
pixel_values = batch.pixel_values[i]
|
||||
pixel_attention_mask = (
|
||||
batch.pixel_attention_mask[i]
|
||||
if batch.pixel_attention_mask is not None
|
||||
else None
|
||||
)
|
||||
image_sizes = (
|
||||
batch.image_sizes[i] if batch.image_sizes is not None else None
|
||||
)
|
||||
image_grid_thw = (
|
||||
batch.image_grid_thw[i]
|
||||
if batch.image_grid_thw is not None
|
||||
else None
|
||||
)
|
||||
if "pixel_attention_mask" in image_input:
|
||||
pixel_attention_mask = image_input["pixel_attention_mask"].to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
pixel_attention_mask = None
|
||||
|
||||
if "image_sizes" in image_input:
|
||||
image_sizes = image_input["image_sizes"].to(device)
|
||||
else:
|
||||
image_sizes = None
|
||||
|
||||
if "image_grid_thw" in image_input:
|
||||
image_grid_thw = image_input["image_grid_thw"].to(device)
|
||||
else:
|
||||
image_grid_thw = None
|
||||
|
||||
encoder_outputs = self.get_vision_embeds(
|
||||
pixel_values=pixel_values,
|
||||
@ -742,12 +742,9 @@ class VlmCausalLM(FlashCausalLM):
|
||||
image_sizes=image_sizes,
|
||||
image_grid_thw=image_grid_thw,
|
||||
)
|
||||
batch.update_encoder_cache(encoder_outputs, i, image_input)
|
||||
batch.update_encoder_cache(encoder_outputs, request_id, image_position)
|
||||
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
return batch.get_mm_embeddings()
|
||||
|
||||
def get_input_embeddings(self, batch):
|
||||
|
Loading…
Reference in New Issue
Block a user