From 3b25cd3213947703c6fb9a3f15826616c4cb2f92 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 9 Aug 2024 13:53:47 +0000 Subject: [PATCH] feat: add guideline to chat request and template --- docs/openapi.json | 2 +- docs/source/conceptual/quantization.md | 4 ++-- router/src/infer/chat_template.rs | 2 ++ router/src/infer/mod.rs | 3 ++- router/src/lib.rs | 6 ++++++ router/src/server.rs | 9 +++++++-- 6 files changed, 20 insertions(+), 6 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 9d281a48..ed9b0b96 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2080,4 +2080,4 @@ "description": "Hugging Face Text Generation Inference API" } ] -} \ No newline at end of file +} diff --git a/docs/source/conceptual/quantization.md b/docs/source/conceptual/quantization.md index a1ebe7e7..b7672a9f 100644 --- a/docs/source/conceptual/quantization.md +++ b/docs/source/conceptual/quantization.md @@ -9,7 +9,7 @@ We recommend using the official quantization scripts for creating your quants: 2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py) 3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md) -For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. +For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. ## Quantization with bitsandbytes, EETQ & fp8 @@ -69,4 +69,4 @@ text-generation-launcher --model-id /data/falcon-40b-gptq/ --sharded true --num- You can learn more about the quantization options by running `text-generation-server quantize --help`. If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration). -You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). \ No newline at end of file +You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 24a00352..9203c1b2 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -48,6 +48,7 @@ impl ChatTemplate { pub(crate) fn apply( &self, + guideline: Option<&str>, mut messages: Vec, grammar_with_prompt: Option<(GrammarType, String)>, ) -> Result { @@ -65,6 +66,7 @@ impl ChatTemplate { self.template .render(ChatTemplateInputs { + guideline, messages, bos_token: self.bos_token.as_deref(), eos_token: self.eos_token.as_deref(), diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 534a2647..58d5cf9a 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -138,13 +138,14 @@ impl Infer { #[instrument(skip_all)] pub(crate) fn apply_chat_template( &self, + guideline: Option, messages: Vec, grammar_with_prompt: Option<(GrammarType, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(messages, grammar_with_prompt) + .apply(guideline.as_deref(), messages, grammar_with_prompt) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); diff --git a/router/src/lib.rs b/router/src/lib.rs index a956b058..4e189a3b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -829,6 +829,11 @@ pub(crate) struct ChatRequest { #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub response_format: Option, + + /// A guideline to be used in the chat_template + #[serde(default)] + #[schema(nullable = true, default = "null", example = "null")] + pub guideline: Option, } fn default_tool_prompt() -> Option { @@ -936,6 +941,7 @@ pub(crate) struct ChatTemplateInputs<'a> { add_generation_prompt: bool, tools: Option<&'a str>, tools_prompt: Option<&'a str>, + guideline: Option<&'a str>, } #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)] diff --git a/router/src/server.rs b/router/src/server.rs index 1d1cd36a..8c0bd762 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -141,6 +141,7 @@ async fn get_chat_tokenize( tool_prompt, temperature, response_format, + guideline, .. } = req; @@ -151,6 +152,7 @@ async fn get_chat_tokenize( tools, tool_choice, &tool_prompt, + guideline, messages, )?; @@ -1123,6 +1125,7 @@ async fn chat_completions( tool_prompt, temperature, response_format, + guideline, .. } = req; @@ -1142,6 +1145,7 @@ async fn chat_completions( tools, tool_choice, &tool_prompt, + guideline, messages, )?; @@ -2402,6 +2406,7 @@ fn prepare_chat_input( tools: Option>, tool_choice: ToolChoice, tool_prompt: &str, + guideline: Option, messages: Vec, ) -> Result { if response_format.is_some() && tools.is_some() { @@ -2411,7 +2416,7 @@ fn prepare_chat_input( } if let Some(format) = response_format { - let inputs = infer.apply_chat_template(messages, None)?; + let inputs = infer.apply_chat_template(guideline, messages, None)?; return Ok((inputs, Some(format), None)); } @@ -2423,6 +2428,6 @@ fn prepare_chat_input( let tools_grammar_prompt = tool_grammar .as_ref() .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); - let inputs = infer.apply_chat_template(messages, tools_grammar_prompt)?; + let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?; Ok((inputs, grammar, tool_grammar)) }