mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Don't break what's not broken.
This commit is contained in:
parent
9b9614cea3
commit
ebbe7edca4
@ -75,7 +75,6 @@ try:
|
||||
from text_generation_server.models.flash_phi import FlashPhi
|
||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
||||
from text_generation_server.models.flash_pali_gemma import FlashPaliGemma
|
||||
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
||||
|
||||
except ImportError as e:
|
||||
@ -434,16 +433,6 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "paligemma":
|
||||
return FlashPaliGemma(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "cohere":
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCohere(
|
||||
|
@ -39,9 +39,6 @@ from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
|
||||
# TODO: used for debugging; to avoid breaking during warmup
|
||||
count = 0
|
||||
|
||||
|
||||
class GemmaConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
@ -66,8 +63,6 @@ class GemmaConfig(PretrainedConfig):
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -91,8 +86,6 @@ class GemmaConfig(PretrainedConfig):
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.quantize = quantize
|
||||
self.use_medusa = use_medusa
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
@ -106,7 +99,7 @@ class GemmaConfig(PretrainedConfig):
|
||||
class GemmaFastRMSNorm(FastRMSNorm):
|
||||
@classmethod
|
||||
def load(cls, prefix, weights, eps=1e-6):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||
return cls(weight, eps)
|
||||
|
||||
# perform the multiplication in full precision and downcast after
|
||||
@ -117,7 +110,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
hidden_states = hidden_states * (self.weight.float() + 1.0)
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states.to(self.weight.dtype), residual
|
||||
|
||||
|
||||
@ -159,107 +152,6 @@ def _load_gqa(config, prefix: str, weights):
|
||||
)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class GemmaRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.register_buffer("inv_freq", None, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if self.inv_freq is None:
|
||||
self.inv_freq = 1.0 / (
|
||||
self.base
|
||||
** (
|
||||
torch.arange(
|
||||
0, self.dim, 2, dtype=torch.int64, device=x.device
|
||||
).float()
|
||||
/ self.dim
|
||||
)
|
||||
)
|
||||
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
inv_freq_expanded = (
|
||||
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
)
|
||||
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = (
|
||||
device_type
|
||||
if isinstance(device_type, str) and device_type != "mps"
|
||||
else "cpu"
|
||||
)
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
freqs = (
|
||||
inv_freq_expanded.float() @ position_ids_expanded.float()
|
||||
).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class FlashGemmaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -271,16 +163,9 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_size = config.head_dim
|
||||
|
||||
# self._rotary_emb = PositionRotaryEmbedding.static(
|
||||
# config=config,
|
||||
# dim=self.head_size,
|
||||
# base=config.rope_theta,
|
||||
# device=weights.device,
|
||||
# )
|
||||
|
||||
self.rotary_emb = GemmaRotaryEmbedding(
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
@ -297,21 +182,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
# TODO: prefer this implementation after debugging
|
||||
# self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.k_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.k_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.v_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
|
||||
)
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
|
||||
)
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
@ -323,7 +194,6 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -337,55 +207,21 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
global count
|
||||
|
||||
# TODO: replace with better implementation after debugging
|
||||
tgt_len, src_len = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = (
|
||||
query_states.unsqueeze(0).view(1, tgt_len, 8, 256).transpose(1, 2)
|
||||
)
|
||||
key_states = key_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2)
|
||||
value_states = (
|
||||
value_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2)
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
# reshape for flash/paged attention
|
||||
kv = torch.cat([key_states, value_states], dim=1).transpose(0, 2)
|
||||
|
||||
# cos2, sin2 = self.rotary_emb(value_states, position_ids, seq_len=None)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, None
|
||||
)
|
||||
|
||||
# replace the kv with the rotated kv
|
||||
kv[:, 0] = key_states.squeeze(0).transpose(0, 1)
|
||||
query = query_states.squeeze(0).transpose(0, 1)
|
||||
|
||||
# TODO: remove after debugging
|
||||
# ipdb> ref_query_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_query_states.pt").to("cuda:0")
|
||||
# ipdb> torch.allclose(query,ref_query_states.squeeze(0).transpose(0,1))
|
||||
|
||||
# ipdb> ref_key_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_key_states.pt").to("cuda:0")
|
||||
# ipdb> torch.allclose(kv[:, 0],ref_key_states.squeeze(0).transpose(0,1))
|
||||
|
||||
# ipdb> ref_value_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_value_states.pt").to("cuda:0")
|
||||
# ipdb> torch.allclose(kv[:, 1],ref_value_states.squeeze(0).transpose(0,1))
|
||||
|
||||
if count > 0:
|
||||
import ipdb
|
||||
|
||||
# ipdb.set_trace()
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0],
|
||||
kv[:, 1],
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
slots,
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
@ -459,9 +295,8 @@ class GemmaMLP(nn.Module):
|
||||
|
||||
|
||||
class FlashGemmaLayer(nn.Module):
|
||||
def __init__(self, prefix, layer_id, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"{prefix or ''}model.layers.{layer_id}"
|
||||
self.self_attn = FlashGemmaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
@ -521,38 +356,18 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
embed_norm = config.hidden_size**0.5
|
||||
pvalue = f"{prefix + '.' if prefix else ''}model.embed_tokens"
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=pvalue,
|
||||
weights=weights,
|
||||
)
|
||||
self.embed_tokens.weight = torch.nn.Parameter(
|
||||
self.embed_tokens.weight[: config.vocab_size, : config.hidden_size]
|
||||
)
|
||||
|
||||
# TODO: avoid making a copy of the embedding matrix. added for debugging
|
||||
self.unscaled_embed_tokens = torch.nn.Parameter(
|
||||
self.embed_tokens.weight.clone()
|
||||
)
|
||||
|
||||
self.embed_tokens.weight *= embed_norm
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashGemmaLayer(
|
||||
f"{prefix + '.' if prefix else ''}",
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
prefix=f"{prefix}.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = GemmaFastRMSNorm.load(
|
||||
prefix=f"{prefix + '.' if prefix else ''}model.norm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -563,8 +378,7 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
# input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
input_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -573,17 +387,16 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
global count
|
||||
hidden_states = inputs_embeds
|
||||
hidden_states = input_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
# TODO: prefer a single rotary embedding implementation after debugging
|
||||
cos, sin = self.layers[i].self_attn.rotary_emb(
|
||||
hidden_states,
|
||||
position_ids,
|
||||
)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
@ -598,20 +411,33 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
count += 1 # for debugging; to avoid breaking during warmup
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = FlashGemmaModel(prefix, config, weights)
|
||||
prefix = f"{prefix + '.' if prefix else ''}model.embed_tokens"
|
||||
prefix = prefix if config.tie_word_embeddings else "lm_head"
|
||||
|
||||
embed_norm = config.hidden_size**0.5
|
||||
if prefix is None:
|
||||
prefix = "model"
|
||||
else:
|
||||
prefix = f"{prefix}.model"
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
self.embed_tokens.weight *= embed_norm
|
||||
|
||||
self.model = FlashGemmaModel(prefix=prefix, config=config, weights=weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix=prefix,
|
||||
prefix=(
|
||||
f"{prefix}.embed_tokens"
|
||||
if config.tie_word_embeddings
|
||||
else f"{prefix}.lm_head"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
@ -627,9 +453,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
inputs_embeds,
|
||||
input_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
|
@ -166,7 +166,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
|
||||
self.language_model = load_text_model(
|
||||
self.text_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
@ -191,9 +191,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
pixel_attention_mask=None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = torch.nn.functional.embedding(
|
||||
input_ids, self.language_model.model.unscaled_embed_tokens
|
||||
)
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# TODO: avoid these casts upstream
|
||||
@ -211,17 +209,14 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
||||
mask = input_ids == self.config.image_token_index | (input_ids == 2)
|
||||
|
||||
# insert image features into input embeddings
|
||||
# normalizer = torch.tensor(
|
||||
# self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype
|
||||
# )
|
||||
# inputs_embeds = inputs_embeds * normalizer
|
||||
inputs_embeds[mask] = scaled_image_features.view(
|
||||
-1, scaled_image_features.shape[-1]
|
||||
)
|
||||
|
||||
# NOTE: scale back up since we dont normalize inside the model like transformers
|
||||
# TODO: simplify all the rescaling
|
||||
normalizer = torch.tensor(
|
||||
self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype
|
||||
)
|
||||
inputs_embeds = inputs_embeds * normalizer
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
@ -233,10 +228,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
||||
max_s=max_s,
|
||||
)
|
||||
|
||||
if input_ids.size(0) != 3000:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
pass
|
||||
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
||||
|
@ -4,12 +4,11 @@ import torch.distributed
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers.models.gemma import GemmaTokenizerFast
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
from transformers import AutoConfig
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||
FlashGemmaForCausalLM,
|
||||
GemmaConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
@ -20,58 +19,15 @@ from text_generation_server.utils import (
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class VisionConfig(PretrainedConfig):
|
||||
class FlashGemma(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1152,
|
||||
intermediate_size: int = 4304,
|
||||
model_type: str = "siglip_vision_model",
|
||||
num_attention_heads: int = 16,
|
||||
num_hidden_layers: int = 27,
|
||||
num_image_tokens: int = 256,
|
||||
patch_size: int = 14,
|
||||
projection_dim: int = 2048,
|
||||
projector_hidden_act: str = "gelu_fast",
|
||||
vision_use_head: bool = False,
|
||||
vocab_size: int = 257152,
|
||||
quantize: Optional[str] = None,
|
||||
image_size: int = 224,
|
||||
layer_norm_eps: float = 1e-06,
|
||||
attention_dropout: float = 0.0,
|
||||
hidden_act: str = "gelu_pytorch_tanh",
|
||||
num_channels: int = 3,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.model_type = model_type
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_image_tokens = num_image_tokens
|
||||
self.patch_size = patch_size
|
||||
self.projection_dim = projection_dim
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.vision_use_head = vision_use_head
|
||||
self.vocab_size = vocab_size
|
||||
self.quantize = quantize
|
||||
self.image_size = image_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_act = hidden_act
|
||||
self.num_channels = num_channels
|
||||
|
||||
|
||||
class BaseFlashGemma(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_cls,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
prefix: Optional[str] = None,
|
||||
config_cls=AutoConfig,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -90,12 +46,13 @@ class BaseFlashGemma(FlashCausalLM):
|
||||
from_slow=False,
|
||||
)
|
||||
|
||||
config = config_cls.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
import ipdb
|
||||
|
||||
is_vlm = hasattr(config, "vision_config")
|
||||
|
||||
ipdb.set_trace()
|
||||
config = config.text_config
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
@ -106,44 +63,19 @@ class BaseFlashGemma(FlashCausalLM):
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = model_cls(prefix, config, weights)
|
||||
# TODO hardcoded
|
||||
prefix = "language_model"
|
||||
model = FlashGemmaForCausalLM(prefix, config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
num_layers = config.num_hidden_layers
|
||||
num_kv_heads = config.num_key_value_heads
|
||||
head_size = config.intermediate_size
|
||||
|
||||
super().__init__(
|
||||
super(FlashGemma, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
|
||||
class FlashGemma(BaseFlashGemma):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(FlashGemma, self).__init__(
|
||||
model_cls=FlashGemmaForCausalLM,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
prefix=None,
|
||||
)
|
||||
|
@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
FlashMistralBatch,
|
||||
)
|
||||
from text_generation_server.models.flash_gemma import BaseFlashGemma
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
||||
from text_generation_server.models.cache_manager import (
|
||||
get_cache_manager,
|
||||
@ -516,126 +515,3 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch):
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
|
||||
class PaliVlmCausalLM(BaseFlashGemma):
|
||||
@property
|
||||
def batch_type(self) -> Type[PaliVlmCausalLMBatch]:
|
||||
return PaliVlmCausalLMBatch
|
||||
|
||||
def forward(
|
||||
self, batch: PaliVlmCausalLMBatch
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
speculative_ids = batch.speculative_ids
|
||||
|
||||
B, speculative_length = speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
new_input_ids = torch.cat(
|
||||
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||
).reshape(-1)
|
||||
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||
arange_int = arange.to(dtype=torch.int32)
|
||||
new_position_ids = (
|
||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||
).view(-1)
|
||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
block_tables = (
|
||||
block_tables.unsqueeze(1)
|
||||
.expand(B, new_length, -1)
|
||||
.reshape(B * new_length, -1)
|
||||
.contiguous()
|
||||
)
|
||||
max_s = max_s + speculative_length
|
||||
|
||||
input_ids = new_input_ids
|
||||
position_ids = new_position_ids
|
||||
else:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
# Try to find an associated cuda graph
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
if sorted_padded_bs:
|
||||
# Get associated cuda graph
|
||||
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||
else:
|
||||
cuda_graph = None
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
# prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
)
|
||||
# if batch.prefill_cache_indices is not None:
|
||||
# batch.prefill_cache_indices = None
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
if batch.pixel_attention_mask is not None:
|
||||
batch.pixel_attention_mask = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
Loading…
Reference in New Issue
Block a user