diff --git a/Dockerfile b/Dockerfile index 175287bb..73a274dc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -50,6 +50,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins ca-certificates \ make \ curl \ + git \ && rm -rf /var/lib/apt/lists/* # Install server diff --git a/integration-tests/images/cow_beach.png b/integration-tests/images/cow_beach.png new file mode 100644 index 00000000..d67f8a1b Binary files /dev/null and b/integration-tests/images/cow_beach.png differ diff --git a/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma.json b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma.json new file mode 100644 index 00000000..037e0b16 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma.json @@ -0,0 +1,25 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 2, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 54901, + "logprob": -0.72753906, + "special": false, + "text": "beach" + }, + { + "id": 1, + "logprob": -0.011009216, + "special": true, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": "beach" +} diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py new file mode 100644 index 00000000..d4e83c9f --- /dev/null +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -0,0 +1,39 @@ +import pytest +import requests +import io +import base64 + + +@pytest.fixture(scope="module") +def flash_pali_gemma_handle(launcher): + with launcher( + "google/paligemma-3b-pt-224", + num_shard=1, + revision="float16", + max_input_length=4000, + max_total_tokens=4096, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_pali_gemma(flash_pali_gemma_handle): + await flash_pali_gemma_handle.health(300) + return flash_pali_gemma_handle.client + + +def get_cow_beach(): + with open("integration-tests/images/cow_beach.png", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()) + return f"data:image/png;base64,{encoded_string.decode('utf-8')}" + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): + cow = get_cow_beach() + inputs = f"![]({cow})Where is the cow standing?\n" + response = await flash_pali_gemma.generate(inputs, max_new_tokens=20) + + assert response.generated_text == "beach" + assert response == response_snapshot diff --git a/router/src/config.rs b/router/src/config.rs index 989f0e31..d27b1136 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -100,7 +100,6 @@ impl LlavaNext { } #[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] pub struct ClipVisionModel { image_size: usize, @@ -108,7 +107,6 @@ pub struct ClipVisionModel { } #[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] pub struct Idefics2 {} @@ -118,6 +116,24 @@ impl Idefics2 { } } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct PaliTextConfig { + num_image_tokens: usize, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Paligemma { + text_config: PaliTextConfig, +} + +impl Paligemma { + pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { + self.text_config.num_image_tokens + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] @@ -140,6 +156,7 @@ pub enum Config { Phi3, Llama, Baichuan, + Paligemma(Paligemma), Gemma, Cohere, Drbx, diff --git a/router/src/validation.rs b/router/src/validation.rs index db832042..ee48f705 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -565,6 +565,30 @@ fn prepare_input( inputs = modified_inputs; tokenizer_query } + Some(Config::Paligemma(config)) => { + let mut modified_inputs = String::with_capacity(inputs.len()); + let mut tokenizer_query = String::with_capacity(inputs.len()); + let mut start = 0; + for chunk in RE.find_iter(&inputs) { + let chunk_start = chunk.start(); + let chunk_end = chunk.end(); + if chunk_start != start { + modified_inputs.push_str(&inputs[start..chunk_start]); + tokenizer_query.push_str(&inputs[start..chunk_start]); + } + let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; + let slots = config.get_number_of_features(height, width); + tokenizer_query.push_str(&"".repeat(slots)); + modified_inputs.push_str(&image_uri); + start = chunk_end; + } + if start != inputs.len() - 1 { + modified_inputs.push_str(&inputs[start..]); + tokenizer_query.push_str(&inputs[start..]); + } + inputs = modified_inputs; + tokenizer_query + } Some(Config::Idefics2(config)) => { let mut modified_inputs = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len()); diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index 7f0efded..9035f6bc 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -13,7 +13,7 @@ grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.23.0 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" @@ -40,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.40.2 ; python_version >= "3.9" and python_version < "3.13" +transformers @ git+https://github.com/huggingface/transformers.git@b8aee2e918d7ba2d5e9e80162ae26b4806873307 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index 7f0efded..9035f6bc 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -13,7 +13,7 @@ grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" +huggingface-hub==0.23.0 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" @@ -40,7 +40,7 @@ sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.40.2 ; python_version >= "3.9" and python_version < "3.13" +transformers @ git+https://github.com/huggingface/transformers.git@b8aee2e918d7ba2d5e9e80162ae26b4806873307 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index d137a500..8de6ead0 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -10,9 +10,9 @@ class FastLinear(torch.nn.Module): bias, ) -> None: super().__init__() - self.weight = torch.nn.Parameter(weight) + self.weight = torch.nn.Parameter(weight, requires_grad=False) if bias is not None: - self.bias = torch.nn.Parameter(bias) + self.bias = torch.nn.Parameter(bias, requires_grad=False) else: self.bias = None diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 43b90bdd..ac6fd0e6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -99,8 +99,13 @@ class GemmaConfig(PretrainedConfig): class GemmaFastRMSNorm(FastRMSNorm): @classmethod def load(cls, prefix, weights, eps=1e-6): + dtype = weights.dtype + weights.dtype = torch.float32 weight = weights.get_tensor(f"{prefix}.weight") + 1 - return cls(weight, eps) + weights.dtype = dtype + new = cls(weight, eps) + new.dtype = dtype + return new # perform the multiplication in full precision and downcast after def forward(self, hidden_states, residual=None): @@ -111,7 +116,7 @@ class GemmaFastRMSNorm(FastRMSNorm): variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * self.weight - return hidden_states.to(self.weight.dtype), residual + return hidden_states.to(self.dtype), residual def load_attention(config, prefix, weights): @@ -153,15 +158,11 @@ def _load_gqa(config, prefix: str, weights): class FlashGemmaAttention(torch.nn.Module): - def __init__( - self, - prefix: str, - config, - weights, - ): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() self.num_heads = config.num_attention_heads self.head_size = config.head_dim + self.causal = causal self.rotary_emb = PositionRotaryEmbedding.static( config=config, @@ -238,6 +239,7 @@ class FlashGemmaAttention(torch.nn.Module): cu_seqlen_prefill, max_s, self.softmax_scale, + causal=self.causal, ) # Decode else: @@ -295,11 +297,10 @@ class GemmaMLP(nn.Module): class FlashGemmaLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix, config, weights, causal: bool): super().__init__() - prefix = f"model.layers.{layer_id}" self.self_attn = FlashGemmaAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal ) self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) @@ -351,30 +352,25 @@ class FlashGemmaLayer(nn.Module): class FlashGemmaModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights, causal: bool): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - embed_norm = config.hidden_size**0.5 - self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights - ) - self.embed_tokens.weight *= embed_norm - self.layers = nn.ModuleList( [ FlashGemmaLayer( - layer_id, - config, - weights, + prefix=f"{prefix}.layers.{layer_id}", + config=config, + weights=weights, + causal=causal, ) for layer_id in range(config.num_hidden_layers) ] ) self.norm = GemmaFastRMSNorm.load( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps ) self.gradient_checkpointing = False @@ -385,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module): def forward( self, - input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -394,7 +390,7 @@ class FlashGemmaModel(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -423,13 +419,30 @@ class FlashGemmaModel(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights, causal: bool): super().__init__() - self.model = FlashGemmaModel(config, weights) + embed_norm = config.hidden_size**0.5 + if prefix is None: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + self.embed_tokens.weight *= embed_norm + + self.model = FlashGemmaModel( + prefix=prefix, config=config, weights=weights, causal=causal + ) self.lm_head = SpeculativeHead.load( - config, - prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", + prefix=( + f"{prefix}.embed_tokens" + if config.tie_word_embeddings + else f"{prefix}.lm_head" + ), + config=config, weights=weights, ) @@ -445,8 +458,9 @@ class FlashGemmaForCausalLM(torch.nn.Module): max_s: int, lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + input_embeds = self.embed_tokens(input_ids) hidden_states = self.model( - input_ids, + input_embeds, position_ids, cu_seqlen_prefill, kv_cache, diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py new file mode 100644 index 00000000..91c709e4 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed +from torch import nn +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear +from text_generation_server.models.custom_modeling.vlm import ( + load_text_model, + load_vision_model, +) + + +class PaliGemmaForConditionalGeneration(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + config.vision_config.quantize = config.quantize + self.vision_tower = load_vision_model( + prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", + config=config.vision_config, + weights=weights, + ) + + self.multi_modal_projector = TensorParallelColumnLinear.load( + config, + prefix="multi_modal_projector.linear", + weights=weights, + bias=True, + ) + + self.vocab_size = config.vocab_size + self.config = config + + text_config = config.text_config + text_config.speculator = config.speculator + text_config.quantize = config.quantize + self.text_model = load_text_model( + prefix="language_model" if not prefix else f"{prefix}.language_model", + config=config.text_config, + weights=weights, + ) + self.pad_token_id = ( + config.pad_token_id if config.pad_token_id is not None else -1 + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor = None, + # Unused here + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + inputs_embeds = self.text_model.embed_tokens(input_ids) + # TODO This is odd but apparently pali gemma position ids start at 1. + if cu_seqlen_prefill is not None: + max_s += 1 + position_ids += 1 + + if pixel_values is not None: + pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) + image_outputs = self.vision_tower(pixel_values) + image_features = self.multi_modal_projector(image_outputs.last_hidden_state) + + # mask where image or padding tokens + mask = input_ids == self.config.image_token_index + + # insert image features into input embeddings + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) + + hidden_states = self.text_model.model( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + ) + + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.text_model.lm_head(hidden_states) + + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py new file mode 100644 index 00000000..f17d6562 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -0,0 +1,565 @@ +from typing import Optional, Tuple, Union + +import math +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import ( + _create_4d_causal_attention_mask, + _prepare_4d_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, +) +from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig + +from text_generation_server.layers.tensor_parallel import ( + TensorParallelEmbedding, + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, prefix, config: SiglipVisionConfig, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.patch_embedding.weight = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False + ) + self.patch_embedding.bias = nn.Parameter( + weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False + ) + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = TensorParallelEmbedding( + prefix=f"{prefix}.position_embedding", weights=weights + ) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions, device=weights.device).expand((1, -1)), + persistent=False, + ) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + patch_embeds = self.patch_embedding( + pixel_values + ) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding( + config.max_position_embeddings, embed_dim + ) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = ( + input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.head_size = self.head_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.k_proj", weights=weights, bias=True + ) + self.v_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.v_proj", weights=weights, bias=True + ) + self.q_proj = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.q_proj", weights=weights, bias=True + ) + self.out_proj = TensorParallelRowLinear.load( + config, prefix=f"{prefix}.out_proj", weights=weights, bias=True + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + # scale post matmul + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) * self.scale + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(attn_weights.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SiglipMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = TensorParallelColumnLinear.load( # config.hidden_size, config.intermediate_size + prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True + ) + self.fc2 = TensorParallelRowLinear.load( # config.intermediate_size, config.hidden_size + prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(nn.Module): + def __init__(self, prefix, config: SiglipConfig, weights): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SiglipAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.layer_norm1 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps + ) + self.mlp = SiglipMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.layer_norm2 = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + if output_attentions: + return hidden_states, attn_weights + return hidden_states, None + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, prefix, config: SiglipVisionConfig, weights): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True + ) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(prefix, config, weights) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +import warnings + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +from torch.nn.init import _calculate_fan_in_and_fan_out + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +from transformers import PreTrainedModel + + +class SiglipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.log(torch.tensor(1.0)) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + + Args: + config: SiglipConfig + """ + + def __init__(self, prefix, config: SiglipConfig, weights): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + SiglipEncoderLayer( + prefix=f"{prefix}.layers.{i}", config=config, weights=weights + ) + for i in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[torch.Tensor] = None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + """ + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + hidden_states, _ = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + return hidden_states + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, prefix, config: SiglipVisionConfig, weights): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) + self.encoder = SiglipEncoder( + prefix=f"{prefix}.encoder", config=config, weights=weights + ) + self.post_layernorm = nn.LayerNorm.load( + prefix=f"{prefix}.post_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + ): + r""" + Returns: + + """ + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + + # NOTE: up until this point, the code logits are exactly + # the same as the transformers code. The values evaulate + # slightly differently in our encoder layer. + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + ) + last_hidden_state = encoder_outputs + post_last_hidden_state = self.post_layernorm(last_hidden_state) + + return BaseModelOutputWithPooling( + last_hidden_state=post_last_hidden_state, + # pooler_output=pooled_output, + # hidden_states=encoder_outputs, + ) diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index 690957d0..b74b43ff 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -11,6 +11,18 @@ def load_text_model(prefix, config, weights, name=None): ) return FlashMistralForCausalLM(prefix, config, weights, name=name) + elif config.model_type == "gemma": + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, + ) + + return FlashGemmaForCausalLM(prefix, config, weights, causal=False) + elif config.model_type == "paligemma": + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, + ) + + return FlashGemmaForCausalLM(prefix, config, weights) else: raise RuntimeError(f"Unsupported model type {config.model_type}") @@ -24,5 +36,13 @@ def load_vision_model(prefix, config, weights): return CLIPVisionTransformer( prefix=f"{prefix}.vision_model", config=config, weights=weights ) + if config.model_type == "siglip_vision_model": + from text_generation_server.models.custom_modeling.siglip import ( + SiglipVisionTransformer, + ) + + return SiglipVisionTransformer( + prefix=f"vision_tower.vision_model", config=config, weights=weights + ) else: raise RuntimeError(f"Unsupported model type {config.model_type}") diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 36351252..c029d8f3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -133,6 +133,17 @@ class FlashCausalLMBatch(Batch): device: torch.device, ) -> "FlashCausalLMBatch": batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) + return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) + + @classmethod + def from_tokenized( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + batch_tokenized_inputs, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashCausalLMBatch": position_ids = [] speculative_ids = [] cu_seqlen_prefill = [0] diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 9c00a056..53bfd064 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -3,12 +3,11 @@ import torch.distributed from opentelemetry import trace from typing import Optional -from transformers.models.gemma import GemmaTokenizerFast +from transformers import AutoConfig, AutoTokenizer from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( FlashGemmaForCausalLM, - GemmaConfig, ) from text_generation_server.utils import ( initialize_torch_distributed, @@ -36,17 +35,15 @@ class FlashGemma(FlashCausalLM): else: raise NotImplementedError("FlashGemma is only available on GPU") - tokenizer = GemmaTokenizerFast.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left", trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, ) - config = GemmaConfig.from_pretrained( + config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize @@ -59,7 +56,9 @@ class FlashGemma(FlashCausalLM): if config.quantize in ["gptq", "awq"]: weights._set_gptq_params(model_id, revision) - model = FlashGemmaForCausalLM(config, weights) + # TODO hardcoded + prefix = "language_model" + model = FlashGemmaForCausalLM(prefix, config, weights, causal=True) torch.distributed.barrier(group=self.process_group) super(FlashGemma, self).__init__( diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py new file mode 100644 index 00000000..d94b9526 --- /dev/null +++ b/server/text_generation_server/models/pali_gemma.py @@ -0,0 +1,123 @@ +import torch +import torch.distributed +from opentelemetry import trace +from typing import Optional, Tuple +from text_generation_server.models.vlm_causal_lm import ( + VlmCausalLM, + VlmCausalLMBatch, + image_text_replacement, + load_data_uri, + split, +) +from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( + PaliGemmaForConditionalGeneration, +) +from transformers import AutoProcessor, AutoConfig, AutoImageProcessor + +tracer = trace.get_tracer(__name__) + + +class PaliGemmaBatch(VlmCausalLMBatch): + @classmethod + def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): + batch_inputs = [] + image_inputs = [] + max_truncation = 0 + for r in requests: + chunks = split(r.inputs) + full_text = "" + image_id = 0 + for chunk in chunks: + if chunk["type"] == "text": + full_text += "" + chunk["content"] + "\n" + elif chunk["type"] == "image": + image = chunk["content"] + # Should never receive URLs anymore, processing should be done + # On the rust layer. + # This avoid making n queries per TP + # if image.startswith("https://") or image.startswith("http://"): + # image = processor.image_processor.fetch_images(image) + if image.startswith("data:"): + image = load_data_uri(image) + else: + raise RuntimeError( + "Cannot process input image not starting with data:" + ) + # TODO do_convert_RGB should be on by default ? + image = image.convert("RGB") + image_input = processor.image_processor(image, return_tensors="pt") + full_text += image_text_replacement(image_input, config, image_id) + image_inputs.append(image_input) + else: + raise RuntimeError(f"Invalid chunk type {chunk['type']}") + + batch_inputs.append(full_text) + max_truncation = max(max_truncation, r.truncate) + + batch_tokenized_inputs = tokenizer( + batch_inputs, + truncation=True, + max_length=max_truncation, + add_special_tokens=False, + )["input_ids"] + if image_inputs: + image_input = image_inputs[0] + new_image_inputs = { + "pixel_values": torch.cat( + [img["pixel_values"] for img in image_inputs], dim=0 + ), + } + if "pixel_attention_mask" in image_input: + new_image_inputs["pixel_attention_mask"] = torch.cat( + [img["pixel_attention_mask"] for img in image_inputs], dim=0 + ) + if "image_sizes" in image_input: + new_image_inputs["image_sizes"] = torch.cat( + [img["image_sizes"] for img in image_inputs], dim=0 + ) + image_inputs = new_image_inputs + else: + image_inputs = None + return batch_tokenized_inputs, image_inputs + + +class PaliGemma(VlmCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.processor = AutoProcessor.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + + super().__init__( + config_cls=AutoConfig, + model_cls=PaliGemmaForConditionalGeneration, + model_id=model_id, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + @property + def batch_type(self): + return PaliGemmaBatch + + def get_layer_config(self, model) -> Tuple[int, int, int]: + return ( + len(model.text_model.model.layers), + model.text_model.model.num_key_value_heads, + model.text_model.model.head_size, + ) + + def max_past(self) -> Optional[int]: + return getattr(self.model.text_model, "max_past", None) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 5394feb5..f0db89b2 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -15,6 +15,7 @@ from text_generation_server.models.flash_mistral import ( BaseFlashMistral, FlashMistralBatch, ) +from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.cache_manager import ( get_cache_manager, ) @@ -80,6 +81,9 @@ def image_text_replacement(image_input, config, image_id) -> str: logger.info(f"Found {num_features} in image of resolution {height}x{width}") return "" * num_features + + elif config.model_type == "paligemma": + return "" * config.text_config.num_image_tokens else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -193,7 +197,10 @@ class VlmCausalLMBatch(FlashMistralBatch): max_truncation = max(max_truncation, r.truncate) batch_tokenized_inputs = tokenizer( - batch_inputs, truncation=True, max_length=max_truncation + batch_inputs, + truncation=True, + max_length=max_truncation, + add_special_tokens=not config.model_type == "paligemma", )["input_ids"] if image_inputs: image_input = image_inputs[0] diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 0830656d..ae60fa63 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -116,6 +116,7 @@ if HAS_FLASH_ATTN_V2_CUDA: max_s, softmax_scale, window_size_left=-1, + causal=True, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -134,7 +135,7 @@ if HAS_FLASH_ATTN_V2_CUDA: 0.0, softmax_scale, False, - True, + causal, window_size_left, 0, False,