mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
small refactor to make router a bit more agnostic
This commit is contained in:
parent
df71aafdcc
commit
4dbb342fe3
236
proto/v3/generate.proto
Normal file
236
proto/v3/generate.proto
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package generate.v3;
|
||||||
|
|
||||||
|
service TextGenerationService {
|
||||||
|
/// Model Info
|
||||||
|
rpc Info (InfoRequest) returns (InfoResponse) {}
|
||||||
|
/// Service discovery
|
||||||
|
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
|
||||||
|
/// Empties batch cache
|
||||||
|
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
||||||
|
/// Remove requests from a cached batch
|
||||||
|
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
|
||||||
|
/// Warmup the model and compute max cache size
|
||||||
|
rpc Warmup (WarmupRequest) returns (WarmupResponse);
|
||||||
|
/// Prefill batch and decode first token
|
||||||
|
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||||
|
/// Decode token for a list of prefilled batches
|
||||||
|
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||||
|
/// Health check
|
||||||
|
rpc Health (HealthRequest) returns (HealthResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
message HealthRequest {}
|
||||||
|
message HealthResponse {}
|
||||||
|
|
||||||
|
/// Empty request
|
||||||
|
message InfoRequest {}
|
||||||
|
|
||||||
|
message InfoResponse {
|
||||||
|
bool requires_padding = 1;
|
||||||
|
string dtype = 2;
|
||||||
|
string device_type = 3;
|
||||||
|
optional uint32 window_size = 4;
|
||||||
|
uint32 speculate = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Empty request
|
||||||
|
message ServiceDiscoveryRequest {}
|
||||||
|
|
||||||
|
message ServiceDiscoveryResponse {
|
||||||
|
/// Other shards urls
|
||||||
|
repeated string urls = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ClearCacheRequest {
|
||||||
|
/// Optional batch id
|
||||||
|
optional uint64 id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Empty response
|
||||||
|
message ClearCacheResponse {}
|
||||||
|
|
||||||
|
enum GrammarType {
|
||||||
|
GRAMMAR_TYPE_NONE = 0;
|
||||||
|
GRAMMAR_TYPE_JSON = 1;
|
||||||
|
GRAMMAR_TYPE_REGEX = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message NextTokenChooserParameters {
|
||||||
|
/// exponential scaling output probability distribution
|
||||||
|
float temperature = 1;
|
||||||
|
/// restricting to the k highest probability elements
|
||||||
|
uint32 top_k = 2;
|
||||||
|
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
float top_p = 3;
|
||||||
|
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
float typical_p = 4;
|
||||||
|
/// apply sampling on the logits
|
||||||
|
bool do_sample = 5;
|
||||||
|
/// random seed for sampling
|
||||||
|
uint64 seed = 6;
|
||||||
|
/// repetition penalty
|
||||||
|
float repetition_penalty = 7;
|
||||||
|
/// frequency penalty
|
||||||
|
float frequency_penalty = 9;
|
||||||
|
/// token watermarking using "A Watermark for Large Language Models"
|
||||||
|
bool watermark = 8;
|
||||||
|
/// grammar (applied if not empty)
|
||||||
|
string grammar = 10;
|
||||||
|
/// grammar type
|
||||||
|
GrammarType grammar_type = 11;
|
||||||
|
}
|
||||||
|
|
||||||
|
message StoppingCriteriaParameters {
|
||||||
|
/// Maximum number of generated tokens
|
||||||
|
uint32 max_new_tokens = 1;
|
||||||
|
/// Optional stopping sequences
|
||||||
|
repeated string stop_sequences = 2;
|
||||||
|
/// Ignore end of sequence token
|
||||||
|
/// used for benchmarking
|
||||||
|
bool ignore_eos_token = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Request {
|
||||||
|
/// Request ID
|
||||||
|
uint64 id = 1;
|
||||||
|
/// The generation context
|
||||||
|
string inputs = 2;
|
||||||
|
/// Context truncation
|
||||||
|
uint32 truncate = 3;
|
||||||
|
/// Next Token Chooser Parameters
|
||||||
|
NextTokenChooserParameters parameters = 4;
|
||||||
|
/// Stopping Criteria Parameters
|
||||||
|
StoppingCriteriaParameters stopping_parameters = 5;
|
||||||
|
/// Return prefill logprobs
|
||||||
|
bool prefill_logprobs = 6;
|
||||||
|
/// Return most likely n tokens
|
||||||
|
uint32 top_n_tokens = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Batch {
|
||||||
|
/// Batch ID
|
||||||
|
uint64 id = 1;
|
||||||
|
/// Individual requests
|
||||||
|
repeated Request requests = 2;
|
||||||
|
/// Batch size (==len(requests))
|
||||||
|
uint32 size = 3;
|
||||||
|
/// Maximum number of tokens this batch will grow to
|
||||||
|
uint32 max_tokens = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CachedBatch {
|
||||||
|
/// Batch ID
|
||||||
|
uint64 id = 1;
|
||||||
|
/// Individual requests ids
|
||||||
|
repeated uint64 request_ids = 2;
|
||||||
|
/// Batch size (==len(requests))
|
||||||
|
uint32 size = 3;
|
||||||
|
/// Maximum number of tokens this batch will grow to
|
||||||
|
uint32 max_tokens = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum FinishReason {
|
||||||
|
FINISH_REASON_LENGTH = 0;
|
||||||
|
FINISH_REASON_EOS_TOKEN = 1;
|
||||||
|
FINISH_REASON_STOP_SEQUENCE = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GeneratedText {
|
||||||
|
/// Output
|
||||||
|
string text = 1;
|
||||||
|
/// Number of generated tokens
|
||||||
|
uint32 generated_tokens = 2;
|
||||||
|
/// Finish reason
|
||||||
|
FinishReason finish_reason = 3;
|
||||||
|
/// Seed
|
||||||
|
optional uint64 seed = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Tokens {
|
||||||
|
/// Token IDs
|
||||||
|
repeated uint32 ids = 1;
|
||||||
|
/// Logprobs
|
||||||
|
repeated float logprobs = 2;
|
||||||
|
/// tokens
|
||||||
|
repeated string texts = 3;
|
||||||
|
/// special
|
||||||
|
repeated bool is_special = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Generation {
|
||||||
|
/// Request ID
|
||||||
|
uint64 request_id = 1;
|
||||||
|
/// Prefill tokens (optional)
|
||||||
|
Tokens prefill_tokens = 2;
|
||||||
|
Tokens tokens = 3;
|
||||||
|
/// Complete generated text
|
||||||
|
optional GeneratedText generated_text = 4;
|
||||||
|
/// Top tokens
|
||||||
|
repeated Tokens top_tokens = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FilterBatchRequest {
|
||||||
|
/// Batch ID
|
||||||
|
uint64 batch_id = 1;
|
||||||
|
/// Requests to keep
|
||||||
|
repeated uint64 request_ids = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message FilterBatchResponse {
|
||||||
|
/// Filtered Batch (cached)
|
||||||
|
CachedBatch batch = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message PrefillRequest {
|
||||||
|
/// Batch
|
||||||
|
Batch batch = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PrefillResponse {
|
||||||
|
/// Generation
|
||||||
|
repeated Generation generations = 1;
|
||||||
|
/// Next batch (cached)
|
||||||
|
optional CachedBatch batch = 2;
|
||||||
|
/// Forward elapsed time in nanoseconds
|
||||||
|
uint64 forward_ns = 3;
|
||||||
|
/// Decode elapsed time in nanoseconds
|
||||||
|
uint64 decode_ns = 4;
|
||||||
|
/// Total elapsed time in nanoseconds
|
||||||
|
uint64 total_ns = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DecodeRequest {
|
||||||
|
/// Cached batches
|
||||||
|
repeated CachedBatch batches = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message DecodeResponse {
|
||||||
|
/// Decodes
|
||||||
|
repeated Generation generations = 1;
|
||||||
|
/// Next batch (cached)
|
||||||
|
optional CachedBatch batch = 2;
|
||||||
|
/// Forward elapsed time in nanoseconds
|
||||||
|
uint64 forward_ns = 3;
|
||||||
|
/// Decode elapsed time in nanoseconds
|
||||||
|
uint64 decode_ns = 4;
|
||||||
|
/// Total elapsed time in nanoseconds
|
||||||
|
uint64 total_ns = 5;
|
||||||
|
/// Concatenate elapsed time in nanoseconds
|
||||||
|
optional uint64 concat_ns = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message WarmupRequest {
|
||||||
|
/// Batch to warmup on
|
||||||
|
Batch batch = 1;
|
||||||
|
uint32 max_input_length = 2;
|
||||||
|
uint32 max_prefill_tokens = 3;
|
||||||
|
uint32 max_total_tokens = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message WarmupResponse {
|
||||||
|
/// Maximum number of tokens supported by the model
|
||||||
|
optional uint32 max_supported_total_tokens = 1;
|
||||||
|
}
|
@ -6,6 +6,7 @@ authors.workspace = true
|
|||||||
homepage.workspace = true
|
homepage.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
async-trait = "^0.1"
|
||||||
base64 = { workspace = true }
|
base64 = { workspace = true }
|
||||||
futures = "^0.3"
|
futures = "^0.3"
|
||||||
grpc-metadata = { path = "../grpc-metadata" }
|
grpc-metadata = { path = "../grpc-metadata" }
|
||||||
|
@ -1,19 +1,31 @@
|
|||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
||||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
println!("cargo:rerun-if-changed=../../proto/generate.proto");
|
println!("cargo:rerun-if-changed=../../proto/**");
|
||||||
fs::create_dir("src/pb").unwrap_or(());
|
|
||||||
|
|
||||||
|
fs::create_dir_all("src/v2/pb").unwrap_or(());
|
||||||
let mut config = prost_build::Config::new();
|
let mut config = prost_build::Config::new();
|
||||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||||
|
|
||||||
tonic_build::configure()
|
tonic_build::configure()
|
||||||
.build_client(true)
|
.build_client(true)
|
||||||
.build_server(false)
|
.build_server(false)
|
||||||
.out_dir("src/pb")
|
.out_dir("src/v2/pb")
|
||||||
.include_file("mod.rs")
|
.include_file("mod.rs")
|
||||||
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
||||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||||
|
|
||||||
|
fs::create_dir_all("src/v3/pb").unwrap_or(());
|
||||||
|
let mut config = prost_build::Config::new();
|
||||||
|
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||||
|
|
||||||
|
tonic_build::configure()
|
||||||
|
.build_client(true)
|
||||||
|
.build_server(false)
|
||||||
|
.out_dir("src/v3/pb")
|
||||||
|
.include_file("mod.rs")
|
||||||
|
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
|
||||||
|
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,25 +1,32 @@
|
|||||||
//! Text Generation gRPC client library
|
//! Text Generation gRPC client library
|
||||||
|
|
||||||
mod client;
|
pub mod v2;
|
||||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
pub mod v3;
|
||||||
mod pb;
|
|
||||||
mod sharded_client;
|
|
||||||
|
|
||||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
use async_trait::async_trait;
|
||||||
pub use client::Client;
|
|
||||||
pub use pb::generate::v2::input_chunk::Chunk;
|
|
||||||
pub use pb::generate::v2::HealthResponse;
|
|
||||||
pub use pb::generate::v2::Image;
|
|
||||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
|
||||||
pub use pb::generate::v2::{
|
|
||||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, Input, InputChunk,
|
|
||||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
|
||||||
};
|
|
||||||
pub use sharded_client::ShardedClient;
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tonic::transport;
|
use tonic::transport;
|
||||||
use tonic::Status;
|
use tonic::Status;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Health {
|
||||||
|
/// Check if a generate server is healthy by asking it to allocate a tensor on device
|
||||||
|
async fn device_health(&self) -> Result<()>;
|
||||||
|
|
||||||
|
/// Check if a generate server is healthy by doing a forward pass.
|
||||||
|
/// EXPENSIVE
|
||||||
|
async fn model_health(&self) -> Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ShardInfo {
|
||||||
|
pub requires_padding: bool,
|
||||||
|
pub dtype: String,
|
||||||
|
pub device_type: String,
|
||||||
|
pub window_size: Option<u32>,
|
||||||
|
pub speculate: u32,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug, Clone)]
|
#[derive(Error, Debug, Clone)]
|
||||||
pub enum ClientError {
|
pub enum ClientError {
|
||||||
#[error("Could not connect to Text Generation server: {0}")]
|
#[error("Could not connect to Text Generation server: {0}")]
|
||||||
|
1
router/client/src/pb/.gitignore
vendored
1
router/client/src/pb/.gitignore
vendored
@ -1 +0,0 @@
|
|||||||
*.rs
|
|
13
router/client/src/v2/mod.rs
Normal file
13
router/client/src/v2/mod.rs
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
mod pb;
|
||||||
|
|
||||||
|
mod client;
|
||||||
|
mod sharded_client;
|
||||||
|
|
||||||
|
pub use client::Client;
|
||||||
|
pub use pb::generate::v2::HealthResponse;
|
||||||
|
pub use pb::generate::v2::{
|
||||||
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
|
||||||
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||||
|
};
|
||||||
|
pub use sharded_client::ShardedClient;
|
1
router/client/src/v2/pb/.gitignore
vendored
Normal file
1
router/client/src/v2/pb/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
*
|
13
router/client/src/v3/mod.rs
Normal file
13
router/client/src/v3/mod.rs
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
mod pb;
|
||||||
|
|
||||||
|
mod client;
|
||||||
|
mod sharded_client;
|
||||||
|
|
||||||
|
pub use client::Client;
|
||||||
|
pub use pb::generate::v3::HealthResponse;
|
||||||
|
pub use pb::generate::v3::{
|
||||||
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
|
||||||
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||||
|
};
|
||||||
|
pub use sharded_client::ShardedClient;
|
1
router/client/src/v3/pb/.gitignore
vendored
Normal file
1
router/client/src/v3/pb/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
*
|
@ -1,22 +1,18 @@
|
|||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{
|
use text_generation_client::Health;
|
||||||
Batch, Input, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
|
||||||
};
|
|
||||||
use text_generation_client::{Chunk, GrammarType as ProtoGrammarType};
|
|
||||||
|
|
||||||
// Note: Request ids and batch ids cannot collide.
|
#[derive(Clone)]
|
||||||
const LIVENESS_ID: u64 = u64::MAX;
|
pub(crate) struct HealthCheck {
|
||||||
const BATCH_ID: u64 = u64::MAX;
|
client: Arc<dyn Health + Send + Sync>,
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub(crate) struct Health {
|
|
||||||
client: ShardedClient,
|
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Health {
|
impl HealthCheck {
|
||||||
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
|
pub(crate) fn new(
|
||||||
|
client: Arc<dyn Health + Send + Sync>,
|
||||||
|
generation_health: Arc<AtomicBool>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
generation_health,
|
generation_health,
|
||||||
@ -24,52 +20,15 @@ impl Health {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn check(&mut self) -> bool {
|
pub(crate) async fn check(&mut self) -> bool {
|
||||||
if self.generation_health.load(Ordering::SeqCst) {
|
let value = if self.generation_health.load(Ordering::SeqCst) {
|
||||||
// Generation is healthy, we only check that the shards are answering gRPC calls
|
// Generation is healthy, we only check that the shards can allocate on device
|
||||||
self.client.health().await.is_ok()
|
self.client.device_health().await
|
||||||
} else {
|
} else {
|
||||||
// Generation is unhealthy or have not sent any generation request yet
|
self.client.model_health().await
|
||||||
|
|
||||||
// Dummy batch of 1 token and 1 generated token
|
|
||||||
let liveness_request = Request {
|
|
||||||
id: LIVENESS_ID,
|
|
||||||
input_chunks: Some(Input {
|
|
||||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
|
||||||
}),
|
|
||||||
inputs: "liveness".to_string(),
|
|
||||||
truncate: 10,
|
|
||||||
prefill_logprobs: false,
|
|
||||||
parameters: Some(NextTokenChooserParameters {
|
|
||||||
temperature: 1.0,
|
|
||||||
top_k: 0,
|
|
||||||
top_p: 1.0,
|
|
||||||
typical_p: 1.0,
|
|
||||||
do_sample: false,
|
|
||||||
seed: 0,
|
|
||||||
repetition_penalty: 1.0,
|
|
||||||
frequency_penalty: 0.0,
|
|
||||||
watermark: false,
|
|
||||||
grammar: String::new(),
|
|
||||||
grammar_type: ProtoGrammarType::None as i32,
|
|
||||||
}),
|
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
|
||||||
max_new_tokens: 1,
|
|
||||||
stop_sequences: vec![],
|
|
||||||
ignore_eos_token: false,
|
|
||||||
}),
|
|
||||||
top_n_tokens: 0,
|
|
||||||
};
|
|
||||||
let batch = Batch {
|
|
||||||
id: BATCH_ID,
|
|
||||||
requests: vec![liveness_request],
|
|
||||||
size: 1,
|
|
||||||
max_tokens: 2,
|
|
||||||
};
|
|
||||||
// Skips the queue
|
|
||||||
let value = self.client.prefill(batch).await.is_ok();
|
|
||||||
// Update generation health
|
|
||||||
self.generation_health.store(value, Ordering::SeqCst);
|
|
||||||
value
|
|
||||||
}
|
}
|
||||||
|
.is_ok();
|
||||||
|
// Update generation health
|
||||||
|
self.generation_health.store(value, Ordering::SeqCst);
|
||||||
|
value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
ChatTemplateInputs, ChatTemplateVersions, Entry, FinishReason, GenerateRequest,
|
||||||
HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text,
|
GenerateStreamResponse, HubProcessorConfig, HubTokenizerConfig, Message, MessageChunk,
|
||||||
TextMessage, Token,
|
PrefillToken, Queue, Text, TextMessage, Token,
|
||||||
};
|
};
|
||||||
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
@ -15,9 +15,8 @@ use std::sync::{
|
|||||||
atomic::{AtomicBool, Ordering},
|
atomic::{AtomicBool, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
};
|
};
|
||||||
use text_generation_client::{
|
use text_generation_client::v2::{Batch, CachedBatch, Generation, ShardedClient};
|
||||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens,
|
use text_generation_client::{v2, ClientError};
|
||||||
};
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
||||||
@ -232,16 +231,8 @@ impl Infer {
|
|||||||
while let Some(response) = stream.next().await {
|
while let Some(response) = stream.next().await {
|
||||||
match response? {
|
match response? {
|
||||||
// Add prefill tokens
|
// Add prefill tokens
|
||||||
InferStreamResponse::Prefill(tokens) => {
|
InferStreamResponse::Prefill(prefill_tokens) => {
|
||||||
// Create Token objects
|
result_prefill = prefill_tokens;
|
||||||
// We do that here instead of in the Python code as Rust for loops are faster
|
|
||||||
result_prefill = tokens
|
|
||||||
.ids
|
|
||||||
.into_iter()
|
|
||||||
.zip(tokens.logprobs.into_iter())
|
|
||||||
.zip(tokens.texts.into_iter())
|
|
||||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
|
||||||
.collect();
|
|
||||||
}
|
}
|
||||||
// Push last token
|
// Push last token
|
||||||
InferStreamResponse::Intermediate { token, top_tokens } => {
|
InferStreamResponse::Intermediate { token, top_tokens } => {
|
||||||
@ -792,6 +783,16 @@ fn send_responses(
|
|||||||
let mut stopped = false;
|
let mut stopped = false;
|
||||||
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
|
// Create Token objects
|
||||||
|
// We do that here instead of in the Python code as Rust for loops are faster
|
||||||
|
let prefill_tokens = prefill_tokens
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(prefill_tokens.logprobs.into_iter())
|
||||||
|
.zip(prefill_tokens.texts.into_iter())
|
||||||
|
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||||
|
.collect();
|
||||||
|
|
||||||
// Send message
|
// Send message
|
||||||
entry
|
entry
|
||||||
.response_tx
|
.response_tx
|
||||||
@ -842,7 +843,7 @@ fn send_responses(
|
|||||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||||
token,
|
token,
|
||||||
top_tokens,
|
top_tokens,
|
||||||
generated_text: generated_text.clone(),
|
generated_text: GeneratedText::from(generated_text.clone()),
|
||||||
queued: entry.queue_time,
|
queued: entry.queue_time,
|
||||||
start: entry.batch_time.unwrap(),
|
start: entry.batch_time.unwrap(),
|
||||||
}))?;
|
}))?;
|
||||||
@ -877,10 +878,36 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct GeneratedText {
|
||||||
|
pub(crate) text: String,
|
||||||
|
pub(crate) generated_tokens: u32,
|
||||||
|
pub(crate) finish_reason: FinishReason,
|
||||||
|
pub(crate) seed: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<v2::GeneratedText> for GeneratedText {
|
||||||
|
fn from(value: v2::GeneratedText) -> Self {
|
||||||
|
let v2_finish_reason = v2::FinishReason::try_from(value.finish_reason).unwrap();
|
||||||
|
let finish_reason = match v2_finish_reason {
|
||||||
|
v2::FinishReason::Length => FinishReason::Length,
|
||||||
|
v2::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||||
|
v2::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
text: value.text,
|
||||||
|
generated_tokens: value.generated_tokens,
|
||||||
|
finish_reason,
|
||||||
|
seed: value.seed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) enum InferStreamResponse {
|
pub(crate) enum InferStreamResponse {
|
||||||
// Optional first message
|
// Optional first message
|
||||||
Prefill(Tokens),
|
Prefill(Vec<PrefillToken>),
|
||||||
// Intermediate messages
|
// Intermediate messages
|
||||||
Intermediate {
|
Intermediate {
|
||||||
token: Token,
|
token: Token,
|
||||||
@ -1355,11 +1382,11 @@ mod tests {
|
|||||||
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
|
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
|
||||||
input: ChatTemplateInputs {
|
input: ChatTemplateInputs {
|
||||||
messages: vec![
|
messages: vec![
|
||||||
TextMessage{
|
TextMessage {
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
|
content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
|
||||||
},
|
},
|
||||||
TextMessage{
|
TextMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "How many helicopters can a human eat in one sitting?".to_string(),
|
content: "How many helicopters can a human eat in one sitting?".to_string(),
|
||||||
},
|
},
|
||||||
|
@ -1087,7 +1087,7 @@ pub struct SimpleToken {
|
|||||||
stop: usize,
|
stop: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
#[serde(rename_all(serialize = "snake_case"))]
|
#[serde(rename_all(serialize = "snake_case"))]
|
||||||
#[schema(example = "Length")]
|
#[schema(example = "Length")]
|
||||||
pub(crate) enum FinishReason {
|
pub(crate) enum FinishReason {
|
||||||
|
@ -12,7 +12,7 @@ use std::fs::File;
|
|||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use text_generation_client::{ClientError, ShardedClient};
|
use text_generation_client::{v2::ShardedClient, ClientError};
|
||||||
use text_generation_router::config::Config;
|
use text_generation_router::config::Config;
|
||||||
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
|
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -1,12 +1,17 @@
|
|||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::infer::InferStreamResponse;
|
use crate::infer::InferStreamResponse;
|
||||||
use crate::validation::ValidGenerateRequest;
|
use crate::validation::{
|
||||||
|
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||||
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_client::ChunksToString;
|
use text_generation_client::ChunksToString;
|
||||||
use text_generation_client::Input;
|
use text_generation_client::Input;
|
||||||
use text_generation_client::{Batch, Request};
|
use text_generation_client::{Batch, Request};
|
||||||
|
use text_generation_client::v2::{
|
||||||
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Span};
|
use tracing::{info_span, instrument, Span};
|
||||||
@ -285,8 +290,12 @@ impl State {
|
|||||||
}),
|
}),
|
||||||
inputs: entry.request.inputs.chunks_to_string(),
|
inputs: entry.request.inputs.chunks_to_string(),
|
||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
||||||
parameters: Some(entry.request.parameters.clone()),
|
parameters: Some(NextTokenChooserParameters::from(
|
||||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
entry.request.parameters.clone(),
|
||||||
|
)),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters::from(
|
||||||
|
entry.request.stopping_parameters.clone(),
|
||||||
|
)),
|
||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
@ -355,6 +364,43 @@ enum QueueCommand {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<ValidParameters> for NextTokenChooserParameters {
|
||||||
|
fn from(value: ValidParameters) -> Self {
|
||||||
|
let (grammar, grammar_type) = match value.grammar {
|
||||||
|
None => (String::new(), GrammarType::None),
|
||||||
|
|
||||||
|
Some(grammar) => match grammar {
|
||||||
|
ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
|
||||||
|
ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
temperature: value.temperature,
|
||||||
|
top_k: value.top_k,
|
||||||
|
top_p: value.top_p,
|
||||||
|
typical_p: value.typical_p,
|
||||||
|
do_sample: value.do_sample,
|
||||||
|
seed: value.seed,
|
||||||
|
repetition_penalty: value.repetition_penalty,
|
||||||
|
frequency_penalty: value.frequency_penalty,
|
||||||
|
watermark: value.watermark,
|
||||||
|
grammar,
|
||||||
|
grammar_type: grammar_type.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
||||||
|
fn from(value: ValidStoppingParameters) -> Self {
|
||||||
|
Self {
|
||||||
|
max_new_tokens: value.max_new_tokens,
|
||||||
|
stop_sequences: value.stop_sequences,
|
||||||
|
ignore_eos_token: value.ignore_eos_token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::config::Config;
|
|
||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::health::Health;
|
use crate::config::Config;
|
||||||
|
use crate::health::HealthCheck;
|
||||||
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -34,7 +34,7 @@ use std::convert::Infallible;
|
|||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::atomic::AtomicBool;
|
use std::sync::atomic::AtomicBool;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{ShardInfo, ShardedClient};
|
use text_generation_client::{v2::ShardedClient, ShardInfo};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
@ -115,7 +115,9 @@ example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
|||||||
)]
|
)]
|
||||||
#[instrument(skip(health))]
|
#[instrument(skip(health))]
|
||||||
/// Health check method
|
/// Health check method
|
||||||
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
async fn health(
|
||||||
|
mut health: Extension<HealthCheck>,
|
||||||
|
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||||
match health.check().await {
|
match health.check().await {
|
||||||
true => Ok(()),
|
true => Ok(()),
|
||||||
false => Err((
|
false => Err((
|
||||||
@ -1482,7 +1484,7 @@ pub async fn run(
|
|||||||
grammar_support,
|
grammar_support,
|
||||||
);
|
);
|
||||||
let generation_health = Arc::new(AtomicBool::new(false));
|
let generation_health = Arc::new(AtomicBool::new(false));
|
||||||
let health_ext = Health::new(client.clone(), generation_health.clone());
|
let health_ext = HealthCheck::new(Arc::new(client.clone()), generation_health.clone());
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
client,
|
client,
|
||||||
validation,
|
validation,
|
||||||
@ -1719,17 +1721,6 @@ async fn shutdown_signal() {
|
|||||||
opentelemetry::global::shutdown_tracer_provider();
|
opentelemetry::global::shutdown_tracer_provider();
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<i32> for FinishReason {
|
|
||||||
fn from(finish_reason: i32) -> Self {
|
|
||||||
let finish_reason = text_generation_client::FinishReason::try_from(finish_reason).unwrap();
|
|
||||||
match finish_reason {
|
|
||||||
text_generation_client::FinishReason::Length => FinishReason::Length,
|
|
||||||
text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
|
||||||
text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert to Axum supported formats
|
/// Convert to Axum supported formats
|
||||||
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
fn from(err: InferError) -> Self {
|
fn from(err: InferError) -> Self {
|
||||||
|
@ -6,10 +6,6 @@ use jsonschema::{Draft, JSONSchema};
|
|||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use text_generation_client::{
|
|
||||||
Chunk, GrammarType as ProtoGrammarType, Image, InputChunk, NextTokenChooserParameters,
|
|
||||||
StoppingCriteriaParameters,
|
|
||||||
};
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
// use tokenizers::TruncationDirection;
|
// use tokenizers::TruncationDirection;
|
||||||
@ -173,10 +169,6 @@ impl Validation {
|
|||||||
// Validate MaxNewTokens
|
// Validate MaxNewTokens
|
||||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
||||||
// return Err(ValidationError::MaxNewTokens(
|
|
||||||
// self.max_total_tokens - self.max_input_length,
|
|
||||||
// max_new_tokens,
|
|
||||||
// ));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
@ -327,13 +319,13 @@ impl Validation {
|
|||||||
// compiler and use that to build the FSM here.
|
// compiler and use that to build the FSM here.
|
||||||
|
|
||||||
// Validate grammar and unpack the grammar and type for the proto message
|
// Validate grammar and unpack the grammar and type for the proto message
|
||||||
let (grammar, grammar_type) = match grammar {
|
let grammar = match grammar {
|
||||||
Some(grammar) => {
|
Some(grammar) => {
|
||||||
// Ensure that grammar is not set if it's not supported
|
// Ensure that grammar is not set if it's not supported
|
||||||
if self.disable_grammar_support {
|
if self.disable_grammar_support {
|
||||||
return Err(ValidationError::Grammar);
|
return Err(ValidationError::Grammar);
|
||||||
}
|
}
|
||||||
match grammar {
|
let valid_grammar = match grammar {
|
||||||
GrammarType::Json(json) => {
|
GrammarType::Json(json) => {
|
||||||
let json = match json {
|
let json = match json {
|
||||||
// if value is a string, we need to parse it again to make sure its
|
// if value is a string, we need to parse it again to make sure its
|
||||||
@ -350,20 +342,20 @@ impl Validation {
|
|||||||
.compile(&json)
|
.compile(&json)
|
||||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||||
|
|
||||||
(
|
// Serialize json to string
|
||||||
// Serialize json to string
|
ValidGrammar::Json(
|
||||||
serde_json::to_string(&json)
|
serde_json::to_string(&json)
|
||||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
|
||||||
ProtoGrammarType::Json.into(),
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
|
GrammarType::Regex(regex) => ValidGrammar::Regex(regex),
|
||||||
}
|
};
|
||||||
|
Some(valid_grammar)
|
||||||
}
|
}
|
||||||
None => (String::new(), ProtoGrammarType::None.into()),
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = ValidParameters {
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
@ -374,9 +366,8 @@ impl Validation {
|
|||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
grammar,
|
grammar,
|
||||||
grammar_type,
|
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = ValidStoppingParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
stop_sequences,
|
stop_sequences,
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
@ -458,6 +449,7 @@ fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
|
|||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format_to_mimetype(format: ImageFormat) -> String {
|
fn format_to_mimetype(format: ImageFormat) -> String {
|
||||||
match format {
|
match format {
|
||||||
ImageFormat::Png => "image/png",
|
ImageFormat::Png => "image/png",
|
||||||
@ -636,14 +628,55 @@ type TokenizerRequest = (
|
|||||||
Span,
|
Span,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) enum ValidGrammar {
|
||||||
|
Json(String),
|
||||||
|
Regex(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct ValidParameters {
|
||||||
|
/// / exponential scaling output probability distribution
|
||||||
|
pub temperature: f32,
|
||||||
|
/// / restricting to the k highest probability elements
|
||||||
|
pub top_k: u32,
|
||||||
|
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
pub top_p: f32,
|
||||||
|
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||||
|
pub typical_p: f32,
|
||||||
|
/// / apply sampling on the logits
|
||||||
|
pub do_sample: bool,
|
||||||
|
/// / random seed for sampling
|
||||||
|
pub seed: u64,
|
||||||
|
/// / repetition penalty
|
||||||
|
pub repetition_penalty: f32,
|
||||||
|
/// / frequency penalty
|
||||||
|
pub frequency_penalty: f32,
|
||||||
|
/// / token watermarking using "A Watermark for Large Language Models"
|
||||||
|
pub watermark: bool,
|
||||||
|
/// / grammar (applied if not empty)
|
||||||
|
pub grammar: Option<ValidGrammar>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct ValidStoppingParameters {
|
||||||
|
/// / Maximum number of generated tokens
|
||||||
|
pub max_new_tokens: u32,
|
||||||
|
/// / Optional stopping sequences
|
||||||
|
pub stop_sequences: Vec<String>,
|
||||||
|
/// / Ignore end of sequence token
|
||||||
|
/// / used for benchmarking
|
||||||
|
pub ignore_eos_token: bool,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct ValidGenerateRequest {
|
pub(crate) struct ValidGenerateRequest {
|
||||||
pub inputs: Vec<InputChunk>,
|
pub inputs: Vec<InputChunk>,
|
||||||
pub input_length: u32,
|
pub input_length: u32,
|
||||||
pub truncate: u32,
|
pub truncate: u32,
|
||||||
pub decoder_input_details: bool,
|
pub decoder_input_details: bool,
|
||||||
pub parameters: NextTokenChooserParameters,
|
pub parameters: ValidParameters,
|
||||||
pub stopping_parameters: StoppingCriteriaParameters,
|
pub stopping_parameters: ValidStoppingParameters,
|
||||||
pub top_n_tokens: u32,
|
pub top_n_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user