mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixed PaliGemma.
This commit is contained in:
parent
67e833cedb
commit
c119ac4d1d
@ -100,7 +100,6 @@ impl LlavaNext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(tag = "model_type")]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct ClipVisionModel {
|
pub struct ClipVisionModel {
|
||||||
image_size: usize,
|
image_size: usize,
|
||||||
@ -108,7 +107,6 @@ pub struct ClipVisionModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(tag = "model_type")]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct Idefics2 {}
|
pub struct Idefics2 {}
|
||||||
|
|
||||||
@ -119,18 +117,20 @@ impl Idefics2 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(tag = "model_type")]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct Paligemma {}
|
pub struct PaliTextConfig {
|
||||||
|
num_image_tokens: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct Paligemma {
|
||||||
|
text_config: PaliTextConfig,
|
||||||
|
}
|
||||||
|
|
||||||
impl Paligemma {
|
impl Paligemma {
|
||||||
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
||||||
// TODO: improve to calculate based on height and width
|
self.text_config.num_image_tokens
|
||||||
// 224 = 256 image tokens
|
|
||||||
// 448 = 1024 image tokens
|
|
||||||
// 896 = 4096 image tokens
|
|
||||||
|
|
||||||
256
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,6 +64,9 @@ try:
|
|||||||
from text_generation_server.models.flash_gemma import (
|
from text_generation_server.models.flash_gemma import (
|
||||||
FlashGemma,
|
FlashGemma,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.pali_gemma import (
|
||||||
|
PaliGemma,
|
||||||
|
)
|
||||||
from text_generation_server.models.flash_santacoder import (
|
from text_generation_server.models.flash_santacoder import (
|
||||||
FlashSantacoderSharded,
|
FlashSantacoderSharded,
|
||||||
)
|
)
|
||||||
@ -654,6 +657,18 @@ def get_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
if model_type == "paligemma":
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
return PaliGemma(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
if model_type == "llava_next":
|
if model_type == "llava_next":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
@ -153,15 +153,11 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemmaAttention(torch.nn.Module):
|
class FlashGemmaAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||||
self,
|
|
||||||
prefix: str,
|
|
||||||
config,
|
|
||||||
weights,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_size = config.head_dim
|
self.head_size = config.head_dim
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
config=config,
|
||||||
@ -238,6 +234,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
causal=self.causal,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
@ -295,10 +292,10 @@ class GemmaMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemmaLayer(nn.Module):
|
class FlashGemmaLayer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, causal: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashGemmaAttention(
|
self.self_attn = FlashGemmaAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
|
||||||
)
|
)
|
||||||
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
@ -350,7 +347,7 @@ class FlashGemmaLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemmaModel(torch.nn.Module):
|
class FlashGemmaModel(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, causal: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
@ -362,6 +359,7 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
prefix=f"{prefix}.layers.{layer_id}",
|
prefix=f"{prefix}.layers.{layer_id}",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
causal=causal,
|
||||||
)
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
@ -378,7 +376,7 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_embeds: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -387,7 +385,7 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = input_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
@ -416,7 +414,7 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemmaForCausalLM(torch.nn.Module):
|
class FlashGemmaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, causal: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
embed_norm = config.hidden_size**0.5
|
embed_norm = config.hidden_size**0.5
|
||||||
@ -430,7 +428,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.embed_tokens.weight *= embed_norm
|
self.embed_tokens.weight *= embed_norm
|
||||||
|
|
||||||
self.model = FlashGemmaModel(prefix=prefix, config=config, weights=weights)
|
self.model = FlashGemmaModel(
|
||||||
|
prefix=prefix, config=config, weights=weights, causal=causal
|
||||||
|
)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
prefix=(
|
prefix=(
|
||||||
f"{prefix}.embed_tokens"
|
f"{prefix}.embed_tokens"
|
||||||
|
@ -19,134 +19,14 @@ from torch import nn
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils.layers import TensorParallelColumnLinear
|
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
|
||||||
from text_generation_server.models.custom_modeling.vlm import (
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
load_text_model,
|
load_text_model,
|
||||||
load_vision_model,
|
load_vision_model,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|
||||||
GemmaConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VisionConfig(PretrainedConfig):
|
class PaliGemmaForConditionalGeneration(nn.Module):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int = 1152,
|
|
||||||
intermediate_size: int = 4304,
|
|
||||||
model_type: str = "siglip_vision_model",
|
|
||||||
num_attention_heads: int = 16,
|
|
||||||
num_hidden_layers: int = 27,
|
|
||||||
num_image_tokens: int = 256,
|
|
||||||
patch_size: int = 14,
|
|
||||||
projection_dim: int = 2048,
|
|
||||||
projector_hidden_act: str = "gelu_fast",
|
|
||||||
vision_use_head: bool = False,
|
|
||||||
vocab_size: int = 257152,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
image_size: int = 224,
|
|
||||||
layer_norm_eps: float = 1e-06,
|
|
||||||
attention_dropout: float = 0.0,
|
|
||||||
hidden_act: str = "gelu_pytorch_tanh",
|
|
||||||
num_channels: int = 3,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.model_type = model_type
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_image_tokens = num_image_tokens
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.projection_dim = projection_dim
|
|
||||||
self.projector_hidden_act = projector_hidden_act
|
|
||||||
self.vision_use_head = vision_use_head
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.quantize = quantize
|
|
||||||
self.image_size = image_size
|
|
||||||
self.layer_norm_eps = layer_norm_eps
|
|
||||||
self.attention_dropout = attention_dropout
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.num_channels = num_channels
|
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaConfig(PretrainedConfig):
|
|
||||||
model_type = "paligemma"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
text_config: GemmaConfig,
|
|
||||||
vision_config: VisionConfig,
|
|
||||||
vocab_size: int = 257152,
|
|
||||||
image_token_index: int = 256000,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.text_config = text_config
|
|
||||||
self.vision_config = vision_config
|
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.image_token_index = image_token_index
|
|
||||||
|
|
||||||
self.intermediate_size = text_config.intermediate_size
|
|
||||||
self.num_hidden_layers = text_config.num_hidden_layers
|
|
||||||
self.num_key_value_heads = text_config.num_key_value_heads
|
|
||||||
self.num_attention_heads = text_config.num_attention_heads
|
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def from_pretrained(pretrained_model_name_or_path, **kwargs):
|
|
||||||
vision_config = VisionConfig(
|
|
||||||
hidden_size=1152,
|
|
||||||
intermediate_size=4304,
|
|
||||||
model_type="siglip_vision_model",
|
|
||||||
num_attention_heads=16,
|
|
||||||
num_hidden_layers=27,
|
|
||||||
num_image_tokens=256,
|
|
||||||
patch_size=14,
|
|
||||||
projection_dim=2048,
|
|
||||||
projector_hidden_act="gelu_fast",
|
|
||||||
vision_use_head=False,
|
|
||||||
vocab_size=257152,
|
|
||||||
)
|
|
||||||
|
|
||||||
text_config = GemmaConfig.from_pretrained(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
attention_bias=False,
|
|
||||||
attention_dropout=0.0,
|
|
||||||
bos_token_id=2,
|
|
||||||
eos_token_id=1,
|
|
||||||
head_dim=256,
|
|
||||||
hidden_act="gelu_pytorch_tanh",
|
|
||||||
hidden_activation=None,
|
|
||||||
hidden_size=2048,
|
|
||||||
initializer_range=0.02,
|
|
||||||
intermediate_size=16384,
|
|
||||||
max_position_embeddings=8192,
|
|
||||||
model_type="gemma",
|
|
||||||
num_attention_heads=8,
|
|
||||||
num_hidden_layers=18,
|
|
||||||
num_image_tokens=256,
|
|
||||||
num_key_value_heads=1,
|
|
||||||
pad_token_id=0,
|
|
||||||
rms_norm_eps=1e-06,
|
|
||||||
rope_theta=10000.0,
|
|
||||||
torch_dtype="float32",
|
|
||||||
transformers_version="4.40.0.dev0",
|
|
||||||
use_cache=True,
|
|
||||||
vocab_size=257216,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return PaliGemmaConfig(
|
|
||||||
text_config=text_config,
|
|
||||||
vision_config=vision_config,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config.vision_config.quantize = config.quantize
|
config.vision_config.quantize = config.quantize
|
||||||
@ -166,6 +46,9 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
text_config = config.text_config
|
||||||
|
text_config.speculator = config.speculator
|
||||||
|
text_config.quantize = config.quantize
|
||||||
self.text_model = load_text_model(
|
self.text_model = load_text_model(
|
||||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||||
config=config.text_config,
|
config=config.text_config,
|
||||||
@ -188,36 +71,28 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
# Unused here
|
||||||
pixel_attention_mask=None,
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
max_s += 1
|
||||||
|
position_ids += 1
|
||||||
|
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None:
|
||||||
# TODO: avoid these casts upstream
|
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
||||||
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
|
|
||||||
image_outputs = self.vision_tower(pixel_values)
|
image_outputs = self.vision_tower(pixel_values)
|
||||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||||
|
|
||||||
# TODO: now we scale them? maybe we can do this up or downstream
|
|
||||||
scaled_image_features = image_features / (
|
|
||||||
self.config.text_config.hidden_size**0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
# mask where image or padding tokens
|
# mask where image or padding tokens
|
||||||
mask = input_ids == self.config.image_token_index | (input_ids == 2)
|
mask = input_ids == self.config.image_token_index
|
||||||
|
|
||||||
# insert image features into input embeddings
|
# insert image features into input embeddings
|
||||||
# normalizer = torch.tensor(
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
# self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype
|
|
||||||
# )
|
|
||||||
# inputs_embeds = inputs_embeds * normalizer
|
|
||||||
inputs_embeds[mask] = scaled_image_features.view(
|
|
||||||
-1, scaled_image_features.shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.language_model.model(
|
hidden_states = self.text_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
@ -230,6 +105,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
|
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -16,7 +16,7 @@ from transformers.modeling_outputs import (
|
|||||||
)
|
)
|
||||||
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.layers.tensor_parallel import (
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -16,7 +16,7 @@ def load_text_model(prefix, config, weights, name=None):
|
|||||||
FlashGemmaForCausalLM,
|
FlashGemmaForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
return FlashGemmaForCausalLM(prefix, config, weights)
|
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
|
||||||
elif config.model_type == "paligemma":
|
elif config.model_type == "paligemma":
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
FlashGemmaForCausalLM,
|
FlashGemmaForCausalLM,
|
||||||
|
@ -3,8 +3,7 @@ import torch.distributed
|
|||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from transformers.models.gemma import GemmaTokenizerFast
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
from transformers import AutoConfig
|
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
@ -36,14 +35,12 @@ class FlashGemma(FlashCausalLM):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashGemma is only available on GPU")
|
raise NotImplementedError("FlashGemma is only available on GPU")
|
||||||
|
|
||||||
tokenizer = GemmaTokenizerFast.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
use_fast=True,
|
|
||||||
from_slow=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
@ -61,7 +58,7 @@ class FlashGemma(FlashCausalLM):
|
|||||||
|
|
||||||
# TODO hardcoded
|
# TODO hardcoded
|
||||||
prefix = "language_model"
|
prefix = "language_model"
|
||||||
model = FlashGemmaForCausalLM(prefix, config, weights)
|
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashGemma, self).__init__(
|
super(FlashGemma, self).__init__(
|
||||||
|
@ -1,50 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
from opentelemetry import trace
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
from text_generation_server.models.vlm_causal_lm import PaliVlmCausalLM
|
|
||||||
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
|
||||||
FlashPaliGemmaForConditionalGeneration,
|
|
||||||
PaliGemmaConfig,
|
|
||||||
)
|
|
||||||
from transformers import AutoProcessor
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashPaliGemma(PaliVlmCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
use_medusa: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
"google/siglip-base-patch16-224",
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
config_cls=PaliGemmaConfig,
|
|
||||||
model_cls=FlashPaliGemmaForConditionalGeneration,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
use_medusa=use_medusa,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
|
||||||
return (
|
|
||||||
len(model.language_model.model.layers),
|
|
||||||
model.language_model.model.num_key_value_heads,
|
|
||||||
model.language_model.model.head_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_past(self) -> Optional[int]:
|
|
||||||
return getattr(self.model.language_model, "max_past", None)
|
|
123
server/text_generation_server/models/pali_gemma.py
Normal file
123
server/text_generation_server/models/pali_gemma.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from opentelemetry import trace
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from text_generation_server.models.vlm_causal_lm import (
|
||||||
|
VlmCausalLM,
|
||||||
|
VlmCausalLMBatch,
|
||||||
|
image_text_replacement,
|
||||||
|
load_data_uri,
|
||||||
|
split,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||||
|
PaliGemmaForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaBatch(VlmCausalLMBatch):
|
||||||
|
@classmethod
|
||||||
|
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
||||||
|
batch_inputs = []
|
||||||
|
image_inputs = []
|
||||||
|
max_truncation = 0
|
||||||
|
for r in requests:
|
||||||
|
chunks = split(r.inputs)
|
||||||
|
full_text = ""
|
||||||
|
image_id = 0
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk["type"] == "text":
|
||||||
|
full_text += "<bos>" + chunk["content"] + "\n"
|
||||||
|
elif chunk["type"] == "image":
|
||||||
|
image = chunk["content"]
|
||||||
|
# Should never receive URLs anymore, processing should be done
|
||||||
|
# On the rust layer.
|
||||||
|
# This avoid making n queries per TP
|
||||||
|
# if image.startswith("https://") or image.startswith("http://"):
|
||||||
|
# image = processor.image_processor.fetch_images(image)
|
||||||
|
if image.startswith("data:"):
|
||||||
|
image = load_data_uri(image)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot process input image not starting with data:"
|
||||||
|
)
|
||||||
|
# TODO do_convert_RGB should be on by default ?
|
||||||
|
image = image.convert("RGB")
|
||||||
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||||||
|
full_text += image_text_replacement(image_input, config, image_id)
|
||||||
|
image_inputs.append(image_input)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||||
|
|
||||||
|
batch_inputs.append(full_text)
|
||||||
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
|
batch_tokenized_inputs = tokenizer(
|
||||||
|
batch_inputs,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_truncation,
|
||||||
|
add_special_tokens=False,
|
||||||
|
)["input_ids"]
|
||||||
|
if image_inputs:
|
||||||
|
image_input = image_inputs[0]
|
||||||
|
new_image_inputs = {
|
||||||
|
"pixel_values": torch.cat(
|
||||||
|
[img["pixel_values"] for img in image_inputs], dim=0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
if "pixel_attention_mask" in image_input:
|
||||||
|
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
||||||
|
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
||||||
|
)
|
||||||
|
if "image_sizes" in image_input:
|
||||||
|
new_image_inputs["image_sizes"] = torch.cat(
|
||||||
|
[img["image_sizes"] for img in image_inputs], dim=0
|
||||||
|
)
|
||||||
|
image_inputs = new_image_inputs
|
||||||
|
else:
|
||||||
|
image_inputs = None
|
||||||
|
return batch_tokenized_inputs, image_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemma(VlmCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
config_cls=AutoConfig,
|
||||||
|
model_cls=PaliGemmaForConditionalGeneration,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_type(self):
|
||||||
|
return PaliGemmaBatch
|
||||||
|
|
||||||
|
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||||
|
return (
|
||||||
|
len(model.text_model.model.layers),
|
||||||
|
model.text_model.model.num_key_value_heads,
|
||||||
|
model.text_model.model.head_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_past(self) -> Optional[int]:
|
||||||
|
return getattr(self.model.text_model, "max_past", None)
|
@ -83,8 +83,7 @@ def image_text_replacement(image_input, config, image_id) -> str:
|
|||||||
return "<image>" * num_features
|
return "<image>" * num_features
|
||||||
|
|
||||||
elif config.model_type == "paligemma":
|
elif config.model_type == "paligemma":
|
||||||
# TODO: use correct number of features
|
return "<image>" * config.text_config.num_image_tokens
|
||||||
return "<image>" * 256
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
@ -174,7 +173,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
image_id = 0
|
image_id = 0
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if chunk["type"] == "text":
|
if chunk["type"] == "text":
|
||||||
full_text += chunk["content"]
|
full_text += "<bos>" + chunk["content"] + "\n"
|
||||||
elif chunk["type"] == "image":
|
elif chunk["type"] == "image":
|
||||||
image = chunk["content"]
|
image = chunk["content"]
|
||||||
# Should never receive URLs anymore, processing should be done
|
# Should never receive URLs anymore, processing should be done
|
||||||
@ -198,7 +197,10 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
batch_tokenized_inputs = tokenizer(
|
||||||
batch_inputs, truncation=True, max_length=max_truncation
|
batch_inputs,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_truncation,
|
||||||
|
add_special_tokens=False,
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
if image_inputs:
|
if image_inputs:
|
||||||
image_input = image_inputs[0]
|
image_input = image_inputs[0]
|
||||||
@ -376,142 +378,3 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
)
|
)
|
||||||
logits = cuda_graph["logits"][:bs]
|
logits = cuda_graph["logits"][:bs]
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
class PaliVlmCausalLMBatch(FlashCausalLMBatch):
|
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
|
||||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
|
||||||
image_sizes: Optional[List[Tuple[int, int]]]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@tracer.start_as_current_span("concatenate")
|
|
||||||
def concatenate(cls, batches):
|
|
||||||
batch = super(PaliVlmCausalLMBatch, cls).concatenate(batches)
|
|
||||||
batch.pixel_values = None
|
|
||||||
batch.pixel_attention_mask = None
|
|
||||||
batch.image_sizes = None
|
|
||||||
return batch
|
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
|
||||||
def filter(self, request_ids: List[int]):
|
|
||||||
batch = super().filter(request_ids)
|
|
||||||
batch.pixel_values = None
|
|
||||||
batch.pixel_attention_mask = None
|
|
||||||
batch.image_sizes = None
|
|
||||||
return batch
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
|
||||||
batch_inputs = []
|
|
||||||
image_inputs = []
|
|
||||||
text_inputs = []
|
|
||||||
image_text_replacements = []
|
|
||||||
max_truncation = 0
|
|
||||||
for r in requests:
|
|
||||||
chunks = split(r.inputs)
|
|
||||||
full_text = ""
|
|
||||||
image_id = 0
|
|
||||||
for chunk in chunks:
|
|
||||||
if chunk["type"] == "text":
|
|
||||||
full_text += chunk["content"]
|
|
||||||
text_inputs.append(chunk["content"])
|
|
||||||
elif chunk["type"] == "image":
|
|
||||||
image = chunk["content"]
|
|
||||||
# Should never receive URLs anymore, processing should be done
|
|
||||||
# On the rust layer.
|
|
||||||
# This avoid making n queries per TP
|
|
||||||
# if image.startswith("https://") or image.startswith("http://"):
|
|
||||||
# image = processor.image_processor.fetch_images(image)
|
|
||||||
if image.startswith("data:"):
|
|
||||||
image = load_data_uri(image)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot process input image not starting with data:"
|
|
||||||
)
|
|
||||||
image_input = processor.image_processor(image, return_tensors="pt")
|
|
||||||
text_replacement = image_text_replacement(
|
|
||||||
image_input, config, image_id
|
|
||||||
)
|
|
||||||
full_text += text_replacement
|
|
||||||
image_text_replacements.append(text_replacement)
|
|
||||||
image_inputs.append(image_input)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
|
||||||
|
|
||||||
batch_inputs.append(full_text)
|
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
|
||||||
batch_inputs,
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_truncation,
|
|
||||||
add_special_tokens=False,
|
|
||||||
)["input_ids"]
|
|
||||||
|
|
||||||
image_token = tokenizer.get_added_vocab()["<image>"]
|
|
||||||
|
|
||||||
# find the index of the first non-image token
|
|
||||||
for batch in batch_tokenized_inputs:
|
|
||||||
first_non_image = 0
|
|
||||||
for i, token in enumerate(batch):
|
|
||||||
if token != image_token:
|
|
||||||
first_non_image = i
|
|
||||||
break
|
|
||||||
|
|
||||||
# manually add the bos to the left of the text
|
|
||||||
batch_tokenized_inputs = [
|
|
||||||
batch[:first_non_image] + [tokenizer.bos_token_id] + batch[first_non_image:]
|
|
||||||
for batch in batch_tokenized_inputs
|
|
||||||
]
|
|
||||||
|
|
||||||
if image_inputs:
|
|
||||||
image_input = image_inputs[0]
|
|
||||||
new_image_inputs = {
|
|
||||||
"pixel_values": torch.cat(
|
|
||||||
[img["pixel_values"] for img in image_inputs], dim=0
|
|
||||||
),
|
|
||||||
}
|
|
||||||
if "pixel_attention_mask" in image_input:
|
|
||||||
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
|
||||||
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
|
||||||
)
|
|
||||||
if "image_sizes" in image_input:
|
|
||||||
new_image_inputs["image_sizes"] = torch.cat(
|
|
||||||
[img["image_sizes"] for img in image_inputs], dim=0
|
|
||||||
)
|
|
||||||
image_inputs = new_image_inputs
|
|
||||||
else:
|
|
||||||
image_inputs = None
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pb_processor(
|
|
||||||
cls,
|
|
||||||
pb: generate_pb2.Batch,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
processor,
|
|
||||||
config,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
) -> "PaliVlmCausalLMBatch":
|
|
||||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
|
||||||
pb.requests, tokenizer, processor, config
|
|
||||||
)
|
|
||||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
|
||||||
if image_inputs is not None:
|
|
||||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
|
||||||
if "pixel_attention_mask" in image_inputs:
|
|
||||||
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
|
||||||
device=device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch.pixel_attention_mask = None
|
|
||||||
if "image_sizes" in image_inputs:
|
|
||||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
|
||||||
else:
|
|
||||||
batch.image_sizes = None
|
|
||||||
else:
|
|
||||||
batch.pixel_values = None
|
|
||||||
batch.pixel_attention_mask = None
|
|
||||||
batch.image_sizes = None
|
|
||||||
return batch
|
|
||||||
|
@ -14,9 +14,9 @@ from typing import List, Optional
|
|||||||
from text_generation_server.cache import Cache
|
from text_generation_server.cache import Cache
|
||||||
from text_generation_server.interceptor import ExceptionInterceptor
|
from text_generation_server.interceptor import ExceptionInterceptor
|
||||||
from text_generation_server.models import Model, get_model
|
from text_generation_server.models import Model, get_model
|
||||||
|
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||||
from text_generation_server.models.vlm_causal_lm import (
|
from text_generation_server.models.vlm_causal_lm import (
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
PaliVlmCausalLMBatch,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
@ -101,7 +101,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
if self.model.batch_type in {
|
if self.model.batch_type in {
|
||||||
IdeficsCausalLMBatch,
|
IdeficsCausalLMBatch,
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
PaliVlmCausalLMBatch,
|
PaliGemmaBatch,
|
||||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
batch = self.model.batch_type.from_pb_processor(
|
batch = self.model.batch_type.from_pb_processor(
|
||||||
request.batch,
|
request.batch,
|
||||||
@ -126,7 +126,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
if self.model.batch_type in {
|
if self.model.batch_type in {
|
||||||
IdeficsCausalLMBatch,
|
IdeficsCausalLMBatch,
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
PaliVlmCausalLMBatch,
|
PaliGemmaBatch,
|
||||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
batch = self.model.batch_type.from_pb_processor(
|
batch = self.model.batch_type.from_pb_processor(
|
||||||
request.batch,
|
request.batch,
|
||||||
|
@ -116,6 +116,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
|
|||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
):
|
):
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
@ -134,7 +135,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
|
|||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
True,
|
causal,
|
||||||
window_size_left,
|
window_size_left,
|
||||||
0,
|
0,
|
||||||
False,
|
False,
|
||||||
|
Loading…
Reference in New Issue
Block a user