diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index ea1fc773..557e03cb 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -14,7 +14,6 @@ use chat_template::ChatTemplate; use futures::future::try_join_all; use futures::Stream; use minijinja::ErrorKind; -use serde::{Deserialize, Serialize}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use thiserror::Error; @@ -374,25 +373,4 @@ impl InferError { InferError::StreamSerializationError(_) => "stream_serialization_error", } } - - pub(crate) fn into_openai_event(self) -> Event { - let message = self.to_string(); - Event::default().json_data(OpenaiErrorEvent { - error: APIError { - message, - http_status_code: 422, - }, - }) - } -} - -#[derive(Serialize)] -pub struct APIError { - message: String, - http_status_code: usize, -} - -#[derive(Serialize)] -pub struct OpenaiErrorEvent { - error: APIError, } diff --git a/router/src/server.rs b/router/src/server.rs index f0469ca5..da63e9b1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -7,10 +7,6 @@ use crate::kserve::{ kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kserve_model_metadata, kserve_model_metadata_ready, }; -use crate::sagemaker::{ - sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse, - __path_sagemaker_compatibility, -}; use crate::validation::ValidationError; use crate::vertex::vertex_compatibility; use crate::ChatTokenizeResponse; @@ -19,8 +15,7 @@ use crate::{ GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse, - TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage, - Validation, + TextMessage, Token, TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -46,7 +41,6 @@ use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::{Cache, Repo, RepoType}; use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; -use pyo3::prelude::*; use pyo3::types::IntoPyDict; use regex::Regex; use serde_json::Value; @@ -56,6 +50,7 @@ use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; use thiserror::Error; +use tokenizers::Tokenizer; use tokio::select; use tokio::signal; use tokio::sync::oneshot; @@ -65,41 +60,6 @@ use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec { - let offsets = encoding.get_offsets(); - let input_ids = encoding.get_ids(); - if offsets.len() == input_ids.len() { - input_ids - .iter() - .zip(offsets) - .map(|(&id, &(start, stop))| { - let text = input - .chars() - .skip(start) - .take(stop - start) - .collect::(); - SimpleToken { - id, - text, - start, - stop, - } - }) - .collect() - } else { - encoding - .get_ids() - .iter() - .map(|&id| SimpleToken { - id, - text: "".to_string(), - start: 0, - stop: 0, - }) - .collect() - } -} - /// Generate tokens if `stream == false` or a stream of token if `stream == true` #[utoipa::path( post, @@ -109,7 +69,7 @@ request_body = CompatGenerateRequest, responses( (status = 200, description = "Generated Text", content( -("application/json" = Vec), +("application/json" = GenerateResponse), ("text/event-stream" = StreamResponse), )), (status = 424, description = "Generation Error", body = ErrorResponse, @@ -123,7 +83,7 @@ example = json ! ({"error": "Incomplete generation"})), ) )] #[instrument(skip(infer, req))] -pub(crate) async fn compat_generate( +async fn compat_generate( Extension(default_return_full_text): Extension, infer: Extension, compute_type: Extension, @@ -181,16 +141,12 @@ async fn openai_get_model_info(info: Extension) -> Json { }) } -/// Template and tokenize ChatRequest #[utoipa::path( post, tag = "Text Generation Inference", path = "/chat_tokenize", request_body = ChatRequest, - responses( - (status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse), - (status = 404, description = "Failed to tokenize ChatRequest", body = ErrorResponse), - ) + responses((status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse)) )] async fn get_chat_tokenize( Extension(infer): Extension, @@ -201,14 +157,40 @@ async fn get_chat_tokenize( let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0; let input = generate_request.inputs.clone(); let encoding = infer.tokenize(generate_request).await?; + if let Some(encoding) = encoding { + let tokens: Vec = encoding + .get_ids() + .iter() + .zip(encoding.get_offsets()) + .map(|(&id, &(start, stop))| { + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); + SimpleToken { + id, + text, + start, + stop, + } + }) + .collect(); - let tokens = encoding_to_tokens(&encoding, &input); - - let resp = ChatTokenizeResponse { - tokenize_response: TokenizeResponse(tokens), - templated_text: input, - }; - Ok((HeaderMap::new(), Json(resp))) + let resp = ChatTokenizeResponse { + tokenize_response: TokenizeResponse(tokens), + templated_text: input, + }; + Ok((HeaderMap::new(), Json(resp))) + } else { + Err(( + StatusCode::NOT_FOUND, + Json(ErrorResponse { + error: "No fast tokenizer or tokenizer.json for this model".to_string(), + error_type: "no fast tokenizer".to_string(), + }), + )) + } } #[utoipa::path( @@ -696,7 +678,7 @@ time_per_token, seed, ) )] -pub(crate) async fn completions( +async fn completions( Extension(infer): Extension, Extension(compute_type): Extension, Extension(info): Extension, @@ -866,7 +848,14 @@ pub(crate) async fn completions( yield Ok(event); } - Err(err) => yield Ok(err.into_openai_event()), + Err(err) => { + let event = Event::default() + .json_data(ErrorEvent::into_api_error(err, 422)) + .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()); + println!("{:?}", event); + yield Ok::(event); + break + } } } }; @@ -1220,7 +1209,7 @@ time_per_token, seed, ) )] -pub(crate) async fn chat_completions( +async fn chat_completions( Extension(infer): Extension, Extension(compute_type): Extension, Extension(info): Extension, @@ -1274,102 +1263,107 @@ pub(crate) async fn chat_completions( }; let mut response_as_tool = using_tools; while let Some(result) = response_stream.next().await { - match result{ - Ok(stream_tokens) => { - let token_text = &stream_token.token.text.clone(); - match state { - StreamState::Buffering => { - json_buffer.push_str(&token_text.replace(" ", "")); - buffer.push(stream_token); - if let Some(captures) = function_regex.captures(&json_buffer) { - let function_name = captures[1].to_string(); - if function_name == "no_tool" { - state = StreamState::BufferTrailing; - response_as_tool = false; - buffer.clear(); - json_buffer.clear(); - } else { - state = StreamState::Content { - skip_close_quote: false, - }; - // send all the buffered messages - for stream_token in &buffer { - let event = create_event_from_stream_token( - stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); - yield Ok::(event); + match result { + Ok(stream_token) => { + let token_text = &stream_token.token.text.clone(); + match state { + StreamState::Buffering => { + json_buffer.push_str(&token_text.replace(" ", "")); + buffer.push(stream_token); + if let Some(captures) = function_regex.captures(&json_buffer) { + let function_name = captures[1].to_string(); + if function_name == "no_tool" { + state = StreamState::BufferTrailing; + response_as_tool = false; + buffer.clear(); + json_buffer.clear(); + } else { + state = StreamState::Content { + skip_close_quote: false, + }; + // send all the buffered messages + for stream_token in &buffer { + let event = create_event_from_stream_token( + stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); + yield Ok::(event); + } } } } - } - // if we skipped sending the buffer we need to avoid sending the following json key and quotes - StreamState::BufferTrailing => { - let infix_text = "\"content\":\""; - json_buffer.push_str(&token_text.replace(" ", "")); - // keep capturing until we find the infix text - match json_buffer.find(infix_text) { - Some(content_key_index) => { - json_buffer = - json_buffer[content_key_index + infix_text.len()..].to_string(); + // if we skipped sending the buffer we need to avoid sending the following json key and quotes + StreamState::BufferTrailing => { + let infix_text = "\"content\":\""; + json_buffer.push_str(&token_text.replace(" ", "")); + // keep capturing until we find the infix text + match json_buffer.find(infix_text) { + Some(content_key_index) => { + json_buffer = + json_buffer[content_key_index + infix_text.len()..].to_string(); + } + None => { + continue; + } } - None => { - continue; + // if there is leftover text after removing the infix text, we need to send it + if !json_buffer.is_empty() { + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + let chat_complete = + CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + Some(json_buffer.clone()), + None, + current_time, + None, + None, + None, + )); + yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { + InferError::StreamSerializationError(e.to_string()).into() + })); } + // cleanup the buffers + buffer.clear(); + json_buffer.clear(); + state = StreamState::Content { + skip_close_quote: true, + }; } - // if there is leftover text after removing the infix text, we need to send it - if !json_buffer.is_empty() { - let event = Event::default(); - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - let chat_complete = - CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( - model_id.clone(), - system_fingerprint.clone(), - Some(json_buffer.clone()), - None, - current_time, - None, - None, - None, - )); - yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { - InferError::StreamSerializationError(e.to_string()).into() - })); - } - // cleanup the buffers - buffer.clear(); - json_buffer.clear(); - state = StreamState::Content { - skip_close_quote: true, - }; - } - StreamState::Content { skip_close_quote } => { - if skip_close_quote && token_text.contains('"') { - break; - } + StreamState::Content { skip_close_quote } => { + if skip_close_quote && token_text.contains('"') { + break; + } + // send the content + let event = create_event_from_stream_token( + &stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); - // send the content - let event = create_event_from_stream_token( - &stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); - - yield Ok::(event); + yield Ok::(event); + } } } - }, - Err(err) => yield Event::from_openai(err) + Err(err) => { + let event = Event::default() + .json_data(ErrorEvent::into_api_error(err, 422)) + .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()); + yield Ok::(event); + break; + } } } yield Ok::(Event::default().data("[DONE]")); @@ -1475,8 +1469,35 @@ async fn tokenize( ) -> Result, (StatusCode, Json)> { let input = req.inputs.clone(); let encoding = infer.tokenize(req).await?; - let tokens = encoding_to_tokens(&encoding, &input); - Ok(Json(TokenizeResponse(tokens))) + if let Some(encoding) = encoding { + let tokens: Vec = encoding + .get_ids() + .iter() + .zip(encoding.get_offsets()) + .map(|(&id, &(start, stop))| { + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); + SimpleToken { + id, + text, + start, + stop, + } + }) + .collect(); + Ok(Json(TokenizeResponse(tokens))) + } else { + Err(( + StatusCode::NOT_FOUND, + Json(ErrorResponse { + error: "No fast tokenizer or tokenizer.json for this model".to_string(), + error_type: "no fast tokenizer".to_string(), + }), + )) + } } /// Prometheus metrics scrape endpoint @@ -1507,14 +1528,11 @@ completions, tokenize, metrics, openai_get_model_info, -sagemaker_compatibility, -get_chat_tokenize, ), components( schemas( Info, CompatGenerateRequest, -SagemakerRequest, GenerateRequest, GrammarType, ChatRequest, @@ -1537,8 +1555,6 @@ ChatCompletionTopLogprob, ChatCompletion, CompletionRequest, CompletionComplete, -SagemakerResponse, -SagemakerStreamResponse, Chunk, Completion, CompletionFinal, @@ -1566,7 +1582,6 @@ Function, FunctionDefinition, ToolChoice, ModelInfo, -ChatTokenizeResponse, ) ), tags( @@ -1586,71 +1601,6 @@ pub fn schema() -> ApiDoc { ApiDoc } -fn py_resolve_tokenizer( - py: pyo3::Python, - tokenizer_name: &str, - revision: Option<&str>, - trust_remote_code: bool, -) -> pyo3::PyResult<()> { - let transformers = py.import_bound("transformers")?; - let auto = transformers.getattr("AutoTokenizer")?; - let from_pretrained = auto.getattr("from_pretrained")?; - let args = (tokenizer_name,); - let kwargs = if let Some(rev) = &revision { - [ - ("revision", rev.to_string().into_py(py)), - ("trust_remote_code", trust_remote_code.into_py(py)), - ] - .into_py_dict_bound(py) - } else { - [("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py) - }; - let tokenizer = from_pretrained.call(args, Some(&kwargs))?; - let save = tokenizer.getattr("save_pretrained")?; - let args = ("out".to_string(),); - save.call1(args)?; - Ok(()) -} - -fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> { - // XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3 - // and state-spaces/mamba-130m - tracing::warn!("Odd tokenizer detected, falling back on legacy tokenization"); - - #[derive(serde::Deserialize)] - struct FallbackConfig { - base_model_name_or_path: Option, - model_type: Option, - ssm_config: Option, - } - config_filename.and_then(|filename| { - std::fs::read_to_string(filename) - .ok() - .as_ref() - .and_then(|c| { - let config: Result = serde_json::from_str(c); - if let Ok(config) = config { - if config.model_type.is_none() { - if let Some(base) = config.base_model_name_or_path { - pyo3::Python::with_gil(|py| -> PyResult<()> { - py_resolve_tokenizer(py, &base, Some("main"), false) - }) - .ok()?; - } - } - if config.ssm_config.is_some() { - // XXX Legacy mamba - pyo3::Python::with_gil(|py| -> PyResult<()> { - py_resolve_tokenizer(py, "EleutherAI/gpt-neox-20b", Some("main"), false) - }) - .ok()?; - } - } - Some(()) - }) - }) -} - /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( @@ -1666,13 +1616,13 @@ pub async fn run( tokenizer_name: String, tokenizer_config_path: Option, revision: Option, - trust_remote_code: bool, hostname: String, port: u16, cors_allow_origin: Option>, ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, + messages_api_enabled: bool, disable_grammar_support: bool, max_client_batch_size: usize, usage_stats_level: usage_stats::UsageStatsLevel, @@ -1744,6 +1694,7 @@ pub async fn run( // Load tokenizer and model info let ( + tokenizer_filename, config_filename, tokenizer_config_filename, preprocessor_config_filename, @@ -1751,6 +1702,7 @@ pub async fn run( model_info, ) = match api { Type::None => ( + Some(local_path.join("tokenizer.json")), Some(local_path.join("config.json")), Some(local_path.join("tokenizer_config.json")), Some(local_path.join("preprocessor_config.json")), @@ -1764,6 +1716,10 @@ pub async fn run( revision.clone().unwrap_or_else(|| "main".to_string()), )); + let tokenizer_filename = match api_repo.get("tokenizer.json").await { + Ok(tokenizer_filename) => Some(tokenizer_filename), + Err(_) => get_base_tokenizer(&api, &api_repo).await, + }; let config_filename = api_repo.get("config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); @@ -1776,6 +1732,7 @@ pub async fn run( None }; ( + tokenizer_filename, config_filename, tokenizer_config_filename, preprocessor_config_filename, @@ -1790,6 +1747,7 @@ pub async fn run( revision.clone().unwrap_or_else(|| "main".to_string()), )); ( + repo.get("tokenizer.json"), repo.get("config.json"), repo.get("tokenizer_config.json"), repo.get("preprocessor_config.json"), @@ -1811,31 +1769,36 @@ pub async fn run( HubTokenizerConfig::default() }); - let tokenizer: Tokenizer = { + let tokenizer: Option = tokenizer_filename.and_then(|filename| { use pyo3::prelude::*; - pyo3::Python::with_gil(|py| -> PyResult<()> { - py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?; + let convert = pyo3::Python::with_gil(|py| -> PyResult<()> { + let transformers = py.import_bound("transformers")?; + let auto = transformers.getattr("AutoTokenizer")?; + let from_pretrained = auto.getattr("from_pretrained")?; + let args = (tokenizer_name.to_string(),); + let kwargs = [( + "revision", + revision.clone().unwrap_or_else(|| "main".to_string()), + )] + .into_py_dict_bound(py); + let tokenizer = from_pretrained.call(args, Some(&kwargs))?; + let save = tokenizer.getattr("save_pretrained")?; + let args = ("out".to_string(),); + save.call1(args)?; Ok(()) }) .inspect_err(|err| { tracing::error!("Failed to import python tokenizer {err}"); - }) - .or_else(|err| { - let out = legacy_tokenizer_handle(config_filename.as_ref()); - out.ok_or(err) - }) - .expect("We cannot load a tokenizer"); - let filename = "out/tokenizer.json"; - if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) { - Tokenizer::Rust(tok) + }); + let filename = if convert.is_ok() { + // If we have correctly loaded and resaved with transformers + // We might have modified the tokenizer.json according to transformers + "out/tokenizer.json".into() } else { - Tokenizer::Python { - tokenizer_name: tokenizer_name.clone(), - revision: revision.clone(), - trust_remote_code, - } - } - }; + filename + }; + Tokenizer::from_file(filename).ok() + }); let config: Option = config_filename.and_then(|filename| { std::fs::read_to_string(filename) @@ -1863,6 +1826,10 @@ pub async fn run( preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); tracing::info!("Using config {config:?}"); + if tokenizer.is_none() { + tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); + tracing::warn!("Rust input length validation and truncation is disabled"); + } // Only send usage stats when TGI is run in container and the function returns Some let is_container = matches!(usage_stats::is_container(), Ok(true)); @@ -1884,6 +1851,7 @@ pub async fn run( // max_batch_size, revision.clone(), validation_workers, + messages_api_enabled, disable_grammar_support, max_client_batch_size, usage_stats_level, @@ -1925,6 +1893,7 @@ pub async fn run( ngrok, _ngrok_authtoken, _ngrok_edge, + messages_api_enabled, disable_grammar_support, max_client_batch_size, model_info, @@ -1977,13 +1946,14 @@ async fn start( validation_workers: usize, api_key: Option, config: Option, - (tokenizer, tokenizer_config): (Tokenizer, HubTokenizerConfig), + (tokenizer, tokenizer_config): (Option, HubTokenizerConfig), (preprocessor_config, processor_config): (Option, HubProcessorConfig), hostname: String, port: u16, ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, + messages_api_enabled: bool, disable_grammar_support: bool, max_client_batch_size: usize, model_info: HubModelInfo, @@ -2298,7 +2268,6 @@ async fn start( .route("/v1/chat/completions", post(chat_completions)) .route("/v1/completions", post(completions)) .route("/vertex", post(vertex_compatibility)) - .route("/invocations", post(sagemaker_compatibility)) .route("/tokenize", post(tokenize)); if let Some(api_key) = api_key { @@ -2334,6 +2303,13 @@ async fn start( .route("/metrics", get(metrics)) .route("/v1/models", get(openai_get_model_info)); + // Conditional AWS Sagemaker route + let aws_sagemaker_route = if messages_api_enabled { + Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED + } else { + Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise + }; + let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); @@ -2341,7 +2317,8 @@ async fn start( let mut app = Router::new() .merge(swagger_ui) .merge(base_routes) - .merge(info_routes); + .merge(info_routes) + .merge(aws_sagemaker_route); #[cfg(feature = "google")] { @@ -2437,6 +2414,30 @@ pub async fn get_hub_model_info(api: &ApiRepo) -> Option { } } +/// get base tokenizer +pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { + let config_filename = api_repo.get("config.json").await.ok()?; + + // Open the file in read-only mode with buffer. + let file = File::open(config_filename).ok()?; + let reader = BufReader::new(file); + + // Read the JSON contents of the file as an instance of `User`. + let config: serde_json::Value = serde_json::from_reader(reader).ok()?; + + if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { + let api_base_repo = api.repo(Repo::with_revision( + base_model_id.to_string(), + RepoType::Model, + "main".to_string(), + )); + + api_base_repo.get("tokenizer.json").await.ok() + } else { + None + } +} + /// get tokenizer_config from the Huggingface Hub pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; @@ -2520,6 +2521,28 @@ impl From for Event { } } +#[derive(serde::Serialize)] +pub struct APIError { + message: String, + http_status_code: usize, +} + +#[derive(serde::Serialize)] +pub struct ErrorEvent { + error: APIError, +} + +impl ErrorEvent { + fn into_api_error(err: InferError, http_status_code: usize) -> Self { + ErrorEvent { + error: APIError { + message: err.to_string(), + http_status_code, + }, + } + } +} + #[derive(Debug, Error)] pub enum WebServerError { #[error("Axum error: {0}")] @@ -2579,11 +2602,10 @@ mod tests { use crate::TokenizerConfigToken; use crate::Tool; - use crate::tests::get_tokenizer; use serde_json::json; - #[tokio::test] - async fn test_prepare_chat_input() { + #[test] + fn test_prepare_chat_input() { // Mock Backend to avoid network requests struct MockBackend; @@ -2624,11 +2646,9 @@ mod tests { ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) ); - let tokenizer = get_tokenizer(); - let infer = Infer::new( backend, - Validation::new(1, tokenizer, None, None, 1, 1, 1, 1, 1, false), + Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false), 1, tokenizer_config, HubProcessorConfig::default(),