Don't break what's not broken.

This commit is contained in:
Nicolas Patry 2024-05-14 09:20:45 +00:00
parent 9b9614cea3
commit ebbe7edca4
5 changed files with 70 additions and 456 deletions

View File

@ -75,7 +75,6 @@ try:
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.models.flash_dbrx import FlashDbrx from text_generation_server.models.flash_dbrx import FlashDbrx
from text_generation_server.models.flash_pali_gemma import FlashPaliGemma
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
except ImportError as e: except ImportError as e:
@ -434,16 +433,6 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "paligemma":
return FlashPaliGemma(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "cohere": if model_type == "cohere":
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCohere( return FlashCohere(

View File

@ -39,9 +39,6 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
# TODO: used for debugging; to avoid breaking during warmup
count = 0
class GemmaConfig(PretrainedConfig): class GemmaConfig(PretrainedConfig):
def __init__( def __init__(
@ -66,8 +63,6 @@ class GemmaConfig(PretrainedConfig):
rope_scaling=None, rope_scaling=None,
attention_bias=False, attention_bias=False,
attention_dropout=0.0, attention_dropout=0.0,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -91,8 +86,6 @@ class GemmaConfig(PretrainedConfig):
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.attention_bias = attention_bias self.attention_bias = attention_bias
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.quantize = quantize
self.use_medusa = use_medusa
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
@ -106,7 +99,7 @@ class GemmaConfig(PretrainedConfig):
class GemmaFastRMSNorm(FastRMSNorm): class GemmaFastRMSNorm(FastRMSNorm):
@classmethod @classmethod
def load(cls, prefix, weights, eps=1e-6): def load(cls, prefix, weights, eps=1e-6):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight") + 1
return cls(weight, eps) return cls(weight, eps)
# perform the multiplication in full precision and downcast after # perform the multiplication in full precision and downcast after
@ -117,7 +110,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True) variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * (self.weight.float() + 1.0) hidden_states = hidden_states * self.weight
return hidden_states.to(self.weight.dtype), residual return hidden_states.to(self.weight.dtype), residual
@ -159,107 +152,6 @@ def _load_gqa(config, prefix: str, weights):
) )
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class GemmaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base
** (
torch.arange(
0, self.dim, 2, dtype=torch.int64, device=x.device
).float()
/ self.dim
)
)
position_ids = position_ids.unsqueeze(0)
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
# import ipdb
# ipdb.set_trace()
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class FlashGemmaAttention(torch.nn.Module): class FlashGemmaAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
@ -271,16 +163,9 @@ class FlashGemmaAttention(torch.nn.Module):
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_size = config.head_dim self.head_size = config.head_dim
# self._rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
# config=config, config=config,
# dim=self.head_size,
# base=config.rope_theta,
# device=weights.device,
# )
self.rotary_emb = GemmaRotaryEmbedding(
dim=self.head_size, dim=self.head_size,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta, base=config.rope_theta,
device=weights.device, device=weights.device,
) )
@ -297,21 +182,7 @@ class FlashGemmaAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()
) )
# TODO: prefer this implementation after debugging self.query_key_value = load_attention(config, prefix, weights)
# self.query_key_value = load_attention(config, prefix, weights)
self.k_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=False,
)
self.v_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
)
self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
)
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
@ -323,7 +194,6 @@ class FlashGemmaAttention(torch.nn.Module):
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups) ).repeat_interleave(self.num_groups)
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
def forward( def forward(
self, self,
@ -337,55 +207,21 @@ class FlashGemmaAttention(torch.nn.Module):
input_lengths, input_lengths,
max_s, max_s,
): ):
global count qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
# TODO: replace with better implementation after debugging [
tgt_len, src_len = hidden_states.size() self.head_size * self.num_heads,
query_states = self.q_proj(hidden_states) 2 * self.head_size * self.num_key_value_heads,
key_states = self.k_proj(hidden_states) ],
value_states = self.v_proj(hidden_states) dim=1,
query_states = (
query_states.unsqueeze(0).view(1, tgt_len, 8, 256).transpose(1, 2)
)
key_states = key_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2)
value_states = (
value_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2)
) )
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# reshape for flash/paged attention self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv = torch.cat([key_states, value_states], dim=1).transpose(0, 2)
# cos2, sin2 = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, None
)
# replace the kv with the rotated kv
kv[:, 0] = key_states.squeeze(0).transpose(0, 1)
query = query_states.squeeze(0).transpose(0, 1)
# TODO: remove after debugging
# ipdb> ref_query_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_query_states.pt").to("cuda:0")
# ipdb> torch.allclose(query,ref_query_states.squeeze(0).transpose(0,1))
# ipdb> ref_key_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_key_states.pt").to("cuda:0")
# ipdb> torch.allclose(kv[:, 0],ref_key_states.squeeze(0).transpose(0,1))
# ipdb> ref_value_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_value_states.pt").to("cuda:0")
# ipdb> torch.allclose(kv[:, 1],ref_value_states.squeeze(0).transpose(0,1))
if count > 0:
import ipdb
# ipdb.set_trace()
paged_attention.reshape_and_cache( paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
kv[:, 1],
kv_cache[0],
kv_cache[1],
slots,
) )
# output tensor # output tensor
@ -459,9 +295,8 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module): class FlashGemmaLayer(nn.Module):
def __init__(self, prefix, layer_id, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
prefix = f"{prefix or ''}model.layers.{layer_id}"
self.self_attn = FlashGemmaAttention( self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
@ -521,38 +356,18 @@ class FlashGemmaModel(torch.nn.Module):
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
embed_norm = config.hidden_size**0.5
pvalue = f"{prefix + '.' if prefix else ''}model.embed_tokens"
self.embed_tokens = TensorParallelEmbedding(
prefix=pvalue,
weights=weights,
)
self.embed_tokens.weight = torch.nn.Parameter(
self.embed_tokens.weight[: config.vocab_size, : config.hidden_size]
)
# TODO: avoid making a copy of the embedding matrix. added for debugging
self.unscaled_embed_tokens = torch.nn.Parameter(
self.embed_tokens.weight.clone()
)
self.embed_tokens.weight *= embed_norm
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashGemmaLayer( FlashGemmaLayer(
f"{prefix + '.' if prefix else ''}", prefix=f"{prefix}.layers.{layer_id}",
layer_id, config=config,
config, weights=weights,
weights,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = GemmaFastRMSNorm.load( self.norm = GemmaFastRMSNorm.load(
prefix=f"{prefix + '.' if prefix else ''}model.norm", prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
weights=weights,
eps=config.rms_norm_eps,
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -563,8 +378,7 @@ class FlashGemmaModel(torch.nn.Module):
def forward( def forward(
self, self,
# input_ids: torch.Tensor, input_embeds: torch.Tensor,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -573,17 +387,16 @@ class FlashGemmaModel(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
) -> torch.Tensor: ) -> torch.Tensor:
global count hidden_states = input_embeds
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
# TODO: prefer a single rotary embedding implementation after debugging
cos, sin = self.layers[i].self_attn.rotary_emb(
hidden_states,
position_ids,
)
hidden_states, residual = layer( hidden_states, residual = layer(
hidden_states, hidden_states,
residual, residual,
@ -598,20 +411,33 @@ class FlashGemmaModel(torch.nn.Module):
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
count += 1 # for debugging; to avoid breaking during warmup
return hidden_states return hidden_states
class FlashGemmaForCausalLM(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.config = config
self.model = FlashGemmaModel(prefix, config, weights) embed_norm = config.hidden_size**0.5
prefix = f"{prefix + '.' if prefix else ''}model.embed_tokens" if prefix is None:
prefix = prefix if config.tie_word_embeddings else "lm_head" prefix = "model"
else:
prefix = f"{prefix}.model"
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.model = FlashGemmaModel(prefix=prefix, config=config, weights=weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix=(
prefix=prefix, f"{prefix}.embed_tokens"
if config.tie_word_embeddings
else f"{prefix}.lm_head"
),
config=config,
weights=weights, weights=weights,
) )
@ -627,9 +453,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.model.embed_tokens(input_ids) input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
inputs_embeds, input_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,

View File

@ -166,7 +166,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.config = config self.config = config
self.language_model = load_text_model( self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model", prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config, config=config.text_config,
weights=weights, weights=weights,
@ -191,9 +191,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
pixel_attention_mask=None, pixel_attention_mask=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = torch.nn.functional.embedding( inputs_embeds = self.text_model.embed_tokens(input_ids)
input_ids, self.language_model.model.unscaled_embed_tokens
)
if pixel_values is not None and len(pixel_values) > 0: if pixel_values is not None and len(pixel_values) > 0:
# TODO: avoid these casts upstream # TODO: avoid these casts upstream
@ -211,17 +209,14 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
mask = input_ids == self.config.image_token_index | (input_ids == 2) mask = input_ids == self.config.image_token_index | (input_ids == 2)
# insert image features into input embeddings # insert image features into input embeddings
# normalizer = torch.tensor(
# self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype
# )
# inputs_embeds = inputs_embeds * normalizer
inputs_embeds[mask] = scaled_image_features.view( inputs_embeds[mask] = scaled_image_features.view(
-1, scaled_image_features.shape[-1] -1, scaled_image_features.shape[-1]
) )
# NOTE: scale back up since we dont normalize inside the model like transformers
# TODO: simplify all the rescaling
normalizer = torch.tensor(
self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype
)
inputs_embeds = inputs_embeds * normalizer
hidden_states = self.language_model.model( hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
@ -233,10 +228,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
max_s=max_s, max_s=max_s,
) )
if input_ids.size(0) != 3000:
# import ipdb; ipdb.set_trace()
pass
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states) logits, speculative_logits = self.language_model.lm_head(hidden_states)

View File

@ -4,12 +4,11 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Optional from typing import Optional
from transformers.models.gemma import GemmaTokenizerFast from transformers.models.gemma import GemmaTokenizerFast
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM, FlashGemmaForCausalLM,
GemmaConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
@ -20,58 +19,15 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
class VisionConfig(PretrainedConfig): class FlashGemma(FlashCausalLM):
def __init__( def __init__(
self, self,
hidden_size: int = 1152,
intermediate_size: int = 4304,
model_type: str = "siglip_vision_model",
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
num_image_tokens: int = 256,
patch_size: int = 14,
projection_dim: int = 2048,
projector_hidden_act: str = "gelu_fast",
vision_use_head: bool = False,
vocab_size: int = 257152,
quantize: Optional[str] = None,
image_size: int = 224,
layer_norm_eps: float = 1e-06,
attention_dropout: float = 0.0,
hidden_act: str = "gelu_pytorch_tanh",
num_channels: int = 3,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.model_type = model_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.num_image_tokens = num_image_tokens
self.patch_size = patch_size
self.projection_dim = projection_dim
self.projector_hidden_act = projector_hidden_act
self.vision_use_head = vision_use_head
self.vocab_size = vocab_size
self.quantize = quantize
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.num_channels = num_channels
class BaseFlashGemma(FlashCausalLM):
def __init__(
self,
model_cls,
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculator: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
prefix: Optional[str] = None,
config_cls=AutoConfig,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -90,12 +46,13 @@ class BaseFlashGemma(FlashCausalLM):
from_slow=False, from_slow=False,
) )
config = config_cls.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
import ipdb
is_vlm = hasattr(config, "vision_config") ipdb.set_trace()
config = config.text_config
config.quantize = quantize config.quantize = quantize
config.speculator = speculator config.speculator = speculator
@ -106,44 +63,19 @@ class BaseFlashGemma(FlashCausalLM):
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = model_cls(prefix, config, weights) # TODO hardcoded
prefix = "language_model"
model = FlashGemmaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__(
num_layers = config.num_hidden_layers
num_kv_heads = config.num_key_value_heads
head_size = config.intermediate_size
super().__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=num_layers, num_layers=len(model.model.layers),
num_kv_heads=num_kv_heads, num_kv_heads=model.model.num_key_value_heads,
head_size=head_size, head_size=model.model.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
class FlashGemma(BaseFlashGemma):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashGemma, self).__init__(
model_cls=FlashGemmaForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
prefix=None,
)

View File

@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import (
BaseFlashMistral, BaseFlashMistral,
FlashMistralBatch, FlashMistralBatch,
) )
from text_generation_server.models.flash_gemma import BaseFlashGemma
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.cache_manager import ( from text_generation_server.models.cache_manager import (
get_cache_manager, get_cache_manager,
@ -516,126 +515,3 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch):
batch.pixel_attention_mask = None batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch
class PaliVlmCausalLM(BaseFlashGemma):
@property
def batch_type(self) -> Type[PaliVlmCausalLMBatch]:
return PaliVlmCausalLMBatch
def forward(
self, batch: PaliVlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
# prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
)
# if batch.prefill_cache_indices is not None:
# batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits