mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing the tests.
This commit is contained in:
parent
b89b9fd016
commit
5bc1fe84eb
@ -1,647 +0,0 @@
|
||||
// This file is @generated by prost-build.
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct HealthRequest {}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct HealthResponse {}
|
||||
/// / Empty request
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct InfoRequest {}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct InfoResponse {
|
||||
#[prost(bool, tag = "1")]
|
||||
pub requires_padding: bool,
|
||||
#[prost(string, tag = "2")]
|
||||
pub dtype: ::prost::alloc::string::String,
|
||||
#[prost(string, tag = "3")]
|
||||
pub device_type: ::prost::alloc::string::String,
|
||||
#[prost(uint32, optional, tag = "4")]
|
||||
pub window_size: ::core::option::Option<u32>,
|
||||
#[prost(uint32, tag = "5")]
|
||||
pub speculate: u32,
|
||||
}
|
||||
/// / Empty request
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ServiceDiscoveryRequest {}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ServiceDiscoveryResponse {
|
||||
/// / Other shards urls
|
||||
#[prost(string, repeated, tag = "1")]
|
||||
pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ClearCacheRequest {
|
||||
/// / Optional batch id
|
||||
#[prost(uint64, optional, tag = "1")]
|
||||
pub id: ::core::option::Option<u64>,
|
||||
}
|
||||
/// / Empty response
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ClearCacheResponse {}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct NextTokenChooserParameters {
|
||||
/// / exponential scaling output probability distribution
|
||||
#[prost(float, tag = "1")]
|
||||
pub temperature: f32,
|
||||
/// / restricting to the k highest probability elements
|
||||
#[prost(uint32, tag = "2")]
|
||||
pub top_k: u32,
|
||||
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
#[prost(float, tag = "3")]
|
||||
pub top_p: f32,
|
||||
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
#[prost(float, tag = "4")]
|
||||
pub typical_p: f32,
|
||||
/// / apply sampling on the logits
|
||||
#[prost(bool, tag = "5")]
|
||||
pub do_sample: bool,
|
||||
/// / random seed for sampling
|
||||
#[prost(uint64, tag = "6")]
|
||||
pub seed: u64,
|
||||
/// / repetition penalty
|
||||
#[prost(float, tag = "7")]
|
||||
pub repetition_penalty: f32,
|
||||
/// / frequency penalty
|
||||
#[prost(float, tag = "9")]
|
||||
pub frequency_penalty: f32,
|
||||
/// / token watermarking using "A Watermark for Large Language Models"
|
||||
#[prost(bool, tag = "8")]
|
||||
pub watermark: bool,
|
||||
/// / grammar (applied if not empty)
|
||||
#[prost(string, tag = "10")]
|
||||
pub grammar: ::prost::alloc::string::String,
|
||||
/// / grammar type
|
||||
#[prost(enumeration = "GrammarType", tag = "11")]
|
||||
pub grammar_type: i32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct StoppingCriteriaParameters {
|
||||
/// / Maximum number of generated tokens
|
||||
#[prost(uint32, tag = "1")]
|
||||
pub max_new_tokens: u32,
|
||||
/// / Optional stopping sequences
|
||||
#[prost(string, repeated, tag = "2")]
|
||||
pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||
/// / Ignore end of sequence token
|
||||
/// / used for benchmarking
|
||||
#[prost(bool, tag = "3")]
|
||||
pub ignore_eos_token: bool,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Request {
|
||||
/// / Request ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub id: u64,
|
||||
/// / The generation context
|
||||
#[prost(string, tag = "2")]
|
||||
pub inputs: ::prost::alloc::string::String,
|
||||
/// / Context truncation
|
||||
#[prost(uint32, tag = "3")]
|
||||
pub truncate: u32,
|
||||
/// / Next Token Chooser Parameters
|
||||
#[prost(message, optional, tag = "4")]
|
||||
pub parameters: ::core::option::Option<NextTokenChooserParameters>,
|
||||
/// / Stopping Criteria Parameters
|
||||
#[prost(message, optional, tag = "5")]
|
||||
pub stopping_parameters: ::core::option::Option<StoppingCriteriaParameters>,
|
||||
/// / Return prefill logprobs
|
||||
#[prost(bool, tag = "6")]
|
||||
pub prefill_logprobs: bool,
|
||||
/// / Return most likely n tokens
|
||||
#[prost(uint32, tag = "7")]
|
||||
pub top_n_tokens: u32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Batch {
|
||||
/// / Batch ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub id: u64,
|
||||
/// / Individual requests
|
||||
#[prost(message, repeated, tag = "2")]
|
||||
pub requests: ::prost::alloc::vec::Vec<Request>,
|
||||
/// / Batch size (==len(requests))
|
||||
#[prost(uint32, tag = "3")]
|
||||
pub size: u32,
|
||||
/// / Maximum number of tokens this batch will grow to
|
||||
#[prost(uint32, tag = "4")]
|
||||
pub max_tokens: u32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct CachedBatch {
|
||||
/// / Batch ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub id: u64,
|
||||
/// / Individual requests ids
|
||||
#[prost(uint64, repeated, tag = "2")]
|
||||
pub request_ids: ::prost::alloc::vec::Vec<u64>,
|
||||
/// / Batch size (==len(requests))
|
||||
#[prost(uint32, tag = "3")]
|
||||
pub size: u32,
|
||||
/// / Maximum number of tokens this batch will grow to
|
||||
#[prost(uint32, tag = "4")]
|
||||
pub max_tokens: u32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct GeneratedText {
|
||||
/// / Output
|
||||
#[prost(string, tag = "1")]
|
||||
pub text: ::prost::alloc::string::String,
|
||||
/// / Number of generated tokens
|
||||
#[prost(uint32, tag = "2")]
|
||||
pub generated_tokens: u32,
|
||||
/// / Finish reason
|
||||
#[prost(enumeration = "FinishReason", tag = "3")]
|
||||
pub finish_reason: i32,
|
||||
/// / Seed
|
||||
#[prost(uint64, optional, tag = "4")]
|
||||
pub seed: ::core::option::Option<u64>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Tokens {
|
||||
/// / Token IDs
|
||||
#[prost(uint32, repeated, tag = "1")]
|
||||
pub ids: ::prost::alloc::vec::Vec<u32>,
|
||||
/// / Logprobs
|
||||
#[prost(float, repeated, tag = "2")]
|
||||
pub logprobs: ::prost::alloc::vec::Vec<f32>,
|
||||
/// / tokens
|
||||
#[prost(string, repeated, tag = "3")]
|
||||
pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||
/// / special
|
||||
#[prost(bool, repeated, tag = "4")]
|
||||
pub is_special: ::prost::alloc::vec::Vec<bool>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Generation {
|
||||
/// / Request ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub request_id: u64,
|
||||
/// / Prefill tokens (optional)
|
||||
#[prost(message, optional, tag = "2")]
|
||||
pub prefill_tokens: ::core::option::Option<Tokens>,
|
||||
#[prost(message, optional, tag = "3")]
|
||||
pub tokens: ::core::option::Option<Tokens>,
|
||||
/// / Complete generated text
|
||||
#[prost(message, optional, tag = "4")]
|
||||
pub generated_text: ::core::option::Option<GeneratedText>,
|
||||
/// / Top tokens
|
||||
#[prost(message, repeated, tag = "5")]
|
||||
pub top_tokens: ::prost::alloc::vec::Vec<Tokens>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct FilterBatchRequest {
|
||||
/// / Batch ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub batch_id: u64,
|
||||
/// / Requests to keep
|
||||
#[prost(uint64, repeated, tag = "2")]
|
||||
pub request_ids: ::prost::alloc::vec::Vec<u64>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct FilterBatchResponse {
|
||||
/// / Filtered Batch (cached)
|
||||
#[prost(message, optional, tag = "1")]
|
||||
pub batch: ::core::option::Option<CachedBatch>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct PrefillRequest {
|
||||
/// / Batch
|
||||
#[prost(message, optional, tag = "1")]
|
||||
pub batch: ::core::option::Option<Batch>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct PrefillResponse {
|
||||
/// / Generation
|
||||
#[prost(message, repeated, tag = "1")]
|
||||
pub generations: ::prost::alloc::vec::Vec<Generation>,
|
||||
/// / Next batch (cached)
|
||||
#[prost(message, optional, tag = "2")]
|
||||
pub batch: ::core::option::Option<CachedBatch>,
|
||||
/// / Forward elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "3")]
|
||||
pub forward_ns: u64,
|
||||
/// / Decode elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "4")]
|
||||
pub decode_ns: u64,
|
||||
/// / Total elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "5")]
|
||||
pub total_ns: u64,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct DecodeRequest {
|
||||
/// / Cached batches
|
||||
#[prost(message, repeated, tag = "1")]
|
||||
pub batches: ::prost::alloc::vec::Vec<CachedBatch>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct DecodeResponse {
|
||||
/// / Decodes
|
||||
#[prost(message, repeated, tag = "1")]
|
||||
pub generations: ::prost::alloc::vec::Vec<Generation>,
|
||||
/// / Next batch (cached)
|
||||
#[prost(message, optional, tag = "2")]
|
||||
pub batch: ::core::option::Option<CachedBatch>,
|
||||
/// / Forward elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "3")]
|
||||
pub forward_ns: u64,
|
||||
/// / Decode elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "4")]
|
||||
pub decode_ns: u64,
|
||||
/// / Total elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "5")]
|
||||
pub total_ns: u64,
|
||||
/// / Concatenate elapsed time in nanoseconds
|
||||
#[prost(uint64, optional, tag = "6")]
|
||||
pub concat_ns: ::core::option::Option<u64>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct WarmupRequest {
|
||||
/// / Batch to warmup on
|
||||
#[prost(message, optional, tag = "1")]
|
||||
pub batch: ::core::option::Option<Batch>,
|
||||
#[prost(uint32, tag = "2")]
|
||||
pub max_input_length: u32,
|
||||
#[prost(uint32, tag = "3")]
|
||||
pub max_prefill_tokens: u32,
|
||||
#[prost(uint32, tag = "4")]
|
||||
pub max_total_tokens: u32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct WarmupResponse {
|
||||
/// / Maximum number of tokens supported by the model
|
||||
#[prost(uint32, optional, tag = "1")]
|
||||
pub max_supported_total_tokens: ::core::option::Option<u32>,
|
||||
}
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
|
||||
#[repr(i32)]
|
||||
pub enum GrammarType {
|
||||
None = 0,
|
||||
Json = 1,
|
||||
Regex = 2,
|
||||
}
|
||||
impl GrammarType {
|
||||
/// String value of the enum field names used in the ProtoBuf definition.
|
||||
///
|
||||
/// The values are not transformed in any way and thus are considered stable
|
||||
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
|
||||
pub fn as_str_name(&self) -> &'static str {
|
||||
match self {
|
||||
GrammarType::None => "GRAMMAR_TYPE_NONE",
|
||||
GrammarType::Json => "GRAMMAR_TYPE_JSON",
|
||||
GrammarType::Regex => "GRAMMAR_TYPE_REGEX",
|
||||
}
|
||||
}
|
||||
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
|
||||
match value {
|
||||
"GRAMMAR_TYPE_NONE" => Some(Self::None),
|
||||
"GRAMMAR_TYPE_JSON" => Some(Self::Json),
|
||||
"GRAMMAR_TYPE_REGEX" => Some(Self::Regex),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
|
||||
#[repr(i32)]
|
||||
pub enum FinishReason {
|
||||
Length = 0,
|
||||
EosToken = 1,
|
||||
StopSequence = 2,
|
||||
}
|
||||
impl FinishReason {
|
||||
/// String value of the enum field names used in the ProtoBuf definition.
|
||||
///
|
||||
/// The values are not transformed in any way and thus are considered stable
|
||||
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
|
||||
pub fn as_str_name(&self) -> &'static str {
|
||||
match self {
|
||||
FinishReason::Length => "FINISH_REASON_LENGTH",
|
||||
FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN",
|
||||
FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE",
|
||||
}
|
||||
}
|
||||
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
|
||||
match value {
|
||||
"FINISH_REASON_LENGTH" => Some(Self::Length),
|
||||
"FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken),
|
||||
"FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Generated client implementations.
|
||||
pub mod text_generation_service_client {
|
||||
#![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)]
|
||||
use tonic::codegen::*;
|
||||
use tonic::codegen::http::Uri;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TextGenerationServiceClient<T> {
|
||||
inner: tonic::client::Grpc<T>,
|
||||
}
|
||||
impl TextGenerationServiceClient<tonic::transport::Channel> {
|
||||
/// Attempt to create a new client by connecting to a given endpoint.
|
||||
pub async fn connect<D>(dst: D) -> Result<Self, tonic::transport::Error>
|
||||
where
|
||||
D: TryInto<tonic::transport::Endpoint>,
|
||||
D::Error: Into<StdError>,
|
||||
{
|
||||
let conn = tonic::transport::Endpoint::new(dst)?.connect().await?;
|
||||
Ok(Self::new(conn))
|
||||
}
|
||||
}
|
||||
impl<T> TextGenerationServiceClient<T>
|
||||
where
|
||||
T: tonic::client::GrpcService<tonic::body::BoxBody>,
|
||||
T::Error: Into<StdError>,
|
||||
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
|
||||
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
|
||||
{
|
||||
pub fn new(inner: T) -> Self {
|
||||
let inner = tonic::client::Grpc::new(inner);
|
||||
Self { inner }
|
||||
}
|
||||
pub fn with_origin(inner: T, origin: Uri) -> Self {
|
||||
let inner = tonic::client::Grpc::with_origin(inner, origin);
|
||||
Self { inner }
|
||||
}
|
||||
pub fn with_interceptor<F>(
|
||||
inner: T,
|
||||
interceptor: F,
|
||||
) -> TextGenerationServiceClient<InterceptedService<T, F>>
|
||||
where
|
||||
F: tonic::service::Interceptor,
|
||||
T::ResponseBody: Default,
|
||||
T: tonic::codegen::Service<
|
||||
http::Request<tonic::body::BoxBody>,
|
||||
Response = http::Response<
|
||||
<T as tonic::client::GrpcService<tonic::body::BoxBody>>::ResponseBody,
|
||||
>,
|
||||
>,
|
||||
<T as tonic::codegen::Service<
|
||||
http::Request<tonic::body::BoxBody>,
|
||||
>>::Error: Into<StdError> + Send + Sync,
|
||||
{
|
||||
TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor))
|
||||
}
|
||||
/// Compress requests with the given encoding.
|
||||
///
|
||||
/// This requires the server to support it otherwise it might respond with an
|
||||
/// error.
|
||||
#[must_use]
|
||||
pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
|
||||
self.inner = self.inner.send_compressed(encoding);
|
||||
self
|
||||
}
|
||||
/// Enable decompressing responses.
|
||||
#[must_use]
|
||||
pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
|
||||
self.inner = self.inner.accept_compressed(encoding);
|
||||
self
|
||||
}
|
||||
/// Limits the maximum size of a decoded message.
|
||||
///
|
||||
/// Default: `4MB`
|
||||
#[must_use]
|
||||
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
|
||||
self.inner = self.inner.max_decoding_message_size(limit);
|
||||
self
|
||||
}
|
||||
/// Limits the maximum size of an encoded message.
|
||||
///
|
||||
/// Default: `usize::MAX`
|
||||
#[must_use]
|
||||
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
|
||||
self.inner = self.inner.max_encoding_message_size(limit);
|
||||
self
|
||||
}
|
||||
/// / Model Info
|
||||
pub async fn info(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::InfoRequest>,
|
||||
) -> std::result::Result<tonic::Response<super::InfoResponse>, tonic::Status> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/Info",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Info"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Service discovery
|
||||
pub async fn service_discovery(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::ServiceDiscoveryRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::ServiceDiscoveryResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/ServiceDiscovery",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(
|
||||
GrpcMethod::new(
|
||||
"generate.v2.TextGenerationService",
|
||||
"ServiceDiscovery",
|
||||
),
|
||||
);
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Empties batch cache
|
||||
pub async fn clear_cache(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::ClearCacheRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::ClearCacheResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/ClearCache",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(
|
||||
GrpcMethod::new("generate.v2.TextGenerationService", "ClearCache"),
|
||||
);
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Remove requests from a cached batch
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::FilterBatchRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::FilterBatchResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/FilterBatch",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(
|
||||
GrpcMethod::new("generate.v2.TextGenerationService", "FilterBatch"),
|
||||
);
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Warmup the model and compute max cache size
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::WarmupRequest>,
|
||||
) -> std::result::Result<tonic::Response<super::WarmupResponse>, tonic::Status> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/Warmup",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Warmup"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Prefill batch and decode first token
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::PrefillRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::PrefillResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/Prefill",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Prefill"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Decode token for a list of prefilled batches
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::DecodeRequest>,
|
||||
) -> std::result::Result<tonic::Response<super::DecodeResponse>, tonic::Status> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/Decode",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Decode"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Health check
|
||||
pub async fn health(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::HealthRequest>,
|
||||
) -> std::result::Result<tonic::Response<super::HealthResponse>, tonic::Status> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/Health",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Health"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
}
|
||||
}
|
@ -1,6 +0,0 @@
|
||||
// This file is @generated by prost-build.
|
||||
pub mod generate {
|
||||
pub mod v2 {
|
||||
include!("generate.v2.rs");
|
||||
}
|
||||
}
|
@ -1421,13 +1421,12 @@ impl Default for ModelsInfo {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
||||
pub(crate) fn get_tokenizer() -> Tokenizer {
|
||||
let api = hf_hub::api::sync::Api::new().unwrap();
|
||||
let repo = api.model("gpt2".to_string());
|
||||
let filename = repo.get("tokenizer.json").unwrap();
|
||||
Tokenizer::from_file(filename).unwrap()
|
||||
Tokenizer::Rust(tokenizers::Tokenizer::from_file(filename).unwrap())
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -2515,10 +2515,11 @@ mod tests {
|
||||
use crate::TokenizerConfigToken;
|
||||
use crate::Tool;
|
||||
|
||||
use crate::tests::get_tokenizer;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_prepare_chat_input() {
|
||||
#[tokio::test]
|
||||
async fn test_prepare_chat_input() {
|
||||
// Mock Backend to avoid network requests
|
||||
struct MockBackend;
|
||||
|
||||
@ -2559,9 +2560,11 @@ mod tests {
|
||||
ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string())
|
||||
);
|
||||
|
||||
let tokenizer = get_tokenizer();
|
||||
|
||||
let infer = Infer::new(
|
||||
backend,
|
||||
Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false),
|
||||
Validation::new(1, tokenizer, None, None, 1, 1, 1, 1, 1, false),
|
||||
1,
|
||||
tokenizer_config,
|
||||
HubProcessorConfig::default(),
|
||||
|
@ -797,7 +797,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_max_new_tokens() {
|
||||
let tokenizer = None;
|
||||
let tokenizer = get_tokenizer();
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_top_n_tokens = 4;
|
||||
@ -824,15 +824,15 @@ mod tests {
|
||||
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
|
||||
.await
|
||||
{
|
||||
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||
Ok((_s, _, 0, 10)) => (),
|
||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||
// Ok((_s, _, 0, 10)) => (),
|
||||
r => panic!("Unexpected not max new tokens: {r:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_input_length() {
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let tokenizer = get_tokenizer();
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_top_n_tokens = 4;
|
||||
@ -866,7 +866,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_best_of_sampling() {
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let tokenizer = get_tokenizer();
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_top_n_tokens = 4;
|
||||
@ -906,7 +906,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_top_p() {
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let tokenizer = get_tokenizer();
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_top_n_tokens = 4;
|
||||
@ -977,7 +977,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_top_n_tokens() {
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let tokenizer = get_tokenizer();
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequences = 3;
|
||||
let max_top_n_tokens = 4;
|
||||
@ -1062,7 +1062,7 @@ mod tests {
|
||||
async fn test_prepare_input_chunks() {
|
||||
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
|
||||
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let tokenizer = get_tokenizer();
|
||||
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
@ -1097,7 +1097,7 @@ mod tests {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some((_encoding, chunks))) => chunks,
|
||||
Ok((_encoding, chunks)) => chunks,
|
||||
_ => panic!("Unexpected tokenization failure"),
|
||||
};
|
||||
|
||||
@ -1119,7 +1119,7 @@ mod tests {
|
||||
async fn test_idefics2_correct_n_fake_tokens() {
|
||||
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
|
||||
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let tokenizer = get_tokenizer();
|
||||
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
@ -1157,7 +1157,7 @@ mod tests {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Some((encoding, chunks))) => (encoding, chunks),
|
||||
Ok((encoding, chunks)) => (encoding, chunks),
|
||||
_ => panic!("Unexpected tokenization failure"),
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user