Merge branch 'main' into mem_refine

This commit is contained in:
Wang, Yi A 2024-06-27 19:12:42 -07:00
commit af16320e66
20 changed files with 6637 additions and 5678 deletions

11
Cargo.lock generated
View File

@ -3762,7 +3762,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "2.0.5-dev0" version = "2.1.1-dev0"
dependencies = [ dependencies = [
"average", "average",
"clap", "clap",
@ -3783,7 +3783,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "2.0.5-dev0" version = "2.1.1-dev0"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"base64 0.22.1", "base64 0.22.1",
@ -3801,7 +3801,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "2.0.5-dev0" version = "2.1.1-dev0"
dependencies = [ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
@ -3820,7 +3820,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "2.0.5-dev0" version = "2.1.1-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum 0.7.5", "axum 0.7.5",
@ -3832,6 +3832,7 @@ dependencies = [
"hf-hub", "hf-hub",
"image", "image",
"init-tracing-opentelemetry", "init-tracing-opentelemetry",
"itertools 0.10.5",
"jsonschema", "jsonschema",
"metrics 0.21.1", "metrics 0.21.1",
"metrics-exporter-prometheus", "metrics-exporter-prometheus",
@ -3854,8 +3855,6 @@ dependencies = [
"tokio-stream", "tokio-stream",
"tower-http", "tower-http",
"tracing", "tracing",
"tracing-core",
"tracing-log 0.2.0",
"tracing-opentelemetry 0.21.0", "tracing-opentelemetry 0.21.0",
"tracing-subscriber", "tracing-subscriber",
"utoipa", "utoipa",

View File

@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b) - [Gemma](https://huggingface.co/google/gemma-7b)
- [Gemma2](https://huggingface.co/google/gemma2-9b)
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj) - [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)

View File

@ -8,61 +8,61 @@
"tokens": [ "tokens": [
{ {
"id": 330, "id": 330,
"logprob": -0.13000488, "logprob": -0.08660889,
"special": false, "special": false,
"text": " A" "text": " A"
}, },
{ {
"id": 13088, "id": 13088,
"logprob": -0.6713867, "logprob": -0.7089844,
"special": false, "special": false,
"text": " chicken" "text": " chicken"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.2980957, "logprob": -0.32885742,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6398, "id": 6398,
"logprob": -0.060638428, "logprob": -0.05126953,
"special": false, "special": false,
"text": " sitting" "text": " sitting"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.27319336, "logprob": -0.35229492,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.140625, "logprob": -0.12561035,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 17972, "id": 17972,
"logprob": -0.040405273, "logprob": -0.038085938,
"special": false, "special": false,
"text": " pile" "text": " pile"
}, },
{ {
"id": 302, "id": 302,
"logprob": -0.0002708435, "logprob": -0.00018656254,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 2445, "id": 2445,
"logprob": -0.095336914, "logprob": -0.07293701,
"special": false, "special": false,
"text": " money" "text": " money"
}, },
{ {
"id": 28723, "id": 28723,
"logprob": -0.0068359375, "logprob": -0.004852295,
"special": false, "special": false,
"text": "." "text": "."
} }

View File

@ -8,115 +8,115 @@
"tokens": [ "tokens": [
{ {
"id": 415, "id": 415,
"logprob": -0.04421997, "logprob": -0.039886475,
"special": false, "special": false,
"text": " The" "text": " The"
}, },
{ {
"id": 12072, "id": 12072,
"logprob": -0.13500977, "logprob": -0.1430664,
"special": false, "special": false,
"text": " cow" "text": " cow"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.06750488, "logprob": -0.056488037,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6328, "id": 6328,
"logprob": -0.6352539, "logprob": -0.6855469,
"special": false, "special": false,
"text": " standing" "text": " standing"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.16186523, "logprob": -0.1685791,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 272, "id": 272,
"logprob": -0.5078125, "logprob": -0.50097656,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 10305, "id": 10305,
"logprob": -0.017913818, "logprob": -0.017303467,
"special": false, "special": false,
"text": " beach" "text": " beach"
}, },
{ {
"id": 304, "id": 304,
"logprob": -1.5205078, "logprob": -1.3564453,
"special": false, "special": false,
"text": " and" "text": " and"
}, },
{ {
"id": 272, "id": 272,
"logprob": -0.029174805, "logprob": -0.017868042,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 13088, "id": 13088,
"logprob": -0.003479004, "logprob": -0.0027103424,
"special": false, "special": false,
"text": " chicken" "text": " chicken"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.0035095215, "logprob": -0.003156662,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6398, "id": 6398,
"logprob": -0.3088379, "logprob": -0.37304688,
"special": false, "special": false,
"text": " sitting" "text": " sitting"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.027755737, "logprob": -0.034576416,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.31884766, "logprob": -0.29418945,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 17972, "id": 17972,
"logprob": -0.047943115, "logprob": -0.042877197,
"special": false, "special": false,
"text": " pile" "text": " pile"
}, },
{ {
"id": 302, "id": 302,
"logprob": -0.0002925396, "logprob": -0.00028443336,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 2445, "id": 2445,
"logprob": -0.02935791, "logprob": -0.023223877,
"special": false, "special": false,
"text": " money" "text": " money"
}, },
{ {
"id": 28723, "id": 28723,
"logprob": -0.031219482, "logprob": -0.018157959,
"special": false, "special": false,
"text": "." "text": "."
}, },
{ {
"id": 32002, "id": 32002,
"logprob": -0.00034475327, "logprob": -0.00018393993,
"special": true, "special": true,
"text": "<end_of_utterance>" "text": "<end_of_utterance>"
}, },

View File

@ -22,6 +22,7 @@ text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] } clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28" futures = "0.3.28"
hf-hub = { workspace = true } hf-hub = { workspace = true }
itertools = "0.10"
jsonschema = { version = "0.17.1", features = ["draft202012"] } jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1" metrics = "0.21.1"
metrics-exporter-prometheus = { version = "0.15.1", features = [] } metrics-exporter-prometheus = { version = "0.15.1", features = [] }

View File

@ -71,10 +71,12 @@ fn get_unpadded_features(
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64; let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio { let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
let new_height = (height * current_width) / width; let new_height = (height * current_width) / width;
(new_height, current_width) let padding = (current_height - new_height) / 2;
(current_height - (2 * padding), current_width)
} else { } else {
let new_width = (width * current_height) / height; let new_width = (width * current_height) / height;
(current_height, new_width) let padding = (current_width - new_width) / 2;
(current_height, current_width - (2 * padding))
}; };
let unpadded_features = current_height * current_width; let unpadded_features = current_height * current_width;
@ -88,7 +90,9 @@ impl LlavaNext {
let patch_size = self.vision_config.patch_size; let patch_size = self.vision_config.patch_size;
assert!(image_size % patch_size == 0); assert!(image_size % patch_size == 0);
let npatches = image_size / patch_size; let npatches = image_size / patch_size;
let (num_patch_height, num_patch_width) = // Dimensions are intentionally swapped to be bug-compatible with
// upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
let (num_patch_width, num_patch_height) =
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size); get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
let (unpadded_features, newline_features) = let (unpadded_features, newline_features) =
@ -112,7 +116,7 @@ pub struct Idefics2 {}
impl Idefics2 { impl Idefics2 {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
320 64
} }
} }
@ -158,6 +162,7 @@ pub enum Config {
Baichuan, Baichuan,
Paligemma(Paligemma), Paligemma(Paligemma),
Gemma, Gemma,
Gemma2,
Cohere, Cohere,
Drbx, Drbx,
Falcon, Falcon,

View File

@ -61,6 +61,9 @@ pub struct HubTokenizerConfig {
pub bos_token: Option<String>, pub bos_token: Option<String>,
#[serde(deserialize_with = "token_serde::deserialize")] #[serde(deserialize_with = "token_serde::deserialize")]
pub eos_token: Option<String>, pub eos_token: Option<String>,
pub tokenizer_class: Option<String>,
pub add_bos_token: Option<bool>,
pub add_eos_token: Option<bool>,
} }
impl HubTokenizerConfig { impl HubTokenizerConfig {
@ -70,6 +73,25 @@ impl HubTokenizerConfig {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "processor_class")]
pub enum HubPreprocessorConfig {
Idefics2Processor(Idefics2Preprocessor),
}
impl HubPreprocessorConfig {
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
let content = std::fs::read_to_string(filename).ok()?;
serde_json::from_str(&content).ok()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Idefics2Preprocessor {
#[serde(default)]
do_image_splitting: bool,
}
#[derive(Debug, Clone, Deserialize, Default)] #[derive(Debug, Clone, Deserialize, Default)]
pub struct HubProcessorConfig { pub struct HubProcessorConfig {
pub chat_template: Option<ChatTemplateVersions>, pub chat_template: Option<ChatTemplateVersions>,

View File

@ -13,9 +13,11 @@ use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use text_generation_router::config::Config; use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig}; use text_generation_router::{
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
use thiserror::Error; use thiserror::Error;
use tokenizers::Tokenizer; use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
use tower_http::cors::AllowOrigin; use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
@ -214,6 +216,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_filename, tokenizer_filename,
config_filename, config_filename,
tokenizer_config_filename, tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename, processor_config_filename,
model_info, model_info,
) = match api { ) = match api {
@ -221,6 +224,7 @@ async fn main() -> Result<(), RouterError> {
Some(local_path.join("tokenizer.json")), Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")), Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")), Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")), Some(local_path.join("processor_config.json")),
None, None,
), ),
@ -237,6 +241,7 @@ async fn main() -> Result<(), RouterError> {
}; };
let config_filename = api_repo.get("config.json").await.ok(); let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok(); let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_model_info(&api_repo).await { let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
@ -249,6 +254,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_filename, tokenizer_filename,
config_filename, config_filename,
tokenizer_config_filename, tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename, processor_config_filename,
model_info, model_info,
) )
@ -263,13 +269,12 @@ async fn main() -> Result<(), RouterError> {
repo.get("tokenizer.json"), repo.get("tokenizer.json"),
repo.get("config.json"), repo.get("config.json"),
repo.get("tokenizer_config.json"), repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"), repo.get("processor_config.json"),
None, None,
) )
} }
}; };
let tokenizer: Option<Tokenizer> =
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
let config: Option<Config> = config_filename.and_then(|filename| { let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename) std::fs::read_to_string(filename)
.ok() .ok()
@ -300,6 +305,23 @@ async fn main() -> Result<(), RouterError> {
HubTokenizerConfig::default() HubTokenizerConfig::default()
}); });
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast") && tokenizer.get_post_processor().is_none() {
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
});
let preprocessor_config =
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
let processor_config = processor_config_filename let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file) .and_then(HubProcessorConfig::from_file)
.unwrap_or_default(); .unwrap_or_default();
@ -361,6 +383,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
tokenizer_config, tokenizer_config,
preprocessor_config,
processor_config, processor_config,
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
@ -504,6 +527,77 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConf
Some(tokenizer_config) Some(tokenizer_config)
} }
/// Create a post_processor for the LlamaTokenizer
pub fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos)
.expect("Should have found the bos token id");
special_tokens.push((bos.clone(), bos_token_id));
single.push(format!("{}:0", bos));
pair.push(format!("{}:0", bos));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos)
.expect("Should have found the eos token id");
special_tokens.push((eos.clone(), eos_token_id));
single.push(format!("{}:0", eos));
pair.push(format!("{}:0", eos));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
single.push(format!("{}:1", bos));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
enum RouterError { enum RouterError {
#[error("Argument validation error: {0}")] #[error("Argument validation error: {0}")]
@ -513,3 +607,36 @@ enum RouterError {
#[error("Tokio runtime failed to start: {0}")] #[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error), Tokio(#[from] std::io::Error),
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_post_processor() {
let tokenizer_config = HubTokenizerConfig {
add_bos_token: None,
add_eos_token: None,
bos_token: Some("<s>".to_string()),
eos_token: Some("</s>".to_string()),
chat_template: None,
tokenizer_class: None,
completion_template: None,
};
let tokenizer =
Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
let expected = TemplateProcessing::builder()
.try_single("<s>:0 $A:0 <s>:1")
.unwrap()
.try_pair("<s>:0 $A:0 $B:1")
.unwrap()
.special_tokens(vec![("<s>".to_string(), 1)])
.build()
.unwrap();
assert_eq!(post_processor, expected);
}
}

View File

@ -12,9 +12,9 @@ use crate::kserve::{
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig,
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse,
Usage, Validation, Token, TokenizeResponse, Usage, Validation,
}; };
use crate::{ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
@ -1423,6 +1423,7 @@ pub async fn run(
_ngrok_authtoken: Option<String>, _ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>, _ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
preprocessor_config: Option<HubPreprocessorConfig>,
processor_config: HubProcessorConfig, processor_config: HubProcessorConfig,
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool, grammar_support: bool,
@ -1636,6 +1637,7 @@ pub async fn run(
validation_workers, validation_workers,
tokenizer, tokenizer,
config, config,
preprocessor_config,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,

View File

@ -1,13 +1,16 @@
/// Payload validation logic /// Payload validation logic
use crate::config::Config; use crate::config::Config;
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest, GrammarType}; use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
};
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat}; use image::{io::Reader as ImageReader, ImageFormat};
use jsonschema::{Draft, JSONSchema}; use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use std::iter;
use text_generation_client::{Chunk, Image, InputChunk}; use text_generation_client::{Chunk, Image, InputChunk};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
@ -36,6 +39,7 @@ impl Validation {
workers: usize, workers: usize,
tokenizer: Option<Tokenizer>, tokenizer: Option<Tokenizer>,
config: Option<Config>, config: Option<Config>,
preprocessor_config: Option<HubPreprocessorConfig>,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32, max_top_n_tokens: u32,
@ -53,12 +57,18 @@ impl Validation {
for _ in 0..workers { for _ in 0..workers {
let tokenizer_clone = tokenizer.clone(); let tokenizer_clone = tokenizer.clone();
let config_clone = config.clone(); let config_clone = config.clone();
let preprocessor_config_clone = preprocessor_config.clone();
let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel(); let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
senders.push(tokenizer_sender); senders.push(tokenizer_sender);
// Spawn worker // Spawn worker
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver) tokenizer_worker(
tokenizer_clone,
config_clone,
preprocessor_config_clone,
tokenizer_receiver,
)
}); });
} }
@ -422,13 +432,20 @@ async fn round_robin_task(
fn tokenizer_worker( fn tokenizer_worker(
tokenizer: Tokenizer, tokenizer: Tokenizer,
config: Option<Config>, config: Option<Config>,
preprocessor_config: Option<HubPreprocessorConfig>,
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) { ) {
// Loop over requests // Loop over requests
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
parent_span.in_scope(|| { parent_span.in_scope(|| {
response_tx response_tx
.send(prepare_input(inputs, truncate, &tokenizer, &config)) .send(prepare_input(
inputs,
truncate,
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
))
.unwrap_or(()) .unwrap_or(())
}) })
} }
@ -508,16 +525,67 @@ fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), Validatio
} }
} }
fn image_tokens(
config: &Config,
preprocessor_config: Option<&HubPreprocessorConfig>,
height: usize,
width: usize,
) -> String {
use Config::*;
use HubPreprocessorConfig::*;
match config {
Idefics => "<image>".to_string(),
Idefics2(config) => {
const FAKE: &str = "<fake_token_around_image>";
const IMAGE: &str = "<image>";
let slots = config.get_number_of_features(height, width);
let mut image_string = String::with_capacity(2 * FAKE.len() + slots * IMAGE.len());
image_string.push_str(FAKE);
image_string.extend(iter::repeat(IMAGE).take(slots));
image_string.push_str(FAKE);
if matches!(
preprocessor_config,
Some(Idefics2Processor(Idefics2Preprocessor {
do_image_splitting: true,
..
}))
) {
image_string = image_string.repeat(5);
};
image_string
}
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
_ => unimplemented!("Images tokens are not supported for this model configuration"),
}
}
fn image_tokens_fixup(config: &Config, text: String) -> String {
match config {
Config::Idefics2(_) => {
const FAKE: &str = "<fake_token_around_image>";
text.replace(&format!("{FAKE}{FAKE}"), FAKE)
}
_ => text,
}
}
/// Get input length and optionally truncate it /// Get input length and optionally truncate it
fn prepare_input( fn prepare_input(
inputs: String, inputs: String,
_truncate: Option<usize>, _truncate: Option<usize>,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
config: &Option<Config>, config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>,
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> { ) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
Some(Config::LlavaNext(config)) => { Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;
@ -529,88 +597,17 @@ fn prepare_input(
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
tokenizer_query.push_str(&"<image>".repeat(slots)); tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end; start = chunk_end;
} }
if start != inputs.len() { if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into()); input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]); tokenizer_query.push_str(&inputs[start..]);
} }
(tokenizer_query, input_chunks)
}
Some(Config::Paligemma(config)) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
tokenizer_query.push_str(&"<image>".repeat(slots));
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
(tokenizer_query, input_chunks)
}
Some(Config::Idefics2(config)) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str("<fake_token_around_image>");
tokenizer_query.push_str(&"<image>".repeat(slots));
tokenizer_query.push_str("<fake_token_around_image>");
input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); tokenizer_query = image_tokens_fixup(config, tokenizer_query);
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
(tokenizer_query, input_chunks)
}
Some(Config::Idefics) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, _height, _width) =
fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = 1;
tokenizer_query.push_str(&"<image>".repeat(slots));
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
(tokenizer_query, input_chunks) (tokenizer_query, input_chunks)
} }
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]), _ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
@ -750,7 +747,7 @@ pub enum ValidationError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::config::{PaliTextConfig, Paligemma}; use crate::config::{Idefics2, PaliTextConfig, Paligemma};
use crate::default_parameters; use crate::default_parameters;
use crate::tests::get_tokenizer; use crate::tests::get_tokenizer;
@ -769,6 +766,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -803,6 +801,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -836,6 +835,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -874,6 +874,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -941,6 +942,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,
@ -1026,6 +1028,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
Some(config), Some(config),
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -1058,4 +1061,83 @@ mod tests {
"Failed to process images", "Failed to process images",
); );
} }
#[tokio::test]
async fn test_idefics2_correct_n_fake_tokens() {
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequence = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let disable_grammar_support = true;
let workers = 1;
let config = Config::Idefics2(Idefics2 {});
let validation = Validation::new(
workers,
tokenizer,
Some(config),
Some(HubPreprocessorConfig::Idefics2Processor(
Idefics2Preprocessor {
do_image_splitting: true,
},
)),
max_best_of,
max_stop_sequence,
max_top_n_tokens,
max_input_length,
max_total_tokens,
disable_grammar_support,
);
let (encoding, chunks) = match validation
.tokenize(
format!(
"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})",
PIXEL_GIF, PIXEL_GIF
),
None,
)
.await
{
Ok(Some((encoding, chunks))) => (encoding, chunks),
_ => panic!("Unexpected tokenization failure"),
};
assert!(
chunks
== vec![
Chunk::Text("test".to_string()).into(),
Chunk::Image(Image {
data: pixel_data.clone(),
mimetype: "image/gif".to_string()
})
.into(),
Chunk::Image(Image {
data: pixel_data.clone(),
mimetype: "image/gif".to_string()
})
.into()
],
"Failed to process images",
);
// Verify the number of fake tokens:
//
// - Two images surrounded/separated by a fake token = 3.
// - Both are split in 5 subimages, separated by a fake token: 2 * 4
//
// Fake tokens get split up by the testing tokenizer, but we don't care.
assert_eq!(
encoding
.get_tokens()
.iter()
.filter(|t| *t == "fake")
.count(),
11
);
}
} }

