From 215ed3ad52651f76ca4326713ba9e4e5107323e5 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 5 Aug 2024 09:11:40 -0400 Subject: [PATCH 01/72] fix: attempt forward on flash attn2 to check hardware support (#2335) * fix: attempt forward on flash attn2 to check hardware support * fix: warn window_size_left when using flash attn 1 * fix: prefer version check over test op and avoid window_size_left if not flash attn2 * fix: improve condtional and error message * fix: update sliding window conditional * fix: simplify changes and revert model changes * fix: avoid changing conditional * fix: typo tweak --- server/text_generation_server/layers/attention/cuda.py | 4 ++++ server/text_generation_server/models/__init__.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index dff742dc..2b898831 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -172,6 +172,10 @@ def paged_attention( try: + is_ampere_or_newer = major >= 8 and minor >= 0 + if not is_ampere_or_newer: + raise ImportError("FlashAttention only supports Ampere GPUs or newer.") + import flash_attn_2_cuda V2 = True diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3dc24159..ae791ef8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -484,6 +484,9 @@ def get_model( ) sliding_window = config_dict.get("sliding_window", -1) + if max_input_tokens is not None and max_input_tokens <= sliding_window: + sliding_window = -1 + if ( (sliding_window is not None and sliding_window != -1) and not SUPPORTS_WINDOWING From dd47a3dac411a3b9896bada2983f5ec7014c1922 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 5 Aug 2024 12:36:44 -0400 Subject: [PATCH 02/72] feat: include local lora adapter loading docs (#2359) --- docs/source/conceptual/lora.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/conceptual/lora.md b/docs/source/conceptual/lora.md index 08df767c..cfc2109b 100644 --- a/docs/source/conceptual/lora.md +++ b/docs/source/conceptual/lora.md @@ -36,6 +36,18 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia ``` +additionally, you can specify the path to the LoRA models using the `LORA_ADAPTERS_PATH` environment variable. For example: + +```bash +LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter +``` + +note it's possible to mix adapter_ids with adapter_id=adapter_path e.g. + +```bash +LORA_ADAPTERS=predibase/dbpedia,myadapter=/path/to/dir/ +``` + In the server logs, you will see the following message: ```txt From 29b8d19cdf83d7fd26a4f17015228363a7163522 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 07:49:53 -0400 Subject: [PATCH 03/72] fix: return the out tensor rather then the functions return value (#2361) --- server/text_generation_server/layers/attention/cuda.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 2b898831..96b654d0 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -292,8 +292,7 @@ else: ) out = torch.empty_like(q) - - return flash_attn_cuda.fwd( + flash_attn_cuda.fwd( q, k, v, @@ -309,4 +308,5 @@ else: False, 0, None, - )[0] + ) + return out From e11f5f1c383831a226e26bf3560d6cdade0ee914 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 07:51:32 -0400 Subject: [PATCH 04/72] feat: implement a templated endpoint for visibility into chat requests (#2333) * feat: implement a templated endpoint for visibility into chat requests * feat: improve to tokenize too * fix: adjust return type * feat: simplify prepare_chat_input logic and adjust start stop chars --- router/src/lib.rs | 6 ++ router/src/server.rs | 210 ++++++++++++++++++++++++++++++------------- 2 files changed, 156 insertions(+), 60 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 14bb8270..386b0556 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1157,6 +1157,12 @@ pub(crate) struct GenerateResponse { pub details: Option
, } +#[derive(Serialize, ToSchema)] +pub(crate) struct ChatTokenizeResponse { + pub(crate) tokenize_response: TokenizeResponse, + pub(crate) templated_text: String, +} + #[derive(Serialize, ToSchema)] #[serde(transparent)] pub(crate) struct TokenizeResponse(Vec); diff --git a/router/src/server.rs b/router/src/server.rs index dcbaa2ad..7655182a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,6 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; +use crate::ChatTokenizeResponse; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -22,7 +23,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -115,6 +116,107 @@ async fn get_model_info(info: Extension) -> Json { Json(info.0) } +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/chat_tokenize", + request_body = ChatRequest, + responses((status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse)) +)] +async fn get_chat_tokenize( + Extension(infer): Extension, + Json(req): Json, +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { + metrics::counter!("tgi_request_count").increment(1); + + let ChatRequest { + model, + max_tokens, + messages, + seed, + stop, + stream, + tools, + tool_choice, + tool_prompt, + temperature, + response_format, + .. + } = req; + + let tool_prompt = tool_prompt.unwrap_or_default(); + let (inputs, _grammar, _tool_grammar) = prepare_chat_input( + &infer, + response_format, + tools, + tool_choice, + &tool_prompt, + messages, + )?; + + let generate_request = GenerateRequest { + inputs, + parameters: GenerateParameters { + best_of: None, + temperature, + repetition_penalty: None, + frequency_penalty: None, + top_k: None, + top_p: None, + typical_p: None, + do_sample: true, + max_new_tokens: max_tokens, + return_full_text: None, + stop: stop.unwrap_or_default(), + truncate: None, + watermark: false, + details: false, + decoder_input_details: !stream, + seed, + top_n_tokens: None, + grammar: _grammar, + adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), + }, + }; + + let input = generate_request.inputs.clone(); + let encoding = infer.tokenize(generate_request).await?; + if let Some(encoding) = encoding { + let tokens: Vec = encoding + .get_ids() + .iter() + .zip(encoding.get_offsets()) + .map(|(&id, &(start, stop))| { + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); + SimpleToken { + id, + text, + start, + stop, + } + }) + .collect(); + + let resp = ChatTokenizeResponse { + tokenize_response: TokenizeResponse(tokens), + templated_text: input, + }; + Ok((HeaderMap::new(), Json(resp))) + } else { + Err(( + StatusCode::NOT_FOUND, + Json(ErrorResponse { + error: "No fast tokenizer or tokenizer.json for this model".to_string(), + error_type: "no fast tokenizer".to_string(), + }), + )) + } +} + #[utoipa::path( get, tag = "Text Generation Inference", @@ -1034,63 +1136,14 @@ async fn chat_completions( Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - - // response_format and tools are mutually exclusive - if response_format.is_some() && tools.as_ref().is_some() { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: "Grammar and tools are mutually exclusive".to_string(), - error_type: "grammar and tools".to_string(), - }), - )); - } - - // extract tool grammar if present - let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { - Ok(grammar) => grammar, - Err(err) => { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - tracing::error!("{err}"); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: err.to_string(), - error_type: err.error_type().to_string(), - }), - )); - } - }; - - // determine the appropriate arguments for apply_chat_template - let tools_grammar_prompt = tool_grammar - .as_ref() - .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt)); - - let (tools_grammar_prompt, grammar) = match response_format { - Some(response_format) => (None, Some(response_format)), - None => ( - tools_grammar_prompt.clone(), - tools_grammar_prompt.map(|(grammar, _)| grammar.clone()), - ), - }; - - // apply chat template to flatten the request into a single input - let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { - Ok(inputs) => inputs, - Err(err) => { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - tracing::error!("{err}"); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: err.to_string(), - error_type: err.error_type().to_string(), - }), - )); - } - }; + let (inputs, grammar, tool_grammar) = prepare_chat_input( + &infer, + response_format, + tools, + tool_choice, + &tool_prompt, + messages, + )?; // build the request passing some parameters let generate_request = GenerateRequest { @@ -1360,8 +1413,11 @@ async fn tokenize( .iter() .zip(encoding.get_offsets()) .map(|(&id, &(start, stop))| { - let text: String = - String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string(); + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); SimpleToken { id, text, @@ -2036,6 +2092,7 @@ async fn start( } let info_routes = Router::new() .route("/", get(health)) + .route("/chat_tokenize", post(get_chat_tokenize)) .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) @@ -2332,3 +2389,36 @@ fn create_post_processor( Ok(post_processor) } + +type PreparedInput = (String, Option, Option); + +fn prepare_chat_input( + infer: &Infer, + response_format: Option, + tools: Option>, + tool_choice: ToolChoice, + tool_prompt: &str, + messages: Vec, +) -> Result { + if response_format.is_some() && tools.is_some() { + return Err(InferError::ToolError( + "Grammar and tools are mutually exclusive".into(), + )); + } + + if let Some(format) = response_format { + let inputs = infer.apply_chat_template(messages, None)?; + return Ok((inputs, Some(format), None)); + } + + // if tools are set, apply the tool grammar and then the chat template + let tool_grammar: Option = ToolGrammar::apply(tools, tool_choice)?; + let grammar = tool_grammar + .as_ref() + .map(|t| GrammarType::Json(serde_json::json!(t))); + let tools_grammar_prompt = tool_grammar + .as_ref() + .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); + let inputs = infer.apply_chat_template(messages, tools_grammar_prompt)?; + Ok((inputs, grammar, tool_grammar)) +} From f8a5b381fe9e755be476d0ab5b20826766ba40cc Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 13:09:50 -0400 Subject: [PATCH 05/72] feat: prefer stop over eos_token to align with openai finish_reason (#2344) --- router/src/lib.rs | 11 ++++++++++- router/src/server.rs | 4 ++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 386b0556..a956b058 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -619,7 +619,7 @@ impl ChatCompletion { message, logprobs: return_logprobs .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), - finish_reason: details.finish_reason.to_string(), + finish_reason: details.finish_reason.format(true), }], usage: Usage { prompt_tokens: details.prefill.len() as u32, @@ -1117,6 +1117,15 @@ impl std::fmt::Display for FinishReason { } } +impl FinishReason { + pub fn format(&self, use_stop: bool) -> String { + match self { + FinishReason::EndOfSequenceToken if use_stop => "stop".to_string(), + _ => self.to_string(), + } + } +} + #[derive(Serialize, ToSchema)] pub(crate) struct BestOfSequence { #[schema(example = "test")] diff --git a/router/src/server.rs b/router/src/server.rs index 7655182a..4b6fe50c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1021,7 +1021,7 @@ async fn completions( total_tokens += details.prefill.len() as u32 + details.generated_tokens; Ok(CompletionComplete { - finish_reason: details.finish_reason.to_string(), + finish_reason: details.finish_reason.format(true), index: index as u32, logprobs: None, text: generation.generated_text, @@ -1212,7 +1212,7 @@ async fn chat_completions( tool_calls, current_time, logprobs, - stream_token.details.map(|d| d.finish_reason.to_string()), + stream_token.details.map(|d| d.finish_reason.format(true)), ), )) .unwrap_or_else(|e| { From 1768c00b9f124fcf92f513165da5228f355f3ea1 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 13:10:19 -0400 Subject: [PATCH 06/72] feat: return the generated text when parsing fails (#2353) --- router/src/server.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 4b6fe50c..1d1cd36a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1246,9 +1246,13 @@ async fn chat_completions( .as_secs(); let (tool_calls, output) = if tool_grammar.is_some() { - let gen_text_value: Value = serde_json::from_str(&generation.generated_text) - .map_err(|e| InferError::ToolError(e.to_string()))?; - + let gen_text_value: Value = + serde_json::from_str(&generation.generated_text).map_err(|e| { + InferError::ToolError(format!( + "Failed to parse generated text: {} {:?}", + e, generation.generated_text + )) + })?; let function = gen_text_value.get("function").ok_or(InferError::ToolError( "No function found in generated text".to_string(), ))?; From a64d407d64de1ba168523df5391863b1f85c0824 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 13:33:22 -0400 Subject: [PATCH 07/72] fix: default num_ln_in_parallel_attn to one if not supplied (#2364) --- .../models/custom_modeling/flash_rw_modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 708641e7..0691da9b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -473,7 +473,9 @@ class FlashRWLayer(nn.Module): class FlashRWLayerNorm(nn.Module): def __init__(self, config, prefix: str, weights): super().__init__() - self.num_ln = config.num_ln_in_parallel_attn + # Falcon2 includes the number of layer norms in the config + # in the case no number of layer norms is provided, we default to 1 + self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1) if self.num_ln == 1: self.input_ln = FastLayerNorm.load( From 133015f40821706b1eaf9943aa3c9aa477d0c614 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 6 Aug 2024 15:25:30 -0400 Subject: [PATCH 08/72] fix: prefer original layernorm names for 180B (#2365) --- .../models/custom_modeling/flash_rw_modeling.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 0691da9b..fc002082 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -382,8 +382,13 @@ class FlashRWLayer(nn.Module): prefix = f"{prefix}.h.{layer_id}" + # NOTE: Falcon 180B uses the ln_attn prefix + ln_prefix = "input_layernorm" + if config.num_hidden_layers == 80: + ln_prefix = "ln_attn" + self.input_layernorm = FastLayerNorm.load( - prefix=f"{prefix}.input_layernorm", + prefix=f"{prefix}.{ln_prefix}", weights=weights, eps=config.layer_norm_epsilon, ) @@ -477,6 +482,10 @@ class FlashRWLayerNorm(nn.Module): # in the case no number of layer norms is provided, we default to 1 self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1) + # Falcon 180B uses the ln_attn prefix and has 2 layer norms + if config.num_hidden_layers == 80: + self.num_ln = 2 + if self.num_ln == 1: self.input_ln = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm", From 8094ecfc9ef22c838fa7d49db4af8301539619e3 Mon Sep 17 00:00:00 2001 From: almersawi <43927639+almersawi@users.noreply.github.com> Date: Thu, 8 Aug 2024 03:45:23 +0400 Subject: [PATCH 09/72] fix: fix num_ln_in_parallel_attn attribute name typo in RWConfig (#2350) Co-authored-by: Islam Almersawi --- .../models/custom_modeling/flash_rw_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index fc002082..10f995a3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -94,7 +94,7 @@ class RWConfig(PretrainedConfig): else kwargs.pop("n_head", 8) ) self.layer_norm_epsilon = layer_norm_epsilon - self.num_ln_in_parallel_attention = num_ln_in_prallel_attention + self.num_ln_in_parallel_attn = num_ln_in_prallel_attention self.initializer_range = initializer_range self.use_cache = use_cache self.hidden_dropout = hidden_dropout From 21267f3ca3f121302b86c1702cc2da6091164c55 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 7 Aug 2024 21:32:37 -0400 Subject: [PATCH 10/72] add gptj modeling in TGI #2366 (CI RUN) (#2372) * add gptj modeling Signed-off-by: Wang, Yi A * fix: update docs for model addition * fix: adjust syntax typo * fix: adjust syntax typo again --------- Signed-off-by: Wang, Yi A Co-authored-by: Wang, Yi A --- docs/source/supported_models.md | 1 + router/src/config.rs | 1 + .../text_generation_server/models/__init__.py | 43 ++ .../custom_modeling/flash_gptj_modeling.py | 405 ++++++++++++++++++ 4 files changed, 450 insertions(+) create mode 100644 server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index bc124f31..b78104df 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -32,6 +32,7 @@ Text Generation Inference enables serving optimized models on specific hardware - [Mpt](https://huggingface.co/mosaicml/mpt-7b-instruct) - [Gpt2](https://huggingface.co/openai-community/gpt2) - [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) +- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b) - [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal) diff --git a/router/src/config.rs b/router/src/config.rs index 7737165e..5d0be9c8 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -153,6 +153,7 @@ pub enum Config { Bloom, Mpt, Gpt2, + Gptj, GptNeox, Phi, #[serde(rename = "phi-msft")] diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ae791ef8..1f9c7526 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -132,6 +132,9 @@ try: from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( FlashGPT2ForCausalLM, ) + from text_generation_server.models.custom_modeling.flash_gptj_modeling import ( + FlashGPTJForCausalLM, + ) from text_generation_server.models.custom_modeling.idefics2 import ( Idefics2ForConditionalGeneration, ) @@ -294,6 +297,11 @@ class ModelType(enum.Enum): "name": "Gpt Neox", "url": "https://huggingface.co/EleutherAI/gpt-neox-20b", } + GPTJ = { + "type": "gptj", + "name": "Gptj", + "url": "https://huggingface.co/EleutherAI/gpt-j-6b", + } IDEFICS = { "type": "idefics", "name": "Idefics", @@ -641,6 +649,41 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif model_type == GPTJ: + if FLASH_ATTENTION: + try: + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPTJForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + ) + except RuntimeError as e: + # Lots of legacy models with various weight names. + log_master(logger.warning, f"Couldn't load flash gptj variant: {e}") + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J")) + else: + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif model_type == GPT_NEOX: if FLASH_ATTENTION: from text_generation_server.models.custom_modeling.flash_neox_modeling import ( diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py new file mode 100644 index 00000000..eb667384 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -0,0 +1,405 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from typing import Optional, List, Tuple + +from text_generation_server.layers.attention import ( + paged_attention, + attention, + reshape_and_cache, +) +from text_generation_server.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + SpeculativeHead, + get_linear, +) +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, +) +from text_generation_server.layers.layernorm import ( + FastLayerNorm, +) +from text_generation_server.utils.import_utils import SYSTEM + + +def load_attention(config, prefix: str, weights): + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def load_row(config, prefix: str, weights, bias: bool): + weight = weights.get_weights_row(prefix) + + if bias and weights.process_group.rank() == 0: + # Rank is only on the first rank process + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + + linear = get_linear(weight, bias) + return TensorParallelRowLinear(linear, process_group=weights.process_group) + + +class GPTJRotary(PositionRotaryEmbedding): + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + # Such controlflows may add some overhead. + if SYSTEM == "cuda": + import rotary_emb + + q1 = query[..., ::2] + q2 = query[..., 1::2] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + + k1 = key[..., ::2] + k2 = key[..., 1::2] + + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + elif SYSTEM == "rocm": + from vllm._C import ops + + # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. + # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 + + head_size = query.shape[-1] + + # Inplace operation, updating query and key. + ops.rotary_embedding(query, key, head_size, cos, sin, False) + elif SYSTEM == "ipex": + import intel_extension_for_pytorch as ipex + + ipex.llm.functional.rotary_embedding( + query, key, sin, cos, query.size(-1), False + ) + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) + + +class FlashGPTJAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + + self.head_size = self.hidden_size // self.num_heads + self.softmax_scale = self.head_size**-0.5 + self.rotary_dim = config.rotary_dim + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + + self.query_key_value = load_attention( + config, + prefix=prefix, + weights=weights, + ) + + self.o_proj = load_row( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=False, + ) + + self.kv_head_mapping = torch.arange( + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) + + self.rotary_emb = GPTJRotary.static( + config=config, + dim=self.rotary_dim, + base=10000, + device=weights.device, + ) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + query, key, value = self.query_key_value(hidden_states).split( + self.head_size * self.num_heads, dim=1 + ) + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_heads, self.head_size) + value = value.view(-1, self.num_heads, self.head_size) + + # Compute rotary embeddings on rotary_ndims + if self.rotary_dim is not None: + self.rotary_emb( + query[..., : self.rotary_dim], key[..., : self.rotary_dim], cos, sin + ) + else: + self.rotary_emb(query, key, cos, sin) + + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attn_output = attention( + query, + key, + value, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + attn_output = paged_attention( + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class GPTJMLP(nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + act = config.activation_function + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + + self.fc_in = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.fc_in", weights=weights, bias=True + ) + + self.fc_out = load_row( + config, + prefix=f"{prefix}.fc_out", + weights=weights, + bias=True, + ) + + def forward(self, hidden_states): + hidden_states = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + return self.fc_out(hidden_states) + + +class FlashGPTJLayer(nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + self.self_attn = FlashGPTJAttention( + prefix=f"{prefix}.attn", config=config, weights=weights + ) + self.mlp = GPTJMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + hidden_states, residual = self.input_layernorm(hidden_states, residual) + # Self Attention + attn_output = self.self_attn( + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + feed_forward_hidden_states = self.mlp(hidden_states) + + return attn_output + feed_forward_hidden_states, residual + + +class FlashGPTJModel(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + self.config = config + + self.wte = TensorParallelEmbedding(prefix=f"{prefix}.wte", weights=weights) + self.layers = nn.ModuleList( + [ + FlashGPTJLayer( + prefix=( + f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}" + ), + config=config, + weights=weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + + self.ln_f = FastLayerNorm.load( + prefix="ln_f" if not prefix else f"{prefix}.ln_f", + weights=weights, + eps=config.layer_norm_epsilon, + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + + def forward( + self, + input_ids: Optional[torch.LongTensor], + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.wte(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.ln_f(hidden_states, residual) + + return hidden_states + + +class FlashGPTJForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights): + super().__init__() + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + self.model = FlashGPTJModel(prefix, config, weights) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor] = None, + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices=prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits From a379d5536bb2de55154dc09c3a1f24ce58cb7df5 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 7 Aug 2024 23:14:02 -0400 Subject: [PATCH 11/72] Fix the prefix for OPT model in opt_modelling.py #2370 (CI RUN) (#2371) * Fix the bug * fix: run lints * fix: small syntax tweak --------- Co-authored-by: Sadra Barikbin --- integration-tests/models/test_opt.py | 19 ++++++++++++++ .../models/custom_modeling/opt_modeling.py | 25 ++++++++++--------- 2 files changed, 32 insertions(+), 12 deletions(-) create mode 100644 integration-tests/models/test_opt.py diff --git a/integration-tests/models/test_opt.py b/integration-tests/models/test_opt.py new file mode 100644 index 00000000..cbeb6376 --- /dev/null +++ b/integration-tests/models/test_opt.py @@ -0,0 +1,19 @@ +import pytest + + +@pytest.fixture(scope="module") +def opt_sharded_handle(launcher): + with launcher("facebook/opt-6.7b", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def opt_sharded(opt_sharded_handle): + await opt_sharded_handle.health(300) + return opt_sharded_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +async def test_opt(opt_sharded): + pass diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 84a1c069..bd440321 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -98,7 +98,9 @@ class OPTLearnedPositionalEmbedding(nn.Module): super().__init__() self.offset = 2 self.weight = nn.Parameter( - weights.get_tensor(f"{prefix}.decoder.embed_positions.weight") + weights.get_tensor( + f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight" + ) ) def forward( @@ -315,7 +317,7 @@ class OPTDecoderLayer(nn.Module): super().__init__() self.process_group = weights.process_group self.hidden_size = config.hidden_size - prefix = f"{prefix}.decoder.layers.{layer_id}" + prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}" self.self_attn = OPTAttention( config, prefix=f"{prefix}.self_attn", @@ -437,15 +439,17 @@ class OPTDecoder(OPTPreTrainedModel): self.max_target_positions = config.max_position_embeddings self.vocab_size = config.vocab_size + prefix = prefix + "." if prefix else "" + self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.decoder.embed_tokens", weights=weights + prefix=f"{prefix}decoder.embed_tokens", weights=weights ) self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights) if config.word_embed_proj_dim != config.hidden_size: self.project_out = FastLinear.load( config, - prefix=f"{prefix}.decoder.project_out", + prefix=f"{prefix}decoder.project_out", weights=weights, bias=False, ) @@ -455,7 +459,7 @@ class OPTDecoder(OPTPreTrainedModel): if config.word_embed_proj_dim != config.hidden_size: self.project_in = FastLinear.load( config, - prefix=f"{prefix}.decoder.project_in", + prefix=f"{prefix}decoder.project_in", weights=weights, bias=False, ) @@ -467,7 +471,7 @@ class OPTDecoder(OPTPreTrainedModel): # see https://github.com/facebookresearch/metaseq/pull/164 if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS + prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS ) else: self.final_layer_norm = None @@ -752,15 +756,12 @@ class OPTForCausalLM(OPTPreTrainedModel): def __init__(self, prefix, config, weights): super().__init__(config) - if not prefix: - prefix = "model" - else: - prefix = f"{prefix}.model" - self.model = OPTModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( - config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights + config, + prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens", + weights=weights, ) def forward( From 82d19d7723c085a0f0bd37494ee2e2c41e1323ca Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 8 Aug 2024 11:14:06 -0400 Subject: [PATCH 12/72] Pr 2374 ci branch (#2378) * Update __init__.py Fix issue with NoneType comparison for max_input_tokens and sliding_window - Add default values for max_input_tokens and sliding_window to handle None cases. - Ensure the comparison between max_input_tokens and sliding_window is handled correctly to prevent TypeError. - This change addresses the error: TypeError: '<=' not supported between instances of 'int' and 'NoneType'. * Update __init__.py Handle NoneType in sliding_window comparison to fix TypeError in __init__.py by ensuring the comparison logic accounts for NoneType values, preventing errors and improving code robustness. * fix: syntax/style tweak --------- Co-authored-by: Praz --- server/text_generation_server/models/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 1f9c7526..da14d083 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -490,7 +490,12 @@ def get_model( raise RuntimeError( "Sharding is currently not supported with `exl2` quantization" ) - sliding_window = config_dict.get("sliding_window", -1) + + sliding_window = ( + config_dict.get("sliding_window") + if config_dict.get("sliding_window") is not None + else -1 + ) if max_input_tokens is not None and max_input_tokens <= sliding_window: sliding_window = -1 From 689b1abbf68cd929f41b72b06cc9e44b266fed53 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Fri, 9 Aug 2024 00:08:52 +0800 Subject: [PATCH 13/72] fix EleutherAI/gpt-neox-20b does not work in tgi (#2346) Signed-off-by: Wang, Yi A --- .../models/custom_modeling/flash_neox_modeling.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b1b03ad7..67237d5c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -153,8 +153,16 @@ class FlashNeoxAttention(torch.nn.Module): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + # Compute rotary embeddings on rotary_ndims + query_rot = qkv[:, 0][..., : self.rotary_dim] + query_pass = qkv[:, 0][..., self.rotary_dim :] + key_rot = qkv[:, 1][..., : self.rotary_dim] + key_pass = qkv[:, 1][..., self.rotary_dim :] + # Inplace rotary - self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin) + self.rotary_emb(query_rot, key_rot, cos, sin) + qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) + qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) From 2ca598063443d0bf4ba3b1f0fb3dbc17e60f8a67 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 8 Aug 2024 12:30:29 -0400 Subject: [PATCH 14/72] Pr 2337 ci branch (#2379) * hotfix: fix xpu crash brought by code refine. torch.xpu rely on import ipex Signed-off-by: Wang, Yi A * reable gemma2 in xpu Signed-off-by: Wang, Yi A * fix in regression in ipex flashattention Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A Co-authored-by: Wang, Yi A --- server/text_generation_server/layers/attention/ipex.py | 7 ++++++- server/text_generation_server/utils/import_utils.py | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index e0956b26..d7cf780a 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -2,6 +2,7 @@ import intel_extension_for_pytorch as ipex import torch from text_generation_server.models.flash_causal_lm import BLOCK_SIZE from text_generation_server.layers.attention import Seqlen +from typing import Optional SUPPORTS_WINDOWING = False @@ -15,11 +16,12 @@ def attention( softmax_scale, window_size_left=-1, causal=True, + softcap: Optional[float] = None, ): out = torch.empty_like(q) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - return ipex.llm.functional.varlen_attention( + ipex.llm.functional.varlen_attention( q, k, v, @@ -36,6 +38,8 @@ def attention( None, ) + return out + def reshape_and_cache( key: torch.Tensor, @@ -58,6 +62,7 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + softcap: Optional[float] = None, ): out = torch.empty_like(query) ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 7c053014..782b4f15 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -56,6 +56,8 @@ elif torch.version.cuda is not None and torch.cuda.is_available(): get_free_memory = get_cuda_free_memory elif is_ipex_available(): SYSTEM = "ipex" + import intel_extension_for_pytorch # noqa: F401 + if hasattr(torch, "xpu") and torch.xpu.is_available(): empty_cache = torch.xpu.empty_cache synchronize = torch.xpu.synchronize From f8521900601578a070ab5bb4275cbb2cd45b8e01 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 8 Aug 2024 14:08:56 -0400 Subject: [PATCH 15/72] fix: prefer hidden_activation over hidden_act in gemma2 (#2381) --- .../models/custom_modeling/flash_gemma2_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index de86f514..54d212e6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -265,7 +265,7 @@ class FlashGemma2Attention(torch.nn.Module): class Gemma2MLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - act = config.hidden_act + act = config.hidden_activation self.act = ( ACT2FN[act] if "gelu" not in act From cb3ae30284ada6d15822a4ccde9156b8e93ef2b6 Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Thu, 8 Aug 2024 22:06:57 +0200 Subject: [PATCH 16/72] Update Quantization docs and minor doc fix. (#2368) * Update Quantization docs and minor doc fix. * update readme with latest quants info * Apply suggestions from code review Co-authored-by: Pedro Cuenca * up --------- Co-authored-by: Pedro Cuenca --- README.md | 2 + docs/openapi.json | 2 +- .../source/basic_tutorials/preparing_model.md | 2 +- .../basic_tutorials/visual_language_models.md | 2 +- docs/source/conceptual/quantization.md | 59 +++++++++++-------- 5 files changed, 41 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index a88e0437..cf7f1d22 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,8 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan - [GPT-Q](https://arxiv.org/abs/2210.17323) - [EETQ](https://github.com/NetEase-FuXi/EETQ) - [AWQ](https://github.com/casper-hansen/AutoAWQ) + - [Marlin](https://github.com/IST-DASLab/marlin) + - [fp8]() - [Safetensors](https://github.com/huggingface/safetensors) weight loading - Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) diff --git a/docs/openapi.json b/docs/openapi.json index ed9b0b96..9d281a48 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2080,4 +2080,4 @@ "description": "Hugging Face Text Generation Inference API" } ] -} +} \ No newline at end of file diff --git a/docs/source/basic_tutorials/preparing_model.md b/docs/source/basic_tutorials/preparing_model.md index 71ca5598..456ade44 100644 --- a/docs/source/basic_tutorials/preparing_model.md +++ b/docs/source/basic_tutorials/preparing_model.md @@ -4,7 +4,7 @@ Text Generation Inference improves the model in several aspects. ## Quantization -TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization) +TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [Marlin](https://github.com/IST-DASLab/marlin), [EETQ](https://github.com/NetEase-FuXi/EETQ), [EXL2](https://github.com/turboderp/exllamav2), and [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq`, `awq`, `marlin`, `exl2`, `eetq` or `fp8` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq). Similarly, when using AWQ quantization, you need to point to one of [these models](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization) ## RoPE Scaling diff --git a/docs/source/basic_tutorials/visual_language_models.md b/docs/source/basic_tutorials/visual_language_models.md index 3770db0b..f152a2f0 100644 --- a/docs/source/basic_tutorials/visual_language_models.md +++ b/docs/source/basic_tutorials/visual_language_models.md @@ -84,7 +84,7 @@ print(chat) ``` -or with OpenAi's library: +or with OpenAI's [client library](https://github.com/openai/openai-python): ```python from openai import OpenAI diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md index 8f26fdba..7507687f 100644 --- a/docs/source/conceptual/quantization.md +++ b/docs/source/conceptual/quantization.md @@ -1,6 +1,40 @@ # Quantization -TGI offers GPTQ and bits-and-bytes quantization to quantize large language models. +TGI offers many quantization schemes to run LLMs effectively and fast based on your use-case. TGI supports GPTQ, AWQ, bits-and-bytes, EETQ, Marlin, EXL2 and fp8 quantization. + +To leverage GPTQ, AWQ, Marlin and EXL2 quants, you must provide pre-quantized weights. Whereas for bits-and-bytes, EETQ and fp8, weights are quantized by TGI on the fly. + +We recommend using the official quantization scripts for creating your quants: +1. [AWQ](https://github.com/casper-hansen/AutoAWQ/blob/main/examples/quantize.py) +2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py) +3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md) + +For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. + +## Quantization with bitsandbytes + +bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing โ€“ weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision. + +8-bit quantization enables multi-billion parameter scale models to fit in smaller hardware without degrading performance too much. +In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below ๐Ÿ‘‡ + +```bash +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes +``` + +4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load. + +In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below ๐Ÿ‘‡ + +```bash +docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4 +``` + +You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). + +Use `eetq` or `fp8` for other quantization schemes. + +In addition to this, TGI allows creating GPTQ quants directly by passing the model weights and a calibration dataset. ## Quantization with GPTQ @@ -35,25 +69,4 @@ text-generation-launcher --model-id /data/falcon-40b-gptq/ --sharded true --num- You can learn more about the quantization options by running `text-generation-server quantize --help`. If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration). -You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). - -## Quantization with bitsandbytes - -bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing โ€“ weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision. - -8-bit quantization enables multi-billion parameter scale models to fit in smaller hardware without degrading performance too much. -In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below ๐Ÿ‘‡ - -```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes -``` - -4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load. - -In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below ๐Ÿ‘‡ - -```bash -docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4 -``` - -You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). +You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). \ No newline at end of file From 6d06473cf48e19e7382b27940f993a5f48c83997 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 9 Aug 2024 04:54:32 -0400 Subject: [PATCH 17/72] Pr 2352 ci branch (#2382) * Fix unsigned integer underflow Passing --max-batch-size to the launcher actually had no effect because after a few requests the max_size passed to State::next_batch would underflow becoming a largo positive number. In the scheduler, as soon as the cached batch size reached the max_batch_size the max_size passed to next_batch becomes 0. Since the only check in that funcion is ``` if Some(batch_requests.len()) == max_size { break; } ``` and it's called after the `batch_requests.len()` has become 1, it doesn't do anything to prevent more than 0 requests from being batched. Now we have cached batch in the server that is large than max_batch_size and `max_size - batch_size as usize` underflows. Signed-off-by: Max de Bayser * fix: update v3 scheduler and ensure max_batch_size > 0 --------- Signed-off-by: Max de Bayser Co-authored-by: Max de Bayser --- backends/v3/src/backend.rs | 3 ++- backends/v3/src/main.rs | 8 ++++++++ backends/v3/src/queue.rs | 7 +++++++ router/src/infer/v2/queue.rs | 7 +++++++ router/src/infer/v2/scheduler.rs | 4 ++-- 5 files changed, 26 insertions(+), 3 deletions(-) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index d82355de..6b3e0526 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -168,7 +168,8 @@ pub(crate) async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 21952e66..471ddb5a 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -150,6 +150,14 @@ async fn main() -> Result<(), RouterError> { } } + if let Some(max_batch_size) = max_batch_size { + if max_batch_size == 0 { + return Err(RouterError::ArgumentValidation( + "`max_batch_size` must be > 0".to_string(), + )); + } + } + let (backend, _backend_info) = connect_backend( max_input_tokens, max_total_tokens, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 9427bd60..b457389c 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -226,6 +226,13 @@ impl State { } } + if let Some(max_size) = max_size { + if max_size == 0 { + tracing::debug!("No capacity"); + return None; + } + } + // Pad prefill_token_budget to be a multiple of block size let prefill_token_budget = ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 0b51645a..696cbfc8 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -205,6 +205,13 @@ impl State { } } + if let Some(max_size) = max_size { + if max_size == 0 { + tracing::debug!("No capacity"); + return None; + } + } + // Pad prefill_token_budget to be a multiple of block size let prefill_token_budget = ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index 3d6c36cf..cc333674 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -161,8 +161,8 @@ pub(crate) async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); - + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) From 7830de1566df365e6cb9ce0a955e8e2ac1b28ca6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 9 Aug 2024 11:42:00 +0200 Subject: [PATCH 18/72] Add FlashInfer support (#2354) This change adds support for FlashInfer. FlashInfer can be enabled using `FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`. Since this functionality is currently only for testing, FlashInfer is not installed anywhere yet. The FlashInfer API is quite different from FlashAttention/vLLM in that it requires more global bookkeeping: * A wrapper class needs to be contstructed (which we just call *state*). Since this is fairly expensive (due to pinned host memory allocation), we only do this once in a FlashCausalLM instance or for each CUDA Graph size. * Each model forward call needs to be wrapped in `begin_forward` and `end_forward`. This sets up data structures that can be reused for all calls to attention for that forward call. When calling attention, we need access to the state object. To avoid passing an argument down the call chain (which would require changes to all models), we use a context variable. Each model forward call is wrapped using a context manager that does all the bookkeeping for such a call: * Set the context variable to the forward call's state. * Call `begin_forward` on the state. * Yield. * Call `end_forward` on the state. * Reset the context variable. We cannot use a single shared global variable for this, since e.g. CUDA Graphs of different sizes each have their own state. --- .../layers/attention/common.py | 4 +- .../layers/attention/cuda.py | 46 ++++- .../layers/attention/flash_infer.py | 164 +++++++++++++++++ .../models/flash_causal_lm.py | 171 ++++++++++++++---- .../text_generation_server/models/globals.py | 5 + 5 files changed, 346 insertions(+), 44 deletions(-) create mode 100644 server/text_generation_server/layers/attention/flash_infer.py diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index bd0717ce..b986a082 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER import torch from typing import Optional -if FLASH_DECODING: +if FLASH_DECODING or FLASH_INFER: @dataclass class Seqlen: diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 96b654d0..1b8e9209 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,6 +1,10 @@ import torch from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE +from text_generation_server.models.globals import ( + FLASH_DECODING, + BLOCK_SIZE, + FLASH_INFER, +) from text_generation_server.layers.attention import Seqlen from typing import Optional @@ -23,7 +27,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING: + if FLASH_DECODING or FLASH_INFER: shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value @@ -72,7 +76,16 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - if FLASH_DECODING: + if FLASH_INFER: + from text_generation_server.layers.attention.flash_infer import decode_state + + return decode_state.get().forward( + query.contiguous(), + paged_kv_cache=(key_cache, value_cache), + logits_soft_cap=softcap, + sm_scale=softmax_scale, + ) + elif FLASH_DECODING: max_q = 1 max_k = max_s import flash_attn_2_cuda @@ -206,7 +219,32 @@ except ImportError: SUPPORTS_WINDOWING = V2 -if V2: +if FLASH_INFER: + + def attention( + q, + k, + v, + cu_seqlens, + max_s, + softmax_scale, + window_size_left=-1, + causal=True, + softcap=0.0, + ): + from text_generation_server.layers.attention.flash_infer import prefill_state + + return prefill_state.get().forward( + q, + k, + v, + causal=causal, + window_left=window_size_left, + logits_soft_cap=softcap, + sm_scale=softmax_scale, + ) + +elif V2: def attention( q, diff --git a/server/text_generation_server/layers/attention/flash_infer.py b/server/text_generation_server/layers/attention/flash_infer.py new file mode 100644 index 00000000..56b53b2c --- /dev/null +++ b/server/text_generation_server/layers/attention/flash_infer.py @@ -0,0 +1,164 @@ +from typing import Optional +from contextvars import ContextVar +from contextlib import contextmanager + +import flashinfer +import torch + +prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar( + "prefill_state" +) + +decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( + "decode_state" +) + +workspace: Optional[torch.Tensor] = None + + +def get_workspace(device): + """Get shared flashinfer workspace.""" + global workspace + if workspace is None: + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + return workspace + + +def create_prefill_state( + *, + device: torch.device, +): + """Create a prefill state.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, kv_layout="NHD", use_cuda_graph=False + ) + + +@contextmanager +def use_prefill_state( + *, + state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper, + cu_seqlens: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer prefill state to the given + `state` and parameters. This state will be used by all calls to the + `attention` function while the context manager is active. + """ + + token = prefill_state.set(state) + try: + state.begin_forward( + qo_indptr=cu_seqlens, + kv_indptr=cu_seqlens, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + q_data_type=query_dtype, + ) + yield + finally: + state.end_forward() + if token is not None: + prefill_state.reset(token) + + +def create_decode_state( + *, + device: torch.device, + num_heads: int, + num_kv_heads: int, +): + """Create a decode state.""" + workspace_buffer = get_workspace(device) + return flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout="NHD", + use_cuda_graph=False, + use_tensor_cores=num_heads // num_kv_heads > 4, + ) + + +def create_decode_state_cuda_graphs( + *, + device: torch.device, + block_tables: torch.Tensor, + block_tables_ptr: torch.Tensor, + last_page_len: torch.Tensor, + num_heads: int, + num_kv_heads: int, +): + """ + Create a decode state for use with CUDA Graphs. `block_tables`, + `block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are + therefore stored as part of the state. + """ + workspace_buffer = get_workspace(device) + return flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout="NHD", + use_cuda_graph=True, + paged_kv_indices_buffer=block_tables, + paged_kv_indptr_buffer=block_tables_ptr, + paged_kv_last_page_len_buffer=last_page_len, + use_tensor_cores=num_heads // num_kv_heads > 4, + ) + + +@contextmanager +def use_decode_state( + *, + state: flashinfer.BatchDecodeWithPagedKVCacheWrapper, + input_lengths: torch.Tensor, + block_tables: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_size: int, + page_size: int, + query_dtype: str = "float16", +): + """ + Context manager to set the active flashinfer decoding state to the given + `state` and parameters. This state will be used by all calls to the + `paged_attention` function while the context manager is active. + """ + indptr = torch.zeros( + input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32 + ) + # Round up to page size and then calculate the cumulative sum to get + # the indices into the block table. + torch.add(input_lengths, page_size - 1, out=indptr[1:]) + indptr[1:].div_(page_size, rounding_mode="floor") + indptr[1:].cumsum_(-1) + + # Get the lengths of the last page in a block. + last_page_len = torch.empty( + input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device + ) + torch.sub(input_lengths, 1, out=last_page_len) + last_page_len.remainder_(page_size) + last_page_len += 1 + + token = decode_state.set(state) + + try: + state.begin_forward( + indptr=indptr, + indices=block_tables, + last_page_len=last_page_len, + num_qo_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + page_size=page_size, + q_data_type=query_dtype, + ) + yield + finally: + state.end_forward() + if token is not None: + decode_state.reset(token) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 36bb2662..12aa7dcd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext import math import os import time @@ -15,7 +16,7 @@ from transformers import ( AutoTokenizer, GenerationConfig, ) -from typing import Iterable, Optional, Tuple, List, Type, Dict +from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -40,6 +41,7 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( MEM_POOL, FLASH_DECODING, + FLASH_INFER, BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, @@ -907,6 +909,7 @@ class FlashCausalLM(Model): config.sliding_window = None self.num_layers = config.num_hidden_layers + self.num_heads = config.num_attention_heads # Validation is done in the model itself if num_kv_heads is None: num_kv_heads = getattr(config, "num_key_value_heads", None) @@ -935,6 +938,21 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] + if FLASH_INFER: + from text_generation_server.layers.attention.flash_infer import ( + create_prefill_state, + create_decode_state, + ) + + self.prefill_state = create_prefill_state(device=device) + + if not CUDA_GRAPHS: + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + super().__init__( model_id=model_id, model=model, @@ -972,7 +990,7 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - if FLASH_DECODING: + if FLASH_DECODING or FLASH_INFER: self.kv_cache = [ ( torch.empty( @@ -1044,38 +1062,66 @@ class FlashCausalLM(Model): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph + if FLASH_INFER: + from text_generation_server.layers.attention.flash_infer import ( + create_decode_state_cuda_graphs, + ) + + block_tables_ptr = torch.zeros( + bs + 1, dtype=torch.int32, device=self.device + ) + last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) + state = create_decode_state_cuda_graphs( + device=input_ids.device, + block_tables=block_tables.view(-1), + block_tables_ptr=block_tables_ptr, + last_page_len=last_page_len, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + self.cuda_graphs[bs]["state"] = state + else: + state = None + torch.cuda.synchronize() # Run once outside to warmup - self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, + with self._forward_context( block_tables=block_tables, - slots=slots, - input_lengths=input_lengths_, - max_s=max_s, - prefill_cache_indices=None, - lm_head_indices=None, - ) - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - input_lengths = Seqlen(input_lengths=input_lengths) - logits, speculative_logits = self.model.forward( + cu_seqlen_prefill=None, + input_lengths=input_lengths, + state=state, + ): + self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + input_lengths=input_lengths_, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, ) - self.cuda_graphs[bs]["logits"] = logits - self.cuda_graphs[bs]["speculative_logits"] = speculative_logits + + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + input_lengths = Seqlen(input_lengths=input_lengths) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() def warmup(self, batch: FlashCausalLMBatch): @@ -1295,23 +1341,28 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = Seqlen(input_lengths=input_lengths) - logits, speculative_logits = self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, + with self._forward_context( block_tables=block_tables, - slots=slots, + cu_seqlen_prefill=cu_seqlen_prefill, input_lengths=input_lengths, - max_s=max_s, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=lm_head_indices, - adapter_data=adapter_data, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits, speculative_logits + ): + input_lengths = Seqlen(input_lengths=input_lengths) + logits, speculative_logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + adapter_data=adapter_data, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded @@ -1325,8 +1376,16 @@ class FlashCausalLM(Model): cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - # Replay the graph - cuda_graph["graph"].replay() + state = cuda_graph.get("state") + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=None, + input_lengths=input_lengths, + state=state, + ): + # Replay the graph + cuda_graph["graph"].replay() + # Slice output to the correct shape speculative_logits = ( cuda_graph["speculative_logits"][:bs] @@ -1698,3 +1757,39 @@ class FlashCausalLM(Model): forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) + + def _forward_context( + self, + *, + block_tables: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + input_lengths: torch.Tensor, + state: Optional[Any] = None, + ) -> ContextManager: + if not FLASH_INFER: + return nullcontext() + + from text_generation_server.layers.attention.flash_infer import ( + use_decode_state, + use_prefill_state, + ) + + if cu_seqlen_prefill is not None: + return use_prefill_state( + state=state if state is not None else self.prefill_state, + cu_seqlens=cu_seqlen_prefill, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + ) + else: + assert input_lengths is not None + return use_decode_state( + state=state if state is not None else self.decode_state, + input_lengths=input_lengths, + block_tables=block_tables.view(-1), + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + ) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 8d2431db..42b43c87 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,6 +5,10 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master +FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"} +if FLASH_INFER: + log_master(logger.info, "Using FLASH_INFER") + MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} @@ -12,6 +16,7 @@ BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 if FLASH_DECODING: log_master(logger.info, "Using FLASH_DECODING") + cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: From c6d5039cd7c05c9d1323b94844ca30ab593f21f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 9 Aug 2024 12:32:37 +0200 Subject: [PATCH 19/72] Add experimental flake (#2384) Add flake.nix --- flake.lock | 99 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ flake.nix | 73 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+) create mode 100644 flake.lock create mode 100644 flake.nix diff --git a/flake.lock b/flake.lock new file mode 100644 index 00000000..bf06240f --- /dev/null +++ b/flake.lock @@ -0,0 +1,99 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1723099294, + "narHash": "sha256-kkijy6sXo/SOhFw7ZEfYHbj1FJHxoeetOVOn3qNHc5o=", + "owner": "danieldk", + "repo": "nixpkgs", + "rev": "45892b6ec142eaf300d777926983a433b5842c88", + "type": "github" + }, + "original": { + "owner": "danieldk", + "ref": "cudnn-9.3", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": [ + "tgi-nix", + "nixpkgs" + ], + "tgi-nix": "tgi-nix" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "tgi-nix": { + "inputs": { + "flake-compat": "flake-compat", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1723188417, + "narHash": "sha256-GXdFuRMU9N9W0CryUTJIWhJkGwQFHSR2EW5xR0ZyBjk=", + "owner": "danieldk", + "repo": "tgi-nix", + "rev": "491db7e234ecf79513ddb94da6ecc14167b9c0b3", + "type": "github" + }, + "original": { + "owner": "danieldk", + "repo": "tgi-nix", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 00000000..c17950c1 --- /dev/null +++ b/flake.nix @@ -0,0 +1,73 @@ +{ + inputs = { + tgi-nix.url = "github:danieldk/tgi-nix"; + nixpkgs.follows = "tgi-nix/nixpkgs"; + flake-utils.url = "github:numtide/flake-utils"; + }; + outputs = + { + self, + nixpkgs, + flake-utils, + tgi-nix, + }: + flake-utils.lib.eachDefaultSystem ( + system: + let + config = { + allowUnfree = true; + cudaSupport = true; + }; + pkgs = import nixpkgs { + inherit config system; + overlays = [ tgi-nix.overlay ]; + }; + in + { + devShells.default = + with pkgs; + mkShell { + buildInputs = + [ + cargo + openssl.dev + pkg-config + ] + ++ (with python3.pkgs; [ + venvShellHook + pip + + einops + fbgemm-gpu + flash-attn + flash-attn-layer-norm + flash-attn-rotary + grpc-interceptor + grpcio-reflection + grpcio-status + hf-transfer + loguru + marlin-kernels + opentelemetry-api + opentelemetry-exporter-otlp + opentelemetry-instrumentation-grpc + opentelemetry-semantic-conventions + peft + tokenizers + torch + transformers + vllm + ]); + + venvDir = "./.venv"; + + postVenv = '' + unset SOURCE_DATE_EPOCH + ''; + postShellHook = '' + unset SOURCE_DATE_EPOCH + ''; + }; + } + ); +} From 952b450a3b4ddf7e41c24f807357a158fd3eae0c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 9 Aug 2024 14:25:44 +0200 Subject: [PATCH 20/72] Using HF_HOME instead of CACHE to get token read in addition to models. (#2288) --- Dockerfile | 2 +- Dockerfile_amd | 2 +- Dockerfile_intel | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 0d57e38d..c68f76f6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -191,7 +191,7 @@ ENV PATH=/opt/conda/bin:$PATH \ CONDA_PREFIX=/opt/conda # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 diff --git a/Dockerfile_amd b/Dockerfile_amd index 51231638..cdad0d28 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -167,7 +167,7 @@ RUN python setup.py build FROM base AS base-copy # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 diff --git a/Dockerfile_intel b/Dockerfile_intel index d20f0a01..158c5a89 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -57,7 +57,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 From 977534bcb8b078ed7ee9df1ae8038d083cd1a58d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 9 Aug 2024 14:56:20 +0200 Subject: [PATCH 21/72] flake: add fmt and clippy (#2389) --- flake.nix | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flake.nix b/flake.nix index c17950c1..a67879fc 100644 --- a/flake.nix +++ b/flake.nix @@ -30,8 +30,10 @@ buildInputs = [ cargo + clippy openssl.dev pkg-config + rustfmt ] ++ (with python3.pkgs; [ venvShellHook From b2b9c427246d895a943fa6a5f8e9b702eff03559 Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Fri, 9 Aug 2024 15:01:34 +0200 Subject: [PATCH 22/72] Update documentation for Supported models (#2386) * Minor doc fixes * up. * Other minor updates. --- README.md | 34 ++++++++++++++++--- docs/source/conceptual/quantization.md | 4 +-- docs/source/quicktour.md | 2 +- docs/source/supported_models.md | 8 ++--- .../text_generation_server/models/__init__.py | 6 ++-- update_doc.py | 2 +- 6 files changed, 41 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index cf7f1d22..803e9172 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Swagger API documentation -A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co) +A Rust, Python and gRPC server for text generation inference. Used in production at [Hugging Face](https://huggingface.co) to power Hugging Chat, the Inference API and Inference Endpoint. @@ -42,6 +42,7 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan - Tensor Parallelism for faster inference on multiple GPUs - Token streaming using Server-Sent Events (SSE) - Continuous batching of incoming requests for increased total throughput +- [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) compatible with Open AI Chat Completion API - Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures - Quantization with : - [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) @@ -49,7 +50,7 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan - [EETQ](https://github.com/NetEase-FuXi/EETQ) - [AWQ](https://github.com/casper-hansen/AutoAWQ) - [Marlin](https://github.com/IST-DASLab/marlin) - - [fp8]() + - [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) - [Safetensors](https://github.com/huggingface/safetensors) weight loading - Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) @@ -94,6 +95,29 @@ curl 127.0.0.1:8080/generate_stream \ -H 'Content-Type: application/json' ``` +You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses. + +```bash +curl localhost:3000/v1/chat/completions \ + -X POST \ + -d '{ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is deep learning?" + } + ], + "stream": true, + "max_tokens": 20 +}' \ + -H 'Content-Type: application/json' +``` + **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above. @@ -122,7 +146,7 @@ For example, if you want to serve the gated Llama V2 model variants: or with Docker: ```shell -model=meta-llama/Llama-2-7b-chat-hf +model=meta-llama/Meta-Llama-3.1-8B-Instruct volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run token= @@ -234,7 +258,7 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 ### Quantization -You can also quantize the weights with bitsandbytes to reduce the VRAM requirement: +You can also run pre-quantized weights (AWQ, GPTQ, Marlin) or on-the-fly quantize weights with bitsandbytes, EETQ, fp8, to reduce the VRAM requirement: ```shell text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize @@ -242,6 +266,8 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantiz 4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`. +Read more about quantization in the [Quantization documentation](https://huggingface.co/docs/text-generation-inference/en/conceptual/quantization). + ## Develop ```shell diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md index 7507687f..a1ebe7e7 100644 --- a/docs/source/conceptual/quantization.md +++ b/docs/source/conceptual/quantization.md @@ -11,7 +11,7 @@ We recommend using the official quantization scripts for creating your quants: For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. -## Quantization with bitsandbytes +## Quantization with bitsandbytes, EETQ & fp8 bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing โ€“ weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision. @@ -32,7 +32,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingf You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). -Use `eetq` or `fp8` for other quantization schemes. +Similarly you can use pass you can pass `--quantize eetq` or `--quantize fp8` for respective quantization schemes. In addition to this, TGI allows creating GPTQ quants directly by passing the model weights and a calibration dataset. diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index 2313c69b..18e1a107 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -21,7 +21,7 @@ TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPU ## Consuming TGI -Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint. +Once TGI is running, you can use the `generate` endpoint or the Open AI Chat Completion API compatible [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint. diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index b78104df..832f88ef 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -1,22 +1,22 @@ # Supported Models and Hardware -Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models are hardware are supported. +Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models (VLMs & LLMs) are supported. ## Supported Models - [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) -- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) +- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Gemma](https://huggingface.co/google/gemma-7b) - [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224) -- [Gemma2](https://huggingface.co/google/gemma2-9b) +- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct) - [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj) -- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) +- [Mistral](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407) - [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1) - [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder) - [Phi](https://huggingface.co/microsoft/phi-1_5) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index da14d083..960b426b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -180,7 +180,7 @@ class ModelType(enum.Enum): LLAMA = { "type": "llama", "name": "Llama", - "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct", + "url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f", } PHI3 = { "type": "phi3", @@ -200,7 +200,7 @@ class ModelType(enum.Enum): GEMMA2 = { "type": "gemma2", "name": "Gemma2", - "url": "https://huggingface.co/google/gemma2-9b", + "url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315", } COHERE = { "type": "cohere", @@ -220,7 +220,7 @@ class ModelType(enum.Enum): MISTRAL = { "type": "mistral", "name": "Mistral", - "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2", + "url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407", } MIXTRAL = { "type": "mixtral", diff --git a/update_doc.py b/update_doc.py index 428d4452..e887e1c6 100644 --- a/update_doc.py +++ b/update_doc.py @@ -7,7 +7,7 @@ import os TEMPLATE = """ # Supported Models and Hardware -Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models are hardware are supported. +Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models (VLMs & LLMs) are supported. ## Supported Models From 6e127dcc9616a0e39efd95a3e08324138e8c4df7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 9 Aug 2024 15:24:21 +0200 Subject: [PATCH 23/72] flake: use rust-overlay (#2390) --- flake.lock | 22 ++++++++++++++++++++++ flake.nix | 19 +++++++++++++++---- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/flake.lock b/flake.lock index bf06240f..043218e8 100644 --- a/flake.lock +++ b/flake.lock @@ -56,9 +56,31 @@ "tgi-nix", "nixpkgs" ], + "rust-overlay": "rust-overlay", "tgi-nix": "tgi-nix" } }, + "rust-overlay": { + "inputs": { + "nixpkgs": [ + "tgi-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1723170066, + "narHash": "sha256-SFkQfOA+8AIYJsPlQtxNP+z5jRLfz91z/aOrV94pPmw=", + "owner": "oxalica", + "repo": "rust-overlay", + "rev": "fecfe4d7c96fea2982c7907997b387a6b52c1093", + "type": "github" + }, + "original": { + "owner": "oxalica", + "repo": "rust-overlay", + "type": "github" + } + }, "systems": { "locked": { "lastModified": 1681028828, diff --git a/flake.nix b/flake.nix index a67879fc..fdd67d00 100644 --- a/flake.nix +++ b/flake.nix @@ -3,12 +3,17 @@ tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; + rust-overlay = { + url = "github:oxalica/rust-overlay"; + inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; + }; }; outputs = { self, nixpkgs, flake-utils, + rust-overlay, tgi-nix, }: flake-utils.lib.eachDefaultSystem ( @@ -20,7 +25,10 @@ }; pkgs = import nixpkgs { inherit config system; - overlays = [ tgi-nix.overlay ]; + overlays = [ + rust-overlay.overlays.default + tgi-nix.overlay + ]; }; in { @@ -29,11 +37,14 @@ mkShell { buildInputs = [ - cargo - clippy openssl.dev pkg-config - rustfmt + (rust-bin.stable.latest.default.override { + extensions = [ + "rust-analyzer" + "rust-src" + ]; + }) ] ++ (with python3.pkgs; [ venvShellHook From 7a48a84784b74f9ea12cf04a3c4572c027cec2e4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 9 Aug 2024 16:41:17 +0200 Subject: [PATCH 24/72] Using an enum for flash backens (paged/flashdecoding/flashinfer) (#2385) * Using an enum for flash backens (paged/flashdecoding/flashinfer) * Early exit on server too. * Clippy. * Fix clippy and fmt. --- .gitignore | 1 + backends/v3/src/backend.rs | 16 ++++++---- docs/openapi.json | 2 +- docs/source/conceptual/quantization.md | 4 +-- launcher/src/main.rs | 2 +- router/src/infer/v2/scheduler.rs | 18 ++++++++---- router/src/lib.rs | 29 +++++++++++++++++++ .../layers/attention/common.py | 4 +-- .../layers/attention/cuda.py | 11 ++++--- .../layers/attention/rocm.py | 4 +-- .../models/flash_causal_lm.py | 11 ++++--- .../text_generation_server/models/globals.py | 14 ++++----- 12 files changed, 78 insertions(+), 38 deletions(-) diff --git a/.gitignore b/.gitignore index 0de8b848..bd9d9125 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp data/ load_tests/*.json +server/fbgemmm diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 6b3e0526..68ddf00b 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -6,7 +6,7 @@ use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; -use text_generation_router::{FinishReason, PrefillToken, Token}; +use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; @@ -35,12 +35,18 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { - let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { - matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + let attention = if let Ok(attention) = std::env::var("ATTENTION") { + attention + .parse() + .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) } else { - false + Attention::Paged + }; + let block_size = if attention == Attention::FlashDecoding { + 256 + } else { + 16 }; - let block_size = if flashdecoding { 256 } else { 16 }; let queue = Queue::new( requires_padding, diff --git a/docs/openapi.json b/docs/openapi.json index 9d281a48..ed9b0b96 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2080,4 +2080,4 @@ "description": "Hugging Face Text Generation Inference API" } ] -} \ No newline at end of file +} diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md index a1ebe7e7..b7672a9f 100644 --- a/docs/source/conceptual/quantization.md +++ b/docs/source/conceptual/quantization.md @@ -9,7 +9,7 @@ We recommend using the official quantization scripts for creating your quants: 2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py) 3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md) -For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. +For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. ## Quantization with bitsandbytes, EETQ & fp8 @@ -69,4 +69,4 @@ text-generation-launcher --model-id /data/falcon-40b-gptq/ --sharded true --num- You can learn more about the quantization options by running `text-generation-server quantize --help`. If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration). -You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). \ No newline at end of file +You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 8acfda0c..a64b1d71 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1461,7 +1461,7 @@ fn main() -> Result<(), LauncherError> { if config.model_type == Some("gemma2".to_string()) { tracing::info!("Forcing flash decoding because of softcap usage"); - std::env::set_var("FLASH_DECODING", "1"); + std::env::set_var("ATTENTION", "flashdecoding"); } let config: Config = config.into(); diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index cc333674..0e5fc8a3 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -1,10 +1,10 @@ /// Batching and inference logic use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::{ - Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, + Attention, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, }; use crate::validation::ValidGenerateRequest; -use crate::{FinishReason, PrefillToken, Token}; +use crate::{Attention, FinishReason, PrefillToken, Token}; use nohash_hasher::IntMap; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -40,12 +40,18 @@ impl BackendV2 { generation_health: Arc, ) -> Self { // Infer shared state - let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { - matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + let attention = if let Ok(attention) = std::env::var("ATTENTION") { + attention + .parse() + .expect(&format!("Invalid attention was specified :`{attention}`")) } else { - false + Attention::Paged + }; + let block_size = if attention == Attention::FlashDecoding { + 256 + } else { + 16 }; - let block_size = if flashdecoding { 256 } else { 16 }; let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); diff --git a/router/src/lib.rs b/router/src/lib.rs index a956b058..66738706 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -15,6 +15,35 @@ use tracing::warn; use utoipa::ToSchema; use validation::Validation; +#[derive(PartialEq)] +pub enum Attention { + Paged, + FlashDecoding, + FlashInfer, +} + +#[derive(Debug)] +pub struct ParseError; + +impl std::fmt::Display for ParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Cannot parse attention value") + } +} +impl std::error::Error for ParseError {} + +impl std::str::FromStr for Attention { + type Err = ParseError; + fn from_str(s: &str) -> Result { + match s { + "paged" => Ok(Attention::Paged), + "flashdecoding" => Ok(Attention::FlashDecoding), + "flashinfer" => Ok(Attention::FlashInfer), + _ => Err(ParseError), + } + } +} + #[derive(Clone, Deserialize, ToSchema)] pub(crate) struct VertexInstance { #[schema(example = "What is Deep Learning?")] diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index b986a082..f162230c 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from text_generation_server.models.globals import FLASH_DECODING, FLASH_INFER +from text_generation_server.models.globals import ATTENTION import torch from typing import Optional -if FLASH_DECODING or FLASH_INFER: +if ATTENTION in {"flashinfer", "flashdecoding"}: @dataclass class Seqlen: diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 1b8e9209..d039e1e7 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,9 +1,8 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import ( - FLASH_DECODING, + ATTENTION, BLOCK_SIZE, - FLASH_INFER, ) from text_generation_server.layers.attention import Seqlen from typing import Optional @@ -27,7 +26,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING or FLASH_INFER: + if ATTENTION in {"flashdecoding", "flashinfer"}: shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value @@ -76,7 +75,7 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - if FLASH_INFER: + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flash_infer import decode_state return decode_state.get().forward( @@ -85,7 +84,7 @@ def paged_attention( logits_soft_cap=softcap, sm_scale=softmax_scale, ) - elif FLASH_DECODING: + elif ATTENTION == "flashdecoding": max_q = 1 max_k = max_s import flash_attn_2_cuda @@ -219,7 +218,7 @@ except ImportError: SUPPORTS_WINDOWING = V2 -if FLASH_INFER: +if ATTENTION == "flashinfer": def attention( q, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 69e64162..16ce8d2b 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,7 +1,7 @@ import os import torch from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING +from text_generation_server.models.globals import ATTENTION from text_generation_server.layers.attention import Seqlen from text_generation_server.utils.log import log_master from loguru import logger @@ -28,7 +28,7 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if FLASH_DECODING: + if ATTENTION == "flashdecoding": shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 12aa7dcd..21b66a68 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -40,8 +40,7 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( MEM_POOL, - FLASH_DECODING, - FLASH_INFER, + ATTENTION, BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, @@ -938,7 +937,7 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] - if FLASH_INFER: + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flash_infer import ( create_prefill_state, create_decode_state, @@ -990,7 +989,7 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - if FLASH_DECODING or FLASH_INFER: + if ATTENTION in {"flashdecoding", "flashinfer"}: self.kv_cache = [ ( torch.empty( @@ -1062,7 +1061,7 @@ class FlashCausalLM(Model): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph - if FLASH_INFER: + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flash_infer import ( create_decode_state_cuda_graphs, ) @@ -1766,7 +1765,7 @@ class FlashCausalLM(Model): input_lengths: torch.Tensor, state: Optional[Any] = None, ) -> ContextManager: - if not FLASH_INFER: + if ATTENTION != "flashinfer": return nullcontext() from text_generation_server.layers.attention.flash_infer import ( diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 42b43c87..b58a5b80 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,16 +5,16 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -FLASH_INFER = os.getenv("FLASH_INFER") in {"1", "true", "True"} -if FLASH_INFER: - log_master(logger.info, "Using FLASH_INFER") +ATTENTION = os.getenv("ATTENTION", "paged") +_expected = {"paged", "flashdecoding", "flashinfer"} +assert ( + ATTENTION in _expected +), f"Attention is not valid {ATTENTION}, expected {_expected}" +log_master(logger.info, f"Using Attention = {ATTENTION}") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli -FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} -BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 -if FLASH_DECODING: - log_master(logger.info, "Using FLASH_DECODING") +BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16 cuda_graphs = os.getenv("CUDA_GRAPHS") From 0d06aed02dba2cfadca5c9fb7e9183545c78f39e Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 9 Aug 2024 10:56:45 -0400 Subject: [PATCH 25/72] feat: add guideline to chat request and template (#2391) * feat: add guideline to chat request and template * fix: add template test and update docs --- docs/openapi.json | 7 +++++++ router/src/infer/chat_template.rs | 15 +++++++++++++++ router/src/infer/mod.rs | 3 ++- router/src/lib.rs | 6 ++++++ router/src/server.rs | 9 +++++++-- 5 files changed, 37 insertions(+), 3 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index ed9b0b96..ecd56e4d 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -819,6 +819,13 @@ "example": "1.0", "nullable": true }, + "guideline": { + "type": "string", + "description": "A guideline to be used in the chat_template", + "default": "null", + "example": "null", + "nullable": true + }, "logit_bias": { "type": "array", "items": { diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 24a00352..7c2753ed 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -48,6 +48,7 @@ impl ChatTemplate { pub(crate) fn apply( &self, + guideline: Option<&str>, mut messages: Vec, grammar_with_prompt: Option<(GrammarType, String)>, ) -> Result { @@ -65,6 +66,7 @@ impl ChatTemplate { self.template .render(ChatTemplateInputs { + guideline, messages, bos_token: self.bos_token.as_deref(), eos_token: self.eos_token.as_deref(), @@ -731,6 +733,19 @@ mod tests { }, target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", }, + ChatTemplateTestItem { + name: "google/shieldgemma-9b", + chat_template: "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n", + input: ChatTemplateInputs { + messages: example_chat_with_system.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + guideline: Some("Do not use offensive language."), + ..Default::default() + }, + target: "You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\n\n\nHuman Question: I'd like to show off how chat templating works!\n\n\nOur safety principle is defined in the below:\n\n* Do not use offensive language.\n\n===\n\nDoes the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\n\n", + }, ]; #[allow(unused_variables)] // name is unused diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 534a2647..58d5cf9a 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -138,13 +138,14 @@ impl Infer { #[instrument(skip_all)] pub(crate) fn apply_chat_template( &self, + guideline: Option, messages: Vec, grammar_with_prompt: Option<(GrammarType, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(messages, grammar_with_prompt) + .apply(guideline.as_deref(), messages, grammar_with_prompt) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); diff --git a/router/src/lib.rs b/router/src/lib.rs index 66738706..0a15c495 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -858,6 +858,11 @@ pub(crate) struct ChatRequest { #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub response_format: Option, + + /// A guideline to be used in the chat_template + #[serde(default)] + #[schema(nullable = true, default = "null", example = "null")] + pub guideline: Option, } fn default_tool_prompt() -> Option { @@ -965,6 +970,7 @@ pub(crate) struct ChatTemplateInputs<'a> { add_generation_prompt: bool, tools: Option<&'a str>, tools_prompt: Option<&'a str>, + guideline: Option<&'a str>, } #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)] diff --git a/router/src/server.rs b/router/src/server.rs index 1d1cd36a..8c0bd762 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -141,6 +141,7 @@ async fn get_chat_tokenize( tool_prompt, temperature, response_format, + guideline, .. } = req; @@ -151,6 +152,7 @@ async fn get_chat_tokenize( tools, tool_choice, &tool_prompt, + guideline, messages, )?; @@ -1123,6 +1125,7 @@ async fn chat_completions( tool_prompt, temperature, response_format, + guideline, .. } = req; @@ -1142,6 +1145,7 @@ async fn chat_completions( tools, tool_choice, &tool_prompt, + guideline, messages, )?; @@ -2402,6 +2406,7 @@ fn prepare_chat_input( tools: Option>, tool_choice: ToolChoice, tool_prompt: &str, + guideline: Option, messages: Vec, ) -> Result { if response_format.is_some() && tools.is_some() { @@ -2411,7 +2416,7 @@ fn prepare_chat_input( } if let Some(format) = response_format { - let inputs = infer.apply_chat_template(messages, None)?; + let inputs = infer.apply_chat_template(guideline, messages, None)?; return Ok((inputs, Some(format), None)); } @@ -2423,6 +2428,6 @@ fn prepare_chat_input( let tools_grammar_prompt = tool_grammar .as_ref() .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); - let inputs = infer.apply_chat_template(messages, tools_grammar_prompt)?; + let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?; Ok((inputs, grammar, tool_grammar)) } From 8dcc7d3f6ba5dacef1405a2cd1bb8e7b7531070f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 9 Aug 2024 22:36:51 +0200 Subject: [PATCH 26/72] Update flake for 9.0a capability in Torch (#2394) --- flake.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flake.lock b/flake.lock index 043218e8..7889e7cf 100644 --- a/flake.lock +++ b/flake.lock @@ -102,11 +102,11 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1723188417, - "narHash": "sha256-GXdFuRMU9N9W0CryUTJIWhJkGwQFHSR2EW5xR0ZyBjk=", + "lastModified": 1723234585, + "narHash": "sha256-HChJpNP155FPhHr9C5BtqllV8Uv/Ebg59HhMc/HhQrc=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "491db7e234ecf79513ddb94da6ecc14167b9c0b3", + "rev": "15bd4a978d6c2c8b04b7afc335d137dbe41e73df", "type": "github" }, "original": { From 01a515dea2f49b0e12a2a4f79e456fc387ab9624 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 12 Aug 2024 09:28:38 +0200 Subject: [PATCH 27/72] nix: add router to the devshell (#2396) --- flake.nix | 4 ++++ router.nix | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 router.nix diff --git a/flake.nix b/flake.nix index fdd67d00..1d14561f 100644 --- a/flake.nix +++ b/flake.nix @@ -70,6 +70,10 @@ torch transformers vllm + + (callPackage ./router.nix { + inherit (rustPlatform) buildRustPackage importCargoLock; + }) ]); venvDir = "./.venv"; diff --git a/router.nix b/router.nix new file mode 100644 index 00000000..eeeac199 --- /dev/null +++ b/router.nix @@ -0,0 +1,18 @@ +{ buildRustPackage, importCargoLock, pkg-config, protobuf, openssl }: + +buildRustPackage { + name = "text-generation-router"; + + src = ./.; + + sourceDir = ./backends/v3; + + cargoLock = { + lockFile = ./Cargo.lock; + }; + + nativeBuildInputs = [ pkg-config ]; + + buildInputs = [ openssl.dev protobuf ]; + +} From 9c739651cd05ee69daab8d7b99c0c26d082783af Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 12 Aug 2024 14:08:38 +0200 Subject: [PATCH 28/72] Upgrade fbgemm (#2398) * Upgrade fbgemm * Fix fbgemm version --- .../models/test_grammar_response_format_llama.py | 4 ++-- server/Makefile-fbgemm | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/integration-tests/models/test_grammar_response_format_llama.py b/integration-tests/models/test_grammar_response_format_llama.py index ea25fa1c..25bf9d98 100644 --- a/integration-tests/models/test_grammar_response_format_llama.py +++ b/integration-tests/models/test_grammar_response_format_llama.py @@ -98,6 +98,6 @@ async def test_grammar_response_format_llama_error_if_tools_not_installed( # 422 means the server was unable to process the request because it contains invalid data. assert response.status_code == 422 assert response.json() == { - "error": "Grammar and tools are mutually exclusive", - "error_type": "grammar and tools", + "error": "Tool error: Grammar and tools are mutually exclusive", + "error_type": "tool_error", } diff --git a/server/Makefile-fbgemm b/server/Makefile-fbgemm index 57526577..5f3c0eaa 100644 --- a/server/Makefile-fbgemm +++ b/server/Makefile-fbgemm @@ -1,4 +1,4 @@ -fbgemm_commit := ddac8dd9fc0bee70a3f456df68b8aac38576c856 +fbgemm_commit := v0.8.0 build-fbgemm: git clone https://github.com/pytorch/FBGEMM.git fbgemm && \ From 730fa00e2072bccfd61a2c4fd67086128ba76075 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 12 Aug 2024 14:08:46 +0200 Subject: [PATCH 29/72] Adding launcher to build. (#2397) --- _launcher.nix | 18 ++++++++++++++++++ flake.nix | 3 +++ 2 files changed, 21 insertions(+) create mode 100644 _launcher.nix diff --git a/_launcher.nix b/_launcher.nix new file mode 100644 index 00000000..1acae7e1 --- /dev/null +++ b/_launcher.nix @@ -0,0 +1,18 @@ +{ buildRustPackage, importCargoLock, pkg-config, protobuf, openssl }: + +buildRustPackage { + name = "text-generation-lancher"; + + src = ./.; + + sourceDir = ./launcher; + + cargoLock = { + lockFile = ./Cargo.lock; + }; + + nativeBuildInputs = [ pkg-config ]; + + buildInputs = [ openssl.dev protobuf ]; + +} diff --git a/flake.nix b/flake.nix index 1d14561f..761c4af8 100644 --- a/flake.nix +++ b/flake.nix @@ -74,6 +74,9 @@ (callPackage ./router.nix { inherit (rustPlatform) buildRustPackage importCargoLock; }) + (callPackage ./_launcher.nix { + inherit (rustPlatform) buildRustPackage importCargoLock; + }) ]); venvDir = "./.venv"; From 84bc3d7b7d65586f7f249b0e9065588b93e7cab3 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 12 Aug 2024 14:08:59 +0200 Subject: [PATCH 30/72] Fixing import exl2 (#2399) --- .../layers/gptq/__init__.py | 58 ++++++++++--------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index f6616d3e..9c9b69d1 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -8,34 +8,6 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader -try: - major, _minor = torch.cuda.get_device_capability() -except Exception: - major = 1 - -HAS_EXLLAMA = False -CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" -V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" -if os.getenv("DISABLE_EXLLAMA") == "True": - HAS_EXLLAMA = False -elif CAN_EXLLAMA: - try: - if V2: - from text_generation_server.layers.gptq.exllamav2 import ( - QuantLinear as ExllamaQuantLinear, # noqa: F401 - ) - - HAS_EXLLAMA = "2" - else: - from text_generation_server.layers.gptq.exllama import ( - Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 - ) - - HAS_EXLLAMA = "1" - - except ImportError: - pass - @dataclass class GPTQWeight(Weight): @@ -432,3 +404,33 @@ class GPTQWeightsLoader(WeightsLoader): else False ) self.quant_method = "gptq" + + +# Needs to be at the end because circular import. +try: + major, _minor = torch.cuda.get_device_capability() +except Exception: + major = 1 + +HAS_EXLLAMA = False +CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" +V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" +if os.getenv("DISABLE_EXLLAMA") == "True": + HAS_EXLLAMA = False +elif CAN_EXLLAMA: + try: + if V2: + from text_generation_server.layers.gptq.exllamav2 import ( + QuantLinear as ExllamaQuantLinear, # noqa: F401 + ) + + HAS_EXLLAMA = "2" + else: + from text_generation_server.layers.gptq.exllama import ( + Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 + ) + + HAS_EXLLAMA = "1" + + except ImportError: + pass From b6bb1d5160083a011d69c1a32547346a3b4d7d94 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 12 Aug 2024 20:10:30 +0800 Subject: [PATCH 31/72] Cpu dockerimage (#2367) add intel-cpu docker image Signed-off-by: Wang, Yi A --- .github/workflows/build.yaml | 17 +++++++++++++++-- .github/workflows/ci_build.yaml | 2 +- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 89d5bdf5..fd059e70 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -50,6 +50,7 @@ jobs: export label_extension="" export docker_devices="" export runs_on="aws-g6-12xlarge-plus-priv" + export platform="" ;; rocm) export dockerfile="Dockerfile_amd" @@ -58,12 +59,21 @@ jobs: # TODO Re-enable when they pass. # export runs_on="amd-gpu-tgi" export runs_on="ubuntu-latest" + export platform="" ;; - intel) + intel-xpu) export dockerfile="Dockerfile_intel" - export label_extension="-intel" + export label_extension="-intel-xpu" export docker_devices="" export runs_on="ubuntu-latest" + export platform="xpu" + ;; + intel-cpu) + export dockerfile="Dockerfile_intel" + export label_extension="-intel-cpu" + export docker_devices="" + export runs_on="ubuntu-latest" + export platform="cpu" ;; esac echo $dockerfile @@ -71,8 +81,10 @@ jobs: echo $label_extension echo $docker_devices echo $runs_on + echo $platform echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV echo "LABEL=${label_extension}" >> $GITHUB_ENV + echo "PLATFORM=${platform}" >> $GITHUB_ENV echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV @@ -139,6 +151,7 @@ jobs: build-args: | GIT_SHA=${{ env.GITHUB_SHA }} DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} + PLATFORM=${{ env.PLATFORM }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min diff --git a/.github/workflows/ci_build.yaml b/.github/workflows/ci_build.yaml index 5ca2854a..6000cec3 100644 --- a/.github/workflows/ci_build.yaml +++ b/.github/workflows/ci_build.yaml @@ -37,7 +37,7 @@ jobs: # fail-fast is true by default fail-fast: false matrix: - hardware: ["cuda", "rocm", "intel"] + hardware: ["cuda", "rocm", "intel-xpu", "intel-cpu"] uses: ./.github/workflows/build.yaml # calls the one above ^ with: hardware: ${{ matrix.hardware }} From 8deeaca4ff2080381e4ed00c98e1711d896687ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 12 Aug 2024 14:59:17 +0200 Subject: [PATCH 32/72] Add support for prefix caching to the v3 router (#2392) This change adds support for prefix caching to the v3 router. This is broken up from the backend support to ease reviewing. For now prefix caching is only enabled with `USE_PREFIX_CACHING=1` in this case, the router will switch to `RadixAllocator`. This allocator uses a radix trie to keep track of prefills that were seen prior. If a new prefill is a prefix of a previously-seen prefil, the router will send a request with `prefix_len>0`, which can be used by the backend to decide to reuse KV blocks from the cache, rather than recomputing them. Even though backend support is not added in this PR, the backend will still work with prefix caching enabled. The prefix lengths are just ignored and not used. --- Cargo.lock | 1 + backends/client/src/v3/client.rs | 1 + backends/client/src/v3/sharded_client.rs | 1 + backends/v3/Cargo.toml | 1 + backends/v3/src/backend.rs | 10 + backends/v3/src/block_allocator.rs | 182 +++-- backends/v3/src/client/grpc_client.rs | 1 + backends/v3/src/client/sharded_client.rs | 1 + backends/v3/src/lib.rs | 1 + backends/v3/src/queue.rs | 53 +- backends/v3/src/radix.rs | 755 ++++++++++++++++++ benchmark/src/generation.rs | 1 + proto/v3/generate.proto | 356 ++++----- router/src/validation.rs | 19 +- .../text_generation_server/models/globals.py | 17 +- 15 files changed, 1145 insertions(+), 255 deletions(-) create mode 100644 backends/v3/src/radix.rs diff --git a/Cargo.lock b/Cargo.lock index 92367d1e..3a5845a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4045,6 +4045,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "slotmap", "text-generation-router", "thiserror", "tokenizers", diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index a996b14f..b321278c 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -156,6 +156,7 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], + prefix_len: 0, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index ae8a899b..1cc173e3 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -244,6 +244,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + prefix_len: 0, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 5d9a140b..129ceb9c 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -33,6 +33,7 @@ rand = "0.8.5" reqwest = { version = "0.11.20", features = [] } serde = "1.0.188" serde_json = "1.0.107" +slotmap = "1.0.7" thiserror = "1.0.48" tokenizers = { workspace = true} tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 68ddf00b..cbcbff72 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -35,15 +35,24 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { + let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") { + matches!(prefix_caching.as_str(), "true" | "1") + } else { + false + }; let attention = if let Ok(attention) = std::env::var("ATTENTION") { attention .parse() .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) + } else if prefix_caching { + Attention::FlashInfer } else { Attention::Paged }; let block_size = if attention == Attention::FlashDecoding { 256 + } else if attention == Attention::FlashInfer { + 1 } else { 16 }; @@ -51,6 +60,7 @@ impl BackendV3 { let queue = Queue::new( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 7467fd85..05c2bd30 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -1,16 +1,26 @@ -use std::cmp::min; +use std::{cmp::min, sync::Arc}; use tokio::sync::{mpsc, oneshot}; +use crate::radix::RadixAllocator; + #[derive(Debug, Clone)] pub(crate) struct BlockAllocation { + pub allocation_id: u64, pub blocks: Vec, pub slots: Vec, - block_allocator: BlockAllocator, + + /// Prefix that was cached and for which the KV does not have to + /// be recomputed. + pub prefix_len: u32, + + pub(crate) block_allocator: Option, } impl Drop for BlockAllocation { fn drop(&mut self) { - self.block_allocator.free(self.blocks.clone()) + if let Some(block_allocator) = self.block_allocator.as_mut() { + block_allocator.free(self.blocks.clone(), self.allocation_id) + } } } @@ -24,6 +34,7 @@ impl BlockAllocator { pub(crate) fn new( max_batch_total_tokens: u32, block_size: u32, + prefix_caching: bool, window_size: Option, ) -> Self { // Create channel @@ -33,6 +44,7 @@ impl BlockAllocator { tokio::spawn(block_allocator_task( max_batch_total_tokens / block_size, block_size, + prefix_caching, window_size, receiver, )); @@ -42,28 +54,32 @@ impl BlockAllocator { } } - pub(crate) async fn allocate(&self, tokens: u32) -> Option { + pub(crate) async fn allocate( + &self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { let (response_sender, response_receiver) = oneshot::channel(); self.block_allocator .send(BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, }) .unwrap(); - response_receiver - .await - .unwrap() - .map(|(blocks, slots)| BlockAllocation { - blocks, - slots, - block_allocator: self.clone(), - }) + response_receiver.await.unwrap().map(|mut allocation| { + allocation.block_allocator = Some(self.clone()); + allocation + }) } - pub(crate) fn free(&self, blocks: Vec) { + pub(crate) fn free(&self, blocks: Vec, allocation_id: u64) { self.block_allocator - .send(BlockAllocatorCommand::Free { blocks }) + .send(BlockAllocatorCommand::Free { + allocation_id, + blocks, + }) .unwrap(); } } @@ -71,54 +87,29 @@ impl BlockAllocator { async fn block_allocator_task( blocks: u32, block_size: u32, + prefix_caching: bool, window_size: Option, mut receiver: mpsc::UnboundedReceiver, ) { - // Block 0 is reserved for health checks - let mut free_blocks: Vec = (1..blocks).collect(); + let mut allocator: Box = if prefix_caching { + Box::new(RadixAllocator::new(block_size, blocks, window_size)) + } else { + Box::new(SimpleAllocator::new(blocks, block_size, window_size)) + }; while let Some(cmd) = receiver.recv().await { match cmd { - BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), + BlockAllocatorCommand::Free { + blocks, + allocation_id, + } => allocator.free(blocks, allocation_id), BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, } => { - // Apply window size - let (required_blocks, repeats) = { - let (tokens, repeats) = match window_size { - None => (tokens, 1), - Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); - (tokens, repeats as usize) - } - }; - // Pad to a multiple of block size - let required_blocks = (tokens + block_size - 1) / block_size; - (required_blocks, repeats) - }; - - let tokens = tokens as usize; - let allocation = if required_blocks > free_blocks.len() as u32 { - None - } else { - let blocks = - free_blocks.split_off(free_blocks.len() - required_blocks as usize); - let mut slots = Vec::with_capacity( - (required_blocks * block_size * repeats as u32) as usize, - ); - - 'slots: for block_id in blocks.repeat(repeats).iter() { - for s in (block_id * block_size)..((block_id + 1) * block_size) { - slots.push(s); - if slots.len() == tokens { - break 'slots; - } - } - } - Some((blocks, slots)) - }; - response_sender.send(allocation).unwrap(); + response_sender + .send(allocator.allocate(tokens, prefill_tokens)) + .unwrap(); } } } @@ -128,9 +119,92 @@ async fn block_allocator_task( enum BlockAllocatorCommand { Free { blocks: Vec, + allocation_id: u64, }, Allocate { tokens: u32, - response_sender: oneshot::Sender, Vec)>>, + prefill_tokens: Option>>, + response_sender: oneshot::Sender>, }, } + +pub(crate) trait Allocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option; + + fn free(&mut self, blocks: Vec, allocation_id: u64); +} + +pub struct SimpleAllocator { + free_blocks: Vec, + block_size: u32, + window_size: Option, +} + +impl SimpleAllocator { + fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { + SimpleAllocator { + block_size, + // Block 0 is reserved for health checks + free_blocks: (1..blocks).collect(), + window_size, + } + } +} + +impl Allocator for SimpleAllocator { + fn allocate( + &mut self, + tokens: u32, + _prefill_tokens: Option>>, + ) -> Option { + // Apply window size + let (required_blocks, repeats) = { + let (tokens, repeats) = match self.window_size { + None => (tokens, 1), + Some(window_size) => { + let repeats = (tokens + window_size - 1) / window_size; + let tokens = min(tokens, window_size); + (tokens, repeats as usize) + } + }; + // Pad to a multiple of block size + let required_blocks = (tokens + self.block_size - 1) / self.block_size; + (required_blocks, repeats) + }; + + let tokens = tokens as usize; + if required_blocks > self.free_blocks.len() as u32 { + None + } else { + let blocks = self + .free_blocks + .split_off(self.free_blocks.len() - required_blocks as usize); + let mut slots = + Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); + + 'slots: for block_id in blocks.repeat(repeats).iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() == tokens { + break 'slots; + } + } + } + Some(BlockAllocation { + allocation_id: 0, + blocks, + slots, + prefix_len: 0, + block_allocator: None, + }) + } + } + + fn free(&mut self, blocks: Vec, _allocation_id: u64) { + self.free_blocks.extend(blocks) + } +} diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index c407687b..6282759e 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -157,6 +157,7 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], + prefix_len: 0, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index afb13cdc..2f78da03 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -245,6 +245,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + prefix_len: 0, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index a6f89169..c8fc55f8 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -2,6 +2,7 @@ mod backend; mod block_allocator; mod client; mod queue; +mod radix; use crate::client::{ClientError, ShardedClient}; pub(crate) use backend::BackendV3; diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index b457389c..13544235 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -46,6 +46,7 @@ impl Queue { pub(crate) fn new( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, @@ -57,6 +58,7 @@ impl Queue { tokio::spawn(queue_task( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, @@ -109,6 +111,7 @@ impl Queue { async fn queue_task( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, @@ -117,6 +120,7 @@ async fn queue_task( let mut state = State::new( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, @@ -176,12 +180,19 @@ impl State { fn new( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, ) -> Self { - let block_allocator = (!requires_padding) - .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); + let block_allocator = (!requires_padding).then(|| { + BlockAllocator::new( + max_batch_total_tokens, + block_size, + prefix_caching, + window_size, + ) + }); Self { entries: VecDeque::with_capacity(128), @@ -305,7 +316,10 @@ impl State { + self.speculate - 1; - match block_allocator.allocate(tokens).await { + match block_allocator + .allocate(tokens, entry.request.input_ids.clone()) + .await + { None => { // Entry is over budget // Add it back to the front @@ -331,11 +345,12 @@ impl State { // Update entry entry.temp_span = Some(entry_batch_span); - let (blocks, slots) = match &block_allocation { - None => (Vec::new(), Vec::new()), + let (blocks, slots, prefix_len) = match &block_allocation { + None => (Vec::new(), Vec::new(), 0), Some(block_allocation) => ( block_allocation.blocks.clone(), block_allocation.slots.clone(), + block_allocation.prefix_len, ), }; @@ -372,6 +387,7 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, + prefix_len, adapter_id: entry.request.adapter_id.clone(), }); // Set batch_time @@ -480,6 +496,8 @@ impl From for StoppingCriteriaParameters { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use tracing::info_span; @@ -492,6 +510,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { inputs: vec![], + input_ids: Some(Arc::new(vec![])), input_length: 0, truncate: 0, decoder_input_details: false, @@ -527,7 +546,7 @@ mod tests { #[tokio::test] async fn test_append() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -543,7 +562,7 @@ mod tests { #[tokio::test] async fn test_next_batch_empty() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -551,7 +570,7 @@ mod tests { #[tokio::test] async fn test_next_batch_min_size() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -583,7 +602,7 @@ mod tests { #[tokio::test] async fn test_next_batch_max_size() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -603,7 +622,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None, 0, 2); + let mut state = State::new(false, 1, false, None, 0, 2); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -636,14 +655,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -651,7 +670,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -684,7 +703,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -700,7 +719,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -725,7 +744,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2, 16); + let queue = Queue::new(false, 1, false, None, 2, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -744,7 +763,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry, _) = default_entry(); queue.append(entry); diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs new file mode 100644 index 00000000..0464b9f8 --- /dev/null +++ b/backends/v3/src/radix.rs @@ -0,0 +1,755 @@ +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; + +use slotmap::{DefaultKey, SlotMap}; + +use crate::block_allocator::{Allocator, BlockAllocation}; + +pub struct RadixAllocator { + allocation_id: u64, + + allocations: HashMap, + + cache_blocks: RadixTrie, + + /// Blocks that are immediately available for allocation. + free_blocks: Vec, +} + +impl RadixAllocator { + pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { + assert_eq!( + block_size, 1, + "Radix tree allocator only works with block_size=1, was: {}", + block_size + ); + if window_size.is_some() { + unimplemented!("Window size not supported in the prefix-caching block allocator yet"); + } + + RadixAllocator { + allocation_id: 0, + allocations: HashMap::new(), + cache_blocks: RadixTrie::new(), + + // Block 0 is reserved for health checks. + free_blocks: (1..n_blocks).collect(), + } + } + + fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option> { + if self.free_blocks.len() < n_blocks_needed { + // This is a bit annoying, we first extend the free list and then + // split it off again below. This is because we need to put it on + // the free list if we cannot allocate enough blocks. This is only + // temporary, the trie needs to be able to report whether it can + // allocate the requested amount. Just not implemented yet. + self.free_blocks.extend( + self.cache_blocks + .evict(n_blocks_needed - self.free_blocks.len()), + ); + } + + if self.free_blocks.len() >= n_blocks_needed { + Some( + self.free_blocks + .split_off(self.free_blocks.len() - n_blocks_needed), + ) + } else { + None + } + } +} + +impl Allocator for RadixAllocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { + let mut blocks = vec![]; + let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { + let node_id = self + .cache_blocks + .find(prefill_tokens.as_slice(), &mut blocks); + // Even if this allocation fails below, we need to increase he + // refcount to ensure that the prefix that was found is not evicted. + + node_id + } else { + self.cache_blocks.root_id() + }; + + self.cache_blocks + .incref(prefix_node) + .expect("Failed to increment refcount"); + + let prefix_len = blocks.len(); + let suffix_len = tokens - prefix_len as u32; + + match self.alloc_or_reclaim(suffix_len as usize) { + Some(suffix_blocks) => blocks.extend(suffix_blocks), + None => { + self.cache_blocks + .decref(prefix_node) + .expect("Failed to decrement refcount"); + return None; + } + } + + // 1:1 mapping of blocks and slots. + let slots = blocks.clone(); + + let allocation = RadixAllocation { + prefix_node, + cached_prefix_len: prefix_len, + prefill_tokens: prefill_tokens.clone(), + }; + + self.allocation_id += 1; + self.allocations.insert(self.allocation_id, allocation); + + Some(BlockAllocation { + allocation_id: self.allocation_id, + block_allocator: None, + blocks, + slots, + prefix_len: prefix_len as u32, + }) + } + + fn free(&mut self, blocks: Vec, allocation_id: u64) { + let allocation = match self.allocations.remove(&allocation_id) { + Some(allocation) => allocation, + None => unreachable!("Tried to free an unknown allocation."), + }; + + self.cache_blocks + .decref(allocation.prefix_node) + .expect("Failed to decrement refcount"); + + if let Some(prefill_tokens) = allocation.prefill_tokens { + let prefill_tokens = prefill_tokens.as_slice(); + + // If there are prefill tokens that did not come from the cache, + // add them to the cache. + if prefill_tokens.len() > allocation.cached_prefix_len { + let prefix_len = self + .cache_blocks + .insert(prefill_tokens, &blocks[..prefill_tokens.len()]) + // Unwrap, failing is a programming error. + .expect("Failed to store prefill tokens"); + + // We can have a prefill with the following structure: + // + // |---| From the prefix cache. + // A B C D E F G + //|--------| Found in the trie during insertion. + // + // This means that while processing this request there was a + // partially overlapping request that had A..=E in its + // prefill. In this case we need to free the blocks D E. + self.free_blocks + .extend(&blocks[allocation.cached_prefix_len..prefix_len]); + } + + // Free non-prefill blocks. + self.free_blocks.extend(&blocks[prefill_tokens.len()..]); + } else { + self.free_blocks.extend(blocks); + } + } +} + +struct RadixAllocation { + prefix_node: NodeId, + cached_prefix_len: usize, + prefill_tokens: Option>>, +} + +// Radix trie that is heavily inspired by radix attention from sglang. +// +// The trie is optimized for prefix caching: +// +// - A normal radix trie stores discrete values. In this radix trie, +// inserting *abc* with value *xyz* will also enable lookup for +// *a* (*x*) and *ab* (*xy*). +// - As a result, every value is required to have the same length as +// the key. +// - We store additional information in each node, such as last access +// time and a reference count. + +#[derive(Debug)] +pub enum TrieError { + InvalidNodeId, + RefCountUnderflow, + BlockTokenCountMismatch, +} + +pub type NodeId = DefaultKey; + +#[derive(Debug)] +pub struct RadixTrie { + /// Identifier of the root nod. + root: DefaultKey, + + /// Leave node identifiers ordered by increasing recency. + leaves: BTreeSet<(u64, NodeId)>, + + /// All trie nodes. + nodes: SlotMap, + + /// Time as a monotonically increating counter to avoid the system + /// call that a real time lookup would require. + time: u64, +} + +impl RadixTrie { + /// Construct a new radix trie. + pub fn new() -> Self { + let root = TrieNode::new(vec![], vec![], 0, None); + let mut nodes = SlotMap::new(); + let root = nodes.insert(root); + RadixTrie { + leaves: BTreeSet::new(), + nodes, + root, + time: 0, + } + } + + /// Find the prefix of the given tokens. + /// + /// The blocks corresponding to the part of the prefix that could be found + /// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`. + /// Returns the identifier of the trie node that contains the longest + /// prefix. The node identifier can be used by callers to e.g. increase its + /// reference count. + /// + /// Using this method will update the access time of the traversed nodes. + pub fn find(&mut self, key: &[u32], blocks: &mut Vec) -> NodeId { + self.time += 1; + self.find_(self.root, key, blocks) + } + + /// Find worker. + fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { + let node = &self.nodes[node_id]; + + if let Some(&child_id) = node.children.get(&key[0]) { + self.update_access_time(child_id); + let child = self.nodes.get(child_id).expect("Invalid child identifier"); + let shared_prefix_len = child.key.shared_prefix_len(key); + blocks.extend(&child.blocks[..shared_prefix_len]); + + let key = &key[shared_prefix_len..]; + if !key.is_empty() { + node_id = self.find_(child_id, key, blocks); + } + } + + node_id + } + + /// Decrease the reference count of a node. + pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + // We don't care about refcounting for root, since it will never + // be evicted. + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + return Err(TrieError::RefCountUnderflow); + } + + node.ref_count -= 1; + if node.ref_count == 0 { + self.leaves.insert((node.last_accessed, node_id)); + } + + Ok(()) + } + + /// Increase the reference count of a node. + pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + self.leaves.remove(&(node.last_accessed, node_id)); + } + node.ref_count += 1; + + Ok(()) + } + + /// Evict `n_blocks` from the trie. + /// + /// Returns the evicted blocks. When the length is less than `n_blocks`, + /// not enough blocks could beevicted. + pub fn evict(&mut self, n_blocks: usize) -> Vec { + // NOTE: we don't return Result here. If any of the unwrapping fails, + // it's a programming error in the trie implementation, not a user + // error caused by e.g. an invalid argument. + + // TODO: add some bookkeeping in the future to check whether we can + // evict n_blocks and return `None` if we can't. We are now needlessly + // evicting prefixes from the cache in such a case. + let mut evicted = Vec::new(); + + while let Some((last_access, node_id)) = self.leaves.pop_first() { + let blocks_needed = n_blocks - evicted.len(); + + let node = self.nodes.get(node_id).expect("Leave does not exist"); + if blocks_needed >= node.blocks.len() { + // We need to evict the whole node if we need more blocks than it has. + let node = self.remove_node(node_id); + evicted.extend(node.blocks); + + if evicted.len() >= n_blocks { + break; + } + } else { + // The node has more blocks than needed, so we'll just remove + // the required number of blocks and leave the remaining blocks + // untouched. + let node = self.nodes.get_mut(node_id).expect("Leave does not exist"); + node.key.truncate(node.blocks.len() - blocks_needed); + evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed)); + self.leaves.insert((last_access, node_id)); + break; + } + } + + evicted + } + + /// Insert a prefill along with its blocks. + /// + /// This method returns the length of the prefix that was already + /// in the trie. E.g. if the length is 10, this means that for + /// the first 10 elements of the tree **the blocks are not updated**. + pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result { + self.time += 1; + self.insert_(self.root, tokens, blocks) + } + + /// Insertion worker. + fn insert_( + &mut self, + node_id: NodeId, + tokens: &[u32], + blocks: &[u32], + ) -> Result { + // TODO: in the future we may want to check that the blocks match for + // the part of the prefix that is already in the trie to detect + // mismatches. + + if tokens.len() != blocks.len() { + return Err(TrieError::BlockTokenCountMismatch); + } + + if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) { + self.update_access_time(child_id); + let child = self + .nodes + .get_mut(child_id) + // Unwrap here, since failure is a bug. + .expect("Child node does not exist"); + let shared_prefix_len = child.key.shared_prefix_len(tokens); + + // We are done, the prefix is already in the trie. + if shared_prefix_len == tokens.len() { + return Ok(shared_prefix_len); + } + + // The node's prefix is a prefix of the insertion prefix. + if shared_prefix_len == child.key.len() { + return Ok(shared_prefix_len + + self.insert_( + child_id, + &tokens[shared_prefix_len..], + &blocks[shared_prefix_len..], + )?); + } + + // The node's prefix and the insertion prefix only match partially, + // split the node to just contain the matching part. Then insert the + // remainder of the prefix into the node again + let child_id = self.split_node(child_id, shared_prefix_len); + let key = &tokens[shared_prefix_len..]; + let blocks = &blocks[shared_prefix_len..]; + Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) + } else { + self.add_node(node_id, tokens, blocks); + Ok(0) + } + } + + fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId { + // We have to make the current node a child to ensure that its + // properties and node id stay the same. + + // This funcion unwraps, an invalid node_id is a programming error. + + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + let mut parent_key = node.key.split_off(prefix_len); + let mut parent_blocks = node.blocks.split_off(prefix_len); + + // Move first part of the prefix to the parent. We swap to avoid + // an allocation + copy for both splits of the key/blocks. + std::mem::swap(&mut node.key, &mut parent_key); + std::mem::swap(&mut node.blocks, &mut parent_blocks); + + let node_key = node.key[0]; + + let grandparent_id = node.parent.expect("Node does not have a parent"); + let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks); + self.add_node_to_parent(parent_id, node_key, node_id); + + // Reborrow to make the borrow checker happy. + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + node.parent = Some(parent_id); + + parent_id + } + + /// Create a node and add it to the parent. + fn add_node( + &mut self, + parent_id: NodeId, + key: impl Into>, + blocks: impl Into>, + ) -> NodeId { + let key = key.into(); + let blocks = blocks.into(); + let first = key[0]; + + let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); + let child_id = self.nodes.insert(child); + + self.add_node_to_parent(parent_id, first, child_id); + self.leaves.insert((self.time, child_id)); + + child_id + } + + /// Add a node to the parent. + fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + if parent.children.insert(first, child_id).is_none() { + // Only increase reference count if child does not replace another child. + self.incref(parent_id) + .expect("Failed to increase parent refcount"); + } + } + + /// Remove a node from the trie. + fn remove_node(&mut self, node_id: NodeId) -> TrieNode { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.remove(node_id).expect("Unknown node"); + let parent_id = node.parent.expect("Attempted to remove root node"); + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + parent.children.remove(&node.key[0]); + self.decref(parent_id) + .expect("Failed to decrease parent refcount"); + self.nodes.remove(node_id); + node + } + + fn update_access_time(&mut self, node_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.get_mut(node_id).expect("Unknown node"); + + // Update the ordered leaves set if the node is a leave. + if self.leaves.remove(&(node.last_accessed, node_id)) { + self.leaves.insert((self.time, node_id)); + } + + node.last_accessed = self.time; + } + + #[allow(dead_code)] + #[doc(hidden)] + /// Print debugging output for the trie. + /// + /// In contrast to `Debug` nicely formatted. + pub fn print_debug(&self) { + self.print_debug_(self.root, 0); + } + + fn print_debug_(&self, node_id: NodeId, indent: usize) { + let node = &self.nodes[node_id]; + eprintln!( + "{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}", + " ".repeat(indent), + node_id, + node.key, + node.blocks, + node.ref_count, + node.last_accessed, + node.parent, + node.children + ); + for child_id in self.nodes[node_id].children.values() { + self.print_debug_(*child_id, indent + 2); + } + } + + pub(crate) fn root_id(&self) -> DefaultKey { + self.root + } +} + +/// Trie node. +#[derive(Debug)] +struct TrieNode { + blocks: Vec, + children: HashMap, + key: Vec, + last_accessed: u64, + parent: Option, + ref_count: usize, +} + +impl TrieNode { + fn new(key: Vec, blocks: Vec, last_accessed: u64, parent: Option) -> Self { + TrieNode { + children: HashMap::new(), + key, + blocks, + last_accessed, + parent, + ref_count: 0, + } + } +} + +/// Helper trait to get the length of the shared prefix of two sequences. +trait SharedPrefixLen { + fn shared_prefix_len(&self, other: &Self) -> usize; +} + +impl SharedPrefixLen for [T] +where + T: PartialEq, +{ + fn shared_prefix_len(&self, other: &Self) -> usize { + self.iter().zip(other).take_while(|(a, b)| a == b).count() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::block_allocator::Allocator; + + use super::RadixAllocator; + + #[test] + fn allocator_reuses_prefixes() { + let mut cache = RadixAllocator::new(1, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.slots, allocation.slots); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_collects_older_prefixes_first() { + let mut cache = RadixAllocator::new(1, 7, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation1.prefix_len, 0); + + let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap(); + assert_eq!(allocation2.blocks, vec![1, 2]); + assert_eq!(allocation2.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // We should get the blocks of the first allocation, since they are more recent. + let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap(); + assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation3.prefix_len, 0); + } + + #[test] + fn allocator_frees_fully_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 10, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation3.prefix_len, 4); + + // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. + assert_eq!(cache.free_blocks.len(), 5); + } + + #[test] + fn allocator_frees_partially_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 20, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); + assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); + assert_eq!(allocation1.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation2 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]); + assert_eq!(allocation2.prefix_len, 2); + + let allocation3 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation3.prefix_len, 2); + + cache.free(allocation3.blocks.clone(), allocation3.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2. + assert_eq!(cache.free_blocks.len(), 11); + + let allocation4 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]); + assert_eq!(allocation4.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + + let allocation5 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]); + assert_eq!(allocation5.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + } + + #[test] + fn trie_insertions_have_correct_prefix_len() { + let mut trie = super::RadixTrie::new(); + + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); + + // Already exists. + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(), + 4 + ); + } + + #[test] + fn trie_get_returns_correct_blocks() { + let mut trie = super::RadixTrie::new(); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + + let mut blocks = Vec::new(); + trie.find(&[0], &mut blocks); + assert_eq!(blocks, vec![0]); + + blocks.clear(); + trie.find(&[0, 1, 2], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2]); + + blocks.clear(); + trie.find(&[1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 4]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5]); + } + + #[test] + fn trie_evict_removes_correct_blocks() { + let mut trie = super::RadixTrie::new(); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + + let mut blocks = Vec::new(); + + // Remove less than the leave blocks. + assert_eq!(trie.evict(1), vec![7]); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]); + + // Refresh other leaf. + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + trie.find(&[1, 2, 3], &mut blocks); + + // Remove the leave blocks exactly. + assert_eq!(trie.evict(2), vec![5, 6]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + trie.find(&[1, 2, 3], &mut blocks); + + // Remove more than the leave blocks. + assert_eq!(trie.evict(3), vec![4, 3, 2]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1]); + + // Clear out the whole trie. + assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); + } +} diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 5e739703..7494d5b5 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -157,6 +157,7 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], + prefix_len: 0, adapter_id: None, }) .collect(); diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 926c878e..68eea7ac 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -3,22 +3,23 @@ syntax = "proto3"; package generate.v3; service TextGenerationService { - /// Model Info - rpc Info (InfoRequest) returns (InfoResponse) {} - /// Service discovery - rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} - /// Empties batch cache - rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); - /// Remove requests from a cached batch - rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); - /// Warmup the model and compute max cache size - rpc Warmup (WarmupRequest) returns (WarmupResponse); - /// Prefill batch and decode first token - rpc Prefill (PrefillRequest) returns (PrefillResponse); - /// Decode token for a list of prefilled batches - rpc Decode (DecodeRequest) returns (DecodeResponse); - /// Health check - rpc Health (HealthRequest) returns (HealthResponse); + /// Model Info + rpc Info(InfoRequest) returns (InfoResponse) {} + /// Service discovery + rpc ServiceDiscovery(ServiceDiscoveryRequest) + returns (ServiceDiscoveryResponse) {} + /// Empties batch cache + rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse); + /// Remove requests from a cached batch + rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse); + /// Warmup the model and compute max cache size + rpc Warmup(WarmupRequest) returns (WarmupResponse); + /// Prefill batch and decode first token + rpc Prefill(PrefillRequest) returns (PrefillResponse); + /// Decode token for a list of prefilled batches + rpc Decode(DecodeRequest) returns (DecodeResponse); + /// Health check + rpc Health(HealthRequest) returns (HealthResponse); } message HealthRequest {} @@ -28,240 +29,239 @@ message HealthResponse {} message InfoRequest {} message InfoResponse { - bool requires_padding = 1; - string dtype = 2; - string device_type = 3; - optional uint32 window_size = 4; - uint32 speculate = 5; + bool requires_padding = 1; + string dtype = 2; + string device_type = 3; + optional uint32 window_size = 4; + uint32 speculate = 5; } /// Empty request message ServiceDiscoveryRequest {} message ServiceDiscoveryResponse { - /// Other shards urls - repeated string urls = 1; + /// Other shards urls + repeated string urls = 1; } message ClearCacheRequest { - /// Optional batch id - optional uint64 id = 1; + /// Optional batch id + optional uint64 id = 1; } /// Empty response message ClearCacheResponse {} message Image { - /// Binary image data. - bytes data = 1; + /// Binary image data. + bytes data = 1; - /// Image MIME type. - string mimetype = 2; + /// Image MIME type. + string mimetype = 2; } message InputChunk { - oneof chunk { - /// Plain text data - string text = 1; - /// Image data - Image image = 2; - } + oneof chunk { + /// Plain text data + string text = 1; + /// Image data + Image image = 2; + } } -message Input { - repeated InputChunk chunks = 1; - } +message Input { repeated InputChunk chunks = 1; } enum GrammarType { - GRAMMAR_TYPE_NONE = 0; - GRAMMAR_TYPE_JSON = 1; - GRAMMAR_TYPE_REGEX = 2; + GRAMMAR_TYPE_NONE = 0; + GRAMMAR_TYPE_JSON = 1; + GRAMMAR_TYPE_REGEX = 2; } message NextTokenChooserParameters { - /// exponential scaling output probability distribution - float temperature = 1; - /// restricting to the k highest probability elements - uint32 top_k = 2; - /// restricting to top tokens summing to prob_cut_off <= prob_cut_off - float top_p = 3; - /// restricting to top tokens summing to prob_cut_off <= prob_cut_off - float typical_p = 4; - /// apply sampling on the logits - bool do_sample = 5; - /// random seed for sampling - uint64 seed = 6; - /// repetition penalty - float repetition_penalty = 7; - /// frequency penalty - float frequency_penalty = 9; - /// token watermarking using "A Watermark for Large Language Models" - bool watermark = 8; - /// grammar (applied if not empty) - string grammar = 10; - /// grammar type - GrammarType grammar_type = 11; + /// exponential scaling output probability distribution + float temperature = 1; + /// restricting to the k highest probability elements + uint32 top_k = 2; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off + float top_p = 3; + /// restricting to top tokens summing to prob_cut_off <= prob_cut_off + float typical_p = 4; + /// apply sampling on the logits + bool do_sample = 5; + /// random seed for sampling + uint64 seed = 6; + /// repetition penalty + float repetition_penalty = 7; + /// frequency penalty + float frequency_penalty = 9; + /// token watermarking using "A Watermark for Large Language Models" + bool watermark = 8; + /// grammar (applied if not empty) + string grammar = 10; + /// grammar type + GrammarType grammar_type = 11; } message StoppingCriteriaParameters { - /// Maximum number of generated tokens - uint32 max_new_tokens = 1; - /// Optional stopping sequences - repeated string stop_sequences = 2; - /// Ignore end of sequence token - /// used for benchmarking - bool ignore_eos_token = 3; + /// Maximum number of generated tokens + uint32 max_new_tokens = 1; + /// Optional stopping sequences + repeated string stop_sequences = 2; + /// Ignore end of sequence token + /// used for benchmarking + bool ignore_eos_token = 3; } message Request { - /// Request ID - uint64 id = 1; - /// The generation context as chunks - Input input_chunks = 8; - /// The generation context, stringified input_chunks - string inputs = 2; - /// Context truncation - uint32 truncate = 3; - /// Next Token Chooser Parameters - NextTokenChooserParameters parameters = 4; - /// Stopping Criteria Parameters - StoppingCriteriaParameters stopping_parameters = 5; - /// Return prefill logprobs - bool prefill_logprobs = 6; - /// Return most likely n tokens - uint32 top_n_tokens = 7; - /// Paged attention blocks - repeated uint32 blocks = 9; - /// Paged attention slots - repeated uint32 slots = 10; - /// LORA adapter index - optional string adapter_id = 11; + /// Request ID + uint64 id = 1; + /// The generation context as chunks + Input input_chunks = 8; + /// The generation context, stringified input_chunks + string inputs = 2; + /// Context truncation + uint32 truncate = 3; + /// Next Token Chooser Parameters + NextTokenChooserParameters parameters = 4; + /// Stopping Criteria Parameters + StoppingCriteriaParameters stopping_parameters = 5; + /// Return prefill logprobs + bool prefill_logprobs = 6; + /// Return most likely n tokens + uint32 top_n_tokens = 7; + /// Paged attention blocks + repeated uint32 blocks = 9; + /// Paged attention slots + repeated uint32 slots = 10; + /// LORA adapter index + optional string adapter_id = 11; + /// Prefix length that can be retrieved from the KV cache. + uint32 prefix_len = 12; } message Batch { - /// Batch ID - uint64 id = 1; - /// Individual requests - repeated Request requests = 2; - /// Batch size (==len(requests)) - uint32 size = 3; - /// Maximum number of tokens this batch will grow to - uint32 max_tokens = 4; - /// Maximum number of Paged Attention blocks - uint32 max_blocks = 5; + /// Batch ID + uint64 id = 1; + /// Individual requests + repeated Request requests = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; + /// Maximum number of Paged Attention blocks + uint32 max_blocks = 5; } message CachedBatch { - /// Batch ID - uint64 id = 1; - /// Individual requests ids - repeated uint64 request_ids = 2; - /// Batch size (==len(requests)) - uint32 size = 3; - /// Maximum number of tokens this batch will grow to - uint32 max_tokens = 4; + /// Batch ID + uint64 id = 1; + /// Individual requests ids + repeated uint64 request_ids = 2; + /// Batch size (==len(requests)) + uint32 size = 3; + /// Maximum number of tokens this batch will grow to + uint32 max_tokens = 4; } enum FinishReason { - FINISH_REASON_LENGTH = 0; - FINISH_REASON_EOS_TOKEN = 1; - FINISH_REASON_STOP_SEQUENCE = 2; + FINISH_REASON_LENGTH = 0; + FINISH_REASON_EOS_TOKEN = 1; + FINISH_REASON_STOP_SEQUENCE = 2; } message GeneratedText { - /// Output - string text = 1; - /// Number of generated tokens - uint32 generated_tokens = 2; - /// Finish reason - FinishReason finish_reason = 3; - /// Seed - optional uint64 seed = 4; + /// Output + string text = 1; + /// Number of generated tokens + uint32 generated_tokens = 2; + /// Finish reason + FinishReason finish_reason = 3; + /// Seed + optional uint64 seed = 4; } message Tokens { - /// Token IDs - repeated uint32 ids = 1; - /// Logprobs - repeated float logprobs = 2; - /// tokens - repeated string texts = 3; - /// special - repeated bool is_special = 4; + /// Token IDs + repeated uint32 ids = 1; + /// Logprobs + repeated float logprobs = 2; + /// tokens + repeated string texts = 3; + /// special + repeated bool is_special = 4; } message Generation { - /// Request ID - uint64 request_id = 1; - /// Prefill tokens (optional) - Tokens prefill_tokens = 2; - Tokens tokens = 3; - /// Complete generated text - optional GeneratedText generated_text = 4; - /// Top tokens - repeated Tokens top_tokens = 5; + /// Request ID + uint64 request_id = 1; + /// Prefill tokens (optional) + Tokens prefill_tokens = 2; + Tokens tokens = 3; + /// Complete generated text + optional GeneratedText generated_text = 4; + /// Top tokens + repeated Tokens top_tokens = 5; } message FilterBatchRequest { - /// Batch ID - uint64 batch_id = 1; - /// Requests to keep - repeated uint64 request_ids = 2; + /// Batch ID + uint64 batch_id = 1; + /// Requests to keep + repeated uint64 request_ids = 2; } message FilterBatchResponse { - /// Filtered Batch (cached) - CachedBatch batch = 1; + /// Filtered Batch (cached) + CachedBatch batch = 1; } - message PrefillRequest { - /// Batch - Batch batch = 1; + /// Batch + Batch batch = 1; } message PrefillResponse { - /// Generation - repeated Generation generations = 1; - /// Next batch (cached) - optional CachedBatch batch = 2; - /// Forward elapsed time in nanoseconds - uint64 forward_ns = 3; - /// Decode elapsed time in nanoseconds - uint64 decode_ns = 4; - /// Total elapsed time in nanoseconds - uint64 total_ns = 5; + /// Generation + repeated Generation generations = 1; + /// Next batch (cached) + optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; } message DecodeRequest { - /// Cached batches - repeated CachedBatch batches = 1; + /// Cached batches + repeated CachedBatch batches = 1; } message DecodeResponse { - /// Decodes - repeated Generation generations = 1; - /// Next batch (cached) - optional CachedBatch batch = 2; - /// Forward elapsed time in nanoseconds - uint64 forward_ns = 3; - /// Decode elapsed time in nanoseconds - uint64 decode_ns = 4; - /// Total elapsed time in nanoseconds - uint64 total_ns = 5; - /// Concatenate elapsed time in nanoseconds - optional uint64 concat_ns = 6; + /// Decodes + repeated Generation generations = 1; + /// Next batch (cached) + optional CachedBatch batch = 2; + /// Forward elapsed time in nanoseconds + uint64 forward_ns = 3; + /// Decode elapsed time in nanoseconds + uint64 decode_ns = 4; + /// Total elapsed time in nanoseconds + uint64 total_ns = 5; + /// Concatenate elapsed time in nanoseconds + optional uint64 concat_ns = 6; } message WarmupRequest { - /// Batch to warmup on - Batch batch = 1; - uint32 max_input_length = 2; - uint32 max_prefill_tokens = 3; - uint32 max_total_tokens = 4; + /// Batch to warmup on + Batch batch = 1; + uint32 max_input_length = 2; + uint32 max_prefill_tokens = 3; + uint32 max_total_tokens = 4; } message WarmupResponse { - /// Maximum number of tokens supported by the model - optional uint32 max_supported_total_tokens = 1; + /// Maximum number of tokens supported by the model + optional uint32 max_supported_total_tokens = 1; } diff --git a/router/src/validation.rs b/router/src/validation.rs index 3d1a4103..5011158a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -11,6 +11,7 @@ use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; use std::iter; +use std::sync::Arc; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::mpsc; @@ -115,13 +116,14 @@ impl Validation { } } + #[allow(clippy::type_complexity)] #[instrument(skip(self, inputs))] async fn validate_input( &self, inputs: String, truncate: Option, max_new_tokens: Option, - ) -> Result<(Vec, usize, u32), ValidationError> { + ) -> Result<(Vec, Option>, usize, u32), ValidationError> { // If we have a fast tokenizer if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { // Create response channel @@ -156,8 +158,10 @@ impl Validation { )); } + let input_ids = encoding.get_ids()[..input_length].to_owned(); + metrics::histogram!("tgi_request_input_length").record(input_length as f64); - Ok((inputs, input_length, max_new_tokens)) + Ok((inputs, Some(input_ids), input_length, max_new_tokens)) } // Return inputs without validation else { @@ -180,7 +184,12 @@ impl Validation { input_length = input_length.saturating_sub(max_new_tokens as usize); } - Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens)) + Ok(( + vec![Chunk::Text(inputs)], + None, + input_length, + max_new_tokens, + )) } } @@ -314,7 +323,7 @@ impl Validation { .unwrap_or(Ok(None))?; // Validate inputs - let (inputs, input_length, max_new_tokens) = self + let (inputs, input_ids, input_length, max_new_tokens) = self .validate_input(request.inputs, truncate, max_new_tokens) .await?; @@ -391,6 +400,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, + input_ids: input_ids.map(Arc::new), decoder_input_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, @@ -707,6 +717,7 @@ pub struct ValidStoppingParameters { #[derive(Debug, Clone)] pub struct ValidGenerateRequest { pub inputs: Vec, + pub input_ids: Option>>, pub input_length: u32, pub truncate: u32, pub decoder_input_details: bool, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index b58a5b80..abc35421 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,16 +5,29 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -ATTENTION = os.getenv("ATTENTION", "paged") +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False) +log_master(logger.info, f"Using Attention = {PREFIX_CACHING}") + +ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") +if PREFIX_CACHING and ATTENTION != "flashinfer": + raise RuntimeError("Prefix caching is only supported with flashinfer") + MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None + # This is overridden by the cli -BLOCK_SIZE: int = 256 if ATTENTION == "flashdecoding" else 16 +BLOCK_SIZE: int +if ATTENTION == "flashdecoding": + BLOCK_SIZE = 256 +elif ATTENTION == "flashinfer": + BLOCK_SIZE = 1 +else: + BLOCK_SIZE = 16 cuda_graphs = os.getenv("CUDA_GRAPHS") From 136bcc812870e36aea69c3bb9cb8012f0a63d973 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 12 Aug 2024 15:22:02 +0200 Subject: [PATCH 33/72] Keeping the benchmark somewhere (#2401) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Danieฬˆl de Kok --- Cargo.lock | 171 ++++++++++++++++++++++++++-- backends/v3/Cargo.toml | 22 +++- backends/v3/benches/prefix_cache.rs | 47 ++++++++ backends/v3/src/block_allocator.rs | 6 +- backends/v3/src/lib.rs | 4 +- backends/v3/src/queue.rs | 2 +- backends/v3/src/radix.rs | 5 + router/Cargo.toml | 20 +++- 8 files changed, 255 insertions(+), 22 deletions(-) create mode 100644 backends/v3/benches/prefix_cache.rs diff --git a/Cargo.lock b/Cargo.lock index 3a5845a7..bb63422e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -180,6 +180,17 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", +] + [[package]] name = "autocfg" version = "1.3.0" @@ -565,6 +576,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.1.7" @@ -617,6 +634,17 @@ dependencies = [ "libloading", ] +[[package]] +name = "clap" +version = "2.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +dependencies = [ + "bitflags 1.3.2", + "textwrap", + "unicode-width", +] + [[package]] name = "clap" version = "4.5.11" @@ -735,6 +763,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f" +dependencies = [ + "atty", + "cast", + "clap 2.34.0", + "criterion-plot", + "csv", + "itertools 0.10.5", + "lazy_static", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_cbor", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -1060,7 +1124,7 @@ checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4" dependencies = [ "bit_field", "flume", - "half", + "half 2.4.1", "lebe", "miniz_oxide", "rayon-core", @@ -1367,6 +1431,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" + [[package]] name = "half" version = "2.4.1" @@ -1404,6 +1474,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "hermit-abi" version = "0.3.9" @@ -1804,6 +1883,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1844,7 +1932,7 @@ dependencies = [ "anyhow", "base64 0.21.7", "bytecount", - "clap", + "clap 4.5.11", "fancy-regex", "fraction", "getrandom", @@ -2132,7 +2220,7 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", "wasi", "windows-sys 0.52.0", @@ -2400,7 +2488,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", ] @@ -2456,6 +2544,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + [[package]] name = "openssl" version = "0.10.66" @@ -2783,6 +2877,34 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +[[package]] +name = "plotters" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" + +[[package]] +name = "plotters-svg" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.17.13" @@ -3525,6 +3647,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde_cbor" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half 1.8.3", + "serde", +] + [[package]] name = "serde_derive" version = "1.0.204" @@ -3891,7 +4023,7 @@ version = "2.2.1-dev0" dependencies = [ "async-stream", "async-trait", - "clap", + "clap 4.5.11", "cmake", "cxx", "cxx-build", @@ -3912,7 +4044,7 @@ name = "text-generation-benchmark" version = "2.2.1-dev0" dependencies = [ "average", - "clap", + "clap 4.5.11", "crossterm", "float-ord", "hf-hub", @@ -3950,7 +4082,7 @@ dependencies = [ name = "text-generation-launcher" version = "2.2.1-dev0" dependencies = [ - "clap", + "clap 4.5.11", "ctrlc", "float_eq", "hf-hub", @@ -3974,7 +4106,7 @@ dependencies = [ "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", - "clap", + "clap 4.5.11", "csv", "futures", "futures-util", @@ -4022,13 +4154,15 @@ dependencies = [ "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", - "clap", + "clap 4.5.11", + "criterion", "futures", "futures-util", "grpc-metadata", "hf-hub", "image", "init-tracing-opentelemetry", + "itertools 0.13.0", "jsonschema", "metrics", "metrics-exporter-prometheus", @@ -4062,6 +4196,15 @@ dependencies = [ "utoipa-swagger-ui", ] +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", +] + [[package]] name = "thiserror" version = "1.0.63" @@ -4136,6 +4279,16 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.8.0" diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 129ceb9c..06a44bec 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -35,8 +35,14 @@ serde = "1.0.188" serde_json = "1.0.107" slotmap = "1.0.7" thiserror = "1.0.48" -tokenizers = { workspace = true} -tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", +] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.37" @@ -44,7 +50,9 @@ tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } -init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } minijinja = { version = "2.0.2" } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" @@ -60,8 +68,16 @@ tower = "^0.4" tonic-build = "0.10.1" prost-build = "0.12.1" +[dev-dependencies] +criterion = "0.3" +itertools = "0.13" + [features] default = ["ngrok"] ngrok = ["text-generation-router/ngrok"] google = ["text-generation-router/google"] kserve = ["text-generation-router/kserve"] + +[[bench]] +name = "prefix_cache" +harness = false diff --git a/backends/v3/benches/prefix_cache.rs b/backends/v3/benches/prefix_cache.rs new file mode 100644 index 00000000..d9df45b2 --- /dev/null +++ b/backends/v3/benches/prefix_cache.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::Rng; + +use text_generation_router_v3::block_allocator::Allocator; +use text_generation_router_v3::radix::RadixAllocator; + +fn prefix_cache_benchmark(c: &mut Criterion) { + // let prefixes: Vec> = (0..8192) + // .chunks(256) + // .into_iter() + // .map(|c| c.collect()) + // .collect(); + + let mut cache = RadixAllocator::new(1, 262144, None); + + c.bench_function("Radix allocator", |b| { + b.iter_batched( + || { + //prefixes + // .choose_multiple(&mut rand::thread_rng(), 5) + // .fold(Vec::new(), |mut v, s| { + // v.extend(s); + // v + // }) + + (0..7936) + .map(|_| rand::thread_rng().gen_range(0..1024)) + .collect::>() + }, + |prefill| { + let alloc = cache.allocate( + prefill.len() as u32 + 13, + Some(Arc::new(black_box(prefill))), + ); + if let Some(alloc) = alloc { + cache.free(alloc.blocks.clone(), alloc.allocation_id); + } + }, + criterion::BatchSize::SmallInput, + ); + }); +} + +criterion_group!(benches, prefix_cache_benchmark); +criterion_main!(benches); diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 05c2bd30..c5503b8c 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -4,7 +4,7 @@ use tokio::sync::{mpsc, oneshot}; use crate::radix::RadixAllocator; #[derive(Debug, Clone)] -pub(crate) struct BlockAllocation { +pub struct BlockAllocation { pub allocation_id: u64, pub blocks: Vec, pub slots: Vec, @@ -25,7 +25,7 @@ impl Drop for BlockAllocation { } #[derive(Debug, Clone)] -pub(crate) struct BlockAllocator { +pub struct BlockAllocator { /// Channel to communicate with the background task block_allocator: mpsc::UnboundedSender, } @@ -128,7 +128,7 @@ enum BlockAllocatorCommand { }, } -pub(crate) trait Allocator { +pub trait Allocator { fn allocate( &mut self, tokens: u32, diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index c8fc55f8..77a9a11a 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -1,8 +1,8 @@ mod backend; -mod block_allocator; +pub mod block_allocator; mod client; mod queue; -mod radix; +pub mod radix; use crate::client::{ClientError, ShardedClient}; pub(crate) use backend::BackendV3; diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 13544235..0fb05a98 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -250,7 +250,7 @@ impl State { // Create span for this batch to add context to inference calls let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); - next_batch_span.follows_from(&Span::current()); + next_batch_span.follows_from(Span::current()); let mut batch_requests = Vec::with_capacity(self.entries.len()); let mut batch_entries = diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 0464b9f8..ef963532 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -205,6 +205,11 @@ pub struct RadixTrie { /// call that a real time lookup would require. time: u64, } +impl Default for RadixTrie { + fn default() -> Self { + Self::new() + } +} impl RadixTrie { /// Construct a new radix trie. diff --git a/router/Cargo.toml b/router/Cargo.toml index 1be74546..7773e212 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -27,8 +27,14 @@ reqwest = { version = "0.11.20", features = [] } serde = "1.0.188" serde_json = "1.0.107" thiserror = "1.0.48" -tokenizers = { workspace = true} -tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", +] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.40" @@ -37,7 +43,9 @@ tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } -init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } minijinja = { version = "2.0.2" } minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } futures-util = "0.3.30" @@ -46,7 +54,11 @@ once_cell = "1.19.0" image = "0.25.1" base64 = { workspace = true } sysinfo = "0.30.13" -uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } +uuid = { version = "1.9.1", default-features = false, features = [ + "v4", + "fast-rng", + "macro-diagnostics", +] } csv = "1.3.0" ureq = "=2.9" From 155f9c98e2a272d255d42f0da8d2e88cc1eaa18a Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 12 Aug 2024 10:58:40 -0400 Subject: [PATCH 34/72] =?UTF-8?q?feat:=20validate=20template=20variables?= =?UTF-8?q?=20before=20apply=20and=20improve=20sliding=20wi=E2=80=A6=20(#2?= =?UTF-8?q?403)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: validate template variables before apply and improve sliding window check * fix: improve missing template var test --- router/src/infer/chat_template.rs | 61 +++++++++++++++++-- router/src/infer/mod.rs | 3 + router/src/server.rs | 1 + router/src/validation.rs | 2 +- .../text_generation_server/models/__init__.py | 18 +++--- 5 files changed, 70 insertions(+), 15 deletions(-) diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 7c2753ed..ef4beee2 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use crate::infer::InferError; use crate::{ ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken, @@ -16,6 +18,7 @@ pub(crate) struct ChatTemplate { bos_token: Option, eos_token: Option, use_default_tool_template: bool, + variables: HashSet, } impl ChatTemplate { @@ -30,19 +33,22 @@ impl ChatTemplate { let template_str = template.into_boxed_str(); env.add_function("raise_exception", raise_exception); - // check if contains the tools variable within the template - let use_default_tool_template = - !template_str.as_ref().replace(' ', "").contains("{{tools}}"); // leaking env and template_str as read-only, static resources for performance. let template = Box::leak(env) .template_from_str(Box::leak(template_str)) .unwrap(); + // get the list of variables that are used in the template + let variables = template.undeclared_variables(true); + // check if the `tools` variable is used in the template + let use_default_tool_template = !variables.contains("tools"); + Self { template, bos_token: bos_token.map(|token| token.as_str().to_string()), eos_token: eos_token.map(|token| token.as_str().to_string()), use_default_tool_template, + variables, } } @@ -64,6 +70,11 @@ impl ChatTemplate { let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); + // check if guideline is expected but not provided + if self.variables.contains("guideline") && guideline.is_none() { + return Err(InferError::MissingTemplateVariable("guideline".to_string())); + } + self.template .render(ChatTemplateInputs { guideline, @@ -82,7 +93,8 @@ impl ChatTemplate { #[cfg(test)] mod tests { use crate::infer::chat_template::raise_exception; - use crate::{ChatTemplateInputs, TextMessage}; + use crate::infer::ChatTemplate; + use crate::{ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken}; use minijinja::Environment; #[test] @@ -770,4 +782,45 @@ mod tests { assert_eq!(result, target); } } + + #[test] + fn test_chat_template_invalid_with_guideline() { + let ct = ChatTemplate::new( + "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n".to_string(), + Some(TokenizerConfigToken::String("".to_string())), + Some(TokenizerConfigToken::String("".to_string())), + ); + + // convert TextMessage to Message + let msgs: Vec = vec![ + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "I'd like to show off how chat templating works!".to_string(), + ), + }, + Message { + name: None, + role: "assistant".to_string(), + content: MessageContent::SingleText( + "I'm doing great. How can I help you today?".to_string(), + ), + }, + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText("Hello, how are you?".to_string()), + }, + ]; + + let result = ct.apply(None, msgs, None); + + match result { + Ok(_) => panic!("Should have failed since no guideline is provided"), + Err(e) => { + assert_eq!(e.to_string(), "Missing template vatiable: guideline") + } + } + } } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 58d5cf9a..c9354d9a 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -337,6 +337,8 @@ pub enum InferError { IncompleteGeneration, #[error("Template error: {0}")] TemplateError(#[from] minijinja::Error), + #[error("Missing template vatiable: {0}")] + MissingTemplateVariable(String), #[error("Tool error: {0}")] ToolError(String), } @@ -349,6 +351,7 @@ impl InferError { InferError::ValidationError(_) => "validation", InferError::IncompleteGeneration => "incomplete_generation", InferError::TemplateError(_) => "template_error", + InferError::MissingTemplateVariable(_) => "missing_template_variable", InferError::ToolError(_) => "tool_error", } } diff --git a/router/src/server.rs b/router/src/server.rs index 8c0bd762..99ec077f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2297,6 +2297,7 @@ impl From for (StatusCode, Json) { InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, + InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, }; diff --git a/router/src/validation.rs b/router/src/validation.rs index 5011158a..0024723c 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -830,7 +830,7 @@ mod tests { .await { // Err(ValidationError::MaxNewTokens(1, 10)) => (), - Ok((_s, 0, 10)) => (), + Ok((_s, _, 0, 10)) => (), r => panic!("Unexpected not max new tokens: {r:?}"), } } diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 960b426b..4fa9e66d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -497,17 +497,15 @@ def get_model( else -1 ) - if max_input_tokens is not None and max_input_tokens <= sliding_window: - sliding_window = -1 + should_use_sliding_window = ( + sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING + ) - if ( - (sliding_window is not None and sliding_window != -1) - and not SUPPORTS_WINDOWING - and max_input_tokens > sliding_window - ): - raise ValueError( - f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." - ) + if should_use_sliding_window: + if max_input_tokens is not None and max_input_tokens > sliding_window: + raise ValueError( + f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." + ) if model_type == DEEPSEEK_V2: if FLASH_ATTENTION: From 4c3f8a70a1c8590851aa3d7c82a7cabf01ed6e87 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 12 Aug 2024 11:24:32 -0400 Subject: [PATCH 35/72] fix: allocate tmp based on sgmv kernel if available (#2345) * fix: allocate tmp based on sgmv kernel if available * fix: re add copy build artifacts step for punica kernels --- Dockerfile | 2 ++ server/text_generation_server/utils/sgmv.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index c68f76f6..458ff699 100644 --- a/Dockerfile +++ b/Dockerfile @@ -226,6 +226,8 @@ COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-31 COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from eetq kernels builder COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from lorax punica kernels builder +COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from fbgemm builder COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages # Copy build artifacts from vllm builder diff --git a/server/text_generation_server/utils/sgmv.py b/server/text_generation_server/utils/sgmv.py index e0aec25f..2d0a73a5 100644 --- a/server/text_generation_server/utils/sgmv.py +++ b/server/text_generation_server/utils/sgmv.py @@ -151,13 +151,17 @@ def get_tmp_expand_size(size: int) -> int: def get_tmp_tensors( nsegments: int, lora_rank: int, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: - if use_cutlass_shrink(lora_rank) and has_sgmv(): + use_cutlass = use_cutlass_shrink(lora_rank) and has_sgmv() + has_sgmv_available = has_sgmv() + + if use_cutlass: tmp = get_tmp_tensor_for_size(nsegments, device) return tmp, tmp + elif has_sgmv_available: + return get_tmp_tensor(device), get_tmp_tensor_for_size(nsegments, device) else: - tmp_shrink = get_tmp_tensor(device) - tmp_expand = get_tmp_tensor_for_size_no_kernels(nsegments, device) - return tmp_shrink, tmp_expand + tmp = get_tmp_tensor_for_size(nsegments, device) + return tmp, tmp def lora_a_sgmv_cutlass( From 30395b09f4eff271dd1dfdc49be4fd46f4a546dd Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 12 Aug 2024 11:26:11 -0400 Subject: [PATCH 36/72] fix: improve completions to send a final chunk with usage details (#2336) * fix: improve completions to send a final chunk with usage details * fix: include finish reason string * fix: remove dev debug trait and unneeded mut * fix: update openapi schema --- docs/openapi.json | 9 ++++++++- router/src/lib.rs | 2 ++ router/src/server.rs | 42 ++++++++++++++++++++++++++++++++++-------- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index ecd56e4d..df21e19d 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1824,7 +1824,8 @@ "type": "object", "required": [ "finish_reason", - "generated_tokens" + "generated_tokens", + "input_length" ], "properties": { "finish_reason": { @@ -1836,6 +1837,12 @@ "example": 1, "minimum": 0 }, + "input_length": { + "type": "integer", + "format": "int32", + "example": 1, + "minimum": 0 + }, "seed": { "type": "integer", "format": "int64", diff --git a/router/src/lib.rs b/router/src/lib.rs index 0a15c495..d7eb4475 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1219,6 +1219,8 @@ pub(crate) struct StreamDetails { pub generated_tokens: u32, #[schema(nullable = true, example = 42)] pub seed: Option, + #[schema(example = 1)] + pub input_length: u32, } #[derive(Serialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 99ec077f..ab268efa 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -533,7 +533,7 @@ async fn generate_stream_internal( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, _input_length, response_stream)) => { + Ok((_permit, input_length, response_stream)) => { let mut index = 0; let mut response_stream = Box::pin(response_stream); // Server-Sent Event stream @@ -576,6 +576,7 @@ async fn generate_stream_internal( finish_reason: generated_text.finish_reason, generated_tokens: generated_text.generated_tokens, seed: generated_text.seed, + input_length, }), false => None, }; @@ -801,21 +802,46 @@ async fn completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); - event - .json_data(Completion::Chunk(Chunk { - id: "".to_string(), - created: current_time, + let message = match stream_token.details { + Some(details) => { + let completion_tokens = details.generated_tokens; + let prompt_tokens = details.input_length; + let total_tokens = prompt_tokens + completion_tokens; + Completion::Final(CompletionFinal { + id: String::new(), + created: current_time, + model: model_id.clone(), + system_fingerprint: system_fingerprint.clone(), + choices: vec![CompletionComplete { + finish_reason: details.finish_reason.to_string(), + index: index as u32, + logprobs: None, + text: stream_token.token.text, + }], + usage: Usage { + prompt_tokens, + completion_tokens, + total_tokens, + }, + }) + } + None => Completion::Chunk(Chunk { + id: String::new(), + created: current_time, choices: vec![CompletionComplete { - finish_reason: "".to_string(), + finish_reason: String::new(), index: index as u32, logprobs: None, text: stream_token.token.text, }], - model: model_id.clone(), system_fingerprint: system_fingerprint.clone(), - })) + }), + }; + + event + .json_data(message) .unwrap_or_else(|_e| Event::default()) }; From 19ea85f8dc7d8d29ddfefa7f906a16747e9996b9 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 12 Aug 2024 18:09:16 +0200 Subject: [PATCH 37/72] Updating the flake. (#2404) --- _server.nix | 12 ++++ flake.lock | 158 +++++++++++++++++++++++++++++++++++++++++++++++----- flake.nix | 10 ++++ 3 files changed, 165 insertions(+), 15 deletions(-) create mode 100644 _server.nix diff --git a/_server.nix b/_server.nix new file mode 100644 index 00000000..bc11ba6b --- /dev/null +++ b/_server.nix @@ -0,0 +1,12 @@ +{ mkPoetryApplication, pkg-config, protobuf, openssl }: + +mkPoetryApplication { + # name = "text-generation-server"; + + projectDir = ./server; + + # nativeBuildInputs = [ pkg-config ]; + + # buildInputs = [ openssl.dev protobuf ]; + +} diff --git a/flake.lock b/flake.lock index 7889e7cf..47f14626 100644 --- a/flake.lock +++ b/flake.lock @@ -33,19 +33,96 @@ "type": "github" } }, - "nixpkgs": { + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, "locked": { - "lastModified": 1723099294, - "narHash": "sha256-kkijy6sXo/SOhFw7ZEfYHbj1FJHxoeetOVOn3qNHc5o=", - "owner": "danieldk", - "repo": "nixpkgs", - "rev": "45892b6ec142eaf300d777926983a433b5842c88", + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", "type": "github" }, "original": { - "owner": "danieldk", - "ref": "cudnn-9.3", + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nix-github-actions": { + "inputs": { + "nixpkgs": [ + "poetry2nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1703863825, + "narHash": "sha256-rXwqjtwiGKJheXB43ybM8NwWB8rO2dSRrEqes0S7F5Y=", + "owner": "nix-community", + "repo": "nix-github-actions", + "rev": "5163432afc817cf8bd1f031418d1869e4c9d5547", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "nix-github-actions", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1719763542, + "narHash": "sha256-mXkOj9sJ0f69Nkc2dGGOWtof9d1YNY8Le/Hia3RN+8Q=", + "owner": "NixOS", "repo": "nixpkgs", + "rev": "e6cdd8a11b26b4d60593733106042141756b54a3", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1723418128, + "narHash": "sha256-k1pEqsnB6ikZyasXbtV6A9akPZMKlsyENPDUA6PXoJo=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "129f579cbb5b4c1ad258fd96bdfb78eb14802727", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "poetry2nix": { + "inputs": { + "flake-utils": "flake-utils_2", + "nix-github-actions": "nix-github-actions", + "nixpkgs": "nixpkgs", + "systems": "systems_3", + "treefmt-nix": "treefmt-nix" + }, + "locked": { + "lastModified": 1723343306, + "narHash": "sha256-/6sRkPq7/5weX2y0V8sQ29Sz35nt8kyj+BsFtkhgbJE=", + "owner": "nix-community", + "repo": "poetry2nix", + "rev": "4a1c112ff0c67f496573dc345bd0b2247818fc29", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "poetry2nix", "type": "github" } }, @@ -56,6 +133,7 @@ "tgi-nix", "nixpkgs" ], + "poetry2nix": "poetry2nix", "rust-overlay": "rust-overlay", "tgi-nix": "tgi-nix" } @@ -68,11 +146,11 @@ ] }, "locked": { - "lastModified": 1723170066, - "narHash": "sha256-SFkQfOA+8AIYJsPlQtxNP+z5jRLfz91z/aOrV94pPmw=", + "lastModified": 1723429325, + "narHash": "sha256-4x/32xTCd+xCwFoI/kKSiCr5LQA2ZlyTRYXKEni5HR8=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "fecfe4d7c96fea2982c7907997b387a6b52c1093", + "rev": "65e3dc0fe079fe8df087cd38f1fe6836a0373aad", "type": "github" }, "original": { @@ -96,17 +174,46 @@ "type": "github" } }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_3": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "id": "systems", + "type": "indirect" + } + }, "tgi-nix": { "inputs": { "flake-compat": "flake-compat", - "nixpkgs": "nixpkgs" + "nixpkgs": "nixpkgs_2" }, "locked": { - "lastModified": 1723234585, - "narHash": "sha256-HChJpNP155FPhHr9C5BtqllV8Uv/Ebg59HhMc/HhQrc=", + "lastModified": 1723450799, + "narHash": "sha256-cuT/ce7R2D5Lx6Ted4YS4y+WrAAOXQFAbzLsM1vtPo8=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "15bd4a978d6c2c8b04b7afc335d137dbe41e73df", + "rev": "29f9b45bd613eced65c4a5241a00aa9346f63d90", "type": "github" }, "original": { @@ -114,6 +221,27 @@ "repo": "tgi-nix", "type": "github" } + }, + "treefmt-nix": { + "inputs": { + "nixpkgs": [ + "poetry2nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1719749022, + "narHash": "sha256-ddPKHcqaKCIFSFc/cvxS14goUhCOAwsM1PbMr0ZtHMg=", + "owner": "numtide", + "repo": "treefmt-nix", + "rev": "8df5ff62195d4e67e2264df0b7f5e8c9995fd0bd", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "treefmt-nix", + "type": "github" + } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 761c4af8..1f842e8e 100644 --- a/flake.nix +++ b/flake.nix @@ -3,6 +3,7 @@ tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; + poetry2nix.url = "github:nix-community/poetry2nix"; rust-overlay = { url = "github:oxalica/rust-overlay"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; @@ -15,6 +16,7 @@ flake-utils, rust-overlay, tgi-nix, + poetry2nix, }: flake-utils.lib.eachDefaultSystem ( system: @@ -30,6 +32,11 @@ tgi-nix.overlay ]; }; + + inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryEditablePackage; + text-generation-server = mkPoetryEditablePackage { + editablePackageSources = ./server; + }; in { devShells.default = @@ -50,14 +57,17 @@ venvShellHook pip + click einops fbgemm-gpu + flashinfer flash-attn flash-attn-layer-norm flash-attn-rotary grpc-interceptor grpcio-reflection grpcio-status + grpcio-tools hf-transfer loguru marlin-kernels From 9a7830bd287fb1ca2a3e94f208acd2eb881eb311 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 12 Aug 2024 14:38:59 -0400 Subject: [PATCH 38/72] Pr 2395 ci run (#2406) * fix(router): Fix appending to message content * feat: add message and chat template test --------- Co-authored-by: Simone Rossi --- router/src/infer/chat_template.rs | 40 ++++++++++++++++++++++++++++++- router/src/lib.rs | 35 +++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index ef4beee2..a8537818 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -94,7 +94,9 @@ impl ChatTemplate { mod tests { use crate::infer::chat_template::raise_exception; use crate::infer::ChatTemplate; - use crate::{ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken}; + use crate::{ + ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken, + }; use minijinja::Environment; #[test] @@ -823,4 +825,40 @@ mod tests { } } } + + #[test] + fn test_chat_template_with_default_tool_template() { + let ct = ChatTemplate::new( + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(), + Some(TokenizerConfigToken::String("".to_string())), + Some(TokenizerConfigToken::String("".to_string())), + ); + + // convert TextMessage to Message + let msgs: Vec = vec![ + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "I'd like to show off how chat templating works!".to_string(), + ), + }, + Message { + name: None, + role: "assistant".to_string(), + content: MessageContent::SingleText("Great! How can I help you today?".to_string()), + }, + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText("Just testing".to_string()), + }, + ]; + let tools = serde_json::json!("[]"); + let tool_prompt = "This default prompt will be used".to_string(); + let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt); + let result = ct.apply(None, msgs, Some(grammer_with_prompt)); + let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string(); + assert_eq!(result.unwrap(), expected); + } } diff --git a/router/src/lib.rs b/router/src/lib.rs index d7eb4475..1b2ff153 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1016,8 +1016,10 @@ impl MessageContent { pub fn push(&mut self, chunk: MessageChunk) { match self { MessageContent::SingleText(text) => { - *self = - MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]); + *self = MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: text.clone() }, + chunk, + ]); } MessageContent::MultipleChunks(chunks) => { chunks.push(chunk); @@ -1348,6 +1350,35 @@ mod tests { ); } + #[test] + fn test_message_content_append() { + let mut content = MessageContent::SingleText("Initial text".to_string()); + let chunk = MessageChunk::Text { + text: "Additional text".to_string(), + }; + + content.push(chunk); + + match content { + MessageContent::MultipleChunks(chunks) => { + assert_eq!(chunks.len(), 2); + assert_eq!( + chunks[0], + MessageChunk::Text { + text: "Initial text".to_string() + } + ); + assert_eq!( + chunks[1], + MessageChunk::Text { + text: "Additional text".to_string() + } + ); + } + _ => panic!("Expected MultipleChunks, but got a different variant"), + } + } + #[test] fn test_chat_request() { let json = json!({ From 8a7749b8fb678934a53ee72102644572a8d7b0a6 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 12 Aug 2024 17:59:37 -0400 Subject: [PATCH 39/72] fix: include create_exllama_buffers and set_device for exllama (#2407) --- server/text_generation_server/layers/gptq/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 9c9b69d1..505caa59 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -422,12 +422,16 @@ elif CAN_EXLLAMA: if V2: from text_generation_server.layers.gptq.exllamav2 import ( QuantLinear as ExllamaQuantLinear, # noqa: F401 + create_exllama_buffers, # noqa: F401 + set_device, # noqa: F401 ) HAS_EXLLAMA = "2" else: from text_generation_server.layers.gptq.exllama import ( Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 + create_exllama_buffers, # noqa: F401 + set_device, # noqa: F401 ) HAS_EXLLAMA = "1" From 6f4bb4f26ffa6351a7caa27d6e81fc4872b3a16b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 13 Aug 2024 10:44:15 +0200 Subject: [PATCH 40/72] nix: incremental build of the launcher (#2410) --- _launcher.nix | 18 -- flake.lock | 844 +++++++++++++++++++++++++++++++++++++++++++++++++- flake.nix | 14 +- 3 files changed, 846 insertions(+), 30 deletions(-) delete mode 100644 _launcher.nix diff --git a/_launcher.nix b/_launcher.nix deleted file mode 100644 index 1acae7e1..00000000 --- a/_launcher.nix +++ /dev/null @@ -1,18 +0,0 @@ -{ buildRustPackage, importCargoLock, pkg-config, protobuf, openssl }: - -buildRustPackage { - name = "text-generation-lancher"; - - src = ./.; - - sourceDir = ./launcher; - - cargoLock = { - lockFile = ./Cargo.lock; - }; - - nativeBuildInputs = [ pkg-config ]; - - buildInputs = [ openssl.dev protobuf ]; - -} diff --git a/flake.lock b/flake.lock index 47f14626..bc114032 100644 --- a/flake.lock +++ b/flake.lock @@ -1,6 +1,309 @@ { "nodes": { + "cachix": { + "inputs": { + "devenv": [ + "crate2nix" + ], + "flake-compat": [ + "crate2nix" + ], + "nixpkgs": "nixpkgs", + "pre-commit-hooks": [ + "crate2nix" + ] + }, + "locked": { + "lastModified": 1709700175, + "narHash": "sha256-A0/6ZjLmT9qdYzKHmevnEIC7G+GiZ4UCr8v0poRPzds=", + "owner": "cachix", + "repo": "cachix", + "rev": "be97b37989f11b724197b5f4c7ffd78f12c8c4bf", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "latest", + "repo": "cachix", + "type": "github" + } + }, + "cachix_2": { + "inputs": { + "devenv": [ + "crate2nix", + "crate2nix_stable" + ], + "flake-compat": [ + "crate2nix", + "crate2nix_stable" + ], + "nixpkgs": "nixpkgs_2", + "pre-commit-hooks": [ + "crate2nix", + "crate2nix_stable" + ] + }, + "locked": { + "lastModified": 1716549461, + "narHash": "sha256-lHy5kgx6J8uD+16SO47dPrbob98sh+W1tf4ceSqPVK4=", + "owner": "cachix", + "repo": "cachix", + "rev": "e2bb269fb8c0828d5d4d2d7b8d09ea85abcacbd4", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "latest", + "repo": "cachix", + "type": "github" + } + }, + "cachix_3": { + "inputs": { + "devenv": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable" + ], + "flake-compat": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable" + ], + "nixpkgs": "nixpkgs_3", + "pre-commit-hooks": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable" + ] + }, + "locked": { + "lastModified": 1716549461, + "narHash": "sha256-lHy5kgx6J8uD+16SO47dPrbob98sh+W1tf4ceSqPVK4=", + "owner": "cachix", + "repo": "cachix", + "rev": "e2bb269fb8c0828d5d4d2d7b8d09ea85abcacbd4", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "latest", + "repo": "cachix", + "type": "github" + } + }, + "crate2nix": { + "inputs": { + "cachix": "cachix", + "crate2nix_stable": "crate2nix_stable", + "devshell": "devshell_3", + "flake-compat": "flake-compat_3", + "flake-parts": "flake-parts_3", + "nix-test-runner": "nix-test-runner_3", + "nixpkgs": [ + "tgi-nix", + "nixpkgs" + ], + "pre-commit-hooks": "pre-commit-hooks_3" + }, + "locked": { + "lastModified": 1723311214, + "narHash": "sha256-xdGZQBEa1AC2us/sY3igS/CucWY6jErXsAvCFRhB2LI=", + "owner": "nix-community", + "repo": "crate2nix", + "rev": "236f6addfd452a48be805819e3216af79e988fd5", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "crate2nix", + "type": "github" + } + }, + "crate2nix_stable": { + "inputs": { + "cachix": "cachix_2", + "crate2nix_stable": "crate2nix_stable_2", + "devshell": "devshell_2", + "flake-compat": "flake-compat_2", + "flake-parts": "flake-parts_2", + "nix-test-runner": "nix-test-runner_2", + "nixpkgs": "nixpkgs_5", + "pre-commit-hooks": "pre-commit-hooks_2" + }, + "locked": { + "lastModified": 1719760004, + "narHash": "sha256-esWhRnt7FhiYq0CcIxw9pvH+ybOQmWBfHYMtleaMhBE=", + "owner": "nix-community", + "repo": "crate2nix", + "rev": "1dee214bb20855fa3e1e7bb98d28922ddaff8c57", + "type": "github" + }, + "original": { + "owner": "nix-community", + "ref": "0.14.1", + "repo": "crate2nix", + "type": "github" + } + }, + "crate2nix_stable_2": { + "inputs": { + "cachix": "cachix_3", + "crate2nix_stable": "crate2nix_stable_3", + "devshell": "devshell", + "flake-compat": "flake-compat", + "flake-parts": "flake-parts", + "nix-test-runner": "nix-test-runner", + "nixpkgs": "nixpkgs_4", + "pre-commit-hooks": "pre-commit-hooks" + }, + "locked": { + "lastModified": 1712821484, + "narHash": "sha256-rGT3CW64cJS9nlnWPFWSc1iEa3dNZecVVuPVGzcsHe8=", + "owner": "nix-community", + "repo": "crate2nix", + "rev": "42883afcad3823fa5811e967fb7bff54bc3c9d6d", + "type": "github" + }, + "original": { + "owner": "nix-community", + "ref": "0.14.0", + "repo": "crate2nix", + "type": "github" + } + }, + "crate2nix_stable_3": { + "inputs": { + "flake-utils": "flake-utils" + }, + "locked": { + "lastModified": 1702842982, + "narHash": "sha256-A9AowkHIjsy1a4LuiPiVP88FMxyCWK41flZEZOUuwQM=", + "owner": "nix-community", + "repo": "crate2nix", + "rev": "75ac2973affa6b9b4f661a7b592cba6e4f51d426", + "type": "github" + }, + "original": { + "owner": "nix-community", + "ref": "0.12.0", + "repo": "crate2nix", + "type": "github" + } + }, + "devshell": { + "inputs": { + "flake-utils": "flake-utils_2", + "nixpkgs": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1717408969, + "narHash": "sha256-Q0OEFqe35fZbbRPPRdrjTUUChKVhhWXz3T9ZSKmaoVY=", + "owner": "numtide", + "repo": "devshell", + "rev": "1ebbe68d57457c8cae98145410b164b5477761f4", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "devshell", + "type": "github" + } + }, + "devshell_2": { + "inputs": { + "flake-utils": "flake-utils_3", + "nixpkgs": [ + "crate2nix", + "crate2nix_stable", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1717408969, + "narHash": "sha256-Q0OEFqe35fZbbRPPRdrjTUUChKVhhWXz3T9ZSKmaoVY=", + "owner": "numtide", + "repo": "devshell", + "rev": "1ebbe68d57457c8cae98145410b164b5477761f4", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "devshell", + "type": "github" + } + }, + "devshell_3": { + "inputs": { + "flake-utils": "flake-utils_4", + "nixpkgs": [ + "crate2nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1711099426, + "narHash": "sha256-HzpgM/wc3aqpnHJJ2oDqPBkNsqWbW0WfWUO8lKu8nGk=", + "owner": "numtide", + "repo": "devshell", + "rev": "2d45b54ca4a183f2fdcf4b19c895b64fbf620ee8", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "devshell", + "type": "github" + } + }, "flake-compat": { + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "revCount": 57, + "type": "tarball", + "url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz" + }, + "original": { + "type": "tarball", + "url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "revCount": 57, + "type": "tarball", + "url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz" + }, + "original": { + "type": "tarball", + "url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz" + } + }, + "flake-compat_3": { + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "revCount": 57, + "type": "tarball", + "url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz" + }, + "original": { + "type": "tarball", + "url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz" + } + }, + "flake-compat_4": { "locked": { "lastModified": 1696426674, "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", @@ -15,10 +318,148 @@ "type": "github" } }, + "flake-parts": { + "inputs": { + "nixpkgs-lib": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1719745305, + "narHash": "sha256-xwgjVUpqSviudEkpQnioeez1Uo2wzrsMaJKJClh+Bls=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "c3c5ecc05edc7dafba779c6c1a61cd08ac6583e9", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-parts_2": { + "inputs": { + "nixpkgs-lib": [ + "crate2nix", + "crate2nix_stable", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1719745305, + "narHash": "sha256-xwgjVUpqSviudEkpQnioeez1Uo2wzrsMaJKJClh+Bls=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "c3c5ecc05edc7dafba779c6c1a61cd08ac6583e9", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-parts_3": { + "inputs": { + "nixpkgs-lib": [ + "crate2nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1712014858, + "narHash": "sha256-sB4SWl2lX95bExY2gMFG5HIzvva5AVMJd4Igm+GpZNw=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "9126214d0a59633752a136528f5f3b9aa8565b7d", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, "flake-utils": { "inputs": { "systems": "systems" }, + "locked": { + "lastModified": 1694529238, + "narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "ff7b65b44d01cf9ba6a71320833626af21126384", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1701680307, + "narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "4022d587cbbfd70fe950c1e2083a02621806a725", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_3": { + "inputs": { + "systems": "systems_3" + }, + "locked": { + "lastModified": 1701680307, + "narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "4022d587cbbfd70fe950c1e2083a02621806a725", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_4": { + "inputs": { + "systems": "systems_4" + }, + "locked": { + "lastModified": 1701680307, + "narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "4022d587cbbfd70fe950c1e2083a02621806a725", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_5": { + "inputs": { + "systems": "systems_5" + }, "locked": { "lastModified": 1710146030, "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", @@ -33,9 +474,9 @@ "type": "github" } }, - "flake-utils_2": { + "flake-utils_6": { "inputs": { - "systems": "systems_2" + "systems": "systems_6" }, "locked": { "lastModified": 1710146030, @@ -51,6 +492,93 @@ "type": "github" } }, + "flake-utils_7": { + "inputs": { + "systems": "systems_7" + }, + "locked": { + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable", + "pre-commit-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "gitignore_2": { + "inputs": { + "nixpkgs": [ + "crate2nix", + "crate2nix_stable", + "pre-commit-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "gitignore_3": { + "inputs": { + "nixpkgs": [ + "crate2nix", + "pre-commit-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, "nix-github-actions": { "inputs": { "nixpkgs": [ @@ -72,7 +600,129 @@ "type": "github" } }, + "nix-test-runner": { + "flake": false, + "locked": { + "lastModified": 1588761593, + "narHash": "sha256-FKJykltAN/g3eIceJl4SfDnnyuH2jHImhMrXS2KvGIs=", + "owner": "stoeffel", + "repo": "nix-test-runner", + "rev": "c45d45b11ecef3eb9d834c3b6304c05c49b06ca2", + "type": "github" + }, + "original": { + "owner": "stoeffel", + "repo": "nix-test-runner", + "type": "github" + } + }, + "nix-test-runner_2": { + "flake": false, + "locked": { + "lastModified": 1588761593, + "narHash": "sha256-FKJykltAN/g3eIceJl4SfDnnyuH2jHImhMrXS2KvGIs=", + "owner": "stoeffel", + "repo": "nix-test-runner", + "rev": "c45d45b11ecef3eb9d834c3b6304c05c49b06ca2", + "type": "github" + }, + "original": { + "owner": "stoeffel", + "repo": "nix-test-runner", + "type": "github" + } + }, + "nix-test-runner_3": { + "flake": false, + "locked": { + "lastModified": 1588761593, + "narHash": "sha256-FKJykltAN/g3eIceJl4SfDnnyuH2jHImhMrXS2KvGIs=", + "owner": "stoeffel", + "repo": "nix-test-runner", + "rev": "c45d45b11ecef3eb9d834c3b6304c05c49b06ca2", + "type": "github" + }, + "original": { + "owner": "stoeffel", + "repo": "nix-test-runner", + "type": "github" + } + }, "nixpkgs": { + "locked": { + "lastModified": 1700612854, + "narHash": "sha256-yrQ8osMD+vDLGFX7pcwsY/Qr5PUd6OmDMYJZzZi0+zc=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "19cbff58383a4ae384dea4d1d0c823d72b49d614", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1715534503, + "narHash": "sha256-5ZSVkFadZbFP1THataCaSf0JH2cAH3S29hU9rrxTEqk=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "2057814051972fa1453ddfb0d98badbea9b83c06", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_3": { + "locked": { + "lastModified": 1715534503, + "narHash": "sha256-5ZSVkFadZbFP1THataCaSf0JH2cAH3S29hU9rrxTEqk=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "2057814051972fa1453ddfb0d98badbea9b83c06", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_4": { + "locked": { + "lastModified": 1719506693, + "narHash": "sha256-C8e9S7RzshSdHB7L+v9I51af1gDM5unhJ2xO1ywxNH8=", + "path": "/nix/store/4p0avw1s3vf27hspgqsrqs37gxk4i83i-source", + "rev": "b2852eb9365c6de48ffb0dc2c9562591f652242a", + "type": "path" + }, + "original": { + "id": "nixpkgs", + "type": "indirect" + } + }, + "nixpkgs_5": { + "locked": { + "lastModified": 1719506693, + "narHash": "sha256-C8e9S7RzshSdHB7L+v9I51af1gDM5unhJ2xO1ywxNH8=", + "path": "/nix/store/4p0avw1s3vf27hspgqsrqs37gxk4i83i-source", + "rev": "b2852eb9365c6de48ffb0dc2c9562591f652242a", + "type": "path" + }, + "original": { + "id": "nixpkgs", + "type": "indirect" + } + }, + "nixpkgs_6": { "locked": { "lastModified": 1719763542, "narHash": "sha256-mXkOj9sJ0f69Nkc2dGGOWtof9d1YNY8Le/Hia3RN+8Q=", @@ -88,7 +738,7 @@ "type": "github" } }, - "nixpkgs_2": { + "nixpkgs_7": { "locked": { "lastModified": 1723418128, "narHash": "sha256-k1pEqsnB6ikZyasXbtV6A9akPZMKlsyENPDUA6PXoJo=", @@ -106,10 +756,10 @@ }, "poetry2nix": { "inputs": { - "flake-utils": "flake-utils_2", + "flake-utils": "flake-utils_7", "nix-github-actions": "nix-github-actions", - "nixpkgs": "nixpkgs", - "systems": "systems_3", + "nixpkgs": "nixpkgs_6", + "systems": "systems_8", "treefmt-nix": "treefmt-nix" }, "locked": { @@ -126,9 +776,110 @@ "type": "github" } }, + "pre-commit-hooks": { + "inputs": { + "flake-compat": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable", + "flake-compat" + ], + "gitignore": "gitignore", + "nixpkgs": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable", + "nixpkgs" + ], + "nixpkgs-stable": [ + "crate2nix", + "crate2nix_stable", + "crate2nix_stable", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1719259945, + "narHash": "sha256-F1h+XIsGKT9TkGO3omxDLEb/9jOOsI6NnzsXFsZhry4=", + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "rev": "0ff4381bbb8f7a52ca4a851660fc7a437a4c6e07", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "type": "github" + } + }, + "pre-commit-hooks_2": { + "inputs": { + "flake-compat": [ + "crate2nix", + "crate2nix_stable", + "flake-compat" + ], + "gitignore": "gitignore_2", + "nixpkgs": [ + "crate2nix", + "crate2nix_stable", + "nixpkgs" + ], + "nixpkgs-stable": [ + "crate2nix", + "crate2nix_stable", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1719259945, + "narHash": "sha256-F1h+XIsGKT9TkGO3omxDLEb/9jOOsI6NnzsXFsZhry4=", + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "rev": "0ff4381bbb8f7a52ca4a851660fc7a437a4c6e07", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "type": "github" + } + }, + "pre-commit-hooks_3": { + "inputs": { + "flake-compat": [ + "crate2nix", + "flake-compat" + ], + "flake-utils": "flake-utils_5", + "gitignore": "gitignore_3", + "nixpkgs": [ + "crate2nix", + "nixpkgs" + ], + "nixpkgs-stable": [ + "crate2nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1712055707, + "narHash": "sha256-4XLvuSIDZJGS17xEwSrNuJLL7UjDYKGJSbK1WWX2AK8=", + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "rev": "e35aed5fda3cc79f88ed7f1795021e559582093a", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "type": "github" + } + }, "root": { "inputs": { - "flake-utils": "flake-utils", + "crate2nix": "crate2nix", + "flake-utils": "flake-utils_6", "nixpkgs": [ "tgi-nix", "nixpkgs" @@ -190,6 +941,81 @@ } }, "systems_3": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_4": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_5": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_6": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_7": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_8": { "locked": { "lastModified": 1681028828, "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", @@ -205,8 +1031,8 @@ }, "tgi-nix": { "inputs": { - "flake-compat": "flake-compat", - "nixpkgs": "nixpkgs_2" + "flake-compat": "flake-compat_4", + "nixpkgs": "nixpkgs_7" }, "locked": { "lastModified": 1723450799, diff --git a/flake.nix b/flake.nix index 1f842e8e..40b845b7 100644 --- a/flake.nix +++ b/flake.nix @@ -1,5 +1,9 @@ { inputs = { + crate2nix = { + url = "github:nix-community/crate2nix"; + inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; + }; tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; @@ -12,6 +16,7 @@ outputs = { self, + crate2nix, nixpkgs, flake-utils, rust-overlay, @@ -21,6 +26,10 @@ flake-utils.lib.eachDefaultSystem ( system: let + cargoNix = crate2nix.tools.${system}.appliedCargoNix { + name = "tgi"; + src = ./.; + }; config = { allowUnfree = true; cudaSupport = true; @@ -81,12 +90,11 @@ transformers vllm + cargoNix.workspaceMembers.text-generation-launcher.build + (callPackage ./router.nix { inherit (rustPlatform) buildRustPackage importCargoLock; }) - (callPackage ./_launcher.nix { - inherit (rustPlatform) buildRustPackage importCargoLock; - }) ]); venvDir = "./.venv"; From cd9b15d17f0cc89bc443ff876b3bee762bb4856e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 13 Aug 2024 10:49:18 +0200 Subject: [PATCH 41/72] Adding more kernels to flake. (#2411) --- flake.lock | 18 +++++++++--------- flake.nix | 2 ++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/flake.lock b/flake.lock index bc114032..ef05c258 100644 --- a/flake.lock +++ b/flake.lock @@ -763,11 +763,11 @@ "treefmt-nix": "treefmt-nix" }, "locked": { - "lastModified": 1723343306, - "narHash": "sha256-/6sRkPq7/5weX2y0V8sQ29Sz35nt8kyj+BsFtkhgbJE=", + "lastModified": 1723512448, + "narHash": "sha256-VSTtxGKre1p6zd6ACuBmfDcR+BT9+ml8Y3KrSbfGFYU=", "owner": "nix-community", "repo": "poetry2nix", - "rev": "4a1c112ff0c67f496573dc345bd0b2247818fc29", + "rev": "ed52f844c4dd04dde45550c3189529854384124e", "type": "github" }, "original": { @@ -897,11 +897,11 @@ ] }, "locked": { - "lastModified": 1723429325, - "narHash": "sha256-4x/32xTCd+xCwFoI/kKSiCr5LQA2ZlyTRYXKEni5HR8=", + "lastModified": 1723515680, + "narHash": "sha256-nHdKymsHCVIh0Wdm4MvSgxcTTg34FJIYHRQkQYaSuvk=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "65e3dc0fe079fe8df087cd38f1fe6836a0373aad", + "rev": "4ee3d9e9569f70d7bb40f28804d6fe950c81eab3", "type": "github" }, "original": { @@ -1035,11 +1035,11 @@ "nixpkgs": "nixpkgs_7" }, "locked": { - "lastModified": 1723450799, - "narHash": "sha256-cuT/ce7R2D5Lx6Ted4YS4y+WrAAOXQFAbzLsM1vtPo8=", + "lastModified": 1723532088, + "narHash": "sha256-6h/P/BkFDw8txlikonKXp5IbluHSPhHJTQRftJLkbLQ=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "29f9b45bd613eced65c4a5241a00aa9346f63d90", + "rev": "32335a37ce0f703bab901baf7b74eb11e9972d5f", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 40b845b7..8784be37 100644 --- a/flake.nix +++ b/flake.nix @@ -66,6 +66,7 @@ venvShellHook pip + causal-conv1d click einops fbgemm-gpu @@ -79,6 +80,7 @@ grpcio-tools hf-transfer loguru + mamba-ssm marlin-kernels opentelemetry-api opentelemetry-exporter-otlp From 59922f9bc16afee9efcc7ee1c5f9d753ef314ffa Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 13 Aug 2024 21:33:55 +0800 Subject: [PATCH 42/72] add numa to improve cpu inference perf (#2330) Signed-off-by: Wang, Yi A --- Dockerfile_intel | 12 +++---- .../models/flash_causal_lm.py | 31 +++++++++++++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index 158c5a89..12480c70 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -106,7 +106,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins g++ \ git \ wget \ - cmake + cmake \ + libnuma-dev ENV HUGGINGFACE_HUB_CACHE=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ @@ -135,7 +136,7 @@ RUN conda install -c conda-forge gperftools mkl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl -RUN pip install triton +RUN pip install triton numa WORKDIR /usr/src @@ -147,16 +148,11 @@ RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install . -ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so +ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib -ENV KMP_BLOCKTIME=1 -ENV KMP_TPAUSE=0 -ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist -ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist -ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist # Install server COPY proto proto diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 21b66a68..42d93a12 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -74,6 +74,36 @@ def get_sliding_windows() -> int: return SLIDING_WINDOW +def init_cpu_threads_env(rank_id: int, world_size: int): + import importlib.util + + if importlib.util.find_spec("numa") is not None: + import numa + import psutil + + nodes = numa.get_max_node() + 1 + rank_per_node = math.ceil(world_size / nodes) + num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes) + node_id = int(rank_id / rank_per_node) + rank_offset_per_node = rank_id % rank_per_node + if os.getenv("OMP_NUM_THREADS") is None: + num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1) + else: + num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS")) + if len(numa.get_membind()) == nodes: + numa.set_membind([node_id]) + torch.set_num_threads(num_cpus_per_rank) + if len(numa.get_affinity(0)) == psutil.cpu_count(logical=True): + cpu_start = num_cpus_per_rank * rank_offset_per_node + numa.set_affinity( + 0, + list(numa.node_to_cpus(node_id))[ + cpu_start : cpu_start + num_cpus_per_rank + ], + ) + logger.info(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}") + + @dataclass class FlashCausalLMBatch(Batch): batch_id: int @@ -854,6 +884,7 @@ class FlashCausalLM(Model): device = torch.device("cpu") # Float16 doesn't exist on target. dtype = torch.bfloat16 if dtype is None else dtype + init_cpu_threads_env(rank_id=rank, world_size=world_size) else: raise NotImplementedError(f"{model_class} is only available on GPU") From 1cebccc72b3233f705724be43ea760239d5d2717 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 13 Aug 2024 10:19:46 -0400 Subject: [PATCH 43/72] fix: adds causal to attention params (#2408) fix: adds causal to attention params to check when using flash attn v1 --- server/text_generation_server/layers/attention/cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index d039e1e7..8703eb94 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -293,6 +293,7 @@ else: max_s, softmax_scale, window_size_left=-1, + causal=None, softcap=None, ): if window_size_left != -1: From c5fff92b48431fcd38a88cee6042eccbbc958c5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 14 Aug 2024 11:06:28 +0200 Subject: [PATCH 44/72] nix: partial incremental build of the router (#2416) This is less incremental than crate2nix, but does build all dependencies separately, so avoids full rebuilds. --- _server.nix | 7 ++++++- flake.lock | 22 ++++++++++++++++++++++ flake.nix | 37 +++++++++++++++++++++++++++++-------- router.nix | 18 ------------------ 4 files changed, 57 insertions(+), 27 deletions(-) delete mode 100644 router.nix diff --git a/_server.nix b/_server.nix index bc11ba6b..2cb2f887 100644 --- a/_server.nix +++ b/_server.nix @@ -1,4 +1,9 @@ -{ mkPoetryApplication, pkg-config, protobuf, openssl }: +{ + mkPoetryApplication, + pkg-config, + protobuf, + openssl, +}: mkPoetryApplication { # name = "text-generation-server"; diff --git a/flake.lock b/flake.lock index ef05c258..dab582a5 100644 --- a/flake.lock +++ b/flake.lock @@ -579,6 +579,27 @@ "type": "github" } }, + "naersk": { + "inputs": { + "nixpkgs": [ + "tgi-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1721727458, + "narHash": "sha256-r/xppY958gmZ4oTfLiHN0ZGuQ+RSTijDblVgVLFi1mw=", + "owner": "nix-community", + "repo": "naersk", + "rev": "3fb418eaf352498f6b6c30592e3beb63df42ef11", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "naersk", + "type": "github" + } + }, "nix-github-actions": { "inputs": { "nixpkgs": [ @@ -880,6 +901,7 @@ "inputs": { "crate2nix": "crate2nix", "flake-utils": "flake-utils_6", + "naersk": "naersk", "nixpkgs": [ "tgi-nix", "nixpkgs" diff --git a/flake.nix b/flake.nix index 8784be37..e1f44212 100644 --- a/flake.nix +++ b/flake.nix @@ -7,6 +7,10 @@ tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; + naersk = { + url = "github:nix-community/naersk"; + inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; + }; poetry2nix.url = "github:nix-community/poetry2nix"; rust-overlay = { url = "github:oxalica/rust-overlay"; @@ -17,6 +21,7 @@ { self, crate2nix, + naersk, nixpkgs, flake-utils, rust-overlay, @@ -41,13 +46,32 @@ tgi-nix.overlay ]; }; + naersk' = pkgs.callPackage naersk { }; + router = + with pkgs; + naersk'.buildPackage { + name = "router"; + src = ./.; + cargoBuildOptions = + x: + x + ++ [ + "-p" + "text-generation-router-v3" + ]; + nativeBuildInputs = [ pkg-config ]; + buildInputs = [ + openssl.dev + protobuf + ]; + doCheck = false; + }; inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryEditablePackage; - text-generation-server = mkPoetryEditablePackage { - editablePackageSources = ./server; - }; + text-generation-server = mkPoetryEditablePackage { editablePackageSources = ./server; }; in { + defaultPackage = router; devShells.default = with pkgs; mkShell { @@ -92,11 +116,8 @@ transformers vllm - cargoNix.workspaceMembers.text-generation-launcher.build - - (callPackage ./router.nix { - inherit (rustPlatform) buildRustPackage importCargoLock; - }) + cargoNix.workspaceMembers.text-generation-launcher.build + router ]); venvDir = "./.venv"; diff --git a/router.nix b/router.nix deleted file mode 100644 index eeeac199..00000000 --- a/router.nix +++ /dev/null @@ -1,18 +0,0 @@ -{ buildRustPackage, importCargoLock, pkg-config, protobuf, openssl }: - -buildRustPackage { - name = "text-generation-router"; - - src = ./.; - - sourceDir = ./backends/v3; - - cargoLock = { - lockFile = ./Cargo.lock; - }; - - nativeBuildInputs = [ pkg-config ]; - - buildInputs = [ openssl.dev protobuf ]; - -} From f3b5c6944173779d4c762702d8369e42a9ce0180 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 14 Aug 2024 11:58:08 +0200 Subject: [PATCH 45/72] Upgrading exl2. (#2415) * Upgrading exl2. * Fixing the other pathways. * Fix idefics. --- .gitignore | 2 +- Dockerfile | 6 +-- flake.nix | 1 + server/Makefile | 1 + server/Makefile-exllamav2 | 12 ++++++ .../layers/gptq/exllamav2.py | 43 ++++++++++++------- .../models/causal_lm.py | 1 + .../models/flash_causal_lm.py | 1 + .../text_generation_server/models/idefics.py | 1 + .../models/idefics_causal_lm.py | 1 + .../models/seq2seq_lm.py | 1 + server/text_generation_server/server.py | 6 +-- 12 files changed, 54 insertions(+), 22 deletions(-) create mode 100644 server/Makefile-exllamav2 diff --git a/.gitignore b/.gitignore index bd9d9125..f79d8faa 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,7 @@ backends/client/src/v3/pb # ROCm auto-generated files *.hip -server/exllamav2_kernels/exllamav2_kernels/hip/ +server/exllamav2 server/exllama_kernels/exllama_kernels/hip/ server/exllama_kernels/exllama_kernels/hip_func/ *_hip.cuh diff --git a/Dockerfile b/Dockerfile index 458ff699..74e7d990 100644 --- a/Dockerfile +++ b/Dockerfile @@ -123,10 +123,10 @@ RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build # Build Transformers exllama kernels FROM kernel-builder AS exllamav2-kernels-builder WORKDIR /usr/src -COPY server/exllamav2_kernels/ . +COPY server/Makefile-exllamav2/ Makefile # Build specific version of transformers -RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-exllamav2 # Build Transformers awq kernels FROM kernel-builder AS awq-kernels-builder @@ -221,7 +221,7 @@ COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 / # Copy build artifacts from exllama kernels builder COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from exllamav2 kernels builder -COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from awq kernels builder COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from eetq kernels builder diff --git a/flake.nix b/flake.nix index e1f44212..229184d2 100644 --- a/flake.nix +++ b/flake.nix @@ -93,6 +93,7 @@ causal-conv1d click einops + exllamav2 fbgemm-gpu flashinfer flash-attn diff --git a/server/Makefile b/server/Makefile index 209fc44e..51ea8b32 100644 --- a/server/Makefile +++ b/server/Makefile @@ -6,6 +6,7 @@ include Makefile-eetq include Makefile-selective-scan include Makefile-lorax-punica include Makefile-fbgemm +include Makefile-exllamav2 unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-exllamav2 b/server/Makefile-exllamav2 new file mode 100644 index 00000000..0d4cc385 --- /dev/null +++ b/server/Makefile-exllamav2 @@ -0,0 +1,12 @@ +exllamav2_commit := v0.1.8 + +build-exllamav2: + git clone https://github.com/turboderp/exllamav2.git exllamav2 && \ + cd exllamav2 && git fetch && git checkout $(exllamav2_commit) && \ + git submodule update --init --recursive && \ + pip install -r requirements.txt && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py build + +install-exllamav2: build-exllamav2 + cd exllamav2/ && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py install diff --git a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py index dc3b832f..920a6adf 100644 --- a/server/text_generation_server/layers/gptq/exllamav2.py +++ b/server/text_generation_server/layers/gptq/exllamav2.py @@ -12,7 +12,10 @@ from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.utils.log import log_master try: - from exllamav2_kernels import make_q_matrix, gemm_half_q_half + from exllamav2.ext import exllamav2_ext + + make_q_matrix = exllamav2_ext.make_q_matrix + gemm_half_q_half = exllamav2_ext.gemm_half_q_half except ImportError: log_master(logger.warning, "exllamav2_kernels not installed.") raise @@ -70,6 +73,10 @@ def ext_make_q_matrix( """ Create Q matrix """ + # max_dq_size = 512*(1024**2) + # max_dq_rows = max_dq_size // out_features[0] + max_dq_rows = 0 + # EXL2 if isinstance(w, Exl2Weight): extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0]) @@ -83,10 +90,12 @@ def ext_make_q_matrix( w.q_scale_max, w.q_groups, extra.q_group_map, - none_tensor, - none_tensor, - none_tensor, + none_tensor, # zeros + none_tensor, # scales + none_tensor, # g_idx + none_tensor, # bias temp_dq, + max_dq_rows, ) # GPTQ elif isinstance(w, GPTQWeight): @@ -106,29 +115,33 @@ def ext_make_q_matrix( w.qweight, extra.q_perm, extra.q_invperm, - none_tensor, - none_tensor, - none_tensor, - none_tensor, + none_tensor, # q_scale + none_tensor, # q_scale_max + none_tensor, # q_groups + none_tensor, # q_group_map w.qzeros, w.scales, w.g_idx.cpu(), + none_tensor, # bias temp_dq, + max_dq_rows, ) # GPTQ without g_idx else: return make_q_matrix( w.qweight, - none_tensor, - none_tensor, - none_tensor, - none_tensor, - none_tensor, - none_tensor, + none_tensor, # q_perm + none_tensor, # q_invperm + none_tensor, # q_scale + none_tensor, # q_scale_max + none_tensor, # q_groups + none_tensor, # q_group_map w.qzeros, w.scales, - none_tensor, + none_tensor, # g_idx + none_tensor, # bias temp_dq, + max_dq_rows, ) else: RuntimeError("Cannot create handle") diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 212ab7a9..ba168b13 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -511,6 +511,7 @@ class CausalLM(Model): config_class=AutoConfig, batch_class=CausalLMBatch, ): + self.quantize = quantize self.batch_class = batch_class self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 42d93a12..5e2fd20a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -872,6 +872,7 @@ class FlashCausalLM(Model): head_size: Optional[int] = None, skip_special_tokens: bool = True, ): + self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index 29929b98..9058cb96 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -33,6 +33,7 @@ class IDEFICSSharded(IdeficsCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 8a80ed68..c5480952 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -580,6 +580,7 @@ class IdeficsCausalLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.quantize = quantize from text_generation_server.models.custom_modeling.idefics_modeling import ( IdeficsForVisionText2Text, ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 79c001b0..3c92128a 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -553,6 +553,7 @@ class Seq2SeqLM(Model): tokenizer_class=AutoTokenizer, aliases=None, ): + self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index b92ab572..22871ec5 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -50,12 +50,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): self, model: Model, cache: Cache, - quantize: Optional[str], server_urls: List[str], ): self.cache = cache self.model = model - self.quantize = quantize + # Quantize is resolved during model loading + self.quantize = model.quantize self.server_urls = server_urls # For some reason, inference_mode does not work well with GLOO which we use on CPU if model.device.type == "cuda": @@ -255,7 +255,7 @@ def serve( ], ) generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( - TextGenerationService(model, Cache(), quantize, server_urls), server + TextGenerationService(model, Cache(), server_urls), server ) SERVICE_NAMES = ( generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name, From 3f385991b084c59968817506a54398a1accd03b9 Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Wed, 14 Aug 2024 12:02:05 +0200 Subject: [PATCH 46/72] More fixes trtllm (#2342) * (backend) use parking_lot crate for RwLock fairness * (docker) let's put rust in the TRTLLM folder when building * (docker) build ompi with SLURM support * (launcher) default new server::run parameters to false for now * (chore) fmt ... why? --- Cargo.lock | 1 + backends/trtllm/Cargo.toml | 5 +++-- backends/trtllm/Dockerfile | 9 +++++---- backends/trtllm/src/backend.rs | 3 ++- backends/trtllm/src/main.rs | 8 ++++---- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bb63422e..d298c379 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4028,6 +4028,7 @@ dependencies = [ "cxx", "cxx-build", "log", + "parking_lot", "pkg-config", "text-generation-router", "thiserror", diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml index 7079d3d1..43a114ba 100644 --- a/backends/trtllm/Cargo.toml +++ b/backends/trtllm/Cargo.toml @@ -8,17 +8,18 @@ homepage.workspace = true [dependencies] async-trait = "0.1" async-stream = "0.3" +clap = { version = "4.5", features = ["derive"] } cxx = "1.0" +log = { version = "0.4", features = [] } text-generation-router = { path = "../../router" } tokenizers = { version = "0.19", features = ["hf-hub"] } tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.15" -clap = { version = "4.5", features = ["derive"] } thiserror = "1.0.62" tracing = "0.1" tracing-opentelemetry = "0.24" tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } -log = { version = "0.4", features = [] } +parking_lot = "0.12" [build-dependencies] cmake = "0.1" diff --git a/backends/trtllm/Dockerfile b/backends/trtllm/Dockerfile index 60ad03f7..5fd2f89f 100644 --- a/backends/trtllm/Dockerfile +++ b/backends/trtllm/Dockerfile @@ -3,7 +3,7 @@ ARG OMPI_VERSION="4.1.6" # Build dependencies resolver stage FROM lukemathwalker/cargo-chef:latest AS chef -WORKDIR /usr/src/text-generation-inference +WORKDIR /usr/src/text-generation-inference/backends/trtllm FROM chef AS planner COPY . . @@ -42,7 +42,7 @@ RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILE mkdir /usr/src/mpi && \ tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \ cd /usr/src/mpi && \ - ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \ + ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda && \ make -j all && \ make install && \ rm -rf "/opt/src/$OMPI_TARBALL_FILENAME" @@ -66,7 +66,7 @@ ENV PATH="/root/.cargo/bin:$PATH" RUN cargo install cargo-chef # Cache dependencies -COPY --from=planner /usr/src/text-generation-inference/recipe.json . +COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json . RUN cargo chef cook --release --recipe-path recipe.json # Build actual TGI @@ -79,7 +79,8 @@ COPY . . COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \ - CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm + cd backends/trtllm && \ + CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime WORKDIR /usr/local/tgi/bin diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index b26d06a6..b23aa6c0 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -12,12 +12,13 @@ use cxx::UniquePtr; use log::{error, warn}; use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; -use tokio::sync::RwLock; use tokio::time::{sleep, Instant}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::{Stream, StreamExt}; use tracing::{instrument, span, Level}; +// use tokio::sync::RwLock; +use parking_lot::RwLock; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidationError::UnsupportedModality; use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError}; diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 6d6ee146..e0ba46c7 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -1,12 +1,10 @@ +use clap::Parser; use std::collections::HashMap; use std::path::PathBuf; - -use clap::Parser; -use tokenizers::{FromPretrainedParameters, Tokenizer}; - use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::TensorRtLlmBackend; use text_generation_router::server; +use tokenizers::{FromPretrainedParameters, Tokenizer}; /// App Configuration #[derive(Parser, Debug)] @@ -160,6 +158,8 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { messages_api_enabled, true, max_client_batch_size, + false, + false, ) .await?; Ok(()) From 9aaa12e7ac54d859ebfee4b023f66010cef1bca6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 15 Aug 2024 10:21:51 +0200 Subject: [PATCH 47/72] nix: build router incrementally (#2422) --- flake.lock | 22 ---------------------- flake.nix | 53 +++++++++++++++++++++++++---------------------------- 2 files changed, 25 insertions(+), 50 deletions(-) diff --git a/flake.lock b/flake.lock index dab582a5..ef05c258 100644 --- a/flake.lock +++ b/flake.lock @@ -579,27 +579,6 @@ "type": "github" } }, - "naersk": { - "inputs": { - "nixpkgs": [ - "tgi-nix", - "nixpkgs" - ] - }, - "locked": { - "lastModified": 1721727458, - "narHash": "sha256-r/xppY958gmZ4oTfLiHN0ZGuQ+RSTijDblVgVLFi1mw=", - "owner": "nix-community", - "repo": "naersk", - "rev": "3fb418eaf352498f6b6c30592e3beb63df42ef11", - "type": "github" - }, - "original": { - "owner": "nix-community", - "repo": "naersk", - "type": "github" - } - }, "nix-github-actions": { "inputs": { "nixpkgs": [ @@ -901,7 +880,6 @@ "inputs": { "crate2nix": "crate2nix", "flake-utils": "flake-utils_6", - "naersk": "naersk", "nixpkgs": [ "tgi-nix", "nixpkgs" diff --git a/flake.nix b/flake.nix index 229184d2..168c8649 100644 --- a/flake.nix +++ b/flake.nix @@ -7,10 +7,6 @@ tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; - naersk = { - url = "github:nix-community/naersk"; - inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; - }; poetry2nix.url = "github:nix-community/poetry2nix"; rust-overlay = { url = "github:oxalica/rust-overlay"; @@ -21,7 +17,6 @@ { self, crate2nix, - naersk, nixpkgs, flake-utils, rust-overlay, @@ -34,6 +29,7 @@ cargoNix = crate2nix.tools.${system}.appliedCargoNix { name = "tgi"; src = ./.; + additionalCargoNixArgs = [ "--all-features" ]; }; config = { allowUnfree = true; @@ -46,32 +42,10 @@ tgi-nix.overlay ]; }; - naersk' = pkgs.callPackage naersk { }; - router = - with pkgs; - naersk'.buildPackage { - name = "router"; - src = ./.; - cargoBuildOptions = - x: - x - ++ [ - "-p" - "text-generation-router-v3" - ]; - nativeBuildInputs = [ pkg-config ]; - buildInputs = [ - openssl.dev - protobuf - ]; - doCheck = false; - }; - inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryEditablePackage; text-generation-server = mkPoetryEditablePackage { editablePackageSources = ./server; }; in { - defaultPackage = router; devShells.default = with pkgs; mkShell { @@ -118,7 +92,30 @@ vllm cargoNix.workspaceMembers.text-generation-launcher.build - router + + (cargoNix.workspaceMembers.text-generation-router-v3.build.override { + crateOverrides = defaultCrateOverrides // { + aws-lc-rs = attrs: { + # aws-lc-rs does its own custom parsing of Cargo environment + # variables like DEP_.*_INCLUDE. However buildRustCrate does + # not use the version number, so the parsing fails. + postPatch = '' + substituteInPlace build.rs \ + --replace-fail \ + "assert!(!selected.is_empty()" \ + "// assert!(!selected.is_empty()" + ''; + }; + rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; }; + text-generation-router-v3 = attrs: { + # We need to do the src/source root dance so that the build + # has access to the protobuf file. + src = ./.; + postPatch = "cd backends/v3"; + buildInputs = [ protobuf ]; + }; + }; + }) ]); venvDir = "./.venv"; From 57b34958235ee64b7c310a4b5410bcc491a0ef28 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 15 Aug 2024 11:12:51 +0200 Subject: [PATCH 48/72] Fixing exl2 and other quanize tests again. (#2419) * Fixing exl2 and other quanize tests again. * Mark exl2 as non release (so CI tests them, needs to be removed latet). * Fixing exl2 (by disabling cuda graphs) * Fix quantization defaults without cuda graphs on exl2 (linked to new issues with it). * Removing serde override. * Go back to released exl2 and remove log. * Adding warnings for deprecated bitsandbytes + upgrade info to warn. --- Dockerfile | 3 + .../models/test_flash_llama_exl2.py | 3 - launcher/src/main.rs | 156 ++++++++++-------- server/poetry.lock | 69 +++++++- server/pyproject.toml | 1 + server/requirements_cuda.txt | 4 + server/requirements_intel.txt | 4 + server/requirements_rocm.txt | 4 + .../models/causal_lm.py | 1 + server/text_generation_server/models/mamba.py | 1 + .../models/seq2seq_lm.py | 1 + 11 files changed, 176 insertions(+), 71 deletions(-) diff --git a/Dockerfile b/Dockerfile index 74e7d990..b2d274d7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -250,6 +250,9 @@ RUN cd server && \ pip install nvidia-nccl-cu12==2.22.3 ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 +# This is needed because exl2 tries to load flash-attn +# And fails with our builds. +ENV EXLLAMA_NO_FLASH_ATTN=1 # Deps before the binaries # The binaries change on every build given we burn the SHA into them diff --git a/integration-tests/models/test_flash_llama_exl2.py b/integration-tests/models/test_flash_llama_exl2.py index 7169c999..18319f60 100644 --- a/integration-tests/models/test_flash_llama_exl2.py +++ b/integration-tests/models/test_flash_llama_exl2.py @@ -21,7 +21,6 @@ async def flash_llama_exl2(flash_llama_exl2_handle): return flash_llama_exl2_handle.client -@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot): @@ -33,7 +32,6 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh assert response == ignore_logprob_response_snapshot -@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2_all_params( @@ -60,7 +58,6 @@ async def test_flash_llama_exl2_all_params( assert response == ignore_logprob_response_snapshot -@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2_load( diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a64b1d71..9a90a673 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -30,11 +30,18 @@ struct RawConfig { n_positions: Option, model_type: Option, max_seq_len: Option, + quantization_config: Option, +} + +#[derive(Deserialize)] +struct QuantizationConfig { + quant_method: Option, } #[derive(Deserialize)] struct Config { max_position_embeddings: Option, + quantize: Option, } impl From for Config { @@ -43,13 +50,16 @@ impl From for Config { .max_position_embeddings .or(other.max_seq_len) .or(other.n_positions); + let quantize = other.quantization_config.and_then(|q| q.quant_method); Config { max_position_embeddings, + quantize, } } } -#[derive(Clone, Copy, Debug, ValueEnum)] +#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)] +#[serde(rename_all = "kebab-case")] enum Quantization { /// 4 bit quantization. Requires a specific AWQ quantized model: /// . @@ -72,17 +82,17 @@ enum Quantization { Marlin, /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, /// but it is known that the model will be much slower to run than the native f16. - #[deprecated( - since = "1.1.0", - note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" - )] + // #[deprecated( + // since = "1.1.0", + // note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" + // )] Bitsandbytes, /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, /// but it is known that the model will be much slower to run than the native f16. - BitsandbytesNF4, + BitsandbytesNf4, /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better /// perplexity performance for you model - BitsandbytesFP4, + BitsandbytesFp4, /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above /// This dtype has native ops should be the fastest if available. /// This is currently not the fastest because of local unpacking + padding to satisfy matrix @@ -99,10 +109,10 @@ impl std::fmt::Display for Quantization { Quantization::Bitsandbytes => { write!(f, "bitsandbytes") } - Quantization::BitsandbytesNF4 => { + Quantization::BitsandbytesNf4 => { write!(f, "bitsandbytes-nf4") } - Quantization::BitsandbytesFP4 => { + Quantization::BitsandbytesFp4 => { write!(f, "bitsandbytes-fp4") } Quantization::Exl2 => { @@ -1085,6 +1095,7 @@ fn spawn_shards( cuda_graphs: Vec, max_total_tokens: usize, max_input_tokens: usize, + quantize: Option, max_log_level: LevelFilter, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1106,7 +1117,6 @@ fn spawn_shards( let shutdown_sender = shutdown_sender.clone(); let otlp_endpoint = args.otlp_endpoint.clone(); let otlp_service_name = args.otlp_service_name.clone(); - let quantize = args.quantize; let speculate = args.speculate; let dtype = args.dtype; let trust_remote_code = args.trust_remote_code; @@ -1429,65 +1439,68 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:#?}", args); - let get_max_position_embeddings = || -> Result> { - let model_id = args.model_id.clone(); - let mut path = std::path::Path::new(&args.model_id).to_path_buf(); - let filename = if !path.exists() { - // Assume it's a hub id + let get_max_positions_quantize = + || -> Result<(usize, Option), Box> { + let model_id = args.model_id.clone(); + let mut path = std::path::Path::new(&args.model_id).to_path_buf(); + let filename = if !path.exists() { + // Assume it's a hub id - let api = if let Ok(token) = std::env::var("HF_TOKEN") { - // env variable has precedence over on file token. - ApiBuilder::new().with_token(Some(token)).build()? + let api = if let Ok(token) = std::env::var("HF_TOKEN") { + // env variable has precedence over on file token. + ApiBuilder::new().with_token(Some(token)).build()? + } else { + Api::new()? + }; + let repo = if let Some(ref revision) = args.revision { + api.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.to_string(), + )) + } else { + api.model(model_id) + }; + repo.get("config.json")? } else { - Api::new()? + path.push("config.json"); + path }; - let repo = if let Some(ref revision) = args.revision { - api.repo(Repo::with_revision( - model_id, - RepoType::Model, - revision.to_string(), - )) - } else { - api.model(model_id) - }; - repo.get("config.json")? - } else { - path.push("config.json"); - path - }; - let content = std::fs::read_to_string(filename)?; - let config: RawConfig = serde_json::from_str(&content)?; + let content = std::fs::read_to_string(filename)?; + let config: RawConfig = serde_json::from_str(&content)?; - if config.model_type == Some("gemma2".to_string()) { - tracing::info!("Forcing flash decoding because of softcap usage"); - std::env::set_var("ATTENTION", "flashdecoding"); - } - let config: Config = config.into(); - - // Quantization usually means you're even more RAM constrained. - let max_default = 4096; - - if let Some(max_position_embeddings) = config.max_position_embeddings { - if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } - Ok(max_default) - } else { - Ok(max_position_embeddings) + if config.model_type == Some("gemma2".to_string()) { + tracing::info!("Forcing flash decoding because of softcap usage"); + std::env::set_var("ATTENTION", "flashdecoding"); } - } else { - Err(Box::new(LauncherError::ArgumentValidation( - "no max defined".to_string(), - ))) - } - }; - let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); + let config: Config = config.into(); + let quantize = config.quantize; + + // Quantization usually means you're even more RAM constrained. + let max_default = 4096; + + if let Some(max_position_embeddings) = config.max_position_embeddings { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + } + Ok((max_default, quantize)) + } else { + Ok((max_position_embeddings, quantize)) + } + } else { + Err(Box::new(LauncherError::ArgumentValidation( + "no max defined".to_string(), + ))) + } + }; + let (max_position_embeddings, quantize): (usize, Option) = + get_max_positions_quantize().unwrap_or((4096, None)); let max_input_tokens = { match (args.max_input_tokens, args.max_input_length) { @@ -1544,18 +1557,26 @@ fn main() -> Result<(), LauncherError> { ))); } - let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { + if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { + tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases."); + } + let quantize = args.quantize.or(quantize); + let cuda_graphs = match (&args.cuda_graphs, &quantize) { (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), #[allow(deprecated)] ( None, Some( Quantization::Bitsandbytes - | Quantization::BitsandbytesNF4 - | Quantization::BitsandbytesFP4, + | Quantization::BitsandbytesNf4 + | Quantization::BitsandbytesFp4, ), ) => { - tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); + tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); + vec![] + } + (None, Some(Quantization::Exl2)) => { + tracing::warn!("Exl2 doesn't work with cuda graphs, deactivating them"); vec![] } _ => { @@ -1672,6 +1693,7 @@ fn main() -> Result<(), LauncherError> { cuda_graphs, max_total_tokens, max_input_tokens, + quantize, max_log_level, shutdown.clone(), &shutdown_receiver, diff --git a/server/poetry.lock b/server/poetry.lock index 5072aa0b..fc1a54a3 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1070,6 +1070,30 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} [package.extras] dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +optional = false +python-versions = ">=3.8" +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + [[package]] name = "markupsafe" version = "2.1.5" @@ -1207,6 +1231,17 @@ torch = "*" type = "url" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -2277,6 +2312,20 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pygments" +version = "2.18.0" +description = "Pygments is a syntax highlighting package written in Python." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"}, + {file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"}, +] + +[package.extras] +windows-terminal = ["colorama (>=0.4.6)"] + [[package]] name = "pytest" version = "7.4.4" @@ -2508,6 +2557,24 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rich" +version = "13.7.1" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"}, + {file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + [[package]] name = "rpds-py" version = "0.19.0" @@ -3584,4 +3651,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "a89867b23017d2efa8a7aa14d4764bcbd3b4dea9bfbf06a7a68464cb184ac6a1" +content-hash = "0ff7a244a409b616490cb238995bbe28dedf67ccb8855edafa2b71ee2e777dbd" diff --git a/server/pyproject.toml b/server/pyproject.toml index 15da4a8f..57deb1b8 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -46,6 +46,7 @@ marlin-kernels = [ { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] +rich = "^13.7.1" [tool.poetry.extras] torch = ["torch"] diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index 828b6fca..eb521bd6 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_intel.txt b/server/requirements_intel.txt index 828b6fca..eb521bd6 100644 --- a/server/requirements_intel.txt +++ b/server/requirements_intel.txt @@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index 828b6fca..eb521bd6 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ba168b13..28534d0f 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -652,6 +652,7 @@ class CausalLM(Model): dtype=dtype, device=device, ) + self.quantize = quantize return self @property diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 5d6ce3c7..f6dcde68 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -412,6 +412,7 @@ class Mamba(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.quantize = quantize self.process_group, _rank, world_size = initialize_torch_distributed() if world_size > 1: raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 3c92128a..04d4c28b 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -676,6 +676,7 @@ class Seq2SeqLM(Model): dtype=dtype, device=device, ) + self.quantize = quantize return self @property From 1b0aa06204d0bad0c06e6352473ccef3aeed7ebf Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 15 Aug 2024 13:28:42 +0200 Subject: [PATCH 49/72] Upgrading the tests to match the current workings. (#2423) --- .../test_bloom_560m/test_bloom_560m.json | 30 +- .../test_flash_deepseek_v2_all_params.json | 8 +- .../test_flash_gemma/test_flash_gemma.json | 32 +- .../test_flash_starcoder_gptq.json | 146 +------- ...t_flash_starcoder_gptq_default_params.json | 144 +------- .../test_flash_starcoder_gptq_load.json | 344 ++++-------------- .../test_mamba/test_mamba_all_params.json | 24 +- .../test_mt0_base_all_params.json | 6 +- .../models/test_flash_starcoder_gptq.py | 4 +- integration-tests/models/test_mamba.py | 2 +- 10 files changed, 166 insertions(+), 574 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json index b274992e..5d0eeef6 100644 --- a/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json +++ b/integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json @@ -11,52 +11,52 @@ }, { "id": 49833, - "logprob": -10.5703125, + "logprob": -10.546875, "text": " dรฉg" }, { "id": 21543, - "logprob": -0.14746094, + "logprob": -0.18457031, "text": "uster" }, { "id": 447, - "logprob": -1.9277344, + "logprob": -1.9287109, "text": " un" }, { "id": 46341, - "logprob": -15.421875, + "logprob": -15.4296875, "text": " ort" }, { "id": 35567, - "logprob": -7.5820312, + "logprob": -7.578125, "text": "olan" }, { "id": 15, - "logprob": -1.4013672, + "logprob": -1.4003906, "text": "," }, { "id": 1669, - "logprob": -1.5664062, + "logprob": -1.5439453, "text": " il" }, { "id": 11580, - "logprob": -0.94189453, + "logprob": -0.93896484, "text": " faut" }, { "id": 3913, - "logprob": -3.6816406, + "logprob": -3.7207031, "text": " tout" }, { "id": 39261, - "logprob": -1.7753906, + "logprob": -1.5742188, "text": " d'abord" } ], @@ -64,13 +64,13 @@ "tokens": [ { "id": 578, - "logprob": -1.6318359, + "logprob": -1.6474609, "special": false, "text": " le" }, { "id": 5608, - "logprob": -2.4882812, + "logprob": -2.4707031, "special": false, "text": " faire" }, @@ -88,19 +88,19 @@ }, { "id": 693, - "logprob": -2.4472656, + "logprob": -2.4628906, "special": false, "text": " ร " }, { "id": 366, - "logprob": -1.1972656, + "logprob": -1.1953125, "special": false, "text": " la" }, { "id": 48844, - "logprob": -1.7890625, + "logprob": -1.7978516, "special": false, "text": " cass" }, diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json index 6b45cf6b..3ac8d050 100644 --- a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json @@ -11,7 +11,7 @@ }, { "id": 3533, - "logprob": -9.625, + "logprob": -9.5625, "text": "Test" }, { @@ -24,13 +24,13 @@ "tokens": [ { "id": 2143, - "logprob": -1.828125, + "logprob": -1.8203125, "special": false, "text": " sent" }, { "id": 10081, - "logprob": -0.41210938, + "logprob": -0.55078125, "special": false, "text": " successfully" }, @@ -42,7 +42,7 @@ }, { "id": 100001, - "logprob": -0.16015625, + "logprob": -0.12695312, "special": true, "text": "<๏ฝœendโ–ofโ–sentence๏ฝœ>" } diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json index 8829f9fe..96f2ce17 100644 --- a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json @@ -24,13 +24,13 @@ "tokens": [ { "id": 1736, - "logprob": -2.03125, + "logprob": -2.046875, "special": false, "text": " form" }, { "id": 109, - "logprob": -1.8671875, + "logprob": -1.8828125, "special": false, "text": "\n\n" }, @@ -42,48 +42,48 @@ }, { "id": 2121, - "logprob": -1.8125, + "logprob": -1.78125, "special": false, "text": " test" }, { "id": 3853, - "logprob": -0.24121094, + "logprob": -0.23632812, "special": false, "text": " request" }, { "id": 1736, - "logprob": -0.100097656, + "logprob": -0.09326172, "special": false, "text": " form" }, { "id": 603, - "logprob": -0.9453125, + "logprob": -0.8828125, "special": false, "text": " is" }, { - "id": 476, - "logprob": -1.703125, + "id": 1671, + "logprob": -1.6171875, "special": false, - "text": " a" + "text": " used" }, { - "id": 4551, - "logprob": -2.453125, + "id": 577, + "logprob": -0.390625, "special": false, - "text": " document" + "text": " to" }, { - "id": 674, - "logprob": -0.796875, + "id": 3853, + "logprob": -1.2265625, "special": false, - "text": " that" + "text": " request" } ], "top_tokens": null }, - "generated_text": " form\n\nThe test request form is a document that" + "generated_text": " form\n\nThe test request form is used to request" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json index 5e537bb7..26224118 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 20, + "finish_reason": "eos_token", + "generated_tokens": 2, "prefill": [ { "id": 589, @@ -11,57 +11,57 @@ }, { "id": 3226, - "logprob": -8.5859375, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -7.5859375, + "logprob": -8.8515625, "text": "ometric" }, { "id": 81, - "logprob": -0.2668457, + "logprob": -0.21875, "text": "_" }, { "id": 6009, - "logprob": -1.6416016, + "logprob": -1.2773438, "text": "mean" }, { "id": 26, - "logprob": -0.22705078, + "logprob": -0.25195312, "text": "(" }, { "id": 62, - "logprob": -5.2304688, + "logprob": -4.8203125, "text": "L" }, { "id": 44, - "logprob": -3.0976562, + "logprob": -3.7734375, "text": ":" }, { "id": 1682, - "logprob": -1.1044922, + "logprob": -0.8310547, "text": " List" }, { "id": 77, - "logprob": -0.14294434, + "logprob": -0.22766113, "text": "[" }, { "id": 1808, - "logprob": -0.32299805, + "logprob": -0.46240234, "text": "float" }, { "id": 10794, - "logprob": -2.8164062, + "logprob": -3.0234375, "text": "]):" } ], @@ -69,126 +69,18 @@ "tokens": [ { "id": 284, - "logprob": -0.1282959, + "logprob": -0.04626465, "special": false, "text": "\n " }, { - "id": 1524, - "logprob": -0.97998047, - "special": false, - "text": " \"\"\"" - }, - { - "id": 284, - "logprob": -0.7006836, - "special": false, - "text": "\n " - }, - { - "id": 14883, - "logprob": -2.1933594, - "special": false, - "text": " Calculate" - }, - { - "id": 322, - "logprob": -0.2697754, - "special": false, - "text": " the" - }, - { - "id": 3226, - "logprob": -0.0836792, - "special": false, - "text": " ge" - }, - { - "id": 21017, - "logprob": -0.018737793, - "special": false, - "text": "ometric" - }, - { - "id": 5651, - "logprob": -0.028640747, - "special": false, - "text": " mean" - }, - { - "id": 432, - "logprob": -0.29467773, - "special": false, - "text": " of" - }, - { - "id": 312, - "logprob": -0.31518555, - "special": false, - "text": " a" - }, - { - "id": 1149, - "logprob": -0.20605469, - "special": false, - "text": " list" - }, - { - "id": 432, - "logprob": -0.23254395, - "special": false, - "text": " of" - }, - { - "id": 7515, - "logprob": -0.4489746, - "special": false, - "text": " numbers" - }, - { - "id": 32, - "logprob": -0.6044922, - "special": false, - "text": "." - }, - { - "id": 446, - "logprob": -0.63964844, - "special": false, - "text": "\n\n " - }, - { - "id": 499, - "logprob": -1.1953125, - "special": false, - "text": " :" - }, - { - "id": 753, - "logprob": -0.03515625, - "special": false, - "text": "param" - }, - { - "id": 498, - "logprob": -0.06311035, - "special": false, - "text": " L" - }, - { - "id": 44, - "logprob": -0.003414154, - "special": false, - "text": ":" - }, - { - "id": 1682, - "logprob": -1.3310547, - "special": false, - "text": " List" + "id": 0, + "logprob": null, + "special": true, + "text": "<|endoftext|>" } ], "top_tokens": null }, - "generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List" + "generated_text": "\n " } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json index bf0f5146..015912f8 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 20, + "finish_reason": "eos_token", + "generated_tokens": 2, "prefill": [ { "id": 589, @@ -11,57 +11,57 @@ }, { "id": 3226, - "logprob": -8.5859375, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -7.5898438, + "logprob": -8.859375, "text": "ometric" }, { "id": 81, - "logprob": -0.26586914, + "logprob": -0.21984863, "text": "_" }, { "id": 6009, - "logprob": -1.6347656, + "logprob": -1.2861328, "text": "mean" }, { "id": 26, - "logprob": -0.22705078, + "logprob": -0.25219727, "text": "(" }, { "id": 62, - "logprob": -5.2382812, + "logprob": -4.8007812, "text": "L" }, { "id": 44, - "logprob": -3.0996094, + "logprob": -3.7949219, "text": ":" }, { "id": 1682, - "logprob": -1.1025391, + "logprob": -0.8046875, "text": " List" }, { "id": 77, - "logprob": -0.14294434, + "logprob": -0.22424316, "text": "[" }, { "id": 1808, - "logprob": -0.32226562, + "logprob": -0.46191406, "text": "float" }, { "id": 10794, - "logprob": -2.8164062, + "logprob": -3.0253906, "text": "]):" } ], @@ -74,121 +74,13 @@ "text": "\n " }, { - "id": 442, - "logprob": -1.3134766, - "special": false, - "text": " return" - }, - { - "id": 11665, - "logprob": -0.10021973, - "special": false, - "text": " reduce" - }, - { - "id": 26, - "logprob": 0.0, - "special": false, - "text": "(" - }, - { - "id": 5962, - "logprob": 0.0, - "special": false, - "text": "lambda" - }, - { - "id": 816, - "logprob": 0.0, - "special": false, - "text": " x" - }, - { - "id": 30, - "logprob": 0.0, - "special": false, - "text": "," - }, - { - "id": 533, - "logprob": 0.0, - "special": false, - "text": " y" - }, - { - "id": 44, - "logprob": 0.0, - "special": false, - "text": ":" - }, - { - "id": 816, - "logprob": 0.0, - "special": false, - "text": " x" - }, - { - "id": 319, - "logprob": -0.42871094, - "special": false, - "text": " *" - }, - { - "id": 533, - "logprob": 0.0, - "special": false, - "text": " y" - }, - { - "id": 30, - "logprob": 0.0, - "special": false, - "text": "," - }, - { - "id": 498, - "logprob": 0.0, - "special": false, - "text": " L" - }, - { - "id": 27, - "logprob": 0.0, - "special": false, - "text": ")" - }, - { - "id": 1115, - "logprob": 0.0, - "special": false, - "text": " **" - }, - { - "id": 308, - "logprob": 0.0, - "special": false, - "text": " (" - }, - { - "id": 35, - "logprob": 0.0, - "special": false, - "text": "1" - }, - { - "id": 32, - "logprob": -0.31323242, - "special": false, - "text": "." - }, - { - "id": 34, - "logprob": 0.0, - "special": false, - "text": "0" + "id": 0, + "logprob": null, + "special": true, + "text": "<|endoftext|>" } ], "top_tokens": null }, - "generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0" + "generated_text": "\n " } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json index 46a21ed8..c9b5ab20 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json @@ -2,8 +2,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "eos_token", + "generated_tokens": 2, "prefill": [ { "id": 589, @@ -12,57 +12,57 @@ }, { "id": 3226, - "logprob": -8.5859375, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -7.5820312, + "logprob": -8.859375, "text": "ometric" }, { "id": 81, - "logprob": -0.26708984, + "logprob": -0.21826172, "text": "_" }, { "id": 6009, - "logprob": -1.6386719, + "logprob": -1.3085938, "text": "mean" }, { "id": 26, - "logprob": -0.22717285, + "logprob": -0.2548828, "text": "(" }, { "id": 62, - "logprob": -5.234375, + "logprob": -4.8007812, "text": "L" }, { "id": 44, - "logprob": -3.1015625, + "logprob": -3.7871094, "text": ":" }, { "id": 1682, - "logprob": -1.1083984, + "logprob": -0.81152344, "text": " List" }, { "id": 77, - "logprob": -0.14294434, + "logprob": -0.22644043, "text": "[" }, { "id": 1808, - "logprob": -0.32592773, + "logprob": -0.46313477, "text": "float" }, { "id": 10794, - "logprob": -2.8164062, + "logprob": -3.0253906, "text": "]):" } ], @@ -70,74 +70,26 @@ "tokens": [ { "id": 284, - "logprob": -0.12817383, + "logprob": -0.046936035, "special": false, "text": "\n " }, { - "id": 1524, - "logprob": -0.9863281, - "special": false, - "text": " \"\"\"" - }, - { - "id": 284, - "logprob": -0.7011719, - "special": false, - "text": "\n " - }, - { - "id": 14883, - "logprob": -2.2050781, - "special": false, - "text": " Calculate" - }, - { - "id": 322, - "logprob": -0.2668457, - "special": false, - "text": " the" - }, - { - "id": 3226, - "logprob": -0.08465576, - "special": false, - "text": " ge" - }, - { - "id": 21017, - "logprob": -0.019012451, - "special": false, - "text": "ometric" - }, - { - "id": 5651, - "logprob": -0.028625488, - "special": false, - "text": " mean" - }, - { - "id": 432, - "logprob": -0.29418945, - "special": false, - "text": " of" - }, - { - "id": 312, - "logprob": -0.3161621, - "special": false, - "text": " a" + "id": 0, + "logprob": null, + "special": true, + "text": "<|endoftext|>" } ], "top_tokens": null }, - "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" + "generated_text": "\n " }, { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "eos_token", + "generated_tokens": 2, "prefill": [ { "id": 589, @@ -146,57 +98,57 @@ }, { "id": 3226, - "logprob": -8.5859375, + "logprob": -8.9375, "text": " ge" }, { "id": 21017, - "logprob": -7.59375, + "logprob": -8.859375, "text": "ometric" }, { "id": 81, - "logprob": -0.26953125, + "logprob": -0.21899414, "text": "_" }, { "id": 6009, - "logprob": -1.640625, + "logprob": -1.3105469, "text": "mean" }, { "id": 26, - "logprob": -0.22705078, + "logprob": -0.25561523, "text": "(" }, { "id": 62, - "logprob": -5.234375, + "logprob": -4.8085938, "text": "L" }, { "id": 44, - "logprob": -3.1132812, + "logprob": -3.7890625, "text": ":" }, { "id": 1682, - "logprob": -1.1123047, + "logprob": -0.80615234, "text": " List" }, { "id": 77, - "logprob": -0.14294434, + "logprob": -0.22375488, "text": "[" }, { "id": 1808, - "logprob": -0.32299805, + "logprob": -0.46801758, "text": "float" }, { "id": 10794, - "logprob": -2.8164062, + "logprob": -3.0253906, "text": "]):" } ], @@ -204,74 +156,26 @@ "tokens": [ { "id": 284, - "logprob": -0.12854004, + "logprob": -0.046447754, "special": false, "text": "\n " }, { - "id": 1524, - "logprob": -0.9897461, - "special": false, - "text": " \"\"\"" - }, - { - "id": 284, - "logprob": -0.69970703, - "special": false, - "text": "\n " - }, - { - "id": 14883, - "logprob": -2.2050781, - "special": false, - "text": " Calculate" - }, - { - "id": 322, - "logprob": -0.2668457, - "special": false, - "text": " the" - }, - { - "id": 3226, - "logprob": -0.08496094, - "special": false, - "text": " ge" - }, - { - "id": 21017, - "logprob": -0.019012451, - "special": false, - "text": "ometric" - }, - { - "id": 5651, - "logprob": -0.029037476, - "special": false, - "text": " mean" - }, - { - "id": 432, - "logprob": -0.2939453, - "special": false, - "text": " of" - }, - { - "id": 312, - "logprob": -0.31591797, - "special": false, - "text": " a" + "id": 0, + "logprob": null, + "special": true, + "text": "<|endoftext|>" } ], "top_tokens": null }, - "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" + "generated_text": "\n " }, { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "eos_token", + "generated_tokens": 2, "prefill": [ { "id": 589, @@ -280,57 +184,57 @@ }, { "id": 3226, - "logprob": -8.5859375, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -7.5859375, + "logprob": -8.859375, "text": "ometric" }, { "id": 81, - "logprob": -0.26586914, + "logprob": -0.2163086, "text": "_" }, { "id": 6009, - "logprob": -1.6347656, + "logprob": -1.2958984, "text": "mean" }, { "id": 26, - "logprob": -0.22766113, + "logprob": -0.2529297, "text": "(" }, { "id": 62, - "logprob": -5.2265625, + "logprob": -4.796875, "text": "L" }, { "id": 44, - "logprob": -3.0976562, + "logprob": -3.7910156, "text": ":" }, { "id": 1682, - "logprob": -1.1025391, + "logprob": -0.8076172, "text": " List" }, { "id": 77, - "logprob": -0.1427002, + "logprob": -0.22375488, "text": "[" }, { "id": 1808, - "logprob": -0.32592773, + "logprob": -0.46655273, "text": "float" }, { "id": 10794, - "logprob": -2.8164062, + "logprob": -3.0234375, "text": "]):" } ], @@ -338,74 +242,26 @@ "tokens": [ { "id": 284, - "logprob": -0.13012695, + "logprob": -0.0463562, "special": false, "text": "\n " }, { - "id": 1524, - "logprob": -0.98046875, - "special": false, - "text": " \"\"\"" - }, - { - "id": 284, - "logprob": -0.69921875, - "special": false, - "text": "\n " - }, - { - "id": 14883, - "logprob": -2.1992188, - "special": false, - "text": " Calculate" - }, - { - "id": 322, - "logprob": -0.2668457, - "special": false, - "text": " the" - }, - { - "id": 3226, - "logprob": -0.083496094, - "special": false, - "text": " ge" - }, - { - "id": 21017, - "logprob": -0.01902771, - "special": false, - "text": "ometric" - }, - { - "id": 5651, - "logprob": -0.029006958, - "special": false, - "text": " mean" - }, - { - "id": 432, - "logprob": -0.29248047, - "special": false, - "text": " of" - }, - { - "id": 312, - "logprob": -0.3161621, - "special": false, - "text": " a" + "id": 0, + "logprob": null, + "special": true, + "text": "<|endoftext|>" } ], "top_tokens": null }, - "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" + "generated_text": "\n " }, { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "eos_token", + "generated_tokens": 2, "prefill": [ { "id": 589, @@ -414,57 +270,57 @@ }, { "id": 3226, - "logprob": -8.5859375, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -7.5859375, + "logprob": -8.859375, "text": "ometric" }, { "id": 81, - "logprob": -0.26904297, + "logprob": -0.21862793, "text": "_" }, { "id": 6009, - "logprob": -1.6386719, + "logprob": -1.3095703, "text": "mean" }, { "id": 26, - "logprob": -0.22705078, + "logprob": -0.25512695, "text": "(" }, { "id": 62, - "logprob": -5.234375, + "logprob": -4.796875, "text": "L" }, { "id": 44, - "logprob": -3.1132812, + "logprob": -3.7890625, "text": ":" }, { "id": 1682, - "logprob": -1.1074219, + "logprob": -0.79589844, "text": " List" }, { "id": 77, - "logprob": -0.14477539, + "logprob": -0.22692871, "text": "[" }, { "id": 1808, - "logprob": -0.3256836, + "logprob": -0.46801758, "text": "float" }, { "id": 10794, - "logprob": -2.8027344, + "logprob": -3.0097656, "text": "]):" } ], @@ -472,67 +328,19 @@ "tokens": [ { "id": 284, - "logprob": -0.12915039, + "logprob": -0.04638672, "special": false, "text": "\n " }, { - "id": 1524, - "logprob": -0.98535156, - "special": false, - "text": " \"\"\"" - }, - { - "id": 284, - "logprob": -0.69921875, - "special": false, - "text": "\n " - }, - { - "id": 14883, - "logprob": -2.2011719, - "special": false, - "text": " Calculate" - }, - { - "id": 322, - "logprob": -0.26708984, - "special": false, - "text": " the" - }, - { - "id": 3226, - "logprob": -0.08502197, - "special": false, - "text": " ge" - }, - { - "id": 21017, - "logprob": -0.019012451, - "special": false, - "text": "ometric" - }, - { - "id": 5651, - "logprob": -0.028625488, - "special": false, - "text": " mean" - }, - { - "id": 432, - "logprob": -0.29589844, - "special": false, - "text": " of" - }, - { - "id": 312, - "logprob": -0.31591797, - "special": false, - "text": " a" + "id": 0, + "logprob": null, + "special": true, + "text": "<|endoftext|>" } ], "top_tokens": null }, - "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" + "generated_text": "\n " } ] diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json index ef88926c..93724fe4 100644 --- a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json @@ -11,22 +11,22 @@ }, { "id": 13, - "logprob": -2.734375, + "logprob": -2.59375, "text": "," }, { "id": 8862, - "logprob": -3.6875, + "logprob": -3.5625, "text": " yellow" }, { "id": 13, - "logprob": -0.40234375, + "logprob": -0.44726562, "text": "," }, { "id": 209, - "logprob": -8.25, + "logprob": -8.0, "text": " " } ], @@ -52,7 +52,7 @@ }, { "id": 9830, - "logprob": -2.25, + "logprob": -2.03125, "special": false, "text": " colors" }, @@ -64,13 +64,13 @@ }, { "id": 329, - "logprob": -2.171875, + "logprob": -2.734375, "special": false, "text": " A" }, { "id": 1180, - "logprob": -2.046875, + "logprob": -2.0, "special": false, "text": " number" }, @@ -81,19 +81,19 @@ "text": " of" }, { - "id": 1027, - "logprob": -1.5546875, + "id": 253, + "logprob": -0.69140625, "special": false, - "text": " different" + "text": " the" }, { "id": 3295, - "logprob": -0.97265625, + "logprob": -0.8203125, "special": false, "text": " color" } ], "top_tokens": null }, - "generated_text": "blue, red, yellow, \nand blue colors. A number of different color" + "generated_text": "blue, red, yellow, \nand blue colors. A number of the color" } diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json index 5cacf3e9..40ec7e2f 100644 --- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json +++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json @@ -26,13 +26,13 @@ }, { "id": 259, - "logprob": -0.4716797, + "logprob": -0.46948242, "special": false, "text": " " }, { "id": 261, - "logprob": -0.044677734, + "logprob": -0.15307617, "special": false, "text": "," }, @@ -56,7 +56,7 @@ }, { "id": 35622, - "logprob": -1.1630859, + "logprob": -1.2998047, "special": false, "text": " cloud" }, diff --git a/integration-tests/models/test_flash_starcoder_gptq.py b/integration-tests/models/test_flash_starcoder_gptq.py index f1007d6e..6d46e54d 100644 --- a/integration-tests/models/test_flash_starcoder_gptq.py +++ b/integration-tests/models/test_flash_starcoder_gptq.py @@ -21,7 +21,7 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap max_new_tokens=20, decoder_input_details=True, ) - assert response.details.generated_tokens == 20 + assert response.details.generated_tokens == 2 assert response == generous_response_snapshot @@ -38,7 +38,7 @@ async def test_flash_starcoder_gptq_default_params( decoder_input_details=True, seed=0, ) - assert response.details.generated_tokens == 20 + assert response.details.generated_tokens == 2 assert response == generous_response_snapshot diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index bc946de8..8548970a 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -47,7 +47,7 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): assert response.details.generated_tokens == 10 assert ( response.generated_text - == "blue, red, yellow, \nand blue colors. A number of different color" + == "blue, red, yellow, \nand blue colors. A number of the color" ) assert response == response_snapshot From 1411bfb9897bafc4d5fbc61d94b7d295db0ba70a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 16 Aug 2024 10:01:01 +0200 Subject: [PATCH 50/72] nix: try to reduce the number of Rust rebuilds (#2424) Try to reduce the number of router/launcher rebuilds by filtering sources. In this way, recompiles should only be triggered by changes in Cargo or Rust files. --- flake.lock | 16 ++++++++++ flake.nix | 30 ++++--------------- nix/crate-overrides.nix | 66 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 25 deletions(-) create mode 100644 nix/crate-overrides.nix diff --git a/flake.lock b/flake.lock index ef05c258..7c772377 100644 --- a/flake.lock +++ b/flake.lock @@ -579,6 +579,21 @@ "type": "github" } }, + "nix-filter": { + "locked": { + "lastModified": 1710156097, + "narHash": "sha256-1Wvk8UP7PXdf8bCCaEoMnOT1qe5/Duqgj+rL8sRQsSM=", + "owner": "numtide", + "repo": "nix-filter", + "rev": "3342559a24e85fc164b295c3444e8a139924675b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "nix-filter", + "type": "github" + } + }, "nix-github-actions": { "inputs": { "nixpkgs": [ @@ -880,6 +895,7 @@ "inputs": { "crate2nix": "crate2nix", "flake-utils": "flake-utils_6", + "nix-filter": "nix-filter", "nixpkgs": [ "tgi-nix", "nixpkgs" diff --git a/flake.nix b/flake.nix index 168c8649..cf05746a 100644 --- a/flake.nix +++ b/flake.nix @@ -4,6 +4,7 @@ url = "github:nix-community/crate2nix"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; + nix-filter.url = "github:numtide/nix-filter"; tgi-nix.url = "github:danieldk/tgi-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; @@ -17,6 +18,7 @@ { self, crate2nix, + nix-filter, nixpkgs, flake-utils, rust-overlay, @@ -44,6 +46,7 @@ }; inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryEditablePackage; text-generation-server = mkPoetryEditablePackage { editablePackageSources = ./server; }; + crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; }; in { devShells.default = @@ -91,31 +94,8 @@ transformers vllm - cargoNix.workspaceMembers.text-generation-launcher.build - - (cargoNix.workspaceMembers.text-generation-router-v3.build.override { - crateOverrides = defaultCrateOverrides // { - aws-lc-rs = attrs: { - # aws-lc-rs does its own custom parsing of Cargo environment - # variables like DEP_.*_INCLUDE. However buildRustCrate does - # not use the version number, so the parsing fails. - postPatch = '' - substituteInPlace build.rs \ - --replace-fail \ - "assert!(!selected.is_empty()" \ - "// assert!(!selected.is_empty()" - ''; - }; - rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; }; - text-generation-router-v3 = attrs: { - # We need to do the src/source root dance so that the build - # has access to the protobuf file. - src = ./.; - postPatch = "cd backends/v3"; - buildInputs = [ protobuf ]; - }; - }; - }) + (cargoNix.workspaceMembers.text-generation-launcher.build.override { inherit crateOverrides; }) + (cargoNix.workspaceMembers.text-generation-router-v3.build.override { inherit crateOverrides; }) ]); venvDir = "./.venv"; diff --git a/nix/crate-overrides.nix b/nix/crate-overrides.nix new file mode 100644 index 00000000..343b3b25 --- /dev/null +++ b/nix/crate-overrides.nix @@ -0,0 +1,66 @@ +{ pkgs, nix-filter }: + +let + filter = nix-filter.lib; +in +with pkgs; +defaultCrateOverrides +// { + aws-lc-rs = attrs: { + # aws-lc-rs does its own custom parsing of Cargo environment + # variables like DEP_.*_INCLUDE. However buildRustCrate does + # not use the version number, so the parsing fails. + postPatch = '' + substituteInPlace build.rs \ + --replace-fail \ + "assert!(!selected.is_empty()" \ + "// assert!(!selected.is_empty()" + ''; + }; + rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; }; + + grpc-metadata = attrs: { + src = + filter { + root = ../backends/grpc-metadata; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; + }; + text-generation-launcer = attrs: { + src = + filter { + root = ../launcher; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; + }; + text-generation-router = attrs: { + src = + filter { + root = ../router; + include = with filter; [ + isDirectory + (matchExt "rs") + ]; + }; + }; + text-generation-router-v3 = attrs: { + # We need to do the src/source root dance so that the build + # has access to the protobuf file. + src = filter { + root = ../.; + include = with filter; [ + isDirectory + (and (inDirectory "backends/v3") (matchExt "rs")) + (and (inDirectory "proto") (matchExt "proto")) + ]; + }; + postPatch = "cd backends/v3"; + buildInputs = [ protobuf ]; + }; +} From 99b662f8c27ddaf469df0f035bc0fdcfa4245bbf Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Fri, 16 Aug 2024 12:43:08 +0200 Subject: [PATCH 51/72] Improve the Consuming TGI + Streaming docs. (#2412) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Improve the Consuming TGI docs. * Fix erronous update to . * add info about Open AI client. * More updates. * Apply suggestions from code review Co-authored-by: Erik Kaunismรคki * Suggestions from Lucain. * Update Gradio snippet. * Up. * Apply suggestions from code review Co-authored-by: Lucain * Update docs/source/basic_tutorials/consuming_tgi.md Co-authored-by: Lucain * Up. * Apply suggestions from code review Co-authored-by: Omar Sanseviero * Up. * Up. * Doc review from Nico. * Doc review from Nico. x2 * Last nit --------- Co-authored-by: Erik Kaunismรคki Co-authored-by: Lucain Co-authored-by: Omar Sanseviero --- docs/openapi.json | 2 +- docs/source/basic_tutorials/consuming_tgi.md | 196 +++++++++++------- docs/source/basic_tutorials/using_guidance.md | 2 +- docs/source/conceptual/streaming.md | 94 +++++---- 4 files changed, 174 insertions(+), 120 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index df21e19d..9ddabaa1 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2094,4 +2094,4 @@ "description": "Hugging Face Text Generation Inference API" } ] -} +} \ No newline at end of file diff --git a/docs/source/basic_tutorials/consuming_tgi.md b/docs/source/basic_tutorials/consuming_tgi.md index 4829ec7c..6e4ec49c 100644 --- a/docs/source/basic_tutorials/consuming_tgi.md +++ b/docs/source/basic_tutorials/consuming_tgi.md @@ -1,81 +1,125 @@ # Consuming Text Generation Inference -There are many ways you can consume Text Generation Inference server in your applications. After launching, you can use the `/generate` route and make a `POST` request to get results from the server. You can also use the `/generate_stream` route if you want TGI to return a stream of tokens. You can make the requests using the tool of your preference, such as curl, Python or TypeScrpt. For a final end-to-end experience, we also open-sourced ChatUI, a chat interface for open-source models. +There are many ways to consume Text Generation Inference (TGI) server in your applications. After launching the server, you can use the [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) `/v1/chat/completions` route and make a `POST` request to get results from the server. You can also pass `"stream": true` to the call if you want TGI to return a stream of tokens. + +For more information on the API, consult the OpenAPI documentation of `text-generation-inference` available [here](https://huggingface.github.io/text-generation-inference). + +You can make the requests using any tool of your preference, such as curl, Python, or TypeScript. For an end-to-end experience, we've open-sourced [ChatUI](https://github.com/huggingface/chat-ui), a chat interface for open-access models. ## curl -After the launch, you can query the model using either the `/generate` or `/generate_stream` routes: +After a successful server launch, you can query the model using the `v1/chat/completions` route, to get responses that are compliant to the OpenAI Chat Completion spec: + +```bash +curl localhost:8080/v1/chat/completions \ + -X POST \ + -d '{ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is deep learning?" + } + ], + "stream": true, + "max_tokens": 20 +}' \ + -H 'Content-Type: application/json' +``` + +For non-chat use-cases, you can also use the `/generate` and `/generate_stream` routes. ```bash curl 127.0.0.1:8080/generate \ -X POST \ - -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ + -d '{ + "inputs":"What is Deep Learning?", + "parameters":{ + "max_new_tokens":20 + } +}' \ -H 'Content-Type: application/json' ``` +## Python -## Inference Client +### Inference Client -[`huggingface-hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a nice high-level class, [`~huggingface_hub.InferenceClient`], which makes it easy to make calls to a TGI endpoint. `InferenceClient` also takes care of parameter validation and provides a simple to-use interface. -You can simply install `huggingface-hub` package with pip. +[`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a high-level class, [`huggingface_hub.InferenceClient`](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient), which makes it easy to make calls to TGI's Messages API. `InferenceClient` also takes care of parameter validation and provides a simple-to-use interface. + +Install `huggingface_hub` package via pip. ```bash -pip install huggingface-hub +pip install huggingface_hub ``` -Once you start the TGI server, instantiate `InferenceClient()` with the URL to the endpoint serving the model. You can then call `text_generation()` to hit the endpoint through Python. +You can now use `InferenceClient` the exact same way you would use `OpenAI` client in Python ```python from huggingface_hub import InferenceClient -client = InferenceClient(model="http://127.0.0.1:8080") -client.text_generation(prompt="Write a code for snake game") +client = InferenceClient( + base_url="http://localhost:8080/v1/", +) + +output = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Count to 10"}, + ], + stream=True, + max_tokens=1024, +) + +for chunk in output: + print(chunk.choices[0].delta.content) ``` -You can do streaming with `InferenceClient` by passing `stream=True`. Streaming will return tokens as they are being generated in the server. To use streaming, you can do as follows: +You can check out more details about OpenAI compatibility [here](https://huggingface.co/docs/huggingface_hub/en/guides/inference#openai-compatibility). + +There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) + +### OpenAI Client + +You can directly use the OpenAI [Python](https://github.com/openai/openai-python) or [JS](https://github.com/openai/openai-node) clients to interact with TGI. + +Install the OpenAI Python package via pip. + +```bash +pip install openai +``` ```python -for token in client.text_generation("How do you make cheese?", max_new_tokens=12, stream=True): - print(token) +from openai import OpenAI + +# init the client but point it to TGI +client = OpenAI( + base_url="http://localhost:8080/v1/", + api_key="-" +) + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "What is deep learning?"} + ], + stream=True +) + +# iterate and print stream +for message in chat_completion: + print(message) ``` -Another parameter you can use with TGI backend is `details`. You can get more details on generation (tokens, probabilities, etc.) by setting `details` to `True`. When it's specified, TGI will return a `TextGenerationResponse` or `TextGenerationStreamResponse` rather than a string or stream. +## UI -```python -output = client.text_generation(prompt="Meaning of life is", details=True) -print(output) - -# TextGenerationResponse(generated_text=' a complex concept that is not always clear to the individual. It is a concept that is not always', details=Details(finish_reason=, generated_tokens=20, seed=None, prefill=[], tokens=[Token(id=267, text=' a', logprob=-2.0723474, special=False), Token(id=11235, text=' complex', logprob=-3.1272552, special=False), Token(id=17908, text=' concept', logprob=-1.3632495, special=False),..)) -``` - -You can see how to stream below. - -```python -output = client.text_generation(prompt="Meaning of life is", stream=True, details=True) -print(next(iter(output))) - -# TextGenerationStreamResponse(token=Token(id=267, text=' a', logprob=-2.0723474, special=False), generated_text=None, details=None) -``` - -You can check out the details of the function [here](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation). There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) - - -## ChatUI - -ChatUI is an open-source interface built for LLM serving. It offers many customization options, such as web search with SERP API and more. ChatUI can automatically consume the TGI server and even provides an option to switch between different TGI endpoints. You can try it out at [Hugging Chat](https://huggingface.co/chat/), or use the [ChatUI Docker Space](https://huggingface.co/new-space?template=huggingchat/chat-ui-template) to deploy your own Hugging Chat to Spaces. - -To serve both ChatUI and TGI in same environment, simply add your own endpoints to the `MODELS` variable in `.env.local` file inside the `chat-ui` repository. Provide the endpoints pointing to where TGI is served. - -``` -{ -// rest of the model config here -"endpoints": [{"url": "https://HOST:PORT/generate_stream"}] -} -``` - -![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png) - -## Gradio +### Gradio Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first. @@ -89,19 +133,28 @@ Assume you are serving your model on port 8080, we will query through [Inference import gradio as gr from huggingface_hub import InferenceClient -client = InferenceClient(model="http://127.0.0.1:8080") +client = InferenceClient(base_url="http://127.0.0.1:8080") def inference(message, history): partial_message = "" - for token in client.text_generation(message, max_new_tokens=20, stream=True): - partial_message += token + output = client.chat.completions.create( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": message}, + ], + stream=True, + max_tokens=1024, + ) + + for chunk in output: + partial_message += chunk.choices[0].delta.content yield partial_message gr.ChatInterface( inference, chatbot=gr.Chatbot(height=300), textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7), - description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.", + description="This is the demo for Gradio UI consuming TGI endpoint.", title="Gradio ๐Ÿค TGI", examples=["Are tomatoes vegetables?"], retry_btn="Retry", @@ -110,20 +163,7 @@ gr.ChatInterface( ).queue().launch() ``` -The UI looks like this ๐Ÿ‘‡ - -
- - -
- -You can try the demo directly here ๐Ÿ‘‡ +You can check out the UI and try the demo directly here ๐Ÿ‘‡