diff --git a/Cargo.lock b/Cargo.lock index 4a466cf0..4603f77d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -128,9 +128,6 @@ name = "arbitrary" version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223" -dependencies = [ - "derive_arbitrary", -] [[package]] name = "arc-swap" @@ -308,7 +305,7 @@ dependencies = [ "http-body 0.4.6", "hyper 0.14.32", "itoa", - "matchit 0.7.3", + "matchit", "memchr", "mime", "percent-encoding", @@ -341,41 +338,7 @@ dependencies = [ "hyper 1.6.0", "hyper-util", "itoa", - "matchit 0.7.3", - "memchr", - "mime", - "percent-encoding", - "pin-project-lite", - "rustversion", - "serde", - "serde_json", - "serde_path_to_error", - "serde_urlencoded", - "sync_wrapper 1.0.2", - "tokio", - "tower 0.5.2", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "axum" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" -dependencies = [ - "axum-core 0.5.0", - "bytes", - "form_urlencoded", - "futures-util", - "http 1.2.0", - "http-body 1.0.1", - "http-body-util", - "hyper 1.6.0", - "hyper-util", - "itoa", - "matchit 0.8.4", + "matchit", "memchr", "mime", "percent-encoding", @@ -431,26 +394,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "axum-core" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" -dependencies = [ - "bytes", - "futures-util", - "http 1.2.0", - "http-body 1.0.1", - "http-body-util", - "mime", - "pin-project-lite", - "rustversion", - "sync_wrapper 1.0.2", - "tower-layer", - "tower-service", - "tracing", -] - [[package]] name = "axum-tracing-opentelemetry" version = "0.16.0" @@ -1165,17 +1108,6 @@ dependencies = [ "powerfmt", ] -[[package]] -name = "derive_arbitrary" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.98", -] - [[package]] name = "derive_builder" version = "0.20.2" @@ -2455,12 +2387,6 @@ dependencies = [ "scopeguard", ] -[[package]] -name = "lockfree-object-pool" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" - [[package]] name = "log" version = "0.4.25" @@ -2522,12 +2448,6 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" -[[package]] -name = "matchit" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" - [[package]] name = "maybe-rayon" version = "0.1.1" @@ -4784,7 +4704,7 @@ dependencies = [ "anyhow", "async-stream", "async-trait", - "axum 0.8.1", + "axum 0.7.9", "axum-tracing-opentelemetry", "base64 0.22.1", "chrono", @@ -4852,7 +4772,7 @@ version = "3.1.1-dev0" dependencies = [ "async-stream", "async-trait", - "axum 0.8.1", + "axum 0.7.9", "axum-tracing-opentelemetry", "base64 0.22.1", "clap 4.5.30", @@ -4901,7 +4821,7 @@ version = "3.1.1-dev0" dependencies = [ "async-stream", "async-trait", - "axum 0.8.1", + "axum 0.7.9", "axum-tracing-opentelemetry", "base64 0.22.1", "clap 4.5.30", @@ -5639,9 +5559,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "utoipa" -version = "5.3.1" +version = "4.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "435c6f69ef38c9017b4b4eea965dfb91e71e53d869e896db40d1cf2441dd75c0" +checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" dependencies = [ "indexmap 2.7.1", "serde", @@ -5651,10 +5571,11 @@ dependencies = [ [[package]] name = "utoipa-gen" -version = "5.3.1" +version = "4.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a77d306bc75294fd52f3e99b13ece67c02c1a2789190a6f31d32f736624326f7" +checksum = "20c24e8ab68ff9ee746aad22d39b5535601e6416d1b0feeabf78be986a5c4392" dependencies = [ + "proc-macro-error", "proc-macro2", "quote", "regex", @@ -5663,18 +5584,16 @@ dependencies = [ [[package]] name = "utoipa-swagger-ui" -version = "9.0.0" +version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "161166ec520c50144922a625d8bc4925cc801b2dda958ab69878527c0e5c5d61" +checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da" dependencies = [ - "axum 0.8.1", - "base64 0.22.1", + "axum 0.7.9", "mime_guess", "regex", "rust-embed", "serde", "serde_json", - "url", "utoipa", "zip", ] @@ -6404,33 +6323,14 @@ dependencies = [ [[package]] name = "zip" -version = "2.2.2" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae9c1ea7b3a5e1f4b922ff856a129881167511563dc219869afe3787fc0c1a45" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" dependencies = [ - "arbitrary", + "byteorder", "crc32fast", "crossbeam-utils", - "displaydoc", "flate2", - "indexmap 2.7.1", - "memchr", - "thiserror 2.0.11", - "zopfli", -] - -[[package]] -name = "zopfli" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5019f391bac5cf252e93bbcc53d039ffd62c7bfb7c150414d61369afe57e946" -dependencies = [ - "bumpalo", - "crc32fast", - "lockfree-object-pool", - "log", - "once_cell", - "simd-adler32", ] [[package]] diff --git a/backends/v2/Cargo.toml b/backends/v2/Cargo.toml index b5638bfc..0decf41a 100644 --- a/backends/v2/Cargo.toml +++ b/backends/v2/Cargo.toml @@ -16,7 +16,7 @@ path = "src/main.rs" [dependencies] async-trait = "0.1.74" async-stream = "0.3.5" -axum = { version = "0.8", features = ["json"] } +axum = { version = "0.7", features = ["json"] } axum-tracing-opentelemetry = "0.16" text-generation-router = { path = "../../router" } clap = { version = "4.4.5", features = ["derive", "env"] } @@ -48,8 +48,8 @@ tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } -utoipa = { version = "5.3.1", features = ["axum_extras"] } -utoipa-swagger-ui = { version = "9.0.0", features = ["axum"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } init-tracing-opentelemetry = { version = "0.14.1", features = [ "opentelemetry-otlp", ] } diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 7bc3181b..996290ed 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -16,7 +16,7 @@ path = "src/main.rs" [dependencies] async-trait = "0.1.74" async-stream = "0.3.5" -axum = { version = "0.8", features = ["json"] } +axum = { version = "0.7", features = ["json"] } axum-tracing-opentelemetry = "0.16" text-generation-router = { path = "../../router" } clap = { version = "4.4.5", features = ["derive", "env"] } @@ -48,8 +48,8 @@ tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.37" tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } -utoipa = { version = "5.3.1", features = ["axum_extras"] } -utoipa-swagger-ui = { version = "9.0.0", features = ["axum"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } init-tracing-opentelemetry = { version = "0.14.1", features = [ "opentelemetry-otlp", ] } diff --git a/docs/openapi.json b/docs/openapi.json index c1b64a21..9de76e47 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1,5 +1,5 @@ { - "openapi": "3.1.0", + "openapi": "3.0.3", "info": { "title": "Text Generation Inference", "description": "Text Generation Webserver", @@ -757,12 +757,10 @@ } }, "seed": { - "type": [ - "integer", - "null" - ], + "type": "integer", "format": "int64", "example": 42, + "nullable": true, "minimum": 0 }, "tokens": { @@ -831,10 +829,8 @@ "$ref": "#/components/schemas/ChatCompletionDelta" }, "finish_reason": { - "type": [ - "string", - "null" - ] + "type": "string", + "nullable": true }, "index": { "type": "integer", @@ -842,14 +838,12 @@ "minimum": 0 }, "logprobs": { - "oneOf": [ - { - "type": "null" - }, + "allOf": [ { "$ref": "#/components/schemas/ChatCompletionLogprobs" } - ] + ], + "nullable": true } } }, @@ -886,14 +880,12 @@ "type": "string" }, "usage": { - "oneOf": [ - { - "type": "null" - }, + "allOf": [ { "$ref": "#/components/schemas/Usage" } - ] + ], + "nullable": true } } }, @@ -914,14 +906,12 @@ "minimum": 0 }, "logprobs": { - "oneOf": [ - { - "type": "null" - }, + "allOf": [ { "$ref": "#/components/schemas/ChatCompletionLogprobs" } - ] + ], + "nullable": true }, "message": { "$ref": "#/components/schemas/OutputMessage" @@ -998,42 +988,34 @@ ], "properties": { "frequency_penalty": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.", - "example": "1.0" + "example": "1.0", + "nullable": true }, "logit_bias": { - "type": [ - "array", - "null" - ], + "type": "array", "items": { "type": "number", "format": "float" }, - "description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token." + "description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.", + "nullable": true }, "logprobs": { - "type": [ - "boolean", - "null" - ], + "type": "boolean", "description": "Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each\noutput token returned in the content of message.", - "example": "false" + "example": "false", + "nullable": true }, "max_tokens": { - "type": [ - "integer", - "null" - ], + "type": "integer", "format": "int32", "description": "The maximum number of tokens that can be generated in the chat completion.", "default": "1024", "example": "32", + "nullable": true, "minimum": 0 }, "messages": { @@ -1045,136 +1027,107 @@ "example": "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]" }, "model": { - "type": [ - "string", - "null" - ], + "type": "string", "description": "[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - "example": "mistralai/Mistral-7B-Instruct-v0.2" + "example": "mistralai/Mistral-7B-Instruct-v0.2", + "nullable": true }, "n": { - "type": [ - "integer", - "null" - ], + "type": "integer", "format": "int32", "description": "UNUSED\nHow many chat completion choices to generate for each input message. Note that you will be charged based on the\nnumber of generated tokens across all of the choices. Keep n as 1 to minimize costs.", "example": "2", + "nullable": true, "minimum": 0 }, "presence_penalty": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,\nincreasing the model's likelihood to talk about new topics", - "example": 0.1 + "example": 0.1, + "nullable": true }, "response_format": { - "oneOf": [ + "allOf": [ { - "type": "null" - }, - { - "$ref": "#/components/schemas/GrammarType", - "description": "Response format constraints for the generation.\n\nNOTE: A request can use `response_format` OR `tools` but not both." + "$ref": "#/components/schemas/GrammarType" } ], - "default": "null" + "default": "null", + "nullable": true }, "seed": { - "type": [ - "integer", - "null" - ], + "type": "integer", "format": "int64", "example": 42, + "nullable": true, "minimum": 0 }, "stop": { - "type": [ - "array", - "null" - ], + "type": "array", "items": { "type": "string" }, "description": "Up to 4 sequences where the API will stop generating further tokens.", - "example": "null" + "example": "null", + "nullable": true }, "stream": { "type": "boolean" }, "stream_options": { - "oneOf": [ + "allOf": [ { - "type": "null" - }, - { - "$ref": "#/components/schemas/StreamOptions", - "description": "Options for streaming response. Only set this when you set stream: true." + "$ref": "#/components/schemas/StreamOptions" } - ] + ], + "nullable": true }, "temperature": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", "description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\nlower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.", - "example": 1.0 + "example": 1.0, + "nullable": true }, "tool_choice": { - "oneOf": [ + "allOf": [ { - "type": "null" - }, - { - "$ref": "#/components/schemas/ToolChoice", - "description": "A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter." + "$ref": "#/components/schemas/ToolChoice" } ], - "default": "auto" + "default": "auto", + "nullable": true }, "tool_prompt": { - "type": [ - "string", - "null" - ], + "type": "string", "description": "A prompt to be appended before the tools", - "example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables." + "example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.", + "nullable": true }, "tools": { - "type": [ - "array", - "null" - ], + "type": "array", "items": { "$ref": "#/components/schemas/Tool" }, "description": "A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of\nfunctions the model may generate JSON inputs for.", - "example": "null" + "example": "null", + "nullable": true }, "top_logprobs": { - "type": [ - "integer", - "null" - ], + "type": "integer", "format": "int32", "description": "An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with\nan associated log probability. logprobs must be set to true if this parameter is used.", "example": "5", + "nullable": true, "minimum": 0 }, "top_p": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", - "example": 0.95 + "example": 0.95, + "nullable": true } } }, @@ -1288,7 +1241,10 @@ } ] } - ] + ], + "discriminator": { + "propertyName": "object" + } }, "CompletionComplete": { "type": "object", @@ -1307,14 +1263,12 @@ "minimum": 0 }, "logprobs": { - "type": [ - "array", - "null" - ], + "type": "array", "items": { "type": "number", "format": "float" - } + }, + "nullable": true }, "text": { "type": "string" @@ -1366,91 +1320,72 @@ ], "properties": { "frequency_penalty": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.", - "example": "1.0" + "example": "1.0", + "nullable": true }, "max_tokens": { - "type": [ - "integer", - "null" - ], + "type": "integer", "format": "int32", "description": "The maximum number of tokens that can be generated in the chat completion.", "default": "1024", "example": "32", + "nullable": true, "minimum": 0 }, "model": { - "type": [ - "string", - "null" - ], + "type": "string", "description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - "example": "mistralai/Mistral-7B-Instruct-v0.2" + "example": "mistralai/Mistral-7B-Instruct-v0.2", + "nullable": true }, "prompt": { - "$ref": "#/components/schemas/Prompt", - "description": "The prompt to generate completions for." + "$ref": "#/components/schemas/Prompt" }, "repetition_penalty": { - "type": [ - "number", - "null" - ], - "format": "float" + "type": "number", + "format": "float", + "nullable": true }, "seed": { - "type": [ - "integer", - "null" - ], + "type": "integer", "format": "int64", "example": 42, + "nullable": true, "minimum": 0 }, "stop": { - "type": [ - "array", - "null" - ], + "type": "array", "items": { "type": "string" }, "description": "Up to 4 sequences where the API will stop generating further tokens.", - "example": "null" + "example": "null", + "nullable": true }, "stream": { "type": "boolean" }, "suffix": { - "type": [ - "string", - "null" - ], - "description": "The text to append to the prompt. This is useful for completing sentences or generating a paragraph of text.\nplease see the completion_template field in the model's tokenizer_config.json file for completion template." + "type": "string", + "description": "The text to append to the prompt. This is useful for completing sentences or generating a paragraph of text.\nplease see the completion_template field in the model's tokenizer_config.json file for completion template.", + "nullable": true }, "temperature": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", "description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\nlower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both.", - "example": 1.0 + "example": 1.0, + "nullable": true }, "top_p": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", - "example": 0.95 + "example": 0.95, + "nullable": true } } }, @@ -1489,13 +1424,11 @@ ], "properties": { "best_of_sequences": { - "type": [ - "array", - "null" - ], + "type": "array", "items": { "$ref": "#/components/schemas/BestOfSequence" - } + }, + "nullable": true }, "finish_reason": { "$ref": "#/components/schemas/FinishReason" @@ -1513,12 +1446,10 @@ } }, "seed": { - "type": [ - "integer", - "null" - ], + "type": "integer", "format": "int64", "example": 42, + "nullable": true, "minimum": 0 }, "tokens": { @@ -1572,10 +1503,8 @@ "type": "string" }, "name": { - "type": [ - "string", - "null" - ] + "type": "string", + "nullable": true } } }, @@ -1588,10 +1517,8 @@ "properties": { "arguments": {}, "description": { - "type": [ - "string", - "null" - ] + "type": "string", + "nullable": true }, "name": { "type": "string" @@ -1613,22 +1540,18 @@ "type": "object", "properties": { "adapter_id": { - "type": [ - "string", - "null" - ], + "type": "string", "description": "Lora adapter id", "default": "null", - "example": "null" + "example": "null", + "nullable": true }, "best_of": { - "type": [ - "integer", - "null" - ], + "type": "integer", "description": "Generate best_of sequences and return the one if the highest token logprobs.", "default": "null", "example": 1, + "nullable": true, "minimum": 0, "exclusiveMinimum": 0 }, @@ -1649,68 +1572,55 @@ "example": true }, "frequency_penalty": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", "description": "The parameter for frequency penalty. 1.0 means no penalty\nPenalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.", "default": "null", "example": 0.1, + "nullable": true, "exclusiveMinimum": -2 }, "grammar": { - "oneOf": [ + "allOf": [ { - "type": "null" - }, - { - "$ref": "#/components/schemas/GrammarType", - "description": "Grammar constraints for the generation." + "$ref": "#/components/schemas/GrammarType" } ], - "default": "null" + "default": "null", + "nullable": true }, "max_new_tokens": { - "type": [ - "integer", - "null" - ], + "type": "integer", "format": "int32", "description": "Maximum number of tokens to generate.", "default": "1024", "example": "20", + "nullable": true, "minimum": 0 }, "repetition_penalty": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", "description": "The parameter for repetition penalty. 1.0 means no penalty.\nSee [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.", "default": "null", "example": 1.03, + "nullable": true, "exclusiveMinimum": 0 }, "return_full_text": { - "type": [ - "boolean", - "null" - ], + "type": "boolean", "description": "Whether to prepend the prompt to the generated text", "default": "null", - "example": false + "example": false, + "nullable": true }, "seed": { - "type": [ - "integer", - "null" - ], + "type": "integer", "format": "int64", "description": "Random sampling seed.", "default": "null", "example": "null", + "nullable": true, "minimum": 0, "exclusiveMinimum": 0 }, @@ -1726,70 +1636,58 @@ "maxItems": 4 }, "temperature": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", "description": "The value used to module the logits distribution.", "default": "null", "example": 0.5, + "nullable": true, "exclusiveMinimum": 0 }, "top_k": { - "type": [ - "integer", - "null" - ], + "type": "integer", "format": "int32", "description": "The number of highest probability vocabulary tokens to keep for top-k-filtering.", "default": "null", "example": 10, + "nullable": true, "exclusiveMinimum": 0 }, "top_n_tokens": { - "type": [ - "integer", - "null" - ], + "type": "integer", "format": "int32", "description": "The number of highest probability vocabulary tokens to keep for top-n-filtering.", "default": "null", "example": 5, + "nullable": true, "minimum": 0, "exclusiveMinimum": 0 }, "top_p": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", "description": "Top-p value for nucleus sampling.", "default": "null", "example": 0.95, + "nullable": true, "maximum": 1, "exclusiveMinimum": 0 }, "truncate": { - "type": [ - "integer", - "null" - ], + "type": "integer", "description": "Truncate inputs tokens to the given size.", "default": "null", "example": "null", + "nullable": true, "minimum": 0 }, "typical_p": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", "description": "Typical Decoding mass\nSee [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.", "default": "null", "example": 0.95, + "nullable": true, "maximum": 1, "exclusiveMinimum": 0 }, @@ -1823,14 +1721,12 @@ ], "properties": { "details": { - "oneOf": [ - { - "type": "null" - }, + "allOf": [ { "$ref": "#/components/schemas/Details" } - ] + ], + "nullable": true }, "generated_text": { "type": "string", @@ -1842,10 +1738,9 @@ "oneOf": [ { "type": "object", - "description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions.", "required": [ - "value", - "type" + "type", + "value" ], "properties": { "type": { @@ -1857,20 +1752,13 @@ "value": { "description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions." } - }, - "example": { - "properties": { - "location": { - "type": "string" - } - } } }, { "type": "object", "required": [ - "value", - "type" + "type", + "value" ], "properties": { "type": { @@ -1884,7 +1772,10 @@ } } } - ] + ], + "discriminator": { + "propertyName": "type" + } }, "Info": { "type": "object", @@ -1902,11 +1793,9 @@ ], "properties": { "docker_label": { - "type": [ - "string", - "null" - ], - "example": "null" + "type": "string", + "example": "null", + "nullable": true }, "max_best_of": { "type": "integer", @@ -1945,18 +1834,14 @@ "example": "bigscience/blomm-560m" }, "model_pipeline_tag": { - "type": [ - "string", - "null" - ], - "example": "text-generation" + "type": "string", + "example": "text-generation", + "nullable": true }, "model_sha": { - "type": [ - "string", - "null" - ], - "example": "e985a63cdc139290c5f700ff1929f0b5942cced2" + "type": "string", + "example": "e985a63cdc139290c5f700ff1929f0b5942cced2", + "nullable": true }, "router": { "type": "string", @@ -1964,11 +1849,9 @@ "example": "text-generation-router" }, "sha": { - "type": [ - "string", - "null" - ], - "example": "null" + "type": "string", + "example": "null", + "nullable": true }, "validation_workers": { "type": "integer", @@ -1982,42 +1865,57 @@ } }, "Message": { - "type": "object", - "required": [ - "role" - ], - "properties": { - "content": { - "oneOf": [ - { - "type": "null" + "allOf": [ + { + "$ref": "#/components/schemas/MessageBody" + }, + { + "type": "object", + "required": [ + "role" + ], + "properties": { + "name": { + "type": "string", + "example": "\"David\"", + "nullable": true }, - { - "$ref": "#/components/schemas/MessageContent" + "role": { + "type": "string", + "example": "user" } - ] - }, - "name": { - "type": [ - "string", - "null" - ], - "example": "\"David\"" - }, - "role": { - "type": "string", - "example": "user" - }, - "tool_calls": { - "type": [ - "array", - "null" - ], - "items": { - "$ref": "#/components/schemas/ToolCall" } } - } + ] + }, + "MessageBody": { + "oneOf": [ + { + "type": "object", + "required": [ + "content" + ], + "properties": { + "content": { + "$ref": "#/components/schemas/MessageContent" + } + } + }, + { + "type": "object", + "required": [ + "tool_calls" + ], + "properties": { + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolCall" + } + } + } + } + ] }, "MessageChunk": { "oneOf": [ @@ -2057,7 +1955,10 @@ } } } - ] + ], + "discriminator": { + "propertyName": "type" + } }, "MessageContent": { "oneOf": [ @@ -2126,12 +2027,10 @@ "minimum": 0 }, "logprob": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", - "example": -0.34 + "example": -0.34, + "nullable": true }, "text": { "type": "string", @@ -2239,12 +2138,10 @@ "minimum": 0 }, "seed": { - "type": [ - "integer", - "null" - ], + "type": "integer", "format": "int64", "example": 42, + "nullable": true, "minimum": 0 } } @@ -2270,23 +2167,19 @@ ], "properties": { "details": { - "oneOf": [ - { - "type": "null" - }, + "allOf": [ { "$ref": "#/components/schemas/StreamDetails" } ], - "default": "null" + "default": "null", + "nullable": true }, "generated_text": { - "type": [ - "string", - "null" - ], + "type": "string", "default": "null", - "example": "test" + "example": "test", + "nullable": true }, "index": { "type": "integer", @@ -2318,6 +2211,10 @@ "role": { "type": "string", "example": "user" + }, + "tool_call_id": { + "type": "string", + "nullable": true } } }, @@ -2337,12 +2234,10 @@ "minimum": 0 }, "logprob": { - "type": [ - "number", - "null" - ], + "type": "number", "format": "float", - "example": -0.34 + "example": -0.34, + "nullable": true }, "special": { "type": "boolean", @@ -2455,14 +2350,12 @@ }, { "type": "object", - "description": "Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool.", "required": [ "function" ], "properties": { "function": { - "$ref": "#/components/schemas/FunctionName", - "description": "Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool." + "$ref": "#/components/schemas/FunctionName" } } } diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py index 3a5e9e3e..b8a90cff 100644 --- a/integration-tests/models/test_tools_llama.py +++ b/integration-tests/models/test_tools_llama.py @@ -481,7 +481,6 @@ async def test_flash_llama_tool_reply_response( messages=[ {"role": "user", "content": "What's the weather like in Paris today?"}, { - "content": "", "role": "assistant", "tool_calls": [ { diff --git a/router/Cargo.toml b/router/Cargo.toml index 278efef3..9326258d 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -11,7 +11,7 @@ homepage.workspace = true anyhow = "1" async-trait = "0.1.74" async-stream = "0.3.5" -axum = { version = "0.8", features = ["json"] } +axum = { version = "0.7", features = ["json"] } axum-tracing-opentelemetry = "0.16" clap = { version = "4.4.5", features = ["derive", "env"] } futures = "0.3.28" @@ -42,8 +42,8 @@ tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.40" tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] } -utoipa = { version = "5.3.1", features = ["axum_extras"] } -utoipa-swagger-ui = { version = "9.0.0", features = ["axum"] } +utoipa = { version = "4.2.0", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } ngrok = { version = "0.13.1", features = ["axum"], optional = true } init-tracing-opentelemetry = { version = "0.14.1", features = [ "opentelemetry-otlp", diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 37fd48bd..f937e776 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -1,5 +1,7 @@ use crate::infer::InferError; -use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool}; +use crate::{ + ChatTemplateInputs, Message, MessageBody, MessageChunk, TextMessage, TokenizerConfigToken, Tool, +}; use chrono::Local; use minijinja::{Environment, ErrorKind, Template}; use minijinja_contrib::pycompat; @@ -73,8 +75,10 @@ impl ChatTemplate { // if the `tools` variable is used in the template, we just append the tool_prompt format!("\n---\n{}", tool_prompt) }; - if let Some(content) = messages.last_mut().and_then(|msg| msg.content.as_mut()) { - content.push(MessageChunk::Text { text }) + if let Some(last_message) = messages.last_mut() { + if let MessageBody::Content { content } = &mut last_message.body { + content.push(MessageChunk::Text { text }); + } } Some(tools) } @@ -158,18 +162,22 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hi!".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "Hello how can I help?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "What is Deep Learning?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "magic!".to_string(), + ..Default::default() }, ], bos_token: Some("[BOS]"), @@ -186,6 +194,182 @@ mod tests { ); } + #[test] + fn test_chat_template_with_tool_response() { + let env = Environment::new(); + + // template modified from Llama-3.1-8B-Instruct + // https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/blob/0e9e39f249a16976918f6564b8830bc894c89659/tokenizer_config.json#L2053 + // the main change is accesing `message.tool_call_id` from the messages + let source = r#" + {{- bos_token }} + {%- if custom_tools is defined %} + {%- set tools = custom_tools %} + {%- endif %} + {%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} + {%- endif %} + {%- if not date_string is defined %} + {%- set date_string = "26 Jul 2024" %} + {%- endif %} + {%- if not tools is defined %} + {%- set tools = none %} + {%- endif %} + + {#- This block extracts the system message, so we can slot it into the right place. #} + {%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {%- set system_message = "" %} + {%- endif %} + + {#- System message + builtin tools #} + {{- "<|start_header_id|>system<|end_header_id|>\n\n" }} + {%- if builtin_tools is defined or tools is not none %} + {{- "Environment: ipython\n" }} + {%- endif %} + {%- if builtin_tools is defined %} + {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}} + {%- endif %} + {{- "Cutting Knowledge Date: December 2023\n" }} + {{- "Today Date: " + date_string + "\n\n" }} + {%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {%- endif %} + {{- system_message }} + {{- "<|eot_id|>" }} + + {#- Custom tools are passed in a user message with some extra guidance #} + {%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} + {%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} + {%- endif %} + + {%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {%- if builtin_tools is defined and tool_call.name in builtin_tools %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- "<|python_tag|>" + tool_call.name + ".call(" }} + {%- for arg_name, arg_val in tool_call.arguments | items %} + {{- arg_name + '="' + arg_val + '"' }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- else %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {%- endif %} + {%- if builtin_tools is defined %} + {#- This means we're in ipython mode #} + {{- "<|eom_id|>" }} + {%- else %} + {{- "<|eot_id|>" }} + {%- endif %} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {{- "TOOL CALL ID: " + message.tool_call_id + "\n\n" }} + {%- if message.content is mapping or message.content is iterable %} + {{- message.content | tojson }} + {%- else %} + {{- message.content }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} + {%- endfor %} + {%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} + {%- endif %} + "#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + TextMessage { + role: "user".to_string(), + content: "Hi!".to_string(), + ..Default::default() + }, + TextMessage { + role: "assistant".to_string(), + content: r#"[ { "id": "0", "function": { "arguments": '{"longitude": 2.2945, "latitude": 48.8567}', "name": "get_weather", "description": None, }, "type": "function", } ]"#.to_string(), + ..Default::default() + }, + TextMessage { + role: "tool".to_string(), + content: "6.7".to_string(), + tool_call_id: Some("0".to_string()), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + ..Default::default() + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + + assert_eq!( + result, + r#"[BOS]<|start_header_id|>system<|end_header_id|> + +Cutting Knowledge Date: December 2023 +Today Date: 26 Jul 2024 + +<|eot_id|><|start_header_id|>user<|end_header_id|> + +Hi!<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +[ { "id": "0", "function": { "arguments": '{"longitude": 2.2945, "latitude": 48.8567}', "name": "get_weather", "description": None, }, "type": "function", } ]<|eot_id|><|start_header_id|>ipython<|end_header_id|> + +TOOL CALL ID: 0 + +"6.7"<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +"# + ); + } + #[test] fn test_chat_template_loop_controls() { // some chat templates as e.g. CohereForAI/c4ai-command-r7b-12-202 contain `break` @@ -224,18 +408,22 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hi!".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "Hello how can I help?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "What is Deep Learning?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "magic!".to_string(), + ..Default::default() }, ], bos_token: Some("[BOS]"), @@ -287,22 +475,27 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hi!".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "Hi again!".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "Hello how can I help?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "What is Deep Learning?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "magic!".to_string(), + ..Default::default() }, ], bos_token: Some("[BOS]"), @@ -359,18 +552,22 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hi!".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "Hello how can I help?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "What is Deep Learning?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "magic!".to_string(), + ..Default::default() }, ], bos_token: Some("[BOS]"), @@ -426,18 +623,22 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hi!".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "Hello how can I help?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "What is Deep Learning?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "magic!".to_string(), + ..Default::default() }, ], bos_token: Some("[BOS]"), @@ -479,18 +680,22 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hi!".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "Hello how can I help?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "What is Deep Learning?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "magic!".to_string(), + ..Default::default() }, ], bos_token: Some("[BOS]"), @@ -516,14 +721,17 @@ mod tests { TextMessage { role: "user".to_string(), content: "Hello, how are you?".to_string(), + ..Default::default() }, TextMessage { role: "assistant".to_string(), content: "I'm doing great. How can I help you today?".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "I'd like to show off how chat templating works!".to_string(), + ..Default::default() }, ]; @@ -531,6 +739,7 @@ mod tests { role: "system".to_string(), content: "You are a friendly chatbot who always responds in the style of a pirate" .to_string(), + ..Default::default() }] .iter() .chain(&example_chat) @@ -674,10 +883,12 @@ mod tests { TextMessage { role: "system".to_string(), content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(), + ..Default::default() }, TextMessage { role: "user".to_string(), content: "How many helicopters can a human eat in one sitting?".to_string(), + ..Default::default() }, ], add_generation_prompt: true, diff --git a/router/src/lib.rs b/router/src/lib.rs index 7da29823..94c7a48d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -663,6 +663,7 @@ impl ChatCompletion { (Some(content), None) => OutputMessage::ChatMessage(TextMessage { role: "assistant".into(), content, + ..Default::default() }), (None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage { role: "assistant".to_string(), @@ -673,6 +674,7 @@ impl ChatCompletion { OutputMessage::ChatMessage(TextMessage { role: "assistant".into(), content: output, + ..Default::default() }) } (None, None) => { @@ -680,6 +682,7 @@ impl ChatCompletion { OutputMessage::ChatMessage(TextMessage { role: "assistant".into(), content: "".to_string(), + ..Default::default() }) } }; @@ -767,6 +770,7 @@ impl ChatCompletionChunk { (Some(delta), _) => ChatCompletionDelta::Chat(TextMessage { role: "assistant".to_string(), content: delta, + ..Default::default() }), (None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta { role: "assistant".to_string(), @@ -783,6 +787,7 @@ impl ChatCompletionChunk { (None, None) => ChatCompletionDelta::Chat(TextMessage { role: "assistant".to_string(), content: "".to_string(), + ..Default::default() }), }; Self { @@ -1129,7 +1134,7 @@ where } #[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)] -pub(crate) struct FunctionDefinition { +pub struct FunctionDefinition { #[serde(default)] pub description: Option, pub name: String, @@ -1157,7 +1162,7 @@ pub(crate) struct ChatTemplateInputs<'a> { } #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)] -pub(crate) struct ToolCall { +pub struct ToolCall { pub id: String, pub r#type: String, pub function: FunctionDefinition, @@ -1176,17 +1181,31 @@ pub enum MessageChunk { ImageUrl { image_url: Url }, } -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq, Default)] +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] pub struct Message { #[schema(example = "user")] - role: String, + pub role: String, + #[serde(flatten)] #[schema(example = "My name is David and I")] - pub content: Option, + pub body: MessageBody, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] - name: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - tool_calls: Option>, + pub name: Option, +} + +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] +#[serde(untagged)] +pub enum MessageBody { + // When a regular text message is provided. + Content { + #[serde(rename = "content")] + content: MessageContent, + }, + // When tool calls are provided. + Tool { + #[serde(rename = "tool_calls")] + tool_calls: Vec, + }, } #[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)] @@ -1213,22 +1232,25 @@ impl MessageContent { } } -#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)] +#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq, Default)] pub struct TextMessage { #[schema(example = "user")] pub role: String, #[schema(example = "My name is David and I")] pub content: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, } impl From for TextMessage { fn from(value: Message) -> Self { - let content = value - .tool_calls - .map(|calls| serde_json::to_string(&calls).unwrap_or_default()) - .map(MessageContent::SingleText) - .or(value.content) - .unwrap_or_else(|| MessageContent::SingleText(String::new())); + let content = match value.body { + MessageBody::Content { content } => content, + MessageBody::Tool { tool_calls } => { + let content = serde_json::to_string(&tool_calls).unwrap_or_default(); + MessageContent::SingleText(content) + } + }; TextMessage { role: value.role, content: match content { @@ -1242,6 +1264,7 @@ impl From for TextMessage { .collect::>() .join(""), }, + ..Default::default() } } } @@ -1680,6 +1703,7 @@ mod tests { let message = OutputMessage::ChatMessage(TextMessage { role: "assistant".to_string(), content: "This is the answer".to_string(), + ..Default::default() }); let serialized = serde_json::to_string(&message).unwrap(); assert_eq!( diff --git a/router/src/sagemaker.rs b/router/src/sagemaker.rs index 41b6a332..750ef222 100644 --- a/router/src/sagemaker.rs +++ b/router/src/sagemaker.rs @@ -49,8 +49,8 @@ request_body = SagemakerRequest, responses( (status = 200, description = "Generated Chat Completion", content( -(SagemakerResponse = "application/json"), -(SagemakerStreamResponse = "text/event-stream"), +("application/json" = SagemakerResponse), +("text/event-stream" = SagemakerStreamResponse), )), (status = 424, description = "Generation Error", body = ErrorResponse, example = json ! ({"error": "Request failed during generation", "error_type": "generation"})), diff --git a/router/src/server.rs b/router/src/server.rs index 791167c0..e9aa4612 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -28,7 +28,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, }; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice}; -use crate::{ModelInfo, ModelsInfo}; +use crate::{MessageBody, ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::{DefaultBodyLimit, Extension}; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -111,9 +111,8 @@ request_body = CompatGenerateRequest, responses( (status = 200, description = "Generated Text", content( -(Vec = "application/json"), -(Vec = "application/json"), -(StreamResponse = "text/event-stream"), +("application/json" = Vec), +("text/event-stream" = StreamResponse), )), (status = 424, description = "Generation Error", body = ErrorResponse, example = json ! ({"error": "Request failed during generation"})), @@ -442,17 +441,17 @@ responses( (status = 200, description = "Generated Text", body = StreamResponse, content_type = "text/event-stream"), (status = 424, description = "Generation Error", body = ErrorResponse, -content_type = "text/event-stream", -example = json ! ({"error": "Request failed during generation"})), +example = json ! ({"error": "Request failed during generation"}), +content_type = "text/event-stream"), (status = 429, description = "Model is overloaded", body = ErrorResponse, -content_type = "text/event-stream", -example = json!({"error": "Model is overloaded"})), +example = json ! ({"error": "Model is overloaded"}), +content_type = "text/event-stream"), (status = 422, description = "Input validation error", body = ErrorResponse, -content_type = "text/event-stream", -example = json!({"error": "Input validation error"})), +example = json ! ({"error": "Input validation error"}), +content_type = "text/event-stream"), (status = 500, description = "Incomplete generation", body = ErrorResponse, -content_type = "text/event-stream", -example = json!({"error": "Incomplete generation"})), +example = json ! ({"error": "Incomplete generation"}), +content_type = "text/event-stream"), ) )] #[instrument( @@ -676,8 +675,8 @@ request_body = CompletionRequest, responses( (status = 200, description = "Generated Chat Completion", content( -(CompletionFinal= "application/json"), -(Chunk= "text/event-stream"), +("application/json" = CompletionFinal), +("text/event-stream" = Chunk), )), (status = 424, description = "Generation Error", body = ErrorResponse, example = json ! ({"error": "Request failed during generation"})), @@ -1202,8 +1201,8 @@ request_body = ChatRequest, responses( (status = 200, description = "Generated Chat Completion", content( -(ChatCompletion = "application/json"), -(ChatCompletionChunk = "text/event-stream"), +("application/json" = ChatCompletion), +("text/event-stream" = ChatCompletionChunk), )), (status = 424, description = "Generation Error", body = ErrorResponse, example = json ! ({"error": "Request failed during generation"})), @@ -1578,6 +1577,7 @@ FunctionDefinition, ToolChoice, ModelInfo, ChatTokenizeResponse, +MessageBody, ) ), tags( diff --git a/router/src/vertex.rs b/router/src/vertex.rs index c31b3059..5a4a3876 100644 --- a/router/src/vertex.rs +++ b/router/src/vertex.rs @@ -174,7 +174,7 @@ mod tests { "What's Deep Learning?".to_string() )), name: None, - tool_calls: None, + ..Default::default() },], max_tokens: Some(128), top_p: Some(0.95),