Fixed PaliGemma.

This commit is contained in:
Nicolas Patry 2024-05-14 15:58:19 +00:00
parent 67e833cedb
commit c119ac4d1d
12 changed files with 195 additions and 371 deletions

View File

@ -100,7 +100,6 @@ impl LlavaNext {
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct ClipVisionModel { pub struct ClipVisionModel {
image_size: usize, image_size: usize,
@ -108,7 +107,6 @@ pub struct ClipVisionModel {
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct Idefics2 {} pub struct Idefics2 {}
@ -119,18 +117,20 @@ impl Idefics2 {
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct Paligemma {} pub struct PaliTextConfig {
num_image_tokens: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Paligemma {
text_config: PaliTextConfig,
}
impl Paligemma { impl Paligemma {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
// TODO: improve to calculate based on height and width self.text_config.num_image_tokens
// 224 = 256 image tokens
// 448 = 1024 image tokens
// 896 = 4096 image tokens
256
} }
} }

View File

@ -64,6 +64,9 @@ try:
from text_generation_server.models.flash_gemma import ( from text_generation_server.models.flash_gemma import (
FlashGemma, FlashGemma,
) )
from text_generation_server.models.pali_gemma import (
PaliGemma,
)
from text_generation_server.models.flash_santacoder import ( from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded, FlashSantacoderSharded,
) )
@ -654,6 +657,18 @@ def get_model(
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "paligemma":
if FLASH_ATTENTION:
return PaliGemma(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "llava_next": if model_type == "llava_next":
if FLASH_ATTENTION: if FLASH_ATTENTION:

View File

@ -153,15 +153,11 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemmaAttention(torch.nn.Module): class FlashGemmaAttention(torch.nn.Module):
def __init__( def __init__(self, prefix: str, config, weights, causal: bool):
self,
prefix: str,
config,
weights,
):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_size = config.head_dim self.head_size = config.head_dim
self.causal = causal
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, config=config,
@ -238,6 +234,7 @@ class FlashGemmaAttention(torch.nn.Module):
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
causal=self.causal,
) )
# Decode # Decode
else: else:
@ -295,10 +292,10 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module): class FlashGemmaLayer(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, causal: bool):
super().__init__() super().__init__()
self.self_attn = FlashGemmaAttention( self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
) )
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@ -350,7 +347,7 @@ class FlashGemmaLayer(nn.Module):
class FlashGemmaModel(torch.nn.Module): class FlashGemmaModel(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, causal: bool):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
@ -362,6 +359,7 @@ class FlashGemmaModel(torch.nn.Module):
prefix=f"{prefix}.layers.{layer_id}", prefix=f"{prefix}.layers.{layer_id}",
config=config, config=config,
weights=weights, weights=weights,
causal=causal,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
@ -378,7 +376,7 @@ class FlashGemmaModel(torch.nn.Module):
def forward( def forward(
self, self,
input_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -387,7 +385,7 @@ class FlashGemmaModel(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = input_embeds hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
@ -416,7 +414,7 @@ class FlashGemmaModel(torch.nn.Module):
class FlashGemmaForCausalLM(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, causal: bool):
super().__init__() super().__init__()
embed_norm = config.hidden_size**0.5 embed_norm = config.hidden_size**0.5
@ -430,7 +428,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
) )
self.embed_tokens.weight *= embed_norm self.embed_tokens.weight *= embed_norm
self.model = FlashGemmaModel(prefix=prefix, config=config, weights=weights) self.model = FlashGemmaModel(
prefix=prefix, config=config, weights=weights, causal=causal
)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
prefix=( prefix=(
f"{prefix}.embed_tokens" f"{prefix}.embed_tokens"

View File

@ -19,134 +19,14 @@ from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.layers import TensorParallelColumnLinear from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.models.custom_modeling.vlm import ( from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
load_vision_model, load_vision_model,
) )
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
GemmaConfig,
)
class VisionConfig(PretrainedConfig): class PaliGemmaForConditionalGeneration(nn.Module):
def __init__(
self,
hidden_size: int = 1152,
intermediate_size: int = 4304,
model_type: str = "siglip_vision_model",
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
num_image_tokens: int = 256,
patch_size: int = 14,
projection_dim: int = 2048,
projector_hidden_act: str = "gelu_fast",
vision_use_head: bool = False,
vocab_size: int = 257152,
quantize: Optional[str] = None,
image_size: int = 224,
layer_norm_eps: float = 1e-06,
attention_dropout: float = 0.0,
hidden_act: str = "gelu_pytorch_tanh",
num_channels: int = 3,
**kwargs,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.model_type = model_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.num_image_tokens = num_image_tokens
self.patch_size = patch_size
self.projection_dim = projection_dim
self.projector_hidden_act = projector_hidden_act
self.vision_use_head = vision_use_head
self.vocab_size = vocab_size
self.quantize = quantize
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.num_channels = num_channels
super().__init__(**kwargs)
class PaliGemmaConfig(PretrainedConfig):
model_type = "paligemma"
def __init__(
self,
text_config: GemmaConfig,
vision_config: VisionConfig,
vocab_size: int = 257152,
image_token_index: int = 256000,
**kwargs,
):
self.text_config = text_config
self.vision_config = vision_config
self.vocab_size = vocab_size
self.image_token_index = image_token_index
self.intermediate_size = text_config.intermediate_size
self.num_hidden_layers = text_config.num_hidden_layers
self.num_key_value_heads = text_config.num_key_value_heads
self.num_attention_heads = text_config.num_attention_heads
super().__init__(**kwargs)
def from_pretrained(pretrained_model_name_or_path, **kwargs):
vision_config = VisionConfig(
hidden_size=1152,
intermediate_size=4304,
model_type="siglip_vision_model",
num_attention_heads=16,
num_hidden_layers=27,
num_image_tokens=256,
patch_size=14,
projection_dim=2048,
projector_hidden_act="gelu_fast",
vision_use_head=False,
vocab_size=257152,
)
text_config = GemmaConfig.from_pretrained(
pretrained_model_name_or_path,
attention_bias=False,
attention_dropout=0.0,
bos_token_id=2,
eos_token_id=1,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation=None,
hidden_size=2048,
initializer_range=0.02,
intermediate_size=16384,
max_position_embeddings=8192,
model_type="gemma",
num_attention_heads=8,
num_hidden_layers=18,
num_image_tokens=256,
num_key_value_heads=1,
pad_token_id=0,
rms_norm_eps=1e-06,
rope_theta=10000.0,
torch_dtype="float32",
transformers_version="4.40.0.dev0",
use_cache=True,
vocab_size=257216,
**kwargs,
)
return PaliGemmaConfig(
text_config=text_config,
vision_config=vision_config,
**kwargs,
)
class FlashPaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.vision_config.quantize = config.quantize config.vision_config.quantize = config.quantize
@ -166,6 +46,9 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.config = config self.config = config
text_config = config.text_config
text_config.speculator = config.speculator
text_config.quantize = config.quantize
self.text_model = load_text_model( self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model", prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config, config=config.text_config,
@ -188,36 +71,28 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, # Unused here
pixel_attention_mask=None, pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids) inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None:
max_s += 1
position_ids += 1
if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None:
# TODO: avoid these casts upstream pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype)
image_outputs = self.vision_tower(pixel_values) image_outputs = self.vision_tower(pixel_values)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state) image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
# TODO: now we scale them? maybe we can do this up or downstream
scaled_image_features = image_features / (
self.config.text_config.hidden_size**0.5
)
# mask where image or padding tokens # mask where image or padding tokens
mask = input_ids == self.config.image_token_index | (input_ids == 2) mask = input_ids == self.config.image_token_index
# insert image features into input embeddings # insert image features into input embeddings
# normalizer = torch.tensor( inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
# self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype
# )
# inputs_embeds = inputs_embeds * normalizer
inputs_embeds[mask] = scaled_image_features.view(
-1, scaled_image_features.shape[-1]
)
hidden_states = self.language_model.model( hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
@ -230,6 +105,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states) logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits return logits, speculative_logits

View File

@ -16,7 +16,7 @@ from transformers.modeling_outputs import (
) )
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from text_generation_server.utils.layers import ( from text_generation_server.layers.tensor_parallel import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear, TensorParallelRowLinear,

View File

@ -16,7 +16,7 @@ def load_text_model(prefix, config, weights, name=None):
FlashGemmaForCausalLM, FlashGemmaForCausalLM,
) )
return FlashGemmaForCausalLM(prefix, config, weights) return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
elif config.model_type == "paligemma": elif config.model_type == "paligemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM, FlashGemmaForCausalLM,