View File

@ -68,6 +68,9 @@ try:
from text_generation_server.models.flash_gemma import ( from text_generation_server.models.flash_gemma import (
FlashGemma, FlashGemma,
) )
from text_generation_server.models.flash_gemma2 import (
FlashGemma2,
)
from text_generation_server.models.pali_gemma import ( from text_generation_server.models.pali_gemma import (
PaliGemma, PaliGemma,
) )
@ -102,6 +105,7 @@ if FLASH_ATTENTION:
__all__.append(FlashQwen2) __all__.append(FlashQwen2)
__all__.append(FlashStarcoder2) __all__.append(FlashStarcoder2)
__all__.append(FlashGemma) __all__.append(FlashGemma)
__all__.append(FlashGemma2)
__all__.append(FlashCohere) __all__.append(FlashCohere)
MAMBA_AVAILABLE = True MAMBA_AVAILABLE = True
@ -143,6 +147,11 @@ class ModelType(enum.Enum):
"name": "Gemma", "name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b", "url": "https://huggingface.co/google/gemma-7b",
} }
GEMMA2 = {
"type": "gemma2",
"name": "Gemma2",
"url": "https://huggingface.co/google/gemma2-9b",
}
COHERE = { COHERE = {
"type": "cohere", "type": "cohere",
"name": "Cohere", "name": "Cohere",
@ -630,6 +639,27 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == GEMMA2:
if FLASH_ATTENTION:
return FlashGemma2(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == COHERE: if model_type == COHERE:
if FLASH_ATTENTION: if FLASH_ATTENTION:

View File

@ -0,0 +1,500 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
class Gemma2Config(PretrainedConfig):
def __init__(
self,
vocab_size=256128,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.head_dim = head_dim
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class Gemma2FastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
weights.dtype = dtype
new = cls(weight, eps)
new.dtype = dtype
return new
# perform the multiplication in full precision and downcast after
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * self.weight
return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq", "marlin"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
class FlashGemma2Attention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
self.causal = causal
if is_sliding:
self.window_size = config.sliding_window
else:
self.window_size = -1
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
# self.softmax_scale = self.head_size**-0.5
self.softmax_scale = config.query_pre_attn_scalar**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
causal=self.causal,
window_size_left=self.window_size,
)
# Decode
else:
paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class Gemma2MLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
class FlashGemma2Layer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool):
super().__init__()
self.self_attn = FlashGemma2Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
causal=causal,
is_sliding=is_sliding,
)
self.mlp = Gemma2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.pre_feedforward_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.pre_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_feedforward_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.post_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
# faster post attention rms norm
normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
normed_attn_res_output = normed_attn_res_output + res
res = normed_attn_res_output
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
mlp_output = self.mlp(pre_normed)
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
return post_hidden_states, normed_attn_res_output
class FlashGemma2Model(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.layers = nn.ModuleList(
[
FlashGemma2Layer(
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
causal=causal,
is_sliding=layer_id % 2 == 0,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashGemma2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
super().__init__()
embed_norm = config.hidden_size**0.5
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.model = FlashGemma2Model(
prefix=prefix, config=config, weights=weights, causal=causal
)
self.lm_head = SpeculativeHead.load(
prefix=(
f"{prefix}.embed_tokens"
if config.tie_word_embeddings
else f"{prefix}.lm_head"
),
config=config,
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
input_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -375,8 +375,6 @@ class FlashGemmaModel(torch.nn.Module):
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
) )
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads

View File

@ -39,7 +39,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
Args: Args:
image_size (`tuple`): image_size (`tuple`):
The size of the input image in the format (width, height). The size of the input image in the format (height, width).
grid_pinpoints (`List`): grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`. of the form `(height, width)`.
@ -47,7 +47,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
The size of each image patch. The size of each image patch.
Returns: Returns:
tuple: The shape of the image patch grid in the format (width, height). tuple: The shape of the image patch grid in the format (height, width).
""" """
if not isinstance(grid_pinpoints, list): if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists") raise ValueError("grid_pinpoints should be a list of tuples or lists")
@ -230,7 +230,10 @@ class LlavaNextForConditionalGeneration(nn.Module):
raise ValueError( raise ValueError(
"The number of patches is not consistent with the image size." "The number of patches is not consistent with the image size."
) )
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx], image_sizes[image_idx],
self.config.image_grid_pinpoints, self.config.image_grid_pinpoints,
self.config.vision_config.image_size, self.config.vision_config.image_size,

View File

@ -28,8 +28,12 @@ from text_generation_server.models.types import (
GeneratedText, GeneratedText,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS from text_generation_server.models.globals import (
import text_generation_server.models.globals as tgi_globals MEM_POOL,
CUDA_GRAPHS,
get_adapter_to_index,
MODEL_ID,
)
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
@ -233,7 +237,8 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens) top_n_tokens.append(r.top_n_tokens)
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0) ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(torch.full((input_length,), adapter_index)) adapter_indices_list.append(torch.full((input_length,), adapter_index))
adapter_set.add(adapter_index) adapter_set.add(adapter_index)
@ -499,9 +504,8 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.append(self.top_n_tokens[idx]) top_n_tokens.append(self.top_n_tokens[idx])
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get( ADAPTER_TO_INDEX = get_adapter_to_index()
self.requests[idx].adapter_id, 0 adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
)
adapter_set.add(adapter_index) adapter_set.add(adapter_index)
remaining_tokens = ( remaining_tokens = (
@ -1017,7 +1021,7 @@ class FlashCausalLM(Model):
tunableop_filepath = os.path.join( tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE, HUGGINGFACE_HUB_CACHE,
f"tunableop_{tgi_globals.MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
) )
logger.info( logger.info(

View File

@ -0,0 +1,75 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import PretrainedConfig, AutoTokenizer
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashGemma2(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PretrainedConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
prefix = ""
model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group)
super(FlashGemma2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -44,3 +44,8 @@ ADAPTER_TO_INDEX: Dict[str, int] = None
def set_adapter_to_index(adapter_to_index: Dict[str, int]): def set_adapter_to_index(adapter_to_index: Dict[str, int]):
global ADAPTER_TO_INDEX global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX = adapter_to_index ADAPTER_TO_INDEX = adapter_to_index
def get_adapter_to_index():
global ADAPTER_TO_INDEX
return ADAPTER_TO_INDEX

View File

@ -39,7 +39,9 @@ class PaliGemmaBatch(VlmCausalLMBatch):
# TODO do_convert_RGB should be on by default ? # TODO do_convert_RGB should be on by default ?
image = image.convert("RGB") image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt") image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id) full_text += image_text_replacement(
processor, image_input, config, image_id
)
image_inputs.append(image_input) image_inputs.append(image_input)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")

View File

@ -1,3 +1,4 @@
from itertools import repeat
import torch import torch
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -15,6 +16,9 @@ from text_generation_server.models.flash_mistral import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
IDEFICS2_IMAGE_TOKEN = "<image>"
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
""" """
@ -22,7 +26,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
Args: Args:
image_size (`tuple`): image_size (`tuple`):
The size of the input image in the format (width, height). The size of the input image in the format (height, width).
grid_pinpoints (`List`): grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`. of the form `(height, width)`.
@ -39,15 +43,13 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
def image_text_replacement(image_input, config, image_id) -> str: def image_text_replacement(processor, image_input, config, image_id: int) -> str:
if config.model_type == "idefics2": if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented. image_seq_len = 64
num_features = 320 image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
return ( if processor.image_processor.do_image_splitting:
"<fake_token_around_image>" image_str *= 5
+ "<image>" * num_features return image_str
+ "<fake_token_around_image>"
)
elif config.model_type == "llava_next": elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id] height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
@ -64,20 +66,35 @@ def image_text_replacement(image_input, config, image_id) -> str:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal") raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def image_text_replacement_fixup(config, text: str) -> str:
if config.model_type == "idefics2":
return text.replace(
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
)
return text
def get_unpadded_features( def get_unpadded_features(
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
current_height = npatches * num_patch_height current_height = npatches * num_patch_height
current_width = npatches * num_patch_width current_width = npatches * num_patch_width
aspect_ratio: float = width / height aspect_ratio: float = original_width / original_height
current_aspect_ratio: float = current_width / current_height current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio: if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width new_height = (original_height * current_width) // original_width
current_height = new_height padding = (current_height - new_height) // 2
current_height = current_height - (2 * padding)
else: else:
new_width = (width * current_height) // height new_width = (original_width * current_height) // original_height
current_width = new_width padding = (current_width - new_width) // 2
current_width = current_width - (2 * padding)
unpadded_features = current_height * current_width unpadded_features = current_height * current_width
newline_features = current_height newline_features = current_height
@ -96,7 +113,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
npatches = image_size // patch_size npatches = image_size // patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape( # Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
[height, width], [height, width],
image_grid_pinpoints, image_grid_pinpoints,
image_size, image_size,
@ -168,9 +187,13 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if chunk_type == "text": if chunk_type == "text":
full_text += chunk.text full_text += chunk.text
elif chunk_type == "image": elif chunk_type == "image":
full_text += image_text_replacement(image_inputs, config, image_id) full_text += image_text_replacement(
processor, image_inputs, config, image_id
)
image_id += 1 image_id += 1
full_text = image_text_replacement_fixup(config, full_text)
batch_inputs.append(full_text) batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)