mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: improve config and refactor
This commit is contained in:
parent
5fd72ed06c
commit
b07b53efba
@ -33,7 +33,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
|||||||
cow = get_cow_beach()
|
cow = get_cow_beach()
|
||||||
inputs = f"Where is the cow standing?\n"
|
inputs = f"Where is the cow standing?\n"
|
||||||
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
|
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
|
||||||
|
|
||||||
# TODO: update this! this is incorrect and just to show the current state of the test
|
# TODO: update this! this is incorrect and just to show the current state of the test
|
||||||
assert response.generated_text == ' - HDS'
|
assert response.generated_text == " - HDS"
|
||||||
# assert response.generated_text == "\nbeach"
|
# assert response.generated_text == "\nbeach"
|
||||||
|
@ -129,7 +129,7 @@ impl Paligemma {
|
|||||||
// 224 = 256 image tokens
|
// 224 = 256 image tokens
|
||||||
// 448 = 1024 image tokens
|
// 448 = 1024 image tokens
|
||||||
// 896 = 4096 image tokens
|
// 896 = 4096 image tokens
|
||||||
|
|
||||||
256
|
256
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -63,6 +63,8 @@ class GemmaConfig(PretrainedConfig):
|
|||||||
rope_scaling=None,
|
rope_scaling=None,
|
||||||
attention_bias=False,
|
attention_bias=False,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@ -86,6 +88,8 @@ class GemmaConfig(PretrainedConfig):
|
|||||||
self.rope_scaling = rope_scaling
|
self.rope_scaling = rope_scaling
|
||||||
self.attention_bias = attention_bias
|
self.attention_bias = attention_bias
|
||||||
self.attention_dropout = attention_dropout
|
self.attention_dropout = attention_dropout
|
||||||
|
self.quantize = quantize
|
||||||
|
self.use_medusa = use_medusa
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
@ -407,7 +411,7 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
@ -459,7 +463,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
@ -29,124 +29,11 @@ from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: prefer using the following config classes
|
|
||||||
# * instead of the hack inside of the gemma modeling file
|
|
||||||
class VisionConfig(PretrainedConfig):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
model_type: str,
|
|
||||||
num_attention_heads: int,
|
|
||||||
num_hidden_layers: int,
|
|
||||||
num_image_tokens: int,
|
|
||||||
patch_size: int,
|
|
||||||
projection_dim: int,
|
|
||||||
projector_hidden_act: str,
|
|
||||||
vision_use_head: bool,
|
|
||||||
vocab_size: int,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.model_type = model_type
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_image_tokens = num_image_tokens
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.projection_dim = projection_dim
|
|
||||||
self.projector_hidden_act = projector_hidden_act
|
|
||||||
self.vision_use_head = vision_use_head
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.quantize = quantize
|
|
||||||
|
|
||||||
|
|
||||||
class PaliTextConfig(PretrainedConfig):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
model_type: str,
|
|
||||||
num_attention_heads: int,
|
|
||||||
num_hidden_layers: int,
|
|
||||||
num_image_tokens: int,
|
|
||||||
num_key_value_heads: int,
|
|
||||||
torch_dtype: str,
|
|
||||||
vocab_size: int,
|
|
||||||
):
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.model_type = model_type
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_image_tokens = num_image_tokens
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.torch_dtype = torch_dtype
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaConfig(PretrainedConfig):
|
class PaliGemmaConfig(PretrainedConfig):
|
||||||
def __init__(
|
model_type = "paligemma"
|
||||||
self,
|
|
||||||
vocab_size=257216,
|
|
||||||
hidden_size=2048,
|
|
||||||
intermediate_size=24576,
|
|
||||||
num_hidden_layers=28,
|
|
||||||
num_attention_heads=16,
|
|
||||||
num_key_value_heads=16,
|
|
||||||
head_dim=256,
|
|
||||||
hidden_act="gelu_pytorch_tanh",
|
|
||||||
max_position_embeddings=8192,
|
|
||||||
initializer_range=0.02,
|
|
||||||
rms_norm_eps=1e-6,
|
|
||||||
use_cache=True,
|
|
||||||
pad_token_id=0,
|
|
||||||
bos_token_id=2,
|
|
||||||
eos_token_id=1,
|
|
||||||
tie_word_embeddings=True,
|
|
||||||
rope_theta=10000.0,
|
|
||||||
rope_scaling=None,
|
|
||||||
attention_bias=False,
|
|
||||||
attention_dropout=0.0,
|
|
||||||
text_config=None,
|
|
||||||
vision_config=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
|
|
||||||
# for backward compatibility
|
def from_pretrained(pretrained_model_name_or_path, **kwargs):
|
||||||
if num_key_value_heads is None:
|
vision_config = VisionConfig(
|
||||||
num_key_value_heads = num_attention_heads
|
|
||||||
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.rms_norm_eps = rms_norm_eps
|
|
||||||
self.use_cache = use_cache
|
|
||||||
self.rope_theta = rope_theta
|
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
self.attention_bias = attention_bias
|
|
||||||
self.attention_dropout = attention_dropout
|
|
||||||
|
|
||||||
self.text_config = GemmaConfig(
|
|
||||||
hidden_size=2048,
|
|
||||||
intermediate_size=16384,
|
|
||||||
model_type="gemma",
|
|
||||||
num_attention_heads=8,
|
|
||||||
num_hidden_layers=18,
|
|
||||||
num_image_tokens=256,
|
|
||||||
num_key_value_heads=1,
|
|
||||||
torch_dtype="float32",
|
|
||||||
vocab_size=257216,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.vision_config = VisionConfig(
|
|
||||||
hidden_size=1152,
|
hidden_size=1152,
|
||||||
intermediate_size=4304,
|
intermediate_size=4304,
|
||||||
model_type="siglip_vision_model",
|
model_type="siglip_vision_model",
|
||||||
@ -160,20 +47,58 @@ class PaliGemmaConfig(PretrainedConfig):
|
|||||||
vocab_size=257152,
|
vocab_size=257152,
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(
|
return GemmaConfig.from_pretrained(
|
||||||
pad_token_id=pad_token_id,
|
pretrained_model_name_or_path, vision_config=vision_config, **kwargs
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VisionConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 1152,
|
||||||
|
intermediate_size: int = 4304,
|
||||||
|
model_type: str = "siglip_vision_model",
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_hidden_layers: int = 27,
|
||||||
|
num_image_tokens: int = 256,
|
||||||
|
patch_size: int = 14,
|
||||||
|
projection_dim: int = 2048,
|
||||||
|
projector_hidden_act: str = "gelu_fast",
|
||||||
|
vision_use_head: bool = False,
|
||||||
|
vocab_size: int = 257152,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
image_size: int = 224,
|
||||||
|
layer_norm_eps: float = 1e-06,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
hidden_act: str = "gelu_pytorch_tanh",
|
||||||
|
num_channels: int = 3,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.model_type = model_type
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_image_tokens = num_image_tokens
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.projection_dim = projection_dim
|
||||||
|
self.projector_hidden_act = projector_hidden_act
|
||||||
|
self.vision_use_head = vision_use_head
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.quantize = quantize
|
||||||
|
self.image_size = image_size
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.num_channels = num_channels
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config.vision_config.quantize = config.quantize
|
config.vision_config.quantize = config.quantize
|
||||||
|
|
||||||
self.vision_tower = load_vision_model(
|
self.vision_tower = load_vision_model(
|
||||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||||
config=config.vision_config,
|
config=config.vision_config,
|
||||||
@ -256,7 +181,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
||||||
|
@ -16,6 +16,12 @@ def load_text_model(prefix, config, weights, name=None):
|
|||||||
FlashGemmaForCausalLM,
|
FlashGemmaForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return FlashGemmaForCausalLM(prefix, config, weights)
|
||||||
|
elif config.model_type == "paligemma":
|
||||||
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
|
FlashGemmaForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
return FlashGemmaForCausalLM(prefix, config, weights)
|
return FlashGemmaForCausalLM(prefix, config, weights)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||||
|
@ -90,41 +90,20 @@ class BaseFlashGemma(FlashCausalLM):
|
|||||||
from_slow=False,
|
from_slow=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = GemmaConfig.from_pretrained(
|
config = config_cls.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
|
||||||
is_vlm = hasattr(config, "vision_config") and hasattr(config, "text_config")
|
is_vlm = hasattr(config, "vision_config")
|
||||||
|
|
||||||
if is_vlm:
|
|
||||||
config.vision_config = VisionConfig(
|
|
||||||
hidden_size=1152,
|
|
||||||
intermediate_size=4304,
|
|
||||||
model_type="siglip_vision_model",
|
|
||||||
num_attention_heads=16,
|
|
||||||
num_hidden_layers=27,
|
|
||||||
num_image_tokens=256,
|
|
||||||
patch_size=14,
|
|
||||||
projection_dim=2048,
|
|
||||||
projector_hidden_act="gelu_fast",
|
|
||||||
vision_use_head=False,
|
|
||||||
vocab_size=257152,
|
|
||||||
quantize=quantize,
|
|
||||||
)
|
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.speculator = speculator
|
config.speculator = speculator
|
||||||
|
|
||||||
if is_vlm:
|
if is_vlm:
|
||||||
config.num_hidden_layers = config.text_config.get("num_hidden_layers")
|
|
||||||
config.intermediate_size = config.text_config.get("intermediate_size")
|
config.intermediate_size = config.text_config.get("intermediate_size")
|
||||||
config.model_type = config.text_config.get("model_type")
|
|
||||||
config.num_attention_heads = config.text_config.get("num_attention_heads")
|
config.num_attention_heads = config.text_config.get("num_attention_heads")
|
||||||
config.num_hidden_layers = config.text_config.get("num_hidden_layers")
|
config.num_hidden_layers = config.text_config.get("num_hidden_layers")
|
||||||
config.num_image_tokens = config.text_config.get("num_image_tokens")
|
|
||||||
config.num_key_value_heads = config.text_config.get("num_key_value_heads")
|
config.num_key_value_heads = config.text_config.get("num_key_value_heads")
|
||||||
config.torch_dtype = config.text_config.get("torch_dtype")
|
|
||||||
config.vocab_size = config.text_config.get("vocab_size")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@ from text_generation_server.models.vlm_causal_lm import PaliVlmCausalLM
|
|||||||
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||||
FlashPaliGemmaForConditionalGeneration,
|
FlashPaliGemmaForConditionalGeneration,
|
||||||
PaliGemmaConfig,
|
PaliGemmaConfig,
|
||||||
PaliTextConfig,
|
|
||||||
)
|
)
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
@ -32,7 +31,7 @@ class FlashPaliGemma(PaliVlmCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
config_cls=PaliTextConfig,
|
config_cls=PaliGemmaConfig,
|
||||||
model_cls=FlashPaliGemmaForConditionalGeneration,
|
model_cls=FlashPaliGemmaForConditionalGeneration,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
@ -83,8 +83,7 @@ def image_text_replacement(image_input, config, image_id) -> str:
|
|||||||
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
|
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
|
||||||
return "<image>" * num_features
|
return "<image>" * num_features
|
||||||
|
|
||||||
# TODO: double check correct naming for model_type
|
elif config.model_type == "paligemma":
|
||||||
elif config.model_type == "gemma":
|
|
||||||
# TODO: use correct number of features
|
# TODO: use correct number of features
|
||||||
return "<image>" * 256
|
return "<image>" * 256
|
||||||
else:
|
else:
|
||||||
|
@ -14,7 +14,10 @@ from typing import List, Optional
|
|||||||
from text_generation_server.cache import Cache
|
from text_generation_server.cache import Cache
|
||||||
from text_generation_server.interceptor import ExceptionInterceptor
|
from text_generation_server.interceptor import ExceptionInterceptor
|
||||||
from text_generation_server.models import Model, get_model
|
from text_generation_server.models import Model, get_model
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, PaliVlmCausalLMBatch
|
from text_generation_server.models.vlm_causal_lm import (
|
||||||
|
VlmCausalLMBatch,
|
||||||
|
PaliVlmCausalLMBatch,
|
||||||
|
)
|
||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||||
|
Loading…
Reference in New Issue
Block a user