feat: load and query model

This commit is contained in:
drbh 2024-05-08 16:21:44 -04:00 committed by Nicolas Patry
parent e3d765645a
commit 5fd72ed06c
14 changed files with 1398 additions and 20 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

View File

@ -0,0 +1,39 @@
import pytest
import requests
import io
import base64
@pytest.fixture(scope="module")
def flash_pali_gemma_handle(launcher):
with launcher(
"Tinkering/test-bvhf",
num_shard=1,
max_input_length=4000,
max_total_tokens=4096,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_pali_gemma(flash_pali_gemma_handle):
await flash_pali_gemma_handle.health(300)
return flash_pali_gemma_handle.client
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
cow = get_cow_beach()
inputs = f"Where is the cow standing?\n![]({cow})"
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
# TODO: update this! this is incorrect and just to show the current state of the test
assert response.generated_text == ' - HDS'
# assert response.generated_text == "\nbeach"

View File

@ -118,6 +118,22 @@ impl Idefics2 {
} }
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct Paligemma {}
impl Paligemma {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
// TODO: improve to calculate based on height and width
// 224 = 256 image tokens
// 448 = 1024 image tokens
// 896 = 4096 image tokens
256
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")] #[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
@ -139,6 +155,7 @@ pub enum Config {
Phi3, Phi3,
Llama, Llama,
Baichuan, Baichuan,
Paligemma(Paligemma),
Gemma, Gemma,
Cohere, Cohere,
Drbx, Drbx,

View File

@ -540,6 +540,30 @@ fn prepare_input(
inputs = modified_inputs; inputs = modified_inputs;
tokenizer_query tokenizer_query
} }
Some(Config::Paligemma(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
modified_inputs.push_str(&inputs[start..chunk_start]);
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str(&"<image>".repeat(slots));
modified_inputs.push_str(&image_uri);
start = chunk_end;
}
if start != inputs.len() - 1 {
modified_inputs.push_str(&inputs[start..]);
tokenizer_query.push_str(&inputs[start..]);
}
inputs = modified_inputs;
tokenizer_query
}
Some(Config::Idefics2(config)) => { Some(Config::Idefics2(config)) => {
let mut modified_inputs = String::with_capacity(inputs.len()); let mut modified_inputs = String::with_capacity(inputs.len());
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());

View File

@ -75,6 +75,7 @@ try:
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.models.flash_dbrx import FlashDbrx from text_generation_server.models.flash_dbrx import FlashDbrx
from text_generation_server.models.flash_pali_gemma import FlashPaliGemma
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
except ImportError as e: except ImportError as e:
@ -433,6 +434,16 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "paligemma":
return FlashPaliGemma(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "cohere": if model_type == "cohere":
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCohere( return FlashCohere(

View File

@ -295,9 +295,9 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module): class FlashGemmaLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"{prefix or ''}model.layers.{layer_id}"
self.self_attn = FlashGemmaAttention( self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
@ -351,21 +351,30 @@ class FlashGemmaLayer(nn.Module):
class FlashGemmaModel(torch.nn.Module): class FlashGemmaModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
embed_norm = config.hidden_size**0.5 embed_norm = config.hidden_size**0.5
pvalue = f"{prefix + '.' if prefix else ''}model.embed_tokens"
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix=pvalue,
weights=weights,
# limit embed_tokens.weight size to the config.vocab_size
) )
self.embed_tokens.weight = torch.nn.Parameter(
self.embed_tokens.weight[: config.vocab_size, : config.hidden_size]
)
# TODO: double check why this is needed
self.embed_tokens.weight *= embed_norm self.embed_tokens.weight *= embed_norm
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashGemmaLayer( FlashGemmaLayer(
f"{prefix + '.' if prefix else ''}",
layer_id, layer_id,
config, config,
weights, weights,
@ -374,7 +383,9 @@ class FlashGemmaModel(torch.nn.Module):
] ]
) )
self.norm = GemmaFastRMSNorm.load( self.norm = GemmaFastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix + '.' if prefix else ''}model.norm",
weights=weights,
eps=config.rms_norm_eps,
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -385,7 +396,8 @@ class FlashGemmaModel(torch.nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, # input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -394,8 +406,8 @@ class FlashGemmaModel(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
@ -423,13 +435,15 @@ class FlashGemmaModel(torch.nn.Module):
class FlashGemmaForCausalLM(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.config = config
self.model = FlashGemmaModel(config, weights) self.model = FlashGemmaModel(prefix, config, weights)
prefix = f"{prefix + '.' if prefix else ''}model.embed_tokens"
prefix = prefix if config.tie_word_embeddings else "lm_head"
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", prefix=prefix,
weights=weights, weights=weights,
) )
@ -445,8 +459,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
input_ids, inputs_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,

View File

@ -0,0 +1,264 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.layers import TensorParallelColumnLinear
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
GemmaConfig,
)
# TODO: prefer using the following config classes
# * instead of the hack inside of the gemma modeling file
class VisionConfig(PretrainedConfig):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
model_type: str,
num_attention_heads: int,
num_hidden_layers: int,
num_image_tokens: int,
patch_size: int,
projection_dim: int,
projector_hidden_act: str,
vision_use_head: bool,
vocab_size: int,
quantize: Optional[str] = None,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.model_type = model_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.num_image_tokens = num_image_tokens
self.patch_size = patch_size
self.projection_dim = projection_dim
self.projector_hidden_act = projector_hidden_act
self.vision_use_head = vision_use_head
self.vocab_size = vocab_size
self.quantize = quantize
class PaliTextConfig(PretrainedConfig):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
model_type: str,
num_attention_heads: int,
num_hidden_layers: int,
num_image_tokens: int,
num_key_value_heads: int,
torch_dtype: str,
vocab_size: int,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.model_type = model_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.num_image_tokens = num_image_tokens
self.num_key_value_heads = num_key_value_heads
self.torch_dtype = torch_dtype
self.vocab_size = vocab_size
class PaliGemmaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=257216,
hidden_size=2048,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
tie_word_embeddings=True,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
text_config=None,
vision_config=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.head_dim = head_dim
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.text_config = GemmaConfig(
hidden_size=2048,
intermediate_size=16384,
model_type="gemma",
num_attention_heads=8,
num_hidden_layers=18,
num_image_tokens=256,
num_key_value_heads=1,
torch_dtype="float32",
vocab_size=257216,
)
self.vision_config = VisionConfig(
hidden_size=1152,
intermediate_size=4304,
model_type="siglip_vision_model",
num_attention_heads=16,
num_hidden_layers=27,
num_image_tokens=256,
patch_size=14,
projection_dim=2048,
projector_hidden_act="gelu_fast",
vision_use_head=False,
vocab_size=257152,
)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class FlashPaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
).to(weights.device, weights.dtype)
self.multi_modal_projector = TensorParallelColumnLinear.load(
config,
prefix="multi_modal_projector.linear",
weights=weights,
bias=True,
).to(weights.device, weights.dtype)
self.vocab_size = config.vocab_size
self.config = config
self.language_model = load_text_model(
prefix=prefix,
config=config,
weights=weights,
).to(weights.device, weights.dtype)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index
# Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
pixel_attention_mask=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
# merge text and images
if pixel_values is not None and len(pixel_values) > 0:
image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
# TODO: make sure to handle the specialized attention mask correctly
inputs_embeds = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids
)
hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -0,0 +1,578 @@
from typing import Optional, Tuple, Union
import math
import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import (
_create_4d_causal_attention_mask,
_prepare_4d_attention_mask,
)
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
)
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from text_generation_server.utils.layers import (
TensorParallelEmbedding,
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, prefix, config: SiglipVisionConfig, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.patch_embedding.bias = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
# TODO: remove this hack! figure out why off by one
self.position_embedding.weight = torch.nn.Parameter(
self.position_embedding.weight[:256, :]
)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
persistent=False,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
patch_embeds = self.patch_embedding(
pixel_values
) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class SiglipTextEmbeddings(nn.Module):
def __init__(self, config: SiglipTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(
config.max_position_embeddings, embed_dim
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
persistent=False,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = (
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.head_size = self.head_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
self.out_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=True,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
]
* 3,
dim=2,
)
key_states = self._shape(key_states, -1, bsz)
value_states = self._shape(value_states, -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_size)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
# scale post matmul
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) * self.scale
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(attn_weights.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_weights, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class SiglipMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load( # config.hidden_size, config.intermediate_size
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load( # config.intermediate_size, config.hidden_size
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module):
def __init__(self, prefix, config: SiglipConfig, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
)
self.mlp = SiglipMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
if output_attentions:
return hidden_states, attn_weights
print(hidden_states[0, 0, :5].tolist())
return hidden_states, None
class SiglipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(
config.hidden_size, config.num_attention_heads, batch_first=True
)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
import warnings
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
) -> torch.Tensor:
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsquently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
from torch.nn.init import _calculate_fan_in_and_fan_out
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
from transformers import PreTrainedModel
class SiglipPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SiglipConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SiglipVisionEmbeddings):
width = (
self.config.vision_config.hidden_size
if isinstance(self.config, SiglipConfig)
else self.config.hidden_size
)
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, SiglipAttention):
nn.init.xavier_uniform_(module.q_proj.weight)
nn.init.xavier_uniform_(module.k_proj.weight)
nn.init.xavier_uniform_(module.v_proj.weight)
nn.init.xavier_uniform_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, SiglipMLP):
nn.init.xavier_uniform_(module.fc1.weight)
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
nn.init.xavier_uniform_(module.probe.data)
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
elif isinstance(module, SiglipModel):
logit_scale_init = torch.log(torch.tensor(1.0))
module.logit_scale.data.fill_(logit_scale_init)
module.logit_bias.data.zero_()
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class SiglipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`SiglipEncoderLayer`].
Args:
config: SiglipConfig
"""
def __init__(self, prefix, config: SiglipConfig, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
SiglipEncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
)
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[torch.Tensor] = None,
):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
hidden_states, _ = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
return hidden_states
class SiglipVisionTransformer(nn.Module):
def __init__(self, prefix, config: SiglipVisionConfig, weights):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.encoder = SiglipEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
):
r"""
Returns:
"""
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
# NOTE: up until this point, the code logits are exactly
# the same as the transformers code. The values evaulate
# slightly differently in our encoder layer.
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
)
last_hidden_state = encoder_outputs
post_last_hidden_state = self.post_layernorm(last_hidden_state)
return BaseModelOutputWithPooling(
last_hidden_state=post_last_hidden_state,
# pooler_output=pooled_output,
# hidden_states=encoder_outputs,
)

View File

@ -11,6 +11,12 @@ def load_text_model(prefix, config, weights, name=None):
) )
return FlashMistralForCausalLM(prefix, config, weights, name=name) return FlashMistralForCausalLM(prefix, config, weights, name=name)
elif config.model_type == "gemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
return FlashGemmaForCausalLM(prefix, config, weights)
else: else:
raise RuntimeError(f"Unsupported model type {config.model_type}") raise RuntimeError(f"Unsupported model type {config.model_type}")
@ -24,5 +30,13 @@ def load_vision_model(prefix, config, weights):
return CLIPVisionTransformer( return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights prefix=f"{prefix}.vision_model", config=config, weights=weights
) )
if config.model_type == "siglip_vision_model":
from text_generation_server.models.custom_modeling.siglip import (
SiglipVisionTransformer,
)
return SiglipVisionTransformer(
prefix=f"vision_tower.vision_model", config=config, weights=weights
)
else: else:
raise RuntimeError(f"Unsupported model type {config.model_type}") raise RuntimeError(f"Unsupported model type {config.model_type}")

View File

@ -133,6 +133,17 @@ class FlashCausalLMBatch(Batch):
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@classmethod
def from_tokenized(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
batch_tokenized_inputs,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
position_ids = [] position_ids = []
speculative_ids = [] speculative_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
@ -207,6 +218,7 @@ class FlashCausalLMBatch(Batch):
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
speculative_length = get_speculate() speculative_length = get_speculate()
speculative_length = 0 if speculative_length is None else speculative_length
total_tokens = input_length + max_new_tokens - 1 + speculative_length total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks blocks += needed_blocks

View File

@ -4,6 +4,7 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Optional from typing import Optional
from transformers.models.gemma import GemmaTokenizerFast from transformers.models.gemma import GemmaTokenizerFast
from transformers import AutoConfig, PretrainedConfig
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
@ -19,15 +20,58 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
class FlashGemma(FlashCausalLM): class VisionConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
hidden_size: int = 1152,
intermediate_size: int = 4304,
model_type: str = "siglip_vision_model",
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
num_image_tokens: int = 256,
patch_size: int = 14,
projection_dim: int = 2048,
projector_hidden_act: str = "gelu_fast",
vision_use_head: bool = False,
vocab_size: int = 257152,
quantize: Optional[str] = None,
image_size: int = 224,
layer_norm_eps: float = 1e-06,
attention_dropout: float = 0.0,
hidden_act: str = "gelu_pytorch_tanh",
num_channels: int = 3,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.model_type = model_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.num_image_tokens = num_image_tokens
self.patch_size = patch_size
self.projection_dim = projection_dim
self.projector_hidden_act = projector_hidden_act
self.vision_use_head = vision_use_head
self.vocab_size = vocab_size
self.quantize = quantize
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.num_channels = num_channels
class BaseFlashGemma(FlashCausalLM):
def __init__(
self,
model_cls,
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculator: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
prefix: Optional[str] = None,
config_cls=AutoConfig,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -49,9 +93,39 @@ class FlashGemma(FlashCausalLM):
config = GemmaConfig.from_pretrained( config = GemmaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
is_vlm = hasattr(config, "vision_config") and hasattr(config, "text_config")
if is_vlm:
config.vision_config = VisionConfig(
hidden_size=1152,
intermediate_size=4304,
model_type="siglip_vision_model",
num_attention_heads=16,
num_hidden_layers=27,
num_image_tokens=256,
patch_size=14,
projection_dim=2048,
projector_hidden_act="gelu_fast",
vision_use_head=False,
vocab_size=257152,
quantize=quantize,
)
config.quantize = quantize config.quantize = quantize
config.speculator = speculator config.speculator = speculator
if is_vlm:
config.num_hidden_layers = config.text_config.get("num_hidden_layers")
config.intermediate_size = config.text_config.get("intermediate_size")
config.model_type = config.text_config.get("model_type")
config.num_attention_heads = config.text_config.get("num_attention_heads")
config.num_hidden_layers = config.text_config.get("num_hidden_layers")
config.num_image_tokens = config.text_config.get("num_image_tokens")
config.num_key_value_heads = config.text_config.get("num_key_value_heads")
config.torch_dtype = config.text_config.get("torch_dtype")
config.vocab_size = config.text_config.get("vocab_size")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
@ -59,17 +133,49 @@ class FlashGemma(FlashCausalLM):
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashGemmaForCausalLM(config, weights) model = model_cls(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__(
if is_vlm:
num_layers = config.num_hidden_layers
num_kv_heads = config.num_key_value_heads
head_size = config.intermediate_size
else:
num_layers = len(model.model.layers)
num_kv_heads = model.model.num_key_value_heads
head_size = model.model.head_size
super().__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=num_layers,
num_kv_heads=model.model.num_key_value_heads, num_kv_heads=num_kv_heads,
head_size=model.model.head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
class FlashGemma(BaseFlashGemma):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashGemma, self).__init__(
model_cls=FlashGemmaForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
prefix=None,
)

View File

@ -0,0 +1,54 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional, Tuple
from text_generation_server.models.vlm_causal_lm import PaliVlmCausalLM
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
FlashPaliGemmaForConditionalGeneration,
PaliGemmaConfig,
PaliTextConfig,
)
from transformers import AutoProcessor
tracer = trace.get_tracer(__name__)
class FlashPaliGemma(PaliVlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
# TODO: load in the correct processor based on the model_id
"google/siglip-base-patch16-224",
# "google/siglip-so400m-patch14-384",
revision=revision,
trust_remote_code=trust_remote_code,
)
super().__init__(
config_cls=PaliTextConfig,
model_cls=FlashPaliGemmaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
prefix="language_model",
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)

View File

@ -15,6 +15,8 @@ from text_generation_server.models.flash_mistral import (
BaseFlashMistral, BaseFlashMistral,
FlashMistralBatch, FlashMistralBatch,
) )
from text_generation_server.models.flash_gemma import BaseFlashGemma
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.cache_manager import ( from text_generation_server.models.cache_manager import (
get_cache_manager, get_cache_manager,
) )
@ -80,6 +82,11 @@ def image_text_replacement(image_input, config, image_id) -> str:
logger.info(f"Found {num_features} in image of resolution {height}x{width}") logger.info(f"Found {num_features} in image of resolution {height}x{width}")
return "<image>" * num_features return "<image>" * num_features
# TODO: double check correct naming for model_type
elif config.model_type == "gemma":
# TODO: use correct number of features
return "<image>" * 256
else: else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal") raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -371,3 +378,238 @@ class VlmCausalLM(BaseFlashMistral):
) )
logits = cuda_graph["logits"][:bs] logits = cuda_graph["logits"][:bs]
return logits, speculative_logits return logits, speculative_logits
class PaliVlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(PaliVlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "PaliVlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
class PaliVlmCausalLM(BaseFlashGemma):
@property
def batch_type(self) -> Type[PaliVlmCausalLMBatch]:
return PaliVlmCausalLMBatch
def forward(
self, batch: PaliVlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
# prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
)
# if batch.prefill_cache_indices is not None:
# batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits

View File

@ -14,7 +14,7 @@ from typing import List, Optional
from text_generation_server.cache import Cache from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model from text_generation_server.models import Model, get_model
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, PaliVlmCausalLMBatch
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
@ -98,6 +98,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if self.model.batch_type in { if self.model.batch_type in {
IdeficsCausalLMBatch, IdeficsCausalLMBatch,
VlmCausalLMBatch, VlmCausalLMBatch,
PaliVlmCausalLMBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call }: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor( batch = self.model.batch_type.from_pb_processor(
request.batch, request.batch,
@ -122,6 +123,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if self.model.batch_type in { if self.model.batch_type in {
IdeficsCausalLMBatch, IdeficsCausalLMBatch,
VlmCausalLMBatch, VlmCausalLMBatch,
PaliVlmCausalLMBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call }: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor( batch = self.model.batch_type.from_pb_processor(
request.batch, request.batch,