mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
WIP
This commit is contained in:
parent
e4ab855480
commit
e6c524c66b
18
Cargo.lock
generated
18
Cargo.lock
generated
@ -3884,6 +3884,12 @@ dependencies = [
|
|||||||
"unicode-segmentation",
|
"unicode-segmentation",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "static_assertions"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "strsim"
|
name = "strsim"
|
||||||
version = "0.11.1"
|
version = "0.11.1"
|
||||||
@ -4175,6 +4181,7 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry 0.21.0",
|
"tracing-opentelemetry 0.21.0",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
"twox-hash",
|
||||||
"ureq",
|
"ureq",
|
||||||
"utoipa",
|
"utoipa",
|
||||||
"utoipa-swagger-ui",
|
"utoipa-swagger-ui",
|
||||||
@ -4776,6 +4783,17 @@ version = "0.2.5"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "twox-hash"
|
||||||
|
version = "1.6.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"rand",
|
||||||
|
"static_assertions",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typenum"
|
name = "typenum"
|
||||||
version = "1.17.0"
|
version = "1.17.0"
|
||||||
|
@ -68,14 +68,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||||
if let Some(config) = config {
|
if let Some(config) = config {
|
||||||
if prefix_caching.is_none() {
|
if prefix_caching.is_none() && config.is_encoder_decoder {
|
||||||
if config.vision_config.is_some() {
|
tracing::info!("Disabling prefix caching because of seq2seq model");
|
||||||
tracing::info!("Disabling prefix caching because of VLM model");
|
prefix_caching = Some("0".to_string());
|
||||||
prefix_caching = Some("0".to_string());
|
|
||||||
} else if config.is_encoder_decoder {
|
|
||||||
tracing::info!("Disabling prefix caching because of seq2seq model");
|
|
||||||
prefix_caching = Some("0".to_string());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
match config.head_dim {
|
match config.head_dim {
|
||||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||||
@ -126,7 +121,6 @@ struct RawConfig {
|
|||||||
hidden_size: Option<usize>,
|
hidden_size: Option<usize>,
|
||||||
num_attention_heads: Option<usize>,
|
num_attention_heads: Option<usize>,
|
||||||
head_dim: Option<usize>,
|
head_dim: Option<usize>,
|
||||||
vision_config: Option<VisionConfig>,
|
|
||||||
is_encoder_decoder: Option<bool>,
|
is_encoder_decoder: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,7 +138,6 @@ struct Config {
|
|||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
head_dim: Option<usize>,
|
head_dim: Option<usize>,
|
||||||
model_type: Option<String>,
|
model_type: Option<String>,
|
||||||
vision_config: Option<VisionConfig>,
|
|
||||||
is_encoder_decoder: bool,
|
is_encoder_decoder: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -172,14 +165,12 @@ impl From<RawConfig> for Config {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
let model_type = other.model_type;
|
let model_type = other.model_type;
|
||||||
let vision_config = other.vision_config;
|
|
||||||
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||||
Config {
|
Config {
|
||||||
max_position_embeddings,
|
max_position_embeddings,
|
||||||
quantize,
|
quantize,
|
||||||
head_dim,
|
head_dim,
|
||||||
model_type,
|
model_type,
|
||||||
vision_config,
|
|
||||||
is_encoder_decoder,
|
is_encoder_decoder,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -61,6 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
|
|||||||
] }
|
] }
|
||||||
csv = "1.3.0"
|
csv = "1.3.0"
|
||||||
ureq = "=2.9"
|
ureq = "=2.9"
|
||||||
|
twox-hash = "1.6.3"
|
||||||
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -129,6 +129,7 @@ pub struct PaliTextConfig {
|
|||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct Paligemma {
|
pub struct Paligemma {
|
||||||
|
pub(crate) image_token_index: u32,
|
||||||
pub(crate) text_config: PaliTextConfig,
|
pub(crate) text_config: PaliTextConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,17 +6,23 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||||
use image::{ImageFormat, ImageReader};
|
use image::{ImageFormat, ImageReader};
|
||||||
|
use itertools::Itertools;
|
||||||
use jsonschema::{Draft, JSONSchema};
|
use jsonschema::{Draft, JSONSchema};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::hash::Hasher;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use std::iter;
|
use std::iter;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
use tokenizers::Encoding;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tracing::{instrument, Span};
|
use tracing::{instrument, Span};
|
||||||
|
use twox_hash::xxh3::HasherExt;
|
||||||
|
use twox_hash::Xxh3Hash128;
|
||||||
use {once_cell::sync::Lazy, regex::Regex};
|
use {once_cell::sync::Lazy, regex::Regex};
|
||||||
|
|
||||||
/// Validation
|
/// Validation
|
||||||
@ -596,6 +602,45 @@ fn image_tokens(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn image_id(config: &Config) -> u32 {
|
||||||
|
use Config::*;
|
||||||
|
match config {
|
||||||
|
Paligemma(pali_gemma) => pali_gemma.image_token_index,
|
||||||
|
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn n_image_tokens(
|
||||||
|
config: &Config,
|
||||||
|
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||||
|
height: usize,
|
||||||
|
width: usize,
|
||||||
|
) -> usize {
|
||||||
|
use Config::*;
|
||||||
|
use HubPreprocessorConfig::*;
|
||||||
|
match config {
|
||||||
|
Idefics => 1,
|
||||||
|
Idefics2(config) => {
|
||||||
|
let repeats = if matches!(
|
||||||
|
preprocessor_config,
|
||||||
|
Some(Idefics2Processor(Idefics2Preprocessor {
|
||||||
|
do_image_splitting: true,
|
||||||
|
..
|
||||||
|
}))
|
||||||
|
) {
|
||||||
|
5
|
||||||
|
} else {
|
||||||
|
1
|
||||||
|
};
|
||||||
|
|
||||||
|
config.get_number_of_features(height, width) * repeats
|
||||||
|
}
|
||||||
|
Paligemma(config) => config.get_number_of_features(height, width),
|
||||||
|
LlavaNext(config) => config.get_number_of_features(height, width),
|
||||||
|
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn image_tokens_fixup(config: &Config, text: String) -> String {
|
fn image_tokens_fixup(config: &Config, text: String) -> String {
|
||||||
match config {
|
match config {
|
||||||
Config::Idefics2(_) => {
|
Config::Idefics2(_) => {
|
||||||
@ -617,8 +662,10 @@ fn prepare_input(
|
|||||||
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
||||||
use Config::*;
|
use Config::*;
|
||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks, image_token_id, image_hashes, image_lens) = match config {
|
||||||
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||||
|
let mut image_hashes = Vec::new();
|
||||||
|
let mut image_lens = Vec::new();
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
@ -630,8 +677,15 @@ fn prepare_input(
|
|||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
input_chunks.push(Chunk::Image(Image { data, mimetype }));
|
|
||||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
||||||
|
|
||||||
|
let mut hasher = Xxh3Hash128::default();
|
||||||
|
hasher.write(&data);
|
||||||
|
image_hashes.push(hasher.finish_ext());
|
||||||
|
image_lens.push(n_image_tokens(config, preprocessor_config, height, width));
|
||||||
|
|
||||||
|
input_chunks.push(Chunk::Image(Image { data, mimetype }));
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
if start != inputs.len() {
|
if start != inputs.len() {
|
||||||
@ -641,9 +695,15 @@ fn prepare_input(
|
|||||||
|
|
||||||
tokenizer_query = image_tokens_fixup(config, tokenizer_query);
|
tokenizer_query = image_tokens_fixup(config, tokenizer_query);
|
||||||
|
|
||||||
(tokenizer_query, input_chunks)
|
(
|
||||||
|
tokenizer_query,
|
||||||
|
input_chunks,
|
||||||
|
image_id(&config),
|
||||||
|
image_hashes,
|
||||||
|
image_lens,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
_ => (inputs.clone(), vec![Chunk::Text(inputs)]),
|
_ => (inputs.clone(), vec![Chunk::Text(inputs)], 0, vec![], vec![]),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
@ -651,6 +711,35 @@ fn prepare_input(
|
|||||||
.encode(tokenizer_query, add_special_tokens)
|
.encode(tokenizer_query, add_special_tokens)
|
||||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
|
||||||
|
tracing::info!("encoding before hash: {:?}", encoding.get_ids());
|
||||||
|
|
||||||
|
// Replace image tokens by hashes. The first token of an image
|
||||||
|
// must be specific to the image for prefix caching.
|
||||||
|
let mut token_ids = encoding.get_ids().to_owned();
|
||||||
|
let mut iter = token_ids.iter_mut().filter(|id| **id == image_token_id);
|
||||||
|
for (image_hash, n_tokens) in image_hashes.iter().zip(image_lens.iter()) {
|
||||||
|
let image_token = iter.next().ok_or(ValidationError::Tokenizer(
|
||||||
|
"Image token not found".to_string(),
|
||||||
|
))?;
|
||||||
|
*image_token = *image_hash as u32;
|
||||||
|
// Skip the remaining tokens of the current image.
|
||||||
|
iter = iter.dropping(n_tokens - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let encoding = Encoding::new(
|
||||||
|
token_ids,
|
||||||
|
encoding.get_type_ids().to_owned(),
|
||||||
|
encoding.get_tokens().to_owned(),
|
||||||
|
encoding.get_word_ids().to_owned(),
|
||||||
|
encoding.get_offsets().to_owned(),
|
||||||
|
encoding.get_special_tokens_mask().to_owned(),
|
||||||
|
encoding.get_attention_mask().to_owned(),
|
||||||
|
encoding.get_overflowing().to_owned(),
|
||||||
|
HashMap::new(),
|
||||||
|
);
|
||||||
|
|
||||||
|
tracing::info!("encoding after hash: {:?}", encoding.get_ids());
|
||||||
|
|
||||||
Ok((encoding, input_chunks))
|
Ok((encoding, input_chunks))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user