diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 453b4f61a..f168fd76b 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -14,8 +14,8 @@ Text Generation Inference enables serving optimized models. The following sectio - [Gemma](https://huggingface.co/google/gemma-7b) - [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224) - [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) -- [Gemma3](https://huggingface.co/collections/google/gemma-3) -- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3) +- [Gemma3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) +- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct) - [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj) diff --git a/router/src/validation.rs b/router/src/validation.rs index 87b28eb74..1119347dc 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -699,7 +699,7 @@ fn image_tokens( // TODO: prefer using the config to determine the number of features let num_mm_soft_tokens_per_image = 256; format!( - "\n\n{:?}\n\n", + "\n\n{}\n\n", "".repeat(num_mm_soft_tokens_per_image) ) } diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index cdcfe91b1..782d66e4e 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -205,7 +205,6 @@ class LoraWeights(AdapterWeights): lora_a_list = [None] * nlayers lora_b_list = [None] * nlayers - # import ipdb; ipdb.set_trace() for layer_id in range(nlayers): key = (layer_id, layer_type) if key not in target_to_layer: diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 4f25cc192..505fbafab 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -38,6 +38,7 @@ def paged_attention( *, kv_scales: KVScales, softcap: Optional[float] = None, + window_size_left: Optional[int] = -1, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Copyright 2023 The vLLM team. All rights @@ -79,12 +80,15 @@ def paged_attention( sm_scale=softmax_scale, k_scale=kv_scales.key_scale_cpu if can_scale else 1.0, v_scale=kv_scales.value_scale_cpu if can_scale else 1.0, + window_size_left=window_size_left, ) elif ATTENTION == "flashdecoding": max_q = 1 max_k = max_s import flash_attn_2_cuda + window_size_right = -1 if window_size_left == -1 else 0 + # TODO fixme when flash contains the fix. # Number of splits is not correctly handled # by the current path @@ -109,8 +113,8 @@ def paged_attention( softmax_scale, False, # zero_tensors True, # causal - -1, # Window_left - -1, # Window right + window_size_left, # Window_left + window_size_right, # Window right softcap, False, # return softmax None, # generator @@ -253,6 +257,7 @@ def attention( sm_scale=softmax_scale, k_scale=kv_scales.key_scale_cpu if can_scale else 1.0, v_scale=kv_scales.value_scale_cpu if can_scale else 1.0, + window_size_left=window_size_left, ) # If we are using flashdecoding or paged, we always use flash-attn for diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index d23451844..9479b6067 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -52,7 +52,6 @@ def use_prefill_with_paged_kv_state( page_size: int, kv_dtype: torch.dtype, q_dtype: torch.dtype, - window_left: int, ): """ Context manager to set the active flashinfer prefill state to the given @@ -95,7 +94,6 @@ def use_prefill_with_paged_kv_state( kv_data_type=kv_dtype, q_data_type=q_dtype, page_size=page_size, - window_left=-1 if window_left is None else window_left, ) yield finally: @@ -172,7 +170,6 @@ def use_decode_state( page_size: int, kv_cache_dtype: torch.dtype, q_dtype: torch.dtype, - window_left: int, ): """ Context manager to set the active flashinfer decoding state to the given @@ -209,7 +206,6 @@ def use_decode_state( page_size=page_size, data_type=kv_cache_dtype, q_data_type=q_dtype, - window_left=-1 if window_left is None else window_left, ) yield finally: diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 0fdc009ca..2ea9caa5b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -272,12 +272,12 @@ class ModelType(enum.Enum): GEMMA3 = { "type": "gemma3", "name": "Gemma3", - "url": "https://huggingface.co/collections/google/gemma-3", + "url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d-release-67c6c6f89c4f76621268bb6d", } GEMMA3_TEXT = { "type": "gemma3_text", "name": "Gemma3 Text", - "url": "https://huggingface.co/collections/google/gemma-3", + "url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d-release-67c6c6f89c4f76621268bb6d", } COHERE = { "type": "cohere", diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index ebf1b80eb..2554bd269 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -287,6 +287,7 @@ class FlashGemma2Attention(torch.nn.Module): max_s, softcap=self.softcap, kv_scales=self.kv_scales, + window_size_left=self.window_size, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py index 085f57ef1..70fe9a3db 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py @@ -281,22 +281,12 @@ class FlashGemma3Attention(torch.nn.Module): padded_query = padded_query.transpose(1, 2).contiguous() padded_key = padded_key.transpose(1, 2).contiguous() padded_value = padded_value.transpose(1, 2).contiguous() - zeros_to_add = torch.zeros( - padded_key.size(0), - self.num_key_value_heads, - 1, - self.head_size, - dtype=padded_key.dtype, - device=padded_key.device, - ) - key_states = torch.cat([padded_key, zeros_to_add], dim=2) - value_states = torch.cat([padded_value, zeros_to_add], dim=2) # Compute attention attn_output = F.scaled_dot_product_attention( padded_query, - key_states, - value_states, + padded_key, + padded_value, attn_mask=attention_mask, scale=self.softmax_scale, enable_gqa=self.enable_gqa, @@ -327,6 +317,7 @@ class FlashGemma3Attention(torch.nn.Module): max_s, softcap=self.softcap, kv_scales=self.kv_scales, + window_size_left=self.window_size, ) return self.o_proj( @@ -513,6 +504,7 @@ class FlashGemma3Model(torch.nn.Module): max_s: int, adapter_data: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + attention_mask_local: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -525,25 +517,6 @@ class FlashGemma3Model(torch.nn.Module): position_ids, max_s, hidden_states.dtype ) - # apply sliding window mask if needed - if layer.self_attn.window_size > 0 and attention_mask is not None: - min_dtype = torch.finfo(hidden_states.dtype).min - # prefill may be larger than sliding window - effective_seq_len = max( - position_ids.shape[0], self.layers[i].self_attn.window_size - ) - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), - diagonal=-self.layers[i].self_attn.window_size, - ) - attention_mask = torch.where( - sliding_window_mask, min_dtype, attention_mask - ) - offset = max(0, position_ids.shape[0] - effective_seq_len) - attention_mask = attention_mask[ - :, :, offset : offset + effective_seq_len - ] - hidden_states, residual = layer( hidden_states, residual, @@ -556,7 +529,11 @@ class FlashGemma3Model(torch.nn.Module): seqlen, max_s, adapter_data, - attention_mask, + ( + attention_mask + if self.layers[i].self_attn.window_size == -1 + else attention_mask_local + ), ) hidden_states, _ = self.norm(hidden_states, residual) @@ -723,24 +700,6 @@ class Gemma3ForConditionalGeneration(nn.Module): config.pad_token_id if config.pad_token_id is not None else -1 ) - def get_image_token_mask(self, input_ids): - device = input_ids.device - - start_token_id = self.config.boi_token_index - K = self.config.mm_tokens_per_image - - mask = torch.zeros_like(input_ids, dtype=torch.bool, device=device) - start_positions = (input_ids == start_token_id).nonzero(as_tuple=True)[0] - mask_indices = start_positions.unsqueeze(1) + torch.arange( - 1, K + 1, device=device - ).unsqueeze(0) - - valid_mask = mask_indices < input_ids.size(0) - mask_indices = mask_indices[valid_mask] - mask[mask_indices] = True - - return mask - def get_attention_mask( self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask ): @@ -751,7 +710,7 @@ class Gemma3ForConditionalGeneration(nn.Module): batch_size = len(lengths) sequence_length = max(lengths) - target_length = max_s + target_length = sequence_length # Create the padding mask from the computed lengths. # pad_mask: [batch, sequence_length] where True indicates valid tokens. seq_range = torch.arange(sequence_length, device=device).unsqueeze(0) @@ -847,7 +806,7 @@ class Gemma3ForConditionalGeneration(nn.Module): # # Determine the maximum sequence length (after padding) from query. # sequence_length = max(lengths) - # target_length = max_s + # target_length = sequence_length # # Create the padding mask from the computed lengths. # # pad_mask: [batch, sequence_length] where True indicates valid tokens. @@ -885,6 +844,26 @@ class Gemma3ForConditionalGeneration(nn.Module): # input_ids.device # ) + if attention_mask is not None: + min_dtype = torch.finfo(inputs_embeds.dtype).min + # prefill may be larger than sliding window + effective_seq_len = max( + position_ids.shape[0], self.config.text_config.sliding_window + ) + sliding_window_mask = torch.tril( + torch.ones_like(attention_mask, dtype=torch.bool), + diagonal=-self.config.text_config.sliding_window, + ) + attention_mask_local = torch.where( + sliding_window_mask, min_dtype, attention_mask + ) + offset = max(0, position_ids.shape[0] - effective_seq_len) + attention_mask_local = attention_mask_local[ + :, :, :, offset : offset + effective_seq_len + ] + else: + attention_mask_local = None + hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, @@ -895,6 +874,7 @@ class Gemma3ForConditionalGeneration(nn.Module): seqlen=seqlen, max_s=max_s, attention_mask=attention_mask, + attention_mask_local=attention_mask_local, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 0fa172d03..7ad294f4b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -242,6 +242,7 @@ class MistralAttention(torch.nn.Module): seqlen, max_s, kv_scales=self.kv_scales, + window_size_left=self.max_past, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index a45dd1e61..e2a3e5860 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -290,6 +290,7 @@ class MixtralAttention(torch.nn.Module): seqlen, max_s, kv_scales=self.kv_scales, + window_size_left=self.max_past, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 9d9562222..75d519e45 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -185,6 +185,7 @@ class Qwen2Attention(torch.nn.Module): seqlen, max_s, kv_scales=self.kv_scales, + window_size_left=self.max_past, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 5e090369b..9508cc4f8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -291,6 +291,7 @@ class Starcoder2Attention(torch.nn.Module): seqlen, max_s, kv_scales=self.kv_scales, + window_size_left=self.max_past, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/gemma3/processing_gemma3.py b/server/text_generation_server/models/custom_modeling/gemma3/processing_gemma3.py index 08e39a7c6..6bdf35c63 100644 --- a/server/text_generation_server/models/custom_modeling/gemma3/processing_gemma3.py +++ b/server/text_generation_server/models/custom_modeling/gemma3/processing_gemma3.py @@ -82,7 +82,7 @@ class Gemma3Processor(ProcessorMixin): do_rescale=False, resample=PILImageResampling.BILINEAR, ) - # import ipdb; ipdb.set_trace() + self.image_token_id = tokenizer.image_token_id image_tokens_expanded = "".join( [tokenizer.image_token] * num_mm_soft_tokens_per_image @@ -91,8 +91,6 @@ class Gemma3Processor(ProcessorMixin): f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n" ) - # import ipdb; ipdb.set_trace() - self.image_processor = image_processor self.tokenizer = tokenizer self.chat_template = chat_template diff --git a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index e317c5b56..066de6a20 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -633,7 +633,7 @@ class Qwen2_5VisionModel(nn.Module): config=config, weights=weights, ) - # import ipdb; ipdb.set_trace() + self.temporal_patch_size = config.temporal_patch_size self.spatial_patch_size = config.spatial_patch_size self.in_channels = config.in_channels diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e268af8b4..d3a83e271 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -83,24 +83,11 @@ from text_generation_server.models.metadata_kernels import ( tracer = trace.get_tracer(__name__) -# Will be set in init -SLIDING_WINDOW: Optional[int] = None - def small_power_of_2(n: int): return 1 << ((n - 1).bit_length() - 1) -def set_sliding_window(sliding_window: int): - global SLIDING_WINDOW - SLIDING_WINDOW = sliding_window - - -def get_sliding_windows() -> int: - global SLIDING_WINDOW - return SLIDING_WINDOW - - def init_cpu_threads_env(rank_id: int, world_size: int): import importlib.util @@ -1002,10 +989,8 @@ class FlashCausalLMBatch(Batch): self.slot_indices, ) - sliding_window = get_sliding_windows() position_ids = [] slot_indices = [] - prefill_cache_indices = [] all_prefill_logprobs = True no_prefill_logprobs = True prefill_cu_outlens = [0] @@ -1064,14 +1049,6 @@ class FlashCausalLMBatch(Batch): # Update cumulative_slot_tokens += len(request_slots) - # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - sliding_window), - cumulative_length + input_length, - dtype=torch.int64, - ) - # Prefill logprobs is ignored if the request is done prefilling prefill_logprobs = r.prefill_logprobs and request_prefilling @@ -1085,9 +1062,6 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - if sliding_window is not None: - prefill_cache_indices.append(request_prefill_cache_indices) - ADAPTER_TO_INDEX = get_adapter_to_index() if ADAPTER_TO_INDEX: adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) @@ -1151,24 +1125,18 @@ class FlashCausalLMBatch(Batch): position_ids = torch.cat(position_ids) if slot_indices: slot_indices = torch.cat(slot_indices) - if sliding_window is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) else: if position_ids: position_ids = position_ids[0] if slot_indices: slot_indices = slot_indices[0] - if sliding_window is not None: - prefill_cache_indices = prefill_cache_indices[0] if not has_triton(): self.position_ids = position_ids.to(device) self.slot_indices = slot_indices.to(device) self.prefill_cu_outlens = prefill_cu_outlens - self.prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None - ) + self.prefill_cache_indices = None if all_prefill_logprobs: prefill_head_indices = None @@ -1306,9 +1274,7 @@ class FlashCausalLM(Model): if text_config is not None: config = text_config - if getattr(config, "sliding_window", None) is not None: - set_sliding_window(config.sliding_window) - else: + if getattr(config, "sliding_window", None) is None: config.sliding_window = None self.num_layers = config.num_hidden_layers @@ -2500,7 +2466,6 @@ class FlashCausalLM(Model): page_size=BLOCK_SIZE, kv_dtype=self.kv_cache_dtype, q_dtype=self.dtype, - window_left=self.sliding_window, ) else: assert input_lengths_tensor is not None @@ -2514,5 +2479,4 @@ class FlashCausalLM(Model): page_size=BLOCK_SIZE, kv_cache_dtype=self.kv_cache_dtype, q_dtype=self.dtype, - window_left=self.sliding_window, ) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index af4d1f082..da317a628 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -110,7 +110,7 @@ class Model(ABC): requires_padding=self.requires_padding, dtype=str(self.dtype), device_type=self.device.type, - window_size=self.sliding_window, + window_size=None, # Setting this parameter to None disabled the block logic with sliding window. speculate=self.speculate, support_chunking=self.support_chunking, use_prefix_caching=PREFIX_CACHING,