View File

@ -3,8 +3,7 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Optional from typing import Optional
from transformers.models.gemma import GemmaTokenizerFast from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
@ -36,14 +35,12 @@ class FlashGemma(FlashCausalLM):
else: else:
raise NotImplementedError("FlashGemma is only available on GPU") raise NotImplementedError("FlashGemma is only available on GPU")
tokenizer = GemmaTokenizerFast.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
padding_side="left", padding_side="left",
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
@ -61,7 +58,7 @@ class FlashGemma(FlashCausalLM):
# TODO hardcoded # TODO hardcoded
prefix = "language_model" prefix = "language_model"
model = FlashGemmaForCausalLM(prefix, config, weights) model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__( super(FlashGemma, self).__init__(

View File

@ -1,50 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional, Tuple
from text_generation_server.models.vlm_causal_lm import PaliVlmCausalLM
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
FlashPaliGemmaForConditionalGeneration,
PaliGemmaConfig,
)
from transformers import AutoProcessor
tracer = trace.get_tracer(__name__)
class FlashPaliGemma(PaliVlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
"google/siglip-base-patch16-224",
revision=revision,
trust_remote_code=trust_remote_code,
)
super().__init__(
config_cls=PaliGemmaConfig,
model_cls=FlashPaliGemmaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)

View File

@ -0,0 +1,123 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional, Tuple
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLM,
VlmCausalLMBatch,
image_text_replacement,
load_data_uri,
split,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from transformers import AutoProcessor, AutoConfig, AutoImageProcessor
tracer = trace.get_tracer(__name__)
class PaliGemmaBatch(VlmCausalLMBatch):
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
batch_inputs = []
image_inputs = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += "<bos>" + chunk["content"] + "\n"
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
# TODO do_convert_RGB should be on by default ?
image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=False,
)["input_ids"]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
class PaliGemma(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
super().__init__(
config_cls=AutoConfig,
model_cls=PaliGemmaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self):
return PaliGemmaBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)

