mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: load and query model
This commit is contained in:
parent
e3d765645a
commit
5fd72ed06c
BIN
integration-tests/images/cow_beach.png
Normal file
BIN
integration-tests/images/cow_beach.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 66 KiB |
39
integration-tests/models/test_flash_pali_gemma.py
Normal file
39
integration-tests/models/test_flash_pali_gemma.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_pali_gemma_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"Tinkering/test-bvhf",
|
||||||
|
num_shard=1,
|
||||||
|
max_input_length=4000,
|
||||||
|
max_total_tokens=4096,
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_pali_gemma(flash_pali_gemma_handle):
|
||||||
|
await flash_pali_gemma_handle.health(300)
|
||||||
|
return flash_pali_gemma_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
def get_cow_beach():
|
||||||
|
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
||||||
|
cow = get_cow_beach()
|
||||||
|
inputs = f"Where is the cow standing?\n"
|
||||||
|
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
|
||||||
|
|
||||||
|
# TODO: update this! this is incorrect and just to show the current state of the test
|
||||||
|
assert response.generated_text == ' - HDS'
|
||||||
|
# assert response.generated_text == "\nbeach"
|
@ -118,6 +118,22 @@ impl Idefics2 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "model_type")]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct Paligemma {}
|
||||||
|
|
||||||
|
impl Paligemma {
|
||||||
|
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
||||||
|
// TODO: improve to calculate based on height and width
|
||||||
|
// 224 = 256 image tokens
|
||||||
|
// 448 = 1024 image tokens
|
||||||
|
// 896 = 4096 image tokens
|
||||||
|
|
||||||
|
256
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(tag = "model_type")]
|
#[serde(tag = "model_type")]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
@ -139,6 +155,7 @@ pub enum Config {
|
|||||||
Phi3,
|
Phi3,
|
||||||
Llama,
|
Llama,
|
||||||
Baichuan,
|
Baichuan,
|
||||||
|
Paligemma(Paligemma),
|
||||||
Gemma,
|
Gemma,
|
||||||
Cohere,
|
Cohere,
|
||||||
Drbx,
|
Drbx,
|
||||||
|
@ -540,6 +540,30 @@ fn prepare_input(
|
|||||||
inputs = modified_inputs;
|
inputs = modified_inputs;
|
||||||
tokenizer_query
|
tokenizer_query
|
||||||
}
|
}
|
||||||
|
Some(Config::Paligemma(config)) => {
|
||||||
|
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||||
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
|
let mut start = 0;
|
||||||
|
for chunk in RE.find_iter(&inputs) {
|
||||||
|
let chunk_start = chunk.start();
|
||||||
|
let chunk_end = chunk.end();
|
||||||
|
if chunk_start != start {
|
||||||
|
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||||
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
|
}
|
||||||
|
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
|
let slots = config.get_number_of_features(height, width);
|
||||||
|
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||||
|
modified_inputs.push_str(&image_uri);
|
||||||
|
start = chunk_end;
|
||||||
|
}
|
||||||
|
if start != inputs.len() - 1 {
|
||||||
|
modified_inputs.push_str(&inputs[start..]);
|
||||||
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
|
}
|
||||||
|
inputs = modified_inputs;
|
||||||
|
tokenizer_query
|
||||||
|
}
|
||||||
Some(Config::Idefics2(config)) => {
|
Some(Config::Idefics2(config)) => {
|
||||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
|
@ -75,6 +75,7 @@ try:
|
|||||||
from text_generation_server.models.flash_phi import FlashPhi
|
from text_generation_server.models.flash_phi import FlashPhi
|
||||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
from text_generation_server.models.flash_dbrx import FlashDbrx
|
||||||
|
from text_generation_server.models.flash_pali_gemma import FlashPaliGemma
|
||||||
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@ -433,6 +434,16 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_type == "paligemma":
|
||||||
|
return FlashPaliGemma(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == "cohere":
|
if model_type == "cohere":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCohere(
|
return FlashCohere(
|
||||||
|
@ -295,9 +295,9 @@ class GemmaMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemmaLayer(nn.Module):
|
class FlashGemmaLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, prefix, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"model.layers.{layer_id}"
|
prefix = f"{prefix or ''}model.layers.{layer_id}"
|
||||||
self.self_attn = FlashGemmaAttention(
|
self.self_attn = FlashGemmaAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
@ -351,21 +351,30 @@ class FlashGemmaLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemmaModel(torch.nn.Module):
|
class FlashGemmaModel(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
embed_norm = config.hidden_size**0.5
|
embed_norm = config.hidden_size**0.5
|
||||||
|
pvalue = f"{prefix + '.' if prefix else ''}model.embed_tokens"
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="model.embed_tokens", weights=weights
|
prefix=pvalue,
|
||||||
|
weights=weights,
|
||||||
|
# limit embed_tokens.weight size to the config.vocab_size
|
||||||
)
|
)
|
||||||
|
self.embed_tokens.weight = torch.nn.Parameter(
|
||||||
|
self.embed_tokens.weight[: config.vocab_size, : config.hidden_size]
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: double check why this is needed
|
||||||
self.embed_tokens.weight *= embed_norm
|
self.embed_tokens.weight *= embed_norm
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashGemmaLayer(
|
FlashGemmaLayer(
|
||||||
|
f"{prefix + '.' if prefix else ''}",
|
||||||
layer_id,
|
layer_id,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
@ -374,7 +383,9 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = GemmaFastRMSNorm.load(
|
self.norm = GemmaFastRMSNorm.load(
|
||||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix + '.' if prefix else ''}model.norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@ -385,7 +396,8 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
# input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -394,8 +406,8 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
@ -423,13 +435,15 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemmaForCausalLM(torch.nn.Module):
|
class FlashGemmaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
self.model = FlashGemmaModel(config, weights)
|
self.model = FlashGemmaModel(prefix, config, weights)
|
||||||
|
prefix = f"{prefix + '.' if prefix else ''}model.embed_tokens"
|
||||||
|
prefix = prefix if config.tie_word_embeddings else "lm_head"
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
|
prefix=prefix,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -445,8 +459,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
inputs_embeds,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
|
@ -0,0 +1,264 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import TensorParallelColumnLinear
|
||||||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
|
load_text_model,
|
||||||
|
load_vision_model,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
|
GemmaConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: prefer using the following config classes
|
||||||
|
# * instead of the hack inside of the gemma modeling file
|
||||||
|
class VisionConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
model_type: str,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_hidden_layers: int,
|
||||||
|
num_image_tokens: int,
|
||||||
|
patch_size: int,
|
||||||
|
projection_dim: int,
|
||||||
|
projector_hidden_act: str,
|
||||||
|
vision_use_head: bool,
|
||||||
|
vocab_size: int,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.model_type = model_type
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_image_tokens = num_image_tokens
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.projection_dim = projection_dim
|
||||||
|
self.projector_hidden_act = projector_hidden_act
|
||||||
|
self.vision_use_head = vision_use_head
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.quantize = quantize
|
||||||
|
|
||||||
|
|
||||||
|
class PaliTextConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
model_type: str,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_hidden_layers: int,
|
||||||
|
num_image_tokens: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
torch_dtype: str,
|
||||||
|
vocab_size: int,
|
||||||
|
):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.model_type = model_type
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_image_tokens = num_image_tokens
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=257216,
|
||||||
|
hidden_size=2048,
|
||||||
|
intermediate_size=24576,
|
||||||
|
num_hidden_layers=28,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=16,
|
||||||
|
head_dim=256,
|
||||||
|
hidden_act="gelu_pytorch_tanh",
|
||||||
|
max_position_embeddings=8192,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=2,
|
||||||
|
eos_token_id=1,
|
||||||
|
tie_word_embeddings=True,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
text_config=None,
|
||||||
|
vision_config=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
self.text_config = GemmaConfig(
|
||||||
|
hidden_size=2048,
|
||||||
|
intermediate_size=16384,
|
||||||
|
model_type="gemma",
|
||||||
|
num_attention_heads=8,
|
||||||
|
num_hidden_layers=18,
|
||||||
|
num_image_tokens=256,
|
||||||
|
num_key_value_heads=1,
|
||||||
|
torch_dtype="float32",
|
||||||
|
vocab_size=257216,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vision_config = VisionConfig(
|
||||||
|
hidden_size=1152,
|
||||||
|
intermediate_size=4304,
|
||||||
|
model_type="siglip_vision_model",
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_hidden_layers=27,
|
||||||
|
num_image_tokens=256,
|
||||||
|
patch_size=14,
|
||||||
|
projection_dim=2048,
|
||||||
|
projector_hidden_act="gelu_fast",
|
||||||
|
vision_use_head=False,
|
||||||
|
vocab_size=257152,
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
config.vision_config.quantize = config.quantize
|
||||||
|
|
||||||
|
self.vision_tower = load_vision_model(
|
||||||
|
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||||
|
config=config.vision_config,
|
||||||
|
weights=weights,
|
||||||
|
).to(weights.device, weights.dtype)
|
||||||
|
|
||||||
|
self.multi_modal_projector = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix="multi_modal_projector.linear",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
).to(weights.device, weights.dtype)
|
||||||
|
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.language_model = load_text_model(
|
||||||
|
prefix=prefix,
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
).to(weights.device, weights.dtype)
|
||||||
|
self.pad_token_id = (
|
||||||
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
def _merge_input_ids_with_image_features(
|
||||||
|
self, image_features, inputs_embeds, input_ids
|
||||||
|
):
|
||||||
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
|
mask = input_ids == self.config.image_token_index
|
||||||
|
# Let's pray we have enabled enough slots !
|
||||||
|
try:
|
||||||
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||||
|
pixel_attention_mask=None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
|
||||||
|
# merge text and images
|
||||||
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
|
image_outputs = self.vision_tower(pixel_values)
|
||||||
|
selected_image_feature = image_outputs.last_hidden_state
|
||||||
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
|
# TODO: make sure to handle the specialized attention mask correctly
|
||||||
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
|
image_features, inputs_embeds, input_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.language_model.model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
||||||
|
|
||||||
|
return logits, speculative_logits
|
578
server/text_generation_server/models/custom_modeling/siglip.py
Normal file
578
server/text_generation_server/models/custom_modeling/siglip.py
Normal file
@ -0,0 +1,578 @@
|
|||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.modeling_attn_mask_utils import (
|
||||||
|
_create_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_attention_mask,
|
||||||
|
)
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutput,
|
||||||
|
BaseModelOutputWithPooling,
|
||||||
|
ImageClassifierOutput,
|
||||||
|
)
|
||||||
|
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipVisionEmbeddings(nn.Module):
|
||||||
|
def __init__(self, prefix, config: SiglipVisionConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=self.embed_dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
padding="valid",
|
||||||
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.patch_embedding.bias = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||||
|
self.num_positions = self.num_patches
|
||||||
|
self.position_embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.position_embedding", weights=weights
|
||||||
|
)
|
||||||
|
# TODO: remove this hack! figure out why off by one
|
||||||
|
self.position_embedding.weight = torch.nn.Parameter(
|
||||||
|
self.position_embedding.weight[:256, :]
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"position_ids",
|
||||||
|
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||||
|
patch_embeds = self.patch_embedding(
|
||||||
|
pixel_values
|
||||||
|
) # shape = [*, width, grid, grid]
|
||||||
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipTextEmbeddings(nn.Module):
|
||||||
|
def __init__(self, config: SiglipTextConfig):
|
||||||
|
super().__init__()
|
||||||
|
embed_dim = config.hidden_size
|
||||||
|
|
||||||
|
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
||||||
|
self.position_embedding = nn.Embedding(
|
||||||
|
config.max_position_embeddings, embed_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
|
self.register_buffer(
|
||||||
|
"position_ids",
|
||||||
|
torch.arange(config.max_position_embeddings).expand((1, -1)),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
seq_length = (
|
||||||
|
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = self.position_ids[:, :seq_length]
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.token_embedding(input_ids)
|
||||||
|
|
||||||
|
position_embeddings = self.position_embedding(position_ids)
|
||||||
|
embeddings = inputs_embeds + position_embeddings
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
self.head_size = self.head_dim
|
||||||
|
if self.head_dim * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.out_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return (
|
||||||
|
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
qkv = self.qkv(hidden_states)
|
||||||
|
query_states, key_states, value_states = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
]
|
||||||
|
* 3,
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
key_states = self._shape(key_states, -1, bsz)
|
||||||
|
value_states = self._shape(value_states, -1, bsz)
|
||||||
|
|
||||||
|
proj_shape = (bsz * self.num_heads, -1, self.head_size)
|
||||||
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||||
|
key_states = key_states.view(*proj_shape)
|
||||||
|
value_states = value_states.view(*proj_shape)
|
||||||
|
|
||||||
|
src_len = key_states.size(1)
|
||||||
|
# scale post matmul
|
||||||
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) * self.scale
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = (
|
||||||
|
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
+ attention_mask
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(attn_weights.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(
|
||||||
|
attn_weights, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
|
attn_output = torch.bmm(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load( # config.hidden_size, config.intermediate_size
|
||||||
|
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load( # config.intermediate_size, config.hidden_size
|
||||||
|
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipEncoderLayer(nn.Module):
|
||||||
|
def __init__(self, prefix, config: SiglipConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.self_attn = SiglipAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.layer_norm1 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
self.mlp = SiglipMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
self.layer_norm2 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.FloatTensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`):
|
||||||
|
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
||||||
|
attention_mask (`torch.FloatTensor`):
|
||||||
|
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
||||||
|
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states, attn_weights = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
if output_attentions:
|
||||||
|
return hidden_states, attn_weights
|
||||||
|
print(hidden_states[0, 0, :5].tolist())
|
||||||
|
return hidden_states, None
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||||
|
"""Multihead Attention Pooling."""
|
||||||
|
|
||||||
|
def __init__(self, config: SiglipVisionConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
||||||
|
self.attention = torch.nn.MultiheadAttention(
|
||||||
|
config.hidden_size, config.num_attention_heads, batch_first=True
|
||||||
|
)
|
||||||
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.mlp = SiglipMLP(config)
|
||||||
|
|
||||||
|
def forward(self, hidden_state):
|
||||||
|
batch_size = hidden_state.shape[0]
|
||||||
|
probe = self.probe.repeat(batch_size, 1, 1)
|
||||||
|
|
||||||
|
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
||||||
|
|
||||||
|
residual = hidden_state
|
||||||
|
hidden_state = self.layernorm(hidden_state)
|
||||||
|
hidden_state = residual + self.mlp(hidden_state)
|
||||||
|
|
||||||
|
return hidden_state[:, 0]
|
||||||
|
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
def _trunc_normal_(tensor, mean, std, a, b):
|
||||||
|
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||||
|
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||||
|
def norm_cdf(x):
|
||||||
|
# Computes standard normal cumulative distribution function
|
||||||
|
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
||||||
|
|
||||||
|
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||||
|
warnings.warn(
|
||||||
|
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||||
|
"The distribution of values may be incorrect.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Values are generated by using a truncated uniform distribution and
|
||||||
|
# then using the inverse CDF for the normal distribution.
|
||||||
|
# Get upper and lower cdf values
|
||||||
|
l = norm_cdf((a - mean) / std)
|
||||||
|
u = norm_cdf((b - mean) / std)
|
||||||
|
|
||||||
|
# Uniformly fill tensor with values from [l, u], then translate to
|
||||||
|
# [2l-1, 2u-1].
|
||||||
|
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||||
|
|
||||||
|
# Use inverse cdf transform for normal distribution to get truncated
|
||||||
|
# standard normal
|
||||||
|
tensor.erfinv_()
|
||||||
|
|
||||||
|
# Transform to proper mean, std
|
||||||
|
tensor.mul_(std * math.sqrt(2.0))
|
||||||
|
tensor.add_(mean)
|
||||||
|
|
||||||
|
# Clamp to ensure it's in the proper range
|
||||||
|
tensor.clamp_(min=a, max=b)
|
||||||
|
|
||||||
|
|
||||||
|
def trunc_normal_tf_(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
mean: float = 0.0,
|
||||||
|
std: float = 1.0,
|
||||||
|
a: float = -2.0,
|
||||||
|
b: float = 2.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Fills the input Tensor with values drawn from a truncated
|
||||||
|
normal distribution. The values are effectively drawn from the
|
||||||
|
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||||
|
with values outside :math:`[a, b]` redrawn until they are within
|
||||||
|
the bounds. The method used for generating the random values works
|
||||||
|
best when :math:`a \\leq \text{mean} \\leq b`.
|
||||||
|
|
||||||
|
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
||||||
|
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
||||||
|
and the result is subsquently scaled and shifted by the mean and std args.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: an n-dimensional `torch.Tensor`
|
||||||
|
mean: the mean of the normal distribution
|
||||||
|
std: the standard deviation of the normal distribution
|
||||||
|
a: the minimum cutoff value
|
||||||
|
b: the maximum cutoff value
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
_trunc_normal_(tensor, 0, 1.0, a, b)
|
||||||
|
tensor.mul_(std).add_(mean)
|
||||||
|
|
||||||
|
|
||||||
|
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||||
|
|
||||||
|
|
||||||
|
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
||||||
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||||
|
if mode == "fan_in":
|
||||||
|
denom = fan_in
|
||||||
|
elif mode == "fan_out":
|
||||||
|
denom = fan_out
|
||||||
|
elif mode == "fan_avg":
|
||||||
|
denom = (fan_in + fan_out) / 2
|
||||||
|
|
||||||
|
variance = scale / denom
|
||||||
|
|
||||||
|
if distribution == "truncated_normal":
|
||||||
|
# constant is stddev of standard normal truncated to (-2, 2)
|
||||||
|
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
||||||
|
elif distribution == "normal":
|
||||||
|
with torch.no_grad():
|
||||||
|
tensor.normal_(std=math.sqrt(variance))
|
||||||
|
elif distribution == "uniform":
|
||||||
|
bound = math.sqrt(3 * variance)
|
||||||
|
with torch.no_grad():
|
||||||
|
tensor.uniform_(-bound, bound)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid distribution {distribution}")
|
||||||
|
|
||||||
|
|
||||||
|
def lecun_normal_(tensor):
|
||||||
|
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
||||||
|
|
||||||
|
|
||||||
|
def default_flax_embed_init(tensor):
|
||||||
|
variance_scaling_(tensor, mode="fan_in", distribution="normal")
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipPreTrainedModel(PreTrainedModel):
|
||||||
|
"""
|
||||||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = SiglipConfig
|
||||||
|
base_model_prefix = "siglip"
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
"""Initialize the weights"""
|
||||||
|
if isinstance(module, SiglipVisionEmbeddings):
|
||||||
|
width = (
|
||||||
|
self.config.vision_config.hidden_size
|
||||||
|
if isinstance(self.config, SiglipConfig)
|
||||||
|
else self.config.hidden_size
|
||||||
|
)
|
||||||
|
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
default_flax_embed_init(module.weight)
|
||||||
|
elif isinstance(module, SiglipAttention):
|
||||||
|
nn.init.xavier_uniform_(module.q_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(module.k_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(module.v_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(module.out_proj.weight)
|
||||||
|
nn.init.zeros_(module.q_proj.bias)
|
||||||
|
nn.init.zeros_(module.k_proj.bias)
|
||||||
|
nn.init.zeros_(module.v_proj.bias)
|
||||||
|
nn.init.zeros_(module.out_proj.bias)
|
||||||
|
elif isinstance(module, SiglipMLP):
|
||||||
|
nn.init.xavier_uniform_(module.fc1.weight)
|
||||||
|
nn.init.xavier_uniform_(module.fc2.weight)
|
||||||
|
nn.init.normal_(module.fc1.bias, std=1e-6)
|
||||||
|
nn.init.normal_(module.fc2.bias, std=1e-6)
|
||||||
|
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
|
||||||
|
nn.init.xavier_uniform_(module.probe.data)
|
||||||
|
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
|
||||||
|
nn.init.zeros_(module.attention.in_proj_bias.data)
|
||||||
|
elif isinstance(module, SiglipModel):
|
||||||
|
logit_scale_init = torch.log(torch.tensor(1.0))
|
||||||
|
module.logit_scale.data.fill_(logit_scale_init)
|
||||||
|
module.logit_bias.data.zero_()
|
||||||
|
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||||
|
lecun_normal_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||||
|
[`SiglipEncoderLayer`].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: SiglipConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prefix, config: SiglipConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SiglipEncoderLayer(
|
||||||
|
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||||
|
)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
"""
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
|
hidden_states, _ = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipVisionTransformer(nn.Module):
|
||||||
|
def __init__(self, prefix, config: SiglipVisionConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
embed_dim = config.hidden_size
|
||||||
|
|
||||||
|
self.embeddings = SiglipVisionEmbeddings(
|
||||||
|
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.encoder = SiglipEncoder(
|
||||||
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.post_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if pixel_values is None:
|
||||||
|
raise ValueError("You have to specify pixel_values")
|
||||||
|
|
||||||
|
hidden_states = self.embeddings(pixel_values)
|
||||||
|
|
||||||
|
# NOTE: up until this point, the code logits are exactly
|
||||||
|
# the same as the transformers code. The values evaulate
|
||||||
|
# slightly differently in our encoder layer.
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
)
|
||||||
|
last_hidden_state = encoder_outputs
|
||||||
|
post_last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPooling(
|
||||||
|
last_hidden_state=post_last_hidden_state,
|
||||||
|
# pooler_output=pooled_output,
|
||||||
|
# hidden_states=encoder_outputs,
|
||||||
|
)
|
@ -11,6 +11,12 @@ def load_text_model(prefix, config, weights, name=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return FlashMistralForCausalLM(prefix, config, weights, name=name)
|
return FlashMistralForCausalLM(prefix, config, weights, name=name)
|
||||||
|
elif config.model_type == "gemma":
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
|
FlashGemmaForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashGemmaForCausalLM(prefix, config, weights)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||||
|
|
||||||
@ -24,5 +30,13 @@ def load_vision_model(prefix, config, weights):
|
|||||||
return CLIPVisionTransformer(
|
return CLIPVisionTransformer(
|
||||||
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
||||||
)
|
)
|
||||||
|
if config.model_type == "siglip_vision_model":
|
||||||
|
from text_generation_server.models.custom_modeling.siglip import (
|
||||||
|
SiglipVisionTransformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SiglipVisionTransformer(
|
||||||
|
prefix=f"vision_tower.vision_model", config=config, weights=weights
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||||
|
@ -133,6 +133,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||||
|
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tokenized(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
batch_tokenized_inputs,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "FlashCausalLMBatch":
|
||||||
position_ids = []
|
position_ids = []
|
||||||
speculative_ids = []
|
speculative_ids = []
|
||||||
cu_seqlen_prefill = [0]
|
cu_seqlen_prefill = [0]
|
||||||
@ -207,6 +218,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Paged attention
|
# Paged attention
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
speculative_length = get_speculate()
|
speculative_length = get_speculate()
|
||||||
|
speculative_length = 0 if speculative_length is None else speculative_length
|
||||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||||
blocks += needed_blocks
|
blocks += needed_blocks
|
||||||
|
@ -4,6 +4,7 @@ import torch.distributed
|
|||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from transformers.models.gemma import GemmaTokenizerFast
|
from transformers.models.gemma import GemmaTokenizerFast
|
||||||
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
@ -19,15 +20,58 @@ from text_generation_server.utils import (
|
|||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma(FlashCausalLM):
|
class VisionConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
hidden_size: int = 1152,
|
||||||
|
intermediate_size: int = 4304,
|
||||||
|
model_type: str = "siglip_vision_model",
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_hidden_layers: int = 27,
|
||||||
|
num_image_tokens: int = 256,
|
||||||
|
patch_size: int = 14,
|
||||||
|
projection_dim: int = 2048,
|
||||||
|
projector_hidden_act: str = "gelu_fast",
|
||||||
|
vision_use_head: bool = False,
|
||||||
|
vocab_size: int = 257152,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
image_size: int = 224,
|
||||||
|
layer_norm_eps: float = 1e-06,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
hidden_act: str = "gelu_pytorch_tanh",
|
||||||
|
num_channels: int = 3,
|
||||||
|
):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.model_type = model_type
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_image_tokens = num_image_tokens
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.projection_dim = projection_dim
|
||||||
|
self.projector_hidden_act = projector_hidden_act
|
||||||
|
self.vision_use_head = vision_use_head
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.quantize = quantize
|
||||||
|
self.image_size = image_size
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.num_channels = num_channels
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFlashGemma(FlashCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_cls,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
speculator: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
config_cls=AutoConfig,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -49,9 +93,39 @@ class FlashGemma(FlashCausalLM):
|
|||||||
config = GemmaConfig.from_pretrained(
|
config = GemmaConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_vlm = hasattr(config, "vision_config") and hasattr(config, "text_config")
|
||||||
|
|
||||||
|
if is_vlm:
|
||||||
|
config.vision_config = VisionConfig(
|
||||||
|
hidden_size=1152,
|
||||||
|
intermediate_size=4304,
|
||||||
|
model_type="siglip_vision_model",
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_hidden_layers=27,
|
||||||
|
num_image_tokens=256,
|
||||||
|
patch_size=14,
|
||||||
|
projection_dim=2048,
|
||||||
|
projector_hidden_act="gelu_fast",
|
||||||
|
vision_use_head=False,
|
||||||
|
vocab_size=257152,
|
||||||
|
quantize=quantize,
|
||||||
|
)
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.speculator = speculator
|
config.speculator = speculator
|
||||||
|
|
||||||
|
if is_vlm:
|
||||||
|
config.num_hidden_layers = config.text_config.get("num_hidden_layers")
|
||||||
|
config.intermediate_size = config.text_config.get("intermediate_size")
|
||||||
|
config.model_type = config.text_config.get("model_type")
|
||||||
|
config.num_attention_heads = config.text_config.get("num_attention_heads")
|
||||||
|
config.num_hidden_layers = config.text_config.get("num_hidden_layers")
|
||||||
|
config.num_image_tokens = config.text_config.get("num_image_tokens")
|
||||||
|
config.num_key_value_heads = config.text_config.get("num_key_value_heads")
|
||||||
|
config.torch_dtype = config.text_config.get("torch_dtype")
|
||||||
|
config.vocab_size = config.text_config.get("vocab_size")
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
@ -59,17 +133,49 @@ class FlashGemma(FlashCausalLM):
|
|||||||
if config.quantize in ["gptq", "awq"]:
|
if config.quantize in ["gptq", "awq"]:
|
||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = FlashGemmaForCausalLM(config, weights)
|
model = model_cls(prefix, config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashGemma, self).__init__(
|
|
||||||
|
if is_vlm:
|
||||||
|
num_layers = config.num_hidden_layers
|
||||||
|
num_kv_heads = config.num_key_value_heads
|
||||||
|
head_size = config.intermediate_size
|
||||||
|
else:
|
||||||
|
num_layers = len(model.model.layers)
|
||||||
|
num_kv_heads = model.model.num_key_value_heads
|
||||||
|
head_size = model.model.head_size
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_layers=len(model.model.layers),
|
num_layers=num_layers,
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_size=model.model.head_size,
|
head_size=head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashGemma(BaseFlashGemma):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
super(FlashGemma, self).__init__(
|
||||||
|
model_cls=FlashGemmaForCausalLM,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
prefix=None,
|
||||||
|
)
|
||||||
|
54
server/text_generation_server/models/flash_pali_gemma.py
Normal file
54
server/text_generation_server/models/flash_pali_gemma.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from opentelemetry import trace
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from text_generation_server.models.vlm_causal_lm import PaliVlmCausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||||
|
FlashPaliGemmaForConditionalGeneration,
|
||||||
|
PaliGemmaConfig,
|
||||||
|
PaliTextConfig,
|
||||||
|
)
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashPaliGemma(PaliVlmCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
# TODO: load in the correct processor based on the model_id
|
||||||
|
"google/siglip-base-patch16-224",
|
||||||
|
# "google/siglip-so400m-patch14-384",
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
config_cls=PaliTextConfig,
|
||||||
|
model_cls=FlashPaliGemmaForConditionalGeneration,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
prefix="language_model",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||||
|
return (
|
||||||
|
len(model.language_model.model.layers),
|
||||||
|
model.language_model.model.num_key_value_heads,
|
||||||
|
model.language_model.model.head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_past(self) -> Optional[int]:
|
||||||
|
return getattr(self.model.language_model, "max_past", None)
|
@ -15,6 +15,8 @@ from text_generation_server.models.flash_mistral import (
|
|||||||
BaseFlashMistral,
|
BaseFlashMistral,
|
||||||
FlashMistralBatch,
|
FlashMistralBatch,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.flash_gemma import BaseFlashGemma
|
||||||
|
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
||||||
from text_generation_server.models.cache_manager import (
|
from text_generation_server.models.cache_manager import (
|
||||||
get_cache_manager,
|
get_cache_manager,
|
||||||
)
|
)
|
||||||
@ -80,6 +82,11 @@ def image_text_replacement(image_input, config, image_id) -> str:
|
|||||||
|
|
||||||
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
|
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
|
||||||
return "<image>" * num_features
|
return "<image>" * num_features
|
||||||
|
|
||||||
|
# TODO: double check correct naming for model_type
|
||||||
|
elif config.model_type == "gemma":
|
||||||
|
# TODO: use correct number of features
|
||||||
|
return "<image>" * 256
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
@ -371,3 +378,238 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
)
|
)
|
||||||
logits = cuda_graph["logits"][:bs]
|
logits = cuda_graph["logits"][:bs]
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class PaliVlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@tracer.start_as_current_span("concatenate")
|
||||||
|
def concatenate(cls, batches):
|
||||||
|
batch = super(PaliVlmCausalLMBatch, cls).concatenate(batches)
|
||||||
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
batch.image_sizes = None
|
||||||
|
return batch
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("filter")
|
||||||
|
def filter(self, request_ids: List[int]):
|
||||||
|
batch = super().filter(request_ids)
|
||||||
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
batch.image_sizes = None
|
||||||
|
return batch
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
||||||
|
batch_inputs = []
|
||||||
|
image_inputs = []
|
||||||
|
max_truncation = 0
|
||||||
|
for r in requests:
|
||||||
|
chunks = split(r.inputs)
|
||||||
|
full_text = ""
|
||||||
|
image_id = 0
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk["type"] == "text":
|
||||||
|
full_text += chunk["content"]
|
||||||
|
elif chunk["type"] == "image":
|
||||||
|
image = chunk["content"]
|
||||||
|
# Should never receive URLs anymore, processing should be done
|
||||||
|
# On the rust layer.
|
||||||
|
# This avoid making n queries per TP
|
||||||
|
# if image.startswith("https://") or image.startswith("http://"):
|
||||||
|
# image = processor.image_processor.fetch_images(image)
|
||||||
|
if image.startswith("data:"):
|
||||||
|
image = load_data_uri(image)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot process input image not starting with data:"
|
||||||
|
)
|
||||||
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||||||
|
full_text += image_text_replacement(image_input, config, image_id)
|
||||||
|
image_inputs.append(image_input)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||||
|
|
||||||
|
batch_inputs.append(full_text)
|
||||||
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
|
batch_tokenized_inputs = tokenizer(
|
||||||
|
batch_inputs, truncation=True, max_length=max_truncation
|
||||||
|
)["input_ids"]
|
||||||
|
if image_inputs:
|
||||||
|
image_input = image_inputs[0]
|
||||||
|
new_image_inputs = {
|
||||||
|
"pixel_values": torch.cat(
|
||||||
|
[img["pixel_values"] for img in image_inputs], dim=0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
if "pixel_attention_mask" in image_input:
|
||||||
|
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
||||||
|
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
||||||
|
)
|
||||||
|
if "image_sizes" in image_input:
|
||||||
|
new_image_inputs["image_sizes"] = torch.cat(
|
||||||
|
[img["image_sizes"] for img in image_inputs], dim=0
|
||||||
|
)
|
||||||
|
image_inputs = new_image_inputs
|
||||||
|
else:
|
||||||
|
image_inputs = None
|
||||||
|
return batch_tokenized_inputs, image_inputs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb_processor(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
processor,
|
||||||
|
config,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "PaliVlmCausalLMBatch":
|
||||||
|
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||||
|
pb.requests, tokenizer, processor, config
|
||||||
|
)
|
||||||
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
|
if image_inputs is not None:
|
||||||
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||||
|
if "pixel_attention_mask" in image_inputs:
|
||||||
|
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
if "image_sizes" in image_inputs:
|
||||||
|
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.image_sizes = None
|
||||||
|
else:
|
||||||
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
batch.image_sizes = None
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
class PaliVlmCausalLM(BaseFlashGemma):
|
||||||
|
@property
|
||||||
|
def batch_type(self) -> Type[PaliVlmCausalLMBatch]:
|
||||||
|
return PaliVlmCausalLMBatch
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, batch: PaliVlmCausalLMBatch
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
# Model Forward
|
||||||
|
if batch.speculative_ids is not None:
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
kv_cache = get_cache_manager().kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
max_s = batch.max_seqlen
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
speculative_ids = batch.speculative_ids
|
||||||
|
|
||||||
|
B, speculative_length = speculative_ids.shape
|
||||||
|
new_length = speculative_length + 1
|
||||||
|
new_input_ids = torch.cat(
|
||||||
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||||
|
).reshape(-1)
|
||||||
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||||
|
arange_int = arange.to(dtype=torch.int32)
|
||||||
|
new_position_ids = (
|
||||||
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||||
|
).view(-1)
|
||||||
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
input_lengths = (
|
||||||
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
|
).view(-1)
|
||||||
|
|
||||||
|
# Add Copy the block tables for all members
|
||||||
|
block_tables = (
|
||||||
|
block_tables.unsqueeze(1)
|
||||||
|
.expand(B, new_length, -1)
|
||||||
|
.reshape(B * new_length, -1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
max_s = max_s + speculative_length
|
||||||
|
|
||||||
|
input_ids = new_input_ids
|
||||||
|
position_ids = new_position_ids
|
||||||
|
else:
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
kv_cache = get_cache_manager().kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
max_s = batch.max_seqlen
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
|
# in a circular buffer mode.
|
||||||
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
# Try to find an associated cuda graph
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
|
if sorted_padded_bs:
|
||||||
|
# Get associated cuda graph
|
||||||
|
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||||
|
else:
|
||||||
|
cuda_graph = None
|
||||||
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
|
logits, speculative_logits = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
# prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
pixel_values=batch.pixel_values,
|
||||||
|
)
|
||||||
|
# if batch.prefill_cache_indices is not None:
|
||||||
|
# batch.prefill_cache_indices = None
|
||||||
|
if batch.pixel_values is not None:
|
||||||
|
batch.pixel_values = None
|
||||||
|
if batch.pixel_attention_mask is not None:
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
if batch.image_sizes is not None:
|
||||||
|
batch.image_sizes = None
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
# Static inputs are potentially padded
|
||||||
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
|
cuda_graph["block_tables"][
|
||||||
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
|
] = block_tables
|
||||||
|
cuda_graph["slots"].fill_(-1)
|
||||||
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
|
cuda_graph["input_lengths"].zero_()
|
||||||
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
|
|
||||||
|
# Replay the graph
|
||||||
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
|
# Slice output to the correct shape
|
||||||
|
speculative_logits = (
|
||||||
|
cuda_graph["speculative_logits"][:bs]
|
||||||
|
if cuda_graph["speculative_logits"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
logits = cuda_graph["logits"][:bs]
|
||||||
|
return logits, speculative_logits
|
||||||
|
@ -14,7 +14,7 @@ from typing import List, Optional
|
|||||||
from text_generation_server.cache import Cache
|
from text_generation_server.cache import Cache
|
||||||
from text_generation_server.interceptor import ExceptionInterceptor
|
from text_generation_server.interceptor import ExceptionInterceptor
|
||||||
from text_generation_server.models import Model, get_model
|
from text_generation_server.models import Model, get_model
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, PaliVlmCausalLMBatch
|
||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||||
@ -98,6 +98,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
if self.model.batch_type in {
|
if self.model.batch_type in {
|
||||||
IdeficsCausalLMBatch,
|
IdeficsCausalLMBatch,
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
|
PaliVlmCausalLMBatch,
|
||||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
batch = self.model.batch_type.from_pb_processor(
|
batch = self.model.batch_type.from_pb_processor(
|
||||||
request.batch,
|
request.batch,
|
||||||
@ -122,6 +123,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
if self.model.batch_type in {
|
if self.model.batch_type in {
|
||||||
IdeficsCausalLMBatch,
|
IdeficsCausalLMBatch,
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
|
PaliVlmCausalLMBatch,
|
||||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
batch = self.model.batch_type.from_pb_processor(
|
batch = self.model.batch_type.from_pb_processor(
|
||||||
request.batch,
|
request.batch,
|
||||||
|
Loading…
Reference in New Issue
Block a user