mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
rename vars
This commit is contained in:
parent
136b9897d4
commit
63ddba24b4
@ -248,9 +248,18 @@ Options:
|
||||
-p, --port <PORT>
|
||||
The port to listen on
|
||||
|
||||
[env: PORT=]
|
||||
[env: PORT=80]
|
||||
[default: 3000]
|
||||
|
||||
```
|
||||
## PROMETHEUS_PORT
|
||||
```shell
|
||||
-p, --prometheus-port <PROMETHEUS_PORT>
|
||||
The Prometheus port to listen on
|
||||
|
||||
[env: PROMETHEUS_PORT=]
|
||||
[default: 9000]
|
||||
|
||||
```
|
||||
## SHARD_UDS_PATH
|
||||
```shell
|
||||
|
@ -777,7 +777,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
image_features = image_features.view(-1, image_features.shape[-1])
|
||||
return image_features
|
||||
|
||||
def get_input_embeds(
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
@ -820,11 +820,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
||||
max_s += 1
|
||||
position_ids += 1
|
||||
|
||||
image_token_mask = (input_ids == self.config.image_token_index).to(
|
||||
input_ids.device
|
||||
)
|
||||
|
||||
if torch.any(image_token_mask):
|
||||
if pixel_values:
|
||||
attention_mask = self.get_attention_mask(
|
||||
input_ids,
|
||||
cu_seqlen_prefill,
|
||||
|
@ -80,7 +80,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
)
|
||||
return image_features
|
||||
|
||||
def get_input_embeds(
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
|
@ -801,7 +801,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
image_hidden_states = torch.stack(all_states, dim=0)
|
||||
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||
|
||||
def get_input_embeds(
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
|
@ -543,7 +543,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||
|
||||
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||
|
||||
def get_input_embeds(
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
|
@ -250,7 +250,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
return image_features.view(-1, image_features.shape[-1])
|
||||
|
||||
def get_input_embeds(
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
|
@ -931,7 +931,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
||||
return image_embeds
|
||||
|
||||
def get_input_embeds(
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
|
@ -509,7 +509,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
||||
return image_embeds
|
||||
|
||||
def get_input_embeds(
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: torch.Tensor = None,
|
||||
|
@ -590,7 +590,7 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
||||
projected_vision_flat = self.model.multi_modal_projector(vision_flat)
|
||||
return projected_vision_flat
|
||||
|
||||
def get_input_embeds(self, input_ids, vision_embeds=None):
|
||||
def get_inputs_embeds(self, input_ids, vision_embeds=None):
|
||||
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
||||
|
||||
if vision_embeds is not None:
|
||||
|
@ -110,7 +110,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
|
||||
def image_text_replacement(processor, image_input, config) -> str:
|
||||
if config.model_type == "idefics2":
|
||||
image_seq_len = 64
|
||||
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
||||
@ -119,8 +119,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
return image_str, IDEFICS2_FAKE_TOKEN
|
||||
if config.model_type == "idefics3":
|
||||
# TODO: implement this in a more general way
|
||||
n_rows = image_input[image_id]["rows"][0][0]
|
||||
n_cols = image_input[image_id]["cols"][0][0]
|
||||
n_rows = image_input["rows"][0][0]
|
||||
n_cols = image_input["cols"][0][0]
|
||||
image_seq_len = int(
|
||||
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
|
||||
/ (config.scale_factor**2)
|
||||
@ -135,7 +135,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
)
|
||||
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
|
||||
elif config.model_type == "llava_next":
|
||||
height, width = image_input[image_id]["image_sizes"][0]
|
||||
height, width = image_input["image_sizes"][0]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
|
||||
log_master(
|
||||
@ -147,12 +147,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
elif config.model_type == "paligemma":
|
||||
return "<image>" * config.text_config.num_image_tokens, "<image>"
|
||||
elif config.model_type == "qwen2_vl":
|
||||
grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0]
|
||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
|
||||
num_pads = grid_t * grid_h * grid_w // 4
|
||||
padding = "<|image_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||
elif config.model_type == "qwen2_5_vl":
|
||||
grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0]
|
||||
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
|
||||
num_pads = grid_t * grid_h * grid_w // 4
|
||||
padding = "<|image_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||
@ -166,8 +166,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
patch_size = config.vision_config.patch_size
|
||||
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
|
||||
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
|
||||
aspect_ratios = image_input[image_id]["aspect_ratios"][0]
|
||||
image_height, image_width = image_input[image_id]["pixel_values"][0].shape[-2:]
|
||||
aspect_ratios = image_input["aspect_ratios"][0]
|
||||
image_height, image_width = image_input["pixel_values"][0].shape[-2:]
|
||||
|
||||
num_patches_per_chunk = int(
|
||||
(image_height // patch_size)
|
||||
@ -264,7 +264,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
||||
return unpadded_features + newline_features + base_features
|
||||
|
||||
|
||||
def scatter_image_embeds(embeds, is_embed):
|
||||
def scatter_image_embeds(
|
||||
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
if is_embed is None:
|
||||
return embeds
|
||||
|
||||
@ -276,15 +278,13 @@ def scatter_image_embeds(embeds, is_embed):
|
||||
return placeholders
|
||||
|
||||
|
||||
def gather_image_embeds(embeds, is_embed):
|
||||
def gather_image_embeds(
|
||||
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
|
||||
) -> Optional[torch.Tensor]:
|
||||
if is_embed is None:
|
||||
return embeds
|
||||
|
||||
gathered = embeds[is_embed]
|
||||
|
||||
if len(gathered) == 0:
|
||||
return None
|
||||
return gathered
|
||||
sel = embeds[is_embed]
|
||||
return sel if sel.numel() else None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -304,6 +304,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||
image_sizes: Optional[List[Tuple[int, int]]]
|
||||
image_grid_thw: Optional[torch.Tensor]
|
||||
cache_entries_to_free: List[Tuple[int, int]]
|
||||
has_image_inputs: bool = False
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
@ -331,6 +333,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
batch.image_grid_thw = None
|
||||
# To be filled in prepare_for_prefill
|
||||
batch.has_image_inputs = False
|
||||
batch.cache_entries_to_free = []
|
||||
|
||||
return batch
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
@ -357,6 +363,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
batch.image_inputs = image_inputs
|
||||
batch.image_positions = image_positions
|
||||
batch.encoder_cache = encoder_cache
|
||||
|
||||
# To be filled in prepare_for_prefill
|
||||
batch.has_image_inputs = False
|
||||
batch.cache_entries_to_free = []
|
||||
return batch
|
||||
|
||||
@classmethod
|
||||
@ -404,7 +414,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
image_inputs.append(image_input)
|
||||
|
||||
img_text, img_start_token_str = image_text_replacement(
|
||||
processor, image_input, config, 0
|
||||
processor, image_input, config
|
||||
)
|
||||
text_parts.append(img_text)
|
||||
|
||||
@ -456,12 +466,17 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
for i in range(num_images):
|
||||
image_id, img_start_token_str, img_text = image_texts[i]
|
||||
img_text = image_text_replacement_fixup(config, img_text)
|
||||
|
||||
if config.model_type == "gemma3":
|
||||
img_text = img_text.replace("\n\n", "")
|
||||
|
||||
tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"]
|
||||
length = len(tokens)
|
||||
|
||||
assert length <= len(
|
||||
input_ids
|
||||
), f"{length} > {len(input_ids)} Image is truncated, try increasing --max-batch-prefill-tokens"
|
||||
|
||||
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
|
||||
index = img_start_token_pos[pos]
|
||||
|
||||
@ -502,151 +517,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
|
||||
return image_positions
|
||||
|
||||
@classmethod
|
||||
def batch_tokenized_inputs2(
|
||||
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
||||
):
|
||||
# sizes to insert correct number of image tokens.
|
||||
kwargs = {}
|
||||
if (
|
||||
hasattr(processor, "image_processor_class")
|
||||
and processor.image_processor_class == "Idefics3ImageProcessor"
|
||||
):
|
||||
kwargs["return_row_col_info"] = True
|
||||
|
||||
batch_image_inputs = []
|
||||
for i, r in enumerate(requests):
|
||||
image_inputs = []
|
||||
batch_image_inputs.append(None)
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
pass
|
||||
elif chunk_type == "image":
|
||||
image = Image.open(BytesIO(chunk.image.data))
|
||||
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
|
||||
# default warmup image is 20x20
|
||||
if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
|
||||
if image.width <= 20:
|
||||
w = image.width * 2
|
||||
h = image.height * 2
|
||||
image = image.resize((w, h))
|
||||
|
||||
if config.model_type in {"llava_next", "gemma3", "llama4"}:
|
||||
image = image
|
||||
elif config.model_type in {"paligemma"}:
|
||||
image = image.convert("RGB")
|
||||
else:
|
||||
image = [image]
|
||||
image_input = processor.image_processor(
|
||||
[image], return_tensors="pt", **kwargs
|
||||
)
|
||||
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
if len(image_inputs) > 0:
|
||||
batch_image_inputs[i] = image_inputs
|
||||
|
||||
batch_image_positions = []
|
||||
batch_tokenized_inputs = []
|
||||
max_length = 0
|
||||
for i, r in enumerate(requests):
|
||||
full_text = ""
|
||||
image_tokens = []
|
||||
batch_image_positions.append(None)
|
||||
image_id = 0
|
||||
for chunk in r.input_chunks.chunks:
|
||||
chunk_type = chunk.WhichOneof("chunk")
|
||||
if chunk_type == "text":
|
||||
full_text += chunk.text
|
||||
elif chunk_type == "image":
|
||||
image_text, id_token = image_text_replacement(
|
||||
processor, batch_image_inputs[i], config, image_id
|
||||
)
|
||||
full_text += image_text
|
||||
|
||||
if config.model_type == "gemma3":
|
||||
image_text = image_text.replace("\n\n", "")
|
||||
elif config.model_type == "idefics2":
|
||||
image_text = image_text_replacement_fixup(config, image_text)
|
||||
|
||||
image_tokens.append(
|
||||
(
|
||||
image_id,
|
||||
id_token,
|
||||
tokenizer(image_text, add_special_tokens=False)[
|
||||
"input_ids"
|
||||
],
|
||||
)
|
||||
)
|
||||
image_id += 1
|
||||
|
||||
full_text = image_text_replacement_fixup(config, full_text)
|
||||
input_ids = tokenizer(
|
||||
full_text,
|
||||
truncation=True,
|
||||
max_length=r.truncate,
|
||||
add_special_tokens=r.add_special_tokens,
|
||||
)["input_ids"]
|
||||
max_length = max(max_length, len(input_ids))
|
||||
batch_tokenized_inputs.append(input_ids)
|
||||
|
||||
prev = 0
|
||||
image_positions = []
|
||||
idefics_replacement = False
|
||||
|
||||
config.image_token_index = (
|
||||
config.image_token_index
|
||||
if hasattr(config, "image_token_index")
|
||||
else config.image_token_id
|
||||
)
|
||||
for image_id, id_token, tokens in image_tokens:
|
||||
id_token = tokenizer.get_vocab()[id_token]
|
||||
|
||||
if config.model_type == "idefics2" and idefics_replacement:
|
||||
id_token = config.image_token_index
|
||||
tokens = tokens[1:]
|
||||
|
||||
length = len(tokens)
|
||||
index = input_ids[prev:].index(id_token)
|
||||
|
||||
index = index + prev
|
||||
assert (
|
||||
input_ids[index : index + length] == tokens
|
||||
), "Image token not found in input_ids"
|
||||
|
||||
if config.model_type in {"llava_next", "paligemma"}:
|
||||
is_embed = None
|
||||
num_placeholder_tokens = length
|
||||
else:
|
||||
is_embed = torch.tensor(tokens) == config.image_token_index
|
||||
num_placeholder_tokens = is_embed.sum().item()
|
||||
|
||||
pos = ImagePositions(
|
||||
offset=index,
|
||||
length=length,
|
||||
id=image_id,
|
||||
num_placeholder_tokens=num_placeholder_tokens,
|
||||
is_embed=is_embed,
|
||||
)
|
||||
|
||||
image_positions.append(pos)
|
||||
prev = index + length
|
||||
|
||||
if config.model_type == "idefics2" and prev != len(input_ids):
|
||||
if input_ids[prev] == config.image_token_id:
|
||||
# means this is replaced by image_text_replacement_fixup
|
||||
idefics_replacement = True
|
||||
else:
|
||||
idefics_replacement = False
|
||||
|
||||
if len(image_positions) > 0:
|
||||
batch_image_positions[i] = image_positions
|
||||
|
||||
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
cls,
|
||||
@ -674,19 +544,17 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
def prepare_for_prefill(self):
|
||||
super().prepare_for_prefill()
|
||||
|
||||
self.has_image = False
|
||||
self.encoder_cache_to_free = []
|
||||
self.has_image_inputs = False
|
||||
self.cache_entries_to_free = []
|
||||
|
||||
self.pixel_values = []
|
||||
|
||||
for i, (
|
||||
r,
|
||||
cache_length,
|
||||
input_length,
|
||||
request_prefilling,
|
||||
) in enumerate(
|
||||
zip(
|
||||
self.requests,
|
||||
self.cache_lengths,
|
||||
self.input_lengths,
|
||||
self.prefilling_mask,
|
||||
@ -695,10 +563,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
if not request_prefilling or self.image_positions[i] is None:
|
||||
continue
|
||||
|
||||
for j, image_position in enumerate(self.image_positions[i]):
|
||||
image_id = image_position.id
|
||||
image_inputs = self.image_inputs[i][j]
|
||||
|
||||
for image_position in self.image_positions[i]:
|
||||
if image_position is None:
|
||||
continue
|
||||
start_pos = image_position.offset
|
||||
length = image_position.length
|
||||
|
||||
@ -709,47 +576,46 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
# The encode input is already processed
|
||||
continue
|
||||
|
||||
self.has_image = True
|
||||
self.has_image_inputs = True
|
||||
|
||||
if image_id not in self.encoder_cache[i]:
|
||||
self.pixel_values.append((i, image_position, image_inputs))
|
||||
self.image_inputs[i][j] = None
|
||||
if image_position.id not in self.encoder_cache[i]:
|
||||
image_inputs = self.image_inputs[i][image_position.id]
|
||||
self.pixel_values.append((i, image_position.id, image_inputs))
|
||||
|
||||
if not self.has_image:
|
||||
# Remove the image from the image_inputs
|
||||
self.image_inputs[i][image_position.id] = None
|
||||
|
||||
if not self.has_image_inputs:
|
||||
self.pixel_values = None
|
||||
self.pixel_attention_mask = None
|
||||
self.image_sizes = None
|
||||
self.image_grid_thw = None
|
||||
|
||||
def update_encoder_cache(self, encoder_outputs, request_id, input):
|
||||
self.encoder_cache[request_id][input.id] = scatter_image_embeds(
|
||||
encoder_outputs, input.is_embed
|
||||
def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
|
||||
self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
|
||||
encoder_outputs, img_pos.is_embed
|
||||
)
|
||||
|
||||
def get_mm_embeddings(self):
|
||||
def gather_vision_embeds(self):
|
||||
device = self.input_ids.device
|
||||
mm_embeds = []
|
||||
for i, (
|
||||
r,
|
||||
chunks = []
|
||||
for (
|
||||
i,
|
||||
cache_length,
|
||||
input_length,
|
||||
prompt_length,
|
||||
request_prefilling,
|
||||
) in enumerate(
|
||||
zip(
|
||||
self.requests,
|
||||
) in zip(
|
||||
range(len(self.requests)),
|
||||
self.cache_lengths,
|
||||
self.input_lengths,
|
||||
self.prompt_lengths,
|
||||
self.prefilling_mask,
|
||||
)
|
||||
):
|
||||
if not request_prefilling or self.image_positions[i] is None:
|
||||
continue
|
||||
|
||||
for j, image_position in enumerate(self.image_positions[i]):
|
||||
image_id = image_position.id
|
||||
|
||||
for image_position in self.image_positions[i]:
|
||||
if image_position is None:
|
||||
continue
|
||||
start_pos = image_position.offset
|
||||
length = image_position.length
|
||||
|
||||
@ -763,13 +629,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
start_idx = max(cache_length - start_pos, 0)
|
||||
end_idx = min(cache_length - start_pos + input_length, length)
|
||||
|
||||
if end_idx == length:
|
||||
self.encoder_cache_to_free.append((i, image_id))
|
||||
|
||||
assert (
|
||||
image_id in self.encoder_cache[i]
|
||||
), f"image_id {image_id} not in encoder_cache {self.encoder_cache[i]}"
|
||||
encoder_output = self.encoder_cache[i][image_id]
|
||||
image_position.id in self.encoder_cache[i]
|
||||
), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}"
|
||||
encoder_output = self.encoder_cache[i][image_position.id]
|
||||
|
||||
is_embed = image_position.is_embed
|
||||
if is_embed is not None:
|
||||
@ -778,25 +641,31 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
from loguru import logger
|
||||
|
||||
logger.info(
|
||||
f"image_id {image_id} start_idx {start_idx} end_idx {end_idx}, length {length}"
|
||||
f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}"
|
||||
)
|
||||
|
||||
mm_embeds_item = gather_image_embeds(
|
||||
embeds = gather_image_embeds(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
if mm_embeds_item is not None:
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
if embeds is not None:
|
||||
chunks.append(embeds)
|
||||
|
||||
if len(mm_embeds) == 0:
|
||||
if end_idx == length:
|
||||
self.cache_entries_to_free.append((i, image_position.id))
|
||||
self.image_positions[i][image_position.id] = None
|
||||
|
||||
if len(chunks) == 0:
|
||||
return None
|
||||
return torch.cat(mm_embeds, dim=0).to(device)
|
||||
return torch.cat(chunks, dim=0).to(device)
|
||||
|
||||
def free_encoder_cache(self):
|
||||
for i, image_id in self.encoder_cache_to_free:
|
||||
self.encoder_cache[i][image_id] = None
|
||||
for i, image_id in self.cache_entries_to_free:
|
||||
self.encoder_cache[i].pop(image_id, None)
|
||||
|
||||
self.encoder_cache_to_free = []
|
||||
self.cache_entries_to_free = []
|
||||
|
||||
# release any freed GPU memory immediately?
|
||||
|
||||
|
||||
class VlmCausalLM(FlashCausalLM):
|
||||
@ -997,20 +866,20 @@ class VlmCausalLM(FlashCausalLM):
|
||||
)
|
||||
return embeds
|
||||
|
||||
def get_input_embeds(
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
vision_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
return self.model.get_input_embeds(
|
||||
return self.model.get_inputs_embeds(
|
||||
input_ids=input_ids,
|
||||
vision_embeds=vision_embeds,
|
||||
)
|
||||
|
||||
def get_mm_embeddings(self, batch):
|
||||
def encode_images(self, batch):
|
||||
if batch.pixel_values is not None:
|
||||
device = batch.input_ids.device
|
||||
for request_id, image_position, image_input in batch.pixel_values:
|
||||
for request_id, image_id, image_input in batch.pixel_values:
|
||||
pixel_values = image_input["pixel_values"].to(device)
|
||||
|
||||
if "pixel_attention_mask" in image_input:
|
||||
@ -1036,19 +905,23 @@ class VlmCausalLM(FlashCausalLM):
|
||||
image_sizes=image_sizes,
|
||||
image_grid_thw=image_grid_thw,
|
||||
)
|
||||
batch.update_encoder_cache(encoder_outputs, request_id, image_position)
|
||||
batch.update_encoder_cache(
|
||||
encoder_outputs,
|
||||
request_id,
|
||||
batch.image_positions[request_id][image_id],
|
||||
)
|
||||
|
||||
batch.pixel_values = None
|
||||
return batch.get_mm_embeddings()
|
||||
|
||||
def get_input_embeddings(self, batch):
|
||||
if batch.has_image:
|
||||
vision_embeds = self.get_mm_embeddings(batch)
|
||||
batch.has_image = False
|
||||
if batch.has_image_inputs:
|
||||
self.encode_images(batch)
|
||||
vision_embeds = batch.gather_vision_embeds()
|
||||
batch.has_image_inputs = False
|
||||
else:
|
||||
vision_embeds = None
|
||||
|
||||
inputs_embeds = self.get_input_embeds(
|
||||
inputs_embeds = self.get_inputs_embeds(
|
||||
batch.input_ids, vision_embeds=vision_embeds
|
||||
)
|
||||
|
||||
@ -1127,6 +1000,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||
attention_mask = self.model.get_attention_mask(
|
||||
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
|
||||
).reshape(-1)
|
||||
batch.pixel_values = 1
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user