From ebbe7edca467e22724000aed9d5e79eae03ae490 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 14 May 2024 09:20:45 +0000 Subject: [PATCH] Don't break what's not broken. --- .../text_generation_server/models/__init__.py | 11 - .../custom_modeling/flash_gemma_modeling.py | 276 ++++-------------- .../flash_pali_gemma_modeling.py | 21 +- .../models/flash_gemma.py | 94 +----- .../models/vlm_causal_lm.py | 124 -------- 5 files changed, 70 insertions(+), 456 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index c4c6342a..e9761dfe 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -75,7 +75,6 @@ try: from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_dbrx import FlashDbrx - from text_generation_server.models.flash_pali_gemma import FlashPaliGemma from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA except ImportError as e: @@ -434,16 +433,6 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "paligemma": - return FlashPaliGemma( - model_id, - revision, - quantize=quantize, - use_medusa=use_medusa, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - if model_type == "cohere": if FLASH_ATTENTION: return FlashCohere( diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index e1061f45..d2401327 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -39,9 +39,6 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) -# TODO: used for debugging; to avoid breaking during warmup -count = 0 - class GemmaConfig(PretrainedConfig): def __init__( @@ -66,8 +63,6 @@ class GemmaConfig(PretrainedConfig): rope_scaling=None, attention_bias=False, attention_dropout=0.0, - quantize: Optional[str] = None, - use_medusa: Optional[str] = None, **kwargs, ): self.vocab_size = vocab_size @@ -91,8 +86,6 @@ class GemmaConfig(PretrainedConfig): self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout - self.quantize = quantize - self.use_medusa = use_medusa super().__init__( pad_token_id=pad_token_id, @@ -106,7 +99,7 @@ class GemmaConfig(PretrainedConfig): class GemmaFastRMSNorm(FastRMSNorm): @classmethod def load(cls, prefix, weights, eps=1e-6): - weight = weights.get_tensor(f"{prefix}.weight") + weight = weights.get_tensor(f"{prefix}.weight") + 1 return cls(weight, eps) # perform the multiplication in full precision and downcast after @@ -117,7 +110,7 @@ class GemmaFastRMSNorm(FastRMSNorm): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = hidden_states * (self.weight.float() + 1.0) + hidden_states = hidden_states * self.weight return hidden_states.to(self.weight.dtype), residual @@ -159,107 +152,6 @@ def _load_gqa(config, prefix: str, weights): ) -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class GemmaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.register_buffer("inv_freq", None, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if self.inv_freq is None: - self.inv_freq = 1.0 / ( - self.base - ** ( - torch.arange( - 0, self.dim, 2, dtype=torch.int64, device=x.device - ).float() - / self.dim - ) - ) - - position_ids = position_ids.unsqueeze(0) - - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = ( - device_type - if isinstance(device_type, str) and device_type != "mps" - else "cpu" - ) - with torch.autocast(device_type=device_type, enabled=False): - - # import ipdb - # ipdb.set_trace() - freqs = ( - inv_freq_expanded.float() @ position_ids_expanded.float() - ).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - class FlashGemmaAttention(torch.nn.Module): def __init__( self, @@ -271,16 +163,9 @@ class FlashGemmaAttention(torch.nn.Module): self.num_heads = config.num_attention_heads self.head_size = config.head_dim - # self._rotary_emb = PositionRotaryEmbedding.static( - # config=config, - # dim=self.head_size, - # base=config.rope_theta, - # device=weights.device, - # ) - - self.rotary_emb = GemmaRotaryEmbedding( + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, dim=self.head_size, - max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, device=weights.device, ) @@ -297,21 +182,7 @@ class FlashGemmaAttention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - # TODO: prefer this implementation after debugging - # self.query_key_value = load_attention(config, prefix, weights) - - self.k_proj = TensorParallelColumnLinear.load( - config, - prefix=f"{prefix}.k_proj", - weights=weights, - bias=False, - ) - self.v_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.v_proj", weights=weights, bias=False - ) - self.q_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.q_proj", weights=weights, bias=False - ) + self.query_key_value = load_attention(config, prefix, weights) self.o_proj = TensorParallelRowLinear.load( config, @@ -323,7 +194,6 @@ class FlashGemmaAttention(torch.nn.Module): self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) - self.num_key_value_groups = self.num_heads // self.num_key_value_heads def forward( self, @@ -337,55 +207,21 @@ class FlashGemmaAttention(torch.nn.Module): input_lengths, max_s, ): - global count - - # TODO: replace with better implementation after debugging - tgt_len, src_len = hidden_states.size() - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = ( - query_states.unsqueeze(0).view(1, tgt_len, 8, 256).transpose(1, 2) - ) - key_states = key_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2) - value_states = ( - value_states.unsqueeze(0).view(1, tgt_len, 1, 256).transpose(1, 2) + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - # reshape for flash/paged attention - kv = torch.cat([key_states, value_states], dim=1).transpose(0, 2) - - # cos2, sin2 = self.rotary_emb(value_states, position_ids, seq_len=None) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, None - ) - - # replace the kv with the rotated kv - kv[:, 0] = key_states.squeeze(0).transpose(0, 1) - query = query_states.squeeze(0).transpose(0, 1) - - # TODO: remove after debugging - # ipdb> ref_query_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_query_states.pt").to("cuda:0") - # ipdb> torch.allclose(query,ref_query_states.squeeze(0).transpose(0,1)) - - # ipdb> ref_key_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_key_states.pt").to("cuda:0") - # ipdb> torch.allclose(kv[:, 0],ref_key_states.squeeze(0).transpose(0,1)) - - # ipdb> ref_value_states = torch.load("/home/ubuntu/Projects/new-model-addition-palma/ref_value_states.pt").to("cuda:0") - # ipdb> torch.allclose(kv[:, 1],ref_value_states.squeeze(0).transpose(0,1)) - - if count > 0: - import ipdb - - # ipdb.set_trace() + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) paged_attention.reshape_and_cache( - kv[:, 0], - kv[:, 1], - kv_cache[0], - kv_cache[1], - slots, + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) # output tensor @@ -459,9 +295,8 @@ class GemmaMLP(nn.Module): class FlashGemmaLayer(nn.Module): - def __init__(self, prefix, layer_id, config, weights): + def __init__(self, prefix, config, weights): super().__init__() - prefix = f"{prefix or ''}model.layers.{layer_id}" self.self_attn = FlashGemmaAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -521,38 +356,18 @@ class FlashGemmaModel(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - embed_norm = config.hidden_size**0.5 - pvalue = f"{prefix + '.' if prefix else ''}model.embed_tokens" - self.embed_tokens = TensorParallelEmbedding( - prefix=pvalue, - weights=weights, - ) - self.embed_tokens.weight = torch.nn.Parameter( - self.embed_tokens.weight[: config.vocab_size, : config.hidden_size] - ) - - # TODO: avoid making a copy of the embedding matrix. added for debugging - self.unscaled_embed_tokens = torch.nn.Parameter( - self.embed_tokens.weight.clone() - ) - - self.embed_tokens.weight *= embed_norm - self.layers = nn.ModuleList( [ FlashGemmaLayer( - f"{prefix + '.' if prefix else ''}", - layer_id, - config, - weights, + prefix=f"{prefix}.layers.{layer_id}", + config=config, + weights=weights, ) for layer_id in range(config.num_hidden_layers) ] ) self.norm = GemmaFastRMSNorm.load( - prefix=f"{prefix + '.' if prefix else ''}model.norm", - weights=weights, - eps=config.rms_norm_eps, + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps ) self.gradient_checkpointing = False @@ -563,8 +378,7 @@ class FlashGemmaModel(torch.nn.Module): def forward( self, - # input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, + input_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], @@ -573,17 +387,16 @@ class FlashGemmaModel(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, ) -> torch.Tensor: - global count - hidden_states = inputs_embeds + hidden_states = input_embeds + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) residual = None for i, layer in enumerate(self.layers): - # TODO: prefer a single rotary embedding implementation after debugging - cos, sin = self.layers[i].self_attn.rotary_emb( - hidden_states, - position_ids, - ) - hidden_states, residual = layer( hidden_states, residual, @@ -598,20 +411,33 @@ class FlashGemmaModel(torch.nn.Module): ) hidden_states, _ = self.norm(hidden_states, residual) - count += 1 # for debugging; to avoid breaking during warmup + return hidden_states class FlashGemmaForCausalLM(torch.nn.Module): def __init__(self, prefix, config, weights): super().__init__() - self.config = config - self.model = FlashGemmaModel(prefix, config, weights) - prefix = f"{prefix + '.' if prefix else ''}model.embed_tokens" - prefix = prefix if config.tie_word_embeddings else "lm_head" + + embed_norm = config.hidden_size**0.5 + if prefix is None: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + self.embed_tokens.weight *= embed_norm + + self.model = FlashGemmaModel(prefix=prefix, config=config, weights=weights) self.lm_head = SpeculativeHead.load( - config, - prefix=prefix, + prefix=( + f"{prefix}.embed_tokens" + if config.tie_word_embeddings + else f"{prefix}.lm_head" + ), + config=config, weights=weights, ) @@ -627,9 +453,9 @@ class FlashGemmaForCausalLM(torch.nn.Module): max_s: int, lm_head_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - inputs_embeds = self.model.embed_tokens(input_ids) + input_embeds = self.embed_tokens(input_ids) hidden_states = self.model( - inputs_embeds, + input_embeds, position_ids, cu_seqlen_prefill, kv_cache, diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 23d4ae22..9b3894bb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -166,7 +166,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): self.vocab_size = config.vocab_size self.config = config - self.language_model = load_text_model( + self.text_model = load_text_model( prefix="language_model" if not prefix else f"{prefix}.language_model", config=config.text_config, weights=weights, @@ -191,9 +191,7 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, pixel_attention_mask=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - inputs_embeds = torch.nn.functional.embedding( - input_ids, self.language_model.model.unscaled_embed_tokens - ) + inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: # TODO: avoid these casts upstream @@ -211,17 +209,14 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): mask = input_ids == self.config.image_token_index | (input_ids == 2) # insert image features into input embeddings + # normalizer = torch.tensor( + # self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype + # ) + # inputs_embeds = inputs_embeds * normalizer inputs_embeds[mask] = scaled_image_features.view( -1, scaled_image_features.shape[-1] ) - # NOTE: scale back up since we dont normalize inside the model like transformers - # TODO: simplify all the rescaling - normalizer = torch.tensor( - self.config.text_config.hidden_size**0.5, dtype=inputs_embeds.dtype - ) - inputs_embeds = inputs_embeds * normalizer - hidden_states = self.language_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, @@ -233,10 +228,6 @@ class FlashPaliGemmaForConditionalGeneration(nn.Module): max_s=max_s, ) - if input_ids.size(0) != 3000: - # import ipdb; ipdb.set_trace() - pass - if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.language_model.lm_head(hidden_states) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 03b7bac3..48690ad5 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -4,12 +4,11 @@ import torch.distributed from opentelemetry import trace from typing import Optional from transformers.models.gemma import GemmaTokenizerFast -from transformers import AutoConfig, PretrainedConfig +from transformers import AutoConfig from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( FlashGemmaForCausalLM, - GemmaConfig, ) from text_generation_server.utils import ( initialize_torch_distributed, @@ -20,58 +19,15 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) -class VisionConfig(PretrainedConfig): +class FlashGemma(FlashCausalLM): def __init__( self, - hidden_size: int = 1152, - intermediate_size: int = 4304, - model_type: str = "siglip_vision_model", - num_attention_heads: int = 16, - num_hidden_layers: int = 27, - num_image_tokens: int = 256, - patch_size: int = 14, - projection_dim: int = 2048, - projector_hidden_act: str = "gelu_fast", - vision_use_head: bool = False, - vocab_size: int = 257152, - quantize: Optional[str] = None, - image_size: int = 224, - layer_norm_eps: float = 1e-06, - attention_dropout: float = 0.0, - hidden_act: str = "gelu_pytorch_tanh", - num_channels: int = 3, - ): - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.model_type = model_type - self.num_attention_heads = num_attention_heads - self.num_hidden_layers = num_hidden_layers - self.num_image_tokens = num_image_tokens - self.patch_size = patch_size - self.projection_dim = projection_dim - self.projector_hidden_act = projector_hidden_act - self.vision_use_head = vision_use_head - self.vocab_size = vocab_size - self.quantize = quantize - self.image_size = image_size - self.layer_norm_eps = layer_norm_eps - self.attention_dropout = attention_dropout - self.hidden_act = hidden_act - self.num_channels = num_channels - - -class BaseFlashGemma(FlashCausalLM): - def __init__( - self, - model_cls, model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, - prefix: Optional[str] = None, - config_cls=AutoConfig, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -90,12 +46,13 @@ class BaseFlashGemma(FlashCausalLM): from_slow=False, ) - config = config_cls.from_pretrained( + config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) + import ipdb - is_vlm = hasattr(config, "vision_config") - + ipdb.set_trace() + config = config.text_config config.quantize = quantize config.speculator = speculator @@ -106,44 +63,19 @@ class BaseFlashGemma(FlashCausalLM): if config.quantize in ["gptq", "awq"]: weights._set_gptq_params(model_id, revision) - model = model_cls(prefix, config, weights) + # TODO hardcoded + prefix = "language_model" + model = FlashGemmaForCausalLM(prefix, config, weights) torch.distributed.barrier(group=self.process_group) - - num_layers = config.num_hidden_layers - num_kv_heads = config.num_key_value_heads - head_size = config.intermediate_size - - super().__init__( + super(FlashGemma, self).__init__( model=model, tokenizer=tokenizer, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_size=head_size, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, dtype=dtype, device=device, rank=rank, world_size=world_size, ) - - -class FlashGemma(BaseFlashGemma): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - use_medusa: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(FlashGemma, self).__init__( - model_cls=FlashGemmaForCausalLM, - model_id=model_id, - revision=revision, - quantize=quantize, - use_medusa=use_medusa, - dtype=dtype, - trust_remote_code=trust_remote_code, - prefix=None, - ) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index a539b6ad..12f709df 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -15,7 +15,6 @@ from text_generation_server.models.flash_mistral import ( BaseFlashMistral, FlashMistralBatch, ) -from text_generation_server.models.flash_gemma import BaseFlashGemma from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.cache_manager import ( get_cache_manager, @@ -516,126 +515,3 @@ class PaliVlmCausalLMBatch(FlashCausalLMBatch): batch.pixel_attention_mask = None batch.image_sizes = None return batch - - -class PaliVlmCausalLM(BaseFlashGemma): - @property - def batch_type(self) -> Type[PaliVlmCausalLMBatch]: - return PaliVlmCausalLMBatch - - def forward( - self, batch: PaliVlmCausalLMBatch - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # Model Forward - if batch.speculative_ids is not None: - input_ids = batch.input_ids - position_ids = batch.position_ids - cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache - block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen - lm_head_indices = batch.prefill_head_indices - - speculative_ids = batch.speculative_ids - - B, speculative_length = speculative_ids.shape - new_length = speculative_length + 1 - new_input_ids = torch.cat( - [input_ids.unsqueeze(-1), speculative_ids], dim=1 - ).reshape(-1) - arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) - arange_int = arange.to(dtype=torch.int32) - new_position_ids = ( - position_ids.unsqueeze(-1).expand(B, new_length) + arange - ).view(-1) - slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) - input_lengths = ( - input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int - ).view(-1) - - # Add Copy the block tables for all members - block_tables = ( - block_tables.unsqueeze(1) - .expand(B, new_length, -1) - .reshape(B * new_length, -1) - .contiguous() - ) - max_s = max_s + speculative_length - - input_ids = new_input_ids - position_ids = new_position_ids - else: - input_ids = batch.input_ids - position_ids = batch.position_ids - cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = get_cache_manager().kv_cache - block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - max_s = batch.max_seqlen - lm_head_indices = batch.prefill_head_indices - - if cu_seqlen_prefill is None and self.max_past() is not None: - # In decode, not prefill, we're actually overwriting the KV-cache - # in a circular buffer mode. - # This makes sure the max_s for the decode pass is correct. - max_s = min(self.max_past(), max_s) - - bs = input_ids.shape[0] - # Try to find an associated cuda graph - bs = input_ids.shape[0] - sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) - if sorted_padded_bs: - # Get associated cuda graph - cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] - else: - cuda_graph = None - if cu_seqlen_prefill is not None or cuda_graph is None: - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - # prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - pixel_values=batch.pixel_values, - ) - # if batch.prefill_cache_indices is not None: - # batch.prefill_cache_indices = None - if batch.pixel_values is not None: - batch.pixel_values = None - if batch.pixel_attention_mask is not None: - batch.pixel_attention_mask = None - if batch.image_sizes is not None: - batch.image_sizes = None - return logits, speculative_logits - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables - cuda_graph["slots"].fill_(-1) - cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - - # Replay the graph - cuda_graph["graph"].replay() - - # Slice output to the correct shape - speculative_logits = ( - cuda_graph["speculative_logits"][:bs] - if cuda_graph["speculative_logits"] is not None - else None - ) - logits = cuda_graph["logits"][:bs] - return logits, speculative_logits