diff --git a/router/client/src/pb/generate.v2.rs b/router/client/src/pb/generate.v2.rs deleted file mode 100644 index 1a206360..00000000 --- a/router/client/src/pb/generate.v2.rs +++ /dev/null @@ -1,647 +0,0 @@ -// This file is @generated by prost-build. -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct HealthRequest {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct HealthResponse {} -/// / Empty request -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct InfoRequest {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct InfoResponse { - #[prost(bool, tag = "1")] - pub requires_padding: bool, - #[prost(string, tag = "2")] - pub dtype: ::prost::alloc::string::String, - #[prost(string, tag = "3")] - pub device_type: ::prost::alloc::string::String, - #[prost(uint32, optional, tag = "4")] - pub window_size: ::core::option::Option, - #[prost(uint32, tag = "5")] - pub speculate: u32, -} -/// / Empty request -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ServiceDiscoveryRequest {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ServiceDiscoveryResponse { - /// / Other shards urls - #[prost(string, repeated, tag = "1")] - pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ClearCacheRequest { - /// / Optional batch id - #[prost(uint64, optional, tag = "1")] - pub id: ::core::option::Option, -} -/// / Empty response -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ClearCacheResponse {} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct NextTokenChooserParameters { - /// / exponential scaling output probability distribution - #[prost(float, tag = "1")] - pub temperature: f32, - /// / restricting to the k highest probability elements - #[prost(uint32, tag = "2")] - pub top_k: u32, - /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off - #[prost(float, tag = "3")] - pub top_p: f32, - /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off - #[prost(float, tag = "4")] - pub typical_p: f32, - /// / apply sampling on the logits - #[prost(bool, tag = "5")] - pub do_sample: bool, - /// / random seed for sampling - #[prost(uint64, tag = "6")] - pub seed: u64, - /// / repetition penalty - #[prost(float, tag = "7")] - pub repetition_penalty: f32, - /// / frequency penalty - #[prost(float, tag = "9")] - pub frequency_penalty: f32, - /// / token watermarking using "A Watermark for Large Language Models" - #[prost(bool, tag = "8")] - pub watermark: bool, - /// / grammar (applied if not empty) - #[prost(string, tag = "10")] - pub grammar: ::prost::alloc::string::String, - /// / grammar type - #[prost(enumeration = "GrammarType", tag = "11")] - pub grammar_type: i32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct StoppingCriteriaParameters { - /// / Maximum number of generated tokens - #[prost(uint32, tag = "1")] - pub max_new_tokens: u32, - /// / Optional stopping sequences - #[prost(string, repeated, tag = "2")] - pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - /// / Ignore end of sequence token - /// / used for benchmarking - #[prost(bool, tag = "3")] - pub ignore_eos_token: bool, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Request { - /// / Request ID - #[prost(uint64, tag = "1")] - pub id: u64, - /// / The generation context - #[prost(string, tag = "2")] - pub inputs: ::prost::alloc::string::String, - /// / Context truncation - #[prost(uint32, tag = "3")] - pub truncate: u32, - /// / Next Token Chooser Parameters - #[prost(message, optional, tag = "4")] - pub parameters: ::core::option::Option, - /// / Stopping Criteria Parameters - #[prost(message, optional, tag = "5")] - pub stopping_parameters: ::core::option::Option, - /// / Return prefill logprobs - #[prost(bool, tag = "6")] - pub prefill_logprobs: bool, - /// / Return most likely n tokens - #[prost(uint32, tag = "7")] - pub top_n_tokens: u32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Batch { - /// / Batch ID - #[prost(uint64, tag = "1")] - pub id: u64, - /// / Individual requests - #[prost(message, repeated, tag = "2")] - pub requests: ::prost::alloc::vec::Vec, - /// / Batch size (==len(requests)) - #[prost(uint32, tag = "3")] - pub size: u32, - /// / Maximum number of tokens this batch will grow to - #[prost(uint32, tag = "4")] - pub max_tokens: u32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CachedBatch { - /// / Batch ID - #[prost(uint64, tag = "1")] - pub id: u64, - /// / Individual requests ids - #[prost(uint64, repeated, tag = "2")] - pub request_ids: ::prost::alloc::vec::Vec, - /// / Batch size (==len(requests)) - #[prost(uint32, tag = "3")] - pub size: u32, - /// / Maximum number of tokens this batch will grow to - #[prost(uint32, tag = "4")] - pub max_tokens: u32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct GeneratedText { - /// / Output - #[prost(string, tag = "1")] - pub text: ::prost::alloc::string::String, - /// / Number of generated tokens - #[prost(uint32, tag = "2")] - pub generated_tokens: u32, - /// / Finish reason - #[prost(enumeration = "FinishReason", tag = "3")] - pub finish_reason: i32, - /// / Seed - #[prost(uint64, optional, tag = "4")] - pub seed: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Tokens { - /// / Token IDs - #[prost(uint32, repeated, tag = "1")] - pub ids: ::prost::alloc::vec::Vec, - /// / Logprobs - #[prost(float, repeated, tag = "2")] - pub logprobs: ::prost::alloc::vec::Vec, - /// / tokens - #[prost(string, repeated, tag = "3")] - pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - /// / special - #[prost(bool, repeated, tag = "4")] - pub is_special: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Generation { - /// / Request ID - #[prost(uint64, tag = "1")] - pub request_id: u64, - /// / Prefill tokens (optional) - #[prost(message, optional, tag = "2")] - pub prefill_tokens: ::core::option::Option, - #[prost(message, optional, tag = "3")] - pub tokens: ::core::option::Option, - /// / Complete generated text - #[prost(message, optional, tag = "4")] - pub generated_text: ::core::option::Option, - /// / Top tokens - #[prost(message, repeated, tag = "5")] - pub top_tokens: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct FilterBatchRequest { - /// / Batch ID - #[prost(uint64, tag = "1")] - pub batch_id: u64, - /// / Requests to keep - #[prost(uint64, repeated, tag = "2")] - pub request_ids: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct FilterBatchResponse { - /// / Filtered Batch (cached) - #[prost(message, optional, tag = "1")] - pub batch: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PrefillRequest { - /// / Batch - #[prost(message, optional, tag = "1")] - pub batch: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PrefillResponse { - /// / Generation - #[prost(message, repeated, tag = "1")] - pub generations: ::prost::alloc::vec::Vec, - /// / Next batch (cached) - #[prost(message, optional, tag = "2")] - pub batch: ::core::option::Option, - /// / Forward elapsed time in nanoseconds - #[prost(uint64, tag = "3")] - pub forward_ns: u64, - /// / Decode elapsed time in nanoseconds - #[prost(uint64, tag = "4")] - pub decode_ns: u64, - /// / Total elapsed time in nanoseconds - #[prost(uint64, tag = "5")] - pub total_ns: u64, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct DecodeRequest { - /// / Cached batches - #[prost(message, repeated, tag = "1")] - pub batches: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct DecodeResponse { - /// / Decodes - #[prost(message, repeated, tag = "1")] - pub generations: ::prost::alloc::vec::Vec, - /// / Next batch (cached) - #[prost(message, optional, tag = "2")] - pub batch: ::core::option::Option, - /// / Forward elapsed time in nanoseconds - #[prost(uint64, tag = "3")] - pub forward_ns: u64, - /// / Decode elapsed time in nanoseconds - #[prost(uint64, tag = "4")] - pub decode_ns: u64, - /// / Total elapsed time in nanoseconds - #[prost(uint64, tag = "5")] - pub total_ns: u64, - /// / Concatenate elapsed time in nanoseconds - #[prost(uint64, optional, tag = "6")] - pub concat_ns: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct WarmupRequest { - /// / Batch to warmup on - #[prost(message, optional, tag = "1")] - pub batch: ::core::option::Option, - #[prost(uint32, tag = "2")] - pub max_input_length: u32, - #[prost(uint32, tag = "3")] - pub max_prefill_tokens: u32, - #[prost(uint32, tag = "4")] - pub max_total_tokens: u32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct WarmupResponse { - /// / Maximum number of tokens supported by the model - #[prost(uint32, optional, tag = "1")] - pub max_supported_total_tokens: ::core::option::Option, -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum GrammarType { - None = 0, - Json = 1, - Regex = 2, -} -impl GrammarType { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - GrammarType::None => "GRAMMAR_TYPE_NONE", - GrammarType::Json => "GRAMMAR_TYPE_JSON", - GrammarType::Regex => "GRAMMAR_TYPE_REGEX", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "GRAMMAR_TYPE_NONE" => Some(Self::None), - "GRAMMAR_TYPE_JSON" => Some(Self::Json), - "GRAMMAR_TYPE_REGEX" => Some(Self::Regex), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum FinishReason { - Length = 0, - EosToken = 1, - StopSequence = 2, -} -impl FinishReason { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - FinishReason::Length => "FINISH_REASON_LENGTH", - FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN", - FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "FINISH_REASON_LENGTH" => Some(Self::Length), - "FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken), - "FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence), - _ => None, - } - } -} -/// Generated client implementations. -pub mod text_generation_service_client { - #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] - use tonic::codegen::*; - use tonic::codegen::http::Uri; - #[derive(Debug, Clone)] - pub struct TextGenerationServiceClient { - inner: tonic::client::Grpc, - } - impl TextGenerationServiceClient { - /// Attempt to create a new client by connecting to a given endpoint. - pub async fn connect(dst: D) -> Result - where - D: TryInto, - D::Error: Into, - { - let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; - Ok(Self::new(conn)) - } - } - impl TextGenerationServiceClient - where - T: tonic::client::GrpcService, - T::Error: Into, - T::ResponseBody: Body + Send + 'static, - ::Error: Into + Send, - { - pub fn new(inner: T) -> Self { - let inner = tonic::client::Grpc::new(inner); - Self { inner } - } - pub fn with_origin(inner: T, origin: Uri) -> Self { - let inner = tonic::client::Grpc::with_origin(inner, origin); - Self { inner } - } - pub fn with_interceptor( - inner: T, - interceptor: F, - ) -> TextGenerationServiceClient> - where - F: tonic::service::Interceptor, - T::ResponseBody: Default, - T: tonic::codegen::Service< - http::Request, - Response = http::Response< - >::ResponseBody, - >, - >, - , - >>::Error: Into + Send + Sync, - { - TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor)) - } - /// Compress requests with the given encoding. - /// - /// This requires the server to support it otherwise it might respond with an - /// error. - #[must_use] - pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.send_compressed(encoding); - self - } - /// Enable decompressing responses. - #[must_use] - pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.accept_compressed(encoding); - self - } - /// Limits the maximum size of a decoded message. - /// - /// Default: `4MB` - #[must_use] - pub fn max_decoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_decoding_message_size(limit); - self - } - /// Limits the maximum size of an encoded message. - /// - /// Default: `usize::MAX` - #[must_use] - pub fn max_encoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_encoding_message_size(limit); - self - } - /// / Model Info - pub async fn info( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/Info", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Info")); - self.inner.unary(req, path, codec).await - } - /// / Service discovery - pub async fn service_discovery( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/ServiceDiscovery", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new( - "generate.v2.TextGenerationService", - "ServiceDiscovery", - ), - ); - self.inner.unary(req, path, codec).await - } - /// / Empties batch cache - pub async fn clear_cache( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/ClearCache", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new("generate.v2.TextGenerationService", "ClearCache"), - ); - self.inner.unary(req, path, codec).await - } - /// / Remove requests from a cached batch - pub async fn filter_batch( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/FilterBatch", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert( - GrpcMethod::new("generate.v2.TextGenerationService", "FilterBatch"), - ); - self.inner.unary(req, path, codec).await - } - /// / Warmup the model and compute max cache size - pub async fn warmup( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/Warmup", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Warmup")); - self.inner.unary(req, path, codec).await - } - /// / Prefill batch and decode first token - pub async fn prefill( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/Prefill", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Prefill")); - self.inner.unary(req, path, codec).await - } - /// / Decode token for a list of prefilled batches - pub async fn decode( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/Decode", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Decode")); - self.inner.unary(req, path, codec).await - } - /// / Health check - pub async fn health( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner - .ready() - .await - .map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static( - "/generate.v2.TextGenerationService/Health", - ); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Health")); - self.inner.unary(req, path, codec).await - } - } -} diff --git a/router/client/src/pb/mod.rs b/router/client/src/pb/mod.rs deleted file mode 100644 index 095ead1f..00000000 --- a/router/client/src/pb/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -// This file is @generated by prost-build. -pub mod generate { - pub mod v2 { - include!("generate.v2.rs"); - } -} diff --git a/router/src/lib.rs b/router/src/lib.rs index f01c2f7f..3132d9b6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1421,13 +1421,12 @@ impl Default for ModelsInfo { mod tests { use super::*; use serde_json::json; - use tokenizers::Tokenizer; - pub(crate) async fn get_tokenizer() -> Tokenizer { + pub(crate) fn get_tokenizer() -> Tokenizer { let api = hf_hub::api::sync::Api::new().unwrap(); let repo = api.model("gpt2".to_string()); let filename = repo.get("tokenizer.json").unwrap(); - Tokenizer::from_file(filename).unwrap() + Tokenizer::Rust(tokenizers::Tokenizer::from_file(filename).unwrap()) } #[test] diff --git a/router/src/server.rs b/router/src/server.rs index be8e7290..8608ca2a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2515,10 +2515,11 @@ mod tests { use crate::TokenizerConfigToken; use crate::Tool; + use crate::tests::get_tokenizer; use serde_json::json; - #[test] - fn test_prepare_chat_input() { + #[tokio::test] + async fn test_prepare_chat_input() { // Mock Backend to avoid network requests struct MockBackend; @@ -2559,9 +2560,11 @@ mod tests { ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) ); + let tokenizer = get_tokenizer(); + let infer = Infer::new( backend, - Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false), + Validation::new(1, tokenizer, None, None, 1, 1, 1, 1, 1, false), 1, tokenizer_config, HubProcessorConfig::default(), diff --git a/router/src/validation.rs b/router/src/validation.rs index 2cd92e5d..a02d53a5 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -797,7 +797,7 @@ mod tests { #[tokio::test] async fn test_validation_max_new_tokens() { - let tokenizer = None; + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequence = 3; let max_top_n_tokens = 4; @@ -824,15 +824,15 @@ mod tests { .validate_input("Hello".to_string(), true, None, Some(max_new_tokens)) .await { - // Err(ValidationError::MaxNewTokens(1, 10)) => (), - Ok((_s, _, 0, 10)) => (), + Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), + // Ok((_s, _, 0, 10)) => (), r => panic!("Unexpected not max new tokens: {r:?}"), } } #[tokio::test] async fn test_validation_input_length() { - let tokenizer = Some(get_tokenizer().await); + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequence = 3; let max_top_n_tokens = 4; @@ -866,7 +866,7 @@ mod tests { #[tokio::test] async fn test_validation_best_of_sampling() { - let tokenizer = Some(get_tokenizer().await); + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequence = 3; let max_top_n_tokens = 4; @@ -906,7 +906,7 @@ mod tests { #[tokio::test] async fn test_validation_top_p() { - let tokenizer = Some(get_tokenizer().await); + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequence = 3; let max_top_n_tokens = 4; @@ -977,7 +977,7 @@ mod tests { #[tokio::test] async fn test_validation_top_n_tokens() { - let tokenizer = Some(get_tokenizer().await); + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequences = 3; let max_top_n_tokens = 4; @@ -1062,7 +1062,7 @@ mod tests { async fn test_prepare_input_chunks() { let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap(); - let tokenizer = Some(get_tokenizer().await); + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequence = 3; @@ -1097,7 +1097,7 @@ mod tests { ) .await { - Ok(Some((_encoding, chunks))) => chunks, + Ok((_encoding, chunks)) => chunks, _ => panic!("Unexpected tokenization failure"), }; @@ -1119,7 +1119,7 @@ mod tests { async fn test_idefics2_correct_n_fake_tokens() { let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap(); - let tokenizer = Some(get_tokenizer().await); + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequence = 3; @@ -1157,7 +1157,7 @@ mod tests { ) .await { - Ok(Some((encoding, chunks))) => (encoding, chunks), + Ok((encoding, chunks)) => (encoding, chunks), _ => panic!("Unexpected tokenization failure"), };