diff --git a/Cargo.lock b/Cargo.lock index 68d96726..6aa50975 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -267,6 +267,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" + [[package]] name = "bit-set" version = "0.5.3" @@ -2960,7 +2966,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "1.4.4" +version = "1.4.5" dependencies = [ "average", "clap", @@ -2981,7 +2987,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "1.4.4" +version = "1.4.5" dependencies = [ "futures", "grpc-metadata", @@ -2998,7 +3004,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "1.4.4" +version = "1.4.5" dependencies = [ "clap", "ctrlc", @@ -3014,7 +3020,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "1.4.4" +version = "1.4.5" dependencies = [ "async-stream", "axum", @@ -3601,11 +3607,11 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.9.6" +version = "2.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11f214ce18d8b2cbe84ed3aa6486ed3f5b285cf8d8fbdbce9f3f767a724adc35" +checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" dependencies = [ - "base64 0.21.7", + "base64 0.22.0", "flate2", "log", "native-tls", diff --git a/Cargo.toml b/Cargo.toml index d76cbc68..77a30f55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ members = [ resolver = "2" [workspace.package] -version = "1.4.4" +version = "1.4.5" edition = "2021" authors = ["Olivier Dehaene"] homepage = "https://github.com/huggingface/text-generation-inference" diff --git a/benchmark/src/app.rs b/benchmark/src/app.rs index b27c56b4..48ac976a 100644 --- a/benchmark/src/app.rs +++ b/benchmark/src/app.rs @@ -444,7 +444,7 @@ fn progress_gauge(title: &str, label: String, progress: f64, color: Color) -> Ga } /// Throughput paragraph -fn throughput_paragraph<'a>(throughput: &Vec, name: &'static str) -> Paragraph<'a> { +fn throughput_paragraph<'a>(throughput: &[f64], name: &'static str) -> Paragraph<'a> { // Throughput average/high/low texts let throughput_texts = statis_spans(throughput, "tokens/secs"); @@ -457,7 +457,7 @@ fn throughput_paragraph<'a>(throughput: &Vec, name: &'static str) -> Paragr } /// Latency paragraph -fn latency_paragraph<'a>(latency: &mut Vec, name: &'static str) -> Paragraph<'a> { +fn latency_paragraph<'a>(latency: &mut [f64], name: &'static str) -> Paragraph<'a> { // Latency average/high/low texts let mut latency_texts = statis_spans(latency, "ms"); @@ -483,7 +483,7 @@ fn latency_paragraph<'a>(latency: &mut Vec, name: &'static str) -> Paragrap } /// Average/High/Low spans -fn statis_spans<'a>(data: &Vec, unit: &'static str) -> Vec> { +fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec> { vec![ Line::from(vec![Span::styled( format!( @@ -543,7 +543,7 @@ fn latency_histogram<'a>( /// Latency/Throughput chart fn latency_throughput_chart<'a>( - latency_throughput: &'a Vec<(f64, f64)>, + latency_throughput: &'a [(f64, f64)], batch_sizes: &'a [u32], zoom: bool, name: &'static str, diff --git a/benchmark/src/table.rs b/benchmark/src/table.rs index c4819ff3..e18d7310 100644 --- a/benchmark/src/table.rs +++ b/benchmark/src/table.rs @@ -151,7 +151,7 @@ fn add_throuhgputs( } } -fn avg_min_max(data: &Vec) -> (f64, f64, f64) { +fn avg_min_max(data: &[f64]) -> (f64, f64, f64) { let average = data.iter().sum::() / data.len() as f64; let min = data .iter() @@ -164,7 +164,7 @@ fn avg_min_max(data: &Vec) -> (f64, f64, f64) { (average, *min, *max) } -fn px(data: &Vec, p: u32) -> f64 { +fn px(data: &[f64], p: u32) -> f64 { let i = (f64::from(p) / 100.0 * data.len() as f64) as usize; *data.get(i).unwrap_or(&std::f64::NAN) } diff --git a/docs/openapi.json b/docs/openapi.json index 75965d98..fdf1c804 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "1.4.4" + "version": "1.4.5" }, "paths": { "/": { @@ -471,6 +471,90 @@ } } } + }, + "/v1/completions": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Generate tokens", + "description": "Generate tokens", + "operationId": "completions", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CompletionRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Generated Text", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatCompletionChunk" + } + } + } + }, + "422": { + "description": "Input validation error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Input validation error" + } + } + } + }, + "424": { + "description": "Generation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Request failed during generation" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded" + } + } + } + }, + "500": { + "description": "Incomplete generation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Incomplete generation" + } + } + } + } + } + } } }, "components": { @@ -669,17 +753,25 @@ "ChatCompletionDelta": { "type": "object", "required": [ - "role", - "content" + "role" ], "properties": { "content": { "type": "string", - "example": "What is Deep Learning?" + "example": "What is Deep Learning?", + "nullable": true }, "role": { "type": "string", "example": "user" + }, + "tool_calls": { + "allOf": [ + { + "$ref": "#/components/schemas/DeltaToolCall" + } + ], + "nullable": true } } }, @@ -739,7 +831,8 @@ "ChatRequest": { "type": "object", "required": [ - "model" + "model", + "messages" ], "properties": { "frequency_penalty": { @@ -777,11 +870,12 @@ "items": { "$ref": "#/components/schemas/Message" }, - "description": "A list of messages comprising the conversation so far." + "description": "A list of messages comprising the conversation so far.", + "example": "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]" }, "model": { "type": "string", - "description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", + "description": "[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", "example": "mistralai/Mistral-7B-Instruct-v0.2" }, "n": { @@ -806,6 +900,15 @@ "nullable": true, "minimum": 0 }, + "stop": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Up to 4 sequences where the API will stop generating further tokens.", + "example": "null", + "nullable": true + }, "stream": { "type": "boolean" }, @@ -816,6 +919,29 @@ "example": 1.0, "nullable": true }, + "tool_choice": { + "allOf": [ + { + "$ref": "#/components/schemas/ToolType" + } + ], + "nullable": true + }, + "tool_prompt": { + "type": "string", + "description": "A prompt to be appended before the tools", + "example": "\"Based on the conversation, please choose the most appropriate tool to use: \"", + "nullable": true + }, + "tools": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Tool" + }, + "description": "A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of\nfunctions the model may generate JSON inputs for.", + "example": "null", + "nullable": true + }, "top_logprobs": { "type": "integer", "format": "int32", @@ -852,6 +978,164 @@ } } }, + "CompletionComplete": { + "type": "object", + "required": [ + "index", + "text", + "finish_reason" + ], + "properties": { + "finish_reason": { + "type": "string" + }, + "index": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "logprobs": { + "type": "array", + "items": { + "type": "number", + "format": "float" + }, + "nullable": true + }, + "text": { + "type": "string" + } + } + }, + "CompletionCompleteChunk": { + "type": "object", + "required": [ + "id", + "object", + "created", + "choices", + "model", + "system_fingerprint" + ], + "properties": { + "choices": { + "type": "array", + "items": { + "$ref": "#/components/schemas/CompletionComplete" + } + }, + "created": { + "type": "integer", + "format": "int64", + "minimum": 0 + }, + "id": { + "type": "string" + }, + "model": { + "type": "string" + }, + "object": { + "type": "string" + }, + "system_fingerprint": { + "type": "string" + } + } + }, + "CompletionRequest": { + "type": "object", + "required": [ + "model", + "prompt" + ], + "properties": { + "frequency_penalty": { + "type": "number", + "format": "float", + "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.", + "example": "1.0", + "nullable": true + }, + "max_tokens": { + "type": "integer", + "format": "int32", + "description": "The maximum number of tokens that can be generated in the chat completion.", + "default": "32", + "nullable": true, + "minimum": 0 + }, + "model": { + "type": "string", + "description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", + "example": "mistralai/Mistral-7B-Instruct-v0.2" + }, + "prompt": { + "type": "string", + "description": "The prompt to generate completions for.", + "example": "What is Deep Learning?" + }, + "repetition_penalty": { + "type": "number", + "format": "float", + "nullable": true + }, + "seed": { + "type": "integer", + "format": "int64", + "example": 42, + "nullable": true, + "minimum": 0 + }, + "stream": { + "type": "boolean" + }, + "suffix": { + "type": "string", + "description": "The text to append to the prompt. This is useful for completing sentences or generating a paragraph of text.\nplease see the completion_template field in the model's tokenizer_config.json file for completion template.", + "nullable": true + }, + "temperature": { + "type": "number", + "format": "float", + "description": "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\nlower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both.", + "example": 1.0, + "nullable": true + }, + "top_p": { + "type": "number", + "format": "float", + "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", + "example": 0.95, + "nullable": true + } + } + }, + "DeltaToolCall": { + "type": "object", + "required": [ + "index", + "id", + "type", + "function" + ], + "properties": { + "function": { + "$ref": "#/components/schemas/Function" + }, + "id": { + "type": "string" + }, + "index": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "type": { + "type": "string" + } + } + }, "Details": { "type": "object", "required": [ @@ -931,6 +1215,38 @@ ], "example": "Length" }, + "Function": { + "type": "object", + "required": [ + "arguments" + ], + "properties": { + "arguments": { + "type": "string" + }, + "name": { + "type": "string", + "nullable": true + } + } + }, + "FunctionDefinition": { + "type": "object", + "required": [ + "name", + "parameters" + ], + "properties": { + "description": { + "type": "string", + "nullable": true + }, + "name": { + "type": "string" + }, + "parameters": {} + } + }, "GenerateParameters": { "type": "object", "properties": { @@ -1261,13 +1577,13 @@ "Message": { "type": "object", "required": [ - "role", - "content" + "role" ], "properties": { "content": { "type": "string", - "example": "My name is David and I" + "example": "My name is David and I", + "nullable": true }, "name": { "type": "string", @@ -1277,6 +1593,13 @@ "role": { "type": "string", "example": "user" + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolCall" + }, + "nullable": true } } }, @@ -1437,6 +1760,64 @@ "$ref": "#/components/schemas/SimpleToken" } }, + "Tool": { + "type": "object", + "required": [ + "type", + "function" + ], + "properties": { + "function": { + "$ref": "#/components/schemas/FunctionDefinition" + }, + "type": { + "type": "string", + "example": "function" + } + } + }, + "ToolCall": { + "type": "object", + "required": [ + "id", + "type", + "function" + ], + "properties": { + "function": { + "$ref": "#/components/schemas/FunctionDefinition" + }, + "id": { + "type": "integer", + "format": "int32", + "minimum": 0 + }, + "type": { + "type": "string" + } + } + }, + "ToolType": { + "oneOf": [ + { + "type": "object", + "required": [ + "FunctionName" + ], + "properties": { + "FunctionName": { + "type": "string" + } + } + }, + { + "type": "string", + "enum": [ + "OneOf" + ] + } + ] + }, "Usage": { "type": "object", "required": [ diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json index 467b8ce3..543be115 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json @@ -17,7 +17,7 @@ "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "1.4.4-native", + "system_fingerprint": "1.4.5-native", "usage": { "completion_tokens": 100, "prompt_tokens": 60, diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json index 8bdb7465..728e90a4 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json @@ -31,7 +31,7 @@ "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "1.4.4-native", + "system_fingerprint": "1.4.5-native", "usage": { "completion_tokens": 29, "prompt_tokens": 316, diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json index 5ba297b1..2e0efb86 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json @@ -31,7 +31,7 @@ "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "1.4.4-native", + "system_fingerprint": "1.4.5-native", "usage": { "completion_tokens": 29, "prompt_tokens": 316, diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json index 522624bc..91854223 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json @@ -30,7 +30,7 @@ "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "1.4.4-native", + "system_fingerprint": "1.4.5-native", "usage": { "completion_tokens": 21, "prompt_tokens": 187, diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json index c085100d..e0c7aed6 100644 --- a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json @@ -23,5 +23,5 @@ "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "1.4.4-native" + "system_fingerprint": "1.4.5-native" } diff --git a/integration-tests/pyproject.toml b/integration-tests/pyproject.toml index cab74c46..ad217072 100644 --- a/integration-tests/pyproject.toml +++ b/integration-tests/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation-integration-tests" -version = "1.4.4" +version = "1.4.5" description = "Text Generation Inference integration tests" authors = ["Nicolas Patry "] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 35e23316..2630862d 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -612,7 +612,7 @@ fn shard_manager( // We read stderr in another thread as it seems that lines() can block in some cases let (err_sender, err_receiver) = mpsc::channel(); thread::spawn(move || { - for line in shard_stderr_reader.lines().flatten() { + for line in shard_stderr_reader.lines().map_while(Result::ok) { err_sender.send(line).unwrap_or(()); } }); @@ -730,7 +730,7 @@ impl TryFrom<&String> for PythonLogMessage { } fn log_lines(lines: Lines) { - for line in lines.flatten() { + for line in lines.map_while(Result::ok) { match PythonLogMessage::try_from(&line) { Ok(log) => log.trace(), Err(_) => tracing::debug!("{line}"), @@ -882,7 +882,7 @@ fn download_convert_model(args: &Args, running: Arc) -> Result<(), L // We read stderr in another thread as it seems that lines() can block in some cases let (err_sender, err_receiver) = mpsc::channel(); thread::spawn(move || { - for line in download_stderr.lines().flatten() { + for line in download_stderr.lines().map_while(Result::ok) { err_sender.send(line).unwrap_or(()); } }); diff --git a/router/src/server.rs b/router/src/server.rs index 2b3a4248..7ec9ddee 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -14,7 +14,7 @@ use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob, ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk, - CompletionRequest, VertexRequest, VertexResponse, + CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, }; use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; use axum::extract::Extension; @@ -1213,6 +1213,12 @@ pub async fn run( ErrorResponse, GrammarType, Usage, + DeltaToolCall, + ToolType, + Tool, + ToolCall, + Function, + FunctionDefinition, ) ), tags( diff --git a/server/pyproject.toml b/server/pyproject.toml index 2fdfa8b8..6f892c14 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation-server" -version = "1.4.4" +version = "1.4.5" description = "Text Generation Inference Python gRPC Server" authors = ["Olivier Dehaene "]