feat: add guideline to chat request and template

This commit is contained in:
drbh 2024-08-09 13:53:47 +00:00
parent 6e127dcc96
commit 3b25cd3213
6 changed files with 20 additions and 6 deletions

View File

@ -2080,4 +2080,4 @@
"description": "Hugging Face Text Generation Inference API" "description": "Hugging Face Text Generation Inference API"
} }
] ]
} }

View File

@ -9,7 +9,7 @@ We recommend using the official quantization scripts for creating your quants:
2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py) 2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py)
3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md) 3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md)
For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest. For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest.
## Quantization with bitsandbytes, EETQ & fp8 ## Quantization with bitsandbytes, EETQ & fp8
@ -69,4 +69,4 @@ text-generation-launcher --model-id /data/falcon-40b-gptq/ --sharded true --num-
You can learn more about the quantization options by running `text-generation-server quantize --help`. You can learn more about the quantization options by running `text-generation-server quantize --help`.
If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration). If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration).
You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf).

View File

@ -48,6 +48,7 @@ impl ChatTemplate {
pub(crate) fn apply( pub(crate) fn apply(
&self, &self,
guideline: Option<&str>,
mut messages: Vec<Message>, mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>, grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
@ -65,6 +66,7 @@ impl ChatTemplate {
self.template self.template
.render(ChatTemplateInputs { .render(ChatTemplateInputs {
guideline,
messages, messages,
bos_token: self.bos_token.as_deref(), bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(), eos_token: self.eos_token.as_deref(),

View File

@ -138,13 +138,14 @@ impl Infer {
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn apply_chat_template( pub(crate) fn apply_chat_template(
&self, &self,
guideline: Option<String>,
messages: Vec<Message>, messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>, grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
self.chat_template self.chat_template
.as_ref() .as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(messages, grammar_with_prompt) .apply(guideline.as_deref(), messages, grammar_with_prompt)
.map_err(|e| { .map_err(|e| {
metrics::counter!("tgi_request_failure", "err" => "template").increment(1); metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
tracing::error!("{e}"); tracing::error!("{e}");

View File

@ -829,6 +829,11 @@ pub(crate) struct ChatRequest {
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub response_format: Option<GrammarType>, pub response_format: Option<GrammarType>,
/// A guideline to be used in the chat_template
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub guideline: Option<String>,
} }
fn default_tool_prompt() -> Option<String> { fn default_tool_prompt() -> Option<String> {
@ -936,6 +941,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
add_generation_prompt: bool, add_generation_prompt: bool,
tools: Option<&'a str>, tools: Option<&'a str>,
tools_prompt: Option<&'a str>, tools_prompt: Option<&'a str>,
guideline: Option<&'a str>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]

View File

@ -141,6 +141,7 @@ async fn get_chat_tokenize(
tool_prompt, tool_prompt,
temperature, temperature,
response_format, response_format,
guideline,
.. ..
} = req; } = req;
@ -151,6 +152,7 @@ async fn get_chat_tokenize(
tools, tools,
tool_choice, tool_choice,
&tool_prompt, &tool_prompt,
guideline,
messages, messages,
)?; )?;
@ -1123,6 +1125,7 @@ async fn chat_completions(
tool_prompt, tool_prompt,
temperature, temperature,
response_format, response_format,
guideline,
.. ..
} = req; } = req;
@ -1142,6 +1145,7 @@ async fn chat_completions(
tools, tools,
tool_choice, tool_choice,
&tool_prompt, &tool_prompt,
guideline,
messages, messages,
)?; )?;
@ -2402,6 +2406,7 @@ fn prepare_chat_input(
tools: Option<Vec<Tool>>, tools: Option<Vec<Tool>>,
tool_choice: ToolChoice, tool_choice: ToolChoice,
tool_prompt: &str, tool_prompt: &str,
guideline: Option<String>,
messages: Vec<Message>, messages: Vec<Message>,
) -> Result<PreparedInput, InferError> { ) -> Result<PreparedInput, InferError> {
if response_format.is_some() && tools.is_some() { if response_format.is_some() && tools.is_some() {
@ -2411,7 +2416,7 @@ fn prepare_chat_input(
} }
if let Some(format) = response_format { if let Some(format) = response_format {
let inputs = infer.apply_chat_template(messages, None)?; let inputs = infer.apply_chat_template(guideline, messages, None)?;
return Ok((inputs, Some(format), None)); return Ok((inputs, Some(format), None));
} }
@ -2423,6 +2428,6 @@ fn prepare_chat_input(
let tools_grammar_prompt = tool_grammar let tools_grammar_prompt = tool_grammar
.as_ref() .as_ref()
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into()));
let inputs = infer.apply_chat_template(messages, tools_grammar_prompt)?; let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?;
Ok((inputs, grammar, tool_grammar)) Ok((inputs, grammar, tool_grammar))
} }