diff --git a/docs/source/reference/launcher.md b/docs/source/reference/launcher.md index 6505a08d..b8f0d950 100644 --- a/docs/source/reference/launcher.md +++ b/docs/source/reference/launcher.md @@ -248,9 +248,18 @@ Options: -p, --port The port to listen on - [env: PORT=] + [env: PORT=80] [default: 3000] +``` +## PROMETHEUS_PORT +```shell + -p, --prometheus-port + The Prometheus port to listen on + + [env: PROMETHEUS_PORT=] + [default: 9000] + ``` ## SHARD_UDS_PATH ```shell diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py index 35b29ce0..d5ba9a8a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py @@ -777,7 +777,7 @@ class Gemma3ForConditionalGeneration(nn.Module): image_features = image_features.view(-1, image_features.shape[-1]) return image_features - def get_input_embeds( + def get_inputs_embeds( self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, @@ -820,11 +820,7 @@ class Gemma3ForConditionalGeneration(nn.Module): max_s += 1 position_ids += 1 - image_token_mask = (input_ids == self.config.image_token_index).to( - input_ids.device - ) - - if torch.any(image_token_mask): + if pixel_values: attention_mask = self.get_attention_mask( input_ids, cu_seqlen_prefill, diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index a6ceade8..d1117e39 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -80,7 +80,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): ) return image_features - def get_input_embeds( + def get_inputs_embeds( self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 110f0d05..bb4c6ca3 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -801,7 +801,7 @@ class Idefics2ForConditionalGeneration(nn.Module): image_hidden_states = torch.stack(all_states, dim=0) return image_hidden_states.view(-1, image_hidden_states.shape[-1]) - def get_input_embeds( + def get_inputs_embeds( self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, diff --git a/server/text_generation_server/models/custom_modeling/idefics3.py b/server/text_generation_server/models/custom_modeling/idefics3.py index 71e00442..d4165502 100644 --- a/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/server/text_generation_server/models/custom_modeling/idefics3.py @@ -543,7 +543,7 @@ class Idefics3ForConditionalGeneration(nn.Module): return image_hidden_states.view(-1, image_hidden_states.shape[-1]) - def get_input_embeds( + def get_inputs_embeds( self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 27861375..bb6da022 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -250,7 +250,7 @@ class LlavaNextForConditionalGeneration(nn.Module): image_features = torch.stack(new_image_features, dim=0) return image_features.view(-1, image_features.shape[-1]) - def get_input_embeds( + def get_inputs_embeds( self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, diff --git a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index 1d48cb34..addb9032 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -931,7 +931,7 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) return image_embeds - def get_input_embeds( + def get_inputs_embeds( self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 7c5ed470..0ca41c1d 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -509,7 +509,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0) return image_embeds - def get_input_embeds( + def get_inputs_embeds( self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, diff --git a/server/text_generation_server/models/transformers_flash_vlm.py b/server/text_generation_server/models/transformers_flash_vlm.py index 1f4a053b..518e5972 100644 --- a/server/text_generation_server/models/transformers_flash_vlm.py +++ b/server/text_generation_server/models/transformers_flash_vlm.py @@ -590,7 +590,7 @@ class TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM): projected_vision_flat = self.model.multi_modal_projector(vision_flat) return projected_vision_flat - def get_input_embeds(self, input_ids, vision_embeds=None): + def get_inputs_embeds(self, input_ids, vision_embeds=None): inputs_embeds = self.model.get_input_embeddings()(input_ids) if vision_embeds is not None: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 49617cbb..4aca153e 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -110,7 +110,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): return height // patch_size, width // patch_size -def image_text_replacement(processor, image_input, config, image_id: int) -> str: +def image_text_replacement(processor, image_input, config) -> str: if config.model_type == "idefics2": image_seq_len = 64 image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}" @@ -119,8 +119,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str return image_str, IDEFICS2_FAKE_TOKEN if config.model_type == "idefics3": # TODO: implement this in a more general way - n_rows = image_input[image_id]["rows"][0][0] - n_cols = image_input[image_id]["cols"][0][0] + n_rows = image_input["rows"][0][0] + n_cols = image_input["cols"][0][0] image_seq_len = int( ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) @@ -135,7 +135,7 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str ) return image_str, IDEFICS3_FAKE_IMAGE_TOKEN elif config.model_type == "llava_next": - height, width = image_input[image_id]["image_sizes"][0] + height, width = image_input["image_sizes"][0] num_features = get_number_of_features(height, width, config) log_master( @@ -147,12 +147,12 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str elif config.model_type == "paligemma": return "" * config.text_config.num_image_tokens, "" elif config.model_type == "qwen2_vl": - grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0] + grid_t, grid_h, grid_w = image_input["image_grid_thw"][0] num_pads = grid_t * grid_h * grid_w // 4 padding = "<|image_pad|>" * num_pads return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" elif config.model_type == "qwen2_5_vl": - grid_t, grid_h, grid_w = image_input[image_id]["image_grid_thw"][0] + grid_t, grid_h, grid_w = image_input["image_grid_thw"][0] num_pads = grid_t * grid_h * grid_w // 4 padding = "<|image_pad|>" * num_pads return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>" @@ -166,8 +166,8 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str patch_size = config.vision_config.patch_size pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2))) - aspect_ratios = image_input[image_id]["aspect_ratios"][0] - image_height, image_width = image_input[image_id]["pixel_values"][0].shape[-2:] + aspect_ratios = image_input["aspect_ratios"][0] + image_height, image_width = image_input["pixel_values"][0].shape[-2:] num_patches_per_chunk = int( (image_height // patch_size) @@ -264,7 +264,9 @@ def get_number_of_features(height: int, width: int, config) -> int: return unpadded_features + newline_features + base_features -def scatter_image_embeds(embeds, is_embed): +def scatter_image_embeds( + embeds: torch.Tensor, is_embed: Optional[torch.Tensor] +) -> torch.Tensor: if is_embed is None: return embeds @@ -276,15 +278,13 @@ def scatter_image_embeds(embeds, is_embed): return placeholders -def gather_image_embeds(embeds, is_embed): +def gather_image_embeds( + embeds: torch.Tensor, is_embed: Optional[torch.Tensor] +) -> Optional[torch.Tensor]: if is_embed is None: return embeds - - gathered = embeds[is_embed] - - if len(gathered) == 0: - return None - return gathered + sel = embeds[is_embed] + return sel if sel.numel() else None @dataclass @@ -304,6 +304,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch): pixel_attention_mask: Optional[List[torch.Tensor]] image_sizes: Optional[List[Tuple[int, int]]] image_grid_thw: Optional[torch.Tensor] + cache_entries_to_free: List[Tuple[int, int]] + has_image_inputs: bool = False @classmethod @tracer.start_as_current_span("concatenate") @@ -331,6 +333,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch): batch.pixel_attention_mask = None batch.image_sizes = None batch.image_grid_thw = None + # To be filled in prepare_for_prefill + batch.has_image_inputs = False + batch.cache_entries_to_free = [] + return batch @tracer.start_as_current_span("filter") @@ -357,6 +363,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch): batch.image_inputs = image_inputs batch.image_positions = image_positions batch.encoder_cache = encoder_cache + + # To be filled in prepare_for_prefill + batch.has_image_inputs = False + batch.cache_entries_to_free = [] return batch @classmethod @@ -404,7 +414,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch): image_inputs.append(image_input) img_text, img_start_token_str = image_text_replacement( - processor, image_input, config, 0 + processor, image_input, config ) text_parts.append(img_text) @@ -456,12 +466,17 @@ class VlmCausalLMBatch(FlashCausalLMBatch): for i in range(num_images): image_id, img_start_token_str, img_text = image_texts[i] img_text = image_text_replacement_fixup(config, img_text) + if config.model_type == "gemma3": img_text = img_text.replace("\n\n", "") tokens = tokenizer(img_text, add_special_tokens=False)["input_ids"] length = len(tokens) + assert length <= len( + input_ids + ), f"{length} > {len(input_ids)} Image is truncated, try increasing --max-batch-prefill-tokens" + pos = torch.searchsorted(img_start_token_pos, last_pos, right=False) index = img_start_token_pos[pos] @@ -502,151 +517,6 @@ class VlmCausalLMBatch(FlashCausalLMBatch): return image_positions - @classmethod - def batch_tokenized_inputs2( - cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config - ): - # sizes to insert correct number of image tokens. - kwargs = {} - if ( - hasattr(processor, "image_processor_class") - and processor.image_processor_class == "Idefics3ImageProcessor" - ): - kwargs["return_row_col_info"] = True - - batch_image_inputs = [] - for i, r in enumerate(requests): - image_inputs = [] - batch_image_inputs.append(None) - for chunk in r.input_chunks.chunks: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - pass - elif chunk_type == "image": - image = Image.open(BytesIO(chunk.image.data)) - # qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the - # default warmup image is 20x20 - if config.model_type in {"qwen2_vl", "qwen2_5_vl"}: - if image.width <= 20: - w = image.width * 2 - h = image.height * 2 - image = image.resize((w, h)) - - if config.model_type in {"llava_next", "gemma3", "llama4"}: - image = image - elif config.model_type in {"paligemma"}: - image = image.convert("RGB") - else: - image = [image] - image_input = processor.image_processor( - [image], return_tensors="pt", **kwargs - ) - - image_inputs.append(image_input) - else: - raise RuntimeError(f"Invalid chunk type {chunk_type}") - - if len(image_inputs) > 0: - batch_image_inputs[i] = image_inputs - - batch_image_positions = [] - batch_tokenized_inputs = [] - max_length = 0 - for i, r in enumerate(requests): - full_text = "" - image_tokens = [] - batch_image_positions.append(None) - image_id = 0 - for chunk in r.input_chunks.chunks: - chunk_type = chunk.WhichOneof("chunk") - if chunk_type == "text": - full_text += chunk.text - elif chunk_type == "image": - image_text, id_token = image_text_replacement( - processor, batch_image_inputs[i], config, image_id - ) - full_text += image_text - - if config.model_type == "gemma3": - image_text = image_text.replace("\n\n", "") - elif config.model_type == "idefics2": - image_text = image_text_replacement_fixup(config, image_text) - - image_tokens.append( - ( - image_id, - id_token, - tokenizer(image_text, add_special_tokens=False)[ - "input_ids" - ], - ) - ) - image_id += 1 - - full_text = image_text_replacement_fixup(config, full_text) - input_ids = tokenizer( - full_text, - truncation=True, - max_length=r.truncate, - add_special_tokens=r.add_special_tokens, - )["input_ids"] - max_length = max(max_length, len(input_ids)) - batch_tokenized_inputs.append(input_ids) - - prev = 0 - image_positions = [] - idefics_replacement = False - - config.image_token_index = ( - config.image_token_index - if hasattr(config, "image_token_index") - else config.image_token_id - ) - for image_id, id_token, tokens in image_tokens: - id_token = tokenizer.get_vocab()[id_token] - - if config.model_type == "idefics2" and idefics_replacement: - id_token = config.image_token_index - tokens = tokens[1:] - - length = len(tokens) - index = input_ids[prev:].index(id_token) - - index = index + prev - assert ( - input_ids[index : index + length] == tokens - ), "Image token not found in input_ids" - - if config.model_type in {"llava_next", "paligemma"}: - is_embed = None - num_placeholder_tokens = length - else: - is_embed = torch.tensor(tokens) == config.image_token_index - num_placeholder_tokens = is_embed.sum().item() - - pos = ImagePositions( - offset=index, - length=length, - id=image_id, - num_placeholder_tokens=num_placeholder_tokens, - is_embed=is_embed, - ) - - image_positions.append(pos) - prev = index + length - - if config.model_type == "idefics2" and prev != len(input_ids): - if input_ids[prev] == config.image_token_id: - # means this is replaced by image_text_replacement_fixup - idefics_replacement = True - else: - idefics_replacement = False - - if len(image_positions) > 0: - batch_image_positions[i] = image_positions - - return batch_tokenized_inputs, batch_image_inputs, batch_image_positions - @classmethod def from_pb_processor( cls, @@ -674,19 +544,17 @@ class VlmCausalLMBatch(FlashCausalLMBatch): def prepare_for_prefill(self): super().prepare_for_prefill() - self.has_image = False - self.encoder_cache_to_free = [] + self.has_image_inputs = False + self.cache_entries_to_free = [] self.pixel_values = [] for i, ( - r, cache_length, input_length, request_prefilling, ) in enumerate( zip( - self.requests, self.cache_lengths, self.input_lengths, self.prefilling_mask, @@ -695,10 +563,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch): if not request_prefilling or self.image_positions[i] is None: continue - for j, image_position in enumerate(self.image_positions[i]): - image_id = image_position.id - image_inputs = self.image_inputs[i][j] - + for image_position in self.image_positions[i]: + if image_position is None: + continue start_pos = image_position.offset length = image_position.length @@ -709,47 +576,46 @@ class VlmCausalLMBatch(FlashCausalLMBatch): # The encode input is already processed continue - self.has_image = True + self.has_image_inputs = True - if image_id not in self.encoder_cache[i]: - self.pixel_values.append((i, image_position, image_inputs)) - self.image_inputs[i][j] = None + if image_position.id not in self.encoder_cache[i]: + image_inputs = self.image_inputs[i][image_position.id] + self.pixel_values.append((i, image_position.id, image_inputs)) - if not self.has_image: + # Remove the image from the image_inputs + self.image_inputs[i][image_position.id] = None + + if not self.has_image_inputs: self.pixel_values = None self.pixel_attention_mask = None self.image_sizes = None self.image_grid_thw = None - def update_encoder_cache(self, encoder_outputs, request_id, input): - self.encoder_cache[request_id][input.id] = scatter_image_embeds( - encoder_outputs, input.is_embed + def update_encoder_cache(self, encoder_outputs, request_id, img_pos): + self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds( + encoder_outputs, img_pos.is_embed ) - def get_mm_embeddings(self): + def gather_vision_embeds(self): device = self.input_ids.device - mm_embeds = [] - for i, ( - r, + chunks = [] + for ( + i, cache_length, input_length, - prompt_length, request_prefilling, - ) in enumerate( - zip( - self.requests, - self.cache_lengths, - self.input_lengths, - self.prompt_lengths, - self.prefilling_mask, - ) + ) in zip( + range(len(self.requests)), + self.cache_lengths, + self.input_lengths, + self.prefilling_mask, ): if not request_prefilling or self.image_positions[i] is None: continue - for j, image_position in enumerate(self.image_positions[i]): - image_id = image_position.id - + for image_position in self.image_positions[i]: + if image_position is None: + continue start_pos = image_position.offset length = image_position.length @@ -763,13 +629,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch): start_idx = max(cache_length - start_pos, 0) end_idx = min(cache_length - start_pos + input_length, length) - if end_idx == length: - self.encoder_cache_to_free.append((i, image_id)) - assert ( - image_id in self.encoder_cache[i] - ), f"image_id {image_id} not in encoder_cache {self.encoder_cache[i]}" - encoder_output = self.encoder_cache[i][image_id] + image_position.id in self.encoder_cache[i] + ), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}" + encoder_output = self.encoder_cache[i][image_position.id] is_embed = image_position.is_embed if is_embed is not None: @@ -778,25 +641,31 @@ class VlmCausalLMBatch(FlashCausalLMBatch): from loguru import logger logger.info( - f"image_id {image_id} start_idx {start_idx} end_idx {end_idx}, length {length}" + f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}" ) - mm_embeds_item = gather_image_embeds( + embeds = gather_image_embeds( encoder_output[start_idx:end_idx], is_embed=is_embed, ) - if mm_embeds_item is not None: - mm_embeds.append(mm_embeds_item) + if embeds is not None: + chunks.append(embeds) - if len(mm_embeds) == 0: + if end_idx == length: + self.cache_entries_to_free.append((i, image_position.id)) + self.image_positions[i][image_position.id] = None + + if len(chunks) == 0: return None - return torch.cat(mm_embeds, dim=0).to(device) + return torch.cat(chunks, dim=0).to(device) def free_encoder_cache(self): - for i, image_id in self.encoder_cache_to_free: - self.encoder_cache[i][image_id] = None + for i, image_id in self.cache_entries_to_free: + self.encoder_cache[i].pop(image_id, None) - self.encoder_cache_to_free = [] + self.cache_entries_to_free = [] + + # release any freed GPU memory immediately? class VlmCausalLM(FlashCausalLM): @@ -997,20 +866,20 @@ class VlmCausalLM(FlashCausalLM): ) return embeds - def get_input_embeds( + def get_inputs_embeds( self, input_ids: torch.Tensor, vision_embeds: Optional[torch.Tensor] = None, ): - return self.model.get_input_embeds( + return self.model.get_inputs_embeds( input_ids=input_ids, vision_embeds=vision_embeds, ) - def get_mm_embeddings(self, batch): + def encode_images(self, batch): if batch.pixel_values is not None: device = batch.input_ids.device - for request_id, image_position, image_input in batch.pixel_values: + for request_id, image_id, image_input in batch.pixel_values: pixel_values = image_input["pixel_values"].to(device) if "pixel_attention_mask" in image_input: @@ -1036,19 +905,23 @@ class VlmCausalLM(FlashCausalLM): image_sizes=image_sizes, image_grid_thw=image_grid_thw, ) - batch.update_encoder_cache(encoder_outputs, request_id, image_position) + batch.update_encoder_cache( + encoder_outputs, + request_id, + batch.image_positions[request_id][image_id], + ) batch.pixel_values = None - return batch.get_mm_embeddings() def get_input_embeddings(self, batch): - if batch.has_image: - vision_embeds = self.get_mm_embeddings(batch) - batch.has_image = False + if batch.has_image_inputs: + self.encode_images(batch) + vision_embeds = batch.gather_vision_embeds() + batch.has_image_inputs = False else: vision_embeds = None - inputs_embeds = self.get_input_embeds( + inputs_embeds = self.get_inputs_embeds( batch.input_ids, vision_embeds=vision_embeds ) @@ -1127,6 +1000,7 @@ class VlmCausalLM(FlashCausalLM): attention_mask = self.model.get_attention_mask( input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True ).reshape(-1) + batch.pixel_values = 1 else: attention_mask = None