mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
fix: suppoer tool call id in template and remove unnecessary changes
This commit is contained in:
parent
56f2d66828
commit
bcc44890a8
132
Cargo.lock
generated
132
Cargo.lock
generated
@ -128,9 +128,6 @@ name = "arbitrary"
|
|||||||
version = "1.4.1"
|
version = "1.4.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223"
|
checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223"
|
||||||
dependencies = [
|
|
||||||
"derive_arbitrary",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "arc-swap"
|
name = "arc-swap"
|
||||||
@ -308,7 +305,7 @@ dependencies = [
|
|||||||
"http-body 0.4.6",
|
"http-body 0.4.6",
|
||||||
"hyper 0.14.32",
|
"hyper 0.14.32",
|
||||||
"itoa",
|
"itoa",
|
||||||
"matchit 0.7.3",
|
"matchit",
|
||||||
"memchr",
|
"memchr",
|
||||||
"mime",
|
"mime",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
@ -341,41 +338,7 @@ dependencies = [
|
|||||||
"hyper 1.6.0",
|
"hyper 1.6.0",
|
||||||
"hyper-util",
|
"hyper-util",
|
||||||
"itoa",
|
"itoa",
|
||||||
"matchit 0.7.3",
|
"matchit",
|
||||||
"memchr",
|
|
||||||
"mime",
|
|
||||||
"percent-encoding",
|
|
||||||
"pin-project-lite",
|
|
||||||
"rustversion",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"serde_path_to_error",
|
|
||||||
"serde_urlencoded",
|
|
||||||
"sync_wrapper 1.0.2",
|
|
||||||
"tokio",
|
|
||||||
"tower 0.5.2",
|
|
||||||
"tower-layer",
|
|
||||||
"tower-service",
|
|
||||||
"tracing",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "axum"
|
|
||||||
version = "0.8.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8"
|
|
||||||
dependencies = [
|
|
||||||
"axum-core 0.5.0",
|
|
||||||
"bytes",
|
|
||||||
"form_urlencoded",
|
|
||||||
"futures-util",
|
|
||||||
"http 1.2.0",
|
|
||||||
"http-body 1.0.1",
|
|
||||||
"http-body-util",
|
|
||||||
"hyper 1.6.0",
|
|
||||||
"hyper-util",
|
|
||||||
"itoa",
|
|
||||||
"matchit 0.8.4",
|
|
||||||
"memchr",
|
"memchr",
|
||||||
"mime",
|
"mime",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
@ -431,26 +394,6 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "axum-core"
|
|
||||||
version = "0.5.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733"
|
|
||||||
dependencies = [
|
|
||||||
"bytes",
|
|
||||||
"futures-util",
|
|
||||||
"http 1.2.0",
|
|
||||||
"http-body 1.0.1",
|
|
||||||
"http-body-util",
|
|
||||||
"mime",
|
|
||||||
"pin-project-lite",
|
|
||||||
"rustversion",
|
|
||||||
"sync_wrapper 1.0.2",
|
|
||||||
"tower-layer",
|
|
||||||
"tower-service",
|
|
||||||
"tracing",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum-tracing-opentelemetry"
|
name = "axum-tracing-opentelemetry"
|
||||||
version = "0.16.0"
|
version = "0.16.0"
|
||||||
@ -1165,17 +1108,6 @@ dependencies = [
|
|||||||
"powerfmt",
|
"powerfmt",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "derive_arbitrary"
|
|
||||||
version = "1.4.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800"
|
|
||||||
dependencies = [
|
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn 2.0.98",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "derive_builder"
|
name = "derive_builder"
|
||||||
version = "0.20.2"
|
version = "0.20.2"
|
||||||
@ -2455,12 +2387,6 @@ dependencies = [
|
|||||||
"scopeguard",
|
"scopeguard",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "lockfree-object-pool"
|
|
||||||
version = "0.1.6"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "log"
|
name = "log"
|
||||||
version = "0.4.25"
|
version = "0.4.25"
|
||||||
@ -2522,12 +2448,6 @@ version = "0.7.3"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "matchit"
|
|
||||||
version = "0.8.4"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "maybe-rayon"
|
name = "maybe-rayon"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
@ -4784,7 +4704,7 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum 0.8.1",
|
"axum 0.7.9",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"chrono",
|
"chrono",
|
||||||
@ -4852,7 +4772,7 @@ version = "3.1.1-dev0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum 0.8.1",
|
"axum 0.7.9",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"clap 4.5.30",
|
"clap 4.5.30",
|
||||||
@ -4901,7 +4821,7 @@ version = "3.1.1-dev0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum 0.8.1",
|
"axum 0.7.9",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"clap 4.5.30",
|
"clap 4.5.30",
|
||||||
@ -5639,9 +5559,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "utoipa"
|
name = "utoipa"
|
||||||
version = "5.3.1"
|
version = "4.2.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "435c6f69ef38c9017b4b4eea965dfb91e71e53d869e896db40d1cf2441dd75c0"
|
checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"indexmap 2.7.1",
|
"indexmap 2.7.1",
|
||||||
"serde",
|
"serde",
|
||||||
@ -5651,10 +5571,11 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "utoipa-gen"
|
name = "utoipa-gen"
|
||||||
version = "5.3.1"
|
version = "4.3.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a77d306bc75294fd52f3e99b13ece67c02c1a2789190a6f31d32f736624326f7"
|
checksum = "20c24e8ab68ff9ee746aad22d39b5535601e6416d1b0feeabf78be986a5c4392"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"proc-macro-error",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"regex",
|
"regex",
|
||||||
@ -5663,18 +5584,16 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "utoipa-swagger-ui"
|
name = "utoipa-swagger-ui"
|
||||||
version = "9.0.0"
|
version = "6.0.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "161166ec520c50144922a625d8bc4925cc801b2dda958ab69878527c0e5c5d61"
|
checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"axum 0.8.1",
|
"axum 0.7.9",
|
||||||
"base64 0.22.1",
|
|
||||||
"mime_guess",
|
"mime_guess",
|
||||||
"regex",
|
"regex",
|
||||||
"rust-embed",
|
"rust-embed",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"url",
|
|
||||||
"utoipa",
|
"utoipa",
|
||||||
"zip",
|
"zip",
|
||||||
]
|
]
|
||||||
@ -6404,33 +6323,14 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zip"
|
name = "zip"
|
||||||
version = "2.2.2"
|
version = "0.6.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ae9c1ea7b3a5e1f4b922ff856a129881167511563dc219869afe3787fc0c1a45"
|
checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arbitrary",
|
"byteorder",
|
||||||
"crc32fast",
|
"crc32fast",
|
||||||
"crossbeam-utils",
|
"crossbeam-utils",
|
||||||
"displaydoc",
|
|
||||||
"flate2",
|
"flate2",
|
||||||
"indexmap 2.7.1",
|
|
||||||
"memchr",
|
|
||||||
"thiserror 2.0.11",
|
|
||||||
"zopfli",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "zopfli"
|
|
||||||
version = "0.8.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e5019f391bac5cf252e93bbcc53d039ffd62c7bfb7c150414d61369afe57e946"
|
|
||||||
dependencies = [
|
|
||||||
"bumpalo",
|
|
||||||
"crc32fast",
|
|
||||||
"lockfree-object-pool",
|
|
||||||
"log",
|
|
||||||
"once_cell",
|
|
||||||
"simd-adler32",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -16,7 +16,7 @@ path = "src/main.rs"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1.74"
|
async-trait = "0.1.74"
|
||||||
async-stream = "0.3.5"
|
async-stream = "0.3.5"
|
||||||
axum = { version = "0.8", features = ["json"] }
|
axum = { version = "0.7", features = ["json"] }
|
||||||
axum-tracing-opentelemetry = "0.16"
|
axum-tracing-opentelemetry = "0.16"
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
@ -48,8 +48,8 @@ tower-http = { version = "0.5.1", features = ["cors"] }
|
|||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-opentelemetry = "0.21.0"
|
tracing-opentelemetry = "0.21.0"
|
||||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
utoipa = { version = "5.3.1", features = ["axum_extras"] }
|
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||||
utoipa-swagger-ui = { version = "9.0.0", features = ["axum"] }
|
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
] }
|
] }
|
||||||
|
@ -16,7 +16,7 @@ path = "src/main.rs"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1.74"
|
async-trait = "0.1.74"
|
||||||
async-stream = "0.3.5"
|
async-stream = "0.3.5"
|
||||||
axum = { version = "0.8", features = ["json"] }
|
axum = { version = "0.7", features = ["json"] }
|
||||||
axum-tracing-opentelemetry = "0.16"
|
axum-tracing-opentelemetry = "0.16"
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
@ -48,8 +48,8 @@ tower-http = { version = "0.5.1", features = ["cors"] }
|
|||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-opentelemetry = "0.21.0"
|
tracing-opentelemetry = "0.21.0"
|
||||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
utoipa = { version = "5.3.1", features = ["axum_extras"] }
|
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||||
utoipa-swagger-ui = { version = "9.0.0", features = ["axum"] }
|
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
] }
|
] }
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -481,7 +481,6 @@ async def test_flash_llama_tool_reply_response(
|
|||||||
messages=[
|
messages=[
|
||||||
{"role": "user", "content": "What's the weather like in Paris today?"},
|
{"role": "user", "content": "What's the weather like in Paris today?"},
|
||||||
{
|
{
|
||||||
"content": "",
|
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{
|
{
|
||||||
|
@ -11,7 +11,7 @@ homepage.workspace = true
|
|||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
async-trait = "0.1.74"
|
async-trait = "0.1.74"
|
||||||
async-stream = "0.3.5"
|
async-stream = "0.3.5"
|
||||||
axum = { version = "0.8", features = ["json"] }
|
axum = { version = "0.7", features = ["json"] }
|
||||||
axum-tracing-opentelemetry = "0.16"
|
axum-tracing-opentelemetry = "0.16"
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
futures = "0.3.28"
|
futures = "0.3.28"
|
||||||
@ -42,8 +42,8 @@ tower-http = { version = "0.5.1", features = ["cors"] }
|
|||||||
tracing = "0.1.40"
|
tracing = "0.1.40"
|
||||||
tracing-opentelemetry = "0.21.0"
|
tracing-opentelemetry = "0.21.0"
|
||||||
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
|
||||||
utoipa = { version = "5.3.1", features = ["axum_extras"] }
|
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||||
utoipa-swagger-ui = { version = "9.0.0", features = ["axum"] }
|
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
|
use crate::{
|
||||||
|
ChatTemplateInputs, Message, MessageBody, MessageChunk, TextMessage, TokenizerConfigToken, Tool,
|
||||||
|
};
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
use minijinja_contrib::pycompat;
|
use minijinja_contrib::pycompat;
|
||||||
@ -73,8 +75,10 @@ impl ChatTemplate {
|
|||||||
// if the `tools` variable is used in the template, we just append the tool_prompt
|
// if the `tools` variable is used in the template, we just append the tool_prompt
|
||||||
format!("\n---\n{}", tool_prompt)
|
format!("\n---\n{}", tool_prompt)
|
||||||
};
|
};
|
||||||
if let Some(content) = messages.last_mut().and_then(|msg| msg.content.as_mut()) {
|
if let Some(last_message) = messages.last_mut() {
|
||||||
content.push(MessageChunk::Text { text })
|
if let MessageBody::Content { content } = &mut last_message.body {
|
||||||
|
content.push(MessageChunk::Text { text });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Some(tools)
|
Some(tools)
|
||||||
}
|
}
|
||||||
@ -158,18 +162,22 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: "Hello how can I help?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -186,6 +194,182 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_template_with_tool_response() {
|
||||||
|
let env = Environment::new();
|
||||||
|
|
||||||
|
// template modified from Llama-3.1-8B-Instruct
|
||||||
|
// https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/blob/0e9e39f249a16976918f6564b8830bc894c89659/tokenizer_config.json#L2053
|
||||||
|
// the main change is accesing `message.tool_call_id` from the messages
|
||||||
|
let source = r#"
|
||||||
|
{{- bos_token }}
|
||||||
|
{%- if custom_tools is defined %}
|
||||||
|
{%- set tools = custom_tools %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if not tools_in_user_message is defined %}
|
||||||
|
{%- set tools_in_user_message = true %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if not date_string is defined %}
|
||||||
|
{%- set date_string = "26 Jul 2024" %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if not tools is defined %}
|
||||||
|
{%- set tools = none %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{#- This block extracts the system message, so we can slot it into the right place. #}
|
||||||
|
{%- if messages[0]['role'] == 'system' %}
|
||||||
|
{%- set system_message = messages[0]['content']|trim %}
|
||||||
|
{%- set messages = messages[1:] %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set system_message = "" %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{#- System message + builtin tools #}
|
||||||
|
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
|
||||||
|
{%- if builtin_tools is defined or tools is not none %}
|
||||||
|
{{- "Environment: ipython\n" }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if builtin_tools is defined %}
|
||||||
|
{{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}}
|
||||||
|
{%- endif %}
|
||||||
|
{{- "Cutting Knowledge Date: December 2023\n" }}
|
||||||
|
{{- "Today Date: " + date_string + "\n\n" }}
|
||||||
|
{%- if tools is not none and not tools_in_user_message %}
|
||||||
|
{{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
|
||||||
|
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
|
||||||
|
{{- "Do not use variables.\n\n" }}
|
||||||
|
{%- for t in tools %}
|
||||||
|
{{- t | tojson(indent=4) }}
|
||||||
|
{{- "\n\n" }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- system_message }}
|
||||||
|
{{- "<|eot_id|>" }}
|
||||||
|
|
||||||
|
{#- Custom tools are passed in a user message with some extra guidance #}
|
||||||
|
{%- if tools_in_user_message and not tools is none %}
|
||||||
|
{#- Extract the first user message so we can plug it in here #}
|
||||||
|
{%- if messages | length != 0 %}
|
||||||
|
{%- set first_user_message = messages[0]['content']|trim %}
|
||||||
|
{%- set messages = messages[1:] %}
|
||||||
|
{%- else %}
|
||||||
|
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
|
||||||
|
{{- "Given the following functions, please respond with a JSON for a function call " }}
|
||||||
|
{{- "with its proper arguments that best answers the given prompt.\n\n" }}
|
||||||
|
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
|
||||||
|
{{- "Do not use variables.\n\n" }}
|
||||||
|
{%- for t in tools %}
|
||||||
|
{{- t | tojson(indent=4) }}
|
||||||
|
{{- "\n\n" }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- first_user_message + "<|eot_id|>"}}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
|
||||||
|
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }}
|
||||||
|
{%- elif 'tool_calls' in message %}
|
||||||
|
{%- if not message.tool_calls|length == 1 %}
|
||||||
|
{{- raise_exception("This model only supports single tool-calls at once!") }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- set tool_call = message.tool_calls[0].function %}
|
||||||
|
{%- if builtin_tools is defined and tool_call.name in builtin_tools %}
|
||||||
|
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
|
||||||
|
{{- "<|python_tag|>" + tool_call.name + ".call(" }}
|
||||||
|
{%- for arg_name, arg_val in tool_call.arguments | items %}
|
||||||
|
{{- arg_name + '="' + arg_val + '"' }}
|
||||||
|
{%- if not loop.last %}
|
||||||
|
{{- ", " }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- ")" }}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
|
||||||
|
{{- '{"name": "' + tool_call.name + '", ' }}
|
||||||
|
{{- '"parameters": ' }}
|
||||||
|
{{- tool_call.arguments | tojson }}
|
||||||
|
{{- "}" }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if builtin_tools is defined %}
|
||||||
|
{#- This means we're in ipython mode #}
|
||||||
|
{{- "<|eom_id|>" }}
|
||||||
|
{%- else %}
|
||||||
|
{{- "<|eot_id|>" }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- elif message.role == "tool" or message.role == "ipython" %}
|
||||||
|
{{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
|
||||||
|
{{- "TOOL CALL ID: " + message.tool_call_id + "\n\n" }}
|
||||||
|
{%- if message.content is mapping or message.content is iterable %}
|
||||||
|
{{- message.content | tojson }}
|
||||||
|
{%- else %}
|
||||||
|
{{- message.content }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- "<|eot_id|>" }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
"#;
|
||||||
|
|
||||||
|
// trim all the whitespace
|
||||||
|
let source = source
|
||||||
|
.lines()
|
||||||
|
.map(|line| line.trim())
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
|
.join("");
|
||||||
|
|
||||||
|
let tmpl = env.template_from_str(&source);
|
||||||
|
|
||||||
|
let chat_template_inputs = ChatTemplateInputs {
|
||||||
|
messages: vec![
|
||||||
|
TextMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
TextMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: r#"[ { "id": "0", "function": { "arguments": '{"longitude": 2.2945, "latitude": 48.8567}', "name": "get_weather", "description": None, }, "type": "function", } ]"#.to_string(),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
TextMessage {
|
||||||
|
role: "tool".to_string(),
|
||||||
|
content: "6.7".to_string(),
|
||||||
|
tool_call_id: Some("0".to_string()),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
bos_token: Some("[BOS]"),
|
||||||
|
eos_token: Some("[EOS]"),
|
||||||
|
add_generation_prompt: true,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
r#"[BOS]<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: 26 Jul 2024
|
||||||
|
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Hi!<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
[ { "id": "0", "function": { "arguments": '{"longitude": 2.2945, "latitude": 48.8567}', "name": "get_weather", "description": None, }, "type": "function", } ]<|eot_id|><|start_header_id|>ipython<|end_header_id|>
|
||||||
|
|
||||||
|
TOOL CALL ID: 0
|
||||||
|
|
||||||
|
"6.7"<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
"#
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_chat_template_loop_controls() {
|
fn test_chat_template_loop_controls() {
|
||||||
// some chat templates as e.g. CohereForAI/c4ai-command-r7b-12-202 contain `break`
|
// some chat templates as e.g. CohereForAI/c4ai-command-r7b-12-202 contain `break`
|
||||||
@ -224,18 +408,22 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: "Hello how can I help?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -287,22 +475,27 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi again!".to_string(),
|
content: "Hi again!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: "Hello how can I help?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -359,18 +552,22 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: "Hello how can I help?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -426,18 +623,22 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: "Hello how can I help?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -479,18 +680,22 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: "Hi!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: "Hello how can I help?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: "magic!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -516,14 +721,17 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hello, how are you?".to_string(),
|
content: "Hello, how are you?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "I'm doing great. How can I help you today?".to_string(),
|
content: "I'm doing great. How can I help you today?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "I'd like to show off how chat templating works!".to_string(),
|
content: "I'd like to show off how chat templating works!".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
@ -531,6 +739,7 @@ mod tests {
|
|||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: "You are a friendly chatbot who always responds in the style of a pirate"
|
content: "You are a friendly chatbot who always responds in the style of a pirate"
|
||||||
.to_string(),
|
.to_string(),
|
||||||
|
..Default::default()
|
||||||
}]
|
}]
|
||||||
.iter()
|
.iter()
|
||||||
.chain(&example_chat)
|
.chain(&example_chat)
|
||||||
@ -674,10 +883,12 @@ mod tests {
|
|||||||
TextMessage {
|
TextMessage {
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
|
content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "How many helicopters can a human eat in one sitting?".to_string(),
|
content: "How many helicopters can a human eat in one sitting?".to_string(),
|
||||||
|
..Default::default()
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
|
@ -663,6 +663,7 @@ impl ChatCompletion {
|
|||||||
(Some(content), None) => OutputMessage::ChatMessage(TextMessage {
|
(Some(content), None) => OutputMessage::ChatMessage(TextMessage {
|
||||||
role: "assistant".into(),
|
role: "assistant".into(),
|
||||||
content,
|
content,
|
||||||
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage {
|
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
@ -673,6 +674,7 @@ impl ChatCompletion {
|
|||||||
OutputMessage::ChatMessage(TextMessage {
|
OutputMessage::ChatMessage(TextMessage {
|
||||||
role: "assistant".into(),
|
role: "assistant".into(),
|
||||||
content: output,
|
content: output,
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
(None, None) => {
|
(None, None) => {
|
||||||
@ -680,6 +682,7 @@ impl ChatCompletion {
|
|||||||
OutputMessage::ChatMessage(TextMessage {
|
OutputMessage::ChatMessage(TextMessage {
|
||||||
role: "assistant".into(),
|
role: "assistant".into(),
|
||||||
content: "".to_string(),
|
content: "".to_string(),
|
||||||
|
..Default::default()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -767,6 +770,7 @@ impl ChatCompletionChunk {
|
|||||||
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: delta,
|
content: delta,
|
||||||
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
|
(None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
@ -783,6 +787,7 @@ impl ChatCompletionChunk {
|
|||||||
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
(None, None) => ChatCompletionDelta::Chat(TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "".to_string(),
|
content: "".to_string(),
|
||||||
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
Self {
|
Self {
|
||||||
@ -1129,7 +1134,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
|
||||||
pub(crate) struct FunctionDefinition {
|
pub struct FunctionDefinition {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub description: Option<String>,
|
pub description: Option<String>,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
@ -1157,7 +1162,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
|
||||||
pub(crate) struct ToolCall {
|
pub struct ToolCall {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub r#type: String,
|
pub r#type: String,
|
||||||
pub function: FunctionDefinition,
|
pub function: FunctionDefinition,
|
||||||
@ -1176,17 +1181,31 @@ pub enum MessageChunk {
|
|||||||
ImageUrl { image_url: Url },
|
ImageUrl { image_url: Url },
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq, Default)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
role: String,
|
pub role: String,
|
||||||
|
#[serde(flatten)]
|
||||||
#[schema(example = "My name is David and I")]
|
#[schema(example = "My name is David and I")]
|
||||||
pub content: Option<MessageContent>,
|
pub body: MessageBody,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "\"David\"")]
|
#[schema(example = "\"David\"")]
|
||||||
name: Option<String>,
|
pub name: Option<String>,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
}
|
||||||
tool_calls: Option<Vec<ToolCall>>,
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum MessageBody {
|
||||||
|
// When a regular text message is provided.
|
||||||
|
Content {
|
||||||
|
#[serde(rename = "content")]
|
||||||
|
content: MessageContent,
|
||||||
|
},
|
||||||
|
// When tool calls are provided.
|
||||||
|
Tool {
|
||||||
|
#[serde(rename = "tool_calls")]
|
||||||
|
tool_calls: Vec<ToolCall>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
|
||||||
@ -1213,22 +1232,25 @@ impl MessageContent {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq, Default)]
|
||||||
pub struct TextMessage {
|
pub struct TextMessage {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
pub role: String,
|
||||||
#[schema(example = "My name is David and I")]
|
#[schema(example = "My name is David and I")]
|
||||||
pub content: String,
|
pub content: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_call_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Message> for TextMessage {
|
impl From<Message> for TextMessage {
|
||||||
fn from(value: Message) -> Self {
|
fn from(value: Message) -> Self {
|
||||||
let content = value
|
let content = match value.body {
|
||||||
.tool_calls
|
MessageBody::Content { content } => content,
|
||||||
.map(|calls| serde_json::to_string(&calls).unwrap_or_default())
|
MessageBody::Tool { tool_calls } => {
|
||||||
.map(MessageContent::SingleText)
|
let content = serde_json::to_string(&tool_calls).unwrap_or_default();
|
||||||
.or(value.content)
|
MessageContent::SingleText(content)
|
||||||
.unwrap_or_else(|| MessageContent::SingleText(String::new()));
|
}
|
||||||
|
};
|
||||||
TextMessage {
|
TextMessage {
|
||||||
role: value.role,
|
role: value.role,
|
||||||
content: match content {
|
content: match content {
|
||||||
@ -1242,6 +1264,7 @@ impl From<Message> for TextMessage {
|
|||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(""),
|
.join(""),
|
||||||
},
|
},
|
||||||
|
..Default::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1680,6 +1703,7 @@ mod tests {
|
|||||||
let message = OutputMessage::ChatMessage(TextMessage {
|
let message = OutputMessage::ChatMessage(TextMessage {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "This is the answer".to_string(),
|
content: "This is the answer".to_string(),
|
||||||
|
..Default::default()
|
||||||
});
|
});
|
||||||
let serialized = serde_json::to_string(&message).unwrap();
|
let serialized = serde_json::to_string(&message).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -49,8 +49,8 @@ request_body = SagemakerRequest,
|
|||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Chat Completion",
|
(status = 200, description = "Generated Chat Completion",
|
||||||
content(
|
content(
|
||||||
(SagemakerResponse = "application/json"),
|
("application/json" = SagemakerResponse),
|
||||||
(SagemakerStreamResponse = "text/event-stream"),
|
("text/event-stream" = SagemakerStreamResponse),
|
||||||
)),
|
)),
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Request failed during generation", "error_type": "generation"})),
|
example = json ! ({"error": "Request failed during generation", "error_type": "generation"})),
|
||||||
|
@ -28,7 +28,7 @@ use crate::{
|
|||||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
|
||||||
use crate::{ModelInfo, ModelsInfo};
|
use crate::{MessageBody, ModelInfo, ModelsInfo};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::{DefaultBodyLimit, Extension};
|
use axum::extract::{DefaultBodyLimit, Extension};
|
||||||
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||||
@ -111,9 +111,8 @@ request_body = CompatGenerateRequest,
|
|||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Text",
|
(status = 200, description = "Generated Text",
|
||||||
content(
|
content(
|
||||||
(Vec<GenerateResponse> = "application/json"),
|
("application/json" = Vec<GenerateResponse>),
|
||||||
(Vec<GenerateResponse> = "application/json"),
|
("text/event-stream" = StreamResponse),
|
||||||
(StreamResponse = "text/event-stream"),
|
|
||||||
)),
|
)),
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Request failed during generation"})),
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
@ -442,17 +441,17 @@ responses(
|
|||||||
(status = 200, description = "Generated Text", body = StreamResponse,
|
(status = 200, description = "Generated Text", body = StreamResponse,
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
content_type = "text/event-stream",
|
example = json ! ({"error": "Request failed during generation"}),
|
||||||
example = json ! ({"error": "Request failed during generation"})),
|
content_type = "text/event-stream"),
|
||||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
content_type = "text/event-stream",
|
example = json ! ({"error": "Model is overloaded"}),
|
||||||
example = json!({"error": "Model is overloaded"})),
|
content_type = "text/event-stream"),
|
||||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
content_type = "text/event-stream",
|
example = json ! ({"error": "Input validation error"}),
|
||||||
example = json!({"error": "Input validation error"})),
|
content_type = "text/event-stream"),
|
||||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
content_type = "text/event-stream",
|
example = json ! ({"error": "Incomplete generation"}),
|
||||||
example = json!({"error": "Incomplete generation"})),
|
content_type = "text/event-stream"),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
@ -676,8 +675,8 @@ request_body = CompletionRequest,
|
|||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Chat Completion",
|
(status = 200, description = "Generated Chat Completion",
|
||||||
content(
|
content(
|
||||||
(CompletionFinal= "application/json"),
|
("application/json" = CompletionFinal),
|
||||||
(Chunk= "text/event-stream"),
|
("text/event-stream" = Chunk),
|
||||||
)),
|
)),
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Request failed during generation"})),
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
@ -1202,8 +1201,8 @@ request_body = ChatRequest,
|
|||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Chat Completion",
|
(status = 200, description = "Generated Chat Completion",
|
||||||
content(
|
content(
|
||||||
(ChatCompletion = "application/json"),
|
("application/json" = ChatCompletion),
|
||||||
(ChatCompletionChunk = "text/event-stream"),
|
("text/event-stream" = ChatCompletionChunk),
|
||||||
)),
|
)),
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Request failed during generation"})),
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
@ -1578,6 +1577,7 @@ FunctionDefinition,
|
|||||||
ToolChoice,
|
ToolChoice,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
ChatTokenizeResponse,
|
ChatTokenizeResponse,
|
||||||
|
MessageBody,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
tags(
|
tags(
|
||||||
|
@ -174,7 +174,7 @@ mod tests {
|
|||||||
"What's Deep Learning?".to_string()
|
"What's Deep Learning?".to_string()
|
||||||
)),
|
)),
|
||||||
name: None,
|
name: None,
|
||||||
tool_calls: None,
|
..Default::default()
|
||||||
},],
|
},],
|
||||||
max_tokens: Some(128),
|
max_tokens: Some(128),
|
||||||
top_p: Some(0.95),
|
top_p: Some(0.95),
|
||||||
|
Loading…
Reference in New Issue
Block a user