From c119ac4d1da88ba505421697aa722204a864176c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 14 May 2024 15:58:19 +0000 Subject: [PATCH] Fixed PaliGemma. --- router/src/config.rs | 20 +-- .../text_generation_server/models/__init__.py | 15 ++ .../custom_modeling/flash_gemma_modeling.py | 26 +-- .../flash_pali_gemma_modeling.py | 161 ++---------------- .../models/custom_modeling/siglip.py | 2 +- .../models/custom_modeling/vlm.py | 2 +- .../models/flash_gemma.py | 9 +- .../models/flash_pali_gemma.py | 50 ------ .../models/pali_gemma.py | 123 +++++++++++++ .../models/vlm_causal_lm.py | 149 +--------------- server/text_generation_server/server.py | 6 +- .../utils/flash_attn.py | 3 +- 12 files changed, 195 insertions(+), 371 deletions(-) delete mode 100644 server/text_generation_server/models/flash_pali_gemma.py create mode 100644 server/text_generation_server/models/pali_gemma.py diff --git a/router/src/config.rs b/router/src/config.rs index 3a7b775f..a98e3fff 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -100,7 +100,6 @@ impl LlavaNext { } #[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] pub struct ClipVisionModel { image_size: usize, @@ -108,7 +107,6 @@ pub struct ClipVisionModel { } #[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] pub struct Idefics2 {} @@ -119,18 +117,20 @@ impl Idefics2 { } #[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] -pub struct Paligemma {} +pub struct PaliTextConfig { + num_image_tokens: usize, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Paligemma { + text_config: PaliTextConfig, +} impl Paligemma { pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { - // TODO: improve to calculate based on height and width - // 224 = 256 image tokens - // 448 = 1024 image tokens - // 896 = 4096 image tokens - - 256 + self.text_config.num_image_tokens } } diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e9761dfe..ca956afc 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -64,6 +64,9 @@ try: from text_generation_server.models.flash_gemma import ( FlashGemma, ) + from text_generation_server.models.pali_gemma import ( + PaliGemma, + ) from text_generation_server.models.flash_santacoder import ( FlashSantacoderSharded, ) @@ -654,6 +657,18 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) + if model_type == "paligemma": + if FLASH_ATTENTION: + return PaliGemma( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + else: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == "llava_next": if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index d2401327..0cda41f5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -153,15 +153,11 @@ def _load_gqa(config, prefix: str, weights): class FlashGemmaAttention(torch.nn.Module): - def __init__( - self, - prefix: str, - config, - weights, - ): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() self.num_heads = config.num_attention_heads self.head_size = config.head_dim + self.causal = causal self.rotary_emb = PositionRotaryEmbedding.static( config=config, @@ -238,6 +234,7 @@ class FlashGemmaAttention(torch.nn.Module): cu_seqlen_prefill, max_s, self.softmax_scale, + causal=self.causal, ) # Decode else: @@ -295,10 +292,10 @@ class GemmaMLP(nn.Module): class FlashGemmaLayer(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, causal: bool): super().__init__() self.self_attn = FlashGemmaAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal ) self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) @@ -350,7 +347,7 @@ class FlashGemmaLayer(nn.Module): class FlashGemmaModel(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, causal: bool): super().__init__() process_group = weights.process_group @@ -362,6 +359,7 @@ class FlashGemmaModel(torch.nn.Module): prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights, + causal=causal, ) for layer_id in range(config.num_hidden_layers) ] @@ -378,7 +376,7 @@ class FlashGemmaModel(torch.nn.Module): def forward( self, - input_embeds: torch.Tensor, + inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -387,7 +385,7 @@ class FlashGemmaModel(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, ) -> torch.Tensor: - hidden_states = input_embeds + hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -416,7 +414,7 @@ class FlashGemmaModel(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, causal: bool): super().__init__() embed_norm = config.hidden_size**0.5 @@ -430,7 +428,9 @@ class FlashGemmaForCausalLM(torch.nn.Module): ) self.embed_tokens.weight *= embed_norm - self.model = FlashGemmaModel(prefix=prefix, config=config, weights=weights) + self.model = FlashGemmaModel( + prefix=prefix, config=config, weights=weights, causal=causal + ) self.lm_head = SpeculativeHead.load( prefix=( f"{prefix}.embed_tokens" diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 9b3894bb..87398302 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -19,134 +19,14 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils.layers import TensorParallelColumnLinear +from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, ) -from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( - GemmaConfig, -) -class VisionConfig(PretrainedConfig): - def __init__( - self, - hidden_size: int = 1152, - intermediate_size: int = 4304, - model_type: str = "siglip_vision_model", - num_attention_heads: int = 16, - num_hidden_layers: int = 27, - num_image_tokens: int = 256, - patch_size: int = 14, - projection_dim: int = 2048, - projector_hidden_act: str = "gelu_fast", - vision_use_head: bool = False, - vocab_size: int = 257152, - quantize: Optional[str] = None, - image_size: int = 224, - layer_norm_eps: float = 1e-06, - attention_dropout: float = 0.0, - hidden_act: str = "gelu_pytorch_tanh", - num_channels: int = 3, - **kwargs, - ): - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.model_type = model_type - self.num_attention_heads = num_attention_heads - self.num_hidden_layers = num_hidden_layers - self.num_image_tokens = num_image_tokens - self.patch_size = patch_size - self.projection_dim = projection_dim - self.projector_hidden_act = projector_hidden_act - self.vision_use_head = vision_use_head - self.vocab_size = vocab_size - self.quantize = quantize - self.image_size = image_size - self.layer_norm_eps = layer_norm_eps - self.attention_dropout = attention_dropout - self.hidden_act = hidden_act - self.num_channels = num_channels - - super().__init__(**kwargs) - - -class PaliGemmaConfig(PretrainedConfig): - model_type = "paligemma" - - def __init__( - self, - text_config: GemmaConfig, - vision_config: VisionConfig, - vocab_size: int = 257152, - image_token_index: int = 256000, - **kwargs, - ): - self.text_config = text_config - self.vision_config = vision_config - - self.vocab_size = vocab_size - self.image_token_index = image_token_index - - self.intermediate_size = text_config.intermediate_size - self.num_hidden_layers = text_config.num_hidden_layers - self.num_key_value_heads = text_config.num_key_value_heads - self.num_attention_heads = text_config.num_attention_heads - - super().__init__(**kwargs) - - def from_pretrained(pretrained_model_name_or_path, **kwargs): - vision_config = VisionConfig( - hidden_size=1152, - intermediate_size=4304, - model_type="siglip_vision_model", - num_attention_heads=16, - num_hidden_layers=27, - num_image_tokens=256, - patch_size=14, - projection_dim=2048, - projector_hidden_act="gelu_fast", - vision_use_head=False, - vocab_size=257152, - ) - - text_config = GemmaConfig.from_pretrained( - pretrained_model_name_or_path, - attention_bias=False, - attention_dropout=0.0, - bos_token_id=2, - eos_token_id=1, - head_dim=256, - hidden_act="gelu_pytorch_tanh", - hidden_activation=None, - hidden_size=2048, - initializer_range=0.02, - intermediate_size=16384, - max_position_embeddings=8192, - model_type="gemma", - num_attention_heads=8, - num_hidden_layers=18, - num_image_tokens=256, - num_key_value_heads=1, - pad_token_id=0, - rms_norm_eps=1e-06, - rope_theta=10000.0, - torch_dtype="float32", - transformers_version="4.40.0.dev0", - use_cache=True, - vocab_size=257216, - **kwargs, - ) - - return PaliGemmaConfig( - text_config=text_config, - vision_config=vision_config, - **kwargs, - ) - - -class FlashPaliGemmaForConditionalGeneration(nn.Module): +class PaliGemmaForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() config.vision_config.quantize = config.quantize @@ -166,6 +46,9 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): self.vocab_size = config.vocab_size self.config = config + text_config = config.text_config + text_config.speculator = config.speculator + text_config.quantize = config.quantize self.text_model = load_text_model( prefix="language_model" if not prefix else f"{prefix}.language_model", config=config.text_config, @@ -188,36 +71,28 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, - past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, - pixel_attention_mask=None, + # Unused here + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_sizes: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.text_model.embed_tokens(input_ids) + # TODO This is odd but apparently pali gemma position ids start at 1. + if cu_seqlen_prefill is not None: + max_s += 1 + position_ids += 1 - if pixel_values is not None and len(pixel_values) > 0: - # TODO: avoid these casts upstream - pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype) - + if pixel_values is not None: + pixel_values = pixel_values.to(dtype=inputs_embeds.dtype) image_outputs = self.vision_tower(pixel_values) image_features = self.multi_modal_projector(image_outputs.last_hidden_state) - # TODO: now we scale them? maybe we can do this up or downstream - scaled_image_features = image_features / ( - self.config.text_config.hidden_size**0.5 - ) - # mask where image or padding tokens - mask = input_ids == self.config.image_token_index | (input_ids == 2) + mask = input_ids == self.config.image_token_index # insert image features into input embeddings - # normalizer = torch.tensor( - # self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype - # ) - # inputs_embeds = inputs_embeds * normalizer - inputs_embeds[mask] = scaled_image_features.view( - -1, scaled_image_features.shape[-1] - ) + inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) - hidden_states = self.language_model.model( + hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -230,6 +105,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits, speculative_logits = self.language_model.lm_head(hidden_states) + logits, speculative_logits = self.text_model.lm_head(hidden_states) return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/siglip.py b/server/text_generation_server/models/custom_modeling/siglip.py index b1f6a192..f17d6562 100644 --- a/server/text_generation_server/models/custom_modeling/siglip.py +++ b/server/text_generation_server/models/custom_modeling/siglip.py @@ -16,7 +16,7 @@ from transformers.modeling_outputs import ( ) from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig -from text_generation_server.utils.layers import ( +from text_generation_server.layers.tensor_parallel import ( TensorParallelEmbedding, TensorParallelColumnLinear, TensorParallelRowLinear, diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index 936d0290..b74b43ff 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -16,7 +16,7 @@ def load_text_model(prefix, config, weights, name=None): FlashGemmaForCausalLM, ) - return FlashGemmaForCausalLM(prefix, config, weights) + return FlashGemmaForCausalLM(prefix, config, weights, causal=False) elif config.model_type == "paligemma": from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( FlashGemmaForCausalLM, diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index e317265b..53bfd064 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -3,8 +3,7 @@ import torch.distributed from opentelemetry import trace from typing import Optional -from transformers.models.gemma import GemmaTokenizerFast -from transformers import AutoConfig +from transformers import AutoConfig, AutoTokenizer from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( @@ -36,14 +35,12 @@ class FlashGemma(FlashCausalLM): else: raise NotImplementedError("FlashGemma is only available on GPU") - tokenizer = GemmaTokenizerFast.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left", trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, ) config = AutoConfig.from_pretrained( @@ -61,7 +58,7 @@ class FlashGemma(FlashCausalLM): # TODO hardcoded prefix = "language_model" - model = FlashGemmaForCausalLM(prefix, config, weights) + model = FlashGemmaForCausalLM(prefix, config, weights, causal=True) torch.distributed.barrier(group=self.process_group) super(FlashGemma, self).__init__( diff --git a/server/text_generation_server/models/flash_pali_gemma.py b/server/text_generation_server/models/flash_pali_gemma.py deleted file mode 100644 index d8b7f5ac..00000000 --- a/server/text_generation_server/models/flash_pali_gemma.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -import torch.distributed -from opentelemetry import trace -from typing import Optional, Tuple -from text_generation_server.models.vlm_causal_lm import PaliVlmCausalLM -from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( - FlashPaliGemmaForConditionalGeneration, - PaliGemmaConfig, -) -from transformers import AutoProcessor - -tracer = trace.get_tracer(__name__) - - -class FlashPaliGemma(PaliVlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - use_medusa: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - "google/siglip-base-patch16-224", - revision=revision, - trust_remote_code=trust_remote_code, - ) - - super().__init__( - config_cls=PaliGemmaConfig, - model_cls=FlashPaliGemmaForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - use_medusa=use_medusa, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.language_model.model.layers), - model.language_model.model.num_key_value_heads, - model.language_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.language_model, "max_past", None) diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py new file mode 100644 index 00000000..d94b9526 --- /dev/null +++ b/server/text_generation_server/models/pali_gemma.py @@ -0,0 +1,123 @@ +import torch +import torch.distributed +from opentelemetry import trace +from typing import Optional, Tuple +from text_generation_server.models.vlm_causal_lm import ( + VlmCausalLM, + VlmCausalLMBatch, + image_text_replacement, + load_data_uri, + split, +) +from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( + PaliGemmaForConditionalGeneration, +) +from transformers import AutoProcessor, AutoConfig, AutoImageProcessor + +tracer = trace.get_tracer(__name__) + + +class PaliGemmaBatch(VlmCausalLMBatch): + @classmethod + def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): + batch_inputs = [] + image_inputs = [] + max_truncation = 0 + for r in requests: + chunks = split(r.inputs) + full_text = "" + image_id = 0 + for chunk in chunks: + if chunk["type"] == "text": + full_text += "" + chunk["content"] + "\n" + elif chunk["type"] == "image": + image = chunk["content"] + # Should never receive URLs anymore, processing should be done + # On the rust layer. + # This avoid making n queries per TP + # if image.startswith("https://") or image.startswith("http://"): + # image = processor.image_processor.fetch_images(image) + if image.startswith("data:"): + image = load_data_uri(image) + else: + raise RuntimeError( + "Cannot process input image not starting with data:" + ) + # TODO do_convert_RGB should be on by default ? + image = image.convert("RGB") + image_input = processor.image_processor(image, return_tensors="pt") + full_text += image_text_replacement(image_input, config, image_id) + image_inputs.append(image_input) + else: + raise RuntimeError(f"Invalid chunk type {chunk['type']}") + + batch_inputs.append(full_text) + max_truncation = max(max_truncation, r.truncate) + + batch_tokenized_inputs = tokenizer( + batch_inputs, + truncation=True, + max_length=max_truncation, + add_special_tokens=False, + )["input_ids"] + if image_inputs: + image_input = image_inputs[0] + new_image_inputs = { + "pixel_values": torch.cat( + [img["pixel_values"] for img in image_inputs], dim=0 + ), + } + if "pixel_attention_mask" in image_input: + new_image_inputs["pixel_attention_mask"] = torch.cat( + [img["pixel_attention_mask"] for img in image_inputs], dim=0 + ) + if "image_sizes" in image_input: + new_image_inputs["image_sizes"] = torch.cat( + [img["image_sizes"] for img in image_inputs], dim=0 + ) + image_inputs = new_image_inputs + else: + image_inputs = None + return batch_tokenized_inputs, image_inputs + + +class PaliGemma(VlmCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.processor = AutoProcessor.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + + super().__init__( + config_cls=AutoConfig, + model_cls=PaliGemmaForConditionalGeneration, + model_id=model_id, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + @property + def batch_type(self): + return PaliGemmaBatch + + def get_layer_config(self, model) -> Tuple[int, int, int]: + return ( + len(model.text_model.model.layers), + model.text_model.model.num_key_value_heads, + model.text_model.model.head_size, + ) + + def max_past(self) -> Optional[int]: + return getattr(self.model.text_model, "max_past", None) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 12f709df..6292b25f 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -83,8 +83,7 @@ def image_text_replacement(image_input, config, image_id) -> str: return "" * num_features elif config.model_type == "paligemma": - # TODO: use correct number of features - return "" * 256 + return "" * config.text_config.num_image_tokens else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -174,7 +173,7 @@ class VlmCausalLMBatch(FlashMistralBatch): image_id = 0 for chunk in chunks: if chunk["type"] == "text": - full_text += chunk["content"] + full_text += "" + chunk["content"] + "\n" elif chunk["type"] == "image": image = chunk["content"] # Should never receive URLs anymore, processing should be done @@ -198,7 +197,10 @@ class VlmCausalLMBatch(FlashMistralBatch): max_truncation = max(max_truncation, r.truncate) batch_tokenized_inputs = tokenizer( - batch_inputs, truncation=True, max_length=max_truncation + batch_inputs, + truncation=True, + max_length=max_truncation, + add_special_tokens=False, )["input_ids"] if image_inputs: image_input = image_inputs[0] @@ -376,142 +378,3 @@ class VlmCausalLM(BaseFlashMistral): ) logits = cuda_graph["logits"][:bs] return logits, speculative_logits - - -class PaliVlmCausalLMBatch(FlashCausalLMBatch): - pixel_values: Optional[List[torch.Tensor]] - pixel_attention_mask: Optional[List[torch.Tensor]] - image_sizes: Optional[List[Tuple[int, int]]] - - @classmethod - @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches): - batch = super(PaliVlmCausalLMBatch, cls).concatenate(batches) - batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - return batch - - @tracer.start_as_current_span("filter") - def filter(self, request_ids: List[int]): - batch = super().filter(request_ids) - batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - return batch - - @classmethod - def batch_tokenized_inputs(cls, requests, tokenizer, processor, config): - batch_inputs = [] - image_inputs = [] - text_inputs = [] - image_text_replacements = [] - max_truncation = 0 - for r in requests: - chunks = split(r.inputs) - full_text = "" - image_id = 0 - for chunk in chunks: - if chunk["type"] == "text": - full_text += chunk["content"] - text_inputs.append(chunk["content"]) - elif chunk["type"] == "image": - image = chunk["content"] - # Should never receive URLs anymore, processing should be done - # On the rust layer. - # This avoid making n queries per TP - # if image.startswith("https://") or image.startswith("http://"): - # image = processor.image_processor.fetch_images(image) - if image.startswith("data:"): - image = load_data_uri(image) - else: - raise RuntimeError( - "Cannot process input image not starting with data:" - ) - image_input = processor.image_processor(image, return_tensors="pt") - text_replacement = image_text_replacement( - image_input, config, image_id - ) - full_text += text_replacement - image_text_replacements.append(text_replacement) - image_inputs.append(image_input) - else: - raise RuntimeError(f"Invalid chunk type {chunk['type']}") - - batch_inputs.append(full_text) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, - truncation=True, - max_length=max_truncation, - add_special_tokens=False, - )["input_ids"] - - image_token = tokenizer.get_added_vocab()[""] - - # find the index of the first non-image token - for batch in batch_tokenized_inputs: - first_non_image = 0 - for i, token in enumerate(batch): - if token != image_token: - first_non_image = i - break - - # manually add the bos to the left of the text - batch_tokenized_inputs = [ - batch[:first_non_image] + [tokenizer.bos_token_id] + batch[first_non_image:] - for batch in batch_tokenized_inputs - ] - - if image_inputs: - image_input = image_inputs[0] - new_image_inputs = { - "pixel_values": torch.cat( - [img["pixel_values"] for img in image_inputs], dim=0 - ), - } - if "pixel_attention_mask" in image_input: - new_image_inputs["pixel_attention_mask"] = torch.cat( - [img["pixel_attention_mask"] for img in image_inputs], dim=0 - ) - if "image_sizes" in image_input: - new_image_inputs["image_sizes"] = torch.cat( - [img["image_sizes"] for img in image_inputs], dim=0 - ) - image_inputs = new_image_inputs - else: - image_inputs = None - return batch_tokenized_inputs, image_inputs - - @classmethod - def from_pb_processor( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - processor, - config, - dtype: torch.dtype, - device: torch.device, - ) -> "PaliVlmCausalLMBatch": - batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( - pb.requests, tokenizer, processor, config - ) - batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) - if image_inputs is not None: - batch.pixel_values = image_inputs["pixel_values"].to(device=device) - if "pixel_attention_mask" in image_inputs: - batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to( - device=device - ) - else: - batch.pixel_attention_mask = None - if "image_sizes" in image_inputs: - batch.image_sizes = image_inputs["image_sizes"].to(device=device) - else: - batch.image_sizes = None - else: - batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - return batch diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 06138fa5..92126fe6 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -14,9 +14,9 @@ from typing import List, Optional from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model +from text_generation_server.models.pali_gemma import PaliGemmaBatch from text_generation_server.models.vlm_causal_lm import ( VlmCausalLMBatch, - PaliVlmCausalLMBatch, ) from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor @@ -101,7 +101,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if self.model.batch_type in { IdeficsCausalLMBatch, VlmCausalLMBatch, - PaliVlmCausalLMBatch, + PaliGemmaBatch, }: # Hack, i would rather use kwargs in the `from_pb` call batch = self.model.batch_type.from_pb_processor( request.batch, @@ -126,7 +126,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if self.model.batch_type in { IdeficsCausalLMBatch, VlmCausalLMBatch, - PaliVlmCausalLMBatch, + PaliGemmaBatch, }: # Hack, i would rather use kwargs in the `from_pb` call batch = self.model.batch_type.from_pb_processor( request.batch, diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 0830656d..ae60fa63 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -116,6 +116,7 @@ if HAS_FLASH_ATTN_V2_CUDA: max_s, softmax_scale, window_size_left=-1, + causal=True, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -134,7 +135,7 @@ if HAS_FLASH_ATTN_V2_CUDA: 0.0, softmax_scale, False, - True, + causal, window_size_left, 0, False,