mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
fix: fix CohereForAI/c4ai-command-r-plus (#1707)
@Narsil @drbh this will update flash attention v2 and vllm. You will need to re-install them.
This commit is contained in:
parent
2b2f4dee94
commit
a1b65e5919
@ -502,6 +502,9 @@ fn shard_manager(
|
|||||||
// Copy current process env
|
// Copy current process env
|
||||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
|
||||||
|
// Remove LOG_LEVEL if present
|
||||||
|
envs.retain(|(name, _)| name != "LOG_LEVEL");
|
||||||
|
|
||||||
// Max total tokens
|
// Max total tokens
|
||||||
envs.push(("MAX_TOTAL_TOKENS".into(), max_total_tokens.to_string().into()));
|
envs.push(("MAX_TOTAL_TOKENS".into(), max_total_tokens.to_string().into()));
|
||||||
|
|
||||||
@ -594,6 +597,7 @@ fn shard_manager(
|
|||||||
tracing::info!("Starting shard");
|
tracing::info!("Starting shard");
|
||||||
let mut p = match Command::new("text-generation-server")
|
let mut p = match Command::new("text-generation-server")
|
||||||
.args(shard_args)
|
.args(shard_args)
|
||||||
|
.env_clear()
|
||||||
.envs(envs)
|
.envs(envs)
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped())
|
.stderr(Stdio::piped())
|
||||||
@ -832,6 +836,9 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||||||
// Copy current process env
|
// Copy current process env
|
||||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
|
||||||
|
// Remove LOG_LEVEL if present
|
||||||
|
envs.retain(|(name, _)| name != "LOG_LEVEL");
|
||||||
|
|
||||||
// Disable progress bar
|
// Disable progress bar
|
||||||
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
|
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
|
||||||
|
|
||||||
@ -866,6 +873,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||||||
tracing::info!("Starting download process.");
|
tracing::info!("Starting download process.");
|
||||||
let mut download_process = match Command::new("text-generation-server")
|
let mut download_process = match Command::new("text-generation-server")
|
||||||
.args(download_args)
|
.args(download_args)
|
||||||
|
.env_clear()
|
||||||
.envs(envs)
|
.envs(envs)
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped())
|
.stderr(Stdio::piped())
|
||||||
|
@ -3,8 +3,8 @@
|
|||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig,
|
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
||||||
Message, PrefillToken, Queue, Token,
|
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
|
||||||
};
|
};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
@ -97,7 +97,18 @@ impl Infer {
|
|||||||
|
|
||||||
let chat_template = tokenizer_config
|
let chat_template = tokenizer_config
|
||||||
.chat_template
|
.chat_template
|
||||||
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
.and_then(|t| match t {
|
||||||
|
ChatTemplateVersions::Single(template) => Some(template),
|
||||||
|
ChatTemplateVersions::Multiple(templates) => templates
|
||||||
|
.into_iter()
|
||||||
|
.find(|t| t.name == "default")
|
||||||
|
.map(|t| t.template),
|
||||||
|
})
|
||||||
|
.map(|t| {
|
||||||
|
// .strip() is not supported in minijinja
|
||||||
|
let t = t.replace(".strip()", " | trim");
|
||||||
|
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
||||||
|
});
|
||||||
|
|
||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
@ -1110,7 +1121,7 @@ mod tests {
|
|||||||
ChatTemplateTestItem {
|
ChatTemplateTestItem {
|
||||||
name: "_base",
|
name: "_base",
|
||||||
chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
||||||
input: ChatTemplateInputs{
|
input: ChatTemplateInputs {
|
||||||
messages: example_chat.clone(),
|
messages: example_chat.clone(),
|
||||||
add_generation_prompt: false,
|
add_generation_prompt: false,
|
||||||
bos_token: Some(""),
|
bos_token: Some(""),
|
||||||
@ -1121,7 +1132,7 @@ mod tests {
|
|||||||
ChatTemplateTestItem {
|
ChatTemplateTestItem {
|
||||||
name: "blenderbot",
|
name: "blenderbot",
|
||||||
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
|
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
|
||||||
input: ChatTemplateInputs{
|
input: ChatTemplateInputs {
|
||||||
messages: example_chat.clone(),
|
messages: example_chat.clone(),
|
||||||
add_generation_prompt: false,
|
add_generation_prompt: false,
|
||||||
bos_token: Some(""),
|
bos_token: Some(""),
|
||||||
@ -1132,7 +1143,7 @@ mod tests {
|
|||||||
ChatTemplateTestItem {
|
ChatTemplateTestItem {
|
||||||
name: "blenderbot_small",
|
name: "blenderbot_small",
|
||||||
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
|
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
|
||||||
input: ChatTemplateInputs{
|
input: ChatTemplateInputs {
|
||||||
messages: example_chat.clone(),
|
messages: example_chat.clone(),
|
||||||
add_generation_prompt: false,
|
add_generation_prompt: false,
|
||||||
bos_token: Some(""),
|
bos_token: Some(""),
|
||||||
@ -1143,7 +1154,7 @@ mod tests {
|
|||||||
ChatTemplateTestItem {
|
ChatTemplateTestItem {
|
||||||
name: "bloom",
|
name: "bloom",
|
||||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||||
input: ChatTemplateInputs{
|
input: ChatTemplateInputs {
|
||||||
messages: example_chat.clone(),
|
messages: example_chat.clone(),
|
||||||
add_generation_prompt: false,
|
add_generation_prompt: false,
|
||||||
bos_token: Some(""),
|
bos_token: Some(""),
|
||||||
@ -1154,7 +1165,7 @@ mod tests {
|
|||||||
ChatTemplateTestItem {
|
ChatTemplateTestItem {
|
||||||
name: "gpt_neox",
|
name: "gpt_neox",
|
||||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||||
input: ChatTemplateInputs{
|
input: ChatTemplateInputs {
|
||||||
messages: example_chat.clone(),
|
messages: example_chat.clone(),
|
||||||
add_generation_prompt: false,
|
add_generation_prompt: false,
|
||||||
bos_token: Some(""),
|
bos_token: Some(""),
|
||||||
@ -1165,38 +1176,37 @@ mod tests {
|
|||||||
ChatTemplateTestItem {
|
ChatTemplateTestItem {
|
||||||
name: "gpt2",
|
name: "gpt2",
|
||||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||||
input: ChatTemplateInputs{
|
input: ChatTemplateInputs {
|
||||||
messages: example_chat.clone(),
|
messages: example_chat.clone(),
|
||||||
add_generation_prompt: false,
|
add_generation_prompt: false,
|
||||||
bos_token: Some(""),
|
bos_token: Some(""),
|
||||||
eos_token: Some("<|endoftext|>"),
|
eos_token: Some("<|endoftext|>"),
|
||||||
},
|
},
|
||||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
|
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
|
||||||
},
|
},
|
||||||
ChatTemplateTestItem {
|
ChatTemplateTestItem {
|
||||||
name: "llama",
|
name: "llama",
|
||||||
// NOTE: the `.strip()` has been replaced with `| trim` in the following template
|
// NOTE: the `.strip()` has been replaced with `| trim` in the following template
|
||||||
chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content | trim + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
|
chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content | trim + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
|
||||||
input: ChatTemplateInputs{
|
input: ChatTemplateInputs {
|
||||||
messages: example_chat_with_system.clone(),
|
messages: example_chat_with_system.clone(),
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
bos_token: Some("<s>"),
|
bos_token: Some("<s>"),
|
||||||
eos_token: Some("</s>"),
|
eos_token: Some("</s>"),
|
||||||
},
|
},
|
||||||
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]"
|
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
|
||||||
},
|
},
|
||||||
ChatTemplateTestItem {
|
ChatTemplateTestItem {
|
||||||
name: "whisper",
|
name: "whisper",
|
||||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||||
input: ChatTemplateInputs{
|
input: ChatTemplateInputs {
|
||||||
messages: example_chat.clone(),
|
messages: example_chat.clone(),
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
bos_token: Some(""),
|
bos_token: Some(""),
|
||||||
eos_token: Some("<|endoftext|>"),
|
eos_token: Some("<|endoftext|>"),
|
||||||
},
|
},
|
||||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
|
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
|
||||||
}
|
},
|
||||||
|
|
||||||
];
|
];
|
||||||
|
|
||||||
#[allow(unused_variables)] // name is unused
|
#[allow(unused_variables)] // name is unused
|
||||||
@ -1222,7 +1232,7 @@ mod tests {
|
|||||||
messages: example_chat_with_system.clone(),
|
messages: example_chat_with_system.clone(),
|
||||||
add_generation_prompt: false,
|
add_generation_prompt: false,
|
||||||
bos_token: Some(""),
|
bos_token: Some(""),
|
||||||
eos_token: Some("</s>")
|
eos_token: Some("</s>"),
|
||||||
},
|
},
|
||||||
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>",
|
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>",
|
||||||
},
|
},
|
||||||
@ -1248,7 +1258,7 @@ mod tests {
|
|||||||
bos_token: Some(""),
|
bos_token: Some(""),
|
||||||
eos_token: Some("</s>"),
|
eos_token: Some("</s>"),
|
||||||
},
|
},
|
||||||
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>"
|
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>",
|
||||||
},
|
},
|
||||||
ChatTemplateTestItem {
|
ChatTemplateTestItem {
|
||||||
name: "HuggingFaceH4/zephyr-7b-gemma-v0.1",
|
name: "HuggingFaceH4/zephyr-7b-gemma-v0.1",
|
||||||
@ -1270,7 +1280,7 @@ mod tests {
|
|||||||
bos_token: Some("<s>"),
|
bos_token: Some("<s>"),
|
||||||
eos_token: Some("</s>"),
|
eos_token: Some("</s>"),
|
||||||
},
|
},
|
||||||
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]"
|
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]",
|
||||||
},
|
},
|
||||||
ChatTemplateTestItem {
|
ChatTemplateTestItem {
|
||||||
name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
@ -1287,7 +1297,7 @@ mod tests {
|
|||||||
name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b",
|
name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b",
|
||||||
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
||||||
input: ChatTemplateInputs {
|
input: ChatTemplateInputs {
|
||||||
messages: example_chat.clone(),
|
messages: example_chat.clone(),
|
||||||
add_generation_prompt: false,
|
add_generation_prompt: false,
|
||||||
bos_token: Some("<s>"),
|
bos_token: Some("<s>"),
|
||||||
eos_token: Some("</s>"),
|
eos_token: Some("</s>"),
|
||||||
@ -1371,7 +1381,7 @@ mod tests {
|
|||||||
bos_token: Some("<s>"),
|
bos_token: Some("<s>"),
|
||||||
eos_token: Some("</s>"),
|
eos_token: Some("</s>"),
|
||||||
},
|
},
|
||||||
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>"
|
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>",
|
||||||
},
|
},
|
||||||
ChatTemplateTestItem {
|
ChatTemplateTestItem {
|
||||||
name: "internlm/internlm2-chat-7b",
|
name: "internlm/internlm2-chat-7b",
|
||||||
@ -1454,7 +1464,7 @@ mod tests {
|
|||||||
eos_token: Some("</s>"),
|
eos_token: Some("</s>"),
|
||||||
},
|
},
|
||||||
target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
|
target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
|
||||||
}
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
#[allow(unused_variables)] // name is unused
|
#[allow(unused_variables)] // name is unused
|
||||||
|
@ -49,9 +49,22 @@ pub struct HubModelInfo {
|
|||||||
pub pipeline_tag: Option<String>,
|
pub pipeline_tag: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Default)]
|
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||||
|
pub struct ChatTemplate {
|
||||||
|
name: String,
|
||||||
|
template: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ChatTemplateVersions {
|
||||||
|
Single(String),
|
||||||
|
Multiple(Vec<ChatTemplate>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Default)]
|
||||||
pub struct HubTokenizerConfig {
|
pub struct HubTokenizerConfig {
|
||||||
pub chat_template: Option<String>,
|
pub chat_template: Option<ChatTemplateVersions>,
|
||||||
pub completion_template: Option<String>,
|
pub completion_template: Option<String>,
|
||||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||||
pub bos_token: Option<String>,
|
pub bos_token: Option<String>,
|
||||||
@ -978,7 +991,10 @@ mod tests {
|
|||||||
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
|
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
|
||||||
|
|
||||||
// check that we successfully parsed the tokens
|
// check that we successfully parsed the tokens
|
||||||
assert_eq!(config.chat_template, Some("test".to_string()));
|
assert_eq!(
|
||||||
|
config.chat_template,
|
||||||
|
Some(ChatTemplateVersions::Single("test".to_string()))
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
config.bos_token,
|
config.bos_token,
|
||||||
Some("<|begin▁of▁sentence|>".to_string())
|
Some("<|begin▁of▁sentence|>".to_string())
|
||||||
@ -1010,7 +1026,10 @@ mod tests {
|
|||||||
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
|
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
|
||||||
|
|
||||||
// check that we successfully parsed the tokens
|
// check that we successfully parsed the tokens
|
||||||
assert_eq!(config.chat_template, Some("test".to_string()));
|
assert_eq!(
|
||||||
|
config.chat_template,
|
||||||
|
Some(ChatTemplateVersions::Single("test".to_string()))
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
config.bos_token,
|
config.bos_token,
|
||||||
Some("<|begin▁of▁sentence|>".to_string())
|
Some("<|begin▁of▁sentence|>".to_string())
|
||||||
|
@ -17,9 +17,6 @@ gen-server:
|
|||||||
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
touch text_generation_server/pb/__init__.py
|
touch text_generation_server/pb/__init__.py
|
||||||
|
|
||||||
install-megablocks:
|
|
||||||
pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
|
|
||||||
|
|
||||||
install: gen-server
|
install: gen-server
|
||||||
pip install pip --upgrade
|
pip install pip --upgrade
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
|
||||||
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69
|
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
vllm-cuda:
|
vllm-cuda:
|
||||||
# Clone vllm
|
# Clone vllm
|
||||||
pip install -U ninja packaging --no-cache-dir
|
pip install -U ninja packaging --no-cache-dir
|
||||||
git clone https://github.com/vllm-project/vllm.git vllm
|
git clone https://github.com/OlivierDehaene/vllm.git vllm
|
||||||
|
|
||||||
build-vllm-cuda: vllm-cuda
|
build-vllm-cuda: vllm-cuda
|
||||||
cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
cd vllm && git fetch && git checkout 4bec8cee87f6bb8cebaec297029713cd2082e0b2
|
||||||
cd vllm && python setup.py build
|
cd vllm && python setup.py build
|
||||||
|
|
||||||
install-vllm-cuda: build-vllm-cuda
|
install-vllm-cuda: build-vllm-cuda
|
||||||
|
@ -43,7 +43,7 @@ class CacheManager:
|
|||||||
]
|
]
|
||||||
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||||
self.slots = torch.arange(
|
self.slots = torch.arange(
|
||||||
0, num_blocks * self.block_size, dtype=torch.int32
|
0, num_blocks * self.block_size, dtype=torch.int64
|
||||||
).view(num_blocks, self.block_size)
|
).view(num_blocks, self.block_size)
|
||||||
|
|
||||||
def allocate(
|
def allocate(
|
||||||
|
@ -23,10 +23,10 @@ import torch.distributed
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
|
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -34,65 +34,106 @@ from text_generation_server.utils.layers import (
|
|||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastRMSNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if IS_CUDA_SYSTEM:
|
||||||
|
import dropout_layer_norm
|
||||||
|
else:
|
||||||
|
dropout_layer_norm = None
|
||||||
|
|
||||||
class CohereConfig(PretrainedConfig):
|
|
||||||
def __init__(
|
class CohereRotary(PositionRotaryEmbedding):
|
||||||
|
def forward(
|
||||||
self,
|
self,
|
||||||
vocab_size=256000,
|
query: torch.Tensor,
|
||||||
hidden_size=8192,
|
key: torch.Tensor,
|
||||||
intermediate_size=22528,
|
cos: torch.Tensor,
|
||||||
num_hidden_layers=40,
|
sin: torch.Tensor,
|
||||||
num_attention_heads=64,
|
|
||||||
num_key_value_heads=None,
|
|
||||||
hidden_act="silu",
|
|
||||||
max_position_embeddings=8192,
|
|
||||||
initializer_range=0.02,
|
|
||||||
layer_norm_eps=1e-5,
|
|
||||||
use_cache=True,
|
|
||||||
pad_token_id=0,
|
|
||||||
bos_token_id=5,
|
|
||||||
eos_token_id=255001,
|
|
||||||
pretraining_tp=1,
|
|
||||||
tie_word_embeddings=True,
|
|
||||||
rope_theta=10000.0,
|
|
||||||
attention_bias=False,
|
|
||||||
attention_dropout=0.0,
|
|
||||||
logit_scale=1.0,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
# Such controlflows may add some overhead.
|
||||||
self.max_position_embeddings = max_position_embeddings
|
if IS_CUDA_SYSTEM:
|
||||||
self.hidden_size = hidden_size
|
import rotary_emb
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
|
|
||||||
# for backward compatibility
|
q1 = query[..., ::2]
|
||||||
if num_key_value_heads is None:
|
q2 = query[..., 1::2]
|
||||||
num_key_value_heads = num_attention_heads
|
|
||||||
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.layer_norm_eps = layer_norm_eps
|
|
||||||
self.pretraining_tp = pretraining_tp
|
|
||||||
self.use_cache = use_cache
|
|
||||||
self.rope_theta = rope_theta
|
|
||||||
self.attention_bias = attention_bias
|
|
||||||
self.attention_dropout = attention_dropout
|
|
||||||
self.logit_scale = logit_scale
|
|
||||||
|
|
||||||
super().__init__(
|
k1 = key[..., ::2]
|
||||||
pad_token_id=pad_token_id,
|
k2 = key[..., 1::2]
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
elif IS_ROCM_SYSTEM:
|
||||||
**kwargs,
|
from vllm import pos_encoding_ops
|
||||||
|
|
||||||
|
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||||
|
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||||
|
|
||||||
|
head_size = query.shape[-1]
|
||||||
|
|
||||||
|
# Inplace operation, updating query and key.
|
||||||
|
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CohereLayerNorm(nn.Module):
|
||||||
|
def __init__(self, prefix, weights, eps):
|
||||||
|
super().__init__()
|
||||||
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
|
self.weight = nn.Parameter(weight)
|
||||||
|
# Fake weights
|
||||||
|
self.ones = weight.new_ones(weight.shape[1])
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||||
|
hidden_states = hidden_states.reshape(
|
||||||
|
-1, self.weight.shape[0], self.weight.shape[1]
|
||||||
|
)
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
mean = hidden_states.mean(-1, keepdim=True)
|
||||||
|
hidden_states_minus_mean = hidden_states - mean
|
||||||
|
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
|
||||||
|
hidden_states = self.weight.to(torch.float32) * hidden_states
|
||||||
|
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
||||||
|
return hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
(
|
||||||
|
hidden_states,
|
||||||
|
*rest,
|
||||||
|
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
None,
|
||||||
|
self.ones,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.eps,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Required to apply one weight matrix per head
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
-1, self.weight.shape[0], self.weight.shape[1]
|
||||||
|
)
|
||||||
|
hidden_states = self.weight * hidden_states
|
||||||
|
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
@ -154,7 +195,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = CohereRotary.static(
|
||||||
config=config,
|
config=config,
|
||||||
dim=self.head_size,
|
dim=self.head_size,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
@ -175,6 +216,22 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
|
self.use_qk_norm = config.use_qk_norm
|
||||||
|
if self.use_qk_norm:
|
||||||
|
self.q_norm = CohereLayerNorm(
|
||||||
|
prefix=f"{prefix}.q_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
self.k_norm = CohereLayerNorm(
|
||||||
|
prefix=f"{prefix}.k_norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_norm = None
|
||||||
|
self.k_norm = None
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
@ -199,21 +256,28 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
query, kv = qkv.split(
|
query, key, value = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_size * self.num_heads,
|
self.head_size * self.num_heads,
|
||||||
2 * self.head_size * self.num_key_value_heads,
|
self.head_size * self.num_key_value_heads,
|
||||||
|
self.head_size * self.num_key_value_heads,
|
||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
query = query.reshape(-1, self.head_size)
|
||||||
|
key = key.reshape(-1, self.head_size)
|
||||||
|
query = self.q_norm(query.contiguous())
|
||||||
|
key = self.k_norm(key.contiguous())
|
||||||
|
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
value = value.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
@ -223,8 +287,8 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
flash_attn.attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
key,
|
||||||
torch.select(kv, dim=1, index=1),
|
value,
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
@ -298,7 +362,7 @@ class FlashCohereLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastLayerNorm.load_no_bias(
|
||||||
prefix=f"{prefix}.input_layernorm",
|
prefix=f"{prefix}.input_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
@ -362,7 +426,7 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastLayerNorm.load_no_bias(
|
||||||
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
|
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,14 +16,13 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
@ -37,14 +36,6 @@ from text_generation_server.utils.layers import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
HAS_MEGABLOCKS = True
|
|
||||||
try:
|
|
||||||
import stk
|
|
||||||
import megablocks.ops as ops
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("Dbrx: megablocks is not installed")
|
|
||||||
HAS_MEGABLOCKS = False
|
|
||||||
|
|
||||||
|
|
||||||
class DbrxAttentionConfig(PretrainedConfig):
|
class DbrxAttentionConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -531,18 +522,6 @@ def round_up(x: torch.Tensor, value: int):
|
|||||||
|
|
||||||
|
|
||||||
class BlockSparseMoE(nn.Module):
|
class BlockSparseMoE(nn.Module):
|
||||||
"""
|
|
||||||
Built on the paper and library Megablocks as described in
|
|
||||||
https://arxiv.org/abs/2211.15841. This implementation is
|
|
||||||
strictly equivalent to standard MoE with full capacity (no
|
|
||||||
dropped tokens). It's faster since it formulates MoE operations
|
|
||||||
in terms of block-sparse operations to accomodate imbalanced
|
|
||||||
assignments of tokens to experts, whereas standard MoE either
|
|
||||||
(1) drop tokens at the cost of reduced performance or (2) set
|
|
||||||
capacity factor to number of experts and thus waste computation
|
|
||||||
and memory on padding.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, prefix, config: DbrxConfig, weights):
|
def __init__(self, prefix, config: DbrxConfig, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.moe_normalize_expert_weights = (
|
self.moe_normalize_expert_weights = (
|
||||||
@ -572,241 +551,40 @@ class BlockSparseMoE(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||||
self.w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights)
|
w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view(
|
||||||
self.w2 = _load_experts(config, f"{prefix}.experts.mlp.w2", weights)
|
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||||
self.v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights)
|
)
|
||||||
|
v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view(
|
||||||
self.offsets = None
|
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||||
self.offsets_block_rows = 0
|
)
|
||||||
|
self.wv1 = torch.cat([w1, v1], dim=1)
|
||||||
|
self.w2 = (
|
||||||
|
_load_experts(config, f"{prefix}.experts.mlp.w2", weights)
|
||||||
|
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
# Calculate the number of bits needed to represent the expert indices
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# so that we can pass it to radix sort.
|
# router_logits: (num_tokens, n_experts)
|
||||||
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
|
router_logits = self.gate(x)
|
||||||
self.blocking = 128
|
out = fused_moe(
|
||||||
self.quantize_scatter_num_bits = -1
|
|
||||||
|
|
||||||
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
|
|
||||||
padded_tokens, _ = x.size()
|
|
||||||
assert padded_tokens % self.blocking == 0
|
|
||||||
assert self.ffn_dim % self.blocking == 0
|
|
||||||
|
|
||||||
# Offsets for the sparse matrix. All rows have the
|
|
||||||
# same number of nonzero blocks dictated by the
|
|
||||||
# dimensionality of a single expert.
|
|
||||||
block_rows = padded_tokens // self.blocking
|
|
||||||
blocks_per_row = self.ffn_dim // self.blocking
|
|
||||||
if self.offsets is None or block_rows > self.offsets_block_rows:
|
|
||||||
self.offsets = torch.arange(
|
|
||||||
0,
|
|
||||||
block_rows * blocks_per_row + 1,
|
|
||||||
blocks_per_row,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=x.device,
|
|
||||||
)
|
|
||||||
self.offsets_block_rows = block_rows
|
|
||||||
offsets = self.offsets
|
|
||||||
else:
|
|
||||||
offsets = self.offsets[: block_rows + 1]
|
|
||||||
|
|
||||||
# Indices for the sparse matrix. The indices for
|
|
||||||
# the intermediate matrix are dynamic depending
|
|
||||||
# on the mapping of tokens to experts.
|
|
||||||
column_indices = ops.topology(
|
|
||||||
padded_bins, self.blocking, block_rows, blocks_per_row
|
|
||||||
)
|
|
||||||
|
|
||||||
# For now, use meta init to save the device memory.
|
|
||||||
data = torch.empty(
|
|
||||||
column_indices.numel(),
|
|
||||||
self.blocking,
|
|
||||||
self.blocking,
|
|
||||||
dtype=x.dtype,
|
|
||||||
device="meta",
|
|
||||||
)
|
|
||||||
shape = (padded_tokens, self.ffn_dim * self.num_experts)
|
|
||||||
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
|
||||||
return stk.Matrix(
|
|
||||||
shape,
|
|
||||||
data,
|
|
||||||
row_indices,
|
|
||||||
column_indices,
|
|
||||||
offsets,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
|
|
||||||
# Sort the expert ids to produce the scatter/gather
|
|
||||||
# indices for the permutation.
|
|
||||||
# selected_experts = selected_experts.int()
|
|
||||||
|
|
||||||
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
|
|
||||||
# and indices == how to sort tokens?
|
|
||||||
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
|
|
||||||
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
|
|
||||||
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
|
|
||||||
|
|
||||||
# Histogram the expert ids to identify the number of
|
|
||||||
# tokens routed to each expert.
|
|
||||||
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
|
|
||||||
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
|
|
||||||
|
|
||||||
# Round the token counts up to the block size used in
|
|
||||||
# the matrix muliplications. Caculate the starting
|
|
||||||
# position of each bin.
|
|
||||||
|
|
||||||
# List of size num_experts
|
|
||||||
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
|
|
||||||
# padded_tokens_per_expert => [128, O, 128, ...]
|
|
||||||
|
|
||||||
# Cumulative selected experts per token
|
|
||||||
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
|
||||||
padded_bins = promote_scalar(padded_bins)
|
|
||||||
# padded_bins => [128, 128, 256, ...]
|
|
||||||
|
|
||||||
# Calculate the bin bounds for the sorted tokens.
|
|
||||||
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
|
||||||
bins = promote_scalar(bins)
|
|
||||||
# bins => [3, 3, 5, ...]
|
|
||||||
|
|
||||||
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
|
||||||
|
|
||||||
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
x: (sequence_length, model_dim)
|
|
||||||
gate_logits: (sequence_length, n_experts)
|
|
||||||
"""
|
|
||||||
# optional reshape
|
|
||||||
input_shape = x.shape
|
|
||||||
x = x.view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
# gate_logits: (sequence_length, n_experts)
|
|
||||||
gate_logits = self.gate(x)
|
|
||||||
selected_experts, weights = select_experts(
|
|
||||||
gate_logits, self.top_k, self.moe_normalize_expert_weights
|
|
||||||
)
|
|
||||||
|
|
||||||
(
|
|
||||||
indices,
|
|
||||||
bin_ids,
|
|
||||||
bins,
|
|
||||||
padded_bins,
|
|
||||||
_,
|
|
||||||
) = self.indices_and_padded_bins(selected_experts)
|
|
||||||
|
|
||||||
# Permute tokens and pad to prepare expert computation
|
|
||||||
# (top_k * sequence_length + padding, model_dim)
|
|
||||||
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
|
|
||||||
|
|
||||||
# Create the sparse matrix topology
|
|
||||||
with torch.no_grad():
|
|
||||||
topo = self.topology(x, padded_bins)
|
|
||||||
|
|
||||||
# Perform the expert computation
|
|
||||||
# First Dense x Dense -> Sparse for w1 and v1,
|
|
||||||
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
|
||||||
x = stk.Matrix(
|
|
||||||
topo.size(),
|
|
||||||
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
|
|
||||||
* stk.ops.sdd(x, self.v1.t(), topo).data,
|
|
||||||
topo.row_indices,
|
|
||||||
topo.column_indices,
|
|
||||||
topo.offsets,
|
|
||||||
topo.column_indices_t,
|
|
||||||
topo.offsets_t,
|
|
||||||
topo.block_offsets_t,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Then Sparse x Dense -> Dense for w2
|
|
||||||
# (top_k * sequence_length + padding, model_dim)
|
|
||||||
x = stk.ops.dsd(x, self.w2)
|
|
||||||
|
|
||||||
# Permute back and remove padding
|
|
||||||
# (sequence_length, model_dim)
|
|
||||||
x = ops.padded_scatter(
|
|
||||||
x,
|
x,
|
||||||
indices,
|
self.wv1,
|
||||||
bin_ids,
|
self.w2,
|
||||||
weights,
|
router_logits,
|
||||||
bins,
|
|
||||||
padded_bins,
|
|
||||||
self.top_k,
|
self.top_k,
|
||||||
self.quantize_scatter_num_bits,
|
renormalize=self.moe_normalize_expert_weights,
|
||||||
).view(*input_shape)
|
inplace=True,
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
|
||||||
torch.distributed.all_reduce(x, group=self.process_group)
|
|
||||||
|
|
||||||
return x.view(*input_shape)
|
|
||||||
|
|
||||||
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
x: (sequence_length, model_dim)
|
|
||||||
gate_logits: (sequence_length, n_experts)
|
|
||||||
"""
|
|
||||||
# optional reshape
|
|
||||||
input_shape = x.shape
|
|
||||||
x = x.view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
# gate_logits: (sequence_length, n_experts)
|
|
||||||
gate_logits = self.gate(x)
|
|
||||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
|
||||||
weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
|
||||||
|
|
||||||
if self.top_k < self.num_experts:
|
|
||||||
_, not_selected_experts = torch.topk(
|
|
||||||
weights,
|
|
||||||
self.num_experts - self.top_k,
|
|
||||||
largest=False,
|
|
||||||
sorted=False,
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
# Mask not selected experts
|
|
||||||
weights.scatter_(1, not_selected_experts, 0)
|
|
||||||
|
|
||||||
# Re-normalize
|
|
||||||
if self.moe_normalize_expert_weights:
|
|
||||||
weights = weights / torch.norm(
|
|
||||||
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
|
|
||||||
)
|
|
||||||
weights = weights.to(x.dtype)
|
|
||||||
|
|
||||||
# Expand to [num_experts, sequence_length, model_dim]
|
|
||||||
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
|
||||||
|
|
||||||
# Permute to [num_experts, model_dim, ffn_dim]
|
|
||||||
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
|
||||||
0, 2, 1
|
|
||||||
)
|
)
|
||||||
v1 = self.v1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
|
||||||
0, 2, 1
|
|
||||||
)
|
|
||||||
|
|
||||||
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, v1)
|
|
||||||
|
|
||||||
out = torch.bmm(
|
|
||||||
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
|
||||||
)
|
|
||||||
# Mask not selected experts
|
|
||||||
out *= weights.t().view(self.num_experts, -1, 1)
|
|
||||||
|
|
||||||
# Sum experts
|
|
||||||
out = out.sum(0)
|
|
||||||
|
|
||||||
# Reduce sum
|
# Reduce sum
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
return out
|
return out.view(*x.shape)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
if len(x) > 256 and HAS_MEGABLOCKS:
|
|
||||||
return self.sparse_forward(x)
|
|
||||||
# This is faster when there is not a lot of tokens
|
|
||||||
return self.dense_forward(x)
|
|
||||||
|
|
||||||
|
|
||||||
class DenseMoE(nn.Module):
|
class DenseMoE(nn.Module):
|
||||||
|
@ -24,6 +24,7 @@ import torch.distributed
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
@ -41,14 +42,6 @@ from text_generation_server.utils.layers import (
|
|||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
HAS_MEGABLOCKS = True
|
|
||||||
try:
|
|
||||||
import stk
|
|
||||||
import megablocks.ops as ops
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("Mixtral: megablocks is not installed")
|
|
||||||
HAS_MEGABLOCKS = False
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralConfig(PretrainedConfig):
|
class MixtralConfig(PretrainedConfig):
|
||||||
model_type = "mixtral"
|
model_type = "mixtral"
|
||||||
@ -321,18 +314,6 @@ def round_up(x: torch.Tensor, value: int):
|
|||||||
|
|
||||||
|
|
||||||
class BlockSparseMoE(nn.Module):
|
class BlockSparseMoE(nn.Module):
|
||||||
"""
|
|
||||||
Built on the paper and library Megablocks as described in
|
|
||||||
https://arxiv.org/abs/2211.15841. This implementation is
|
|
||||||
strictly equivalent to standard MoE with full capacity (no
|
|
||||||
dropped tokens). It's faster since it formulates MoE operations
|
|
||||||
in terms of block-sparse operations to accomodate imbalanced
|
|
||||||
assignments of tokens to experts, whereas standard MoE either
|
|
||||||
(1) drop tokens at the cost of reduced performance or (2) set
|
|
||||||
capacity factor to number of experts and thus waste computation
|
|
||||||
and memory on padding.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, prefix, config: MixtralConfig, weights):
|
def __init__(self, prefix, config: MixtralConfig, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_dim = config.hidden_size
|
self.hidden_dim = config.hidden_size
|
||||||
@ -357,236 +338,40 @@ class BlockSparseMoE(nn.Module):
|
|||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||||
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights)
|
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
|
||||||
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
|
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||||
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights)
|
)
|
||||||
|
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
|
||||||
self.offsets = None
|
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||||
self.offsets_block_rows = 0
|
)
|
||||||
|
self.w13 = torch.cat([w1, w3], dim=1)
|
||||||
|
self.w2 = (
|
||||||
|
_load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||||
|
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
# Calculate the number of bits needed to represent the expert indices
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# so that we can pass it to radix sort.
|
# router_logits: (num_tokens, n_experts)
|
||||||
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
|
router_logits = self.gate(x)
|
||||||
self.blocking = 128
|
out = fused_moe(
|
||||||
self.quantize_scatter_num_bits = -1
|
|
||||||
|
|
||||||
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
|
|
||||||
padded_tokens, _ = x.size()
|
|
||||||
assert padded_tokens % self.blocking == 0
|
|
||||||
assert self.ffn_dim % self.blocking == 0
|
|
||||||
|
|
||||||
# Offsets for the sparse matrix. All rows have the
|
|
||||||
# same number of nonzero blocks dictated by the
|
|
||||||
# dimensionality of a single expert.
|
|
||||||
block_rows = padded_tokens // self.blocking
|
|
||||||
blocks_per_row = self.ffn_dim // self.blocking
|
|
||||||
if self.offsets is None or block_rows > self.offsets_block_rows:
|
|
||||||
self.offsets = torch.arange(
|
|
||||||
0,
|
|
||||||
block_rows * blocks_per_row + 1,
|
|
||||||
blocks_per_row,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=x.device,
|
|
||||||
)
|
|
||||||
self.offsets_block_rows = block_rows
|
|
||||||
offsets = self.offsets
|
|
||||||
else:
|
|
||||||
offsets = self.offsets[: block_rows + 1]
|
|
||||||
|
|
||||||
# Indices for the sparse matrix. The indices for
|
|
||||||
# the intermediate matrix are dynamic depending
|
|
||||||
# on the mapping of tokens to experts.
|
|
||||||
column_indices = ops.topology(
|
|
||||||
padded_bins, self.blocking, block_rows, blocks_per_row
|
|
||||||
)
|
|
||||||
|
|
||||||
# For now, use meta init to save the device memory.
|
|
||||||
data = torch.empty(
|
|
||||||
column_indices.numel(),
|
|
||||||
self.blocking,
|
|
||||||
self.blocking,
|
|
||||||
dtype=x.dtype,
|
|
||||||
device="meta",
|
|
||||||
)
|
|
||||||
shape = (padded_tokens, self.ffn_dim * self.num_experts)
|
|
||||||
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
|
||||||
return stk.Matrix(
|
|
||||||
shape,
|
|
||||||
data,
|
|
||||||
row_indices,
|
|
||||||
column_indices,
|
|
||||||
offsets,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
|
|
||||||
# Sort the expert ids to produce the scatter/gather
|
|
||||||
# indices for the permutation.
|
|
||||||
# selected_experts = selected_experts.int()
|
|
||||||
|
|
||||||
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
|
|
||||||
# and indices == how to sort tokens?
|
|
||||||
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
|
|
||||||
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
|
|
||||||
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
|
|
||||||
|
|
||||||
# Histogram the expert ids to identify the number of
|
|
||||||
# tokens routed to each expert.
|
|
||||||
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
|
|
||||||
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
|
|
||||||
|
|
||||||
# Round the token counts up to the block size used in
|
|
||||||
# the matrix muliplications. Caculate the starting
|
|
||||||
# position of each bin.
|
|
||||||
|
|
||||||
# List of size num_experts
|
|
||||||
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
|
|
||||||
# padded_tokens_per_expert => [128, O, 128, ...]
|
|
||||||
|
|
||||||
# Cumulative selected experts per token
|
|
||||||
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
|
||||||
padded_bins = promote_scalar(padded_bins)
|
|
||||||
# padded_bins => [128, 128, 256, ...]
|
|
||||||
|
|
||||||
# Calculate the bin bounds for the sorted tokens.
|
|
||||||
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
|
||||||
bins = promote_scalar(bins)
|
|
||||||
# bins => [3, 3, 5, ...]
|
|
||||||
|
|
||||||
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
|
||||||
|
|
||||||
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
x: (sequence_length, model_dim)
|
|
||||||
gate_logits: (sequence_length, n_experts)
|
|
||||||
"""
|
|
||||||
# optional reshape
|
|
||||||
input_shape = x.shape
|
|
||||||
x = x.view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
# gate_logits: (sequence_length, n_experts)
|
|
||||||
gate_logits = self.gate(x)
|
|
||||||
selected_experts, weights = select_experts(gate_logits, self.top_k)
|
|
||||||
|
|
||||||
(
|
|
||||||
indices,
|
|
||||||
bin_ids,
|
|
||||||
bins,
|
|
||||||
padded_bins,
|
|
||||||
_,
|
|
||||||
) = self.indices_and_padded_bins(selected_experts)
|
|
||||||
|
|
||||||
# Permute tokens and pad to prepare expert computation
|
|
||||||
# (top_k * sequence_length + padding, model_dim)
|
|
||||||
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
|
|
||||||
|
|
||||||
# Create the sparse matrix topology
|
|
||||||
with torch.no_grad():
|
|
||||||
topo = self.topology(x, padded_bins)
|
|
||||||
|
|
||||||
# Perform the expert computation
|
|
||||||
# First Dense x Dense -> Sparse for w1 and w3,
|
|
||||||
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
|
||||||
x = stk.Matrix(
|
|
||||||
topo.size(),
|
|
||||||
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
|
|
||||||
* stk.ops.sdd(x, self.w3.t(), topo).data,
|
|
||||||
topo.row_indices,
|
|
||||||
topo.column_indices,
|
|
||||||
topo.offsets,
|
|
||||||
topo.column_indices_t,
|
|
||||||
topo.offsets_t,
|
|
||||||
topo.block_offsets_t,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Then Sparse x Dense -> Dense for w2
|
|
||||||
# (top_k * sequence_length + padding, model_dim)
|
|
||||||
x = stk.ops.dsd(x, self.w2)
|
|
||||||
|
|
||||||
# Permute back and remove padding
|
|
||||||
# (sequence_length, model_dim)
|
|
||||||
x = ops.padded_scatter(
|
|
||||||
x,
|
x,
|
||||||
indices,
|
self.w13,
|
||||||
bin_ids,
|
self.w2,
|
||||||
weights,
|
router_logits,
|
||||||
bins,
|
|
||||||
padded_bins,
|
|
||||||
self.top_k,
|
self.top_k,
|
||||||
self.quantize_scatter_num_bits,
|
renormalize=True,
|
||||||
).view(*input_shape)
|
inplace=True,
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
|
||||||
torch.distributed.all_reduce(x, group=self.process_group)
|
|
||||||
|
|
||||||
return x.view(*input_shape)
|
|
||||||
|
|
||||||
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
x: (sequence_length, model_dim)
|
|
||||||
gate_logits: (sequence_length, n_experts)
|
|
||||||
"""
|
|
||||||
# optional reshape
|
|
||||||
input_shape = x.shape
|
|
||||||
x = x.view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
# gate_logits: (sequence_length, n_experts)
|
|
||||||
gate_logits = self.gate(x)
|
|
||||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
|
||||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
|
||||||
|
|
||||||
if self.top_k < self.num_experts:
|
|
||||||
_, not_selected_experts = torch.topk(
|
|
||||||
all_probs,
|
|
||||||
self.num_experts - self.top_k,
|
|
||||||
largest=False,
|
|
||||||
sorted=False,
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
# Mask not selected experts
|
|
||||||
all_probs.scatter_(1, not_selected_experts, 0)
|
|
||||||
|
|
||||||
# Re-normalize
|
|
||||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
|
||||||
weights = weights.to(x.dtype)
|
|
||||||
|
|
||||||
# Expand to [num_experts, sequence_length, model_dim]
|
|
||||||
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
|
||||||
|
|
||||||
# Permute to [num_experts, model_dim, ffn_dim]
|
|
||||||
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
|
||||||
0, 2, 1
|
|
||||||
)
|
)
|
||||||
w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
|
||||||
0, 2, 1
|
|
||||||
)
|
|
||||||
|
|
||||||
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3)
|
|
||||||
|
|
||||||
out = torch.bmm(
|
|
||||||
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
|
||||||
)
|
|
||||||
# Mask not selected experts
|
|
||||||
out *= weights.t().view(self.num_experts, -1, 1)
|
|
||||||
|
|
||||||
# Sum experts
|
|
||||||
out = out.sum(0)
|
|
||||||
|
|
||||||
# Reduce sum
|
# Reduce sum
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
return out
|
return out.view(*x.shape)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
if len(x) > 256 and HAS_MEGABLOCKS:
|
|
||||||
return self.sparse_forward(x)
|
|
||||||
# This is faster when there is not a lot of tokens
|
|
||||||
return self.dense_forward(x)
|
|
||||||
|
|
||||||
|
|
||||||
class DenseMoE(nn.Module):
|
class DenseMoE(nn.Module):
|
||||||
|
@ -169,6 +169,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
|
|
||||||
tokenized_input = tokenized_input[-r.truncate :]
|
tokenized_input = tokenized_input[-r.truncate :]
|
||||||
|
if (
|
||||||
|
tokenized_input[0] == tokenizer.bos_token_id
|
||||||
|
and tokenized_input[1] == tokenizer.bos_token_id
|
||||||
|
):
|
||||||
|
tokenized_input = tokenized_input[1:]
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
@ -694,7 +699,7 @@ class FlashCausalLM(Model):
|
|||||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
|
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
block_tables = (
|
block_tables = (
|
||||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||||
|
@ -3,12 +3,11 @@ import torch.distributed
|
|||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer, AutoConfig
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||||
FlashCohereForCausalLM,
|
FlashCohereForCausalLM,
|
||||||
CohereConfig,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
@ -32,7 +31,7 @@ class FlashCohere(FlashCausalLM):
|
|||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashCohere is only available on GPU")
|
raise NotImplementedError("FlashCohere is only available on GPU")
|
||||||
|
|
||||||
@ -46,7 +45,7 @@ class FlashCohere(FlashCausalLM):
|
|||||||
from_slow=False,
|
from_slow=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = CohereConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
@ -385,7 +385,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
|
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
block_tables = (
|
block_tables = (
|
||||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||||
|
@ -88,6 +88,9 @@ def attention(
|
|||||||
out,
|
out,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
|
@ -19,7 +19,6 @@ from accelerate import init_empty_weights
|
|||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||||
from text_generation_server.utils.log import log_once
|
|
||||||
|
|
||||||
HAS_AWQ = True
|
HAS_AWQ = True
|
||||||
try:
|
try:
|
||||||
@ -35,12 +34,6 @@ except Exception:
|
|||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
|
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
|
||||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||||
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
|
||||||
# V2 = False
|
|
||||||
# log_once(
|
|
||||||
# logger.warning,
|
|
||||||
# "Disabling exllama v2 and using v1 instead because there are issues when sharding",
|
|
||||||
# )
|
|
||||||
|
|
||||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
@ -174,6 +167,8 @@ class EETQLinear(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
device = weight.device
|
device = weight.device
|
||||||
|
if weight.dtype != torch.float16:
|
||||||
|
weight = weight.to(dtype=torch.float16)
|
||||||
weight = torch.t(weight).contiguous().cpu()
|
weight = torch.t(weight).contiguous().cpu()
|
||||||
weight, scale = quant_weights(weight, torch.int8, False)
|
weight, scale = quant_weights(weight, torch.int8, False)
|
||||||
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
from vllm import cache_ops
|
from vllm._C import cache_ops, ops
|
||||||
from vllm import attention_ops
|
|
||||||
|
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
@ -14,7 +13,7 @@ def reshape_and_cache(
|
|||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
):
|
):
|
||||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
@ -54,9 +53,9 @@ def attention(
|
|||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
if use_v1:
|
if use_v1:
|
||||||
attention_ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
@ -68,6 +67,8 @@ def attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
|
"auto",
|
||||||
|
1.0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
@ -83,7 +84,7 @@ def attention(
|
|||||||
device=out.device,
|
device=out.device,
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.empty_like(exp_sums)
|
||||||
attention_ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
out,
|
out,
|
||||||
exp_sums,
|
exp_sums,
|
||||||
max_logits,
|
max_logits,
|
||||||
@ -98,4 +99,6 @@ def attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
|
"auto",
|
||||||
|
1.0,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user