mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Support continue final message (#2733)
* feat: support continue_final_message param in chat request * feat: add test for continue final message * fix: bump openapi docs * fix: remove continue_final_message chat request param * fix: remove unneeded launcher args in continue test * fix: bump test output * fix: remove accidentally included guideline from rebase * fix: remove guideline tests * fix: adjust continuation tests expected text * fix: replace expected output for continue test
This commit is contained in:
parent
caff779dd4
commit
d471805134
@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "Both an elephant and a mouse are mammals. However, the differences between elephants and mice are:\n\n1",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1732541189,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.4.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 30,
|
||||||
|
"prompt_tokens": 49,
|
||||||
|
"total_tokens": 79
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1732541190,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.4.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 30,
|
||||||
|
"prompt_tokens": 73,
|
||||||
|
"total_tokens": 103
|
||||||
|
}
|
||||||
|
}
|
76
integration-tests/models/test_continue_final_message.py
Normal file
76
integration-tests/models/test_continue_final_message.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def llama_continue_final_message_handle(launcher):
|
||||||
|
with launcher("TinyLlama/TinyLlama-1.1B-Chat-v1.0") as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def llama_continue_final_message(llama_continue_final_message_handle):
|
||||||
|
await llama_continue_final_message_handle.health(300)
|
||||||
|
return llama_continue_final_message_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama_completion_single_prompt(
|
||||||
|
llama_continue_final_message, response_snapshot
|
||||||
|
):
|
||||||
|
response = requests.post(
|
||||||
|
f"{llama_continue_final_message.base_url}/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "system message"},
|
||||||
|
{"role": "user", "content": "Which is bigger an elephant or a mouse?"},
|
||||||
|
],
|
||||||
|
"max_tokens": 30,
|
||||||
|
"stream": False,
|
||||||
|
"seed": 1337,
|
||||||
|
},
|
||||||
|
headers=llama_continue_final_message.headers,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
response = response.json()
|
||||||
|
print(response)
|
||||||
|
assert len(response["choices"]) == 1
|
||||||
|
content = response["choices"][0]["message"]["content"]
|
||||||
|
assert (
|
||||||
|
content
|
||||||
|
== "Both an elephant and a mouse are mammals. However, the differences between elephants and mice are:\n\n1"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama_completion_single_prompt_continue(
|
||||||
|
llama_continue_final_message, response_snapshot
|
||||||
|
):
|
||||||
|
response = requests.post(
|
||||||
|
f"{llama_continue_final_message.base_url}/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "system message"},
|
||||||
|
{"role": "user", "content": "Which is bigger an elephant or a mouse?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "the elephant, but have you heard about",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_tokens": 30,
|
||||||
|
"stream": False,
|
||||||
|
"seed": 1337,
|
||||||
|
},
|
||||||
|
headers=llama_continue_final_message.headers,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
response = response.json()
|
||||||
|
print(response)
|
||||||
|
assert len(response["choices"]) == 1
|
||||||
|
content = response["choices"][0]["message"]["content"]
|
||||||
|
assert (
|
||||||
|
content
|
||||||
|
== " the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
@ -75,8 +75,9 @@ impl ChatTemplate {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||||
|
let final_message = messages.last().cloned();
|
||||||
self.template
|
let mut rendered_template = self
|
||||||
|
.template
|
||||||
.render(ChatTemplateInputs {
|
.render(ChatTemplateInputs {
|
||||||
messages,
|
messages,
|
||||||
bos_token: self.bos_token.as_deref(),
|
bos_token: self.bos_token.as_deref(),
|
||||||
@ -84,7 +85,24 @@ impl ChatTemplate {
|
|||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
tools,
|
tools,
|
||||||
})
|
})
|
||||||
.map_err(InferError::TemplateError)
|
.map_err(InferError::TemplateError)?;
|
||||||
|
|
||||||
|
// if the last message is from the assistant, continue the generation prompt
|
||||||
|
rendered_template = match final_message {
|
||||||
|
Some(msg) if msg.role == "assistant" => {
|
||||||
|
match rendered_template.rfind(msg.content.as_str()) {
|
||||||
|
// implementation based on feature in transformers pipeline
|
||||||
|
// https://github.com/huggingface/transformers/blob/1cf17077bf2d4affed31387c0943251a4ba8fab7/src/transformers/pipelines/text_generation.py#L418
|
||||||
|
Some(index) => rendered_template[..index + msg.content.len()]
|
||||||
|
.trim_end()
|
||||||
|
.to_string(),
|
||||||
|
None => rendered_template,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => rendered_template,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(rendered_template)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user