mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 03:52:08 +00:00
update pixel_values
This commit is contained in:
parent
52e4186c2a
commit
7237e8e6bf
@ -517,10 +517,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
self.has_image = False
|
self.has_image = False
|
||||||
self.encoder_cache_to_free = []
|
self.encoder_cache_to_free = []
|
||||||
self.scheduled_image_input = []
|
|
||||||
scheduled_image_pixel_values = []
|
|
||||||
|
|
||||||
device = self.input_ids.device
|
self.pixel_values = []
|
||||||
|
|
||||||
for i, (
|
for i, (
|
||||||
r,
|
r,
|
||||||
@ -540,7 +538,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
for j, image_position in enumerate(self.image_positions[i]):
|
for j, image_position in enumerate(self.image_positions[i]):
|
||||||
image_id = image_position.id
|
image_id = image_position.id
|
||||||
pixel_values = self.image_inputs[i][j]
|
image_inputs = self.image_inputs[i][j]
|
||||||
|
|
||||||
start_pos = image_position.offset
|
start_pos = image_position.offset
|
||||||
length = image_position.length
|
length = image_position.length
|
||||||
@ -555,38 +553,38 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
self.has_image = True
|
self.has_image = True
|
||||||
|
|
||||||
if image_id not in self.encoder_cache[i]:
|
if image_id not in self.encoder_cache[i]:
|
||||||
self.scheduled_image_input.append((i, image_position))
|
self.pixel_values.append((i, image_position, image_inputs))
|
||||||
scheduled_image_pixel_values.append(pixel_values)
|
# scheduled_image_pixel_values.append(image_inputs)
|
||||||
self.image_inputs[i][j] = None
|
self.image_inputs[i][j] = None
|
||||||
|
|
||||||
if self.has_image and len(scheduled_image_pixel_values):
|
# if self.has_image and len(scheduled_image_pixel_values):
|
||||||
self.pixel_values = [
|
# self.pixel_values = [
|
||||||
d["pixel_values"].to(device) for d in scheduled_image_pixel_values
|
# d["pixel_values"].to(device) for d in scheduled_image_pixel_values
|
||||||
]
|
# ]
|
||||||
|
|
||||||
if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
|
# if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
|
||||||
self.pixel_attention_mask = [
|
# self.pixel_attention_mask = [
|
||||||
d["pixel_attention_mask"].to(device)
|
# d["pixel_attention_mask"].to(device)
|
||||||
for d in scheduled_image_pixel_values
|
# for d in scheduled_image_pixel_values
|
||||||
]
|
# ]
|
||||||
|
|
||||||
if "image_sizes" in scheduled_image_pixel_values[0]:
|
# if "image_sizes" in scheduled_image_pixel_values[0]:
|
||||||
self.image_sizes = [
|
# self.image_sizes = [
|
||||||
d["image_sizes"].to(device) for d in scheduled_image_pixel_values
|
# d["image_sizes"].to(device) for d in scheduled_image_pixel_values
|
||||||
]
|
# ]
|
||||||
|
|
||||||
if "image_grid_thw" in scheduled_image_pixel_values[0]:
|
# if "image_grid_thw" in scheduled_image_pixel_values[0]:
|
||||||
self.image_grid_thw = [
|
# self.image_grid_thw = [
|
||||||
d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values
|
# d["image_grid_thw"].to(device) for d in scheduled_image_pixel_values
|
||||||
]
|
# ]
|
||||||
else:
|
if not self.has_image:
|
||||||
self.pixel_values = None
|
self.pixel_values = None
|
||||||
self.pixel_attention_mask = None
|
self.pixel_attention_mask = None
|
||||||
self.image_sizes = None
|
self.image_sizes = None
|
||||||
self.image_grid_thw = None
|
self.image_grid_thw = None
|
||||||
|
|
||||||
def update_encoder_cache(self, encoder_outputs, batch, input):
|
def update_encoder_cache(self, encoder_outputs, request_id, input):
|
||||||
self.encoder_cache[batch][input.id] = scatter_image_embeds(
|
self.encoder_cache[request_id][input.id] = scatter_image_embeds(
|
||||||
encoder_outputs, input.is_embed
|
encoder_outputs, input.is_embed
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -717,24 +715,26 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
|
|
||||||
def get_mm_embeddings(self, batch):
|
def get_mm_embeddings(self, batch):
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
for i, image_input in batch.scheduled_image_input:
|
device = batch.input_ids.device
|
||||||
from pdb import set_trace
|
for request_id, image_position, image_input in batch.pixel_values:
|
||||||
|
pixel_values = image_input["pixel_values"].to(device)
|
||||||
|
|
||||||
set_trace()
|
if "pixel_attention_mask" in image_input:
|
||||||
pixel_values = batch.pixel_values[i]
|
pixel_attention_mask = image_input["pixel_attention_mask"].to(
|
||||||
pixel_attention_mask = (
|
device
|
||||||
batch.pixel_attention_mask[i]
|
|
||||||
if batch.pixel_attention_mask is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
image_sizes = (
|
|
||||||
batch.image_sizes[i] if batch.image_sizes is not None else None
|
|
||||||
)
|
|
||||||
image_grid_thw = (
|
|
||||||
batch.image_grid_thw[i]
|
|
||||||
if batch.image_grid_thw is not None
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
pixel_attention_mask = None
|
||||||
|
|
||||||
|
if "image_sizes" in image_input:
|
||||||
|
image_sizes = image_input["image_sizes"].to(device)
|
||||||
|
else:
|
||||||
|
image_sizes = None
|
||||||
|
|
||||||
|
if "image_grid_thw" in image_input:
|
||||||
|
image_grid_thw = image_input["image_grid_thw"].to(device)
|
||||||
|
else:
|
||||||
|
image_grid_thw = None
|
||||||
|
|
||||||
encoder_outputs = self.get_vision_embeds(
|
encoder_outputs = self.get_vision_embeds(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
@ -742,12 +742,9 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
image_sizes=image_sizes,
|
image_sizes=image_sizes,
|
||||||
image_grid_thw=image_grid_thw,
|
image_grid_thw=image_grid_thw,
|
||||||
)
|
)
|
||||||
batch.update_encoder_cache(encoder_outputs, i, image_input)
|
batch.update_encoder_cache(encoder_outputs, request_id, image_position)
|
||||||
|
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
|
||||||
batch.image_sizes = None
|
|
||||||
batch.image_grid_thw = None
|
|
||||||
return batch.get_mm_embeddings()
|
return batch.get_mm_embeddings()
|
||||||
|
|
||||||
def get_input_embeddings(self, batch):
|
def get_input_embeddings(self, batch):
|
||||||
|
Loading…
Reference in New Issue
Block a user