mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
use custom vllm with kv_head_mapping
This commit is contained in:
parent
0604c5cb83
commit
d4da0d4d97
@ -211,7 +211,7 @@ COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /op
|
||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Install vllm/flash-attention dependencies
|
||||
RUN pip install einops py-cpuinfo prometheus_client --no-cache-dir
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
|
@ -1,8 +1,8 @@
|
||||
/// Batching and inference logic
|
||||
use crate::validation::{Validation, ValidationError};
|
||||
use crate::{
|
||||
ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig,
|
||||
Message, PrefillToken, Queue, Token,
|
||||
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
||||
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
|
||||
};
|
||||
use futures::future::try_join_all;
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
@ -86,7 +86,18 @@ impl Infer {
|
||||
|
||||
let chat_template = tokenizer_config
|
||||
.chat_template
|
||||
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
||||
.and_then(|t| match t {
|
||||
ChatTemplateVersions::Single(template) => Some(template),
|
||||
ChatTemplateVersions::Multiple(templates) => templates
|
||||
.into_iter()
|
||||
.find(|t| t.name == "default")
|
||||
.map(|t| t.template),
|
||||
})
|
||||
.map(|t| {
|
||||
// .strip() is not supported in minijinja
|
||||
let t = t.replace(".strip()", " | trim");
|
||||
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
||||
});
|
||||
|
||||
// Inference limit with a semaphore
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||
@ -1099,7 +1110,7 @@ mod tests {
|
||||
ChatTemplateTestItem {
|
||||
name: "_base",
|
||||
chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
@ -1110,7 +1121,7 @@ mod tests {
|
||||
ChatTemplateTestItem {
|
||||
name: "blenderbot",
|
||||
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
@ -1121,7 +1132,7 @@ mod tests {
|
||||
ChatTemplateTestItem {
|
||||
name: "blenderbot_small",
|
||||
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
@ -1132,7 +1143,7 @@ mod tests {
|
||||
ChatTemplateTestItem {
|
||||
name: "bloom",
|
||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
@ -1143,7 +1154,7 @@ mod tests {
|
||||
ChatTemplateTestItem {
|
||||
name: "gpt_neox",
|
||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
@ -1154,38 +1165,37 @@ mod tests {
|
||||
ChatTemplateTestItem {
|
||||
name: "gpt2",
|
||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
},
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "llama",
|
||||
// NOTE: the `.strip()` has been replaced with `| trim` in the following template
|
||||
chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content | trim + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat_with_system.clone(),
|
||||
add_generation_prompt: true,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
},
|
||||
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]"
|
||||
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "whisper",
|
||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: true,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
},
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
|
||||
}
|
||||
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
|
||||
},
|
||||
];
|
||||
|
||||
#[allow(unused_variables)] // name is unused
|
||||
@ -1211,7 +1221,7 @@ mod tests {
|
||||
messages: example_chat_with_system.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("</s>")
|
||||
eos_token: Some("</s>"),
|
||||
},
|
||||
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
@ -1237,7 +1247,7 @@ mod tests {
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("</s>"),
|
||||
},
|
||||
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>"
|
||||
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "HuggingFaceH4/zephyr-7b-gemma-v0.1",
|
||||
@ -1259,7 +1269,7 @@ mod tests {
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
},
|
||||
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]"
|
||||
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
@ -1360,7 +1370,7 @@ mod tests {
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
},
|
||||
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>"
|
||||
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "internlm/internlm2-chat-7b",
|
||||
@ -1443,7 +1453,7 @@ mod tests {
|
||||
eos_token: Some("</s>"),
|
||||
},
|
||||
target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
|
||||
}
|
||||
},
|
||||
];
|
||||
|
||||
#[allow(unused_variables)] // name is unused
|
||||
|
@ -48,9 +48,22 @@ pub struct HubModelInfo {
|
||||
pub pipeline_tag: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Default)]
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChatTemplate {
|
||||
name: String,
|
||||
template: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ChatTemplateVersions {
|
||||
Single(String),
|
||||
Multiple(Vec<ChatTemplate>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct HubTokenizerConfig {
|
||||
pub chat_template: Option<String>,
|
||||
pub chat_template: Option<ChatTemplateVersions>,
|
||||
pub completion_template: Option<String>,
|
||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||
pub bos_token: Option<String>,
|
||||
|
@ -1,10 +1,10 @@
|
||||
vllm-cuda:
|
||||
# Clone vllm
|
||||
pip install -U ninja packaging --no-cache-dir
|
||||
git clone https://github.com/vllm-project/vllm.git vllm
|
||||
git clone https://github.com/OlivierDehaene/vllm.git vllm
|
||||
|
||||
build-vllm-cuda: vllm-cuda
|
||||
cd vllm && git fetch && git checkout b7782002e1da25de77e0b1890ff8b72dd4df917c
|
||||
cd vllm && git fetch && git checkout 2b42fc4826258c961f22b5172ada41a9e0c87686
|
||||
cd vllm && python setup.py build
|
||||
|
||||
install-vllm-cuda: build-vllm-cuda
|
||||
|
@ -23,7 +23,6 @@ import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
@ -47,14 +46,18 @@ else:
|
||||
class CohereLayerNorm(nn.Module):
|
||||
def __init__(self, prefix, weights, eps):
|
||||
super().__init__()
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
self.weight = nn.Parameter(weight)
|
||||
# Fake weights
|
||||
self.ones = weight.new_ones(weight.shape[1])
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||
# if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||
if True:
|
||||
hidden_states = hidden_states.reshape(
|
||||
-1, self.weight.shape[0], self.weight.shape[1]
|
||||
)
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
mean = hidden_states.mean(-1, keepdim=True)
|
||||
@ -62,6 +65,7 @@ class CohereLayerNorm(nn.Module):
|
||||
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = self.weight.to(torch.float32) * hidden_states
|
||||
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
(
|
||||
@ -95,64 +99,6 @@ class CohereLayerNorm(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CohereConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256000,
|
||||
hidden_size=8192,
|
||||
intermediate_size=22528,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=64,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=5,
|
||||
eos_token_id=255001,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
logit_scale=1.0,
|
||||
use_qk_norm=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.logit_scale = logit_scale
|
||||
self.use_qk_norm = use_qk_norm
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
@ -236,23 +182,16 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
|
||||
self.use_qk_norm = config.use_qk_norm
|
||||
if self.use_qk_norm:
|
||||
rank = weights.process_group.rank()
|
||||
self.q_norm = CohereLayerNorm(
|
||||
prefix=f"{prefix}.q_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.q_norm.weight.data = self.q_norm.weight[
|
||||
self.num_heads * rank : self.num_heads * (rank + 1)
|
||||
]
|
||||
self.k_norm = CohereLayerNorm(
|
||||
prefix=f"{prefix}.k_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.k_norm.weight.data = self.k_norm.weight[
|
||||
self.num_key_value_heads * rank : self.num_key_value_heads * (rank + 1)
|
||||
]
|
||||
else:
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
@ -263,6 +202,10 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -322,7 +265,7 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.num_key_value_heads,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -375,6 +375,10 @@ class DbrxAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -430,7 +434,7 @@ class DbrxAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.num_key_value_heads,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -188,6 +188,10 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -240,7 +244,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.num_key_value_heads,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -176,6 +176,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -228,7 +232,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.num_key_value_heads,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -222,6 +222,10 @@ class MixtralAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -281,7 +285,7 @@ class MixtralAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.num_key_value_heads,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -120,6 +120,9 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.dense = load_row(
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||
)
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -165,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
qkv[:, 0],
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.num_heads,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -140,6 +140,10 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -202,7 +206,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.num_key_value_heads,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -104,6 +104,10 @@ class Qwen2Attention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -163,7 +167,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.num_key_value_heads,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -151,6 +151,15 @@ class FlashRWAttention(torch.nn.Module):
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||
)
|
||||
|
||||
if self.num_heads_kv == 1:
|
||||
self.kv_head_mapping = torch.zeros(
|
||||
self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
else:
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -204,7 +213,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.num_heads_kv,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
@ -263,6 +272,10 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||
)
|
||||
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_groups, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -319,7 +332,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.num_groups,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -241,6 +241,9 @@ class FlashMQAttention(torch.nn.Module):
|
||||
self.c_proj = load_row(
|
||||
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
||||
)
|
||||
self.kv_head_mapping = torch.zeros(
|
||||
self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -289,7 +292,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.num_heads,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -190,6 +190,9 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
bias=config.use_bias,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -249,7 +252,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.num_key_value_heads,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
@ -165,6 +165,11 @@ class FlashCausalLMBatch(Batch):
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
if (
|
||||
tokenized_input[0] == tokenizer.bos_token_id
|
||||
and tokenized_input[1] == tokenizer.bos_token_id
|
||||
):
|
||||
tokenized_input = tokenized_input[1:]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
|
@ -3,13 +3,12 @@ import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.cohere import AutoTokenizer, CohereConfig
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||
FlashCohereForCausalLM,
|
||||
CohereConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
|
@ -21,7 +21,7 @@ def attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
num_key_value_heads: torch.Tensor,
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
@ -60,7 +60,7 @@ def attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_key_value_heads,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
@ -92,7 +92,7 @@ def attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_key_value_heads,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
|
Loading…
Reference in New Issue
Block a user