feat: improve config and refactor

This commit is contained in:
drbh 2024-05-09 01:25:59 +00:00 committed by Nicolas Patry
parent 5fd72ed06c
commit b07b53efba
9 changed files with 72 additions and 158 deletions

View File

@ -33,7 +33,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
cow = get_cow_beach() cow = get_cow_beach()
inputs = f"Where is the cow standing?\n![]({cow})" inputs = f"Where is the cow standing?\n![]({cow})"
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20) response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
# TODO: update this! this is incorrect and just to show the current state of the test # TODO: update this! this is incorrect and just to show the current state of the test
assert response.generated_text == ' - HDS' assert response.generated_text == " - HDS"
# assert response.generated_text == "\nbeach" # assert response.generated_text == "\nbeach"

View File

@ -129,7 +129,7 @@ impl Paligemma {
// 224 = 256 image tokens // 224 = 256 image tokens
// 448 = 1024 image tokens // 448 = 1024 image tokens
// 896 = 4096 image tokens // 896 = 4096 image tokens
256 256
} }
} }

View File

@ -63,6 +63,8 @@ class GemmaConfig(PretrainedConfig):
rope_scaling=None, rope_scaling=None,
attention_bias=False, attention_bias=False,
attention_dropout=0.0, attention_dropout=0.0,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -86,6 +88,8 @@ class GemmaConfig(PretrainedConfig):
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.attention_bias = attention_bias self.attention_bias = attention_bias
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.quantize = quantize
self.use_medusa = use_medusa
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
@ -407,7 +411,7 @@ class FlashGemmaModel(torch.nn.Module):
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
@ -459,7 +463,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.model.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
inputs_embeds, inputs_embeds,
position_ids, position_ids,

View File

@ -29,124 +29,11 @@ from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
) )
# TODO: prefer using the following config classes
# * instead of the hack inside of the gemma modeling file
class VisionConfig(PretrainedConfig):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
model_type: str,
num_attention_heads: int,
num_hidden_layers: int,
num_image_tokens: int,
patch_size: int,
projection_dim: int,
projector_hidden_act: str,
vision_use_head: bool,
vocab_size: int,
quantize: Optional[str] = None,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.model_type = model_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.num_image_tokens = num_image_tokens
self.patch_size = patch_size
self.projection_dim = projection_dim
self.projector_hidden_act = projector_hidden_act
self.vision_use_head = vision_use_head
self.vocab_size = vocab_size
self.quantize = quantize
class PaliTextConfig(PretrainedConfig):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
model_type: str,
num_attention_heads: int,
num_hidden_layers: int,
num_image_tokens: int,
num_key_value_heads: int,
torch_dtype: str,
vocab_size: int,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.model_type = model_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.num_image_tokens = num_image_tokens
self.num_key_value_heads = num_key_value_heads
self.torch_dtype = torch_dtype
self.vocab_size = vocab_size
class PaliGemmaConfig(PretrainedConfig): class PaliGemmaConfig(PretrainedConfig):
def __init__( model_type = "paligemma"
self,
vocab_size=257216,
hidden_size=2048,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
tie_word_embeddings=True,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
text_config=None,
vision_config=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.head_dim = head_dim
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility def from_pretrained(pretrained_model_name_or_path, **kwargs):
if num_key_value_heads is None: vision_config = VisionConfig(
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.text_config = GemmaConfig(
hidden_size=2048,
intermediate_size=16384,
model_type="gemma",
num_attention_heads=8,
num_hidden_layers=18,
num_image_tokens=256,
num_key_value_heads=1,
torch_dtype="float32",
vocab_size=257216,
)
self.vision_config = VisionConfig(
hidden_size=1152, hidden_size=1152,
intermediate_size=4304, intermediate_size=4304,
model_type="siglip_vision_model", model_type="siglip_vision_model",
@ -160,20 +47,58 @@ class PaliGemmaConfig(PretrainedConfig):
vocab_size=257152, vocab_size=257152,
) )
super().__init__( return GemmaConfig.from_pretrained(
pad_token_id=pad_token_id, pretrained_model_name_or_path, vision_config=vision_config, **kwargs
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
) )
class VisionConfig(PretrainedConfig):
def __init__(
self,
hidden_size: int = 1152,
intermediate_size: int = 4304,
model_type: str = "siglip_vision_model",
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
num_image_tokens: int = 256,
patch_size: int = 14,
projection_dim: int = 2048,
projector_hidden_act: str = "gelu_fast",
vision_use_head: bool = False,
vocab_size: int = 257152,
quantize: Optional[str] = None,
image_size: int = 224,
layer_norm_eps: float = 1e-06,
attention_dropout: float = 0.0,
hidden_act: str = "gelu_pytorch_tanh",
num_channels: int = 3,
**kwargs,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.model_type = model_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.num_image_tokens = num_image_tokens
self.patch_size = patch_size
self.projection_dim = projection_dim
self.projector_hidden_act = projector_hidden_act
self.vision_use_head = vision_use_head
self.vocab_size = vocab_size
self.quantize = quantize
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.num_channels = num_channels
super().__init__(**kwargs)
class FlashPaliGemmaForConditionalGeneration(nn.Module): class FlashPaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.vision_config.quantize = config.quantize config.vision_config.quantize = config.quantize
self.vision_tower = load_vision_model( self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config, config=config.vision_config,
@ -256,7 +181,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
max_s=max_s, max_s=max_s,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states) logits, speculative_logits = self.language_model.lm_head(hidden_states)

View File

@ -16,6 +16,12 @@ def load_text_model(prefix, config, weights, name=None):
FlashGemmaForCausalLM, FlashGemmaForCausalLM,
) )
return FlashGemmaForCausalLM(prefix, config, weights)
elif config.model_type == "paligemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
return FlashGemmaForCausalLM(prefix, config, weights) return FlashGemmaForCausalLM(prefix, config, weights)
else: else:
raise RuntimeError(f"Unsupported model type {config.model_type}") raise RuntimeError(f"Unsupported model type {config.model_type}")

