From 5ba7805f1caca1362ede8ac4434961e1113f79c1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 17 Sep 2024 16:16:51 +0200 Subject: [PATCH] We can have a tokenizer anywhere. --- router/client/src/pb/generate.v2.rs | 647 ++++++++++++++++++++++++++++ router/client/src/pb/mod.rs | 6 + router/src/infer/mod.rs | 4 +- router/src/lib.rs | 80 ++++ router/src/server.rs | 189 +++----- router/src/validation.rs | 191 ++++---- 6 files changed, 885 insertions(+), 232 deletions(-) create mode 100644 router/client/src/pb/generate.v2.rs create mode 100644 router/client/src/pb/mod.rs diff --git a/router/client/src/pb/generate.v2.rs b/router/client/src/pb/generate.v2.rs new file mode 100644 index 00000000..1a206360 --- /dev/null +++ b/router/client/src/pb/generate.v2.rs @@ -0,0 +1,647 @@ +// This file is @generated by prost-build. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct HealthRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct HealthResponse {} +/// / Empty request +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InfoRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InfoResponse { + #[prost(bool, tag = "1")] + pub requires_padding: bool, + #[prost(string, tag = "2")] + pub dtype: ::prost::alloc::string::String, + #[prost(string, tag = "3")] + pub device_type: ::prost::alloc::string::String, + #[prost(uint32, optional, tag = "4")] + pub window_size: ::core::option::Option, + #[prost(uint32, tag = "5")] + pub speculate: u32, +} +/// / Empty request +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ServiceDiscoveryRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ServiceDiscoveryResponse { + /// / Other shards urls + #[prost(string, repeated, tag = "1")] + pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearCacheRequest { + /// / Optional batch id + #[prost(uint64, optional, tag = "1")] + pub id: ::core::option::Option, +} +/// / Empty response +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearCacheResponse {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct NextTokenChooserParameters { + /// / exponential scaling output probability distribution + #[prost(float, tag = "1")] + pub temperature: f32, + /// / restricting to the k highest probability elements + #[prost(uint32, tag = "2")] + pub top_k: u32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + #[prost(float, tag = "3")] + pub top_p: f32, + /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off + #[prost(float, tag = "4")] + pub typical_p: f32, + /// / apply sampling on the logits + #[prost(bool, tag = "5")] + pub do_sample: bool, + /// / random seed for sampling + #[prost(uint64, tag = "6")] + pub seed: u64, + /// / repetition penalty + #[prost(float, tag = "7")] + pub repetition_penalty: f32, + /// / frequency penalty + #[prost(float, tag = "9")] + pub frequency_penalty: f32, + /// / token watermarking using "A Watermark for Large Language Models" + #[prost(bool, tag = "8")] + pub watermark: bool, + /// / grammar (applied if not empty) + #[prost(string, tag = "10")] + pub grammar: ::prost::alloc::string::String, + /// / grammar type + #[prost(enumeration = "GrammarType", tag = "11")] + pub grammar_type: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StoppingCriteriaParameters { + /// / Maximum number of generated tokens + #[prost(uint32, tag = "1")] + pub max_new_tokens: u32, + /// / Optional stopping sequences + #[prost(string, repeated, tag = "2")] + pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// / Ignore end of sequence token + /// / used for benchmarking + #[prost(bool, tag = "3")] + pub ignore_eos_token: bool, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Request { + /// / Request ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / The generation context + #[prost(string, tag = "2")] + pub inputs: ::prost::alloc::string::String, + /// / Context truncation + #[prost(uint32, tag = "3")] + pub truncate: u32, + /// / Next Token Chooser Parameters + #[prost(message, optional, tag = "4")] + pub parameters: ::core::option::Option, + /// / Stopping Criteria Parameters + #[prost(message, optional, tag = "5")] + pub stopping_parameters: ::core::option::Option, + /// / Return prefill logprobs + #[prost(bool, tag = "6")] + pub prefill_logprobs: bool, + /// / Return most likely n tokens + #[prost(uint32, tag = "7")] + pub top_n_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Batch { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / Individual requests + #[prost(message, repeated, tag = "2")] + pub requests: ::prost::alloc::vec::Vec, + /// / Batch size (==len(requests)) + #[prost(uint32, tag = "3")] + pub size: u32, + /// / Maximum number of tokens this batch will grow to + #[prost(uint32, tag = "4")] + pub max_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CachedBatch { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub id: u64, + /// / Individual requests ids + #[prost(uint64, repeated, tag = "2")] + pub request_ids: ::prost::alloc::vec::Vec, + /// / Batch size (==len(requests)) + #[prost(uint32, tag = "3")] + pub size: u32, + /// / Maximum number of tokens this batch will grow to + #[prost(uint32, tag = "4")] + pub max_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GeneratedText { + /// / Output + #[prost(string, tag = "1")] + pub text: ::prost::alloc::string::String, + /// / Number of generated tokens + #[prost(uint32, tag = "2")] + pub generated_tokens: u32, + /// / Finish reason + #[prost(enumeration = "FinishReason", tag = "3")] + pub finish_reason: i32, + /// / Seed + #[prost(uint64, optional, tag = "4")] + pub seed: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Tokens { + /// / Token IDs + #[prost(uint32, repeated, tag = "1")] + pub ids: ::prost::alloc::vec::Vec, + /// / Logprobs + #[prost(float, repeated, tag = "2")] + pub logprobs: ::prost::alloc::vec::Vec, + /// / tokens + #[prost(string, repeated, tag = "3")] + pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// / special + #[prost(bool, repeated, tag = "4")] + pub is_special: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Generation { + /// / Request ID + #[prost(uint64, tag = "1")] + pub request_id: u64, + /// / Prefill tokens (optional) + #[prost(message, optional, tag = "2")] + pub prefill_tokens: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub tokens: ::core::option::Option, + /// / Complete generated text + #[prost(message, optional, tag = "4")] + pub generated_text: ::core::option::Option, + /// / Top tokens + #[prost(message, repeated, tag = "5")] + pub top_tokens: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FilterBatchRequest { + /// / Batch ID + #[prost(uint64, tag = "1")] + pub batch_id: u64, + /// / Requests to keep + #[prost(uint64, repeated, tag = "2")] + pub request_ids: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FilterBatchResponse { + /// / Filtered Batch (cached) + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrefillRequest { + /// / Batch + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrefillResponse { + /// / Generation + #[prost(message, repeated, tag = "1")] + pub generations: ::prost::alloc::vec::Vec, + /// / Next batch (cached) + #[prost(message, optional, tag = "2")] + pub batch: ::core::option::Option, + /// / Forward elapsed time in nanoseconds + #[prost(uint64, tag = "3")] + pub forward_ns: u64, + /// / Decode elapsed time in nanoseconds + #[prost(uint64, tag = "4")] + pub decode_ns: u64, + /// / Total elapsed time in nanoseconds + #[prost(uint64, tag = "5")] + pub total_ns: u64, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DecodeRequest { + /// / Cached batches + #[prost(message, repeated, tag = "1")] + pub batches: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DecodeResponse { + /// / Decodes + #[prost(message, repeated, tag = "1")] + pub generations: ::prost::alloc::vec::Vec, + /// / Next batch (cached) + #[prost(message, optional, tag = "2")] + pub batch: ::core::option::Option, + /// / Forward elapsed time in nanoseconds + #[prost(uint64, tag = "3")] + pub forward_ns: u64, + /// / Decode elapsed time in nanoseconds + #[prost(uint64, tag = "4")] + pub decode_ns: u64, + /// / Total elapsed time in nanoseconds + #[prost(uint64, tag = "5")] + pub total_ns: u64, + /// / Concatenate elapsed time in nanoseconds + #[prost(uint64, optional, tag = "6")] + pub concat_ns: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WarmupRequest { + /// / Batch to warmup on + #[prost(message, optional, tag = "1")] + pub batch: ::core::option::Option, + #[prost(uint32, tag = "2")] + pub max_input_length: u32, + #[prost(uint32, tag = "3")] + pub max_prefill_tokens: u32, + #[prost(uint32, tag = "4")] + pub max_total_tokens: u32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct WarmupResponse { + /// / Maximum number of tokens supported by the model + #[prost(uint32, optional, tag = "1")] + pub max_supported_total_tokens: ::core::option::Option, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum GrammarType { + None = 0, + Json = 1, + Regex = 2, +} +impl GrammarType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + GrammarType::None => "GRAMMAR_TYPE_NONE", + GrammarType::Json => "GRAMMAR_TYPE_JSON", + GrammarType::Regex => "GRAMMAR_TYPE_REGEX", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "GRAMMAR_TYPE_NONE" => Some(Self::None), + "GRAMMAR_TYPE_JSON" => Some(Self::Json), + "GRAMMAR_TYPE_REGEX" => Some(Self::Regex), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum FinishReason { + Length = 0, + EosToken = 1, + StopSequence = 2, +} +impl FinishReason { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + FinishReason::Length => "FINISH_REASON_LENGTH", + FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN", + FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "FINISH_REASON_LENGTH" => Some(Self::Length), + "FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken), + "FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence), + _ => None, + } + } +} +/// Generated client implementations. +pub mod text_generation_service_client { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + #[derive(Debug, Clone)] + pub struct TextGenerationServiceClient { + inner: tonic::client::Grpc, + } + impl TextGenerationServiceClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl TextGenerationServiceClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> TextGenerationServiceClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + Send + Sync, + { + TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + /// / Model Info + pub async fn info( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Info", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Info")); + self.inner.unary(req, path, codec).await + } + /// / Service discovery + pub async fn service_discovery( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/ServiceDiscovery", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new( + "generate.v2.TextGenerationService", + "ServiceDiscovery", + ), + ); + self.inner.unary(req, path, codec).await + } + /// / Empties batch cache + pub async fn clear_cache( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/ClearCache", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("generate.v2.TextGenerationService", "ClearCache"), + ); + self.inner.unary(req, path, codec).await + } + /// / Remove requests from a cached batch + pub async fn filter_batch( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/FilterBatch", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("generate.v2.TextGenerationService", "FilterBatch"), + ); + self.inner.unary(req, path, codec).await + } + /// / Warmup the model and compute max cache size + pub async fn warmup( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Warmup", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Warmup")); + self.inner.unary(req, path, codec).await + } + /// / Prefill batch and decode first token + pub async fn prefill( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Prefill", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Prefill")); + self.inner.unary(req, path, codec).await + } + /// / Decode token for a list of prefilled batches + pub async fn decode( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Decode", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Decode")); + self.inner.unary(req, path, codec).await + } + /// / Health check + pub async fn health( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/generate.v2.TextGenerationService/Health", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("generate.v2.TextGenerationService", "Health")); + self.inner.unary(req, path, codec).await + } + } +} diff --git a/router/client/src/pb/mod.rs b/router/client/src/pb/mod.rs new file mode 100644 index 00000000..095ead1f --- /dev/null +++ b/router/client/src/pb/mod.rs @@ -0,0 +1,6 @@ +// This file is @generated by prost-build. +pub mod generate { + pub mod v2 { + include!("generate.v2.rs"); + } +} diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 896f4f43..557e03cb 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -135,7 +135,7 @@ impl Infer { pub(crate) async fn tokenize( &self, request: GenerateRequest, - ) -> Result, InferError> { + ) -> Result { // Tokenize request let inputs = request.inputs; let add_special_tokens = request.add_special_tokens; @@ -150,7 +150,7 @@ impl Infer { })?; // Return Encoding - Ok(encoding.map(|(encoding, _)| encoding)) + Ok(encoding.0) } /// Apply the chat template to the chat request diff --git a/router/src/lib.rs b/router/src/lib.rs index 7c40c7e3..f01c2f7f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -14,11 +14,91 @@ mod vertex; use crate::infer::{Infer, InferError}; use crate::server::prepare_chat_input; +use pyo3::prelude::*; +use pyo3::types::IntoPyDict; use serde::{Deserialize, Serialize}; +use tokenizers::Encoding; use tracing::warn; use utoipa::ToSchema; use validation::Validation; +#[derive(Clone)] +pub enum Tokenizer { + Python { + tokenizer_name: String, + revision: Option, + }, + Rust(tokenizers::Tokenizer), +} + +impl Tokenizer { + fn into_owned<'a>(self, py: Python<'a>) -> OwnedTokenizer<'a> { + match self { + Self::Python { + tokenizer_name, + revision, + } => { + let pytok = || -> pyo3::PyResult> { + let transformers = py.import_bound("transformers")?; + let auto = transformers.getattr("AutoTokenizer")?; + let from_pretrained = auto.getattr("from_pretrained")?; + let args = (tokenizer_name.to_string(),); + let kwargs = if let Some(rev) = &revision { + [("revision", rev.to_string())].into_py_dict_bound(py) + } else { + pyo3::types::PyDict::new_bound(py) + }; + let tokenizer = from_pretrained.call(args, Some(&kwargs))?; + Ok(tokenizer) + }() + .expect("Cannot load the tokenizer"); + tracing::info!("Loaded a python tokenizer"); + OwnedTokenizer::Python(pytok) + } + Self::Rust(tok) => OwnedTokenizer::Rust(tok), + } + } +} + +pub enum OwnedTokenizer<'a> { + Python(pyo3::Bound<'a, pyo3::PyAny>), + Rust(tokenizers::Tokenizer), +} + +impl<'a> OwnedTokenizer<'a> { + fn encode( + &self, + query: String, + add_special_tokens: bool, + ) -> Result> { + match self { + Self::Python(pytok) => { + let py = pytok.py(); + let kwargs = [ + ("text", query.into_py(py)), + ("add_special_tokens", add_special_tokens.into_py(py)), + ] + .into_py_dict_bound(py); + let encode = pytok.getattr("encode")?; + let input_ids: Vec = encode.call((), Some(&kwargs))?.extract()?; + Ok(Encoding::new( + input_ids, + vec![], // type ids + vec![], // tokens (strings) + vec![], // words + vec![], // offsets + vec![], // special_tokens_mask + vec![], // attention_mask + vec![], // overflowing + std::collections::HashMap::new(), //sequence_ranges + )) + } + Self::Rust(tok) => tok.encode(query, add_special_tokens), + } + } +} + + /// Hub type #[derive(Clone, Debug, Deserialize)] pub struct HubModelInfo { diff --git a/router/src/server.rs b/router/src/server.rs index eb1d2544..2f86bcb6 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -19,7 +19,8 @@ use crate::{ GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse, - TextMessage, Token, TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation, + TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage, + Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -52,9 +53,8 @@ use std::convert::Infallible; use std::fs::File; use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::path::{Path, PathBuf}; +use std::path::Path; use thiserror::Error; -use tokenizers::Tokenizer; use tokio::select; use tokio::signal; use tokio::sync::oneshot; @@ -161,40 +161,30 @@ async fn get_chat_tokenize( let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0; let input = generate_request.inputs.clone(); let encoding = infer.tokenize(generate_request).await?; - if let Some(encoding) = encoding { - let tokens: Vec = encoding - .get_ids() - .iter() - .zip(encoding.get_offsets()) - .map(|(&id, &(start, stop))| { - let text = input - .chars() - .skip(start) - .take(stop - start) - .collect::(); - SimpleToken { - id, - text, - start, - stop, - } - }) - .collect(); + let tokens: Vec = encoding + .get_ids() + .iter() + .zip(encoding.get_offsets()) + .map(|(&id, &(start, stop))| { + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); + SimpleToken { + id, + text, + start, + stop, + } + }) + .collect(); - let resp = ChatTokenizeResponse { - tokenize_response: TokenizeResponse(tokens), - templated_text: input, - }; - Ok((HeaderMap::new(), Json(resp))) - } else { - Err(( - StatusCode::NOT_FOUND, - Json(ErrorResponse { - error: "No fast tokenizer or tokenizer.json for this model".to_string(), - error_type: "no fast tokenizer".to_string(), - }), - )) - } + let resp = ChatTokenizeResponse { + tokenize_response: TokenizeResponse(tokens), + templated_text: input, + }; + Ok((HeaderMap::new(), Json(resp))) } #[utoipa::path( @@ -1458,35 +1448,25 @@ async fn tokenize( ) -> Result, (StatusCode, Json)> { let input = req.inputs.clone(); let encoding = infer.tokenize(req).await?; - if let Some(encoding) = encoding { - let tokens: Vec = encoding - .get_ids() - .iter() - .zip(encoding.get_offsets()) - .map(|(&id, &(start, stop))| { - let text = input - .chars() - .skip(start) - .take(stop - start) - .collect::(); - SimpleToken { - id, - text, - start, - stop, - } - }) - .collect(); - Ok(Json(TokenizeResponse(tokens))) - } else { - Err(( - StatusCode::NOT_FOUND, - Json(ErrorResponse { - error: "No fast tokenizer or tokenizer.json for this model".to_string(), - error_type: "no fast tokenizer".to_string(), - }), - )) - } + let tokens: Vec = encoding + .get_ids() + .iter() + .zip(encoding.get_offsets()) + .map(|(&id, &(start, stop))| { + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); + SimpleToken { + id, + text, + start, + stop, + } + }) + .collect(); + Ok(Json(TokenizeResponse(tokens))) } /// Prometheus metrics scrape endpoint @@ -1687,7 +1667,6 @@ pub async fn run( // Load tokenizer and model info let ( - tokenizer_filename, config_filename, tokenizer_config_filename, preprocessor_config_filename, @@ -1695,7 +1674,6 @@ pub async fn run( model_info, ) = match api { Type::None => ( - Some(local_path.join("tokenizer.json")), Some(local_path.join("config.json")), Some(local_path.join("tokenizer_config.json")), Some(local_path.join("preprocessor_config.json")), @@ -1709,10 +1687,6 @@ pub async fn run( revision.clone().unwrap_or_else(|| "main".to_string()), )); - let tokenizer_filename = match api_repo.get("tokenizer.json").await { - Ok(tokenizer_filename) => Some(tokenizer_filename), - Err(_) => get_base_tokenizer(&api, &api_repo).await, - }; let config_filename = api_repo.get("config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); @@ -1725,7 +1699,6 @@ pub async fn run( None }; ( - tokenizer_filename, config_filename, tokenizer_config_filename, preprocessor_config_filename, @@ -1740,7 +1713,6 @@ pub async fn run( revision.clone().unwrap_or_else(|| "main".to_string()), )); ( - repo.get("tokenizer.json"), repo.get("config.json"), repo.get("tokenizer_config.json"), repo.get("preprocessor_config.json"), @@ -1762,21 +1734,22 @@ pub async fn run( HubTokenizerConfig::default() }); - let tokenizer: Option = tokenizer_filename.and_then(|filename| { + let tokenizer: Tokenizer = { use pyo3::prelude::*; - let convert = pyo3::Python::with_gil(|py| -> PyResult<()> { + pyo3::Python::with_gil(|py| -> PyResult<()> { let transformers = py.import_bound("transformers")?; let auto = transformers.getattr("AutoTokenizer")?; let from_pretrained = auto.getattr("from_pretrained")?; let args = (tokenizer_name.to_string(),); - let kwargs = [ - ( - "revision", - (revision.clone().unwrap_or_else(|| "main".to_string())).into_py(py), - ), - ("trust_remote_code", trust_remote_code.into_py(py)), - ] - .into_py_dict_bound(py); + let kwargs = if let Some(rev) = &revision { + [ + ("revision", rev.to_string().into_py(py)), + ("trust_remote_code", trust_remote_code.into_py(py)), + ] + .into_py_dict_bound(py) + } else { + [("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py) + }; let tokenizer = from_pretrained.call(args, Some(&kwargs))?; let save = tokenizer.getattr("save_pretrained")?; let args = ("out".to_string(),); @@ -1785,16 +1758,18 @@ pub async fn run( }) .inspect_err(|err| { tracing::error!("Failed to import python tokenizer {err}"); - }); - let filename = if convert.is_ok() { - // If we have correctly loaded and resaved with transformers - // We might have modified the tokenizer.json according to transformers - "out/tokenizer.json".into() + }) + .expect("We cannot load a tokenizer"); + let filename = "out/tokenizer.json"; + if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) { + Tokenizer::Rust(tok) } else { - filename - }; - Tokenizer::from_file(filename).ok() - }); + Tokenizer::Python { + tokenizer_name: tokenizer_name.clone(), + revision: revision.clone(), + } + } + }; let config: Option = config_filename.and_then(|filename| { std::fs::read_to_string(filename) @@ -1822,10 +1797,6 @@ pub async fn run( preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); tracing::info!("Using config {config:?}"); - if tokenizer.is_none() { - tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); - tracing::warn!("Rust input length validation and truncation is disabled"); - } // Only send usage stats when TGI is run in container and the function returns Some let is_container = matches!(usage_stats::is_container(), Ok(true)); @@ -1940,7 +1911,7 @@ async fn start( validation_workers: usize, api_key: Option, config: Option, - (tokenizer, tokenizer_config): (Option, HubTokenizerConfig), + (tokenizer, tokenizer_config): (Tokenizer, HubTokenizerConfig), (preprocessor_config, processor_config): (Option, HubProcessorConfig), hostname: String, port: u16, @@ -2400,30 +2371,6 @@ pub async fn get_hub_model_info(api: &ApiRepo) -> Option { } } -/// get base tokenizer -pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { - let config_filename = api_repo.get("config.json").await.ok()?; - - // Open the file in read-only mode with buffer. - let file = File::open(config_filename).ok()?; - let reader = BufReader::new(file); - - // Read the JSON contents of the file as an instance of `User`. - let config: serde_json::Value = serde_json::from_reader(reader).ok()?; - - if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { - let api_base_repo = api.repo(Repo::with_revision( - base_model_id.to_string(), - RepoType::Model, - "main".to_string(), - )); - - api_base_repo.get("tokenizer.json").await.ok() - } else { - None - } -} - /// get tokenizer_config from the Huggingface Hub pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; diff --git a/router/src/validation.rs b/router/src/validation.rs index 85b4220b..2cd92e5d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -4,6 +4,7 @@ use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput} use crate::{ GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, }; +use crate::{OwnedTokenizer, Tokenizer}; use base64::{engine::general_purpose::STANDARD, Engine}; use image::{ImageFormat, ImageReader}; use jsonschema::{Draft, JSONSchema}; @@ -13,7 +14,6 @@ use std::io::Cursor; use std::iter; use std::sync::Arc; use thiserror::Error; -use tokenizers::tokenizer::Tokenizer; use tokio::sync::mpsc; use tokio::sync::oneshot; use tracing::{instrument, Span}; @@ -30,14 +30,14 @@ pub struct Validation { max_total_tokens: usize, disable_grammar_support: bool, /// Channel to communicate with the background tokenization task - sender: Option>, + sender: mpsc::UnboundedSender, } impl Validation { #[allow(clippy::too_many_arguments)] pub(crate) fn new( workers: usize, - tokenizer: Option, + tokenizer: Tokenizer, config: Option, preprocessor_config: Option, max_best_of: usize, @@ -47,8 +47,13 @@ impl Validation { max_total_tokens: usize, disable_grammar_support: bool, ) -> Self { + let workers = if let Tokenizer::Python { .. } = &tokenizer { + 1 + } else { + workers + }; // If we have a fast tokenizer - let sender = if let Some(tokenizer) = tokenizer { + let sender = { // Create round robin channel let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel(); let mut senders = Vec::with_capacity(workers); @@ -75,9 +80,7 @@ impl Validation { // Create tokenization round robin task tokio::spawn(round_robin_task(validation_round_robin_receiver, senders)); - Some(validation_sender) - } else { - None + validation_sender }; Self { @@ -97,28 +100,25 @@ impl Validation { inputs: String, add_special_tokens: bool, truncate: Option, - ) -> Result)>, ValidationError> { + ) -> Result<(tokenizers::Encoding, Vec), ValidationError> { // If we have a fast tokenizer - if let Some(sender) = &self.sender { - // Create response channel - let (response_sender, response_receiver) = oneshot::channel(); - // Send request to the background validation task - // Unwrap is safe here - sender - .send(( - (inputs, add_special_tokens, truncate), - response_sender, - Span::current(), - )) - .unwrap(); + // Create response channel + let (response_sender, response_receiver) = oneshot::channel(); + // Send request to the background validation task + // Unwrap is safe here + let _ = &self + .sender + .send(( + (inputs, add_special_tokens, truncate), + response_sender, + Span::current(), + )) + .unwrap(); - // Await on response channel - // Unwrap is safe here - let encoding = response_receiver.await.unwrap()?; - Ok(Some(encoding)) - } else { - Ok(None) - } + // Await on response channel + // Unwrap is safe here + let encoding = response_receiver.await.unwrap()?; + Ok(encoding) } #[allow(clippy::type_complexity)] @@ -131,76 +131,46 @@ impl Validation { max_new_tokens: Option, ) -> Result<(Vec, Option>, usize, u32), ValidationError> { // If we have a fast tokenizer - if let Some((encoding, inputs)) = self + let (encoding, inputs) = self .tokenize(inputs.clone(), add_special_tokens, truncate) - .await? - { - // Create response channel - let input_length = if let Some(truncate) = truncate { - std::cmp::min(encoding.len(), truncate) - } else { - encoding.len() - }; + .await?; + // Create response channel + let input_length = if let Some(truncate) = truncate { + std::cmp::min(encoding.len(), truncate) + } else { + encoding.len() + }; - // Get total tokens - let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { - max_new_tokens - } else { - self.max_total_tokens.saturating_sub(input_length) as u32 - }; - let total_tokens = input_length + max_new_tokens as usize; + // Get total tokens + let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { + max_new_tokens + } else { + self.max_total_tokens.saturating_sub(input_length) as u32 + }; + let total_tokens = input_length + max_new_tokens as usize; - // Validate MaxTotalTokens - if total_tokens > self.max_total_tokens { - return Err(ValidationError::MaxTotalTokens( - self.max_total_tokens, - input_length, - max_new_tokens, - )); - } - - // Validate InputLength - if input_length > self.max_input_length { - return Err(ValidationError::InputLength( - self.max_input_length, - input_length, - )); - } - - let ids = encoding.get_ids(); - let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned(); - - metrics::histogram!("tgi_request_input_length").record(input_length as f64); - Ok((inputs, Some(input_ids), input_length, max_new_tokens)) - } - // Return inputs without validation - else { - // In this case, we don't know the real length in tokens of the inputs - // However, the inputs will be truncated by the python servers - // We make sure that truncate + max_new_tokens <= self.max_total_tokens - let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { - max_new_tokens - } else if let Some(truncate) = truncate { - self.max_total_tokens.saturating_sub(truncate) as u32 - } else { - return Err(ValidationError::UnsetMaxNewTokens); - }; - let mut input_length = truncate.unwrap_or(self.max_input_length); - - // We don't have a tokenizer, therefore we have no idea how long is the query, let - // them through and hope for the best. - // Validate MaxNewTokens - if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { - input_length = input_length.saturating_sub(max_new_tokens as usize); - } - - Ok(( - vec![Chunk::Text(inputs)], - None, + // Validate MaxTotalTokens + if total_tokens > self.max_total_tokens { + return Err(ValidationError::MaxTotalTokens( + self.max_total_tokens, input_length, max_new_tokens, - )) + )); } + + // Validate InputLength + if input_length > self.max_input_length { + return Err(ValidationError::InputLength( + self.max_input_length, + input_length, + )); + } + + let ids = encoding.get_ids(); + let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned(); + + metrics::histogram!("tgi_request_input_length").record(input_length as f64); + Ok((inputs, Some(input_ids), input_length, max_new_tokens)) } /// Validate a payload and get the number of tokens in the input @@ -464,23 +434,26 @@ fn tokenizer_worker( preprocessor_config: Option, mut receiver: mpsc::UnboundedReceiver, ) { - // Loop over requests - while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = - receiver.blocking_recv() - { - parent_span.in_scope(|| { - response_tx - .send(prepare_input( - inputs, - truncate, - add_special_tokens, - &tokenizer, - config.as_ref(), - preprocessor_config.as_ref(), - )) - .unwrap_or(()) - }) - } + pyo3::Python::with_gil(|py| { + let tokenizer = tokenizer.into_owned(py); + // Loop over requests + while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = + receiver.blocking_recv() + { + parent_span.in_scope(|| { + response_tx + .send(prepare_input( + inputs, + truncate, + add_special_tokens, + &tokenizer, + config.as_ref(), + preprocessor_config.as_ref(), + )) + .unwrap_or(()) + }) + } + }); } fn format_from_mimetype(mimetype: &str) -> Option { @@ -612,7 +585,7 @@ fn prepare_input( inputs: String, _truncate: Option, add_special_tokens: bool, - tokenizer: &Tokenizer, + tokenizer: &OwnedTokenizer, config: Option<&Config>, preprocessor_config: Option<&HubPreprocessorConfig>, ) -> Result<(tokenizers::Encoding, Vec), ValidationError> {