From d4da0d4d97e020968ed3f3c79559d6e545339d87 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 9 Apr 2024 19:04:44 +0200 Subject: [PATCH] use custom vllm with kv_head_mapping --- Dockerfile | 2 +- router/src/infer.rs | 78 ++++++++++-------- router/src/lib.rs | 17 +++- server/Makefile-vllm | 4 +- .../custom_modeling/flash_cohere_modeling.py | 81 +++---------------- .../custom_modeling/flash_dbrx_modeling.py | 6 +- .../custom_modeling/flash_gemma_modeling.py | 6 +- .../custom_modeling/flash_llama_modeling.py | 6 +- .../custom_modeling/flash_mixtral_modeling.py | 6 +- .../custom_modeling/flash_neox_modeling.py | 5 +- .../custom_modeling/flash_phi_modeling.py | 6 +- .../custom_modeling/flash_qwen2_modeling.py | 6 +- .../custom_modeling/flash_rw_modeling.py | 17 +++- .../flash_santacoder_modeling.py | 5 +- .../flash_starcoder2_modeling.py | 5 +- .../models/flash_causal_lm.py | 5 ++ .../models/flash_cohere.py | 3 +- .../utils/paged_attention.py | 6 +- 18 files changed, 140 insertions(+), 124 deletions(-) diff --git a/Dockerfile b/Dockerfile index 360a6d2c..0bc5f8d9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -211,7 +211,7 @@ COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /op COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages # Install vllm/flash-attention dependencies -RUN pip install einops py-cpuinfo prometheus_client --no-cache-dir +RUN pip install einops --no-cache-dir # Install server COPY proto proto diff --git a/router/src/infer.rs b/router/src/infer.rs index e5517511..075e76d8 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,8 +1,8 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; use crate::{ - ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig, - Message, PrefillToken, Queue, Token, + ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse, + HubTokenizerConfig, Message, PrefillToken, Queue, Token, }; use futures::future::try_join_all; use minijinja::{Environment, ErrorKind, Template}; @@ -86,7 +86,18 @@ impl Infer { let chat_template = tokenizer_config .chat_template - .map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); + .and_then(|t| match t { + ChatTemplateVersions::Single(template) => Some(template), + ChatTemplateVersions::Multiple(templates) => templates + .into_iter() + .find(|t| t.name == "default") + .map(|t| t.template), + }) + .map(|t| { + // .strip() is not supported in minijinja + let t = t.replace(".strip()", " | trim"); + ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token) + }); // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); @@ -1099,7 +1110,7 @@ mod tests { ChatTemplateTestItem { name: "_base", chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", - input: ChatTemplateInputs{ + input: ChatTemplateInputs { messages: example_chat.clone(), add_generation_prompt: false, bos_token: Some(""), @@ -1110,7 +1121,7 @@ mod tests { ChatTemplateTestItem { name: "blenderbot", chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", - input: ChatTemplateInputs{ + input: ChatTemplateInputs { messages: example_chat.clone(), add_generation_prompt: false, bos_token: Some(""), @@ -1121,7 +1132,7 @@ mod tests { ChatTemplateTestItem { name: "blenderbot_small", chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", - input: ChatTemplateInputs{ + input: ChatTemplateInputs { messages: example_chat.clone(), add_generation_prompt: false, bos_token: Some(""), @@ -1132,7 +1143,7 @@ mod tests { ChatTemplateTestItem { name: "bloom", chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - input: ChatTemplateInputs{ + input: ChatTemplateInputs { messages: example_chat.clone(), add_generation_prompt: false, bos_token: Some(""), @@ -1143,7 +1154,7 @@ mod tests { ChatTemplateTestItem { name: "gpt_neox", chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - input: ChatTemplateInputs{ + input: ChatTemplateInputs { messages: example_chat.clone(), add_generation_prompt: false, bos_token: Some(""), @@ -1154,38 +1165,37 @@ mod tests { ChatTemplateTestItem { name: "gpt2", chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - input: ChatTemplateInputs{ - messages: example_chat.clone(), - add_generation_prompt: false, - bos_token: Some(""), - eos_token: Some("<|endoftext|>"), + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some("<|endoftext|>"), }, - target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" + target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", }, ChatTemplateTestItem { name: "llama", // NOTE: the `.strip()` has been replaced with `| trim` in the following template chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content | trim + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}", - input: ChatTemplateInputs{ - messages: example_chat_with_system.clone(), - add_generation_prompt: true, - bos_token: Some(""), - eos_token: Some(""), + input: ChatTemplateInputs { + messages: example_chat_with_system.clone(), + add_generation_prompt: true, + bos_token: Some(""), + eos_token: Some(""), }, - target: "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + target: "[INST] <>\nYou are a friendly chatbot who always responds in the style of a pirate\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", }, ChatTemplateTestItem { name: "whisper", chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", - input: ChatTemplateInputs{ - messages: example_chat.clone(), - add_generation_prompt: true, - bos_token: Some(""), - eos_token: Some("<|endoftext|>"), + input: ChatTemplateInputs { + messages: example_chat.clone(), + add_generation_prompt: true, + bos_token: Some(""), + eos_token: Some("<|endoftext|>"), }, - target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" - } - + target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>", + }, ]; #[allow(unused_variables)] // name is unused @@ -1211,7 +1221,7 @@ mod tests { messages: example_chat_with_system.clone(), add_generation_prompt: false, bos_token: Some(""), - eos_token: Some("") + eos_token: Some(""), }, target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHello, how are you?<|assistant|>\nI'm doing great. How can I help you today?<|user|>\nI'd like to show off how chat templating works!", }, @@ -1237,7 +1247,7 @@ mod tests { bos_token: Some(""), eos_token: Some(""), }, - target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHow many helicopters can a human eat in one sitting?<|assistant|>" + target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate<|user|>\nHow many helicopters can a human eat in one sitting?<|assistant|>", }, ChatTemplateTestItem { name: "HuggingFaceH4/zephyr-7b-gemma-v0.1", @@ -1259,7 +1269,7 @@ mod tests { bos_token: Some(""), eos_token: Some(""), }, - target: "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + target: "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]", }, ChatTemplateTestItem { name: "mistralai/Mixtral-8x7B-Instruct-v0.1", @@ -1276,7 +1286,7 @@ mod tests { name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b", chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", input: ChatTemplateInputs { - messages: example_chat.clone(), + messages: example_chat.clone(), add_generation_prompt: false, bos_token: Some(""), eos_token: Some(""), @@ -1360,7 +1370,7 @@ mod tests { bos_token: Some(""), eos_token: Some(""), }, - target: "<|prompt|>Hello, how are you?<|answer|>I'm doing great. How can I help you today?<|prompt|>I'd like to show off how chat templating works!" + target: "<|prompt|>Hello, how are you?<|answer|>I'm doing great. How can I help you today?<|prompt|>I'd like to show off how chat templating works!", }, ChatTemplateTestItem { name: "internlm/internlm2-chat-7b", @@ -1443,7 +1453,7 @@ mod tests { eos_token: Some(""), }, target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", - } + }, ]; #[allow(unused_variables)] // name is unused diff --git a/router/src/lib.rs b/router/src/lib.rs index 5415a956..ea0179ea 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -48,9 +48,22 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } -#[derive(Clone, Deserialize, Default)] +#[derive(Debug, Clone, Deserialize)] +pub struct ChatTemplate { + name: String, + template: String, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +pub enum ChatTemplateVersions { + Single(String), + Multiple(Vec), +} + +#[derive(Debug, Clone, Deserialize, Default)] pub struct HubTokenizerConfig { - pub chat_template: Option, + pub chat_template: Option, pub completion_template: Option, #[serde(deserialize_with = "token_serde::deserialize")] pub bos_token: Option, diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 17660e8b..338ef8c8 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,10 +1,10 @@ vllm-cuda: # Clone vllm pip install -U ninja packaging --no-cache-dir - git clone https://github.com/vllm-project/vllm.git vllm + git clone https://github.com/OlivierDehaene/vllm.git vllm build-vllm-cuda: vllm-cuda - cd vllm && git fetch && git checkout b7782002e1da25de77e0b1890ff8b72dd4df917c + cd vllm && git fetch && git checkout 2b42fc4826258c961f22b5172ada41a9e0c87686 cd vllm && python setup.py build install-vllm-cuda: build-vllm-cuda diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index c9d87972..8df7e075 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -23,7 +23,6 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn @@ -47,14 +46,18 @@ else: class CohereLayerNorm(nn.Module): def __init__(self, prefix, weights, eps): super().__init__() - weight = weights.get_tensor(f"{prefix}.weight") + weight = weights.get_sharded(f"{prefix}.weight", dim=0) self.weight = nn.Parameter(weight) # Fake weights self.ones = weight.new_ones(weight.shape[1]) self.eps = eps def forward(self, hidden_states): - if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: + # if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: + if True: + hidden_states = hidden_states.reshape( + -1, self.weight.shape[0], self.weight.shape[1] + ) input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) mean = hidden_states.mean(-1, keepdim=True) @@ -62,6 +65,7 @@ class CohereLayerNorm(nn.Module): variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps) hidden_states = self.weight.to(torch.float32) * hidden_states + hidden_states = hidden_states.view(-1, self.weight.shape[1]) return hidden_states.to(input_dtype) ( @@ -95,64 +99,6 @@ class CohereLayerNorm(nn.Module): return hidden_states -class CohereConfig(PretrainedConfig): - def __init__( - self, - vocab_size=256000, - hidden_size=8192, - intermediate_size=22528, - num_hidden_layers=40, - num_attention_heads=64, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=8192, - initializer_range=0.02, - layer_norm_eps=1e-5, - use_cache=True, - pad_token_id=0, - bos_token_id=5, - eos_token_id=255001, - pretraining_tp=1, - tie_word_embeddings=True, - rope_theta=10000.0, - attention_bias=False, - attention_dropout=0.0, - logit_scale=1.0, - use_qk_norm=False, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.logit_scale = logit_scale - self.use_qk_norm = use_qk_norm - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) @@ -236,23 +182,16 @@ class FlashCohereAttention(torch.nn.Module): self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: - rank = weights.process_group.rank() self.q_norm = CohereLayerNorm( prefix=f"{prefix}.q_norm", weights=weights, eps=config.layer_norm_eps, ) - self.q_norm.weight.data = self.q_norm.weight[ - self.num_heads * rank : self.num_heads * (rank + 1) - ] self.k_norm = CohereLayerNorm( prefix=f"{prefix}.k_norm", weights=weights, eps=config.layer_norm_eps, ) - self.k_norm.weight.data = self.k_norm.weight[ - self.num_key_value_heads * rank : self.num_key_value_heads * (rank + 1) - ] else: self.q_norm = None self.k_norm = None @@ -263,6 +202,10 @@ class FlashCohereAttention(torch.nn.Module): weights=weights, bias=config.attention_bias, ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) def forward( self, @@ -322,7 +265,7 @@ class FlashCohereAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 92423d89..d04ce39e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -375,6 +375,10 @@ class DbrxAttention(torch.nn.Module): weights=weights, bias=False, ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) def forward( self, @@ -430,7 +434,7 @@ class DbrxAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index e66c56d1..bd7596db 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -188,6 +188,10 @@ class FlashGemmaAttention(torch.nn.Module): weights=weights, bias=False, ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) def forward( self, @@ -240,7 +244,7 @@ class FlashGemmaAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 64ff6a85..3a269fc0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -176,6 +176,10 @@ class FlashLlamaAttention(torch.nn.Module): weights=weights, bias=False, ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) def forward( self, @@ -228,7 +232,7 @@ class FlashLlamaAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 52ac8fa4..89eb8f43 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -222,6 +222,10 @@ class MixtralAttention(torch.nn.Module): weights=weights, bias=False, ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) def forward( self, @@ -281,7 +285,7 @@ class MixtralAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index ad8933a2..ee062d3d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -120,6 +120,9 @@ class FlashNeoxAttention(torch.nn.Module): self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=True ) + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) def forward( self, @@ -165,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 0], kv_cache[0], kv_cache[1], - self.num_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 48f54e25..cfe447a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -140,6 +140,10 @@ class FlashPhiAttention(torch.nn.Module): weights=weights, bias=True, ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) def forward( self, @@ -202,7 +206,7 @@ class FlashPhiAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index a8268220..94023b33 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -104,6 +104,10 @@ class Qwen2Attention(torch.nn.Module): weights=weights, bias=False, ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) def forward( self, @@ -163,7 +167,7 @@ class Qwen2Attention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 3ac912f4..a9127d1f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -151,6 +151,15 @@ class FlashRWAttention(torch.nn.Module): config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) + if self.num_heads_kv == 1: + self.kv_head_mapping = torch.zeros( + self.num_heads, dtype=torch.int32, device=weights.device + ) + else: + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) + def forward( self, hidden_states, @@ -204,7 +213,7 @@ class FlashRWAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.num_heads_kv, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, @@ -263,6 +272,10 @@ class FlashRWLargeAttention(torch.nn.Module): config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) + self.kv_head_mapping = torch.arange( + 0, self.num_groups, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_heads) + def forward( self, hidden_states, @@ -319,7 +332,7 @@ class FlashRWLargeAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.num_groups, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 63b458b2..bbb603a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -241,6 +241,9 @@ class FlashMQAttention(torch.nn.Module): self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) + self.kv_head_mapping = torch.zeros( + self.num_heads, dtype=torch.int32, device=weights.device + ) def forward( self, @@ -289,7 +292,7 @@ class FlashMQAttention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.num_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 63395099..ed77af78 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -190,6 +190,9 @@ class Starcoder2Attention(torch.nn.Module): bias=config.use_bias, ) self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) def forward( self, @@ -249,7 +252,7 @@ class Starcoder2Attention(torch.nn.Module): query, kv_cache[0], kv_cache[1], - self.num_key_value_heads, + self.kv_head_mapping, self.softmax_scale, block_tables, input_lengths, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 57dd8704..02ba704a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -165,6 +165,11 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping[r.id] = i tokenized_input = tokenized_input[-r.truncate :] + if ( + tokenized_input[0] == tokenizer.bos_token_id + and tokenized_input[1] == tokenizer.bos_token_id + ): + tokenized_input = tokenized_input[1:] input_length = len(tokenized_input) input_lengths.append(input_length) diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index 0c64a036..ebdf3793 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -3,13 +3,12 @@ import torch.distributed from opentelemetry import trace from typing import Optional -from transformers import AutoTokenizer +from transformers.models.cohere import AutoTokenizer, CohereConfig from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( FlashCohereForCausalLM, CohereConfig, -) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 09e426ae..18e605b0 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -21,7 +21,7 @@ def attention( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, - num_key_value_heads: torch.Tensor, + kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, input_lengths: torch.Tensor, @@ -60,7 +60,7 @@ def attention( query, key_cache, value_cache, - num_key_value_heads, + kv_head_mapping, softmax_scale, block_tables, input_lengths, @@ -92,7 +92,7 @@ def attention( query, key_cache, value_cache, - num_key_value_heads, + kv_head_mapping, softmax_scale, block_tables, input_lengths,