mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
Rebase.
This commit is contained in:
parent
ddbbf6b50c
commit
bc0a33e1c9
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,6 +3,8 @@ target
|
|||||||
router/tokenizer.json
|
router/tokenizer.json
|
||||||
*__pycache__*
|
*__pycache__*
|
||||||
|
|
||||||
|
backends/v3/src/client/pb
|
||||||
|
|
||||||
# ROCm auto-generated files
|
# ROCm auto-generated files
|
||||||
*.hip
|
*.hip
|
||||||
server/exllamav2_kernels/exllamav2_kernels/hip/
|
server/exllamav2_kernels/exllamav2_kernels/hip/
|
||||||
|
4963
Cargo.lock
generated
Normal file
4963
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,11 +1,12 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
# "benchmark",
|
# "benchmark",
|
||||||
"backends/v3",
|
"backends/v3",
|
||||||
# "backends/client",
|
# "backends/client",
|
||||||
"backends/grpc-metadata",
|
"backends/grpc-metadata",
|
||||||
|
# "backends/trtllm",
|
||||||
"launcher"
|
"launcher"
|
||||||
, "backends/trtllm"]
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
@ -18,6 +19,8 @@ homepage = "https://github.com/huggingface/text-generation-inference"
|
|||||||
base64 = "0.22.0"
|
base64 = "0.22.0"
|
||||||
tokenizers = { version = "0.19.1", features = ["http"] }
|
tokenizers = { version = "0.19.1", features = ["http"] }
|
||||||
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
||||||
|
metrics = { version = "0.23.0" }
|
||||||
|
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
incremental = true
|
incremental = true
|
||||||
|
@ -2,8 +2,8 @@ use std::future::Future;
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::pin::{pin, Pin};
|
use std::pin::{pin, Pin};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use std::sync::{Arc, OnceLock};
|
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::{Arc, OnceLock};
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
@ -13,15 +13,15 @@ use log::{error, warn};
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tokio::time::{Instant, sleep};
|
use tokio::time::{sleep, Instant};
|
||||||
use tokio_stream::{Stream, StreamExt};
|
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{instrument, Level, span};
|
use tokio_stream::{Stream, StreamExt};
|
||||||
|
use tracing::{instrument, span, Level};
|
||||||
|
|
||||||
use text_generation_router::{FinishReason, Token};
|
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
|
|
||||||
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
||||||
|
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
|
||||||
|
use text_generation_router::{FinishReason, Token};
|
||||||
|
|
||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||||
|
@ -24,8 +24,8 @@ grpc-metadata = { path = "../grpc-metadata" }
|
|||||||
futures = "0.3.28"
|
futures = "0.3.28"
|
||||||
hf-hub = { workspace = true }
|
hf-hub = { workspace = true }
|
||||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||||
metrics = "0.21.1"
|
metrics = { workspace = true }
|
||||||
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
metrics-exporter-prometheus = { workspace = true }
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = "0.13.0"
|
opentelemetry-otlp = "0.13.0"
|
||||||
|
@ -148,8 +148,8 @@ pub(crate) async fn batching_task(
|
|||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
let batch_max_tokens = batch.max_tokens;
|
let batch_max_tokens = batch.max_tokens;
|
||||||
let mut batches = vec![batch];
|
let mut batches = vec![batch];
|
||||||
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
|
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||||
|
|
||||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
@ -170,9 +170,11 @@ pub(crate) async fn batching_task(
|
|||||||
{
|
{
|
||||||
// Tracking metrics
|
// Tracking metrics
|
||||||
if min_size.is_some() {
|
if min_size.is_some() {
|
||||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
|
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||||
|
.increment(1);
|
||||||
} else {
|
} else {
|
||||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
|
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||||
|
.increment(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
entries.iter_mut().for_each(|(_, entry)| {
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
@ -218,8 +220,8 @@ pub(crate) async fn batching_task(
|
|||||||
.await;
|
.await;
|
||||||
waiting_tokens += 1;
|
waiting_tokens += 1;
|
||||||
}
|
}
|
||||||
metrics::gauge!("tgi_batch_current_size", 0.0);
|
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
|
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -232,7 +234,7 @@ async fn prefill(
|
|||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
@ -243,18 +245,22 @@ async fn prefill(
|
|||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
||||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
.record(timings.decode.as_secs_f64());
|
||||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
let _ = client.clear_cache(Some(batch_id)).await;
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -268,7 +274,7 @@ async fn decode(
|
|||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||||
|
|
||||||
match client.decode(batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
@ -280,13 +286,18 @@ async fn decode(
|
|||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
if let Some(concat_duration) = timings.concat {
|
if let Some(concat_duration) = timings.concat {
|
||||||
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||||
|
.record(concat_duration.as_secs_f64());
|
||||||
}
|
}
|
||||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
.record(timings.decode.as_secs_f64());
|
||||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
@ -295,7 +306,7 @@ async fn decode(
|
|||||||
let _ = client.clear_cache(Some(id)).await;
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
}
|
}
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -353,7 +364,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||||||
// request and we need to stop generating hence why we unwrap_or(true)
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||||
tracing::error!("Entry response channel error.");
|
tracing::error!("Entry response channel error.");
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
err
|
err
|
||||||
}).unwrap_or(true);
|
}).unwrap_or(true);
|
||||||
if stopped {
|
if stopped {
|
||||||
@ -369,7 +380,7 @@ fn send_responses(
|
|||||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||||
// Return directly if the channel is disconnected
|
// Return directly if the channel is disconnected
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
return Ok(true);
|
return Ok(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -395,7 +406,7 @@ fn send_responses(
|
|||||||
// Create last Token
|
// Create last Token
|
||||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
let n = tokens_.ids.len();
|
let n = tokens_.ids.len();
|
||||||
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||||
let mut iterator = tokens_
|
let mut iterator = tokens_
|
||||||
.ids
|
.ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@ -460,7 +471,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
let err = InferError::GenerationError(error.to_string());
|
let err = InferError::GenerationError(error.to_string());
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
|
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
@ -8,10 +8,10 @@ use tonic::Status;
|
|||||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
mod pb;
|
mod pb;
|
||||||
|
|
||||||
mod client;
|
mod grpc_client;
|
||||||
mod sharded_client;
|
mod sharded_client;
|
||||||
|
|
||||||
pub use client::Client;
|
pub use grpc_client::Client;
|
||||||
pub use pb::generate::v3::{
|
pub use pb::generate::v3::{
|
||||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||||
|
@ -2,7 +2,7 @@ use crate::client::{ClientError, Result};
|
|||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::client::{Health, ShardInfo};
|
use crate::client::{Health, ShardInfo};
|
||||||
|
|
||||||
use crate::client::client::{DecodeTimings, PrefillTimings};
|
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||||
use crate::client::{
|
use crate::client::{
|
||||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
@ -30,6 +30,7 @@ pub struct BackendInfo {
|
|||||||
pub max_batch_size: Option<usize>,
|
pub max_batch_size: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn connect_backend(
|
pub async fn connect_backend(
|
||||||
max_input_tokens: usize,
|
max_input_tokens: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use clap::Parser;
|
use clap::{Parser, Subcommand};
|
||||||
use text_generation_router::server;
|
use text_generation_router::server;
|
||||||
use text_generation_router_v3::{connect_backend, V3Error};
|
use text_generation_router_v3::{connect_backend, V3Error};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
@ -7,6 +7,9 @@ use thiserror::Error;
|
|||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Option<Commands>,
|
||||||
|
|
||||||
#[clap(default_value = "128", long, env)]
|
#[clap(default_value = "128", long, env)]
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
#[clap(default_value = "2", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
@ -44,6 +47,8 @@ struct Args {
|
|||||||
#[clap(default_value = "2", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
api_key: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
json_output: bool,
|
json_output: bool,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
@ -65,12 +70,18 @@ struct Args {
|
|||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
enum Commands {
|
||||||
|
PrintSchema,
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), RouterError> {
|
async fn main() -> Result<(), RouterError> {
|
||||||
// Get args
|
// Get args
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let Args {
|
let Args {
|
||||||
|
command,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
@ -89,6 +100,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
|
api_key,
|
||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
otlp_service_name,
|
otlp_service_name,
|
||||||
@ -101,8 +113,19 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
|
let print_schema_command = match command {
|
||||||
|
Some(Commands::PrintSchema) => true,
|
||||||
|
None => {
|
||||||
|
// only init logging if we are not running the print schema command
|
||||||
|
text_generation_router::logging::init_logging(
|
||||||
|
otlp_endpoint,
|
||||||
|
otlp_service_name,
|
||||||
|
json_output,
|
||||||
|
);
|
||||||
|
false
|
||||||
|
}
|
||||||
|
};
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
|
||||||
|
|
||||||
// Validate args
|
// Validate args
|
||||||
if max_input_tokens >= max_total_tokens {
|
if max_input_tokens >= max_total_tokens {
|
||||||
@ -151,6 +174,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
|
api_key,
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
@ -163,6 +187,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
|
print_schema_command,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -17,8 +17,8 @@ futures = "0.3.28"
|
|||||||
hf-hub = { workspace = true }
|
hf-hub = { workspace = true }
|
||||||
itertools = "0.10"
|
itertools = "0.10"
|
||||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||||
metrics = "0.23.0"
|
metrics = { workspace = true }
|
||||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
metrics-exporter-prometheus = { workspace = true }
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = "0.13.0"
|
opentelemetry-otlp = "0.13.0"
|
||||||
@ -48,6 +48,7 @@ base64 = { workspace = true }
|
|||||||
sysinfo = "0.30.13"
|
sysinfo = "0.30.13"
|
||||||
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] }
|
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] }
|
||||||
csv = "1.3.0"
|
csv = "1.3.0"
|
||||||
|
ureq = "=2.9"
|
||||||
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{ChatTemplateInputs, GrammarType, Message, MessageChunk, Text, TextMessage};
|
use crate::{
|
||||||
|
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
|
||||||
|
};
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
use minijinja_contrib::pycompat;
|
use minijinja_contrib::pycompat;
|
||||||
|
|
||||||
@ -19,8 +21,8 @@ pub(crate) struct ChatTemplate {
|
|||||||
impl ChatTemplate {
|
impl ChatTemplate {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
template: String,
|
template: String,
|
||||||
bos_token: Option<String>,
|
bos_token: Option<TokenizerConfigToken>,
|
||||||
eos_token: Option<String>,
|
eos_token: Option<TokenizerConfigToken>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let mut env = Box::new(Environment::new());
|
let mut env = Box::new(Environment::new());
|
||||||
// enable things like .strip() or .capitalize()
|
// enable things like .strip() or .capitalize()
|
||||||
@ -38,8 +40,8 @@ impl ChatTemplate {
|
|||||||
|
|
||||||
Self {
|
Self {
|
||||||
template,
|
template,
|
||||||
bos_token,
|
bos_token: bos_token.map(|token| token.as_str().to_string()),
|
||||||
eos_token,
|
eos_token: eos_token.map(|token| token.as_str().to_string()),
|
||||||
use_default_tool_template,
|
use_default_tool_template,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -52,9 +54,9 @@ impl ChatTemplate {
|
|||||||
if self.use_default_tool_template {
|
if self.use_default_tool_template {
|
||||||
if let Some(last_message) = messages.last_mut() {
|
if let Some(last_message) = messages.last_mut() {
|
||||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||||
last_message.content.push(MessageChunk::Text(Text {
|
last_message.content.push(MessageChunk::Text {
|
||||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||||
}));
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,46 +1,62 @@
|
|||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolType, Tools};
|
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
|
||||||
use serde_json::{json, Map, Value};
|
use serde_json::{json, Map, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
pub(crate) struct ToolGrammar {}
|
pub(crate) struct ToolGrammar {}
|
||||||
|
|
||||||
impl ToolGrammar {
|
impl ToolGrammar {
|
||||||
|
// find a tool by name
|
||||||
|
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.find(|tool| tool.function.name == name)
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn apply(
|
pub fn apply(
|
||||||
tools: Option<Vec<Tool>>,
|
tools: Option<Vec<Tool>>,
|
||||||
tool_choice: Option<ToolType>,
|
tool_choice: ToolChoice,
|
||||||
) -> Result<Option<Tools>, InferError> {
|
) -> Result<Option<Tools>, InferError> {
|
||||||
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
// if no tools are provided, we return None
|
||||||
// let tool_prompt = tool_prompt.unwrap_or_default();
|
let tools = match tools {
|
||||||
|
Some(tools) if !tools.is_empty() => tools,
|
||||||
|
_ => return Ok(None),
|
||||||
|
};
|
||||||
|
|
||||||
|
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||||
|
|
||||||
|
// if tools are provided and no tool_choice we default to the OneOf
|
||||||
let tools_to_use = match tool_choice {
|
let tools_to_use = match tool_choice {
|
||||||
ToolType::FunctionName(name) => {
|
ToolType::FunctionName(name) => {
|
||||||
vec![req_tools
|
vec![Self::find_tool_by_name(&tools, &name)?]
|
||||||
.iter()
|
|
||||||
.find(|tool| tool.function.name == *name)
|
|
||||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
|
||||||
.clone()]
|
|
||||||
}
|
}
|
||||||
ToolType::OneOf => req_tools.to_owned(),
|
ToolType::Function { function } => {
|
||||||
|
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||||
|
}
|
||||||
|
ToolType::OneOf => tools,
|
||||||
|
ToolType::NoTool => return Ok(None),
|
||||||
};
|
};
|
||||||
|
|
||||||
// adds the error notification function for LLM feedback if required
|
// adds the error notification function for LLM feedback if required
|
||||||
let mut text_response_properties = Map::new();
|
let mut text_response_properties = Map::new();
|
||||||
text_response_properties.insert(
|
text_response_properties.insert(
|
||||||
"error".to_string(),
|
"error".to_string(),
|
||||||
json!({
|
serde_json::json!({
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The error or issue to notify"
|
"description": "The error or issue to notify"
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
text_response_properties.insert(
|
text_response_properties.insert(
|
||||||
"_name".to_string(),
|
"_name".to_string(),
|
||||||
json!({
|
serde_json::json!({
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"const": "notify_error"
|
"const": "notify_error"
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
let functions: HashMap<String, Value> = tools_to_use
|
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||||
.iter()
|
.iter()
|
||||||
.map(|tool| {
|
.map(|tool| {
|
||||||
let func = tool.function.clone();
|
let func = tool.function.clone();
|
||||||
@ -91,7 +107,7 @@ impl ToolGrammar {
|
|||||||
})
|
})
|
||||||
.chain([(
|
.chain([(
|
||||||
"notify_error".to_string(),
|
"notify_error".to_string(),
|
||||||
json!({
|
serde_json::json!({
|
||||||
"properties": text_response_properties,
|
"properties": text_response_properties,
|
||||||
"required": ["error", "_name"],
|
"required": ["error", "_name"],
|
||||||
"type": "object"
|
"type": "object"
|
||||||
@ -114,9 +130,6 @@ impl ToolGrammar {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
return Ok(Some(tools));
|
Ok(Some(tools))
|
||||||
}
|
|
||||||
// Err(InferError::ToolError("No tools provided".to_string()))
|
|
||||||
Ok(None)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -35,9 +35,9 @@ use futures::stream::StreamExt;
|
|||||||
use futures::stream::{FuturesOrdered, FuturesUnordered};
|
use futures::stream::{FuturesOrdered, FuturesUnordered};
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
use http::header::AUTHORIZATION;
|
|
||||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||||
use hf_hub::{Cache, Repo, RepoType};
|
use hf_hub::{Cache, Repo, RepoType};
|
||||||
|
use http::header::AUTHORIZATION;
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
@ -46,6 +46,7 @@ use std::io::BufReader;
|
|||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use tokenizers::processors::template::TemplateProcessing;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
@ -1406,15 +1407,16 @@ pub async fn run(
|
|||||||
max_input_tokens: usize,
|
max_input_tokens: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
addr: SocketAddr,
|
|
||||||
allow_origin: Option<AllowOrigin>,
|
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
|
tokenizer_name: String,
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
revision: Option<String>,
|
||||||
|
hostname: String,
|
||||||
|
port: u16,
|
||||||
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
ngrok: bool,
|
ngrok: bool,
|
||||||
_ngrok_authtoken: Option<String>,
|
_ngrok_authtoken: Option<String>,
|
||||||
_ngrok_edge: Option<String>,
|
_ngrok_edge: Option<String>,
|
||||||
tokenizer_config: HubTokenizerConfig,
|
|
||||||
preprocessor_config: Option<HubPreprocessorConfig>,
|
|
||||||
processor_config: HubProcessorConfig,
|
|
||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
grammar_support: bool,
|
grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
@ -1507,6 +1509,16 @@ pub async fn run(
|
|||||||
println!("{}", api_doc);
|
println!("{}", api_doc);
|
||||||
std::process::exit(0);
|
std::process::exit(0);
|
||||||
}
|
}
|
||||||
|
// CORS allowed origins
|
||||||
|
// map to go inside the option and then map to parse from String to HeaderValue
|
||||||
|
// Finally, convert to AllowOrigin
|
||||||
|
let allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
|
||||||
|
AllowOrigin::list(
|
||||||
|
cors_allow_origin
|
||||||
|
.iter()
|
||||||
|
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
// Parse Huggingface hub token
|
// Parse Huggingface hub token
|
||||||
let authorization_token = std::env::var("HF_TOKEN")
|
let authorization_token = std::env::var("HF_TOKEN")
|
||||||
@ -1564,6 +1576,7 @@ pub async fn run(
|
|||||||
tokenizer_filename,
|
tokenizer_filename,
|
||||||
config_filename,
|
config_filename,
|
||||||
tokenizer_config_filename,
|
tokenizer_config_filename,
|
||||||
|
preprocessor_config_filename,
|
||||||
processor_config_filename,
|
processor_config_filename,
|
||||||
model_info,
|
model_info,
|
||||||
) = match api {
|
) = match api {
|
||||||
@ -1571,6 +1584,7 @@ pub async fn run(
|
|||||||
Some(local_path.join("tokenizer.json")),
|
Some(local_path.join("tokenizer.json")),
|
||||||
Some(local_path.join("config.json")),
|
Some(local_path.join("config.json")),
|
||||||
Some(local_path.join("tokenizer_config.json")),
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
|
Some(local_path.join("preprocessor_config.json")),
|
||||||
Some(local_path.join("processor_config.json")),
|
Some(local_path.join("processor_config.json")),
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
@ -1587,6 +1601,7 @@ pub async fn run(
|
|||||||
};
|
};
|
||||||
let config_filename = api_repo.get("config.json").await.ok();
|
let config_filename = api_repo.get("config.json").await.ok();
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
|
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||||
|
|
||||||
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
|
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
|
||||||
@ -1599,6 +1614,7 @@ pub async fn run(
|
|||||||
tokenizer_filename,
|
tokenizer_filename,
|
||||||
config_filename,
|
config_filename,
|
||||||
tokenizer_config_filename,
|
tokenizer_config_filename,
|
||||||
|
preprocessor_config_filename,
|
||||||
processor_config_filename,
|
processor_config_filename,
|
||||||
model_info,
|
model_info,
|
||||||
)
|
)
|
||||||
@ -1613,13 +1629,40 @@ pub async fn run(
|
|||||||
repo.get("tokenizer.json"),
|
repo.get("tokenizer.json"),
|
||||||
repo.get("config.json"),
|
repo.get("config.json"),
|
||||||
repo.get("tokenizer_config.json"),
|
repo.get("tokenizer_config.json"),
|
||||||
|
repo.get("preprocessor_config.json"),
|
||||||
repo.get("processor_config.json"),
|
repo.get("processor_config.json"),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let tokenizer: Option<Tokenizer> =
|
|
||||||
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
|
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||||
|
{
|
||||||
|
HubTokenizerConfig::from_file(filename)
|
||||||
|
} else {
|
||||||
|
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||||
|
};
|
||||||
|
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||||
|
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||||
|
HubTokenizerConfig::default()
|
||||||
|
});
|
||||||
|
|
||||||
|
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
|
||||||
|
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
||||||
|
if let Some(tokenizer) = &mut tokenizer {
|
||||||
|
if let Some(class) = &tokenizer_config.tokenizer_class {
|
||||||
|
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
|
||||||
|
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
||||||
|
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
||||||
|
tokenizer.with_post_processor(post_processor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokenizer
|
||||||
|
});
|
||||||
|
|
||||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||||
std::fs::read_to_string(filename)
|
std::fs::read_to_string(filename)
|
||||||
.ok()
|
.ok()
|
||||||
@ -1638,22 +1681,13 @@ pub async fn run(
|
|||||||
pipeline_tag: None,
|
pipeline_tag: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
|
||||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
|
||||||
{
|
|
||||||
HubTokenizerConfig::from_file(filename)
|
|
||||||
} else {
|
|
||||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
|
||||||
};
|
|
||||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
|
||||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
|
||||||
HubTokenizerConfig::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
let processor_config = processor_config_filename
|
let processor_config = processor_config_filename
|
||||||
.and_then(HubProcessorConfig::from_file)
|
.and_then(HubProcessorConfig::from_file)
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let preprocessor_config: Option<HubPreprocessorConfig> =
|
||||||
|
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
|
||||||
|
|
||||||
tracing::info!("Using config {config:?}");
|
tracing::info!("Using config {config:?}");
|
||||||
if tokenizer.is_none() {
|
if tokenizer.is_none() {
|
||||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||||
@ -2107,3 +2141,74 @@ pub enum WebServerError {
|
|||||||
#[error("Axum error: {0}")]
|
#[error("Axum error: {0}")]
|
||||||
Axum(#[from] axum::BoxError),
|
Axum(#[from] axum::BoxError),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a post_processor for the LlamaTokenizer
|
||||||
|
fn create_post_processor(
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
tokenizer_config: &HubTokenizerConfig,
|
||||||
|
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
|
||||||
|
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
|
||||||
|
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
|
||||||
|
|
||||||
|
let bos_token = tokenizer_config.bos_token.as_ref();
|
||||||
|
let eos_token = tokenizer_config.eos_token.as_ref();
|
||||||
|
|
||||||
|
if add_bos_token && bos_token.is_none() {
|
||||||
|
panic!("add_bos_token = true but bos_token is None");
|
||||||
|
}
|
||||||
|
|
||||||
|
if add_eos_token && eos_token.is_none() {
|
||||||
|
panic!("add_eos_token = true but eos_token is None");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut single = Vec::new();
|
||||||
|
let mut pair = Vec::new();
|
||||||
|
let mut special_tokens = Vec::new();
|
||||||
|
|
||||||
|
if add_bos_token {
|
||||||
|
if let Some(bos) = bos_token {
|
||||||
|
let bos_token_id = tokenizer
|
||||||
|
.token_to_id(bos.as_str())
|
||||||
|
.expect("Should have found the bos token id");
|
||||||
|
special_tokens.push((bos.as_str(), bos_token_id));
|
||||||
|
single.push(format!("{}:0", bos.as_str()));
|
||||||
|
pair.push(format!("{}:0", bos.as_str()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
single.push("$A:0".to_string());
|
||||||
|
pair.push("$A:0".to_string());
|
||||||
|
|
||||||
|
if add_eos_token {
|
||||||
|
if let Some(eos) = eos_token {
|
||||||
|
let eos_token_id = tokenizer
|
||||||
|
.token_to_id(eos.as_str())
|
||||||
|
.expect("Should have found the eos token id");
|
||||||
|
special_tokens.push((eos.as_str(), eos_token_id));
|
||||||
|
single.push(format!("{}:0", eos.as_str()));
|
||||||
|
pair.push(format!("{}:0", eos.as_str()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if add_bos_token {
|
||||||
|
if let Some(bos) = bos_token {
|
||||||
|
pair.push(format!("{}:1", bos.as_str()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pair.push("$B:1".to_string());
|
||||||
|
|
||||||
|
if add_eos_token {
|
||||||
|
if let Some(eos) = eos_token {
|
||||||
|
pair.push(format!("{}:1", eos.as_str()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let post_processor = TemplateProcessing::builder()
|
||||||
|
.try_single(single)?
|
||||||
|
.try_pair(pair)?
|
||||||
|
.special_tokens(special_tokens)
|
||||||
|
.build()?;
|
||||||
|
|
||||||
|
Ok(post_processor)
|
||||||
|
}
|
||||||
|
@ -5,13 +5,12 @@ use crate::{
|
|||||||
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
|
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
|
||||||
};
|
};
|
||||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||||
use image::{io::Reader as ImageReader, ImageFormat};
|
use image::{ImageFormat, ImageReader};
|
||||||
use jsonschema::{Draft, JSONSchema};
|
use jsonschema::{Draft, JSONSchema};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use std::iter;
|
use std::iter;
|
||||||
use text_generation_client::{Chunk, Image, InputChunk};
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
@ -181,11 +180,7 @@ impl Validation {
|
|||||||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((
|
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
|
||||||
vec![Chunk::Text(inputs).into()],
|
|
||||||
input_length,
|
|
||||||
max_new_tokens,
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -589,7 +584,7 @@ fn prepare_input(
|
|||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
config: Option<&Config>,
|
config: Option<&Config>,
|
||||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||||
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
|
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
||||||
use Config::*;
|
use Config::*;
|
||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
@ -601,16 +596,16 @@ fn prepare_input(
|
|||||||
let chunk_start = chunk.start();
|
let chunk_start = chunk.start();
|
||||||
let chunk_end = chunk.end();
|
let chunk_end = chunk.end();
|
||||||
if chunk_start != start {
|
if chunk_start != start {
|
||||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
input_chunks.push(Chunk::Image(Image { data, mimetype }));
|
||||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
if start != inputs.len() {
|
if start != inputs.len() {
|
||||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
input_chunks.push(Chunk::Text(inputs[start..].to_string()));
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -618,7 +613,7 @@ fn prepare_input(
|
|||||||
|
|
||||||
(tokenizer_query, input_chunks)
|
(tokenizer_query, input_chunks)
|
||||||
}
|
}
|
||||||
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
|
_ => (inputs.clone(), vec![Chunk::Text(inputs)]),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
@ -784,8 +779,7 @@ pub enum ValidationError {
|
|||||||
#[error("Could not fetch image: {0}")]
|
#[error("Could not fetch image: {0}")]
|
||||||
FailedFetchImage(#[from] reqwest::Error),
|
FailedFetchImage(#[from] reqwest::Error),
|
||||||
#[error("{0} modality is not supported")]
|
#[error("{0} modality is not supported")]
|
||||||
UnsupportedModality(&'static str)
|
UnsupportedModality(&'static str),
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
Loading…
Reference in New Issue
Block a user