This commit is contained in:
OlivierDehaene 2024-06-03 18:09:27 +02:00
parent b9dffbd512
commit 8a820d882d
15 changed files with 118 additions and 102 deletions

42
Cargo.lock generated
View File

@ -4,9 +4,9 @@ version = 3
[[package]]
name = "addr2line"
version = "0.21.0"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb"
checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678"
dependencies = [
"gimli",
]
@ -350,9 +350,9 @@ dependencies = [
[[package]]
name = "backtrace"
version = "0.3.71"
version = "0.3.72"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d"
checksum = "17c6a35df3749d2e8bb1b7b21a976d82b15548788d2735b9d82f329268f71a11"
dependencies = [
"addr2line",
"cc",
@ -1138,9 +1138,9 @@ dependencies = [
[[package]]
name = "gimli"
version = "0.28.1"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd"
[[package]]
name = "glob"
@ -1396,9 +1396,9 @@ dependencies = [
[[package]]
name = "hyper-util"
version = "0.1.4"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d8d52be92d09acc2e01dddb7fde3ad983fc6489c7db4837e605bc3fca4cb63e"
checksum = "7b875924a60b96e5d7b9ae7b066540b1dd1cbd90d1828f54c92e02a283351c56"
dependencies = [
"bytes",
"futures-util",
@ -1938,11 +1938,10 @@ dependencies = [
[[package]]
name = "native-tls"
version = "0.2.11"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e"
checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466"
dependencies = [
"lazy_static",
"libc",
"log",
"openssl",
@ -2168,9 +2167,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "object"
version = "0.32.2"
version = "0.35.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441"
checksum = "b8ec7ab813848ba4522158d5517a6093db1ded27575b070f4177b8d12b41db5e"
dependencies = [
"memchr",
]
@ -2563,9 +2562,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.84"
version = "1.0.85"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6"
checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23"
dependencies = [
"unicode-ident",
]
@ -3554,6 +3553,7 @@ dependencies = [
name = "text-generation-client"
version = "2.0.5-dev0"
dependencies = [
"async-trait",
"base64 0.22.1",
"futures",
"grpc-metadata",
@ -3752,9 +3752,9 @@ dependencies = [
[[package]]
name = "tokio"
version = "1.37.0"
version = "1.38.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787"
checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a"
dependencies = [
"backtrace",
"bytes",
@ -3781,9 +3781,9 @@ dependencies = [
[[package]]
name = "tokio-macros"
version = "2.2.0"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a"
dependencies = [
"proc-macro2",
"quote",
@ -4733,9 +4733,9 @@ checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
[[package]]
name = "winnow"
version = "0.6.8"
version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3c52e9c97a68071b23e836c9380edae937f17b9c4667bd021973efc689f618d"
checksum = "86c949fede1d13936a99f14fafd3e76fd642b556dd2ce96287fbe2e0151bfac6"
dependencies = [
"memchr",
]

View File

@ -22,7 +22,7 @@ hf-hub = { version = "0.3.1", features = ["tokio"] }
[profile.release]
debug = 1
incremental = true
lto = "fat"
opt-level = 3
codegen-units = 1
#lto = "fat"
#opt-level = 3
#codegen-units = 1
panic = "abort"

View File

@ -1,9 +1,9 @@
use std::time::{Duration, Instant};
use text_generation_client::v2::{
use text_generation_client::v3::{
Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters,
};
use text_generation_client::ClientError;
use text_generation_client::{Chunk, ClientError, Input};
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::sync::{broadcast, mpsc};

View File

@ -8,7 +8,7 @@ use crate::app::App;
use crate::event::Event;
use crossterm::ExecutableCommand;
use std::io;
use text_generation_client::v2::{GrammarType, NextTokenChooserParameters, ShardedClient};
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend;

View File

@ -4,7 +4,7 @@
/// and: https://github.com/orhun/rust-tui-template
use clap::Parser;
use std::path::Path;
use text_generation_client::v2::ShardedClient;
use text_generation_client::v3::ShardedClient;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;

View File

@ -51,27 +51,6 @@ message ClearCacheRequest {
/// Empty response
message ClearCacheResponse {}
message Image {
/// Binary image data.
bytes data = 1;
/// Image MIME type.
string mimetype = 2;
}
message InputChunk {
oneof chunk {
/// Plain text data
string text = 1;
/// Image data
Image image = 2;
}
}
message Input {
repeated InputChunk chunks = 1;
}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
@ -116,9 +95,7 @@ message StoppingCriteriaParameters {
message Request {
/// Request ID
uint64 id = 1;
/// The generation context as chunks
Input input_chunks = 8;
/// The generation context, stringified input_chunks
/// The generation context
string inputs = 2;
/// Context truncation
uint32 truncate = 3;

View File

@ -51,6 +51,27 @@ message ClearCacheRequest {
/// Empty response
message ClearCacheResponse {}
message Image {
/// Binary image data.
bytes data = 1;
/// Image MIME type.
string mimetype = 2;
}
message InputChunk {
oneof chunk {
/// Plain text data
string text = 1;
/// Image data
Image image = 2;
}
}
message Input {
repeated InputChunk chunks = 1;
}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
@ -95,7 +116,9 @@ message StoppingCriteriaParameters {
message Request {
/// Request ID
uint64 id = 1;
/// The generation context
/// The generation context as chunks
Input input_chunks = 8;
/// The generation context, stringified input_chunks
string inputs = 2;
/// Context truncation
uint32 truncate = 3;

View File

@ -4,10 +4,13 @@ pub mod v2;
pub mod v3;
use async_trait::async_trait;
use base64::{engine::general_purpose::STANDARD, Engine};
use thiserror::Error;
use tonic::transport;
use tonic::Status;
pub use v3::{Chunk, Image, Input, InputChunk};
#[async_trait]
pub trait Health {
/// Check if a generate server is healthy by asking it to allocate a tensor on device
@ -47,12 +50,12 @@ impl From<Status> for ClientError {
impl From<transport::Error> for ClientError {
fn from(err: transport::Error) -> Self {
Self::Connection(err.to_string())
let err = Self::Connection(err.to_string());
tracing::error!("{err}");
err
}
}
pub type Result<T> = std::result::Result<T, ClientError>;
// Small convenience re-wrapping of `Chunk`.
impl From<Chunk> for InputChunk {
fn from(chunk: Chunk) -> Self {
@ -82,3 +85,7 @@ impl ChunksToString for Vec<InputChunk> {
output
}
}
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
pub type Result<T> = std::result::Result<T, ClientError>;

View File

@ -1,7 +1,8 @@
/// Single shard Client
use crate::v2::pb;
use crate::Result;
use crate::{ClientError, Result};
use crate::WARMUP_IMAGE_BASE64;
use grpc_metadata::InjectTelemetryContext;
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v2::*;
@ -10,8 +11,6 @@ use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
@ -46,7 +45,9 @@ impl Client {
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await?;
let response = self.stub.service_discovery(request).await.map_err(|_| {
ClientError::Connection("Server does not support v2 interface".to_string())
})?;
let urls = response
.into_inner()
.urls
@ -117,22 +118,6 @@ impl Client {
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
// Safe unwrap, because we control the data.
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
mimetype: "image/jpeg;base64".to_string(),
})
.into(),
);
}
// Send stringly-typed inputs for compatibility for backends that haven't
// been updated to support chunks.
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
@ -145,9 +130,6 @@ impl Client {
requests.push(Request {
id: 0,
input_chunks: Some(Input {
chunks: input_chunks,
}),
inputs,
// We truncate the input on the server side to be sure that it has the correct size
truncate,

View File

@ -1,7 +1,8 @@
use crate::v3::{pb, Chunk};
use crate::{ClientError, Result, WARMUP_IMAGE_BASE64};
/// Single shard Client
use crate::v3::pb;
use crate::Result;
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use grpc_metadata::InjectTelemetryContext;
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*;
@ -44,7 +45,9 @@ impl Client {
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await?;
let response = self.stub.service_discovery(request).await.map_err(|_| {
ClientError::Connection("Server does not support v3 interface".to_string())
})?;
let urls = response
.into_inner()
.urls
@ -115,18 +118,40 @@ impl Client {
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
// Safe unwrap, because we control the data.
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
mimetype: "image/jpeg;base64".to_string(),
})
.into(),
);
}
// Send stringly-typed inputs for compatibility for backends that haven't
// been updated to support chunks.
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str("![]()");
inputs.push_str(&format!(
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
));
}
requests.push(Request {
id: 0,
// We truncate the input on the server side to be sure that it has the correct size
inputs,
input_chunks: Some(Input {
chunks: input_chunks,
}),
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {

View File

@ -5,9 +5,9 @@ mod client;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v3::HealthResponse;
pub use pb::generate::v3::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;

View File

@ -2,7 +2,7 @@
use crate::{v3, Health, ShardInfo};
use crate::{ClientError, Result};
use crate::v3::InfoResponse;
use crate::v3::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
@ -217,6 +217,9 @@ impl Health for ShardedClient {
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {

View File

@ -5,12 +5,10 @@ use crate::validation::{
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque;
use text_generation_client::ChunksToString;
use text_generation_client::Input;
use text_generation_client::{Batch, Request};
use text_generation_client::v2::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use text_generation_client::ChunksToString;
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
@ -283,9 +281,6 @@ impl State {
batch_requests.push(Request {
id,
prefill_logprobs: entry.request.decoder_input_details,
input_chunks: Some(Input {
chunks: entry.request.inputs.clone(),
}),
inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate,
parameters: Some(NextTokenChooserParameters::from(
@ -309,7 +304,7 @@ impl State {
// Empty batch
if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
tracing::debug!("Filtered out all entries");
return None;
}

View File

@ -8,6 +8,7 @@ use std::collections::VecDeque;
use text_generation_client::v3::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use text_generation_client::{ChunksToString, Input};
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
@ -280,7 +281,10 @@ impl State {
batch_requests.push(Request {
id,
prefill_logprobs: entry.request.decoder_input_details,
inputs: entry.request.inputs.clone(),
inputs: entry.request.inputs.chunks_to_string(),
input_chunks: Some(Input {
chunks: entry.request.inputs.clone(),
}),
truncate: entry.request.truncate,
parameters: Some(NextTokenChooserParameters::from(
entry.request.parameters.clone(),
@ -406,7 +410,7 @@ mod tests {
let entry = Entry {
request: ValidGenerateRequest {
inputs: String::new(),
inputs: vec![],
input_length: 0,
truncate: 0,
decoder_input_details: false,

View File

@ -1,16 +1,16 @@
use crate::config::Config;
/// Payload validation logic
use crate::config::Config;
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest, GrammarType};
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat};
use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng};
use serde_json::Value;
use std::io::Cursor;
use text_generation_client::{Chunk, Image, InputChunk};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
// use tokenizers::TruncationDirection;
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat};
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tracing::{instrument, Span};