diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 557e03cb..d3d6bc59 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -10,10 +10,12 @@ use crate::{ }; use async_stream::stream; use async_trait::async_trait; +use axum::response::sse::Event; use chat_template::ChatTemplate; use futures::future::try_join_all; use futures::Stream; use minijinja::ErrorKind; +use serde::Serialize; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use thiserror::Error; @@ -373,4 +375,26 @@ impl InferError { InferError::StreamSerializationError(_) => "stream_serialization_error", } } + + pub(crate) fn into_openai_event(self) -> Event { + Event::default() + .json_data(OpenaiErrorEvent { + error: APIError { + message: self.to_string(), + http_status_code: 422, + }, + }) + .unwrap() + } +} + +#[derive(Serialize)] +pub struct APIError { + message: String, + http_status_code: usize, +} + +#[derive(Serialize)] +pub struct OpenaiErrorEvent { + error: APIError, } diff --git a/router/src/server.rs b/router/src/server.rs index da63e9b1..cbb04174 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -7,6 +7,10 @@ use crate::kserve::{ kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kserve_model_metadata, kserve_model_metadata_ready, }; +use crate::sagemaker::{ + sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse, + __path_sagemaker_compatibility, +}; use crate::validation::ValidationError; use crate::vertex::vertex_compatibility; use crate::ChatTokenizeResponse; @@ -15,7 +19,8 @@ use crate::{ GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse, - TextMessage, Token, TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation, + TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage, + Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -41,6 +46,7 @@ use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::{Cache, Repo, RepoType}; use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; +use pyo3::prelude::*; use pyo3::types::IntoPyDict; use regex::Regex; use serde_json::Value; @@ -50,7 +56,6 @@ use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; use thiserror::Error; -use tokenizers::Tokenizer; use tokio::select; use tokio::signal; use tokio::sync::oneshot; @@ -60,6 +65,41 @@ use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; +fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec { + let offsets = encoding.get_offsets(); + let input_ids = encoding.get_ids(); + if offsets.len() == input_ids.len() { + input_ids + .iter() + .zip(offsets) + .map(|(&id, &(start, stop))| { + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); + SimpleToken { + id, + text, + start, + stop, + } + }) + .collect() + } else { + encoding + .get_ids() + .iter() + .map(|&id| SimpleToken { + id, + text: "".to_string(), + start: 0, + stop: 0, + }) + .collect() + } +} + /// Generate tokens if `stream == false` or a stream of token if `stream == true` #[utoipa::path( post, @@ -69,7 +109,7 @@ request_body = CompatGenerateRequest, responses( (status = 200, description = "Generated Text", content( -("application/json" = GenerateResponse), +("application/json" = Vec), ("text/event-stream" = StreamResponse), )), (status = 424, description = "Generation Error", body = ErrorResponse, @@ -83,7 +123,7 @@ example = json ! ({"error": "Incomplete generation"})), ) )] #[instrument(skip(infer, req))] -async fn compat_generate( +pub(crate) async fn compat_generate( Extension(default_return_full_text): Extension, infer: Extension, compute_type: Extension, @@ -141,12 +181,16 @@ async fn openai_get_model_info(info: Extension) -> Json { }) } +/// Template and tokenize ChatRequest #[utoipa::path( post, tag = "Text Generation Inference", path = "/chat_tokenize", request_body = ChatRequest, - responses((status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse)) + responses( + (status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse), + (status = 404, description = "Failed to tokenize ChatRequest", body = ErrorResponse), + ) )] async fn get_chat_tokenize( Extension(infer): Extension, @@ -157,40 +201,14 @@ async fn get_chat_tokenize( let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0; let input = generate_request.inputs.clone(); let encoding = infer.tokenize(generate_request).await?; - if let Some(encoding) = encoding { - let tokens: Vec = encoding - .get_ids() - .iter() - .zip(encoding.get_offsets()) - .map(|(&id, &(start, stop))| { - let text = input - .chars() - .skip(start) - .take(stop - start) - .collect::(); - SimpleToken { - id, - text, - start, - stop, - } - }) - .collect(); - let resp = ChatTokenizeResponse { - tokenize_response: TokenizeResponse(tokens), - templated_text: input, - }; - Ok((HeaderMap::new(), Json(resp))) - } else { - Err(( - StatusCode::NOT_FOUND, - Json(ErrorResponse { - error: "No fast tokenizer or tokenizer.json for this model".to_string(), - error_type: "no fast tokenizer".to_string(), - }), - )) - } + let tokens = encoding_to_tokens(&encoding, &input); + + let resp = ChatTokenizeResponse { + tokenize_response: TokenizeResponse(tokens), + templated_text: input, + }; + Ok((HeaderMap::new(), Json(resp))) } #[utoipa::path( @@ -678,7 +696,7 @@ time_per_token, seed, ) )] -async fn completions( +pub(crate) async fn completions( Extension(infer): Extension, Extension(compute_type): Extension, Extension(info): Extension, @@ -848,14 +866,7 @@ async fn completions( yield Ok(event); } - Err(err) => { - let event = Event::default() - .json_data(ErrorEvent::into_api_error(err, 422)) - .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()); - println!("{:?}", event); - yield Ok::(event); - break - } + Err(err) => yield Ok(err.into_openai_event()), } } }; @@ -1209,7 +1220,7 @@ time_per_token, seed, ) )] -async fn chat_completions( +pub(crate) async fn chat_completions( Extension(infer): Extension, Extension(compute_type): Extension, Extension(info): Extension, @@ -1263,107 +1274,102 @@ async fn chat_completions( }; let mut response_as_tool = using_tools; while let Some(result) = response_stream.next().await { - match result { - Ok(stream_token) => { - let token_text = &stream_token.token.text.clone(); - match state { - StreamState::Buffering => { - json_buffer.push_str(&token_text.replace(" ", "")); - buffer.push(stream_token); - if let Some(captures) = function_regex.captures(&json_buffer) { - let function_name = captures[1].to_string(); - if function_name == "no_tool" { - state = StreamState::BufferTrailing; - response_as_tool = false; - buffer.clear(); - json_buffer.clear(); - } else { - state = StreamState::Content { - skip_close_quote: false, - }; - // send all the buffered messages - for stream_token in &buffer { - let event = create_event_from_stream_token( - stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); - yield Ok::(event); - } - } - } - } - // if we skipped sending the buffer we need to avoid sending the following json key and quotes - StreamState::BufferTrailing => { - let infix_text = "\"content\":\""; - json_buffer.push_str(&token_text.replace(" ", "")); - // keep capturing until we find the infix text - match json_buffer.find(infix_text) { - Some(content_key_index) => { - json_buffer = - json_buffer[content_key_index + infix_text.len()..].to_string(); - } - None => { - continue; - } - } - // if there is leftover text after removing the infix text, we need to send it - if !json_buffer.is_empty() { - let event = Event::default(); - let current_time = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_else(|_| std::time::Duration::from_secs(0)) - .as_secs(); - let chat_complete = - CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( - model_id.clone(), + match result{ + Ok(stream_token) => { + let token_text = &stream_token.token.text.clone(); + match state { + StreamState::Buffering => { + json_buffer.push_str(&token_text.replace(" ", "")); + buffer.push(stream_token); + if let Some(captures) = function_regex.captures(&json_buffer) { + let function_name = captures[1].to_string(); + if function_name == "no_tool" { + state = StreamState::BufferTrailing; + response_as_tool = false; + buffer.clear(); + json_buffer.clear(); + } else { + state = StreamState::Content { + skip_close_quote: false, + }; + // send all the buffered messages + for stream_token in &buffer { + let event = create_event_from_stream_token( + stream_token, + logprobs, + stream_options.clone(), + response_as_tool, system_fingerprint.clone(), - Some(json_buffer.clone()), - None, - current_time, - None, - None, - None, - )); - yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { - InferError::StreamSerializationError(e.to_string()).into() - })); + model_id.clone(), + ); + yield Ok::(event); + } } - // cleanup the buffers - buffer.clear(); - json_buffer.clear(); - state = StreamState::Content { - skip_close_quote: true, - }; - } - StreamState::Content { skip_close_quote } => { - if skip_close_quote && token_text.contains('"') { - break; - } - // send the content - let event = create_event_from_stream_token( - &stream_token, - logprobs, - stream_options.clone(), - response_as_tool, - system_fingerprint.clone(), - model_id.clone(), - ); - - yield Ok::(event); } } + // if we skipped sending the buffer we need to avoid sending the following json key and quotes + StreamState::BufferTrailing => { + let infix_text = "\"content\":\""; + json_buffer.push_str(&token_text.replace(" ", "")); + // keep capturing until we find the infix text + match json_buffer.find(infix_text) { + Some(content_key_index) => { + json_buffer = + json_buffer[content_key_index + infix_text.len()..].to_string(); + } + None => { + continue; + } + } + // if there is leftover text after removing the infix text, we need to send it + if !json_buffer.is_empty() { + let event = Event::default(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs(); + let chat_complete = + CompletionType::ChatCompletionChunk(ChatCompletionChunk::new( + model_id.clone(), + system_fingerprint.clone(), + Some(json_buffer.clone()), + None, + current_time, + None, + None, + None, + )); + yield Ok(event.json_data(chat_complete).unwrap_or_else(|e| { + InferError::StreamSerializationError(e.to_string()).into() + })); + } + // cleanup the buffers + buffer.clear(); + json_buffer.clear(); + state = StreamState::Content { + skip_close_quote: true, + }; + } + StreamState::Content { skip_close_quote } => { + if skip_close_quote && token_text.contains('"') { + break; + } + + // send the content + let event = create_event_from_stream_token( + &stream_token, + logprobs, + stream_options.clone(), + response_as_tool, + system_fingerprint.clone(), + model_id.clone(), + ); + + yield Ok::(event); + } } - Err(err) => { - let event = Event::default() - .json_data(ErrorEvent::into_api_error(err, 422)) - .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into()); - yield Ok::(event); - break; - } + } + Err(err) => yield Ok(err.into_openai_event()) } } yield Ok::(Event::default().data("[DONE]")); @@ -1469,35 +1475,8 @@ async fn tokenize( ) -> Result, (StatusCode, Json)> { let input = req.inputs.clone(); let encoding = infer.tokenize(req).await?; - if let Some(encoding) = encoding { - let tokens: Vec = encoding - .get_ids() - .iter() - .zip(encoding.get_offsets()) - .map(|(&id, &(start, stop))| { - let text = input - .chars() - .skip(start) - .take(stop - start) - .collect::(); - SimpleToken { - id, - text, - start, - stop, - } - }) - .collect(); - Ok(Json(TokenizeResponse(tokens))) - } else { - Err(( - StatusCode::NOT_FOUND, - Json(ErrorResponse { - error: "No fast tokenizer or tokenizer.json for this model".to_string(), - error_type: "no fast tokenizer".to_string(), - }), - )) - } + let tokens = encoding_to_tokens(&encoding, &input); + Ok(Json(TokenizeResponse(tokens))) } /// Prometheus metrics scrape endpoint @@ -1528,11 +1507,14 @@ completions, tokenize, metrics, openai_get_model_info, +sagemaker_compatibility, +get_chat_tokenize, ), components( schemas( Info, CompatGenerateRequest, +SagemakerRequest, GenerateRequest, GrammarType, ChatRequest, @@ -1555,6 +1537,8 @@ ChatCompletionTopLogprob, ChatCompletion, CompletionRequest, CompletionComplete, +SagemakerResponse, +SagemakerStreamResponse, Chunk, Completion, CompletionFinal, @@ -1582,6 +1566,7 @@ Function, FunctionDefinition, ToolChoice, ModelInfo, +ChatTokenizeResponse, ) ), tags( @@ -1601,6 +1586,71 @@ pub fn schema() -> ApiDoc { ApiDoc } +fn py_resolve_tokenizer( + py: pyo3::Python, + tokenizer_name: &str, + revision: Option<&str>, + trust_remote_code: bool, +) -> pyo3::PyResult<()> { + let transformers = py.import_bound("transformers")?; + let auto = transformers.getattr("AutoTokenizer")?; + let from_pretrained = auto.getattr("from_pretrained")?; + let args = (tokenizer_name,); + let kwargs = if let Some(rev) = &revision { + [ + ("revision", rev.to_string().into_py(py)), + ("trust_remote_code", trust_remote_code.into_py(py)), + ] + .into_py_dict_bound(py) + } else { + [("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py) + }; + let tokenizer = from_pretrained.call(args, Some(&kwargs))?; + let save = tokenizer.getattr("save_pretrained")?; + let args = ("out".to_string(),); + save.call1(args)?; + Ok(()) +} + +fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> { + // XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3 + // and state-spaces/mamba-130m + tracing::warn!("Odd tokenizer detected, falling back on legacy tokenization"); + + #[derive(serde::Deserialize)] + struct FallbackConfig { + base_model_name_or_path: Option, + model_type: Option, + ssm_config: Option, + } + config_filename.and_then(|filename| { + std::fs::read_to_string(filename) + .ok() + .as_ref() + .and_then(|c| { + let config: Result = serde_json::from_str(c); + if let Ok(config) = config { + if config.model_type.is_none() { + if let Some(base) = config.base_model_name_or_path { + pyo3::Python::with_gil(|py| -> PyResult<()> { + py_resolve_tokenizer(py, &base, Some("main"), false) + }) + .ok()?; + } + } + if config.ssm_config.is_some() { + // XXX Legacy mamba + pyo3::Python::with_gil(|py| -> PyResult<()> { + py_resolve_tokenizer(py, "EleutherAI/gpt-neox-20b", Some("main"), false) + }) + .ok()?; + } + } + Some(()) + }) + }) +} + /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( @@ -1616,13 +1666,13 @@ pub async fn run( tokenizer_name: String, tokenizer_config_path: Option, revision: Option, + trust_remote_code: bool, hostname: String, port: u16, cors_allow_origin: Option>, ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, - messages_api_enabled: bool, disable_grammar_support: bool, max_client_batch_size: usize, usage_stats_level: usage_stats::UsageStatsLevel, @@ -1694,7 +1744,6 @@ pub async fn run( // Load tokenizer and model info let ( - tokenizer_filename, config_filename, tokenizer_config_filename, preprocessor_config_filename, @@ -1702,7 +1751,6 @@ pub async fn run( model_info, ) = match api { Type::None => ( - Some(local_path.join("tokenizer.json")), Some(local_path.join("config.json")), Some(local_path.join("tokenizer_config.json")), Some(local_path.join("preprocessor_config.json")), @@ -1716,10 +1764,6 @@ pub async fn run( revision.clone().unwrap_or_else(|| "main".to_string()), )); - let tokenizer_filename = match api_repo.get("tokenizer.json").await { - Ok(tokenizer_filename) => Some(tokenizer_filename), - Err(_) => get_base_tokenizer(&api, &api_repo).await, - }; let config_filename = api_repo.get("config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); @@ -1732,7 +1776,6 @@ pub async fn run( None }; ( - tokenizer_filename, config_filename, tokenizer_config_filename, preprocessor_config_filename, @@ -1747,7 +1790,6 @@ pub async fn run( revision.clone().unwrap_or_else(|| "main".to_string()), )); ( - repo.get("tokenizer.json"), repo.get("config.json"), repo.get("tokenizer_config.json"), repo.get("preprocessor_config.json"), @@ -1769,36 +1811,31 @@ pub async fn run( HubTokenizerConfig::default() }); - let tokenizer: Option = tokenizer_filename.and_then(|filename| { + let tokenizer: Tokenizer = { use pyo3::prelude::*; - let convert = pyo3::Python::with_gil(|py| -> PyResult<()> { - let transformers = py.import_bound("transformers")?; - let auto = transformers.getattr("AutoTokenizer")?; - let from_pretrained = auto.getattr("from_pretrained")?; - let args = (tokenizer_name.to_string(),); - let kwargs = [( - "revision", - revision.clone().unwrap_or_else(|| "main".to_string()), - )] - .into_py_dict_bound(py); - let tokenizer = from_pretrained.call(args, Some(&kwargs))?; - let save = tokenizer.getattr("save_pretrained")?; - let args = ("out".to_string(),); - save.call1(args)?; + pyo3::Python::with_gil(|py| -> PyResult<()> { + py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?; Ok(()) }) .inspect_err(|err| { tracing::error!("Failed to import python tokenizer {err}"); - }); - let filename = if convert.is_ok() { - // If we have correctly loaded and resaved with transformers - // We might have modified the tokenizer.json according to transformers - "out/tokenizer.json".into() + }) + .or_else(|err| { + let out = legacy_tokenizer_handle(config_filename.as_ref()); + out.ok_or(err) + }) + .expect("We cannot load a tokenizer"); + let filename = "out/tokenizer.json"; + if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) { + Tokenizer::Rust(tok) } else { - filename - }; - Tokenizer::from_file(filename).ok() - }); + Tokenizer::Python { + tokenizer_name: tokenizer_name.clone(), + revision: revision.clone(), + trust_remote_code, + } + } + }; let config: Option = config_filename.and_then(|filename| { std::fs::read_to_string(filename) @@ -1826,10 +1863,6 @@ pub async fn run( preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); tracing::info!("Using config {config:?}"); - if tokenizer.is_none() { - tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); - tracing::warn!("Rust input length validation and truncation is disabled"); - } // Only send usage stats when TGI is run in container and the function returns Some let is_container = matches!(usage_stats::is_container(), Ok(true)); @@ -1851,7 +1884,6 @@ pub async fn run( // max_batch_size, revision.clone(), validation_workers, - messages_api_enabled, disable_grammar_support, max_client_batch_size, usage_stats_level, @@ -1893,7 +1925,6 @@ pub async fn run( ngrok, _ngrok_authtoken, _ngrok_edge, - messages_api_enabled, disable_grammar_support, max_client_batch_size, model_info, @@ -1946,14 +1977,13 @@ async fn start( validation_workers: usize, api_key: Option, config: Option, - (tokenizer, tokenizer_config): (Option, HubTokenizerConfig), + (tokenizer, tokenizer_config): (Tokenizer, HubTokenizerConfig), (preprocessor_config, processor_config): (Option, HubProcessorConfig), hostname: String, port: u16, ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, - messages_api_enabled: bool, disable_grammar_support: bool, max_client_batch_size: usize, model_info: HubModelInfo, @@ -2268,6 +2298,7 @@ async fn start( .route("/v1/chat/completions", post(chat_completions)) .route("/v1/completions", post(completions)) .route("/vertex", post(vertex_compatibility)) + .route("/invocations", post(sagemaker_compatibility)) .route("/tokenize", post(tokenize)); if let Some(api_key) = api_key { @@ -2303,13 +2334,6 @@ async fn start( .route("/metrics", get(metrics)) .route("/v1/models", get(openai_get_model_info)); - // Conditional AWS Sagemaker route - let aws_sagemaker_route = if messages_api_enabled { - Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED - } else { - Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise - }; - let compute_type = ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string())); @@ -2317,8 +2341,7 @@ async fn start( let mut app = Router::new() .merge(swagger_ui) .merge(base_routes) - .merge(info_routes) - .merge(aws_sagemaker_route); + .merge(info_routes); #[cfg(feature = "google")] { @@ -2414,30 +2437,6 @@ pub async fn get_hub_model_info(api: &ApiRepo) -> Option { } } -/// get base tokenizer -pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { - let config_filename = api_repo.get("config.json").await.ok()?; - - // Open the file in read-only mode with buffer. - let file = File::open(config_filename).ok()?; - let reader = BufReader::new(file); - - // Read the JSON contents of the file as an instance of `User`. - let config: serde_json::Value = serde_json::from_reader(reader).ok()?; - - if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { - let api_base_repo = api.repo(Repo::with_revision( - base_model_id.to_string(), - RepoType::Model, - "main".to_string(), - )); - - api_base_repo.get("tokenizer.json").await.ok() - } else { - None - } -} - /// get tokenizer_config from the Huggingface Hub pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; @@ -2521,28 +2520,6 @@ impl From for Event { } } -#[derive(serde::Serialize)] -pub struct APIError { - message: String, - http_status_code: usize, -} - -#[derive(serde::Serialize)] -pub struct ErrorEvent { - error: APIError, -} - -impl ErrorEvent { - fn into_api_error(err: InferError, http_status_code: usize) -> Self { - ErrorEvent { - error: APIError { - message: err.to_string(), - http_status_code, - }, - } - } -} - #[derive(Debug, Error)] pub enum WebServerError { #[error("Axum error: {0}")] @@ -2602,10 +2579,11 @@ mod tests { use crate::TokenizerConfigToken; use crate::Tool; + use crate::tests::get_tokenizer; use serde_json::json; - #[test] - fn test_prepare_chat_input() { + #[tokio::test] + async fn test_prepare_chat_input() { // Mock Backend to avoid network requests struct MockBackend; @@ -2646,9 +2624,11 @@ mod tests { ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) ); + let tokenizer = get_tokenizer(); + let infer = Infer::new( backend, - Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false), + Validation::new(1, tokenizer, None, None, 1, 1, 1, 1, 1, false), 1, tokenizer_config, HubProcessorConfig::default(),