use crate::infer::Infer; use crate::server::{generate_internal, ComputeType}; use crate::{ ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest, GrammarType, Message, StreamOptions, Tool, ToolChoice, }; use axum::extract::Extension; use axum::http::{HeaderMap, StatusCode}; use axum::response::{IntoResponse, Response}; use axum::Json; use serde::{Deserialize, Serialize}; use tracing::instrument; use utoipa::ToSchema; #[derive(Clone, Deserialize, ToSchema)] #[cfg_attr(test, derive(Debug, PartialEq))] pub(crate) struct GenerateVertexInstance { #[schema(example = "What is Deep Learning?")] pub inputs: String, #[schema(nullable = true, default = "null", example = "null")] pub parameters: Option, } #[derive(Clone, Deserialize, ToSchema)] #[cfg_attr(test, derive(Debug, PartialEq))] pub(crate) struct VertexChat { messages: Vec, // Messages is ignored there. #[serde(default)] parameters: VertexParameters, } #[derive(Clone, Deserialize, ToSchema, Serialize, Default)] #[cfg_attr(test, derive(Debug, PartialEq))] pub(crate) struct VertexParameters { #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. pub model: Option, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, /// decreasing the model's likelihood to repeat the same line verbatim. #[serde(default)] #[schema(example = "1.0")] pub frequency_penalty: Option, /// UNUSED /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, /// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, /// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should /// result in a ban or exclusive selection of the relevant token. #[serde(default)] pub logit_bias: Option>, /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each /// output token returned in the content of message. #[serde(default)] #[schema(example = "false")] pub logprobs: Option, /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with /// an associated log probability. logprobs must be set to true if this parameter is used. #[serde(default)] #[schema(example = "5")] pub top_logprobs: Option, /// The maximum number of tokens that can be generated in the chat completion. #[serde(default)] #[schema(example = "32")] pub max_tokens: Option, /// UNUSED /// How many chat completion choices to generate for each input message. Note that you will be charged based on the /// number of generated tokens across all of the choices. Keep n as 1 to minimize costs. #[serde(default)] #[schema(nullable = true, example = "2")] pub n: Option, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, /// increasing the model's likelihood to talk about new topics #[serde(default)] #[schema(nullable = true, example = 0.1)] pub presence_penalty: Option, /// Up to 4 sequences where the API will stop generating further tokens. #[serde(default)] #[schema(nullable = true, example = "null")] pub stop: Option>, #[serde(default = "bool::default")] pub stream: bool, #[schema(nullable = true, example = 42)] pub seed: Option, /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while /// lower values like 0.2 will make it more focused and deterministic. /// /// We generally recommend altering this or `top_p` but not both. #[serde(default)] #[schema(nullable = true, example = 1.0)] pub temperature: Option, /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the /// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. #[serde(default)] #[schema(nullable = true, example = 0.95)] pub top_p: Option, /// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of /// functions the model may generate JSON inputs for. #[serde(default)] #[schema(nullable = true, example = "null")] pub tools: Option>, /// A prompt to be appended before the tools #[serde(default)] #[schema( nullable = true, example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables." )] pub tool_prompt: Option, /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, example = "null")] pub tool_choice: ToolChoice, /// Response format constraints for the generation. /// /// NOTE: A request can use `response_format` OR `tools` but not both. #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub response_format: Option, /// A guideline to be used in the chat_template #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub guideline: Option, /// Options for streaming response. Only set this when you set stream: true. #[serde(default)] #[schema(nullable = true, example = "null")] pub stream_options: Option, } impl From for ChatRequest { fn from(val: VertexChat) -> Self { Self { messages: val.messages, frequency_penalty: val.parameters.frequency_penalty, guideline: val.parameters.guideline, logit_bias: val.parameters.logit_bias, logprobs: val.parameters.logprobs, max_tokens: val.parameters.max_tokens, model: val.parameters.model, n: val.parameters.n, presence_penalty: val.parameters.presence_penalty, response_format: val.parameters.response_format, seed: val.parameters.seed, stop: val.parameters.stop, stream_options: val.parameters.stream_options, stream: val.parameters.stream, temperature: val.parameters.temperature, tool_choice: val.parameters.tool_choice, tool_prompt: val.parameters.tool_prompt, tools: val.parameters.tools, top_logprobs: val.parameters.top_logprobs, top_p: val.parameters.top_p, } } } #[derive(Clone, Deserialize, ToSchema)] #[cfg_attr(test, derive(Debug, PartialEq))] #[serde(untagged)] pub(crate) enum VertexInstance { Generate(GenerateVertexInstance), Chat(VertexChat), } #[derive(Deserialize, ToSchema)] #[cfg_attr(test, derive(Debug, PartialEq))] pub(crate) struct VertexRequest { #[serde(rename = "instances")] pub instances: Vec, } #[derive(Clone, Deserialize, ToSchema, Serialize)] pub(crate) struct VertexResponse { pub predictions: Vec, } /// Generate tokens from Vertex request #[utoipa::path( post, tag = "Text Generation Inference", path = "/vertex", request_body = VertexRequest, responses( (status = 200, description = "Generated Text", body = VertexResponse), (status = 424, description = "Generation Error", body = ErrorResponse, example = json ! ({"error": "Request failed during generation"})), (status = 429, description = "Model is overloaded", body = ErrorResponse, example = json ! ({"error": "Model is overloaded"})), (status = 422, description = "Input validation error", body = ErrorResponse, example = json ! ({"error": "Input validation error"})), (status = 500, description = "Incomplete generation", body = ErrorResponse, example = json ! ({"error": "Incomplete generation"})), ) )] #[instrument( skip_all, fields( total_time, validation_time, queue_time, inference_time, time_per_token, seed, ) )] pub(crate) async fn vertex_compatibility( Extension(infer): Extension, Extension(compute_type): Extension, Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); metrics::counter!("tgi_request_count").increment(1); // check that theres at least one instance if req.instances.is_empty() { return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { error: "Input validation error".to_string(), error_type: "Input validation error".to_string(), }), )); } // Prepare futures for all instances let mut futures = Vec::with_capacity(req.instances.len()); for instance in req.instances.into_iter() { let generate_request = match instance { VertexInstance::Generate(instance) => GenerateRequest { inputs: instance.inputs.clone(), add_special_tokens: true, parameters: GenerateParameters { do_sample: true, max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), seed: instance.parameters.as_ref().and_then(|p| p.seed), details: true, decoder_input_details: true, ..Default::default() }, }, VertexInstance::Chat(instance) => { let chat_request: ChatRequest = instance.into(); let (generate_request, _using_tools): (GenerateRequest, bool) = chat_request.try_into_generate(&infer)?; generate_request } }; let infer_clone = infer.clone(); let compute_type_clone = compute_type.clone(); let span_clone = span.clone(); futures.push(async move { generate_internal( Extension(infer_clone), compute_type_clone, Json(generate_request), span_clone, ) .await .map(|(_, Json(generation))| generation.generated_text) .map_err(|_| { ( StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse { error: "Incomplete generation".into(), error_type: "Incomplete generation".into(), }), ) }) }); } // execute all futures in parallel, collect results, returning early if any error occurs let results = futures::future::join_all(futures).await; let predictions: Result, _> = results.into_iter().collect(); let predictions = predictions?; let response = VertexResponse { predictions }; Ok((HeaderMap::new(), Json(response)).into_response()) } #[cfg(test)] mod tests { use super::*; use crate::{Message, MessageContent}; #[test] fn vertex_deserialization() { let string = serde_json::json!({ "messages": [{"role": "user", "content": "What's Deep Learning?"}], "parameters": { "max_tokens": 128, "top_p": 0.95, "temperature": 0.7 } }); let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize"); let string = serde_json::json!({ "messages": [{"role": "user", "content": "What's Deep Learning?"}], }); let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize"); let string = serde_json::json!({ "instances": [ { "messages": [{"role": "user", "content": "What's Deep Learning?"}], "parameters": { "max_tokens": 128, "top_p": 0.95, "temperature": 0.7 } } ] }); let request: VertexRequest = serde_json::from_value(string).expect("Can deserialize"); assert_eq!( request, VertexRequest { instances: vec![VertexInstance::Chat(VertexChat { messages: vec![Message { role: "user".to_string(), content: MessageContent::SingleText("What's Deep Learning?".to_string()), name: None, },], parameters: VertexParameters { max_tokens: Some(128), top_p: Some(0.95), temperature: Some(0.7), ..Default::default() } })] } ); } }