View File

@ -83,8 +83,7 @@ def image_text_replacement(image_input, config, image_id) -> str:
return "<image>" * num_features return "<image>" * num_features
elif config.model_type == "paligemma": elif config.model_type == "paligemma":
# TODO: use correct number of features return "<image>" * config.text_config.num_image_tokens
return "<image>" * 256
else: else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal") raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -174,7 +173,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
image_id = 0 image_id = 0
for chunk in chunks: for chunk in chunks:
if chunk["type"] == "text": if chunk["type"] == "text":
full_text += chunk["content"] full_text += "<bos>" + chunk["content"] + "\n"
elif chunk["type"] == "image": elif chunk["type"] == "image":
image = chunk["content"] image = chunk["content"]
# Should never receive URLs anymore, processing should be done # Should never receive URLs anymore, processing should be done
@ -198,7 +197,10 @@ class VlmCausalLMBatch(FlashMistralBatch):
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer( batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=False,
)["input_ids"] )["input_ids"]
if image_inputs: if image_inputs:
image_input = image_inputs[0] image_input = image_inputs[0]
@ -376,142 +378,3 @@ class VlmCausalLM(BaseFlashMistral):
) )
logits = cuda_graph["logits"][:bs] logits = cuda_graph["logits"][:bs]
return logits, speculative_logits return logits, speculative_logits
class PaliVlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(PaliVlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
@classmethod
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
batch_inputs = []
image_inputs = []
text_inputs = []
image_text_replacements = []
max_truncation = 0
for r in requests:
chunks = split(r.inputs)
full_text = ""
image_id = 0
for chunk in chunks:
if chunk["type"] == "text":
full_text += chunk["content"]
text_inputs.append(chunk["content"])
elif chunk["type"] == "image":
image = chunk["content"]
# Should never receive URLs anymore, processing should be done
# On the rust layer.
# This avoid making n queries per TP
# if image.startswith("https://") or image.startswith("http://"):
# image = processor.image_processor.fetch_images(image)
if image.startswith("data:"):
image = load_data_uri(image)
else:
raise RuntimeError(
"Cannot process input image not starting with data:"
)
image_input = processor.image_processor(image, return_tensors="pt")
text_replacement = image_text_replacement(
image_input, config, image_id
)
full_text += text_replacement
image_text_replacements.append(text_replacement)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs,
truncation=True,
max_length=max_truncation,
add_special_tokens=False,
)["input_ids"]
image_token = tokenizer.get_added_vocab()["<image>"]
# find the index of the first non-image token
for batch in batch_tokenized_inputs:
first_non_image = 0
for i, token in enumerate(batch):
if token != image_token:
first_non_image = i
break
# manually add the bos to the left of the text
batch_tokenized_inputs = [
batch[:first_non_image] + [tokenizer.bos_token_id] + batch[first_non_image:]
for batch in batch_tokenized_inputs
]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "PaliVlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch

View File

@ -14,9 +14,9 @@ from typing import List, Optional
from text_generation_server.cache import Cache from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model from text_generation_server.models import Model, get_model
from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.vlm_causal_lm import ( from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch, VlmCausalLMBatch,
PaliVlmCausalLMBatch,
) )
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
@ -101,7 +101,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if self.model.batch_type in { if self.model.batch_type in {
IdeficsCausalLMBatch, IdeficsCausalLMBatch,
VlmCausalLMBatch, VlmCausalLMBatch,
PaliVlmCausalLMBatch, PaliGemmaBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call }: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor( batch = self.model.batch_type.from_pb_processor(
request.batch, request.batch,
@ -126,7 +126,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if self.model.batch_type in { if self.model.batch_type in {
IdeficsCausalLMBatch, IdeficsCausalLMBatch,
VlmCausalLMBatch, VlmCausalLMBatch,
PaliVlmCausalLMBatch, PaliGemmaBatch,
}: # Hack, i would rather use kwargs in the `from_pb` call }: # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor( batch = self.model.batch_type.from_pb_processor(
request.batch, request.batch,

View File

@ -116,6 +116,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True,
): ):
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
@ -134,7 +135,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,
True, causal,
window_size_left, window_size_left,
0, 0,
False, False,