mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
working
This commit is contained in:
parent
92909f3f33
commit
44ed5efbcc
@ -207,6 +207,8 @@ class FlashCausalLMBatch(Batch):
|
||||
# Maximum number of blocks
|
||||
max_blocks: int
|
||||
|
||||
inputs_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
@ -1896,6 +1898,8 @@ class FlashCausalLM(Model):
|
||||
if prefill:
|
||||
batch.prepare_for_prefill()
|
||||
|
||||
self.get_input_embeddings(batch)
|
||||
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
# Update adapter indices for speculative tokens (if present)
|
||||
|
@ -368,6 +368,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
||||
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
||||
@ -377,9 +378,12 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
)
|
||||
inputs["input_ids"] = None
|
||||
|
||||
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
||||
logits = self.model.original_forward(
|
||||
input_ids=inputs["input_ids"],
|
||||
inputs_embeds=inputs_embeds.unsqueeze(0),
|
||||
position_ids=inputs["position_ids"],
|
||||
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||
@ -568,3 +572,44 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||
inputs["cache_position"] = position_ids
|
||||
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
||||
return inputs
|
||||
|
||||
def get_vision_embeds(self, pixel_values, image_sizes=None):
|
||||
image_features = self.model.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=self.model.config.vision_config.vision_feature_layer,
|
||||
vision_feature_select_strategy=self.model.config.vision_config.vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.model.multi_modal_projector(vision_flat)
|
||||
return projected_vision_flat
|
||||
|
||||
def get_input_embeds(self, input_ids, vision_embeddings=None):
|
||||
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
||||
|
||||
if vision_embeddings is not None:
|
||||
original_inputs_embeds_shape = inputs_embeds.shape
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||
-1
|
||||
)
|
||||
final_mask = special_image_mask.to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
|
||||
|
||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||
num_tokens_to_fill = final_mask_1d.sum()
|
||||
|
||||
if num_tokens_to_fill != vision_embeddings.size(0):
|
||||
raise ValueError(
|
||||
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||
f"but multi_modal_projector returned {vision_embeddings.size(0)}"
|
||||
)
|
||||
|
||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
|
||||
-1, inputs_embeds.size(-1)
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
expanded_mask, vision_embeddings
|
||||
)
|
||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
||||
return inputs_embeds
|
||||
|
@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
@ -115,7 +116,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
||||
if processor.image_processor.do_image_splitting:
|
||||
image_str *= 5
|
||||
return image_str
|
||||
return image_str, IDEFICS2_FAKE_TOKEN
|
||||
if config.model_type == "idefics3":
|
||||
# TODO: implement this in a more general way
|
||||
n_rows = image_input["rows"][0][image_id]
|
||||
@ -132,7 +133,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
image_token=IDEFICS3_IMAGE_TOKEN,
|
||||
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
||||
)
|
||||
return image_str
|
||||
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
|
||||
elif config.model_type == "llava_next":
|
||||
height, width = image_input["image_sizes"][image_id]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
@ -141,32 +142,32 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
logger.info,
|
||||
f"Found {num_features} features in image of resolution {height}x{width}",
|
||||
)
|
||||
return "<image>" * num_features
|
||||
return "<image>" * num_features, "<image>"
|
||||
|
||||
elif config.model_type == "paligemma":
|
||||
return "<image>" * config.text_config.num_image_tokens
|
||||
return "<image>" * config.text_config.num_image_tokens, "<image>"
|
||||
elif config.model_type == "qwen2_vl":
|
||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
||||
num_pads = grid_t * grid_h * grid_w // 4
|
||||
padding = "<|image_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||
elif config.model_type == "qwen2_5_vl":
|
||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
||||
num_pads = grid_t * grid_h * grid_w // 4
|
||||
padding = "<|image_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||
elif config.model_type == "gemma3":
|
||||
# TODO: get correct number of features via reviewing the Gemma3 architecture
|
||||
# and calculating the number of image tokens
|
||||
num_pads = 256
|
||||
padding = "<image_soft_token>" * num_pads
|
||||
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
|
||||
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n", "<start_of_image>"
|
||||
elif config.model_type == "llama4":
|
||||
patch_size = config.vision_config.patch_size
|
||||
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
|
||||
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
|
||||
aspect_ratios = image_input["aspect_ratios"][image_id]
|
||||
image_height, image_width = image_input["pixel_values"][image_id].shape[-2:]
|
||||
aspect_ratios = image_input[image_id]["aspect_ratios"][0]
|
||||
image_height, image_width = image_input[image_id]["pixel_values"][0].shape[-2:]
|
||||
|
||||
num_patches_per_chunk = int(
|
||||
(image_height // patch_size)
|
||||
@ -177,7 +178,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
aspect_ratios, num_patches_per_chunk
|
||||
)
|
||||
|
||||
return tokens_for_this_image
|
||||
return tokens_for_this_image, "<|image_start|>"
|
||||
else:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
@ -244,7 +245,42 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
||||
return unpadded_features + newline_features + base_features
|
||||
|
||||
|
||||
def scatter_image_embeds(embeds, is_embed):
|
||||
if is_embed is None:
|
||||
return embeds
|
||||
|
||||
placeholders = embeds.new_full(
|
||||
(is_embed.shape[0], embeds.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
placeholders[is_embed] = embeds
|
||||
return placeholders
|
||||
|
||||
|
||||
def gather_image_embeds(embeds, is_embed):
|
||||
if is_embed is None:
|
||||
return embeds
|
||||
|
||||
gathered = embeds[is_embed]
|
||||
|
||||
if len(gathered) == 0:
|
||||
return None
|
||||
return gathered
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImagePositions:
|
||||
offset: int
|
||||
length: int
|
||||
id: int
|
||||
num_placeholder_tokens: int
|
||||
is_embed: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]]
|
||||
image_positions: Optional[List[List[ImagePositions]]]
|
||||
encoder_cache: Optional[List[Dict[int, torch.Tensor]]]
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||
image_sizes: Optional[List[Tuple[int, int]]]
|
||||
@ -276,8 +312,17 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
# Process images first. We need all of them so that the processor
|
||||
# can make the image splits the same size. And we need the final
|
||||
# sizes to insert correct number of image tokens.
|
||||
images = []
|
||||
for r in requests:
|
||||
kwargs = {}
|
||||
if (
|
||||
hasattr(processor, "image_processor_class")
|
||||
and processor.image_processor_class == "Idefics3ImageProcessor"
|
||||
):
|
||||
kwargs["return_row_col_info"] = True
|
||||
|
||||
batch_image_inputs = []
|
||||
for i, r in enumerate(requests):
|
||||
image_inputs = []
|
||||
batch_image_inputs.append(None)
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
@ -292,46 +337,54 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
h = image.height * 2
|
||||
image = image.resize((w, h))
|
||||
|
||||
if config.model_type == "llava_next":
|
||||
images.append(image)
|
||||
elif config.model_type == "gemma3":
|
||||
images.append(image)
|
||||
elif config.model_type == "llama4":
|
||||
images.append(image)
|
||||
if config.model_type in {"llava_next", "gemma3", "llama4"}:
|
||||
image = image
|
||||
else:
|
||||
images.append([image])
|
||||
image = [image]
|
||||
|
||||
pixel_values = processor.image_processor(
|
||||
[image], return_tensors="pt", **kwargs
|
||||
)
|
||||
|
||||
image_inputs.append(pixel_values)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
if images:
|
||||
kwargs = {}
|
||||
if (
|
||||
hasattr(processor, "image_processor_class")
|
||||
and processor.image_processor_class == "Idefics3ImageProcessor"
|
||||
):
|
||||
kwargs["return_row_col_info"] = True
|
||||
|
||||
image_inputs = processor.image_processor(
|
||||
images, return_tensors="pt", **kwargs
|
||||
)
|
||||
else:
|
||||
image_inputs = None
|
||||
if len(image_inputs) > 0:
|
||||
batch_image_inputs[i] = image_inputs
|
||||
|
||||
batch_image_positions = []
|
||||
batch_tokenized_inputs = []
|
||||
max_length = 0
|
||||
image_id = 0
|
||||
for r in requests:
|
||||
for i, r in enumerate(requests):
|
||||
full_text = ""
|
||||
image_tokens = []
|
||||
batch_image_positions.append(None)
|
||||
image_id = 0
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
full_text += chunk.text
|
||||
elif chunk_type == "image":
|
||||
full_text += image_text_replacement(
|
||||
processor, image_inputs, config, image_id
|
||||
image_text, id_token = image_text_replacement(
|
||||
processor, batch_image_inputs[i], config, image_id
|
||||
)
|
||||
full_text += image_text
|
||||
|
||||
if config.model_type == "gemma3":
|
||||
image_text = image_text.replace("\n\n", "")
|
||||
|
||||
image_tokens.append(
|
||||
(
|
||||
image_id,
|
||||
id_token,
|
||||
tokenizer(image_text, add_special_tokens=False)[
|
||||
"input_ids"
|
||||
],
|
||||
)
|
||||
)
|
||||
image_id += 1
|
||||
# from pdb import set_trace; set_trace()
|
||||
|
||||
full_text = image_text_replacement_fixup(config, full_text)
|
||||
input_ids = tokenizer(
|
||||
full_text,
|
||||
@ -342,7 +395,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
max_length = max(max_length, len(input_ids))
|
||||
batch_tokenized_inputs.append(input_ids)
|
||||
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
prev = 0
|
||||
image_positions = []
|
||||
for image_id, id_token, tokens in image_tokens:
|
||||
id_token = tokenizer.get_vocab()[id_token]
|
||||
|
||||
length = len(tokens)
|
||||
index = input_ids[prev:].index(id_token)
|
||||
|
||||
index = index + prev
|
||||
assert (
|
||||
input_ids[index : index + length] == tokens
|
||||
), "Image token not found in input_ids"
|
||||
|
||||
if config.model_type in {"llava_next", "paligemma"}:
|
||||
is_embed = None
|
||||
num_placeholder_tokens = length
|
||||
else:
|
||||
is_embed = torch.tensor(tokens) == config.image_token_index
|
||||
num_placeholder_tokens = is_embed.sum().item()
|
||||
|
||||
pos = ImagePositions(
|
||||
offset=index,
|
||||
length=length,
|
||||
id=image_id,
|
||||
num_placeholder_tokens=num_placeholder_tokens,
|
||||
is_embed=is_embed,
|
||||
)
|
||||
|
||||
image_positions.append(pos)
|
||||
prev = index + length
|
||||
|
||||
if len(image_positions) > 0:
|
||||
batch_image_positions[i] = image_positions
|
||||
|
||||
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
@ -354,27 +441,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "VlmCausalLMBatch":
|
||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||
pb.requests, tokenizer, processor, config
|
||||
batch_tokenized_inputs, image_inputs, image_positions = (
|
||||
cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
|
||||
)
|
||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
if image_inputs is not None:
|
||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||
if "pixel_attention_mask" in image_inputs:
|
||||
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
||||
device=device
|
||||
)
|
||||
else:
|
||||
batch.pixel_attention_mask = None
|
||||
if "image_sizes" in image_inputs:
|
||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||
else:
|
||||
batch.image_sizes = None
|
||||
if "image_grid_thw" in image_inputs:
|
||||
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
||||
else:
|
||||
batch.image_grid_thw = None
|
||||
else:
|
||||
batch.image_inputs = image_inputs
|
||||
batch.image_positions = image_positions
|
||||
batch.encoder_cache = [{} for _ in range(len(pb.requests))]
|
||||
if len(image_inputs):
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
@ -394,58 +468,77 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
r,
|
||||
cache_length,
|
||||
input_length,
|
||||
prompt_length,
|
||||
request_prefilling,
|
||||
) in enumerate(
|
||||
zip(
|
||||
self.requests,
|
||||
self.cache_lengths,
|
||||
self.input_lengths,
|
||||
self.prompt_lengths,
|
||||
self.prefilling_mask,
|
||||
)
|
||||
):
|
||||
if not request_prefilling:
|
||||
if not request_prefilling or self.image_positions[i] is None:
|
||||
continue
|
||||
|
||||
for mm_inputs in batch_mm_inputs[i]:
|
||||
for j, mm_input in enumerate(mm_input):
|
||||
image_id = mm_input.id
|
||||
pixel_values = self.all_pixel_values[i][j].pixel_values
|
||||
for j, image_position in enumerate(self.image_positions[i]):
|
||||
image_id = image_position.id
|
||||
pixel_values = self.image_inputs[i][j]
|
||||
|
||||
start_pos = mm_input.offset
|
||||
length = mm_input.length
|
||||
num_placeholder_tokens = mm_input.num_placeholder_tokens
|
||||
start_pos = image_position.offset
|
||||
length = image_position.length
|
||||
|
||||
if start_pos >= cache_length + input_length:
|
||||
# No encoder input required at this step
|
||||
break
|
||||
if start_pos + length <= cache_length:
|
||||
# The encode input is already processed
|
||||
continue
|
||||
if start_pos >= cache_length + input_length:
|
||||
# No encoder input required at this step
|
||||
break
|
||||
if start_pos + length <= cache_length:
|
||||
# The encode input is already processed
|
||||
continue
|
||||
|
||||
self.has_image = True
|
||||
self.has_image = True
|
||||
|
||||
if image_id not in self.encoder_cache[i][image_id]:
|
||||
self.scheduled_image_input.append((i, mm_input))
|
||||
scheduled_image_pixel_values.append(pixel_values)
|
||||
if image_id not in self.encoder_cache[i]:
|
||||
self.scheduled_image_input.append((i, image_position))
|
||||
scheduled_image_pixel_values.append(pixel_values)
|
||||
|
||||
if self.has_image and len(scheduled_image_pixel_values):
|
||||
self.pixel_values = torch.cat([scheduled_image_pixel_values], dim=0).to(device)
|
||||
self.pixel_values = torch.cat(
|
||||
[d["pixel_values"] for d in scheduled_image_pixel_values], dim=0
|
||||
).to(device)
|
||||
|
||||
if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
|
||||
self.pixel_attention_mask = torch.cat(
|
||||
[d["pixel_attention_mask"] for d in scheduled_image_pixel_values],
|
||||
dim=0,
|
||||
).to(device)
|
||||
|
||||
if "image_sizes" in scheduled_image_pixel_values[0]:
|
||||
self.image_sizes = torch.cat(
|
||||
[d["image_sizes"] for d in scheduled_image_pixel_values], dim=0
|
||||
).to(device)
|
||||
|
||||
if "image_grid_thw" in scheduled_image_pixel_values[0]:
|
||||
self.image_grid_thw = torch.cat(
|
||||
[d["image_grid_thw"] for d in scheduled_image_pixel_values], dim=0
|
||||
).to(device)
|
||||
else:
|
||||
self.pixel_values = None
|
||||
self.pixel_attention_mask = None
|
||||
self.image_sizes = None
|
||||
self.image_grid_thw = None
|
||||
|
||||
def update_encoder_cache(self, encoder_outputs):
|
||||
prev = 0
|
||||
for i, input in self.scheduled_image_input:
|
||||
length = input.num_placeholder_tokens
|
||||
output = encoder_outputs[prev:length]
|
||||
batch.encoder_cache[i][image_id] = self.scatter_image_embed(output, input.is_embed)
|
||||
self.encoder_cache[i][input.id] = scatter_image_embeds(
|
||||
encoder_outputs[prev : prev + length], input.is_embed
|
||||
)
|
||||
|
||||
prev = length
|
||||
prev = prev + length
|
||||
|
||||
def get_mm_embeddings(self):
|
||||
device = self.input_ids.device
|
||||
|
||||
mm_embeds = []
|
||||
for i, (
|
||||
r,
|
||||
cache_length,
|
||||
@ -461,51 +554,43 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
self.prefilling_mask,
|
||||
)
|
||||
):
|
||||
if not request_prefilling:
|
||||
if not request_prefilling or self.image_positions[i] is None:
|
||||
continue
|
||||
|
||||
for mm_inputs in batch_mm_inputs[i]:
|
||||
for j, mm_input in enumerate(mm_input):
|
||||
image_id = mm_input.id
|
||||
pixel_values = self.all_pixel_values[i][j].pixel_values
|
||||
for j, image_position in enumerate(self.image_positions[i]):
|
||||
image_id = image_position.id
|
||||
|
||||
start_pos = mm_input.offset
|
||||
length = mm_input.length
|
||||
num_placeholder_tokens = mm_input.num_placeholder_tokens
|
||||
start_pos = image_position.offset
|
||||
length = image_position.length
|
||||
|
||||
if start_pos >= cache_length + input_length:
|
||||
# No encoder input required at this step
|
||||
break
|
||||
if start_pos + length <= cache_length:
|
||||
# The encode input is already processed
|
||||
continue
|
||||
if start_pos >= cache_length + input_length:
|
||||
# No encoder input required at this step
|
||||
break
|
||||
if start_pos + length <= cache_length:
|
||||
# The encode input is already processed
|
||||
continue
|
||||
|
||||
self.has_image = True
|
||||
start_idx = max(cache_length - start_pos, 0)
|
||||
end_idx = min(cache_length - start_pos + input_length, length)
|
||||
|
||||
if image_id not in self.encoder_cache[i][image_id]:
|
||||
self.scheduled_image_input.append(mm_input)
|
||||
scheduled_image_pixel_values.append(pixel_values)
|
||||
assert (
|
||||
image_id in self.encoder_cache[i]
|
||||
), f"image_id {image_id} not in encoder_cache {self.encoder_cache[i]}"
|
||||
encoder_output = self.encoder_cache[i][image_id]
|
||||
|
||||
is_embed = image_position.is_embed
|
||||
if is_embed is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
|
||||
start_idx = max(cache_length - start_pos, 0)
|
||||
end_idx = min(
|
||||
cache_length - start_pos + input_length,
|
||||
length)
|
||||
|
||||
encoder_output = self.encoder_cache[i][image_id]
|
||||
|
||||
is_embed = pos_info.is_embed
|
||||
if is_embed is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
|
||||
mm_embeds_item = gather_mm_embeds(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
mm_embeds_item = gather_image_embeds(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
|
||||
return torch.cat(mm_embeds, dim=0).to(device)
|
||||
|
||||
|
||||
class VlmCausalLM(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
@ -544,12 +629,24 @@ class VlmCausalLM(FlashCausalLM):
|
||||
|
||||
def get_mm_embeddings(self, batch):
|
||||
if batch.pixel_values is not None:
|
||||
encoder_outputs = self.model.get_mm_embeddings(batch.pixel_values)
|
||||
|
||||
encoder_outputs = self.get_vision_embeds(batch.pixel_values)
|
||||
batch.update_encoder_cache(encoder_outputs)
|
||||
|
||||
batch.pixel_values = None
|
||||
return batch.get_mm_embeddings()
|
||||
|
||||
def get_input_embeddings(self, batch):
|
||||
if batch.has_image:
|
||||
vision_embeddings = self.get_mm_embeddings(batch)
|
||||
batch.has_image = False
|
||||
else:
|
||||
vision_embeddings = None
|
||||
|
||||
inputs_embeds = self.get_input_embeds(
|
||||
batch.input_ids, vision_embeddings=vision_embeddings
|
||||
)
|
||||
|
||||
batch.inputs_embeds = inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: VlmCausalLMBatch,
|
||||
@ -600,6 +697,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
position_ids = new_position_ids
|
||||
else:
|
||||
input_ids = batch.input_ids
|
||||
inputs_embeds = batch.inputs_embeds
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = self.kv_cache
|
||||
@ -659,6 +757,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
|
Loading…
Reference in New Issue
Block a user