Don't break what's not broken.

This commit is contained in:
Nicolas Patry 2024-05-14 09:20:45 +00:00
parent 9b9614cea3
commit ebbe7edca4
5 changed files with 70 additions and 456 deletions

View File

@ -75,7 +75,6 @@ try:
from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.models.flash_dbrx import FlashDbrx
from text_generation_server.models.flash_pali_gemma import FlashPaliGemma
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
except ImportError as e:
@ -434,16 +433,6 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "paligemma":
return FlashPaliGemma(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "cohere":
if FLASH_ATTENTION:
return FlashCohere(

View File

@ -39,9 +39,6 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
# TODO: used for debugging; to avoid breaking during warmup
count = 0
class GemmaConfig(PretrainedConfig):
def __init__(
@ -66,8 +63,6 @@ class GemmaConfig(PretrainedConfig):
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
**kwargs,
):
self.vocab_size = vocab_size
@ -91,8 +86,6 @@ class GemmaConfig(PretrainedConfig):
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.quantize = quantize
self.use_medusa = use_medusa
super().__init__(
pad_token_id=pad_token_id,
@ -106,7 +99,7 @@ class GemmaConfig(PretrainedConfig):
class GemmaFastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix, weights, eps=1e-6):
weight = weights.get_tensor(f"{prefix}.weight")
weight = weights.get_tensor(f"{prefix}.weight") + 1
return cls(weight, eps)
# perform the multiplication in full precision and downcast after
@ -117,7 +110,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * (self.weight.float() + 1.0)
hidden_states = hidden_states * self.weight
return hidden_states.to(self.weight.dtype), residual
@ -159,107 +152,6 @@ def _load_gqa(config, prefix: str, weights):
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class GemmaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base
** (
torch.arange(
0, self.dim, 2, dtype=torch.int64, device=x.device
).float()
/ self.dim
)
)
position_ids = position_ids.unsqueeze(0)
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
# import ipdb
# ipdb.set_trace()
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class FlashGemmaAttention(torch.nn.Module):
def __init__(
self,
@ -271,16 +163,9 @@ class FlashGemmaAttention(torch.nn.Module):
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
# self._rotary_emb = PositionRotaryEmbedding.static(
# config=config,
# dim=self.head_size,
# base=config.rope_theta,
# device=weights.device,
# )
self.rotary_emb = GemmaRotaryEmbedding(
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
device=weights.device,
)
@ -297,21 +182,7 @@ class FlashGemmaAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size()
)
# TODO: prefer this implementation after debugging
# self.query_key_value = load_attention(config, prefix, weights)
self.k_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=False,
)
self.v_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
)
self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
)
self.query_key_value = load_attention(config, prefix, weights)
self.o_proj = TensorParallelRowLinear.load(
config,
@ -323,7 +194,6 @@ class FlashGemmaAttention(torch.nn.Module):
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
def forward(
self,
@ -337,55 +207,21 @@ class FlashGemmaAttention(torch.nn.Module):
input_lengths,
max_s,
):
global count
# TODO: replace with better implementation after debugging
tgt_len, src_len = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = (
query_states.unsqueeze(0).view(1, tgt_len, 8, 256).transpose(1, 2)
)
key_states = key_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2)
value_states = (
value_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2)
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# reshape for flash/paged attention
kv = torch.cat([key_states, value_states], dim=1).transpose(0, 2)
# cos2, sin2 = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, None
)
# replace the kv with the rotated kv
kv[:, 0] = key_states.squeeze(0).transpose(0, 1)
query = query_states.squeeze(0).transpose(0, 1)
# TODO: remove after debugging
# ipdb> ref_query_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_query_states.pt").to("cuda:0")
# ipdb> torch.allclose(query,ref_query_states.squeeze(0).transpose(0,1))
# ipdb> ref_key_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_key_states.pt").to("cuda:0")
# ipdb> torch.allclose(kv[:, 0],ref_key_states.squeeze(0).transpose(0,1))
# ipdb> ref_value_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_value_states.pt").to("cuda:0")
# ipdb> torch.allclose(kv[:, 1],ref_value_states.squeeze(0).transpose(0,1))
if count > 0:
import ipdb
# ipdb.set_trace()
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache(
kv[:, 0],
kv[:, 1],
kv_cache[0],
kv_cache[1],
slots,
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
@ -459,9 +295,8 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module):
def __init__(self, prefix, layer_id, config, weights):
def __init__(self, prefix, config, weights):
super().__init__()
prefix = f"{prefix or ''}model.layers.{layer_id}"
self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
@ -521,38 +356,18 @@ class FlashGemmaModel(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
embed_norm = config.hidden_size**0.5
pvalue = f"{prefix + '.' if prefix else ''}model.embed_tokens"
self.embed_tokens = TensorParallelEmbedding(
prefix=pvalue,
weights=weights,
)
self.embed_tokens.weight = torch.nn.Parameter(
self.embed_tokens.weight[: config.vocab_size, : config.hidden_size]
)
# TODO: avoid making a copy of the embedding matrix. added for debugging
self.unscaled_embed_tokens = torch.nn.Parameter(
self.embed_tokens.weight.clone()
)
self.embed_tokens.weight *= embed_norm
self.layers = nn.ModuleList(
[
FlashGemmaLayer(
f"{prefix + '.' if prefix else ''}",
layer_id,
config,
weights,
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = GemmaFastRMSNorm.load(
prefix=f"{prefix + '.' if prefix else ''}model.norm",
weights=weights,
eps=config.rms_norm_eps,
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.gradient_checkpointing = False
@ -563,8 +378,7 @@ class FlashGemmaModel(torch.nn.Module):
def forward(
self,
# input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
input_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@ -573,17 +387,16 @@ class FlashGemmaModel(torch.nn.Module):
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
global count
hidden_states = inputs_embeds
hidden_states = input_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
# TODO: prefer a single rotary embedding implementation after debugging
cos, sin = self.layers[i].self_attn.rotary_emb(
hidden_states,
position_ids,
)
hidden_states, residual = layer(
hidden_states,
residual,
@ -598,20 +411,33 @@ class FlashGemmaModel(torch.nn.Module):
)
hidden_states, _ = self.norm(hidden_states, residual)
count += 1 # for debugging; to avoid breaking during warmup
return hidden_states
class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.model = FlashGemmaModel(prefix, config, weights)
prefix = f"{prefix + '.' if prefix else ''}model.embed_tokens"
prefix = prefix if config.tie_word_embeddings else "lm_head"
embed_norm = config.hidden_size**0.5
if prefix is None:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.model = FlashGemmaModel(prefix=prefix, config=config, weights=weights)
self.lm_head = SpeculativeHead.load(
config,
prefix=prefix,
prefix=(
f"{prefix}.embed_tokens"
if config.tie_word_embeddings
else f"{prefix}.lm_head"
),
config=config,
weights=weights,
)
@ -627,9 +453,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.model.embed_tokens(input_ids)
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
inputs_embeds,
input_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,

View File

@ -166,7 +166,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
self.vocab_size = config.vocab_size
self.config = config
self.language_model = load_text_model(
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
@ -191,9 +191,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
pixel_attention_mask=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = torch.nn.functional.embedding(
input_ids, self.language_model.model.unscaled_embed_tokens
)
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# TODO: avoid these casts upstream
@ -211,17 +209,14 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
mask = input_ids == self.config.image_token_index | (input_ids == 2)
# insert image features into input embeddings
# normalizer = torch.tensor(
# self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype
# )
# inputs_embeds = inputs_embeds * normalizer
inputs_embeds[mask] = scaled_image_features.view(
-1, scaled_image_features.shape[-1]
)
# NOTE: scale back up since we dont normalize inside the model like transformers
# TODO: simplify all the rescaling
normalizer = torch.tensor(
self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype
)
inputs_embeds = inputs_embeds * normalizer
hidden_states = self.language_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
@ -233,10 +228,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module):
max_s=max_s,
)
if input_ids.size(0) != 3000:
# import ipdb; ipdb.set_trace()
pass
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states)

View File

@ -4,12 +4,11 @@ import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers.models.gemma import GemmaTokenizerFast
from transformers import AutoConfig, PretrainedConfig
from transformers import AutoConfig
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
GemmaConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
@ -20,58 +19,15 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__)
class VisionConfig(PretrainedConfig):
class FlashGemma(FlashCausalLM):
def __init__(
self,
hidden_size: int = 1152,
intermediate_size: int = 4304,
model_type: str = "siglip_vision_model",
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
num_image_tokens: int = 256,
patch_size: int = 14,
projection_dim: int = 2048,
projector_hidden_act: str = "gelu_fast",
vision_use_head: bool = False,
vocab_size: int = 257152,
quantize: Optional[str] = None,
image_size: int = 224,
layer_norm_eps: float = 1e-06,
attention_dropout: float = 0.0,
hidden_act: str = "gelu_pytorch_tanh",
num_channels: int = 3,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.model_type = model_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.num_image_tokens = num_image_tokens
self.patch_size = patch_size
self.projection_dim = projection_dim
self.projector_hidden_act = projector_hidden_act
self.vision_use_head = vision_use_head
self.vocab_size = vocab_size
self.quantize = quantize
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.num_channels = num_channels
class BaseFlashGemma(FlashCausalLM):
def __init__(
self,
model_cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
prefix: Optional[str] = None,
config_cls=AutoConfig,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -90,12 +46,13 @@ class BaseFlashGemma(FlashCausalLM):
from_slow=False,
)
config = config_cls.from_pretrained(
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
import ipdb
is_vlm = hasattr(config, "vision_config")
ipdb.set_trace()
config = config.text_config
config.quantize = quantize
config.speculator = speculator
@ -106,44 +63,19 @@ class BaseFlashGemma(FlashCausalLM):
if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision)
model = model_cls(prefix, config, weights)
# TODO hardcoded
prefix = "language_model"
model = FlashGemmaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
num_layers = config.num_hidden_layers
num_kv_heads = config.num_key_value_heads
head_size = config.intermediate_size
super().__init__(
super(FlashGemma, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_size=head_size,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
class FlashGemma(BaseFlashGemma):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashGemma, self).__init__(
model_cls=FlashGemmaForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
prefix=None,
)

View File

@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
FlashMistralBatch,
)
from text_generation_server.models.flash_gemma import BaseFlashGemma
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.cache_manager import (
get_cache_manager,
@ -516,126 +515,3 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch):
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
class PaliVlmCausalLM(BaseFlashGemma):
@property
def batch_type(self) -> Type[PaliVlmCausalLMBatch]:
return PaliVlmCausalLMBatch
def forward(
self, batch: PaliVlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
# prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
)
# if batch.prefill_cache_indices is not None:
# batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits