diff --git a/integration-tests/images/cow_beach.png b/integration-tests/images/cow_beach.png new file mode 100644 index 00000000..d67f8a1b Binary files /dev/null and b/integration-tests/images/cow_beach.png differ diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py new file mode 100644 index 00000000..b3415d26 --- /dev/null +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -0,0 +1,39 @@ +import pytest +import requests +import io +import base64 + + +@pytest.fixture(scope="module") +def flash_pali_gemma_handle(launcher): + with launcher( + "Tinkering/test-bvhf", + num_shard=1, + max_input_length=4000, + max_total_tokens=4096, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_pali_gemma(flash_pali_gemma_handle): + await flash_pali_gemma_handle.health(300) + return flash_pali_gemma_handle.client + + +def get_cow_beach(): + with open("integration-tests/images/cow_beach.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): + cow = get_cow_beach() + inputs = f"Where is the cow standing?\n![]({cow})" + response = await flash_pali_gemma.generate(inputs, max_new_tokens=20) + + # TODO: update this! this is incorrect and just to show the current state of the test + assert response.generated_text == ' - HDS' + # assert response.generated_text == "\nbeach" \ No newline at end of file diff --git a/router/src/config.rs b/router/src/config.rs index 8640ede9..296d0352 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -118,6 +118,22 @@ impl Idefics2 { } } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "model_type")] +#[serde(rename_all = "snake_case")] +pub struct Paligemma {} + +impl Paligemma { + pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { + // TODO: improve to calculate based on height and width + // 224 = 256 image tokens + // 448 = 1024 image tokens + // 896 = 4096 image tokens + + 256 + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] @@ -139,6 +155,7 @@ pub enum Config { Phi3, Llama, Baichuan, + Paligemma(Paligemma), Gemma, Cohere, Drbx, diff --git a/router/src/validation.rs b/router/src/validation.rs index be4bef00..24a45d60 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -540,6 +540,30 @@ fn prepare_input( inputs = modified_inputs; tokenizer_query } + Some(Config::Paligemma(config)) => { + let mut modified_inputs = String::with_capacity(inputs.len()); + let mut tokenizer_query = String::with_capacity(inputs.len()); + let mut start = 0; + for chunk in RE.find_iter(&inputs) { + let chunk_start = chunk.start(); + let chunk_end = chunk.end(); + if chunk_start != start { + modified_inputs.push_str(&inputs[start..chunk_start]); + tokenizer_query.push_str(&inputs[start..chunk_start]); + } + let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; + let slots = config.get_number_of_features(height, width); + tokenizer_query.push_str(&"".repeat(slots)); + modified_inputs.push_str(&image_uri); + start = chunk_end; + } + if start != inputs.len() - 1 { + modified_inputs.push_str(&inputs[start..]); + tokenizer_query.push_str(&inputs[start..]); + } + inputs = modified_inputs; + tokenizer_query + } Some(Config::Idefics2(config)) => { let mut modified_inputs = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len()); diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e9761dfe..c4c6342a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -75,6 +75,7 @@ try: from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_dbrx import FlashDbrx + from text_generation_server.models.flash_pali_gemma import FlashPaliGemma from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA except ImportError as e: @@ -433,6 +434,16 @@ def get_model( trust_remote_code=trust_remote_code, ) + if model_type == "paligemma": + return FlashPaliGemma( + model_id, + revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + if model_type == "cohere": if FLASH_ATTENTION: return FlashCohere( diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 43b90bdd..602a29cc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -295,9 +295,9 @@ class GemmaMLP(nn.Module): class FlashGemmaLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix or ''}model.layers.{layer_id}" self.self_attn = FlashGemmaAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -351,21 +351,30 @@ class FlashGemmaLayer(nn.Module): class FlashGemmaModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() embed_norm = config.hidden_size**0.5 + pvalue = f"{prefix + '.' if prefix else ''}model.embed_tokens" self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=pvalue, + weights=weights, + # limit embed_tokens.weight size to the config.vocab_size ) + self.embed_tokens.weight = torch.nn.Parameter( + self.embed_tokens.weight[: config.vocab_size, : config.hidden_size] + ) + + # TODO: double check why this is needed self.embed_tokens.weight *= embed_norm self.layers = nn.ModuleList( [ FlashGemmaLayer( + f"{prefix + '.' if prefix else ''}", layer_id, config, weights, @@ -374,7 +383,9 @@ class FlashGemmaModel(torch.nn.Module): ] ) self.norm = GemmaFastRMSNorm.load( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps + prefix=f"{prefix + '.' if prefix else ''}model.norm", + weights=weights, + eps=config.rms_norm_eps, ) self.gradient_checkpointing = False @@ -385,7 +396,8 @@ class FlashGemmaModel(torch.nn.Module): def forward( self, - input_ids: torch.Tensor, + # input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -394,8 +406,8 @@ class FlashGemmaModel(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - + hidden_states = inputs_embeds + # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( @@ -423,13 +435,15 @@ class FlashGemmaModel(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() - - self.model = FlashGemmaModel(config, weights) + self.config = config + self.model = FlashGemmaModel(prefix, config, weights) + prefix = f"{prefix + '.' if prefix else ''}model.embed_tokens" + prefix = prefix if config.tie_word_embeddings else "lm_head" self.lm_head = SpeculativeHead.load( config, - prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", + prefix=prefix, weights=weights, ) @@ -445,8 +459,9 @@ class FlashGemmaForCausalLM(torch.nn.Module): max_s: int, lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( - input_ids, + inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py new file mode 100644 index 00000000..e2bc9807 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -0,0 +1,264 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed +from torch import nn +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +from text_generation_server.utils.layers import TensorParallelColumnLinear +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, + load_vision_model, +) +from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + GemmaConfig, +) + + +# TODO: prefer using the following config classes +# * instead of the hack inside of the gemma modeling file +class VisionConfig(PretrainedConfig): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + model_type: str, + num_attention_heads: int, + num_hidden_layers: int, + num_image_tokens: int, + patch_size: int, + projection_dim: int, + projector_hidden_act: str, + vision_use_head: bool, + vocab_size: int, + quantize: Optional[str] = None, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.model_type = model_type + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.num_image_tokens = num_image_tokens + self.patch_size = patch_size + self.projection_dim = projection_dim + self.projector_hidden_act = projector_hidden_act + self.vision_use_head = vision_use_head + self.vocab_size = vocab_size + self.quantize = quantize + + +class PaliTextConfig(PretrainedConfig): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + model_type: str, + num_attention_heads: int, + num_hidden_layers: int, + num_image_tokens: int, + num_key_value_heads: int, + torch_dtype: str, + vocab_size: int, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.model_type = model_type + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.num_image_tokens = num_image_tokens + self.num_key_value_heads = num_key_value_heads + self.torch_dtype = torch_dtype + self.vocab_size = vocab_size + + +class PaliGemmaConfig(PretrainedConfig): + def __init__( + self, + vocab_size=257216, + hidden_size=2048, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=2, + eos_token_id=1, + tie_word_embeddings=True, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + text_config=None, + vision_config=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.head_dim = head_dim + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.text_config = GemmaConfig( + hidden_size=2048, + intermediate_size=16384, + model_type="gemma", + num_attention_heads=8, + num_hidden_layers=18, + num_image_tokens=256, + num_key_value_heads=1, + torch_dtype="float32", + vocab_size=257216, + ) + + self.vision_config = VisionConfig( + hidden_size=1152, + intermediate_size=4304, + model_type="siglip_vision_model", + num_attention_heads=16, + num_hidden_layers=27, + num_image_tokens=256, + patch_size=14, + projection_dim=2048, + projector_hidden_act="gelu_fast", + vision_use_head=False, + vocab_size=257152, + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class FlashPaliGemmaForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = config.quantize + + self.vision_tower = load_vision_model( + prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", + config=config.vision_config, + weights=weights, + ).to(weights.device, weights.dtype) + + self.multi_modal_projector = TensorParallelColumnLinear.load( + config, + prefix="multi_modal_projector.linear", + weights=weights, + bias=True, + ).to(weights.device, weights.dtype) + + self.vocab_size = config.vocab_size + self.config = config + + self.language_model = load_text_model( + prefix=prefix, + config=config, + weights=weights, + ).to(weights.device, weights.dtype) + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def _merge_input_ids_with_image_features( + self, image_features, inputs_embeds, input_ids + ): + """In place merges in vision_embeddings with inputs_embeds.""" + mask = input_ids == self.config.image_token_index + # Let's pray we have enabled enough slots ! + try: + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + except Exception as e: + raise RuntimeError( + f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}" + ) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + pixel_attention_mask=None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + + if pixel_values is not None: + pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype) + + # merge text and images + if pixel_values is not None and len(pixel_values) > 0: + image_outputs = self.vision_tower(pixel_values) + selected_image_feature = image_outputs.last_hidden_state + image_features = self.multi_modal_projector(selected_image_feature) + # TODO: make sure to handle the specialized attention mask correctly + inputs_embeds = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids + ) + + hidden_states = self.language_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + ) + + + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.language_model.lm_head(hidden_states) + + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py new file mode 100644 index 00000000..81795f86 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -0,0 +1,578 @@ +from typing import Optional, Tuple, Union + +import math +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import ( + _create_4d_causal_attention_mask, + _prepare_4d_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, +) +from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig + +from text_generation_server.utils.layers import ( + TensorParallelEmbedding, + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, prefix, config: SiglipVisionConfig, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + self.patch_embedding.bias = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False + ) + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.position_embedding", weights=weights + ) + # TODO: remove this hack! figure out why off by one + self.position_embedding.weight = torch.nn.Parameter( + self.position_embedding.weight[:256, :] + ) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions, device=weights.device).expand((1, -1)), + persistent=False, + ) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + patch_embeds = self.patch_embedding( + pixel_values + ) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding( + config.max_position_embeddings, embed_dim + ) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = ( + input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.head_size = self.head_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=True, + ) + self.out_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=True, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, _ = hidden_states.size() + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split( + [ + self.head_size * self.num_heads, + ] + * 3, + dim=2, + ) + key_states = self._shape(key_states, -1, bsz) + value_states = self._shape(value_states, -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_size) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + # scale post matmul + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) * self.scale + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(attn_weights.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.bmm(attn_weights, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SiglipMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( # config.hidden_size, config.intermediate_size + prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( # config.intermediate_size, config.hidden_size + prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(nn.Module): + def __init__(self, prefix, config: SiglipConfig, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SiglipAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps + ) + self.mlp = SiglipMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + if output_attentions: + return hidden_states, attn_weights + print(hidden_states[0, 0, :5].tolist()) + return hidden_states, None + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True + ) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +import warnings + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +from torch.nn.init import _calculate_fan_in_and_fan_out + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +from transformers import PreTrainedModel + + +class SiglipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.log(torch.tensor(1.0)) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + + Args: + config: SiglipConfig + """ + + def __init__(self, prefix, config: SiglipConfig, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + SiglipEncoderLayer( + prefix=f"{prefix}.layers.{i}", config=config, weights=weights + ) + for i in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[torch.Tensor] = None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + """ + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + hidden_states, _ = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + return hidden_states + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, prefix, config: SiglipVisionConfig, weights): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.encoder = SiglipEncoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + self.post_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + ): + r""" + Returns: + + """ + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + + # NOTE: up until this point, the code logits are exactly + # the same as the transformers code. The values evaulate + # slightly differently in our encoder layer. + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + ) + last_hidden_state = encoder_outputs + post_last_hidden_state = self.post_layernorm(last_hidden_state) + + return BaseModelOutputWithPooling( + last_hidden_state=post_last_hidden_state, + # pooler_output=pooled_output, + # hidden_states=encoder_outputs, + ) diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index 690957d0..9ae4add7 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -11,6 +11,12 @@ def load_text_model(prefix, config, weights, name=None): ) return FlashMistralForCausalLM(prefix, config, weights, name=name) + elif config.model_type == "gemma": + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, + ) + + return FlashGemmaForCausalLM(prefix, config, weights) else: raise RuntimeError(f"Unsupported model type {config.model_type}") @@ -24,5 +30,13 @@ def load_vision_model(prefix, config, weights): return CLIPVisionTransformer( prefix=f"{prefix}.vision_model", config=config, weights=weights ) + if config.model_type == "siglip_vision_model": + from text_generation_server.models.custom_modeling.siglip import ( + SiglipVisionTransformer, + ) + + return SiglipVisionTransformer( + prefix=f"vision_tower.vision_model", config=config, weights=weights + ) else: raise RuntimeError(f"Unsupported model type {config.model_type}") diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5aa7a568..dc9adebc 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -133,6 +133,17 @@ class FlashCausalLMBatch(Batch): device: torch.device, ) -> "FlashCausalLMBatch": batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) + return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + + @classmethod + def from_tokenized( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + batch_tokenized_inputs, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashCausalLMBatch": position_ids = [] speculative_ids = [] cu_seqlen_prefill = [0] @@ -207,6 +218,7 @@ class FlashCausalLMBatch(Batch): # Paged attention # Remove one as the first token des not have a past speculative_length = get_speculate() + speculative_length = 0 if speculative_length is None else speculative_length total_tokens = input_length + max_new_tokens - 1 + speculative_length needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) blocks += needed_blocks diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 9c00a056..e547e267 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -4,6 +4,7 @@ import torch.distributed from opentelemetry import trace from typing import Optional from transformers.models.gemma import GemmaTokenizerFast +from transformers import AutoConfig, PretrainedConfig from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( @@ -19,15 +20,58 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) -class FlashGemma(FlashCausalLM): +class VisionConfig(PretrainedConfig): def __init__( self, + hidden_size: int = 1152, + intermediate_size: int = 4304, + model_type: str = "siglip_vision_model", + num_attention_heads: int = 16, + num_hidden_layers: int = 27, + num_image_tokens: int = 256, + patch_size: int = 14, + projection_dim: int = 2048, + projector_hidden_act: str = "gelu_fast", + vision_use_head: bool = False, + vocab_size: int = 257152, + quantize: Optional[str] = None, + image_size: int = 224, + layer_norm_eps: float = 1e-06, + attention_dropout: float = 0.0, + hidden_act: str = "gelu_pytorch_tanh", + num_channels: int = 3, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.model_type = model_type + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.num_image_tokens = num_image_tokens + self.patch_size = patch_size + self.projection_dim = projection_dim + self.projector_hidden_act = projector_hidden_act + self.vision_use_head = vision_use_head + self.vocab_size = vocab_size + self.quantize = quantize + self.image_size = image_size + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.num_channels = num_channels + + +class BaseFlashGemma(FlashCausalLM): + def __init__( + self, + model_cls, model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + prefix: Optional[str] = None, + config_cls=AutoConfig, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -49,9 +93,39 @@ class FlashGemma(FlashCausalLM): config = GemmaConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) + + is_vlm = hasattr(config, "vision_config") and hasattr(config, "text_config") + + if is_vlm: + config.vision_config = VisionConfig( + hidden_size=1152, + intermediate_size=4304, + model_type="siglip_vision_model", + num_attention_heads=16, + num_hidden_layers=27, + num_image_tokens=256, + patch_size=14, + projection_dim=2048, + projector_hidden_act="gelu_fast", + vision_use_head=False, + vocab_size=257152, + quantize=quantize, + ) + config.quantize = quantize config.speculator = speculator + if is_vlm: + config.num_hidden_layers = config.text_config.get("num_hidden_layers") + config.intermediate_size = config.text_config.get("intermediate_size") + config.model_type = config.text_config.get("model_type") + config.num_attention_heads = config.text_config.get("num_attention_heads") + config.num_hidden_layers = config.text_config.get("num_hidden_layers") + config.num_image_tokens = config.text_config.get("num_image_tokens") + config.num_key_value_heads = config.text_config.get("num_key_value_heads") + config.torch_dtype = config.text_config.get("torch_dtype") + config.vocab_size = config.text_config.get("vocab_size") + torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") @@ -59,17 +133,49 @@ class FlashGemma(FlashCausalLM): if config.quantize in ["gptq", "awq"]: weights._set_gptq_params(model_id, revision) - model = FlashGemmaForCausalLM(config, weights) + model = model_cls(prefix, config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashGemma, self).__init__( + + if is_vlm: + num_layers = config.num_hidden_layers + num_kv_heads = config.num_key_value_heads + head_size = config.intermediate_size + else: + num_layers = len(model.model.layers) + num_kv_heads = model.model.num_key_value_heads + head_size = model.model.head_size + + super().__init__( model=model, tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_size=head_size, dtype=dtype, device=device, rank=rank, world_size=world_size, ) + + +class FlashGemma(BaseFlashGemma): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + use_medusa: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + super(FlashGemma, self).__init__( + model_cls=FlashGemmaForCausalLM, + model_id=model_id, + revision=revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + prefix=None, + ) diff --git a/server/text_generation_server/models/flash_pali_gemma.py b/server/text_generation_server/models/flash_pali_gemma.py new file mode 100644 index 00000000..16955f8c --- /dev/null +++ b/server/text_generation_server/models/flash_pali_gemma.py @@ -0,0 +1,54 @@ +import torch +import torch.distributed +from opentelemetry import trace +from typing import Optional, Tuple +from text_generation_server.models.vlm_causal_lm import PaliVlmCausalLM +from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( + FlashPaliGemmaForConditionalGeneration, + PaliGemmaConfig, + PaliTextConfig, +) +from transformers import AutoProcessor + +tracer = trace.get_tracer(__name__) + + +class FlashPaliGemma(PaliVlmCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + use_medusa: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.processor = AutoProcessor.from_pretrained( + # TODO: load in the correct processor based on the model_id + "google/siglip-base-patch16-224", + # "google/siglip-so400m-patch14-384", + revision=revision, + trust_remote_code=trust_remote_code, + ) + + super().__init__( + config_cls=PaliTextConfig, + model_cls=FlashPaliGemmaForConditionalGeneration, + model_id=model_id, + revision=revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + prefix="language_model", + ) + + def get_layer_config(self, model) -> Tuple[int, int, int]: + return ( + len(model.language_model.model.layers), + model.language_model.model.num_key_value_heads, + model.language_model.model.head_size, + ) + + def max_past(self) -> Optional[int]: + return getattr(self.model.language_model, "max_past", None) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 5394feb5..3ec6c319 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -15,6 +15,8 @@ from text_generation_server.models.flash_mistral import ( BaseFlashMistral, FlashMistralBatch, ) +from text_generation_server.models.flash_gemma import BaseFlashGemma +from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.cache_manager import ( get_cache_manager, ) @@ -80,6 +82,11 @@ def image_text_replacement(image_input, config, image_id) -> str: logger.info(f"Found {num_features} in image of resolution {height}x{width}") return "" * num_features + + # TODO: double check correct naming for model_type + elif config.model_type == "gemma": + # TODO: use correct number of features + return "" * 256 else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -371,3 +378,238 @@ class VlmCausalLM(BaseFlashMistral): ) logits = cuda_graph["logits"][:bs] return logits, speculative_logits + + +class PaliVlmCausalLMBatch(FlashCausalLMBatch): + pixel_values: Optional[List[torch.Tensor]] + pixel_attention_mask: Optional[List[torch.Tensor]] + image_sizes: Optional[List[Tuple[int, int]]] + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches): + batch = super(PaliVlmCausalLMBatch, cls).concatenate(batches) + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + return batch + + @tracer.start_as_current_span("filter") + def filter(self, request_ids: List[int]): + batch = super().filter(request_ids) + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + return batch + + @classmethod + def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): + batch_inputs = [] + image_inputs = [] + max_truncation = 0 + for r in requests: + chunks = split(r.inputs) + full_text = "" + image_id = 0 + for chunk in chunks: + if chunk["type"] == "text": + full_text += chunk["content"] + elif chunk["type"] == "image": + image = chunk["content"] + # Should never receive URLs anymore, processing should be done + # On the rust layer. + # This avoid making n queries per TP + # if image.startswith("https://") or image.startswith("http://"): + # image = processor.image_processor.fetch_images(image) + if image.startswith("data:"): + image = load_data_uri(image) + else: + raise RuntimeError( + "Cannot process input image not starting with data:" + ) + image_input = processor.image_processor(image, return_tensors="pt") + full_text += image_text_replacement(image_input, config, image_id) + image_inputs.append(image_input) + else: + raise RuntimeError(f"Invalid chunk type {chunk['type']}") + + batch_inputs.append(full_text) + max_truncation = max(max_truncation, r.truncate) + + batch_tokenized_inputs = tokenizer( + batch_inputs, truncation=True, max_length=max_truncation + )["input_ids"] + if image_inputs: + image_input = image_inputs[0] + new_image_inputs = { + "pixel_values": torch.cat( + [img["pixel_values"] for img in image_inputs], dim=0 + ), + } + if "pixel_attention_mask" in image_input: + new_image_inputs["pixel_attention_mask"] = torch.cat( + [img["pixel_attention_mask"] for img in image_inputs], dim=0 + ) + if "image_sizes" in image_input: + new_image_inputs["image_sizes"] = torch.cat( + [img["image_sizes"] for img in image_inputs], dim=0 + ) + image_inputs = new_image_inputs + else: + image_inputs = None + return batch_tokenized_inputs, image_inputs + + @classmethod + def from_pb_processor( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + processor, + config, + dtype: torch.dtype, + device: torch.device, + ) -> "PaliVlmCausalLMBatch": + batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( + pb.requests, tokenizer, processor, config + ) + batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + if image_inputs is not None: + batch.pixel_values = image_inputs["pixel_values"].to(device=device) + if "pixel_attention_mask" in image_inputs: + batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( + device=device + ) + else: + batch.pixel_attention_mask = None + if "image_sizes" in image_inputs: + batch.image_sizes = image_inputs["image_sizes"].to(device=device) + else: + batch.image_sizes = None + else: + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + return batch + + +class PaliVlmCausalLM(BaseFlashGemma): + @property + def batch_type(self) -> Type[PaliVlmCausalLMBatch]: + return PaliVlmCausalLMBatch + + def forward( + self, batch: PaliVlmCausalLMBatch + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Model Forward + if batch.speculative_ids is not None: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = get_cache_manager().kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat( + [input_ids.unsqueeze(-1), speculative_ids], dim=1 + ).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = ( + position_ids.unsqueeze(-1).expand(B, new_length) + arange + ).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = ( + input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + + # Add Copy the block tables for all members + block_tables = ( + block_tables.unsqueeze(1) + .expand(B, new_length, -1) + .reshape(B * new_length, -1) + .contiguous() + ) + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = get_cache_manager().kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + max_s = batch.max_seqlen + lm_head_indices = batch.prefill_head_indices + + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) + + bs = input_ids.shape[0] + # Try to find an associated cuda graph + bs = input_ids.shape[0] + sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) + if sorted_padded_bs: + # Get associated cuda graph + cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] + else: + cuda_graph = None + if cu_seqlen_prefill is not None or cuda_graph is None: + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + # prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + pixel_values=batch.pixel_values, + ) + # if batch.prefill_cache_indices is not None: + # batch.prefill_cache_indices = None + if batch.pixel_values is not None: + batch.pixel_values = None + if batch.pixel_attention_mask is not None: + batch.pixel_attention_mask = None + if batch.image_sizes is not None: + batch.image_sizes = None + return logits, speculative_logits + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + cuda_graph["slots"].fill_(-1) + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + + # Replay the graph + cuda_graph["graph"].replay() + + # Slice output to the correct shape + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 9d0571a6..94cf1417 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -14,7 +14,7 @@ from typing import List, Optional from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model -from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch +from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, PaliVlmCausalLMBatch from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch @@ -98,6 +98,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if self.model.batch_type in { IdeficsCausalLMBatch, VlmCausalLMBatch, + PaliVlmCausalLMBatch, }: # Hack, i would rather use kwargs in the `from_pb` call batch = self.model.batch_type.from_pb_processor( request.batch, @@ -122,6 +123,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if self.model.batch_type in { IdeficsCausalLMBatch, VlmCausalLMBatch, + PaliVlmCausalLMBatch, }: # Hack, i would rather use kwargs in the `from_pb` call batch = self.model.batch_type.from_pb_processor( request.batch,