From 8a820d882d956b80f397a1401f025a944c956b12 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 3 Jun 2024 18:09:27 +0200 Subject: [PATCH] rebase --- Cargo.lock | 42 +++++++++++++------------- Cargo.toml | 6 ++-- benchmark/src/generation.rs | 4 +-- benchmark/src/lib.rs | 2 +- benchmark/src/main.rs | 2 +- proto/generate.proto | 25 +-------------- proto/v3/generate.proto | 25 ++++++++++++++- router/client/src/lib.rs | 13 ++++++-- router/client/src/v2/client.rs | 28 +++-------------- router/client/src/v3/client.rs | 37 +++++++++++++++++++---- router/client/src/v3/mod.rs | 6 ++-- router/client/src/v3/sharded_client.rs | 5 ++- router/src/infer/v2/queue.rs | 9 ++---- router/src/infer/v3/queue.rs | 8 +++-- router/src/validation.rs | 8 ++--- 15 files changed, 118 insertions(+), 102 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 413ff8ab..b5de8576 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.21.0" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" dependencies = [ "gimli", ] @@ -350,9 +350,9 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.71" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" +checksum = "17c6a35df3749d2e8bb1b7b21a976d82b15548788d2735b9d82f329268f71a11" dependencies = [ "addr2line", "cc", @@ -1138,9 +1138,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.1" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "glob" @@ -1396,9 +1396,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d8d52be92d09acc2e01dddb7fde3ad983fc6489c7db4837e605bc3fca4cb63e" +checksum = "7b875924a60b96e5d7b9ae7b066540b1dd1cbd90d1828f54c92e02a283351c56" dependencies = [ "bytes", "futures-util", @@ -1938,11 +1938,10 @@ dependencies = [ [[package]] name = "native-tls" -version = "0.2.11" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" dependencies = [ - "lazy_static", "libc", "log", "openssl", @@ -2168,9 +2167,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "object" -version = "0.32.2" +version = "0.35.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +checksum = "b8ec7ab813848ba4522158d5517a6093db1ded27575b070f4177b8d12b41db5e" dependencies = [ "memchr", ] @@ -2563,9 +2562,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.84" +version = "1.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6" +checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" dependencies = [ "unicode-ident", ] @@ -3554,6 +3553,7 @@ dependencies = [ name = "text-generation-client" version = "2.0.5-dev0" dependencies = [ + "async-trait", "base64 0.22.1", "futures", "grpc-metadata", @@ -3752,9 +3752,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.37.0" +version = "1.38.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" +checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" dependencies = [ "backtrace", "bytes", @@ -3781,9 +3781,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", @@ -4733,9 +4733,9 @@ checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" [[package]] name = "winnow" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3c52e9c97a68071b23e836c9380edae937f17b9c4667bd021973efc689f618d" +checksum = "86c949fede1d13936a99f14fafd3e76fd642b556dd2ce96287fbe2e0151bfac6" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index 16dd9423..22597b50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ hf-hub = { version = "0.3.1", features = ["tokio"] } [profile.release] debug = 1 incremental = true -lto = "fat" -opt-level = 3 -codegen-units = 1 +#lto = "fat" +#opt-level = 3 +#codegen-units = 1 panic = "abort" diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index f49d786a..27b74249 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -1,9 +1,9 @@ use std::time::{Duration, Instant}; -use text_generation_client::v2::{ +use text_generation_client::v3::{ Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, }; -use text_generation_client::ClientError; +use text_generation_client::{Chunk, ClientError, Input}; use tokenizers::{Tokenizer, TruncationDirection}; use tokio::sync::{broadcast, mpsc}; diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index 048c86af..c33d64e6 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -8,7 +8,7 @@ use crate::app::App; use crate::event::Event; use crossterm::ExecutableCommand; use std::io; -use text_generation_client::v2::{GrammarType, NextTokenChooserParameters, ShardedClient}; +use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient}; use tokenizers::Tokenizer; use tokio::sync::{broadcast, mpsc}; use tui::backend::CrosstermBackend; diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 1e45c1dd..b9d80b7a 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -4,7 +4,7 @@ /// and: https://github.com/orhun/rust-tui-template use clap::Parser; use std::path::Path; -use text_generation_client::v2::ShardedClient; +use text_generation_client::v3::ShardedClient; use tokenizers::{FromPretrainedParameters, Tokenizer}; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; diff --git a/proto/generate.proto b/proto/generate.proto index f568d01c..6351e37f 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -51,27 +51,6 @@ message ClearCacheRequest { /// Empty response message ClearCacheResponse {} -message Image { - /// Binary image data. - bytes data = 1; - - /// Image MIME type. - string mimetype = 2; -} - -message InputChunk { - oneof chunk { - /// Plain text data - string text = 1; - /// Image data - Image image = 2; - } -} - -message Input { - repeated InputChunk chunks = 1; - } - enum GrammarType { GRAMMAR_TYPE_NONE = 0; GRAMMAR_TYPE_JSON = 1; @@ -116,9 +95,7 @@ message StoppingCriteriaParameters { message Request { /// Request ID uint64 id = 1; - /// The generation context as chunks - Input input_chunks = 8; - /// The generation context, stringified input_chunks + /// The generation context string inputs = 2; /// Context truncation uint32 truncate = 3; diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index e594c607..ca2908c9 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -51,6 +51,27 @@ message ClearCacheRequest { /// Empty response message ClearCacheResponse {} +message Image { + /// Binary image data. + bytes data = 1; + + /// Image MIME type. + string mimetype = 2; +} + +message InputChunk { + oneof chunk { + /// Plain text data + string text = 1; + /// Image data + Image image = 2; + } +} + +message Input { + repeated InputChunk chunks = 1; + } + enum GrammarType { GRAMMAR_TYPE_NONE = 0; GRAMMAR_TYPE_JSON = 1; @@ -95,7 +116,9 @@ message StoppingCriteriaParameters { message Request { /// Request ID uint64 id = 1; - /// The generation context + /// The generation context as chunks + Input input_chunks = 8; + /// The generation context, stringified input_chunks string inputs = 2; /// Context truncation uint32 truncate = 3; diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 0663e301..45bee10c 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -4,10 +4,13 @@ pub mod v2; pub mod v3; use async_trait::async_trait; +use base64::{engine::general_purpose::STANDARD, Engine}; use thiserror::Error; use tonic::transport; use tonic::Status; +pub use v3::{Chunk, Image, Input, InputChunk}; + #[async_trait] pub trait Health { /// Check if a generate server is healthy by asking it to allocate a tensor on device @@ -47,12 +50,12 @@ impl From for ClientError { impl From for ClientError { fn from(err: transport::Error) -> Self { - Self::Connection(err.to_string()) + let err = Self::Connection(err.to_string()); + tracing::error!("{err}"); + err } } -pub type Result = std::result::Result; - // Small convenience re-wrapping of `Chunk`. impl From for InputChunk { fn from(chunk: Chunk) -> Self { @@ -82,3 +85,7 @@ impl ChunksToString for Vec { output } } + +static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; + +pub type Result = std::result::Result; diff --git a/router/client/src/v2/client.rs b/router/client/src/v2/client.rs index b31857fc..9a2e6ac7 100644 --- a/router/client/src/v2/client.rs +++ b/router/client/src/v2/client.rs @@ -1,7 +1,8 @@ /// Single shard Client use crate::v2::pb; -use crate::Result; +use crate::{ClientError, Result}; +use crate::WARMUP_IMAGE_BASE64; use grpc_metadata::InjectTelemetryContext; use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v2::*; @@ -10,8 +11,6 @@ use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; -static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII="; - /// Text Generation Inference gRPC client #[derive(Debug, Clone)] pub struct Client { @@ -46,7 +45,9 @@ impl Client { #[instrument(skip(self))] pub async fn service_discovery(&mut self) -> Result> { let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); - let response = self.stub.service_discovery(request).await?; + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v2 interface".to_string()) + })?; let urls = response .into_inner() .urls @@ -117,22 +118,6 @@ impl Client { while n_tokens < max_prefill_tokens { let truncate = min(max_input_length, max_prefill_tokens - n_tokens); - let mut input_chunks = Vec::new(); - input_chunks - .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); - if n_tokens == 0 { - input_chunks.push( - Chunk::Image(Image { - // Safe unwrap, because we control the data. - data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(), - mimetype: "image/jpeg;base64".to_string(), - }) - .into(), - ); - } - - // Send stringly-typed inputs for compatibility for backends that haven't - // been updated to support chunks. let mut inputs = String::new(); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); if n_tokens == 0 { @@ -145,9 +130,6 @@ impl Client { requests.push(Request { id: 0, - input_chunks: Some(Input { - chunks: input_chunks, - }), inputs, // We truncate the input on the server side to be sure that it has the correct size truncate, diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index d4155909..1f3a89a0 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -1,7 +1,8 @@ +use crate::v3::{pb, Chunk}; +use crate::{ClientError, Result, WARMUP_IMAGE_BASE64}; /// Single shard Client -use crate::v3::pb; -use crate::Result; - +use base64::engine::general_purpose::STANDARD; +use base64::Engine; use grpc_metadata::InjectTelemetryContext; use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v3::*; @@ -44,7 +45,9 @@ impl Client { #[instrument(skip(self))] pub async fn service_discovery(&mut self) -> Result> { let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); - let response = self.stub.service_discovery(request).await?; + let response = self.stub.service_discovery(request).await.map_err(|_| { + ClientError::Connection("Server does not support v3 interface".to_string()) + })?; let urls = response .into_inner() .urls @@ -115,18 +118,40 @@ impl Client { while n_tokens < max_prefill_tokens { let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + let mut input_chunks = Vec::new(); + input_chunks + .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + if n_tokens == 0 { + input_chunks.push( + Chunk::Image(Image { + // Safe unwrap, because we control the data. + data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(), + mimetype: "image/jpeg;base64".to_string(), + }) + .into(), + ); + } + + // Send stringly-typed inputs for compatibility for backends that haven't + // been updated to support chunks. + let mut inputs = String::new(); inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); if n_tokens == 0 { // 1 request is enough to test vision heads. // Sending images on other queries messes up easily with truncation. - inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=)"); + inputs.push_str(&format!( + "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})", + )); } requests.push(Request { id: 0, - // We truncate the input on the server side to be sure that it has the correct size inputs, + input_chunks: Some(Input { + chunks: input_chunks, + }), + // We truncate the input on the server side to be sure that it has the correct size truncate, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { diff --git a/router/client/src/v3/mod.rs b/router/client/src/v3/mod.rs index 7d551c13..4a1296a2 100644 --- a/router/client/src/v3/mod.rs +++ b/router/client/src/v3/mod.rs @@ -5,9 +5,9 @@ mod client; mod sharded_client; pub use client::Client; -pub use pb::generate::v3::HealthResponse; pub use pb::generate::v3::{ - Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse, - NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens, + input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, + HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request, + StoppingCriteriaParameters, Tokens, }; pub use sharded_client::ShardedClient; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 73ae356e..9b4f74d8 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -2,7 +2,7 @@ use crate::{v3, Health, ShardInfo}; use crate::{ClientError, Result}; -use crate::v3::InfoResponse; +use crate::v3::{Chunk, InfoResponse, Input}; use async_trait::async_trait; use futures::future::join_all; use tonic::transport::Uri; @@ -217,6 +217,9 @@ impl Health for ShardedClient { let liveness_request = Request { id: u64::MAX, inputs: "liveness".to_string(), + input_chunks: Some(Input { + chunks: vec![Chunk::Text("liveness".into()).into()], + }), truncate: 10, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 4a041ea7..3725c03e 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -5,12 +5,10 @@ use crate::validation::{ use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::min; use std::collections::VecDeque; -use text_generation_client::ChunksToString; -use text_generation_client::Input; -use text_generation_client::{Batch, Request}; use text_generation_client::v2::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; +use text_generation_client::ChunksToString; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; @@ -283,9 +281,6 @@ impl State { batch_requests.push(Request { id, prefill_logprobs: entry.request.decoder_input_details, - input_chunks: Some(Input { - chunks: entry.request.inputs.clone(), - }), inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, parameters: Some(NextTokenChooserParameters::from( @@ -309,7 +304,7 @@ impl State { // Empty batch if batch_requests.is_empty() { - tracing::debug!("Filterered out all entries"); + tracing::debug!("Filtered out all entries"); return None; } diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index f13cf936..b926f329 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -8,6 +8,7 @@ use std::collections::VecDeque; use text_generation_client::v3::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; +use text_generation_client::{ChunksToString, Input}; use tokio::sync::{mpsc, oneshot}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; @@ -280,7 +281,10 @@ impl State { batch_requests.push(Request { id, prefill_logprobs: entry.request.decoder_input_details, - inputs: entry.request.inputs.clone(), + inputs: entry.request.inputs.chunks_to_string(), + input_chunks: Some(Input { + chunks: entry.request.inputs.clone(), + }), truncate: entry.request.truncate, parameters: Some(NextTokenChooserParameters::from( entry.request.parameters.clone(), @@ -406,7 +410,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { - inputs: String::new(), + inputs: vec![], input_length: 0, truncate: 0, decoder_input_details: false, diff --git a/router/src/validation.rs b/router/src/validation.rs index c321c33b..bb9ad318 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,16 +1,16 @@ -use crate::config::Config; /// Payload validation logic +use crate::config::Config; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{GenerateParameters, GenerateRequest, GrammarType}; +use base64::{engine::general_purpose::STANDARD, Engine}; +use image::{io::Reader as ImageReader, ImageFormat}; use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; use serde_json::Value; use std::io::Cursor; +use text_generation_client::{Chunk, Image, InputChunk}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; -// use tokenizers::TruncationDirection; -use base64::{engine::general_purpose::STANDARD, Engine}; -use image::{io::Reader as ImageReader, ImageFormat}; use tokio::sync::mpsc; use tokio::sync::oneshot; use tracing::{instrument, Span};