diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py index b3415d26..94b7b89d 100644 --- a/integration-tests/models/test_flash_pali_gemma.py +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -33,7 +33,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): cow = get_cow_beach() inputs = f"Where is the cow standing?\n![]({cow})" response = await flash_pali_gemma.generate(inputs, max_new_tokens=20) - + # TODO: update this! this is incorrect and just to show the current state of the test - assert response.generated_text == ' - HDS' - # assert response.generated_text == "\nbeach" \ No newline at end of file + assert response.generated_text == " - HDS" + # assert response.generated_text == "\nbeach" diff --git a/router/src/config.rs b/router/src/config.rs index 296d0352..3a7b775f 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -129,7 +129,7 @@ impl Paligemma { // 224 = 256 image tokens // 448 = 1024 image tokens // 896 = 4096 image tokens - + 256 } } diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 602a29cc..0954d857 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -63,6 +63,8 @@ class GemmaConfig(PretrainedConfig): rope_scaling=None, attention_bias=False, attention_dropout=0.0, + quantize: Optional[str] = None, + use_medusa: Optional[str] = None, **kwargs, ): self.vocab_size = vocab_size @@ -86,6 +88,8 @@ class GemmaConfig(PretrainedConfig): self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout + self.quantize = quantize + self.use_medusa = use_medusa super().__init__( pad_token_id=pad_token_id, @@ -407,7 +411,7 @@ class FlashGemmaModel(torch.nn.Module): max_s: int, ) -> torch.Tensor: hidden_states = inputs_embeds - + # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( @@ -459,7 +463,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): max_s: int, lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.model.embed_tokens(input_ids) hidden_states = self.model( inputs_embeds, position_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index e2bc9807..3e091ebd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -29,124 +29,11 @@ from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( ) -# TODO: prefer using the following config classes -# * instead of the hack inside of the gemma modeling file -class VisionConfig(PretrainedConfig): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - model_type: str, - num_attention_heads: int, - num_hidden_layers: int, - num_image_tokens: int, - patch_size: int, - projection_dim: int, - projector_hidden_act: str, - vision_use_head: bool, - vocab_size: int, - quantize: Optional[str] = None, - ): - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.model_type = model_type - self.num_attention_heads = num_attention_heads - self.num_hidden_layers = num_hidden_layers - self.num_image_tokens = num_image_tokens - self.patch_size = patch_size - self.projection_dim = projection_dim - self.projector_hidden_act = projector_hidden_act - self.vision_use_head = vision_use_head - self.vocab_size = vocab_size - self.quantize = quantize - - -class PaliTextConfig(PretrainedConfig): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - model_type: str, - num_attention_heads: int, - num_hidden_layers: int, - num_image_tokens: int, - num_key_value_heads: int, - torch_dtype: str, - vocab_size: int, - ): - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.model_type = model_type - self.num_attention_heads = num_attention_heads - self.num_hidden_layers = num_hidden_layers - self.num_image_tokens = num_image_tokens - self.num_key_value_heads = num_key_value_heads - self.torch_dtype = torch_dtype - self.vocab_size = vocab_size - - class PaliGemmaConfig(PretrainedConfig): - def __init__( - self, - vocab_size=257216, - hidden_size=2048, - intermediate_size=24576, - num_hidden_layers=28, - num_attention_heads=16, - num_key_value_heads=16, - head_dim=256, - hidden_act="gelu_pytorch_tanh", - max_position_embeddings=8192, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=2, - eos_token_id=1, - tie_word_embeddings=True, - rope_theta=10000.0, - rope_scaling=None, - attention_bias=False, - attention_dropout=0.0, - text_config=None, - vision_config=None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.head_dim = head_dim - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads + model_type = "paligemma" - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - - self.text_config = GemmaConfig( - hidden_size=2048, - intermediate_size=16384, - model_type="gemma", - num_attention_heads=8, - num_hidden_layers=18, - num_image_tokens=256, - num_key_value_heads=1, - torch_dtype="float32", - vocab_size=257216, - ) - - self.vision_config = VisionConfig( + def from_pretrained(pretrained_model_name_or_path, **kwargs): + vision_config = VisionConfig( hidden_size=1152, intermediate_size=4304, model_type="siglip_vision_model", @@ -160,20 +47,58 @@ class PaliGemmaConfig(PretrainedConfig): vocab_size=257152, ) - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, + return GemmaConfig.from_pretrained( + pretrained_model_name_or_path, vision_config=vision_config, **kwargs ) +class VisionConfig(PretrainedConfig): + def __init__( + self, + hidden_size: int = 1152, + intermediate_size: int = 4304, + model_type: str = "siglip_vision_model", + num_attention_heads: int = 16, + num_hidden_layers: int = 27, + num_image_tokens: int = 256, + patch_size: int = 14, + projection_dim: int = 2048, + projector_hidden_act: str = "gelu_fast", + vision_use_head: bool = False, + vocab_size: int = 257152, + quantize: Optional[str] = None, + image_size: int = 224, + layer_norm_eps: float = 1e-06, + attention_dropout: float = 0.0, + hidden_act: str = "gelu_pytorch_tanh", + num_channels: int = 3, + **kwargs, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.model_type = model_type + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.num_image_tokens = num_image_tokens + self.patch_size = patch_size + self.projection_dim = projection_dim + self.projector_hidden_act = projector_hidden_act + self.vision_use_head = vision_use_head + self.vocab_size = vocab_size + self.quantize = quantize + self.image_size = image_size + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.num_channels = num_channels + + super().__init__(**kwargs) + + class FlashPaliGemmaForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() config.vision_config.quantize = config.quantize - self.vision_tower = load_vision_model( prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", config=config.vision_config, @@ -256,7 +181,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): max_s=max_s, ) - if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.language_model.lm_head(hidden_states) diff --git a/server/text_generation_server/models/custom_modeling/vlm.py b/server/text_generation_server/models/custom_modeling/vlm.py index 9ae4add7..936d0290 100644 --- a/server/text_generation_server/models/custom_modeling/vlm.py +++ b/server/text_generation_server/models/custom_modeling/vlm.py @@ -16,6 +16,12 @@ def load_text_model(prefix, config, weights, name=None): FlashGemmaForCausalLM, ) + return FlashGemmaForCausalLM(prefix, config, weights) + elif config.model_type == "paligemma": + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, + ) + return FlashGemmaForCausalLM(prefix, config, weights) else: raise RuntimeError(f"Unsupported model type {config.model_type}") diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index e547e267..d60d49de 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -90,41 +90,20 @@ class BaseFlashGemma(FlashCausalLM): from_slow=False, ) - config = GemmaConfig.from_pretrained( + config = config_cls.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) - is_vlm = hasattr(config, "vision_config") and hasattr(config, "text_config") - - if is_vlm: - config.vision_config = VisionConfig( - hidden_size=1152, - intermediate_size=4304, - model_type="siglip_vision_model", - num_attention_heads=16, - num_hidden_layers=27, - num_image_tokens=256, - patch_size=14, - projection_dim=2048, - projector_hidden_act="gelu_fast", - vision_use_head=False, - vocab_size=257152, - quantize=quantize, - ) + is_vlm = hasattr(config, "vision_config") config.quantize = quantize config.speculator = speculator if is_vlm: - config.num_hidden_layers = config.text_config.get("num_hidden_layers") config.intermediate_size = config.text_config.get("intermediate_size") - config.model_type = config.text_config.get("model_type") config.num_attention_heads = config.text_config.get("num_attention_heads") config.num_hidden_layers = config.text_config.get("num_hidden_layers") - config.num_image_tokens = config.text_config.get("num_image_tokens") config.num_key_value_heads = config.text_config.get("num_key_value_heads") - config.torch_dtype = config.text_config.get("torch_dtype") - config.vocab_size = config.text_config.get("vocab_size") torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_pali_gemma.py b/server/text_generation_server/models/flash_pali_gemma.py index 16955f8c..e3ab8bbf 100644 --- a/server/text_generation_server/models/flash_pali_gemma.py +++ b/server/text_generation_server/models/flash_pali_gemma.py @@ -6,7 +6,6 @@ from text_generation_server.models.vlm_causal_lm import PaliVlmCausalLM from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( FlashPaliGemmaForConditionalGeneration, PaliGemmaConfig, - PaliTextConfig, ) from transformers import AutoProcessor @@ -32,7 +31,7 @@ class FlashPaliGemma(PaliVlmCausalLM): ) super().__init__( - config_cls=PaliTextConfig, + config_cls=PaliGemmaConfig, model_cls=FlashPaliGemmaForConditionalGeneration, model_id=model_id, revision=revision, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 3ec6c319..b2dbd0a3 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -83,8 +83,7 @@ def image_text_replacement(image_input, config, image_id) -> str: logger.info(f"Found {num_features} in image of resolution {height}x{width}") return "" * num_features - # TODO: double check correct naming for model_type - elif config.model_type == "gemma": + elif config.model_type == "paligemma": # TODO: use correct number of features return "" * 256 else: diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 94cf1417..06138fa5 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -14,7 +14,10 @@ from typing import List, Optional from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model -from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, PaliVlmCausalLMBatch +from text_generation_server.models.vlm_causal_lm import ( + VlmCausalLMBatch, + PaliVlmCausalLMBatch, +) from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch