rename vars

This commit is contained in:
Mohit Sharma 2025-04-22 12:46:36 +00:00
parent 136b9897d4
commit 63ddba24b4
10 changed files with 112 additions and 233 deletions

View File

@ -248,9 +248,18 @@ Options:
-p, --port <PORT> -p, --port <PORT>
The port to listen on The port to listen on
[env: PORT=] [env: PORT=80]
[default: 3000] [default: 3000]
```
## PROMETHEUS_PORT
```shell
-p, --prometheus-port <PROMETHEUS_PORT>
The Prometheus port to listen on
[env: PROMETHEUS_PORT=]
[default: 9000]
``` ```
## SHARD_UDS_PATH ## SHARD_UDS_PATH
```shell ```shell

View File

@ -777,7 +777,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
image_features = image_features.view(-1, image_features.shape[-1]) image_features = image_features.view(-1, image_features.shape[-1])
return image_features return image_features
def get_input_embeds( def get_inputs_embeds(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None, vision_embeds: torch.Tensor = None,
@ -820,11 +820,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
max_s += 1 max_s += 1
position_ids += 1 position_ids += 1
image_token_mask = (input_ids == self.config.image_token_index).to( if pixel_values:
input_ids.device
)
if torch.any(image_token_mask):
attention_mask = self.get_attention_mask( attention_mask = self.get_attention_mask(
input_ids, input_ids,
cu_seqlen_prefill, cu_seqlen_prefill,

View File

@ -80,7 +80,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
) )
return image_features return image_features
def get_input_embeds( def get_inputs_embeds(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None, vision_embeds: torch.Tensor = None,

View File

@ -801,7 +801,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
image_hidden_states = torch.stack(all_states, dim=0) image_hidden_states = torch.stack(all_states, dim=0)
return image_hidden_states.view(-1, image_hidden_states.shape[-1]) return image_hidden_states.view(-1, image_hidden_states.shape[-1])
def get_input_embeds( def get_inputs_embeds(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None, vision_embeds: torch.Tensor = None,

View File

@ -543,7 +543,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
return image_hidden_states.view(-1, image_hidden_states.shape[-1]) return image_hidden_states.view(-1, image_hidden_states.shape[-1])
def get_input_embeds( def get_inputs_embeds(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None, vision_embeds: torch.Tensor = None,

View File

@ -250,7 +250,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
image_features = torch.stack(new_image_features, dim=0) image_features = torch.stack(new_image_features, dim=0)
return image_features.view(-1, image_features.shape[-1]) return image_features.view(-1, image_features.shape[-1])
def get_input_embeds( def get_inputs_embeds(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None, vision_embeds: torch.Tensor = None,

View File

@ -931,7 +931,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds return image_embeds
def get_input_embeds( def get_inputs_embeds(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None, vision_embeds: torch.Tensor = None,

View File

@ -509,7 +509,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds return image_embeds
def get_input_embeds( def get_inputs_embeds(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None, vision_embeds: torch.Tensor = None,

View File

@ -590,7 +590,7 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
projected_vision_flat = self.model.multi_modal_projector(vision_flat) projected_vision_flat = self.model.multi_modal_projector(vision_flat)
return projected_vision_flat return projected_vision_flat
def get_input_embeds(self, input_ids, vision_embeds=None): def get_inputs_embeds(self, input_ids, vision_embeds=None):
inputs_embeds = self.model.get_input_embeddings()(input_ids) inputs_embeds = self.model.get_input_embeddings()(input_ids)
if vision_embeds is not None: if vision_embeds is not None:

View File

@ -110,7 +110,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
def image_text_replacement(processor, image_input, config, image_id: int) -> str: def image_text_replacement(processor, image_input, config) -> str:
if config.model_type == "idefics2": if config.model_type == "idefics2":
image_seq_len = 64 image_seq_len = 64
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
@ -119,8 +119,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
return image_str, IDEFICS2_FAKE_TOKEN return image_str, IDEFICS2_FAKE_TOKEN
if config.model_type == "idefics3": if config.model_type == "idefics3":
# TODO: implement this in a more general way # TODO: implement this in a more general way
n_rows = image_input[image_id]["rows"][0][0] n_rows = image_input["rows"][0][0]
n_cols = image_input[image_id]["cols"][0][0] n_cols = image_input["cols"][0][0]
image_seq_len = int( image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2) ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
/ (config.scale_factor**2) / (config.scale_factor**2)
@ -135,7 +135,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
) )
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
elif config.model_type == "llava_next": elif config.model_type == "llava_next":
height, width = image_input[image_id]["image_sizes"][0] height, width = image_input["image_sizes"][0]
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
log_master( log_master(
@ -147,12 +147,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
elif config.model_type == "paligemma": elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens, "<image>" return "<image>" * config.text_config.num_image_tokens, "<image>"
elif config.model_type == "qwen2_vl": elif config.model_type == "qwen2_vl":
grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0] grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4 num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
elif config.model_type == "qwen2_5_vl": elif config.model_type == "qwen2_5_vl":
grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0] grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4 num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
@ -166,8 +166,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
patch_size = config.vision_config.patch_size patch_size = config.vision_config.patch_size
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2))) downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
aspect_ratios = image_input[image_id]["aspect_ratios"][0] aspect_ratios = image_input["aspect_ratios"][0]
image_height, image_width = image_input[image_id]["pixel_values"][0].shape[-2:] image_height, image_width = image_input["pixel_values"][0].shape[-2:]
num_patches_per_chunk = int( num_patches_per_chunk = int(
(image_height // patch_size) (image_height // patch_size)
@ -264,7 +264,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
return unpadded_features + newline_features + base_features return unpadded_features + newline_features + base_features
def scatter_image_embeds(embeds, is_embed): def scatter_image_embeds(
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
) -> torch.Tensor:
if is_embed is None: if is_embed is None:
return embeds return embeds
@ -276,15 +278,13 @@ def scatter_image_embeds(embeds, is_embed):
return placeholders return placeholders
def gather_image_embeds(embeds, is_embed): def gather_image_embeds(
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
) -> Optional[torch.Tensor]:
if is_embed is None: if is_embed is None:
return embeds return embeds
sel = embeds[is_embed]
gathered = embeds[is_embed] return sel if sel.numel() else None
if len(gathered) == 0:
return None
return gathered
@dataclass @dataclass
@ -304,6 +304,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_attention_mask: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]] image_sizes: Optional[List[Tuple[int, int]]]
image_grid_thw: Optional[torch.Tensor] image_grid_thw: Optional[torch.Tensor]
cache_entries_to_free: List[Tuple[int, int]]
has_image_inputs: bool = False
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
@ -331,6 +333,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
batch.image_grid_thw = None batch.image_grid_thw = None
# To be filled in prepare_for_prefill
batch.has_image_inputs = False
batch.cache_entries_to_free = []
return batch return batch
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -357,6 +363,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_inputs = image_inputs batch.image_inputs = image_inputs
batch.image_positions = image_positions batch.image_positions = image_positions
batch.encoder_cache = encoder_cache batch.encoder_cache = encoder_cache
# To be filled in prepare_for_prefill
batch.has_image_inputs = False
batch.cache_entries_to_free = []
return batch return batch
@classmethod @classmethod
@ -404,7 +414,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
image_inputs.append(image_input) image_inputs.append(image_input)
img_text, img_start_token_str = image_text_replacement( img_text, img_start_token_str = image_text_replacement(
processor, image_input, config, 0 processor, image_input, config
) )
text_parts.append(img_text) text_parts.append(img_text)
@ -456,12 +466,17 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
for i in range(num_images): for i in range(num_images):
image_id, img_start_token_str, img_text = image_texts[i] image_id, img_start_token_str, img_text = image_texts[i]
img_text = image_text_replacement_fixup(config, img_text) img_text = image_text_replacement_fixup(config, img_text)
if config.model_type == "gemma3": if config.model_type == "gemma3":
img_text = img_text.replace("\n\n", "") img_text = img_text.replace("\n\n", "")
tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"] tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"]
length = len(tokens) length = len(tokens)
assert length <= len(
input_ids
), f"{length} > {len(input_ids)} Image is truncated, try increasing --max-batch-prefill-tokens"
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False) pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
index = img_start_token_pos[pos] index = img_start_token_pos[pos]
@ -502,151 +517,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
return image_positions return image_positions
@classmethod
def batch_tokenized_inputs2(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
):
# sizes to insert correct number of image tokens.
kwargs = {}
if (
hasattr(processor, "image_processor_class")
and processor.image_processor_class == "Idefics3ImageProcessor"
):
kwargs["return_row_col_info"] = True
batch_image_inputs = []
for i, r in enumerate(requests):
image_inputs = []
batch_image_inputs.append(None)
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
pass
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
# default warmup image is 20x20
if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if image.width <= 20:
w = image.width * 2
h = image.height * 2
image = image.resize((w, h))
if config.model_type in {"llava_next", "gemma3", "llama4"}:
image = image
elif config.model_type in {"paligemma"}:
image = image.convert("RGB")
else:
image = [image]
image_input = processor.image_processor(
[image], return_tensors="pt", **kwargs
)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
if len(image_inputs) > 0:
batch_image_inputs[i] = image_inputs
batch_image_positions = []
batch_tokenized_inputs = []
max_length = 0
for i, r in enumerate(requests):
full_text = ""
image_tokens = []
batch_image_positions.append(None)
image_id = 0
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += chunk.text
elif chunk_type == "image":
image_text, id_token = image_text_replacement(
processor, batch_image_inputs[i], config, image_id
)
full_text += image_text
if config.model_type == "gemma3":
image_text = image_text.replace("\n\n", "")
elif config.model_type == "idefics2":
image_text = image_text_replacement_fixup(config, image_text)
image_tokens.append(
(
image_id,
id_token,
tokenizer(image_text, add_special_tokens=False)[
"input_ids"
],
)
)
image_id += 1
full_text = image_text_replacement_fixup(config, full_text)
input_ids = tokenizer(
full_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
max_length = max(max_length, len(input_ids))
batch_tokenized_inputs.append(input_ids)
prev = 0
image_positions = []
idefics_replacement = False
config.image_token_index = (
config.image_token_index
if hasattr(config, "image_token_index")
else config.image_token_id
)
for image_id, id_token, tokens in image_tokens:
id_token = tokenizer.get_vocab()[id_token]
if config.model_type == "idefics2" and idefics_replacement:
id_token = config.image_token_index
tokens = tokens[1:]
length = len(tokens)
index = input_ids[prev:].index(id_token)
index = index + prev
assert (
input_ids[index : index + length] == tokens
), "Image token not found in input_ids"
if config.model_type in {"llava_next", "paligemma"}:
is_embed = None
num_placeholder_tokens = length
else:
is_embed = torch.tensor(tokens) == config.image_token_index
num_placeholder_tokens = is_embed.sum().item()
pos = ImagePositions(
offset=index,
length=length,
id=image_id,
num_placeholder_tokens=num_placeholder_tokens,
is_embed=is_embed,
)
image_positions.append(pos)
prev = index + length
if config.model_type == "idefics2" and prev != len(input_ids):
if input_ids[prev] == config.image_token_id:
# means this is replaced by image_text_replacement_fixup
idefics_replacement = True
else:
idefics_replacement = False
if len(image_positions) > 0:
batch_image_positions[i] = image_positions
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
@classmethod @classmethod
def from_pb_processor( def from_pb_processor(
cls, cls,
@ -674,19 +544,17 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
def prepare_for_prefill(self): def prepare_for_prefill(self):
super().prepare_for_prefill() super().prepare_for_prefill()
self.has_image = False self.has_image_inputs = False
self.encoder_cache_to_free = [] self.cache_entries_to_free = []
self.pixel_values = [] self.pixel_values = []
for i, ( for i, (
r,
cache_length, cache_length,
input_length, input_length,
request_prefilling, request_prefilling,
) in enumerate( ) in enumerate(
zip( zip(
self.requests,
self.cache_lengths, self.cache_lengths,
self.input_lengths, self.input_lengths,
self.prefilling_mask, self.prefilling_mask,
@ -695,10 +563,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if not request_prefilling or self.image_positions[i] is None: if not request_prefilling or self.image_positions[i] is None:
continue continue
for j, image_position in enumerate(self.image_positions[i]): for image_position in self.image_positions[i]:
image_id = image_position.id if image_position is None:
image_inputs = self.image_inputs[i][j] continue
start_pos = image_position.offset start_pos = image_position.offset
length = image_position.length length = image_position.length
@ -709,47 +576,46 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
# The encode input is already processed # The encode input is already processed
continue continue
self.has_image = True self.has_image_inputs = True
if image_id not in self.encoder_cache[i]: if image_position.id not in self.encoder_cache[i]:
self.pixel_values.append((i, image_position, image_inputs)) image_inputs = self.image_inputs[i][image_position.id]
self.image_inputs[i][j] = None self.pixel_values.append((i, image_position.id, image_inputs))
if not self.has_image: # Remove the image from the image_inputs
self.image_inputs[i][image_position.id] = None
if not self.has_image_inputs:
self.pixel_values = None self.pixel_values = None
self.pixel_attention_mask = None self.pixel_attention_mask = None
self.image_sizes = None self.image_sizes = None
self.image_grid_thw = None self.image_grid_thw = None
def update_encoder_cache(self, encoder_outputs, request_id, input): def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
self.encoder_cache[request_id][input.id] = scatter_image_embeds( self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
encoder_outputs, input.is_embed encoder_outputs, img_pos.is_embed
) )
def get_mm_embeddings(self): def gather_vision_embeds(self):
device = self.input_ids.device device = self.input_ids.device
mm_embeds = [] chunks = []
for i, ( for (
r, i,
cache_length, cache_length,
input_length, input_length,
prompt_length,
request_prefilling, request_prefilling,
) in enumerate( ) in zip(
zip( range(len(self.requests)),
self.requests, self.cache_lengths,
self.cache_lengths, self.input_lengths,
self.input_lengths, self.prefilling_mask,
self.prompt_lengths,
self.prefilling_mask,
)
): ):
if not request_prefilling or self.image_positions[i] is None: if not request_prefilling or self.image_positions[i] is None:
continue continue
for j, image_position in enumerate(self.image_positions[i]): for image_position in self.image_positions[i]:
image_id = image_position.id if image_position is None:
continue
start_pos = image_position.offset start_pos = image_position.offset
length = image_position.length length = image_position.length
@ -763,13 +629,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
start_idx = max(cache_length - start_pos, 0) start_idx = max(cache_length - start_pos, 0)
end_idx = min(cache_length - start_pos + input_length, length) end_idx = min(cache_length - start_pos + input_length, length)
if end_idx == length:
self.encoder_cache_to_free.append((i, image_id))
assert ( assert (
image_id in self.encoder_cache[i] image_position.id in self.encoder_cache[i]
), f"image_id {image_id} not in encoder_cache {self.encoder_cache[i]}" ), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}"
encoder_output = self.encoder_cache[i][image_id] encoder_output = self.encoder_cache[i][image_position.id]
is_embed = image_position.is_embed is_embed = image_position.is_embed
if is_embed is not None: if is_embed is not None:
@ -778,25 +641,31 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
from loguru import logger from loguru import logger
logger.info( logger.info(
f"image_id {image_id} start_idx {start_idx} end_idx {end_idx}, length {length}" f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}"
) )
mm_embeds_item = gather_image_embeds( embeds = gather_image_embeds(
encoder_output[start_idx:end_idx], encoder_output[start_idx:end_idx],
is_embed=is_embed, is_embed=is_embed,
) )
if mm_embeds_item is not None: if embeds is not None:
mm_embeds.append(mm_embeds_item) chunks.append(embeds)
if len(mm_embeds) == 0: if end_idx == length:
self.cache_entries_to_free.append((i, image_position.id))
self.image_positions[i][image_position.id] = None
if len(chunks) == 0:
return None return None
return torch.cat(mm_embeds, dim=0).to(device) return torch.cat(chunks, dim=0).to(device)
def free_encoder_cache(self): def free_encoder_cache(self):
for i, image_id in self.encoder_cache_to_free: for i, image_id in self.cache_entries_to_free:
self.encoder_cache[i][image_id] = None self.encoder_cache[i].pop(image_id, None)
self.encoder_cache_to_free = [] self.cache_entries_to_free = []
# release any freed GPU memory immediately?
class VlmCausalLM(FlashCausalLM): class VlmCausalLM(FlashCausalLM):
@ -997,20 +866,20 @@ class VlmCausalLM(FlashCausalLM):
) )
return embeds return embeds
def get_input_embeds( def get_inputs_embeds(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
vision_embeds: Optional[torch.Tensor] = None, vision_embeds: Optional[torch.Tensor] = None,
): ):
return self.model.get_input_embeds( return self.model.get_inputs_embeds(
input_ids=input_ids, input_ids=input_ids,
vision_embeds=vision_embeds, vision_embeds=vision_embeds,
) )
def get_mm_embeddings(self, batch): def encode_images(self, batch):
if batch.pixel_values is not None: if batch.pixel_values is not None:
device = batch.input_ids.device device = batch.input_ids.device
for request_id, image_position, image_input in batch.pixel_values: for request_id, image_id, image_input in batch.pixel_values:
pixel_values = image_input["pixel_values"].to(device) pixel_values = image_input["pixel_values"].to(device)
if "pixel_attention_mask" in image_input: if "pixel_attention_mask" in image_input:
@ -1036,19 +905,23 @@ class VlmCausalLM(FlashCausalLM):
image_sizes=image_sizes, image_sizes=image_sizes,
image_grid_thw=image_grid_thw, image_grid_thw=image_grid_thw,
) )
batch.update_encoder_cache(encoder_outputs, request_id, image_position) batch.update_encoder_cache(
encoder_outputs,
request_id,
batch.image_positions[request_id][image_id],
)
batch.pixel_values = None batch.pixel_values = None
return batch.get_mm_embeddings()
def get_input_embeddings(self, batch): def get_input_embeddings(self, batch):
if batch.has_image: if batch.has_image_inputs:
vision_embeds = self.get_mm_embeddings(batch) self.encode_images(batch)
batch.has_image = False vision_embeds = batch.gather_vision_embeds()
batch.has_image_inputs = False
else: else:
vision_embeds = None vision_embeds = None
inputs_embeds = self.get_input_embeds( inputs_embeds = self.get_inputs_embeds(
batch.input_ids, vision_embeds=vision_embeds batch.input_ids, vision_embeds=vision_embeds
) )
@ -1127,6 +1000,7 @@ class VlmCausalLM(FlashCausalLM):
attention_mask = self.model.get_attention_mask( attention_mask = self.model.get_attention_mask(
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
).reshape(-1) ).reshape(-1)
batch.pixel_values = 1
else: else:
attention_mask = None attention_mask = None