mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
working
This commit is contained in:
parent
92909f3f33
commit
44ed5efbcc
@ -207,6 +207,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Maximum number of blocks
|
# Maximum number of blocks
|
||||||
max_blocks: int
|
max_blocks: int
|
||||||
|
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
return generate_pb2.CachedBatch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
@ -1896,6 +1898,8 @@ class FlashCausalLM(Model):
|
|||||||
if prefill:
|
if prefill:
|
||||||
batch.prepare_for_prefill()
|
batch.prepare_for_prefill()
|
||||||
|
|
||||||
|
self.get_input_embeddings(batch)
|
||||||
|
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
# Update adapter indices for speculative tokens (if present)
|
# Update adapter indices for speculative tokens (if present)
|
||||||
|
@ -368,6 +368,7 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
pixel_attention_mask=None,
|
pixel_attention_mask=None,
|
||||||
image_sizes: Optional[torch.LongTensor] = None,
|
image_sizes: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
):
|
):
|
||||||
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
||||||
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
||||||
@ -377,9 +378,12 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
)
|
)
|
||||||
|
inputs["input_ids"] = None
|
||||||
|
|
||||||
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
||||||
logits = self.model.original_forward(
|
logits = self.model.original_forward(
|
||||||
input_ids=inputs["input_ids"],
|
input_ids=inputs["input_ids"],
|
||||||
|
inputs_embeds=inputs_embeds.unsqueeze(0),
|
||||||
position_ids=inputs["position_ids"],
|
position_ids=inputs["position_ids"],
|
||||||
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||||
@ -568,3 +572,44 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
|||||||
inputs["cache_position"] = position_ids
|
inputs["cache_position"] = position_ids
|
||||||
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
inputs["attention_mask"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
def get_vision_embeds(self, pixel_values, image_sizes=None):
|
||||||
|
image_features = self.model.get_image_features(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
vision_feature_layer=self.model.config.vision_config.vision_feature_layer,
|
||||||
|
vision_feature_select_strategy=self.model.config.vision_config.vision_feature_select_strategy,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||||
|
projected_vision_flat = self.model.multi_modal_projector(vision_flat)
|
||||||
|
return projected_vision_flat
|
||||||
|
|
||||||
|
def get_input_embeds(self, input_ids, vision_embeddings=None):
|
||||||
|
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
if vision_embeddings is not None:
|
||||||
|
original_inputs_embeds_shape = inputs_embeds.shape
|
||||||
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
final_mask = special_image_mask.to(inputs_embeds.device)
|
||||||
|
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
|
||||||
|
|
||||||
|
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||||
|
num_tokens_to_fill = final_mask_1d.sum()
|
||||||
|
|
||||||
|
if num_tokens_to_fill != vision_embeddings.size(0):
|
||||||
|
raise ValueError(
|
||||||
|
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||||
|
f"but multi_modal_projector returned {vision_embeddings.size(0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
|
||||||
|
-1, inputs_embeds.size(-1)
|
||||||
|
)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(
|
||||||
|
expanded_mask, vision_embeddings
|
||||||
|
)
|
||||||
|
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
||||||
|
return inputs_embeds
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -115,7 +116,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
||||||
if processor.image_processor.do_image_splitting:
|
if processor.image_processor.do_image_splitting:
|
||||||
image_str *= 5
|
image_str *= 5
|
||||||
return image_str
|
return image_str, IDEFICS2_FAKE_TOKEN
|
||||||
if config.model_type == "idefics3":
|
if config.model_type == "idefics3":
|
||||||
# TODO: implement this in a more general way
|
# TODO: implement this in a more general way
|
||||||
n_rows = image_input["rows"][0][image_id]
|
n_rows = image_input["rows"][0][image_id]
|
||||||
@ -132,7 +133,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
image_token=IDEFICS3_IMAGE_TOKEN,
|
image_token=IDEFICS3_IMAGE_TOKEN,
|
||||||
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
||||||
)
|
)
|
||||||
return image_str
|
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
|
||||||
elif config.model_type == "llava_next":
|
elif config.model_type == "llava_next":
|
||||||
height, width = image_input["image_sizes"][image_id]
|
height, width = image_input["image_sizes"][image_id]
|
||||||
num_features = get_number_of_features(height, width, config)
|
num_features = get_number_of_features(height, width, config)
|
||||||
@ -141,32 +142,32 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
logger.info,
|
logger.info,
|
||||||
f"Found {num_features} features in image of resolution {height}x{width}",
|
f"Found {num_features} features in image of resolution {height}x{width}",
|
||||||
)
|
)
|
||||||
return "<image>" * num_features
|
return "<image>" * num_features, "<image>"
|
||||||
|
|
||||||
elif config.model_type == "paligemma":
|
elif config.model_type == "paligemma":
|
||||||
return "<image>" * config.text_config.num_image_tokens
|
return "<image>" * config.text_config.num_image_tokens, "<image>"
|
||||||
elif config.model_type == "qwen2_vl":
|
elif config.model_type == "qwen2_vl":
|
||||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
||||||
num_pads = grid_t * grid_h * grid_w // 4
|
num_pads = grid_t * grid_h * grid_w // 4
|
||||||
padding = "<|image_pad|>" * num_pads
|
padding = "<|image_pad|>" * num_pads
|
||||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||||
elif config.model_type == "qwen2_5_vl":
|
elif config.model_type == "qwen2_5_vl":
|
||||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
|
||||||
num_pads = grid_t * grid_h * grid_w // 4
|
num_pads = grid_t * grid_h * grid_w // 4
|
||||||
padding = "<|image_pad|>" * num_pads
|
padding = "<|image_pad|>" * num_pads
|
||||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||||
elif config.model_type == "gemma3":
|
elif config.model_type == "gemma3":
|
||||||
# TODO: get correct number of features via reviewing the Gemma3 architecture
|
# TODO: get correct number of features via reviewing the Gemma3 architecture
|
||||||
# and calculating the number of image tokens
|
# and calculating the number of image tokens
|
||||||
num_pads = 256
|
num_pads = 256
|
||||||
padding = "<image_soft_token>" * num_pads
|
padding = "<image_soft_token>" * num_pads
|
||||||
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
|
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n", "<start_of_image>"
|
||||||
elif config.model_type == "llama4":
|
elif config.model_type == "llama4":
|
||||||
patch_size = config.vision_config.patch_size
|
patch_size = config.vision_config.patch_size
|
||||||
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
|
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
|
||||||
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
|
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
|
||||||
aspect_ratios = image_input["aspect_ratios"][image_id]
|
aspect_ratios = image_input[image_id]["aspect_ratios"][0]
|
||||||
image_height, image_width = image_input["pixel_values"][image_id].shape[-2:]
|
image_height, image_width = image_input[image_id]["pixel_values"][0].shape[-2:]
|
||||||
|
|
||||||
num_patches_per_chunk = int(
|
num_patches_per_chunk = int(
|
||||||
(image_height // patch_size)
|
(image_height // patch_size)
|
||||||
@ -177,7 +178,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
aspect_ratios, num_patches_per_chunk
|
aspect_ratios, num_patches_per_chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
return tokens_for_this_image
|
return tokens_for_this_image, "<|image_start|>"
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
@ -244,7 +245,42 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||||||
return unpadded_features + newline_features + base_features
|
return unpadded_features + newline_features + base_features
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_image_embeds(embeds, is_embed):
|
||||||
|
if is_embed is None:
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
placeholders = embeds.new_full(
|
||||||
|
(is_embed.shape[0], embeds.shape[-1]),
|
||||||
|
fill_value=torch.nan,
|
||||||
|
)
|
||||||
|
placeholders[is_embed] = embeds
|
||||||
|
return placeholders
|
||||||
|
|
||||||
|
|
||||||
|
def gather_image_embeds(embeds, is_embed):
|
||||||
|
if is_embed is None:
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
gathered = embeds[is_embed]
|
||||||
|
|
||||||
|
if len(gathered) == 0:
|
||||||
|
return None
|
||||||
|
return gathered
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImagePositions:
|
||||||
|
offset: int
|
||||||
|
length: int
|
||||||
|
id: int
|
||||||
|
num_placeholder_tokens: int
|
||||||
|
is_embed: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLMBatch(FlashCausalLMBatch):
|
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
|
image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]]
|
||||||
|
image_positions: Optional[List[List[ImagePositions]]]
|
||||||
|
encoder_cache: Optional[List[Dict[int, torch.Tensor]]]
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
image_sizes: Optional[List[Tuple[int, int]]]
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
@ -276,8 +312,17 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
# Process images first. We need all of them so that the processor
|
# Process images first. We need all of them so that the processor
|
||||||
# can make the image splits the same size. And we need the final
|
# can make the image splits the same size. And we need the final
|
||||||
# sizes to insert correct number of image tokens.
|
# sizes to insert correct number of image tokens.
|
||||||
images = []
|
kwargs = {}
|
||||||
for r in requests:
|
if (
|
||||||
|
hasattr(processor, "image_processor_class")
|
||||||
|
and processor.image_processor_class == "Idefics3ImageProcessor"
|
||||||
|
):
|
||||||
|
kwargs["return_row_col_info"] = True
|
||||||
|
|
||||||
|
batch_image_inputs = []
|
||||||
|
for i, r in enumerate(requests):
|
||||||
|
image_inputs = []
|
||||||
|
batch_image_inputs.append(None)
|
||||||
for chunk in r.input_chunks.chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
if chunk_type == "text":
|
if chunk_type == "text":
|
||||||
@ -292,46 +337,54 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
h = image.height * 2
|
h = image.height * 2
|
||||||
image = image.resize((w, h))
|
image = image.resize((w, h))
|
||||||
|
|
||||||
if config.model_type == "llava_next":
|
if config.model_type in {"llava_next", "gemma3", "llama4"}:
|
||||||
images.append(image)
|
image = image
|
||||||
elif config.model_type == "gemma3":
|
|
||||||
images.append(image)
|
|
||||||
elif config.model_type == "llama4":
|
|
||||||
images.append(image)
|
|
||||||
else:
|
else:
|
||||||
images.append([image])
|
image = [image]
|
||||||
|
|
||||||
|
pixel_values = processor.image_processor(
|
||||||
|
[image], return_tensors="pt", **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
image_inputs.append(pixel_values)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
if images:
|
if len(image_inputs) > 0:
|
||||||
kwargs = {}
|
batch_image_inputs[i] = image_inputs
|
||||||
if (
|
|
||||||
hasattr(processor, "image_processor_class")
|
|
||||||
and processor.image_processor_class == "Idefics3ImageProcessor"
|
|
||||||
):
|
|
||||||
kwargs["return_row_col_info"] = True
|
|
||||||
|
|
||||||
image_inputs = processor.image_processor(
|
|
||||||
images, return_tensors="pt", **kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_inputs = None
|
|
||||||
|
|
||||||
|
batch_image_positions = []
|
||||||
batch_tokenized_inputs = []
|
batch_tokenized_inputs = []
|
||||||
max_length = 0
|
max_length = 0
|
||||||
image_id = 0
|
for i, r in enumerate(requests):
|
||||||
for r in requests:
|
|
||||||
full_text = ""
|
full_text = ""
|
||||||
|
image_tokens = []
|
||||||
|
batch_image_positions.append(None)
|
||||||
|
image_id = 0
|
||||||
for chunk in r.input_chunks.chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
if chunk_type == "text":
|
if chunk_type == "text":
|
||||||
full_text += chunk.text
|
full_text += chunk.text
|
||||||
elif chunk_type == "image":
|
elif chunk_type == "image":
|
||||||
full_text += image_text_replacement(
|
image_text, id_token = image_text_replacement(
|
||||||
processor, image_inputs, config, image_id
|
processor, batch_image_inputs[i], config, image_id
|
||||||
|
)
|
||||||
|
full_text += image_text
|
||||||
|
|
||||||
|
if config.model_type == "gemma3":
|
||||||
|
image_text = image_text.replace("\n\n", "")
|
||||||
|
|
||||||
|
image_tokens.append(
|
||||||
|
(
|
||||||
|
image_id,
|
||||||
|
id_token,
|
||||||
|
tokenizer(image_text, add_special_tokens=False)[
|
||||||
|
"input_ids"
|
||||||
|
],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
image_id += 1
|
image_id += 1
|
||||||
# from pdb import set_trace; set_trace()
|
|
||||||
full_text = image_text_replacement_fixup(config, full_text)
|
full_text = image_text_replacement_fixup(config, full_text)
|
||||||
input_ids = tokenizer(
|
input_ids = tokenizer(
|
||||||
full_text,
|
full_text,
|
||||||
@ -342,7 +395,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
max_length = max(max_length, len(input_ids))
|
max_length = max(max_length, len(input_ids))
|
||||||
batch_tokenized_inputs.append(input_ids)
|
batch_tokenized_inputs.append(input_ids)
|
||||||
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
prev = 0
|
||||||
|
image_positions = []
|
||||||
|
for image_id, id_token, tokens in image_tokens:
|
||||||
|
id_token = tokenizer.get_vocab()[id_token]
|
||||||
|
|
||||||
|
length = len(tokens)
|
||||||
|
index = input_ids[prev:].index(id_token)
|
||||||
|
|
||||||
|
index = index + prev
|
||||||
|
assert (
|
||||||
|
input_ids[index : index + length] == tokens
|
||||||
|
), "Image token not found in input_ids"
|
||||||
|
|
||||||
|
if config.model_type in {"llava_next", "paligemma"}:
|
||||||
|
is_embed = None
|
||||||
|
num_placeholder_tokens = length
|
||||||
|
else:
|
||||||
|
is_embed = torch.tensor(tokens) == config.image_token_index
|
||||||
|
num_placeholder_tokens = is_embed.sum().item()
|
||||||
|
|
||||||
|
pos = ImagePositions(
|
||||||
|
offset=index,
|
||||||
|
length=length,
|
||||||
|
id=image_id,
|
||||||
|
num_placeholder_tokens=num_placeholder_tokens,
|
||||||
|
is_embed=is_embed,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_positions.append(pos)
|
||||||
|
prev = index + length
|
||||||
|
|
||||||
|
if len(image_positions) > 0:
|
||||||
|
batch_image_positions[i] = image_positions
|
||||||
|
|
||||||
|
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb_processor(
|
def from_pb_processor(
|
||||||
@ -354,27 +441,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "VlmCausalLMBatch":
|
) -> "VlmCausalLMBatch":
|
||||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
batch_tokenized_inputs, image_inputs, image_positions = (
|
||||||
pb.requests, tokenizer, processor, config
|
cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
|
||||||
)
|
)
|
||||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
if image_inputs is not None:
|
batch.image_inputs = image_inputs
|
||||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
batch.image_positions = image_positions
|
||||||
if "pixel_attention_mask" in image_inputs:
|
batch.encoder_cache = [{} for _ in range(len(pb.requests))]
|
||||||
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
if len(image_inputs):
|
||||||
device=device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch.pixel_attention_mask = None
|
|
||||||
if "image_sizes" in image_inputs:
|
|
||||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
|
||||||
else:
|
|
||||||
batch.image_sizes = None
|
|
||||||
if "image_grid_thw" in image_inputs:
|
|
||||||
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
|
||||||
else:
|
|
||||||
batch.image_grid_thw = None
|
|
||||||
else:
|
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
@ -394,28 +468,24 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
r,
|
r,
|
||||||
cache_length,
|
cache_length,
|
||||||
input_length,
|
input_length,
|
||||||
prompt_length,
|
|
||||||
request_prefilling,
|
request_prefilling,
|
||||||
) in enumerate(
|
) in enumerate(
|
||||||
zip(
|
zip(
|
||||||
self.requests,
|
self.requests,
|
||||||
self.cache_lengths,
|
self.cache_lengths,
|
||||||
self.input_lengths,
|
self.input_lengths,
|
||||||
self.prompt_lengths,
|
|
||||||
self.prefilling_mask,
|
self.prefilling_mask,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
if not request_prefilling:
|
if not request_prefilling or self.image_positions[i] is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for mm_inputs in batch_mm_inputs[i]:
|
for j, image_position in enumerate(self.image_positions[i]):
|
||||||
for j, mm_input in enumerate(mm_input):
|
image_id = image_position.id
|
||||||
image_id = mm_input.id
|
pixel_values = self.image_inputs[i][j]
|
||||||
pixel_values = self.all_pixel_values[i][j].pixel_values
|
|
||||||
|
|
||||||
start_pos = mm_input.offset
|
start_pos = image_position.offset
|
||||||
length = mm_input.length
|
length = image_position.length
|
||||||
num_placeholder_tokens = mm_input.num_placeholder_tokens
|
|
||||||
|
|
||||||
if start_pos >= cache_length + input_length:
|
if start_pos >= cache_length + input_length:
|
||||||
# No encoder input required at this step
|
# No encoder input required at this step
|
||||||
@ -426,26 +496,49 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
self.has_image = True
|
self.has_image = True
|
||||||
|
|
||||||
if image_id not in self.encoder_cache[i][image_id]:
|
if image_id not in self.encoder_cache[i]:
|
||||||
self.scheduled_image_input.append((i, mm_input))
|
self.scheduled_image_input.append((i, image_position))
|
||||||
scheduled_image_pixel_values.append(pixel_values)
|
scheduled_image_pixel_values.append(pixel_values)
|
||||||
|
|
||||||
if self.has_image and len(scheduled_image_pixel_values):
|
if self.has_image and len(scheduled_image_pixel_values):
|
||||||
self.pixel_values = torch.cat([scheduled_image_pixel_values], dim=0).to(device)
|
self.pixel_values = torch.cat(
|
||||||
|
[d["pixel_values"] for d in scheduled_image_pixel_values], dim=0
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
if "pixel_attention_mask" in scheduled_image_pixel_values[0]:
|
||||||
|
self.pixel_attention_mask = torch.cat(
|
||||||
|
[d["pixel_attention_mask"] for d in scheduled_image_pixel_values],
|
||||||
|
dim=0,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
if "image_sizes" in scheduled_image_pixel_values[0]:
|
||||||
|
self.image_sizes = torch.cat(
|
||||||
|
[d["image_sizes"] for d in scheduled_image_pixel_values], dim=0
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
if "image_grid_thw" in scheduled_image_pixel_values[0]:
|
||||||
|
self.image_grid_thw = torch.cat(
|
||||||
|
[d["image_grid_thw"] for d in scheduled_image_pixel_values], dim=0
|
||||||
|
).to(device)
|
||||||
|
else:
|
||||||
|
self.pixel_values = None
|
||||||
|
self.pixel_attention_mask = None
|
||||||
|
self.image_sizes = None
|
||||||
|
self.image_grid_thw = None
|
||||||
|
|
||||||
def update_encoder_cache(self, encoder_outputs):
|
def update_encoder_cache(self, encoder_outputs):
|
||||||
prev = 0
|
prev = 0
|
||||||
for i, input in self.scheduled_image_input:
|
for i, input in self.scheduled_image_input:
|
||||||
length = input.num_placeholder_tokens
|
length = input.num_placeholder_tokens
|
||||||
output = encoder_outputs[prev:length]
|
self.encoder_cache[i][input.id] = scatter_image_embeds(
|
||||||
batch.encoder_cache[i][image_id] = self.scatter_image_embed(output, input.is_embed)
|
encoder_outputs[prev : prev + length], input.is_embed
|
||||||
|
)
|
||||||
|
|
||||||
prev = length
|
prev = prev + length
|
||||||
|
|
||||||
def get_mm_embeddings(self):
|
def get_mm_embeddings(self):
|
||||||
device = self.input_ids.device
|
device = self.input_ids.device
|
||||||
|
mm_embeds = []
|
||||||
for i, (
|
for i, (
|
||||||
r,
|
r,
|
||||||
cache_length,
|
cache_length,
|
||||||
@ -461,17 +554,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
self.prefilling_mask,
|
self.prefilling_mask,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
if not request_prefilling:
|
if not request_prefilling or self.image_positions[i] is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for mm_inputs in batch_mm_inputs[i]:
|
for j, image_position in enumerate(self.image_positions[i]):
|
||||||
for j, mm_input in enumerate(mm_input):
|
image_id = image_position.id
|
||||||
image_id = mm_input.id
|
|
||||||
pixel_values = self.all_pixel_values[i][j].pixel_values
|
|
||||||
|
|
||||||
start_pos = mm_input.offset
|
start_pos = image_position.offset
|
||||||
length = mm_input.length
|
length = image_position.length
|
||||||
num_placeholder_tokens = mm_input.num_placeholder_tokens
|
|
||||||
|
|
||||||
if start_pos >= cache_length + input_length:
|
if start_pos >= cache_length + input_length:
|
||||||
# No encoder input required at this step
|
# No encoder input required at this step
|
||||||
@ -480,25 +570,19 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
# The encode input is already processed
|
# The encode input is already processed
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.has_image = True
|
|
||||||
|
|
||||||
if image_id not in self.encoder_cache[i][image_id]:
|
|
||||||
self.scheduled_image_input.append(mm_input)
|
|
||||||
scheduled_image_pixel_values.append(pixel_values)
|
|
||||||
|
|
||||||
|
|
||||||
start_idx = max(cache_length - start_pos, 0)
|
start_idx = max(cache_length - start_pos, 0)
|
||||||
end_idx = min(
|
end_idx = min(cache_length - start_pos + input_length, length)
|
||||||
cache_length - start_pos + input_length,
|
|
||||||
length)
|
|
||||||
|
|
||||||
|
assert (
|
||||||
|
image_id in self.encoder_cache[i]
|
||||||
|
), f"image_id {image_id} not in encoder_cache {self.encoder_cache[i]}"
|
||||||
encoder_output = self.encoder_cache[i][image_id]
|
encoder_output = self.encoder_cache[i][image_id]
|
||||||
|
|
||||||
is_embed = pos_info.is_embed
|
is_embed = image_position.is_embed
|
||||||
if is_embed is not None:
|
if is_embed is not None:
|
||||||
is_embed = is_embed[start_idx:end_idx]
|
is_embed = is_embed[start_idx:end_idx]
|
||||||
|
|
||||||
mm_embeds_item = gather_mm_embeds(
|
mm_embeds_item = gather_image_embeds(
|
||||||
encoder_output[start_idx:end_idx],
|
encoder_output[start_idx:end_idx],
|
||||||
is_embed=is_embed,
|
is_embed=is_embed,
|
||||||
)
|
)
|
||||||
@ -506,6 +590,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
return torch.cat(mm_embeds, dim=0).to(device)
|
return torch.cat(mm_embeds, dim=0).to(device)
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLM(FlashCausalLM):
|
class VlmCausalLM(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -544,12 +629,24 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
|
|
||||||
def get_mm_embeddings(self, batch):
|
def get_mm_embeddings(self, batch):
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
encoder_outputs = self.model.get_mm_embeddings(batch.pixel_values)
|
encoder_outputs = self.get_vision_embeds(batch.pixel_values)
|
||||||
|
|
||||||
batch.update_encoder_cache(encoder_outputs)
|
batch.update_encoder_cache(encoder_outputs)
|
||||||
|
batch.pixel_values = None
|
||||||
return batch.get_mm_embeddings()
|
return batch.get_mm_embeddings()
|
||||||
|
|
||||||
|
def get_input_embeddings(self, batch):
|
||||||
|
if batch.has_image:
|
||||||
|
vision_embeddings = self.get_mm_embeddings(batch)
|
||||||
|
batch.has_image = False
|
||||||
|
else:
|
||||||
|
vision_embeddings = None
|
||||||
|
|
||||||
|
inputs_embeds = self.get_input_embeds(
|
||||||
|
batch.input_ids, vision_embeddings=vision_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
batch.inputs_embeds = inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: VlmCausalLMBatch,
|
batch: VlmCausalLMBatch,
|
||||||
@ -600,6 +697,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
position_ids = new_position_ids
|
position_ids = new_position_ids
|
||||||
else:
|
else:
|
||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
|
inputs_embeds = batch.inputs_embeds
|
||||||
position_ids = batch.position_ids
|
position_ids = batch.position_ids
|
||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache = self.kv_cache
|
kv_cache = self.kv_cache
|
||||||
@ -659,6 +757,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
Loading…
Reference in New Issue
Block a user