mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: make eos and bos optional and only init template once
This commit is contained in:
parent
db835509ed
commit
f378c60517
@ -34,7 +34,11 @@ pub struct Infer {
|
|||||||
/// Inference limit
|
/// Inference limit
|
||||||
limit_concurrent_requests: Arc<Semaphore>,
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
/// Chat template (template, bos_token, eos_token)
|
/// Chat template (template, bos_token, eos_token)
|
||||||
template: (Option<Template<'static, 'static>>, String, String),
|
template: (
|
||||||
|
Option<Template<'static, 'static>>,
|
||||||
|
Option<String>,
|
||||||
|
Option<String>,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infer shared state
|
/// Infer shared state
|
||||||
@ -92,18 +96,20 @@ impl Infer {
|
|||||||
.template_from_str(Box::leak(template_str))
|
.template_from_str(Box::leak(template_str))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
|
let eos_token = tokenizer_config
|
||||||
|
.eos_token
|
||||||
|
.map_or_else(String::new, |t| t)
|
||||||
|
.into();
|
||||||
|
let bos_token = tokenizer_config
|
||||||
|
.bos_token
|
||||||
|
.map_or_else(String::new, |t| t)
|
||||||
|
.into();
|
||||||
Self {
|
Self {
|
||||||
validation,
|
validation,
|
||||||
queue,
|
queue,
|
||||||
shared,
|
shared,
|
||||||
limit_concurrent_requests: semaphore,
|
limit_concurrent_requests: semaphore,
|
||||||
template: (
|
template: (template, eos_token, bos_token),
|
||||||
template,
|
|
||||||
// initialize bos_token and eos_token to empty strings if not provided
|
|
||||||
tokenizer_config.bos_token.unwrap_or_default(),
|
|
||||||
tokenizer_config.eos_token.unwrap_or_default(),
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -160,14 +166,14 @@ impl Infer {
|
|||||||
/// Apply the chat template to the chat request
|
/// Apply the chat template to the chat request
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
|
||||||
let (template, bos_token, eos_token) = self.template.clone();
|
let (template, bos_token, eos_token) = &self.template;
|
||||||
template
|
template
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
.render(ChatTemplateInputs {
|
.render(ChatTemplateInputs {
|
||||||
messages,
|
messages,
|
||||||
eos_token,
|
eos_token: eos_token.as_deref(),
|
||||||
bos_token,
|
bos_token: bos_token.as_deref(),
|
||||||
})
|
})
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||||
@ -773,8 +779,8 @@ mod tests {
|
|||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: "[BOS]".to_string(),
|
bos_token: Some("[BOS]"),
|
||||||
eos_token: "[EOS]".to_string(),
|
eos_token: Some("[EOS]"),
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
@ -852,8 +858,8 @@ magic!"#
|
|||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: "[BOS]".to_string(),
|
bos_token: Some("[BOS]"),
|
||||||
eos_token: "[EOS]".to_string(),
|
eos_token: Some("[EOS]"),
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
|
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
|
||||||
@ -924,8 +930,8 @@ magic!"#
|
|||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: "[BOS]".to_string(),
|
bos_token: Some("[BOS]"),
|
||||||
eos_token: "[EOS]".to_string(),
|
eos_token: Some("[EOS]"),
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
|
@ -368,10 +368,10 @@ pub(crate) struct ChatRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
pub(crate) struct ChatTemplateInputs {
|
pub(crate) struct ChatTemplateInputs<'a> {
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
bos_token: String,
|
bos_token: Option<&'a str>,
|
||||||
eos_token: String,
|
eos_token: Option<&'a str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
@ -659,9 +659,9 @@ async fn chat_completions(
|
|||||||
|
|
||||||
// build the complete response object with the full text
|
// build the complete response object with the full text
|
||||||
let response = ChatCompletion::new(
|
let response = ChatCompletion::new(
|
||||||
generation.generated_text,
|
|
||||||
model_id,
|
model_id,
|
||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
|
generation.generated_text,
|
||||||
current_time,
|
current_time,
|
||||||
generation.details.unwrap(),
|
generation.details.unwrap(),
|
||||||
logprobs,
|
logprobs,
|
||||||
|
Loading…
Reference in New Issue
Block a user