View File

@ -90,41 +90,20 @@ class BaseFlashGemma(FlashCausalLM):
from_slow=False, from_slow=False,
) )
config = GemmaConfig.from_pretrained( config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
is_vlm = hasattr(config, "vision_config") and hasattr(config, "text_config") is_vlm = hasattr(config, "vision_config")
if is_vlm:
config.vision_config = VisionConfig(
hidden_size=1152,
intermediate_size=4304,
model_type="siglip_vision_model",
num_attention_heads=16,
num_hidden_layers=27,
num_image_tokens=256,
patch_size=14,
projection_dim=2048,
projector_hidden_act="gelu_fast",
vision_use_head=False,
vocab_size=257152,
quantize=quantize,
)
config.quantize = quantize config.quantize = quantize
config.speculator = speculator config.speculator = speculator
if is_vlm: if is_vlm:
config.num_hidden_layers = config.text_config.get("num_hidden_layers")
config.intermediate_size = config.text_config.get("intermediate_size") config.intermediate_size = config.text_config.get("intermediate_size")
config.model_type = config.text_config.get("model_type")
config.num_attention_heads = config.text_config.get("num_attention_heads") config.num_attention_heads = config.text_config.get("num_attention_heads")
config.num_hidden_layers = config.text_config.get("num_hidden_layers") config.num_hidden_layers = config.text_config.get("num_hidden_layers")
config.num_image_tokens = config.text_config.get("num_image_tokens")
config.num_key_value_heads = config.text_config.get("num_key_value_heads") config.num_key_value_heads = config.text_config.get("num_key_value_heads")
config.torch_dtype = config.text_config.get("torch_dtype")
config.vocab_size = config.text_config.get("vocab_size")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -6,7 +6,6 @@ from text_generation_server.models.vlm_causal_lm import PaliVlmCausalLM
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
FlashPaliGemmaForConditionalGeneration, FlashPaliGemmaForConditionalGeneration,
PaliGemmaConfig, PaliGemmaConfig,
PaliTextConfig,
) )
from transformers import AutoProcessor from transformers import AutoProcessor
@ -32,7 +31,7 @@ class FlashPaliGemma(PaliVlmCausalLM):
) )
super().__init__( super().__init__(
config_cls=PaliTextConfig, config_cls=PaliGemmaConfig,
model_cls=FlashPaliGemmaForConditionalGeneration, model_cls=FlashPaliGemmaForConditionalGeneration,
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,

View File

@ -83,8 +83,7 @@ def image_text_replacement(image_input, config, image_id) -> str:
logger.info(f"Found {num_features} in image of resolution {height}x{width}") logger.info(f"Found {num_features} in image of resolution {height}x{width}")
return "<image>" * num_features return "<image>" * num_features
# TODO: double check correct naming for model_type elif config.model_type == "paligemma":
elif config.model_type == "gemma":
# TODO: use correct number of features # TODO: use correct number of features
return "<image>" * 256 return "<image>" * 256
else: else:

View File

@ -14,7 +14,10 @@ from typing import List, Optional
from text_generation_server.cache import Cache from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model from text_generation_server.models import Model, get_model
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, PaliVlmCausalLMBatch from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch,
PaliVlmCausalLMBatch,
)
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch