use custom vllm with kv_head_mapping

This commit is contained in:
OlivierDehaene 2024-04-09 19:04:44 +02:00
parent 0604c5cb83
commit d4da0d4d97
18 changed files with 140 additions and 124 deletions

View File

@ -211,7 +211,7 @@ COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /op
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
# Install vllm/flash-attention dependencies # Install vllm/flash-attention dependencies
RUN pip install einops py-cpuinfo prometheus_client --no-cache-dir RUN pip install einops --no-cache-dir
# Install server # Install server
COPY proto proto COPY proto proto

View File

@ -1,8 +1,8 @@
/// Batching and inference logic /// Batching and inference logic
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::{ use crate::{
ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig, ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
Message, PrefillToken, Queue, Token, HubTokenizerConfig, Message, PrefillToken, Queue, Token,
}; };
use futures::future::try_join_all; use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template}; use minijinja::{Environment, ErrorKind, Template};
@ -86,7 +86,18 @@ impl Infer {
let chat_template = tokenizer_config let chat_template = tokenizer_config
.chat_template .chat_template
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)); .and_then(|t| match t {
ChatTemplateVersions::Single(template) => Some(template),
ChatTemplateVersions::Multiple(templates) => templates
.into_iter()
.find(|t| t.name == "default")
.map(|t| t.template),
})
.map(|t| {
// .strip() is not supported in minijinja
let t = t.replace(".strip()", " | trim");
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
});
// Inference limit with a semaphore // Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
@ -1099,7 +1110,7 @@ mod tests {
ChatTemplateTestItem { ChatTemplateTestItem {
name: "_base", name: "_base",
chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
input: ChatTemplateInputs{ input: ChatTemplateInputs {
messages: example_chat.clone(), messages: example_chat.clone(),
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
@ -1110,7 +1121,7 @@ mod tests {
ChatTemplateTestItem { ChatTemplateTestItem {
name: "blenderbot", name: "blenderbot",
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
input: ChatTemplateInputs{ input: ChatTemplateInputs {
messages: example_chat.clone(), messages: example_chat.clone(),
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
@ -1121,7 +1132,7 @@ mod tests {
ChatTemplateTestItem { ChatTemplateTestItem {
name: "blenderbot_small", name: "blenderbot_small",
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
input: ChatTemplateInputs{ input: ChatTemplateInputs {
messages: example_chat.clone(), messages: example_chat.clone(),
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
@ -1132,7 +1143,7 @@ mod tests {
ChatTemplateTestItem { ChatTemplateTestItem {
name: "bloom", name: "bloom",
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
input: ChatTemplateInputs{ input: ChatTemplateInputs {
messages: example_chat.clone(), messages: example_chat.clone(),
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
@ -1143,7 +1154,7 @@ mod tests {
ChatTemplateTestItem { ChatTemplateTestItem {
name: "gpt_neox", name: "gpt_neox",
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
input: ChatTemplateInputs{ input: ChatTemplateInputs {
messages: example_chat.clone(), messages: example_chat.clone(),
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
@ -1154,38 +1165,37 @@ mod tests {
ChatTemplateTestItem { ChatTemplateTestItem {
name: "gpt2", name: "gpt2",
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
input: ChatTemplateInputs{ input: ChatTemplateInputs {
messages: example_chat.clone(), messages: example_chat.clone(),
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
eos_token: Some("<|endoftext|>"), eos_token: Some("<|endoftext|>"),
}, },
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
}, },
ChatTemplateTestItem { ChatTemplateTestItem {
name: "llama", name: "llama",
// NOTE: the `.strip()` has been replaced with `| trim` in the following template // NOTE: the `.strip()` has been replaced with `| trim` in the following template
chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content | trim + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}", chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content | trim + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
input: ChatTemplateInputs{ input: ChatTemplateInputs {
messages: example_chat_with_system.clone(), messages: example_chat_with_system.clone(),
add_generation_prompt: true, add_generation_prompt: true,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
}, },
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]" target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
}, },
ChatTemplateTestItem { ChatTemplateTestItem {
name: "whisper", name: "whisper",
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
input: ChatTemplateInputs{ input: ChatTemplateInputs {
messages: example_chat.clone(), messages: example_chat.clone(),
add_generation_prompt: true, add_generation_prompt: true,
bos_token: Some(""), bos_token: Some(""),
eos_token: Some("<|endoftext|>"), eos_token: Some("<|endoftext|>"),
}, },
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
} },
]; ];
#[allow(unused_variables)] // name is unused #[allow(unused_variables)] // name is unused
@ -1211,7 +1221,7 @@ mod tests {
messages: example_chat_with_system.clone(), messages: example_chat_with_system.clone(),
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some(""), bos_token: Some(""),
eos_token: Some("</s>") eos_token: Some("</s>"),
}, },
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>", target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>",
}, },
@ -1237,7 +1247,7 @@ mod tests {
bos_token: Some(""), bos_token: Some(""),
eos_token: Some("</s>"), eos_token: Some("</s>"),
}, },
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>" target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>",
}, },
ChatTemplateTestItem { ChatTemplateTestItem {
name: "HuggingFaceH4/zephyr-7b-gemma-v0.1", name: "HuggingFaceH4/zephyr-7b-gemma-v0.1",
@ -1259,7 +1269,7 @@ mod tests {
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
}, },
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]" target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]",
}, },
ChatTemplateTestItem { ChatTemplateTestItem {
name: "mistralai/Mixtral-8x7B-Instruct-v0.1", name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
@ -1276,7 +1286,7 @@ mod tests {
name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b", name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b",
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
input: ChatTemplateInputs { input: ChatTemplateInputs {
messages: example_chat.clone(), messages: example_chat.clone(),
add_generation_prompt: false, add_generation_prompt: false,
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
@ -1360,7 +1370,7 @@ mod tests {
bos_token: Some("<s>"), bos_token: Some("<s>"),
eos_token: Some("</s>"), eos_token: Some("</s>"),
}, },
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>" target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>",
}, },
ChatTemplateTestItem { ChatTemplateTestItem {
name: "internlm/internlm2-chat-7b", name: "internlm/internlm2-chat-7b",
@ -1443,7 +1453,7 @@ mod tests {
eos_token: Some("</s>"), eos_token: Some("</s>"),
}, },
target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
} },
]; ];
#[allow(unused_variables)] // name is unused #[allow(unused_variables)] // name is unused

View File

@ -48,9 +48,22 @@ pub struct HubModelInfo {
pub pipeline_tag: Option<String>, pub pipeline_tag: Option<String>,
} }
#[derive(Clone, Deserialize, Default)] #[derive(Debug, Clone, Deserialize)]
pub struct ChatTemplate {
name: String,
template: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum ChatTemplateVersions {
Single(String),
Multiple(Vec<ChatTemplate>),
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct HubTokenizerConfig { pub struct HubTokenizerConfig {
pub chat_template: Option<String>, pub chat_template: Option<ChatTemplateVersions>,
pub completion_template: Option<String>, pub completion_template: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")] #[serde(deserialize_with = "token_serde::deserialize")]
pub bos_token: Option<String>, pub bos_token: Option<String>,

View File

@ -1,10 +1,10 @@
vllm-cuda: vllm-cuda:
# Clone vllm # Clone vllm
pip install -U ninja packaging --no-cache-dir pip install -U ninja packaging --no-cache-dir
git clone https://github.com/vllm-project/vllm.git vllm git clone https://github.com/OlivierDehaene/vllm.git vllm
build-vllm-cuda: vllm-cuda build-vllm-cuda: vllm-cuda
cd vllm && git fetch && git checkout b7782002e1da25de77e0b1890ff8b72dd4df917c cd vllm && git fetch && git checkout 2b42fc4826258c961f22b5172ada41a9e0c87686
cd vllm && python setup.py build cd vllm && python setup.py build
install-vllm-cuda: build-vllm-cuda install-vllm-cuda: build-vllm-cuda

View File

@ -23,7 +23,6 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
@ -47,14 +46,18 @@ else:
class CohereLayerNorm(nn.Module): class CohereLayerNorm(nn.Module):
def __init__(self, prefix, weights, eps): def __init__(self, prefix, weights, eps):
super().__init__() super().__init__()
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_sharded(f"{prefix}.weight", dim=0)
self.weight = nn.Parameter(weight) self.weight = nn.Parameter(weight)
# Fake weights # Fake weights
self.ones = weight.new_ones(weight.shape[1]) self.ones = weight.new_ones(weight.shape[1])
self.eps = eps self.eps = eps
def forward(self, hidden_states): def forward(self, hidden_states):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: # if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
if True:
hidden_states = hidden_states.reshape(
-1, self.weight.shape[0], self.weight.shape[1]
)
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True) mean = hidden_states.mean(-1, keepdim=True)
@ -62,6 +65,7 @@ class CohereLayerNorm(nn.Module):
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True) variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps) hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
hidden_states = self.weight.to(torch.float32) * hidden_states hidden_states = self.weight.to(torch.float32) * hidden_states
hidden_states = hidden_states.view(-1, self.weight.shape[1])
return hidden_states.to(input_dtype) return hidden_states.to(input_dtype)
( (
@ -95,64 +99,6 @@ class CohereLayerNorm(nn.Module):
return hidden_states return hidden_states
class CohereConfig(PretrainedConfig):
def __init__(
self,
vocab_size=256000,
hidden_size=8192,
intermediate_size=22528,
num_hidden_layers=40,
num_attention_heads=64,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=8192,
initializer_range=0.02,
layer_norm_eps=1e-5,
use_cache=True,
pad_token_id=0,
bos_token_id=5,
eos_token_id=255001,
pretraining_tp=1,
tie_word_embeddings=True,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
logit_scale=1.0,
use_qk_norm=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.logit_scale = logit_scale
self.use_qk_norm = use_qk_norm
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
@ -236,23 +182,16 @@ class FlashCohereAttention(torch.nn.Module):
self.use_qk_norm = config.use_qk_norm self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm: if self.use_qk_norm:
rank = weights.process_group.rank()
self.q_norm = CohereLayerNorm( self.q_norm = CohereLayerNorm(
prefix=f"{prefix}.q_norm", prefix=f"{prefix}.q_norm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
self.q_norm.weight.data = self.q_norm.weight[
self.num_heads * rank : self.num_heads * (rank + 1)
]
self.k_norm = CohereLayerNorm( self.k_norm = CohereLayerNorm(
prefix=f"{prefix}.k_norm", prefix=f"{prefix}.k_norm",
weights=weights, weights=weights,
eps=config.layer_norm_eps, eps=config.layer_norm_eps,
) )
self.k_norm.weight.data = self.k_norm.weight[
self.num_key_value_heads * rank : self.num_key_value_heads * (rank + 1)
]
else: else:
self.q_norm = None self.q_norm = None
self.k_norm = None self.k_norm = None
@ -263,6 +202,10 @@ class FlashCohereAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=config.attention_bias, bias=config.attention_bias,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -322,7 +265,7 @@ class FlashCohereAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.num_key_value_heads, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -375,6 +375,10 @@ class DbrxAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -430,7 +434,7 @@ class DbrxAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.num_key_value_heads, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -188,6 +188,10 @@ class FlashGemmaAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -240,7 +244,7 @@ class FlashGemmaAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.num_key_value_heads, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -176,6 +176,10 @@ class FlashLlamaAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -228,7 +232,7 @@ class FlashLlamaAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.num_key_value_heads, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -222,6 +222,10 @@ class MixtralAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -281,7 +285,7 @@ class MixtralAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.num_key_value_heads, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -120,6 +120,9 @@ class FlashNeoxAttention(torch.nn.Module):
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True config, prefix=f"{prefix}.dense", weights=weights, bias=True
) )
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
@ -165,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.num_heads, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -140,6 +140,10 @@ class FlashPhiAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=True, bias=True,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -202,7 +206,7 @@ class FlashPhiAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.num_key_value_heads, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -104,6 +104,10 @@ class Qwen2Attention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -163,7 +167,7 @@ class Qwen2Attention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.num_key_value_heads, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -151,6 +151,15 @@ class FlashRWAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
if self.num_heads_kv == 1:
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
else:
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -204,7 +213,7 @@ class FlashRWAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.num_heads_kv, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
@ -263,6 +272,10 @@ class FlashRWLargeAttention(torch.nn.Module):
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
self.kv_head_mapping = torch.arange(
0, self.num_groups, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_heads)
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -319,7 +332,7 @@ class FlashRWLargeAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.num_groups, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -241,6 +241,9 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row( self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
) )
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
def forward( def forward(
self, self,
@ -289,7 +292,7 @@ class FlashMQAttention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.num_heads, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -190,6 +190,9 @@ class Starcoder2Attention(torch.nn.Module):
bias=config.use_bias, bias=config.use_bias,
) )
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -249,7 +252,7 @@ class Starcoder2Attention(torch.nn.Module):
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.num_key_value_heads, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,

View File

@ -165,6 +165,11 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :] tokenized_input = tokenized_input[-r.truncate :]
if (
tokenized_input[0] == tokenizer.bos_token_id
and tokenized_input[1] == tokenizer.bos_token_id
):
tokenized_input = tokenized_input[1:]
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)

View File

@ -3,13 +3,12 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Optional from typing import Optional
from transformers import AutoTokenizer from transformers.models.cohere import AutoTokenizer, CohereConfig
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM, FlashCohereForCausalLM,
CohereConfig, CohereConfig,
)
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,

View File

@ -21,7 +21,7 @@ def attention(
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
num_key_value_heads: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
@ -60,7 +60,7 @@ def attention(
query, query,
key_cache, key_cache,
value_cache, value_cache,
num_key_value_heads, kv_head_mapping,
softmax_scale, softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,
@ -92,7 +92,7 @@ def attention(
query, query,
key_cache, key_cache,
value_cache, value_cache,
num_key_value_heads, kv_head_mapping,
softmax_scale, softmax_scale,
block_tables, block_tables,
input_lengths, input_lengths,