mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Don't break what's not broken.
This commit is contained in:
parent
9b9614cea3
commit
ebbe7edca4
@ -75,7 +75,6 @@ try:
|
|||||||
from text_generation_server.models.flash_phi import FlashPhi
|
from text_generation_server.models.flash_phi import FlashPhi
|
||||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
from text_generation_server.models.flash_dbrx import FlashDbrx
|
||||||
from text_generation_server.models.flash_pali_gemma import FlashPaliGemma
|
|
||||||
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@ -434,16 +433,6 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "paligemma":
|
|
||||||
return FlashPaliGemma(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
use_medusa=use_medusa,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type == "cohere":
|
if model_type == "cohere":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCohere(
|
return FlashCohere(
|
||||||
|
@ -39,9 +39,6 @@ from text_generation_server.layers.layernorm import (
|
|||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: used for debugging; to avoid breaking during warmup
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
|
|
||||||
class GemmaConfig(PretrainedConfig):
|
class GemmaConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -66,8 +63,6 @@ class GemmaConfig(PretrainedConfig):
|
|||||||
rope_scaling=None,
|
rope_scaling=None,
|
||||||
attention_bias=False,
|
attention_bias=False,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
quantize: Optional[str] = None,
|
|
||||||
use_medusa: Optional[str] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@ -91,8 +86,6 @@ class GemmaConfig(PretrainedConfig):
|
|||||||
self.rope_scaling = rope_scaling
|
self.rope_scaling = rope_scaling
|
||||||
self.attention_bias = attention_bias
|
self.attention_bias = attention_bias
|
||||||
self.attention_dropout = attention_dropout
|
self.attention_dropout = attention_dropout
|
||||||
self.quantize = quantize
|
|
||||||
self.use_medusa = use_medusa
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
@ -106,7 +99,7 @@ class GemmaConfig(PretrainedConfig):
|
|||||||
class GemmaFastRMSNorm(FastRMSNorm):
|
class GemmaFastRMSNorm(FastRMSNorm):
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, prefix, weights, eps=1e-6):
|
def load(cls, prefix, weights, eps=1e-6):
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
# perform the multiplication in full precision and downcast after
|
# perform the multiplication in full precision and downcast after
|
||||||
@ -117,7 +110,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
|
|||||||
hidden_states = hidden_states.to(torch.float32)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
hidden_states = hidden_states * (self.weight.float() + 1.0)
|
hidden_states = hidden_states * self.weight
|
||||||
return hidden_states.to(self.weight.dtype), residual
|
return hidden_states.to(self.weight.dtype), residual
|
||||||
|
|
||||||
|
|
||||||
@ -159,107 +152,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
||||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
||||||
"""
|
|
||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
||||||
if n_rep == 1:
|
|
||||||
return hidden_states
|
|
||||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
|
||||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
|
||||||
)
|
|
||||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
||||||
|
|
||||||
|
|
||||||
class GemmaRotaryEmbedding(nn.Module):
|
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dim = dim
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.base = base
|
|
||||||
self.register_buffer("inv_freq", None, persistent=False)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, x, position_ids, seq_len=None):
|
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
||||||
if self.inv_freq is None:
|
|
||||||
self.inv_freq = 1.0 / (
|
|
||||||
self.base
|
|
||||||
** (
|
|
||||||
torch.arange(
|
|
||||||
0, self.dim, 2, dtype=torch.int64, device=x.device
|
|
||||||
).float()
|
|
||||||
/ self.dim
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
inv_freq_expanded = (
|
|
||||||
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
|
||||||
# Force float32 since bfloat16 loses precision on long contexts
|
|
||||||
# See https://github.com/huggingface/transformers/pull/29285
|
|
||||||
device_type = x.device.type
|
|
||||||
device_type = (
|
|
||||||
device_type
|
|
||||||
if isinstance(device_type, str) and device_type != "mps"
|
|
||||||
else "cpu"
|
|
||||||
)
|
|
||||||
with torch.autocast(device_type=device_type, enabled=False):
|
|
||||||
|
|
||||||
# import ipdb
|
|
||||||
# ipdb.set_trace()
|
|
||||||
freqs = (
|
|
||||||
inv_freq_expanded.float() @ position_ids_expanded.float()
|
|
||||||
).transpose(1, 2)
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
cos = emb.cos()
|
|
||||||
sin = emb.sin()
|
|
||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
||||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q (`torch.Tensor`): The query tensor.
|
|
||||||
k (`torch.Tensor`): The key tensor.
|
|
||||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
||||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
||||||
position_ids (`torch.Tensor`, *optional*):
|
|
||||||
Deprecated and unused.
|
|
||||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
||||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
||||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
||||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
||||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
||||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
||||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
||||||
Returns:
|
|
||||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
||||||
"""
|
|
||||||
cos = cos.unsqueeze(unsqueeze_dim)
|
|
||||||
sin = sin.unsqueeze(unsqueeze_dim)
|
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
class FlashGemmaAttention(torch.nn.Module):
|
class FlashGemmaAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -271,16 +163,9 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_size = config.head_dim
|
self.head_size = config.head_dim
|
||||||
|
|
||||||
# self._rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
# config=config,
|
config=config,
|
||||||
# dim=self.head_size,
|
|
||||||
# base=config.rope_theta,
|
|
||||||
# device=weights.device,
|
|
||||||
# )
|
|
||||||
|
|
||||||
self.rotary_emb = GemmaRotaryEmbedding(
|
|
||||||
dim=self.head_size,
|
dim=self.head_size,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
device=weights.device,
|
device=weights.device,
|
||||||
)
|
)
|
||||||
@ -297,21 +182,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: prefer this implementation after debugging
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
# self.query_key_value = load_attention(config, prefix, weights)
|
|
||||||
|
|
||||||
self.k_proj = TensorParallelColumnLinear.load(
|
|
||||||
config,
|
|
||||||
prefix=f"{prefix}.k_proj",
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.v_proj = TensorParallelColumnLinear.load(
|
|
||||||
config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
|
|
||||||
)
|
|
||||||
self.q_proj = TensorParallelColumnLinear.load(
|
|
||||||
config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -323,7 +194,6 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_groups)
|
||||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -337,55 +207,21 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
global count
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
query, kv = qkv.split(
|
||||||
# TODO: replace with better implementation after debugging
|
[
|
||||||
tgt_len, src_len = hidden_states.size()
|
self.head_size * self.num_heads,
|
||||||
query_states = self.q_proj(hidden_states)
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
key_states = self.k_proj(hidden_states)
|
],
|
||||||
value_states = self.v_proj(hidden_states)
|
dim=1,
|
||||||
|
|
||||||
query_states = (
|
|
||||||
query_states.unsqueeze(0).view(1, tgt_len, 8, 256).transpose(1, 2)
|
|
||||||
)
|
|
||||||
key_states = key_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2)
|
|
||||||
value_states = (
|
|
||||||
value_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2)
|
|
||||||
)
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
# reshape for flash/paged attention
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
kv = torch.cat([key_states, value_states], dim=1).transpose(0, 2)
|
|
||||||
|
|
||||||
# cos2, sin2 = self.rotary_emb(value_states, position_ids, seq_len=None)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin, None
|
|
||||||
)
|
|
||||||
|
|
||||||
# replace the kv with the rotated kv
|
|
||||||
kv[:, 0] = key_states.squeeze(0).transpose(0, 1)
|
|
||||||
query = query_states.squeeze(0).transpose(0, 1)
|
|
||||||
|
|
||||||
# TODO: remove after debugging
|
|
||||||
# ipdb> ref_query_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_query_states.pt").to("cuda:0")
|
|
||||||
# ipdb> torch.allclose(query,ref_query_states.squeeze(0).transpose(0,1))
|
|
||||||
|
|
||||||
# ipdb> ref_key_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_key_states.pt").to("cuda:0")
|
|
||||||
# ipdb> torch.allclose(kv[:, 0],ref_key_states.squeeze(0).transpose(0,1))
|
|
||||||
|
|
||||||
# ipdb> ref_value_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_value_states.pt").to("cuda:0")
|
|
||||||
# ipdb> torch.allclose(kv[:, 1],ref_value_states.squeeze(0).transpose(0,1))
|
|
||||||
|
|
||||||
if count > 0:
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
# ipdb.set_trace()
|
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0],
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
kv[:, 1],
|
|
||||||
kv_cache[0],
|
|
||||||
kv_cache[1],
|
|
||||||
slots,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
@ -459,9 +295,8 @@ class GemmaMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemmaLayer(nn.Module):
|
class FlashGemmaLayer(nn.Module):
|
||||||
def __init__(self, prefix, layer_id, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"{prefix or ''}model.layers.{layer_id}"
|
|
||||||
self.self_attn = FlashGemmaAttention(
|
self.self_attn = FlashGemmaAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
@ -521,38 +356,18 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
embed_norm = config.hidden_size**0.5
|
|
||||||
pvalue = f"{prefix + '.' if prefix else ''}model.embed_tokens"
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
|
||||||
prefix=pvalue,
|
|
||||||
weights=weights,
|
|
||||||
)
|
|
||||||
self.embed_tokens.weight = torch.nn.Parameter(
|
|
||||||
self.embed_tokens.weight[: config.vocab_size, : config.hidden_size]
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: avoid making a copy of the embedding matrix. added for debugging
|
|
||||||
self.unscaled_embed_tokens = torch.nn.Parameter(
|
|
||||||
self.embed_tokens.weight.clone()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embed_tokens.weight *= embed_norm
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashGemmaLayer(
|
FlashGemmaLayer(
|
||||||
f"{prefix + '.' if prefix else ''}",
|
prefix=f"{prefix}.layers.{layer_id}",
|
||||||
layer_id,
|
config=config,
|
||||||
config,
|
weights=weights,
|
||||||
weights,
|
|
||||||
)
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = GemmaFastRMSNorm.load(
|
self.norm = GemmaFastRMSNorm.load(
|
||||||
prefix=f"{prefix + '.' if prefix else ''}model.norm",
|
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
weights=weights,
|
|
||||||
eps=config.rms_norm_eps,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@ -563,8 +378,7 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
# input_ids: torch.Tensor,
|
input_embeds: torch.Tensor,
|
||||||
inputs_embeds: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
@ -573,17 +387,16 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
global count
|
hidden_states = input_embeds
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
|
position_ids, max_s, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
# TODO: prefer a single rotary embedding implementation after debugging
|
|
||||||
cos, sin = self.layers[i].self_attn.rotary_emb(
|
|
||||||
hidden_states,
|
|
||||||
position_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
@ -598,20 +411,33 @@ class FlashGemmaModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
count += 1 # for debugging; to avoid breaking during warmup
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlashGemmaForCausalLM(torch.nn.Module):
|
class FlashGemmaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
|
||||||
self.model = FlashGemmaModel(prefix, config, weights)
|
embed_norm = config.hidden_size**0.5
|
||||||
prefix = f"{prefix + '.' if prefix else ''}model.embed_tokens"
|
if prefix is None:
|
||||||
prefix = prefix if config.tie_word_embeddings else "lm_head"
|
prefix = "model"
|
||||||
|
else:
|
||||||
|
prefix = f"{prefix}.model"
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
self.embed_tokens.weight *= embed_norm
|
||||||
|
|
||||||
|
self.model = FlashGemmaModel(prefix=prefix, config=config, weights=weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
prefix=(
|
||||||
prefix=prefix,
|
f"{prefix}.embed_tokens"
|
||||||
|
if config.tie_word_embeddings
|
||||||
|
else f"{prefix}.lm_head"
|
||||||
|
),
|
||||||
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -627,9 +453,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||||||
max_s: int,
|
max_s: int,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
input_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
inputs_embeds,
|
input_embeds,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
|
@ -166,7 +166,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.language_model = load_text_model(
|
self.text_model = load_text_model(
|
||||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||||
config=config.text_config,
|
config=config.text_config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -191,9 +191,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||||
pixel_attention_mask=None,
|
pixel_attention_mask=None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = torch.nn.functional.embedding(
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
input_ids, self.language_model.model.unscaled_embed_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
# TODO: avoid these casts upstream
|
# TODO: avoid these casts upstream
|
||||||
@ -211,17 +209,14 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
mask = input_ids == self.config.image_token_index | (input_ids == 2)
|
mask = input_ids == self.config.image_token_index | (input_ids == 2)
|
||||||
|
|
||||||
# insert image features into input embeddings
|
# insert image features into input embeddings
|
||||||
|
# normalizer = torch.tensor(
|
||||||
|
# self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype
|
||||||
|
# )
|
||||||
|
# inputs_embeds = inputs_embeds * normalizer
|
||||||
inputs_embeds[mask] = scaled_image_features.view(
|
inputs_embeds[mask] = scaled_image_features.view(
|
||||||
-1, scaled_image_features.shape[-1]
|
-1, scaled_image_features.shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: scale back up since we dont normalize inside the model like transformers
|
|
||||||
# TODO: simplify all the rescaling
|
|
||||||
normalizer = torch.tensor(
|
|
||||||
self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype
|
|
||||||
)
|
|
||||||
inputs_embeds = inputs_embeds * normalizer
|
|
||||||
|
|
||||||
hidden_states = self.language_model.model(
|
hidden_states = self.language_model.model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -233,10 +228,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
if input_ids.size(0) != 3000:
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
pass
|
|
||||||
|
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
||||||
|
@ -4,12 +4,11 @@ import torch.distributed
|
|||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from transformers.models.gemma import GemmaTokenizerFast
|
from transformers.models.gemma import GemmaTokenizerFast
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||||
FlashGemmaForCausalLM,
|
FlashGemmaForCausalLM,
|
||||||
GemmaConfig,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
@ -20,58 +19,15 @@ from text_generation_server.utils import (
|
|||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
class VisionConfig(PretrainedConfig):
|
class FlashGemma(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int = 1152,
|
|
||||||
intermediate_size: int = 4304,
|
|
||||||
model_type: str = "siglip_vision_model",
|
|
||||||
num_attention_heads: int = 16,
|
|
||||||
num_hidden_layers: int = 27,
|
|
||||||
num_image_tokens: int = 256,
|
|
||||||
patch_size: int = 14,
|
|
||||||
projection_dim: int = 2048,
|
|
||||||
projector_hidden_act: str = "gelu_fast",
|
|
||||||
vision_use_head: bool = False,
|
|
||||||
vocab_size: int = 257152,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
image_size: int = 224,
|
|
||||||
layer_norm_eps: float = 1e-06,
|
|
||||||
attention_dropout: float = 0.0,
|
|
||||||
hidden_act: str = "gelu_pytorch_tanh",
|
|
||||||
num_channels: int = 3,
|
|
||||||
):
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.model_type = model_type
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_image_tokens = num_image_tokens
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.projection_dim = projection_dim
|
|
||||||
self.projector_hidden_act = projector_hidden_act
|
|
||||||
self.vision_use_head = vision_use_head
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.quantize = quantize
|
|
||||||
self.image_size = image_size
|
|
||||||
self.layer_norm_eps = layer_norm_eps
|
|
||||||
self.attention_dropout = attention_dropout
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.num_channels = num_channels
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFlashGemma(FlashCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_cls,
|
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
speculator: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
prefix: Optional[str] = None,
|
|
||||||
config_cls=AutoConfig,
|
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -90,12 +46,13 @@ class BaseFlashGemma(FlashCausalLM):
|
|||||||
from_slow=False,
|
from_slow=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = config_cls.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
import ipdb
|
||||||
|
|
||||||
is_vlm = hasattr(config, "vision_config")
|
ipdb.set_trace()
|
||||||
|
config = config.text_config
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.speculator = speculator
|
config.speculator = speculator
|
||||||
|
|
||||||
@ -106,44 +63,19 @@ class BaseFlashGemma(FlashCausalLM):
|
|||||||
if config.quantize in ["gptq", "awq"]:
|
if config.quantize in ["gptq", "awq"]:
|
||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = model_cls(prefix, config, weights)
|
# TODO hardcoded
|
||||||
|
prefix = "language_model"
|
||||||
|
model = FlashGemmaForCausalLM(prefix, config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(FlashGemma, self).__init__(
|
||||||
num_layers = config.num_hidden_layers
|
|
||||||
num_kv_heads = config.num_key_value_heads
|
|
||||||
head_size = config.intermediate_size
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_layers=num_layers,
|
num_layers=len(model.model.layers),
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=model.model.num_key_value_heads,
|
||||||
head_size=head_size,
|
head_size=model.model.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma(BaseFlashGemma):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
use_medusa: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
super(FlashGemma, self).__init__(
|
|
||||||
model_cls=FlashGemmaForCausalLM,
|
|
||||||
model_id=model_id,
|
|
||||||
revision=revision,
|
|
||||||
quantize=quantize,
|
|
||||||
use_medusa=use_medusa,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
prefix=None,
|
|
||||||
)
|
|
||||||
|
@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import (
|
|||||||
BaseFlashMistral,
|
BaseFlashMistral,
|
||||||
FlashMistralBatch,
|
FlashMistralBatch,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.flash_gemma import BaseFlashGemma
|
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
||||||
from text_generation_server.models.cache_manager import (
|
from text_generation_server.models.cache_manager import (
|
||||||
get_cache_manager,
|
get_cache_manager,
|
||||||
@ -516,126 +515,3 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
class PaliVlmCausalLM(BaseFlashGemma):
|
|
||||||
@property
|
|
||||||
def batch_type(self) -> Type[PaliVlmCausalLMBatch]:
|
|
||||||
return PaliVlmCausalLMBatch
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, batch: PaliVlmCausalLMBatch
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
||||||
# Model Forward
|
|
||||||
if batch.speculative_ids is not None:
|
|
||||||
input_ids = batch.input_ids
|
|
||||||
position_ids = batch.position_ids
|
|
||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
||||||
kv_cache = get_cache_manager().kv_cache
|
|
||||||
block_tables = batch.block_tables_tensor
|
|
||||||
slots = batch.slots[batch.slot_indices]
|
|
||||||
input_lengths = batch.input_lengths_tensor
|
|
||||||
max_s = batch.max_seqlen
|
|
||||||
lm_head_indices = batch.prefill_head_indices
|
|
||||||
|
|
||||||
speculative_ids = batch.speculative_ids
|
|
||||||
|
|
||||||
B, speculative_length = speculative_ids.shape
|
|
||||||
new_length = speculative_length + 1
|
|
||||||
new_input_ids = torch.cat(
|
|
||||||
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
|
||||||
).reshape(-1)
|
|
||||||
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
|
||||||
arange_int = arange.to(dtype=torch.int32)
|
|
||||||
new_position_ids = (
|
|
||||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
|
||||||
).view(-1)
|
|
||||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
|
||||||
input_lengths = (
|
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
|
||||||
).view(-1)
|
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
|
||||||
block_tables = (
|
|
||||||
block_tables.unsqueeze(1)
|
|
||||||
.expand(B, new_length, -1)
|
|
||||||
.reshape(B * new_length, -1)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
max_s = max_s + speculative_length
|
|
||||||
|
|
||||||
input_ids = new_input_ids
|
|
||||||
position_ids = new_position_ids
|
|
||||||
else:
|
|
||||||
input_ids = batch.input_ids
|
|
||||||
position_ids = batch.position_ids
|
|
||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
||||||
kv_cache = get_cache_manager().kv_cache
|
|
||||||
block_tables = batch.block_tables_tensor
|
|
||||||
slots = batch.slots[batch.slot_indices]
|
|
||||||
input_lengths = batch.input_lengths_tensor
|
|
||||||
max_s = batch.max_seqlen
|
|
||||||
lm_head_indices = batch.prefill_head_indices
|
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
|
||||||
# in a circular buffer mode.
|
|
||||||
# This makes sure the max_s for the decode pass is correct.
|
|
||||||
max_s = min(self.max_past(), max_s)
|
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
|
||||||
# Try to find an associated cuda graph
|
|
||||||
bs = input_ids.shape[0]
|
|
||||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
|
||||||
if sorted_padded_bs:
|
|
||||||
# Get associated cuda graph
|
|
||||||
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
|
||||||
else:
|
|
||||||
cuda_graph = None
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
|
||||||
logits, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
block_tables=block_tables,
|
|
||||||
slots=slots,
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
max_s=max_s,
|
|
||||||
# prefill_cache_indices=batch.prefill_cache_indices,
|
|
||||||
lm_head_indices=lm_head_indices,
|
|
||||||
pixel_values=batch.pixel_values,
|
|
||||||
)
|
|
||||||
# if batch.prefill_cache_indices is not None:
|
|
||||||
# batch.prefill_cache_indices = None
|
|
||||||
if batch.pixel_values is not None:
|
|
||||||
batch.pixel_values = None
|
|
||||||
if batch.pixel_attention_mask is not None:
|
|
||||||
batch.pixel_attention_mask = None
|
|
||||||
if batch.image_sizes is not None:
|
|
||||||
batch.image_sizes = None
|
|
||||||
return logits, speculative_logits
|
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
|
||||||
# Static inputs are potentially padded
|
|
||||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
|
||||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
|
||||||
cuda_graph["block_tables"][
|
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
|
||||||
] = block_tables
|
|
||||||
cuda_graph["slots"].fill_(-1)
|
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
|
||||||
cuda_graph["input_lengths"].zero_()
|
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
|
||||||
|
|
||||||
# Replay the graph
|
|
||||||
cuda_graph["graph"].replay()
|
|
||||||
|
|
||||||
# Slice output to the correct shape
|
|
||||||
speculative_logits = (
|
|
||||||
cuda_graph["speculative_logits"][:bs]
|
|
||||||
if cuda_graph["speculative_logits"] is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
logits = cuda_graph["logits"][:bs]
|
|
||||||
return logits, speculative_logits
|
|
||||||
|
Loading…
Reference in New Issue
Block a user