mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
We can have a tokenizer anywhere.
This commit is contained in:
parent
ed87b464b4
commit
5ba7805f1c
647
router/client/src/pb/generate.v2.rs
Normal file
647
router/client/src/pb/generate.v2.rs
Normal file
@ -0,0 +1,647 @@
|
|||||||
|
// This file is @generated by prost-build.
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct HealthRequest {}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct HealthResponse {}
|
||||||
|
/// / Empty request
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct InfoRequest {}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct InfoResponse {
|
||||||
|
#[prost(bool, tag = "1")]
|
||||||
|
pub requires_padding: bool,
|
||||||
|
#[prost(string, tag = "2")]
|
||||||
|
pub dtype: ::prost::alloc::string::String,
|
||||||
|
#[prost(string, tag = "3")]
|
||||||
|
pub device_type: ::prost::alloc::string::String,
|
||||||
|
#[prost(uint32, optional, tag = "4")]
|
||||||
|
pub window_size: ::core::option::Option<u32>,
|
||||||
|
#[prost(uint32, tag = "5")]
|
||||||
|
pub speculate: u32,
|
||||||
|
}
|
||||||
|
/// / Empty request
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct ServiceDiscoveryRequest {}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct ServiceDiscoveryResponse {
|
||||||
|
/// / Other shards urls
|
||||||
|
#[prost(string, repeated, tag = "1")]
|
||||||
|
pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct ClearCacheRequest {
|
||||||
|
/// / Optional batch id
|
||||||
|
#[prost(uint64, optional, tag = "1")]
|
||||||
|
pub id: ::core::option::Option<u64>,
|
||||||
|
}
|
||||||
|
/// / Empty response
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct ClearCacheResponse {}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct NextTokenChooserParameters {
|
||||||
|
/// / exponential scaling output probability distribution
|
||||||
|
#[prost(float, tag = "1")]
|
||||||
|
pub temperature: f32,
|
||||||
|
/// / restricting to the k highest probability elements
|
||||||
|
#[prost(uint32, tag = "2")]
|
||||||
|
pub top_k: u32,
|
||||||
|
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
#[prost(float, tag = "3")]
|
||||||
|
pub top_p: f32,
|
||||||
|
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
#[prost(float, tag = "4")]
|
||||||
|
pub typical_p: f32,
|
||||||
|
/// / apply sampling on the logits
|
||||||
|
#[prost(bool, tag = "5")]
|
||||||
|
pub do_sample: bool,
|
||||||
|
/// / random seed for sampling
|
||||||
|
#[prost(uint64, tag = "6")]
|
||||||
|
pub seed: u64,
|
||||||
|
/// / repetition penalty
|
||||||
|
#[prost(float, tag = "7")]
|
||||||
|
pub repetition_penalty: f32,
|
||||||
|
/// / frequency penalty
|
||||||
|
#[prost(float, tag = "9")]
|
||||||
|
pub frequency_penalty: f32,
|
||||||
|
/// / token watermarking using "A Watermark for Large Language Models"
|
||||||
|
#[prost(bool, tag = "8")]
|
||||||
|
pub watermark: bool,
|
||||||
|
/// / grammar (applied if not empty)
|
||||||
|
#[prost(string, tag = "10")]
|
||||||
|
pub grammar: ::prost::alloc::string::String,
|
||||||
|
/// / grammar type
|
||||||
|
#[prost(enumeration = "GrammarType", tag = "11")]
|
||||||
|
pub grammar_type: i32,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct StoppingCriteriaParameters {
|
||||||
|
/// / Maximum number of generated tokens
|
||||||
|
#[prost(uint32, tag = "1")]
|
||||||
|
pub max_new_tokens: u32,
|
||||||
|
/// / Optional stopping sequences
|
||||||
|
#[prost(string, repeated, tag = "2")]
|
||||||
|
pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||||
|
/// / Ignore end of sequence token
|
||||||
|
/// / used for benchmarking
|
||||||
|
#[prost(bool, tag = "3")]
|
||||||
|
pub ignore_eos_token: bool,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Request {
|
||||||
|
/// / Request ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub id: u64,
|
||||||
|
/// / The generation context
|
||||||
|
#[prost(string, tag = "2")]
|
||||||
|
pub inputs: ::prost::alloc::string::String,
|
||||||
|
/// / Context truncation
|
||||||
|
#[prost(uint32, tag = "3")]
|
||||||
|
pub truncate: u32,
|
||||||
|
/// / Next Token Chooser Parameters
|
||||||
|
#[prost(message, optional, tag = "4")]
|
||||||
|
pub parameters: ::core::option::Option<NextTokenChooserParameters>,
|
||||||
|
/// / Stopping Criteria Parameters
|
||||||
|
#[prost(message, optional, tag = "5")]
|
||||||
|
pub stopping_parameters: ::core::option::Option<StoppingCriteriaParameters>,
|
||||||
|
/// / Return prefill logprobs
|
||||||
|
#[prost(bool, tag = "6")]
|
||||||
|
pub prefill_logprobs: bool,
|
||||||
|
/// / Return most likely n tokens
|
||||||
|
#[prost(uint32, tag = "7")]
|
||||||
|
pub top_n_tokens: u32,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Batch {
|
||||||
|
/// / Batch ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub id: u64,
|
||||||
|
/// / Individual requests
|
||||||
|
#[prost(message, repeated, tag = "2")]
|
||||||
|
pub requests: ::prost::alloc::vec::Vec<Request>,
|
||||||
|
/// / Batch size (==len(requests))
|
||||||
|
#[prost(uint32, tag = "3")]
|
||||||
|
pub size: u32,
|
||||||
|
/// / Maximum number of tokens this batch will grow to
|
||||||
|
#[prost(uint32, tag = "4")]
|
||||||
|
pub max_tokens: u32,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct CachedBatch {
|
||||||
|
/// / Batch ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub id: u64,
|
||||||
|
/// / Individual requests ids
|
||||||
|
#[prost(uint64, repeated, tag = "2")]
|
||||||
|
pub request_ids: ::prost::alloc::vec::Vec<u64>,
|
||||||
|
/// / Batch size (==len(requests))
|
||||||
|
#[prost(uint32, tag = "3")]
|
||||||
|
pub size: u32,
|
||||||
|
/// / Maximum number of tokens this batch will grow to
|
||||||
|
#[prost(uint32, tag = "4")]
|
||||||
|
pub max_tokens: u32,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct GeneratedText {
|
||||||
|
/// / Output
|
||||||
|
#[prost(string, tag = "1")]
|
||||||
|
pub text: ::prost::alloc::string::String,
|
||||||
|
/// / Number of generated tokens
|
||||||
|
#[prost(uint32, tag = "2")]
|
||||||
|
pub generated_tokens: u32,
|
||||||
|
/// / Finish reason
|
||||||
|
#[prost(enumeration = "FinishReason", tag = "3")]
|
||||||
|
pub finish_reason: i32,
|
||||||
|
/// / Seed
|
||||||
|
#[prost(uint64, optional, tag = "4")]
|
||||||
|
pub seed: ::core::option::Option<u64>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Tokens {
|
||||||
|
/// / Token IDs
|
||||||
|
#[prost(uint32, repeated, tag = "1")]
|
||||||
|
pub ids: ::prost::alloc::vec::Vec<u32>,
|
||||||
|
/// / Logprobs
|
||||||
|
#[prost(float, repeated, tag = "2")]
|
||||||
|
pub logprobs: ::prost::alloc::vec::Vec<f32>,
|
||||||
|
/// / tokens
|
||||||
|
#[prost(string, repeated, tag = "3")]
|
||||||
|
pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||||
|
/// / special
|
||||||
|
#[prost(bool, repeated, tag = "4")]
|
||||||
|
pub is_special: ::prost::alloc::vec::Vec<bool>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct Generation {
|
||||||
|
/// / Request ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub request_id: u64,
|
||||||
|
/// / Prefill tokens (optional)
|
||||||
|
#[prost(message, optional, tag = "2")]
|
||||||
|
pub prefill_tokens: ::core::option::Option<Tokens>,
|
||||||
|
#[prost(message, optional, tag = "3")]
|
||||||
|
pub tokens: ::core::option::Option<Tokens>,
|
||||||
|
/// / Complete generated text
|
||||||
|
#[prost(message, optional, tag = "4")]
|
||||||
|
pub generated_text: ::core::option::Option<GeneratedText>,
|
||||||
|
/// / Top tokens
|
||||||
|
#[prost(message, repeated, tag = "5")]
|
||||||
|
pub top_tokens: ::prost::alloc::vec::Vec<Tokens>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct FilterBatchRequest {
|
||||||
|
/// / Batch ID
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub batch_id: u64,
|
||||||
|
/// / Requests to keep
|
||||||
|
#[prost(uint64, repeated, tag = "2")]
|
||||||
|
pub request_ids: ::prost::alloc::vec::Vec<u64>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct FilterBatchResponse {
|
||||||
|
/// / Filtered Batch (cached)
|
||||||
|
#[prost(message, optional, tag = "1")]
|
||||||
|
pub batch: ::core::option::Option<CachedBatch>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct PrefillRequest {
|
||||||
|
/// / Batch
|
||||||
|
#[prost(message, optional, tag = "1")]
|
||||||
|
pub batch: ::core::option::Option<Batch>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct PrefillResponse {
|
||||||
|
/// / Generation
|
||||||
|
#[prost(message, repeated, tag = "1")]
|
||||||
|
pub generations: ::prost::alloc::vec::Vec<Generation>,
|
||||||
|
/// / Next batch (cached)
|
||||||
|
#[prost(message, optional, tag = "2")]
|
||||||
|
pub batch: ::core::option::Option<CachedBatch>,
|
||||||
|
/// / Forward elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "3")]
|
||||||
|
pub forward_ns: u64,
|
||||||
|
/// / Decode elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "4")]
|
||||||
|
pub decode_ns: u64,
|
||||||
|
/// / Total elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "5")]
|
||||||
|
pub total_ns: u64,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct DecodeRequest {
|
||||||
|
/// / Cached batches
|
||||||
|
#[prost(message, repeated, tag = "1")]
|
||||||
|
pub batches: ::prost::alloc::vec::Vec<CachedBatch>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct DecodeResponse {
|
||||||
|
/// / Decodes
|
||||||
|
#[prost(message, repeated, tag = "1")]
|
||||||
|
pub generations: ::prost::alloc::vec::Vec<Generation>,
|
||||||
|
/// / Next batch (cached)
|
||||||
|
#[prost(message, optional, tag = "2")]
|
||||||
|
pub batch: ::core::option::Option<CachedBatch>,
|
||||||
|
/// / Forward elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "3")]
|
||||||
|
pub forward_ns: u64,
|
||||||
|
/// / Decode elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "4")]
|
||||||
|
pub decode_ns: u64,
|
||||||
|
/// / Total elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, tag = "5")]
|
||||||
|
pub total_ns: u64,
|
||||||
|
/// / Concatenate elapsed time in nanoseconds
|
||||||
|
#[prost(uint64, optional, tag = "6")]
|
||||||
|
pub concat_ns: ::core::option::Option<u64>,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct WarmupRequest {
|
||||||
|
/// / Batch to warmup on
|
||||||
|
#[prost(message, optional, tag = "1")]
|
||||||
|
pub batch: ::core::option::Option<Batch>,
|
||||||
|
#[prost(uint32, tag = "2")]
|
||||||
|
pub max_input_length: u32,
|
||||||
|
#[prost(uint32, tag = "3")]
|
||||||
|
pub max_prefill_tokens: u32,
|
||||||
|
#[prost(uint32, tag = "4")]
|
||||||
|
pub max_total_tokens: u32,
|
||||||
|
}
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||||
|
pub struct WarmupResponse {
|
||||||
|
/// / Maximum number of tokens supported by the model
|
||||||
|
#[prost(uint32, optional, tag = "1")]
|
||||||
|
pub max_supported_total_tokens: ::core::option::Option<u32>,
|
||||||
|
}
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
|
||||||
|
#[repr(i32)]
|
||||||
|
pub enum GrammarType {
|
||||||
|
None = 0,
|
||||||
|
Json = 1,
|
||||||
|
Regex = 2,
|
||||||
|
}
|
||||||
|
impl GrammarType {
|
||||||
|
/// String value of the enum field names used in the ProtoBuf definition.
|
||||||
|
///
|
||||||
|
/// The values are not transformed in any way and thus are considered stable
|
||||||
|
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
|
||||||
|
pub fn as_str_name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
GrammarType::None => "GRAMMAR_TYPE_NONE",
|
||||||
|
GrammarType::Json => "GRAMMAR_TYPE_JSON",
|
||||||
|
GrammarType::Regex => "GRAMMAR_TYPE_REGEX",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||||
|
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
|
||||||
|
match value {
|
||||||
|
"GRAMMAR_TYPE_NONE" => Some(Self::None),
|
||||||
|
"GRAMMAR_TYPE_JSON" => Some(Self::Json),
|
||||||
|
"GRAMMAR_TYPE_REGEX" => Some(Self::Regex),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
|
||||||
|
#[repr(i32)]
|
||||||
|
pub enum FinishReason {
|
||||||
|
Length = 0,
|
||||||
|
EosToken = 1,
|
||||||
|
StopSequence = 2,
|
||||||
|
}
|
||||||
|
impl FinishReason {
|
||||||
|
/// String value of the enum field names used in the ProtoBuf definition.
|
||||||
|
///
|
||||||
|
/// The values are not transformed in any way and thus are considered stable
|
||||||
|
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
|
||||||
|
pub fn as_str_name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
FinishReason::Length => "FINISH_REASON_LENGTH",
|
||||||
|
FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN",
|
||||||
|
FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||||
|
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
|
||||||
|
match value {
|
||||||
|
"FINISH_REASON_LENGTH" => Some(Self::Length),
|
||||||
|
"FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken),
|
||||||
|
"FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Generated client implementations.
|
||||||
|
pub mod text_generation_service_client {
|
||||||
|
#![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)]
|
||||||
|
use tonic::codegen::*;
|
||||||
|
use tonic::codegen::http::Uri;
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct TextGenerationServiceClient<T> {
|
||||||
|
inner: tonic::client::Grpc<T>,
|
||||||
|
}
|
||||||
|
impl TextGenerationServiceClient<tonic::transport::Channel> {
|
||||||
|
/// Attempt to create a new client by connecting to a given endpoint.
|
||||||
|
pub async fn connect<D>(dst: D) -> Result<Self, tonic::transport::Error>
|
||||||
|
where
|
||||||
|
D: TryInto<tonic::transport::Endpoint>,
|
||||||
|
D::Error: Into<StdError>,
|
||||||
|
{
|
||||||
|
let conn = tonic::transport::Endpoint::new(dst)?.connect().await?;
|
||||||
|
Ok(Self::new(conn))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl<T> TextGenerationServiceClient<T>
|
||||||
|
where
|
||||||
|
T: tonic::client::GrpcService<tonic::body::BoxBody>,
|
||||||
|
T::Error: Into<StdError>,
|
||||||
|
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
|
||||||
|
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
|
||||||
|
{
|
||||||
|
pub fn new(inner: T) -> Self {
|
||||||
|
let inner = tonic::client::Grpc::new(inner);
|
||||||
|
Self { inner }
|
||||||
|
}
|
||||||
|
pub fn with_origin(inner: T, origin: Uri) -> Self {
|
||||||
|
let inner = tonic::client::Grpc::with_origin(inner, origin);
|
||||||
|
Self { inner }
|
||||||
|
}
|
||||||
|
pub fn with_interceptor<F>(
|
||||||
|
inner: T,
|
||||||
|
interceptor: F,
|
||||||
|
) -> TextGenerationServiceClient<InterceptedService<T, F>>
|
||||||
|
where
|
||||||
|
F: tonic::service::Interceptor,
|
||||||
|
T::ResponseBody: Default,
|
||||||
|
T: tonic::codegen::Service<
|
||||||
|
http::Request<tonic::body::BoxBody>,
|
||||||
|
Response = http::Response<
|
||||||
|
<T as tonic::client::GrpcService<tonic::body::BoxBody>>::ResponseBody,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
|
<T as tonic::codegen::Service<
|
||||||
|
http::Request<tonic::body::BoxBody>,
|
||||||
|
>>::Error: Into<StdError> + Send + Sync,
|
||||||
|
{
|
||||||
|
TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor))
|
||||||
|
}
|
||||||
|
/// Compress requests with the given encoding.
|
||||||
|
///
|
||||||
|
/// This requires the server to support it otherwise it might respond with an
|
||||||
|
/// error.
|
||||||
|
#[must_use]
|
||||||
|
pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
|
||||||
|
self.inner = self.inner.send_compressed(encoding);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Enable decompressing responses.
|
||||||
|
#[must_use]
|
||||||
|
pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
|
||||||
|
self.inner = self.inner.accept_compressed(encoding);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Limits the maximum size of a decoded message.
|
||||||
|
///
|
||||||
|
/// Default: `4MB`
|
||||||
|
#[must_use]
|
||||||
|
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
|
||||||
|
self.inner = self.inner.max_decoding_message_size(limit);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Limits the maximum size of an encoded message.
|
||||||
|
///
|
||||||
|
/// Default: `usize::MAX`
|
||||||
|
#[must_use]
|
||||||
|
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
|
||||||
|
self.inner = self.inner.max_encoding_message_size(limit);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// / Model Info
|
||||||
|
pub async fn info(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::InfoRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::InfoResponse>, tonic::Status> {
|
||||||
|
self.inner
|
||||||
|
.ready()
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v2.TextGenerationService/Info",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut()
|
||||||
|
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Info"));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Service discovery
|
||||||
|
pub async fn service_discovery(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::ServiceDiscoveryRequest>,
|
||||||
|
) -> std::result::Result<
|
||||||
|
tonic::Response<super::ServiceDiscoveryResponse>,
|
||||||
|
tonic::Status,
|
||||||
|
> {
|
||||||
|
self.inner
|
||||||
|
.ready()
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v2.TextGenerationService/ServiceDiscovery",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut()
|
||||||
|
.insert(
|
||||||
|
GrpcMethod::new(
|
||||||
|
"generate.v2.TextGenerationService",
|
||||||
|
"ServiceDiscovery",
|
||||||
|
),
|
||||||
|
);
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Empties batch cache
|
||||||
|
pub async fn clear_cache(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::ClearCacheRequest>,
|
||||||
|
) -> std::result::Result<
|
||||||
|
tonic::Response<super::ClearCacheResponse>,
|
||||||
|
tonic::Status,
|
||||||
|
> {
|
||||||
|
self.inner
|
||||||
|
.ready()
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v2.TextGenerationService/ClearCache",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut()
|
||||||
|
.insert(
|
||||||
|
GrpcMethod::new("generate.v2.TextGenerationService", "ClearCache"),
|
||||||
|
);
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Remove requests from a cached batch
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::FilterBatchRequest>,
|
||||||
|
) -> std::result::Result<
|
||||||
|
tonic::Response<super::FilterBatchResponse>,
|
||||||
|
tonic::Status,
|
||||||
|
> {
|
||||||
|
self.inner
|
||||||
|
.ready()
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v2.TextGenerationService/FilterBatch",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut()
|
||||||
|
.insert(
|
||||||
|
GrpcMethod::new("generate.v2.TextGenerationService", "FilterBatch"),
|
||||||
|
);
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Warmup the model and compute max cache size
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::WarmupRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::WarmupResponse>, tonic::Status> {
|
||||||
|
self.inner
|
||||||
|
.ready()
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v2.TextGenerationService/Warmup",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut()
|
||||||
|
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Warmup"));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Prefill batch and decode first token
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::PrefillRequest>,
|
||||||
|
) -> std::result::Result<
|
||||||
|
tonic::Response<super::PrefillResponse>,
|
||||||
|
tonic::Status,
|
||||||
|
> {
|
||||||
|
self.inner
|
||||||
|
.ready()
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v2.TextGenerationService/Prefill",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut()
|
||||||
|
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Prefill"));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Decode token for a list of prefilled batches
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::DecodeRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::DecodeResponse>, tonic::Status> {
|
||||||
|
self.inner
|
||||||
|
.ready()
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v2.TextGenerationService/Decode",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut()
|
||||||
|
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Decode"));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
/// / Health check
|
||||||
|
pub async fn health(
|
||||||
|
&mut self,
|
||||||
|
request: impl tonic::IntoRequest<super::HealthRequest>,
|
||||||
|
) -> std::result::Result<tonic::Response<super::HealthResponse>, tonic::Status> {
|
||||||
|
self.inner
|
||||||
|
.ready()
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tonic::Status::new(
|
||||||
|
tonic::Code::Unknown,
|
||||||
|
format!("Service was not ready: {}", e.into()),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let codec = tonic::codec::ProstCodec::default();
|
||||||
|
let path = http::uri::PathAndQuery::from_static(
|
||||||
|
"/generate.v2.TextGenerationService/Health",
|
||||||
|
);
|
||||||
|
let mut req = request.into_request();
|
||||||
|
req.extensions_mut()
|
||||||
|
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Health"));
|
||||||
|
self.inner.unary(req, path, codec).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
6
router/client/src/pb/mod.rs
Normal file
6
router/client/src/pb/mod.rs
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
// This file is @generated by prost-build.
|
||||||
|
pub mod generate {
|
||||||
|
pub mod v2 {
|
||||||
|
include!("generate.v2.rs");
|
||||||
|
}
|
||||||
|
}
|
@ -135,7 +135,7 @@ impl Infer {
|
|||||||
pub(crate) async fn tokenize(
|
pub(crate) async fn tokenize(
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
) -> Result<tokenizers::Encoding, InferError> {
|
||||||
// Tokenize request
|
// Tokenize request
|
||||||
let inputs = request.inputs;
|
let inputs = request.inputs;
|
||||||
let add_special_tokens = request.add_special_tokens;
|
let add_special_tokens = request.add_special_tokens;
|
||||||
@ -150,7 +150,7 @@ impl Infer {
|
|||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Return Encoding
|
// Return Encoding
|
||||||
Ok(encoding.map(|(encoding, _)| encoding))
|
Ok(encoding.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply the chat template to the chat request
|
/// Apply the chat template to the chat request
|
||||||
|
@ -14,11 +14,91 @@ mod vertex;
|
|||||||
|
|
||||||
use crate::infer::{Infer, InferError};
|
use crate::infer::{Infer, InferError};
|
||||||
use crate::server::prepare_chat_input;
|
use crate::server::prepare_chat_input;
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::types::IntoPyDict;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tokenizers::Encoding;
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum Tokenizer {
|
||||||
|
Python {
|
||||||
|
tokenizer_name: String,
|
||||||
|
revision: Option<String>,
|
||||||
|
},
|
||||||
|
Rust(tokenizers::Tokenizer),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tokenizer {
|
||||||
|
fn into_owned<'a>(self, py: Python<'a>) -> OwnedTokenizer<'a> {
|
||||||
|
match self {
|
||||||
|
Self::Python {
|
||||||
|
tokenizer_name,
|
||||||
|
revision,
|
||||||
|
} => {
|
||||||
|
let pytok = || -> pyo3::PyResult<pyo3::Bound<'a, pyo3::PyAny>> {
|
||||||
|
let transformers = py.import_bound("transformers")?;
|
||||||
|
let auto = transformers.getattr("AutoTokenizer")?;
|
||||||
|
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||||
|
let args = (tokenizer_name.to_string(),);
|
||||||
|
let kwargs = if let Some(rev) = &revision {
|
||||||
|
[("revision", rev.to_string())].into_py_dict_bound(py)
|
||||||
|
} else {
|
||||||
|
pyo3::types::PyDict::new_bound(py)
|
||||||
|
};
|
||||||
|
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||||
|
Ok(tokenizer)
|
||||||
|
}()
|
||||||
|
.expect("Cannot load the tokenizer");
|
||||||
|
tracing::info!("Loaded a python tokenizer");
|
||||||
|
OwnedTokenizer::Python(pytok)
|
||||||
|
}
|
||||||
|
Self::Rust(tok) => OwnedTokenizer::Rust(tok),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum OwnedTokenizer<'a> {
|
||||||
|
Python(pyo3::Bound<'a, pyo3::PyAny>),
|
||||||
|
Rust(tokenizers::Tokenizer),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> OwnedTokenizer<'a> {
|
||||||
|
fn encode(
|
||||||
|
&self,
|
||||||
|
query: String,
|
||||||
|
add_special_tokens: bool,
|
||||||
|
) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
match self {
|
||||||
|
Self::Python(pytok) => {
|
||||||
|
let py = pytok.py();
|
||||||
|
let kwargs = [
|
||||||
|
("text", query.into_py(py)),
|
||||||
|
("add_special_tokens", add_special_tokens.into_py(py)),
|
||||||
|
]
|
||||||
|
.into_py_dict_bound(py);
|
||||||
|
let encode = pytok.getattr("encode")?;
|
||||||
|
let input_ids: Vec<u32> = encode.call((), Some(&kwargs))?.extract()?;
|
||||||
|
Ok(Encoding::new(
|
||||||
|
input_ids,
|
||||||
|
vec![], // type ids
|
||||||
|
vec![], // tokens (strings)
|
||||||
|
vec![], // words
|
||||||
|
vec![], // offsets
|
||||||
|
vec![], // special_tokens_mask
|
||||||
|
vec![], // attention_mask
|
||||||
|
vec![], // overflowing
|
||||||
|
std::collections::HashMap::new(), //sequence_ranges
|
||||||
|
))
|
||||||
|
}
|
||||||
|
Self::Rust(tok) => tok.encode(query, add_special_tokens),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Hub type
|
/// Hub type
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct HubModelInfo {
|
pub struct HubModelInfo {
|
||||||
|
@ -19,7 +19,8 @@ use crate::{
|
|||||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
|
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
|
||||||
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
|
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
|
||||||
TextMessage, Token, TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
|
TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage,
|
||||||
|
Validation,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||||
@ -52,9 +53,8 @@ use std::convert::Infallible;
|
|||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::Path;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
@ -161,40 +161,30 @@ async fn get_chat_tokenize(
|
|||||||
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
|
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
|
||||||
let input = generate_request.inputs.clone();
|
let input = generate_request.inputs.clone();
|
||||||
let encoding = infer.tokenize(generate_request).await?;
|
let encoding = infer.tokenize(generate_request).await?;
|
||||||
if let Some(encoding) = encoding {
|
let tokens: Vec<SimpleToken> = encoding
|
||||||
let tokens: Vec<SimpleToken> = encoding
|
.get_ids()
|
||||||
.get_ids()
|
.iter()
|
||||||
.iter()
|
.zip(encoding.get_offsets())
|
||||||
.zip(encoding.get_offsets())
|
.map(|(&id, &(start, stop))| {
|
||||||
.map(|(&id, &(start, stop))| {
|
let text = input
|
||||||
let text = input
|
.chars()
|
||||||
.chars()
|
.skip(start)
|
||||||
.skip(start)
|
.take(stop - start)
|
||||||
.take(stop - start)
|
.collect::<String>();
|
||||||
.collect::<String>();
|
SimpleToken {
|
||||||
SimpleToken {
|
id,
|
||||||
id,
|
text,
|
||||||
text,
|
start,
|
||||||
start,
|
stop,
|
||||||
stop,
|
}
|
||||||
}
|
})
|
||||||
})
|
.collect();
|
||||||
.collect();
|
|
||||||
|
|
||||||
let resp = ChatTokenizeResponse {
|
let resp = ChatTokenizeResponse {
|
||||||
tokenize_response: TokenizeResponse(tokens),
|
tokenize_response: TokenizeResponse(tokens),
|
||||||
templated_text: input,
|
templated_text: input,
|
||||||
};
|
};
|
||||||
Ok((HeaderMap::new(), Json(resp)))
|
Ok((HeaderMap::new(), Json(resp)))
|
||||||
} else {
|
|
||||||
Err((
|
|
||||||
StatusCode::NOT_FOUND,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
|
|
||||||
error_type: "no fast tokenizer".to_string(),
|
|
||||||
}),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
@ -1458,35 +1448,25 @@ async fn tokenize(
|
|||||||
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let input = req.inputs.clone();
|
let input = req.inputs.clone();
|
||||||
let encoding = infer.tokenize(req).await?;
|
let encoding = infer.tokenize(req).await?;
|
||||||
if let Some(encoding) = encoding {
|
let tokens: Vec<SimpleToken> = encoding
|
||||||
let tokens: Vec<SimpleToken> = encoding
|
.get_ids()
|
||||||
.get_ids()
|
.iter()
|
||||||
.iter()
|
.zip(encoding.get_offsets())
|
||||||
.zip(encoding.get_offsets())
|
.map(|(&id, &(start, stop))| {
|
||||||
.map(|(&id, &(start, stop))| {
|
let text = input
|
||||||
let text = input
|
.chars()
|
||||||
.chars()
|
.skip(start)
|
||||||
.skip(start)
|
.take(stop - start)
|
||||||
.take(stop - start)
|
.collect::<String>();
|
||||||
.collect::<String>();
|
SimpleToken {
|
||||||
SimpleToken {
|
id,
|
||||||
id,
|
text,
|
||||||
text,
|
start,
|
||||||
start,
|
stop,
|
||||||
stop,
|
}
|
||||||
}
|
})
|
||||||
})
|
.collect();
|
||||||
.collect();
|
Ok(Json(TokenizeResponse(tokens)))
|
||||||
Ok(Json(TokenizeResponse(tokens)))
|
|
||||||
} else {
|
|
||||||
Err((
|
|
||||||
StatusCode::NOT_FOUND,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
|
|
||||||
error_type: "no fast tokenizer".to_string(),
|
|
||||||
}),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Prometheus metrics scrape endpoint
|
/// Prometheus metrics scrape endpoint
|
||||||
@ -1687,7 +1667,6 @@ pub async fn run(
|
|||||||
|
|
||||||
// Load tokenizer and model info
|
// Load tokenizer and model info
|
||||||
let (
|
let (
|
||||||
tokenizer_filename,
|
|
||||||
config_filename,
|
config_filename,
|
||||||
tokenizer_config_filename,
|
tokenizer_config_filename,
|
||||||
preprocessor_config_filename,
|
preprocessor_config_filename,
|
||||||
@ -1695,7 +1674,6 @@ pub async fn run(
|
|||||||
model_info,
|
model_info,
|
||||||
) = match api {
|
) = match api {
|
||||||
Type::None => (
|
Type::None => (
|
||||||
Some(local_path.join("tokenizer.json")),
|
|
||||||
Some(local_path.join("config.json")),
|
Some(local_path.join("config.json")),
|
||||||
Some(local_path.join("tokenizer_config.json")),
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
Some(local_path.join("preprocessor_config.json")),
|
Some(local_path.join("preprocessor_config.json")),
|
||||||
@ -1709,10 +1687,6 @@ pub async fn run(
|
|||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
));
|
));
|
||||||
|
|
||||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
|
||||||
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
|
||||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
|
||||||
};
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok();
|
let config_filename = api_repo.get("config.json").await.ok();
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||||
@ -1725,7 +1699,6 @@ pub async fn run(
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
(
|
(
|
||||||
tokenizer_filename,
|
|
||||||
config_filename,
|
config_filename,
|
||||||
tokenizer_config_filename,
|
tokenizer_config_filename,
|
||||||
preprocessor_config_filename,
|
preprocessor_config_filename,
|
||||||
@ -1740,7 +1713,6 @@ pub async fn run(
|
|||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||||
));
|
));
|
||||||
(
|
(
|
||||||
repo.get("tokenizer.json"),
|
|
||||||
repo.get("config.json"),
|
repo.get("config.json"),
|
||||||
repo.get("tokenizer_config.json"),
|
repo.get("tokenizer_config.json"),
|
||||||
repo.get("preprocessor_config.json"),
|
repo.get("preprocessor_config.json"),
|
||||||
@ -1762,21 +1734,22 @@ pub async fn run(
|
|||||||
HubTokenizerConfig::default()
|
HubTokenizerConfig::default()
|
||||||
});
|
});
|
||||||
|
|
||||||
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
|
let tokenizer: Tokenizer = {
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
let convert = pyo3::Python::with_gil(|py| -> PyResult<()> {
|
pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||||
let transformers = py.import_bound("transformers")?;
|
let transformers = py.import_bound("transformers")?;
|
||||||
let auto = transformers.getattr("AutoTokenizer")?;
|
let auto = transformers.getattr("AutoTokenizer")?;
|
||||||
let from_pretrained = auto.getattr("from_pretrained")?;
|
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||||
let args = (tokenizer_name.to_string(),);
|
let args = (tokenizer_name.to_string(),);
|
||||||
let kwargs = [
|
let kwargs = if let Some(rev) = &revision {
|
||||||
(
|
[
|
||||||
"revision",
|
("revision", rev.to_string().into_py(py)),
|
||||||
(revision.clone().unwrap_or_else(|| "main".to_string())).into_py(py),
|
("trust_remote_code", trust_remote_code.into_py(py)),
|
||||||
),
|
]
|
||||||
("trust_remote_code", trust_remote_code.into_py(py)),
|
.into_py_dict_bound(py)
|
||||||
]
|
} else {
|
||||||
.into_py_dict_bound(py);
|
[("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py)
|
||||||
|
};
|
||||||
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||||
let save = tokenizer.getattr("save_pretrained")?;
|
let save = tokenizer.getattr("save_pretrained")?;
|
||||||
let args = ("out".to_string(),);
|
let args = ("out".to_string(),);
|
||||||
@ -1785,16 +1758,18 @@ pub async fn run(
|
|||||||
})
|
})
|
||||||
.inspect_err(|err| {
|
.inspect_err(|err| {
|
||||||
tracing::error!("Failed to import python tokenizer {err}");
|
tracing::error!("Failed to import python tokenizer {err}");
|
||||||
});
|
})
|
||||||
let filename = if convert.is_ok() {
|
.expect("We cannot load a tokenizer");
|
||||||
// If we have correctly loaded and resaved with transformers
|
let filename = "out/tokenizer.json";
|
||||||
// We might have modified the tokenizer.json according to transformers
|
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
|
||||||
"out/tokenizer.json".into()
|
Tokenizer::Rust(tok)
|
||||||
} else {
|
} else {
|
||||||
filename
|
Tokenizer::Python {
|
||||||
};
|
tokenizer_name: tokenizer_name.clone(),
|
||||||
Tokenizer::from_file(filename).ok()
|
revision: revision.clone(),
|
||||||
});
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||||
std::fs::read_to_string(filename)
|
std::fs::read_to_string(filename)
|
||||||
@ -1822,10 +1797,6 @@ pub async fn run(
|
|||||||
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
|
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
|
||||||
|
|
||||||
tracing::info!("Using config {config:?}");
|
tracing::info!("Using config {config:?}");
|
||||||
if tokenizer.is_none() {
|
|
||||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
|
||||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only send usage stats when TGI is run in container and the function returns Some
|
// Only send usage stats when TGI is run in container and the function returns Some
|
||||||
let is_container = matches!(usage_stats::is_container(), Ok(true));
|
let is_container = matches!(usage_stats::is_container(), Ok(true));
|
||||||
@ -1940,7 +1911,7 @@ async fn start(
|
|||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
config: Option<Config>,
|
config: Option<Config>,
|
||||||
(tokenizer, tokenizer_config): (Option<Tokenizer>, HubTokenizerConfig),
|
(tokenizer, tokenizer_config): (Tokenizer, HubTokenizerConfig),
|
||||||
(preprocessor_config, processor_config): (Option<HubPreprocessorConfig>, HubProcessorConfig),
|
(preprocessor_config, processor_config): (Option<HubPreprocessorConfig>, HubProcessorConfig),
|
||||||
hostname: String,
|
hostname: String,
|
||||||
port: u16,
|
port: u16,
|
||||||
@ -2400,30 +2371,6 @@ pub async fn get_hub_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// get base tokenizer
|
|
||||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok()?;
|
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
|
||||||
let file = File::open(config_filename).ok()?;
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of `User`.
|
|
||||||
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
|
|
||||||
|
|
||||||
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
|
|
||||||
let api_base_repo = api.repo(Repo::with_revision(
|
|
||||||
base_model_id.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
"main".to_string(),
|
|
||||||
));
|
|
||||||
|
|
||||||
api_base_repo.get("tokenizer.json").await.ok()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get tokenizer_config from the Huggingface Hub
|
/// get tokenizer_config from the Huggingface Hub
|
||||||
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
||||||
|
@ -4,6 +4,7 @@ use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}
|
|||||||
use crate::{
|
use crate::{
|
||||||
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
|
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
|
||||||
};
|
};
|
||||||
|
use crate::{OwnedTokenizer, Tokenizer};
|
||||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||||
use image::{ImageFormat, ImageReader};
|
use image::{ImageFormat, ImageReader};
|
||||||
use jsonschema::{Draft, JSONSchema};
|
use jsonschema::{Draft, JSONSchema};
|
||||||
@ -13,7 +14,6 @@ use std::io::Cursor;
|
|||||||
use std::iter;
|
use std::iter;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tracing::{instrument, Span};
|
use tracing::{instrument, Span};
|
||||||
@ -30,14 +30,14 @@ pub struct Validation {
|
|||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
/// Channel to communicate with the background tokenization task
|
/// Channel to communicate with the background tokenization task
|
||||||
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
sender: mpsc::UnboundedSender<TokenizerRequest>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Validation {
|
impl Validation {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
workers: usize,
|
workers: usize,
|
||||||
tokenizer: Option<Tokenizer>,
|
tokenizer: Tokenizer,
|
||||||
config: Option<Config>,
|
config: Option<Config>,
|
||||||
preprocessor_config: Option<HubPreprocessorConfig>,
|
preprocessor_config: Option<HubPreprocessorConfig>,
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
@ -47,8 +47,13 @@ impl Validation {
|
|||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let workers = if let Tokenizer::Python { .. } = &tokenizer {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
workers
|
||||||
|
};
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
let sender = if let Some(tokenizer) = tokenizer {
|
let sender = {
|
||||||
// Create round robin channel
|
// Create round robin channel
|
||||||
let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
|
let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
|
||||||
let mut senders = Vec::with_capacity(workers);
|
let mut senders = Vec::with_capacity(workers);
|
||||||
@ -75,9 +80,7 @@ impl Validation {
|
|||||||
// Create tokenization round robin task
|
// Create tokenization round robin task
|
||||||
tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));
|
tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));
|
||||||
|
|
||||||
Some(validation_sender)
|
validation_sender
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
@ -97,28 +100,25 @@ impl Validation {
|
|||||||
inputs: String,
|
inputs: String,
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
|
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some(sender) = &self.sender {
|
// Create response channel
|
||||||
// Create response channel
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
let (response_sender, response_receiver) = oneshot::channel();
|
// Send request to the background validation task
|
||||||
// Send request to the background validation task
|
// Unwrap is safe here
|
||||||
// Unwrap is safe here
|
let _ = &self
|
||||||
sender
|
.sender
|
||||||
.send((
|
.send((
|
||||||
(inputs, add_special_tokens, truncate),
|
(inputs, add_special_tokens, truncate),
|
||||||
response_sender,
|
response_sender,
|
||||||
Span::current(),
|
Span::current(),
|
||||||
))
|
))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Await on response channel
|
// Await on response channel
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
let encoding = response_receiver.await.unwrap()?;
|
let encoding = response_receiver.await.unwrap()?;
|
||||||
Ok(Some(encoding))
|
Ok(encoding)
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
@ -131,76 +131,46 @@ impl Validation {
|
|||||||
max_new_tokens: Option<u32>,
|
max_new_tokens: Option<u32>,
|
||||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some((encoding, inputs)) = self
|
let (encoding, inputs) = self
|
||||||
.tokenize(inputs.clone(), add_special_tokens, truncate)
|
.tokenize(inputs.clone(), add_special_tokens, truncate)
|
||||||
.await?
|
.await?;
|
||||||
{
|
// Create response channel
|
||||||
// Create response channel
|
let input_length = if let Some(truncate) = truncate {
|
||||||
let input_length = if let Some(truncate) = truncate {
|
std::cmp::min(encoding.len(), truncate)
|
||||||
std::cmp::min(encoding.len(), truncate)
|
} else {
|
||||||
} else {
|
encoding.len()
|
||||||
encoding.len()
|
};
|
||||||
};
|
|
||||||
|
|
||||||
// Get total tokens
|
// Get total tokens
|
||||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
||||||
max_new_tokens
|
max_new_tokens
|
||||||
} else {
|
} else {
|
||||||
self.max_total_tokens.saturating_sub(input_length) as u32
|
self.max_total_tokens.saturating_sub(input_length) as u32
|
||||||
};
|
};
|
||||||
let total_tokens = input_length + max_new_tokens as usize;
|
let total_tokens = input_length + max_new_tokens as usize;
|
||||||
|
|
||||||
// Validate MaxTotalTokens
|
// Validate MaxTotalTokens
|
||||||
if total_tokens > self.max_total_tokens {
|
if total_tokens > self.max_total_tokens {
|
||||||
return Err(ValidationError::MaxTotalTokens(
|
return Err(ValidationError::MaxTotalTokens(
|
||||||
self.max_total_tokens,
|
self.max_total_tokens,
|
||||||
input_length,
|
|
||||||
max_new_tokens,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate InputLength
|
|
||||||
if input_length > self.max_input_length {
|
|
||||||
return Err(ValidationError::InputLength(
|
|
||||||
self.max_input_length,
|
|
||||||
input_length,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let ids = encoding.get_ids();
|
|
||||||
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
|
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
|
||||||
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
|
|
||||||
}
|
|
||||||
// Return inputs without validation
|
|
||||||
else {
|
|
||||||
// In this case, we don't know the real length in tokens of the inputs
|
|
||||||
// However, the inputs will be truncated by the python servers
|
|
||||||
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
|
||||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
|
||||||
max_new_tokens
|
|
||||||
} else if let Some(truncate) = truncate {
|
|
||||||
self.max_total_tokens.saturating_sub(truncate) as u32
|
|
||||||
} else {
|
|
||||||
return Err(ValidationError::UnsetMaxNewTokens);
|
|
||||||
};
|
|
||||||
let mut input_length = truncate.unwrap_or(self.max_input_length);
|
|
||||||
|
|
||||||
// We don't have a tokenizer, therefore we have no idea how long is the query, let
|
|
||||||
// them through and hope for the best.
|
|
||||||
// Validate MaxNewTokens
|
|
||||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
|
||||||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok((
|
|
||||||
vec![Chunk::Text(inputs)],
|
|
||||||
None,
|
|
||||||
input_length,
|
input_length,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
))
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate InputLength
|
||||||
|
if input_length > self.max_input_length {
|
||||||
|
return Err(ValidationError::InputLength(
|
||||||
|
self.max_input_length,
|
||||||
|
input_length,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ids = encoding.get_ids();
|
||||||
|
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
||||||
|
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate a payload and get the number of tokens in the input
|
/// Validate a payload and get the number of tokens in the input
|
||||||
@ -464,23 +434,26 @@ fn tokenizer_worker(
|
|||||||
preprocessor_config: Option<HubPreprocessorConfig>,
|
preprocessor_config: Option<HubPreprocessorConfig>,
|
||||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||||
) {
|
) {
|
||||||
// Loop over requests
|
pyo3::Python::with_gil(|py| {
|
||||||
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
|
let tokenizer = tokenizer.into_owned(py);
|
||||||
receiver.blocking_recv()
|
// Loop over requests
|
||||||
{
|
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
|
||||||
parent_span.in_scope(|| {
|
receiver.blocking_recv()
|
||||||
response_tx
|
{
|
||||||
.send(prepare_input(
|
parent_span.in_scope(|| {
|
||||||
inputs,
|
response_tx
|
||||||
truncate,
|
.send(prepare_input(
|
||||||
add_special_tokens,
|
inputs,
|
||||||
&tokenizer,
|
truncate,
|
||||||
config.as_ref(),
|
add_special_tokens,
|
||||||
preprocessor_config.as_ref(),
|
&tokenizer,
|
||||||
))
|
config.as_ref(),
|
||||||
.unwrap_or(())
|
preprocessor_config.as_ref(),
|
||||||
})
|
))
|
||||||
}
|
.unwrap_or(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
|
fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
|
||||||
@ -612,7 +585,7 @@ fn prepare_input(
|
|||||||
inputs: String,
|
inputs: String,
|
||||||
_truncate: Option<usize>,
|
_truncate: Option<usize>,
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &OwnedTokenizer,
|
||||||
config: Option<&Config>,
|
config: Option<&Config>,
|
||||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||||
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
||||||
|
Loading…
Reference in New Issue
Block a user