mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
rename vars
This commit is contained in:
parent
136b9897d4
commit
63ddba24b4
@ -248,9 +248,18 @@ Options:
|
|||||||
-p, --port <PORT>
|
-p, --port <PORT>
|
||||||
The port to listen on
|
The port to listen on
|
||||||
|
|
||||||
[env: PORT=]
|
[env: PORT=80]
|
||||||
[default: 3000]
|
[default: 3000]
|
||||||
|
|
||||||
|
```
|
||||||
|
## PROMETHEUS_PORT
|
||||||
|
```shell
|
||||||
|
-p, --prometheus-port <PROMETHEUS_PORT>
|
||||||
|
The Prometheus port to listen on
|
||||||
|
|
||||||
|
[env: PROMETHEUS_PORT=]
|
||||||
|
[default: 9000]
|
||||||
|
|
||||||
```
|
```
|
||||||
## SHARD_UDS_PATH
|
## SHARD_UDS_PATH
|
||||||
```shell
|
```shell
|
||||||
|
@ -777,7 +777,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
image_features = image_features.view(-1, image_features.shape[-1])
|
image_features = image_features.view(-1, image_features.shape[-1])
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
def get_input_embeds(
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
vision_embeds: torch.Tensor = None,
|
vision_embeds: torch.Tensor = None,
|
||||||
@ -820,11 +820,7 @@ class Gemma3ForConditionalGeneration(nn.Module):
|
|||||||
max_s += 1
|
max_s += 1
|
||||||
position_ids += 1
|
position_ids += 1
|
||||||
|
|
||||||
image_token_mask = (input_ids == self.config.image_token_index).to(
|
if pixel_values:
|
||||||
input_ids.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if torch.any(image_token_mask):
|
|
||||||
attention_mask = self.get_attention_mask(
|
attention_mask = self.get_attention_mask(
|
||||||
input_ids,
|
input_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
|
@ -80,7 +80,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
def get_input_embeds(
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
vision_embeds: torch.Tensor = None,
|
vision_embeds: torch.Tensor = None,
|
||||||
|
@ -801,7 +801,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
image_hidden_states = torch.stack(all_states, dim=0)
|
image_hidden_states = torch.stack(all_states, dim=0)
|
||||||
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||||
|
|
||||||
def get_input_embeds(
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
vision_embeds: torch.Tensor = None,
|
vision_embeds: torch.Tensor = None,
|
||||||
|
@ -543,7 +543,7 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||||
|
|
||||||
def get_input_embeds(
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
vision_embeds: torch.Tensor = None,
|
vision_embeds: torch.Tensor = None,
|
||||||
|
@ -250,7 +250,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
image_features = torch.stack(new_image_features, dim=0)
|
image_features = torch.stack(new_image_features, dim=0)
|
||||||
return image_features.view(-1, image_features.shape[-1])
|
return image_features.view(-1, image_features.shape[-1])
|
||||||
|
|
||||||
def get_input_embeds(
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
vision_embeds: torch.Tensor = None,
|
vision_embeds: torch.Tensor = None,
|
||||||
|
@ -931,7 +931,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module):
|
|||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
||||||
return image_embeds
|
return image_embeds
|
||||||
|
|
||||||
def get_input_embeds(
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
vision_embeds: torch.Tensor = None,
|
vision_embeds: torch.Tensor = None,
|
||||||
|
@ -509,7 +509,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
|
||||||
return image_embeds
|
return image_embeds
|
||||||
|
|
||||||
def get_input_embeds(
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
vision_embeds: torch.Tensor = None,
|
vision_embeds: torch.Tensor = None,
|
||||||
|
@ -590,7 +590,7 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):
|
|||||||
projected_vision_flat = self.model.multi_modal_projector(vision_flat)
|
projected_vision_flat = self.model.multi_modal_projector(vision_flat)
|
||||||
return projected_vision_flat
|
return projected_vision_flat
|
||||||
|
|
||||||
def get_input_embeds(self, input_ids, vision_embeds=None):
|
def get_inputs_embeds(self, input_ids, vision_embeds=None):
|
||||||
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
if vision_embeds is not None:
|
if vision_embeds is not None:
|
||||||
|
@ -110,7 +110,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|||||||
return height // patch_size, width // patch_size
|
return height // patch_size, width // patch_size
|
||||||
|
|
||||||
|
|
||||||
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
|
def image_text_replacement(processor, image_input, config) -> str:
|
||||||
if config.model_type == "idefics2":
|
if config.model_type == "idefics2":
|
||||||
image_seq_len = 64
|
image_seq_len = 64
|
||||||
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
|
||||||
@ -119,8 +119,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
return image_str, IDEFICS2_FAKE_TOKEN
|
return image_str, IDEFICS2_FAKE_TOKEN
|
||||||
if config.model_type == "idefics3":
|
if config.model_type == "idefics3":
|
||||||
# TODO: implement this in a more general way
|
# TODO: implement this in a more general way
|
||||||
n_rows = image_input[image_id]["rows"][0][0]
|
n_rows = image_input["rows"][0][0]
|
||||||
n_cols = image_input[image_id]["cols"][0][0]
|
n_cols = image_input["cols"][0][0]
|
||||||
image_seq_len = int(
|
image_seq_len = int(
|
||||||
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
|
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
|
||||||
/ (config.scale_factor**2)
|
/ (config.scale_factor**2)
|
||||||
@ -135,7 +135,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
)
|
)
|
||||||
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
|
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
|
||||||
elif config.model_type == "llava_next":
|
elif config.model_type == "llava_next":
|
||||||
height, width = image_input[image_id]["image_sizes"][0]
|
height, width = image_input["image_sizes"][0]
|
||||||
num_features = get_number_of_features(height, width, config)
|
num_features = get_number_of_features(height, width, config)
|
||||||
|
|
||||||
log_master(
|
log_master(
|
||||||
@ -147,12 +147,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
elif config.model_type == "paligemma":
|
elif config.model_type == "paligemma":
|
||||||
return "<image>" * config.text_config.num_image_tokens, "<image>"
|
return "<image>" * config.text_config.num_image_tokens, "<image>"
|
||||||
elif config.model_type == "qwen2_vl":
|
elif config.model_type == "qwen2_vl":
|
||||||
grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0]
|
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
|
||||||
num_pads = grid_t * grid_h * grid_w // 4
|
num_pads = grid_t * grid_h * grid_w // 4
|
||||||
padding = "<|image_pad|>" * num_pads
|
padding = "<|image_pad|>" * num_pads
|
||||||
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||||
elif config.model_type == "qwen2_5_vl":
|
elif config.model_type == "qwen2_5_vl":
|
||||||
grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0]
|
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
|
||||||
num_pads = grid_t * grid_h * grid_w // 4
|
num_pads = grid_t * grid_h * grid_w // 4
|
||||||
padding = "<|image_pad|>" * num_pads
|
padding = "<|image_pad|>" * num_pads
|
||||||
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
|
||||||
@ -166,8 +166,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
patch_size = config.vision_config.patch_size
|
patch_size = config.vision_config.patch_size
|
||||||
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
|
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
|
||||||
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
|
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
|
||||||
aspect_ratios = image_input[image_id]["aspect_ratios"][0]
|
aspect_ratios = image_input["aspect_ratios"][0]
|
||||||
image_height, image_width = image_input[image_id]["pixel_values"][0].shape[-2:]
|
image_height, image_width = image_input["pixel_values"][0].shape[-2:]
|
||||||
|
|
||||||
num_patches_per_chunk = int(
|
num_patches_per_chunk = int(
|
||||||
(image_height // patch_size)
|
(image_height // patch_size)
|
||||||
@ -264,7 +264,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||||||
return unpadded_features + newline_features + base_features
|
return unpadded_features + newline_features + base_features
|
||||||
|
|
||||||
|
|
||||||
def scatter_image_embeds(embeds, is_embed):
|
def scatter_image_embeds(
|
||||||
|
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
|
||||||
|
) -> torch.Tensor:
|
||||||
if is_embed is None:
|
if is_embed is None:
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
@ -276,15 +278,13 @@ def scatter_image_embeds(embeds, is_embed):
|
|||||||
return placeholders
|
return placeholders
|
||||||
|
|
||||||
|
|
||||||
def gather_image_embeds(embeds, is_embed):
|
def gather_image_embeds(
|
||||||
|
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
if is_embed is None:
|
if is_embed is None:
|
||||||
return embeds
|
return embeds
|
||||||
|
sel = embeds[is_embed]
|
||||||
gathered = embeds[is_embed]
|
return sel if sel.numel() else None
|
||||||
|
|
||||||
if len(gathered) == 0:
|
|
||||||
return None
|
|
||||||
return gathered
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -304,6 +304,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
image_sizes: Optional[List[Tuple[int, int]]]
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
image_grid_thw: Optional[torch.Tensor]
|
image_grid_thw: Optional[torch.Tensor]
|
||||||
|
cache_entries_to_free: List[Tuple[int, int]]
|
||||||
|
has_image_inputs: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
@ -331,6 +333,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
# To be filled in prepare_for_prefill
|
||||||
|
batch.has_image_inputs = False
|
||||||
|
batch.cache_entries_to_free = []
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
@ -357,6 +363,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
batch.image_inputs = image_inputs
|
batch.image_inputs = image_inputs
|
||||||
batch.image_positions = image_positions
|
batch.image_positions = image_positions
|
||||||
batch.encoder_cache = encoder_cache
|
batch.encoder_cache = encoder_cache
|
||||||
|
|
||||||
|
# To be filled in prepare_for_prefill
|
||||||
|
batch.has_image_inputs = False
|
||||||
|
batch.cache_entries_to_free = []
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -404,7 +414,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
image_inputs.append(image_input)
|
image_inputs.append(image_input)
|
||||||
|
|
||||||
img_text, img_start_token_str = image_text_replacement(
|
img_text, img_start_token_str = image_text_replacement(
|
||||||
processor, image_input, config, 0
|
processor, image_input, config
|
||||||
)
|
)
|
||||||
text_parts.append(img_text)
|
text_parts.append(img_text)
|
||||||
|
|
||||||
@ -456,12 +466,17 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
image_id, img_start_token_str, img_text = image_texts[i]
|
image_id, img_start_token_str, img_text = image_texts[i]
|
||||||
img_text = image_text_replacement_fixup(config, img_text)
|
img_text = image_text_replacement_fixup(config, img_text)
|
||||||
|
|
||||||
if config.model_type == "gemma3":
|
if config.model_type == "gemma3":
|
||||||
img_text = img_text.replace("\n\n", "")
|
img_text = img_text.replace("\n\n", "")
|
||||||
|
|
||||||
tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"]
|
tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"]
|
||||||
length = len(tokens)
|
length = len(tokens)
|
||||||
|
|
||||||
|
assert length <= len(
|
||||||
|
input_ids
|
||||||
|
), f"{length} > {len(input_ids)} Image is truncated, try increasing --max-batch-prefill-tokens"
|
||||||
|
|
||||||
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
|
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
|
||||||
index = img_start_token_pos[pos]
|
index = img_start_token_pos[pos]
|
||||||
|
|
||||||
@ -502,151 +517,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
return image_positions
|
return image_positions
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def batch_tokenized_inputs2(
|
|
||||||
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
|
|
||||||
):
|
|
||||||
# sizes to insert correct number of image tokens.
|
|
||||||
kwargs = {}
|
|
||||||
if (
|
|
||||||
hasattr(processor, "image_processor_class")
|
|
||||||
and processor.image_processor_class == "Idefics3ImageProcessor"
|
|
||||||
):
|
|
||||||
kwargs["return_row_col_info"] = True
|
|
||||||
|
|
||||||
batch_image_inputs = []
|
|
||||||
for i, r in enumerate(requests):
|
|
||||||
image_inputs = []
|
|
||||||
batch_image_inputs.append(None)
|
|
||||||
for chunk in r.input_chunks.chunks:
|
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
|
||||||
if chunk_type == "text":
|
|
||||||
pass
|
|
||||||
elif chunk_type == "image":
|
|
||||||
image = Image.open(BytesIO(chunk.image.data))
|
|
||||||
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
|
|
||||||
# default warmup image is 20x20
|
|
||||||
if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
|
|
||||||
if image.width <= 20:
|
|
||||||
w = image.width * 2
|
|
||||||
h = image.height * 2
|
|
||||||
image = image.resize((w, h))
|
|
||||||
|
|
||||||
if config.model_type in {"llava_next", "gemma3", "llama4"}:
|
|
||||||
image = image
|
|
||||||
elif config.model_type in {"paligemma"}:
|
|
||||||
image = image.convert("RGB")
|
|
||||||
else:
|
|
||||||
image = [image]
|
|
||||||
image_input = processor.image_processor(
|
|
||||||
[image], return_tensors="pt", **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
image_inputs.append(image_input)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
|
||||||
|
|
||||||
if len(image_inputs) > 0:
|
|
||||||
batch_image_inputs[i] = image_inputs
|
|
||||||
|
|
||||||
batch_image_positions = []
|
|
||||||
batch_tokenized_inputs = []
|
|
||||||
max_length = 0
|
|
||||||
for i, r in enumerate(requests):
|
|
||||||
full_text = ""
|
|
||||||
image_tokens = []
|
|
||||||
batch_image_positions.append(None)
|
|
||||||
image_id = 0
|
|
||||||
for chunk in r.input_chunks.chunks:
|
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
|
||||||
if chunk_type == "text":
|
|
||||||
full_text += chunk.text
|
|
||||||
elif chunk_type == "image":
|
|
||||||
image_text, id_token = image_text_replacement(
|
|
||||||
processor, batch_image_inputs[i], config, image_id
|
|
||||||
)
|
|
||||||
full_text += image_text
|
|
||||||
|
|
||||||
if config.model_type == "gemma3":
|
|
||||||
image_text = image_text.replace("\n\n", "")
|
|
||||||
elif config.model_type == "idefics2":
|
|
||||||
image_text = image_text_replacement_fixup(config, image_text)
|
|
||||||
|
|
||||||
image_tokens.append(
|
|
||||||
(
|
|
||||||
image_id,
|
|
||||||
id_token,
|
|
||||||
tokenizer(image_text, add_special_tokens=False)[
|
|
||||||
"input_ids"
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
image_id += 1
|
|
||||||
|
|
||||||
full_text = image_text_replacement_fixup(config, full_text)
|
|
||||||
input_ids = tokenizer(
|
|
||||||
full_text,
|
|
||||||
truncation=True,
|
|
||||||
max_length=r.truncate,
|
|
||||||
add_special_tokens=r.add_special_tokens,
|
|
||||||
)["input_ids"]
|
|
||||||
max_length = max(max_length, len(input_ids))
|
|
||||||
batch_tokenized_inputs.append(input_ids)
|
|
||||||
|
|
||||||
prev = 0
|
|
||||||
image_positions = []
|
|
||||||
idefics_replacement = False
|
|
||||||
|
|
||||||
config.image_token_index = (
|
|
||||||
config.image_token_index
|
|
||||||
if hasattr(config, "image_token_index")
|
|
||||||
else config.image_token_id
|
|
||||||
)
|
|
||||||
for image_id, id_token, tokens in image_tokens:
|
|
||||||
id_token = tokenizer.get_vocab()[id_token]
|
|
||||||
|
|
||||||
if config.model_type == "idefics2" and idefics_replacement:
|
|
||||||
id_token = config.image_token_index
|
|
||||||
tokens = tokens[1:]
|
|
||||||
|
|
||||||
length = len(tokens)
|
|
||||||
index = input_ids[prev:].index(id_token)
|
|
||||||
|
|
||||||
index = index + prev
|
|
||||||
assert (
|
|
||||||
input_ids[index : index + length] == tokens
|
|
||||||
), "Image token not found in input_ids"
|
|
||||||
|
|
||||||
if config.model_type in {"llava_next", "paligemma"}:
|
|
||||||
is_embed = None
|
|
||||||
num_placeholder_tokens = length
|
|
||||||
else:
|
|
||||||
is_embed = torch.tensor(tokens) == config.image_token_index
|
|
||||||
num_placeholder_tokens = is_embed.sum().item()
|
|
||||||
|
|
||||||
pos = ImagePositions(
|
|
||||||
offset=index,
|
|
||||||
length=length,
|
|
||||||
id=image_id,
|
|
||||||
num_placeholder_tokens=num_placeholder_tokens,
|
|
||||||
is_embed=is_embed,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_positions.append(pos)
|
|
||||||
prev = index + length
|
|
||||||
|
|
||||||
if config.model_type == "idefics2" and prev != len(input_ids):
|
|
||||||
if input_ids[prev] == config.image_token_id:
|
|
||||||
# means this is replaced by image_text_replacement_fixup
|
|
||||||
idefics_replacement = True
|
|
||||||
else:
|
|
||||||
idefics_replacement = False
|
|
||||||
|
|
||||||
if len(image_positions) > 0:
|
|
||||||
batch_image_positions[i] = image_positions
|
|
||||||
|
|
||||||
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb_processor(
|
def from_pb_processor(
|
||||||
cls,
|
cls,
|
||||||
@ -674,19 +544,17 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
def prepare_for_prefill(self):
|
def prepare_for_prefill(self):
|
||||||
super().prepare_for_prefill()
|
super().prepare_for_prefill()
|
||||||
|
|
||||||
self.has_image = False
|
self.has_image_inputs = False
|
||||||
self.encoder_cache_to_free = []
|
self.cache_entries_to_free = []
|
||||||
|
|
||||||
self.pixel_values = []
|
self.pixel_values = []
|
||||||
|
|
||||||
for i, (
|
for i, (
|
||||||
r,
|
|
||||||
cache_length,
|
cache_length,
|
||||||
input_length,
|
input_length,
|
||||||
request_prefilling,
|
request_prefilling,
|
||||||
) in enumerate(
|
) in enumerate(
|
||||||
zip(
|
zip(
|
||||||
self.requests,
|
|
||||||
self.cache_lengths,
|
self.cache_lengths,
|
||||||
self.input_lengths,
|
self.input_lengths,
|
||||||
self.prefilling_mask,
|
self.prefilling_mask,
|
||||||
@ -695,10 +563,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
if not request_prefilling or self.image_positions[i] is None:
|
if not request_prefilling or self.image_positions[i] is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for j, image_position in enumerate(self.image_positions[i]):
|
for image_position in self.image_positions[i]:
|
||||||
image_id = image_position.id
|
if image_position is None:
|
||||||
image_inputs = self.image_inputs[i][j]
|
continue
|
||||||
|
|
||||||
start_pos = image_position.offset
|
start_pos = image_position.offset
|
||||||
length = image_position.length
|
length = image_position.length
|
||||||
|
|
||||||
@ -709,47 +576,46 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
# The encode input is already processed
|
# The encode input is already processed
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.has_image = True
|
self.has_image_inputs = True
|
||||||
|
|
||||||
if image_id not in self.encoder_cache[i]:
|
if image_position.id not in self.encoder_cache[i]:
|
||||||
self.pixel_values.append((i, image_position, image_inputs))
|
image_inputs = self.image_inputs[i][image_position.id]
|
||||||
self.image_inputs[i][j] = None
|
self.pixel_values.append((i, image_position.id, image_inputs))
|
||||||
|
|
||||||
if not self.has_image:
|
# Remove the image from the image_inputs
|
||||||
|
self.image_inputs[i][image_position.id] = None
|
||||||
|
|
||||||
|
if not self.has_image_inputs:
|
||||||
self.pixel_values = None
|
self.pixel_values = None
|
||||||
self.pixel_attention_mask = None
|
self.pixel_attention_mask = None
|
||||||
self.image_sizes = None
|
self.image_sizes = None
|
||||||
self.image_grid_thw = None
|
self.image_grid_thw = None
|
||||||
|
|
||||||
def update_encoder_cache(self, encoder_outputs, request_id, input):
|
def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
|
||||||
self.encoder_cache[request_id][input.id] = scatter_image_embeds(
|
self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
|
||||||
encoder_outputs, input.is_embed
|
encoder_outputs, img_pos.is_embed
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_mm_embeddings(self):
|
def gather_vision_embeds(self):
|
||||||
device = self.input_ids.device
|
device = self.input_ids.device
|
||||||
mm_embeds = []
|
chunks = []
|
||||||
for i, (
|
for (
|
||||||
r,
|
i,
|
||||||
cache_length,
|
cache_length,
|
||||||
input_length,
|
input_length,
|
||||||
prompt_length,
|
|
||||||
request_prefilling,
|
request_prefilling,
|
||||||
) in enumerate(
|
) in zip(
|
||||||
zip(
|
range(len(self.requests)),
|
||||||
self.requests,
|
self.cache_lengths,
|
||||||
self.cache_lengths,
|
self.input_lengths,
|
||||||
self.input_lengths,
|
self.prefilling_mask,
|
||||||
self.prompt_lengths,
|
|
||||||
self.prefilling_mask,
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
if not request_prefilling or self.image_positions[i] is None:
|
if not request_prefilling or self.image_positions[i] is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for j, image_position in enumerate(self.image_positions[i]):
|
for image_position in self.image_positions[i]:
|
||||||
image_id = image_position.id
|
if image_position is None:
|
||||||
|
continue
|
||||||
start_pos = image_position.offset
|
start_pos = image_position.offset
|
||||||
length = image_position.length
|
length = image_position.length
|
||||||
|
|
||||||
@ -763,13 +629,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
start_idx = max(cache_length - start_pos, 0)
|
start_idx = max(cache_length - start_pos, 0)
|
||||||
end_idx = min(cache_length - start_pos + input_length, length)
|
end_idx = min(cache_length - start_pos + input_length, length)
|
||||||
|
|
||||||
if end_idx == length:
|
|
||||||
self.encoder_cache_to_free.append((i, image_id))
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
image_id in self.encoder_cache[i]
|
image_position.id in self.encoder_cache[i]
|
||||||
), f"image_id {image_id} not in encoder_cache {self.encoder_cache[i]}"
|
), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}"
|
||||||
encoder_output = self.encoder_cache[i][image_id]
|
encoder_output = self.encoder_cache[i][image_position.id]
|
||||||
|
|
||||||
is_embed = image_position.is_embed
|
is_embed = image_position.is_embed
|
||||||
if is_embed is not None:
|
if is_embed is not None:
|
||||||
@ -778,25 +641,31 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"image_id {image_id} start_idx {start_idx} end_idx {end_idx}, length {length}"
|
f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}"
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_embeds_item = gather_image_embeds(
|
embeds = gather_image_embeds(
|
||||||
encoder_output[start_idx:end_idx],
|
encoder_output[start_idx:end_idx],
|
||||||
is_embed=is_embed,
|
is_embed=is_embed,
|
||||||
)
|
)
|
||||||
if mm_embeds_item is not None:
|
if embeds is not None:
|
||||||
mm_embeds.append(mm_embeds_item)
|
chunks.append(embeds)
|
||||||
|
|
||||||
if len(mm_embeds) == 0:
|
if end_idx == length:
|
||||||
|
self.cache_entries_to_free.append((i, image_position.id))
|
||||||
|
self.image_positions[i][image_position.id] = None
|
||||||
|
|
||||||
|
if len(chunks) == 0:
|
||||||
return None
|
return None
|
||||||
return torch.cat(mm_embeds, dim=0).to(device)
|
return torch.cat(chunks, dim=0).to(device)
|
||||||
|
|
||||||
def free_encoder_cache(self):
|
def free_encoder_cache(self):
|
||||||
for i, image_id in self.encoder_cache_to_free:
|
for i, image_id in self.cache_entries_to_free:
|
||||||
self.encoder_cache[i][image_id] = None
|
self.encoder_cache[i].pop(image_id, None)
|
||||||
|
|
||||||
self.encoder_cache_to_free = []
|
self.cache_entries_to_free = []
|
||||||
|
|
||||||
|
# release any freed GPU memory immediately?
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLM(FlashCausalLM):
|
class VlmCausalLM(FlashCausalLM):
|
||||||
@ -997,20 +866,20 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
def get_input_embeds(
|
def get_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
vision_embeds: Optional[torch.Tensor] = None,
|
vision_embeds: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
return self.model.get_input_embeds(
|
return self.model.get_inputs_embeds(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
vision_embeds=vision_embeds,
|
vision_embeds=vision_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_mm_embeddings(self, batch):
|
def encode_images(self, batch):
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
device = batch.input_ids.device
|
device = batch.input_ids.device
|
||||||
for request_id, image_position, image_input in batch.pixel_values:
|
for request_id, image_id, image_input in batch.pixel_values:
|
||||||
pixel_values = image_input["pixel_values"].to(device)
|
pixel_values = image_input["pixel_values"].to(device)
|
||||||
|
|
||||||
if "pixel_attention_mask" in image_input:
|
if "pixel_attention_mask" in image_input:
|
||||||
@ -1036,19 +905,23 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
image_sizes=image_sizes,
|
image_sizes=image_sizes,
|
||||||
image_grid_thw=image_grid_thw,
|
image_grid_thw=image_grid_thw,
|
||||||
)
|
)
|
||||||
batch.update_encoder_cache(encoder_outputs, request_id, image_position)
|
batch.update_encoder_cache(
|
||||||
|
encoder_outputs,
|
||||||
|
request_id,
|
||||||
|
batch.image_positions[request_id][image_id],
|
||||||
|
)
|
||||||
|
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
return batch.get_mm_embeddings()
|
|
||||||
|
|
||||||
def get_input_embeddings(self, batch):
|
def get_input_embeddings(self, batch):
|
||||||
if batch.has_image:
|
if batch.has_image_inputs:
|
||||||
vision_embeds = self.get_mm_embeddings(batch)
|
self.encode_images(batch)
|
||||||
batch.has_image = False
|
vision_embeds = batch.gather_vision_embeds()
|
||||||
|
batch.has_image_inputs = False
|
||||||
else:
|
else:
|
||||||
vision_embeds = None
|
vision_embeds = None
|
||||||
|
|
||||||
inputs_embeds = self.get_input_embeds(
|
inputs_embeds = self.get_inputs_embeds(
|
||||||
batch.input_ids, vision_embeds=vision_embeds
|
batch.input_ids, vision_embeds=vision_embeds
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1127,6 +1000,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
attention_mask = self.model.get_attention_mask(
|
attention_mask = self.model.get_attention_mask(
|
||||||
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
|
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
|
||||||
).reshape(-1)
|
).reshape(-1)
|
||||||
|
batch.pixel_values = 1
|
||||||
else:
|
else:
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user