Idefics2: sync added image tokens with transformers

Before this change, the number of reserved image tokens was not the
same as the number of images. Fixes #2029.

While at it, also remove all the image token handling duplication
in `prepare_input`.
This commit is contained in:
Daniël de Kok 2024-06-20 09:21:58 +02:00
parent cdbf802860
commit 9ce4552bae
10 changed files with 6103 additions and 5917 deletions

View File

@ -8,61 +8,61 @@
"tokens": [ "tokens": [
{ {
"id": 330, "id": 330,
"logprob": -0.13000488, "logprob": -0.08660889,
"special": false, "special": false,
"text": " A" "text": " A"
}, },
{ {
"id": 13088, "id": 13088,
"logprob": -0.6713867, "logprob": -0.7089844,
"special": false, "special": false,
"text": " chicken" "text": " chicken"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.2980957, "logprob": -0.32885742,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6398, "id": 6398,
"logprob": -0.060638428, "logprob": -0.05126953,
"special": false, "special": false,
"text": " sitting" "text": " sitting"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.27319336, "logprob": -0.35229492,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.140625, "logprob": -0.12561035,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 17972, "id": 17972,
"logprob": -0.040405273, "logprob": -0.038085938,
"special": false, "special": false,
"text": " pile" "text": " pile"
}, },
{ {
"id": 302, "id": 302,
"logprob": -0.0002708435, "logprob": -0.00018656254,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 2445, "id": 2445,
"logprob": -0.095336914, "logprob": -0.07293701,
"special": false, "special": false,
"text": " money" "text": " money"
}, },
{ {
"id": 28723, "id": 28723,
"logprob": -0.0068359375, "logprob": -0.004852295,
"special": false, "special": false,
"text": "." "text": "."
} }

View File

@ -8,115 +8,115 @@
"tokens": [ "tokens": [
{ {
"id": 415, "id": 415,
"logprob": -0.04421997, "logprob": -0.039886475,
"special": false, "special": false,
"text": " The" "text": " The"
}, },
{ {
"id": 12072, "id": 12072,
"logprob": -0.13500977, "logprob": -0.1430664,
"special": false, "special": false,
"text": " cow" "text": " cow"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.06750488, "logprob": -0.056488037,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6328, "id": 6328,
"logprob": -0.6352539, "logprob": -0.6855469,
"special": false, "special": false,
"text": " standing" "text": " standing"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.16186523, "logprob": -0.1685791,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 272, "id": 272,
"logprob": -0.5078125, "logprob": -0.50097656,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 10305, "id": 10305,
"logprob": -0.017913818, "logprob": -0.017303467,
"special": false, "special": false,
"text": " beach" "text": " beach"
}, },
{ {
"id": 304, "id": 304,
"logprob": -1.5205078, "logprob": -1.3564453,
"special": false, "special": false,
"text": " and" "text": " and"
}, },
{ {
"id": 272, "id": 272,
"logprob": -0.029174805, "logprob": -0.017868042,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 13088, "id": 13088,
"logprob": -0.003479004, "logprob": -0.0027103424,
"special": false, "special": false,
"text": " chicken" "text": " chicken"
}, },
{ {
"id": 349, "id": 349,
"logprob": -0.0035095215, "logprob": -0.003156662,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 6398, "id": 6398,
"logprob": -0.3088379, "logprob": -0.37304688,
"special": false, "special": false,
"text": " sitting" "text": " sitting"
}, },
{ {
"id": 356, "id": 356,
"logprob": -0.027755737, "logprob": -0.034576416,
"special": false, "special": false,
"text": " on" "text": " on"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.31884766, "logprob": -0.29418945,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 17972, "id": 17972,
"logprob": -0.047943115, "logprob": -0.042877197,
"special": false, "special": false,
"text": " pile" "text": " pile"
}, },
{ {
"id": 302, "id": 302,
"logprob": -0.0002925396, "logprob": -0.00028443336,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 2445, "id": 2445,
"logprob": -0.02935791, "logprob": -0.023223877,
"special": false, "special": false,
"text": " money" "text": " money"
}, },
{ {
"id": 28723, "id": 28723,
"logprob": -0.031219482, "logprob": -0.018157959,
"special": false, "special": false,
"text": "." "text": "."
}, },
{ {
"id": 32002, "id": 32002,
"logprob": -0.00034475327, "logprob": -0.00018393993,
"special": true, "special": true,
"text": "<end_of_utterance>" "text": "<end_of_utterance>"
}, },

View File

@ -112,7 +112,7 @@ pub struct Idefics2 {}
impl Idefics2 { impl Idefics2 {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
320 64
} }
} }

View File

@ -70,6 +70,25 @@ impl HubTokenizerConfig {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "processor_class")]
pub enum HubPreprocessorConfig {
Idefics2Processor(Idefics2Preprocessor),
}
impl HubPreprocessorConfig {
pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {
let content = std::fs::read_to_string(filename).ok()?;
serde_json::from_str(&content).ok()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Idefics2Preprocessor {
#[serde(default)]
do_image_splitting: bool,
}
#[derive(Debug, Clone, Deserialize, Default)] #[derive(Debug, Clone, Deserialize, Default)]
pub struct HubProcessorConfig { pub struct HubProcessorConfig {
pub chat_template: Option<ChatTemplateVersions>, pub chat_template: Option<ChatTemplateVersions>,

View File

@ -13,7 +13,9 @@ use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use text_generation_router::config::Config; use text_generation_router::config::Config;
use text_generation_router::{server, HubModelInfo, HubProcessorConfig, HubTokenizerConfig}; use text_generation_router::{
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
use thiserror::Error; use thiserror::Error;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tower_http::cors::AllowOrigin; use tower_http::cors::AllowOrigin;
@ -209,6 +211,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_filename, tokenizer_filename,
config_filename, config_filename,
tokenizer_config_filename, tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename, processor_config_filename,
model_info, model_info,
) = match api { ) = match api {
@ -216,6 +219,7 @@ async fn main() -> Result<(), RouterError> {
Some(local_path.join("tokenizer.json")), Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")), Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")), Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")), Some(local_path.join("processor_config.json")),
None, None,
), ),
@ -232,6 +236,7 @@ async fn main() -> Result<(), RouterError> {
}; };
let config_filename = api_repo.get("config.json").await.ok(); let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok(); let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_model_info(&api_repo).await { let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
@ -244,6 +249,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_filename, tokenizer_filename,
config_filename, config_filename,
tokenizer_config_filename, tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename, processor_config_filename,
model_info, model_info,
) )
@ -258,6 +264,7 @@ async fn main() -> Result<(), RouterError> {
repo.get("tokenizer.json"), repo.get("tokenizer.json"),
repo.get("config.json"), repo.get("config.json"),
repo.get("tokenizer_config.json"), repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"), repo.get("processor_config.json"),
None, None,
) )
@ -295,6 +302,8 @@ async fn main() -> Result<(), RouterError> {
HubTokenizerConfig::default() HubTokenizerConfig::default()
}); });
let preprocessor_config =
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
let processor_config = processor_config_filename let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file) .and_then(HubProcessorConfig::from_file)
.unwrap_or_default(); .unwrap_or_default();
@ -356,6 +365,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
tokenizer_config, tokenizer_config,
preprocessor_config,
processor_config, processor_config,
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,

View File

@ -12,9 +12,9 @@ use crate::kserve::{
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, GenerateResponse, GrammarType, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig,
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, HubTokenizerConfig, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse,
Usage, Validation, Token, TokenizeResponse, Usage, Validation,
}; };
use crate::{ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
@ -1421,6 +1421,7 @@ pub async fn run(
_ngrok_authtoken: Option<String>, _ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>, _ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
preprocessor_config: Option<HubPreprocessorConfig>,
processor_config: HubProcessorConfig, processor_config: HubProcessorConfig,
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool, grammar_support: bool,
@ -1634,6 +1635,7 @@ pub async fn run(
validation_workers, validation_workers,
tokenizer, tokenizer,
config, config,
preprocessor_config,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,

View File

@ -1,13 +1,16 @@
/// Payload validation logic /// Payload validation logic
use crate::config::Config; use crate::config::Config;
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest, GrammarType}; use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
};
use base64::{engine::general_purpose::STANDARD, Engine}; use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat}; use image::{io::Reader as ImageReader, ImageFormat};
use jsonschema::{Draft, JSONSchema}; use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use std::iter;
use text_generation_client::{Chunk, Image, InputChunk}; use text_generation_client::{Chunk, Image, InputChunk};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
@ -36,6 +39,7 @@ impl Validation {
workers: usize, workers: usize,
tokenizer: Option<Tokenizer>, tokenizer: Option<Tokenizer>,
config: Option<Config>, config: Option<Config>,
preprocessor_config: Option<HubPreprocessorConfig>,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32, max_top_n_tokens: u32,
@ -53,12 +57,18 @@ impl Validation {
for _ in 0..workers { for _ in 0..workers {
let tokenizer_clone = tokenizer.clone(); let tokenizer_clone = tokenizer.clone();
let config_clone = config.clone(); let config_clone = config.clone();
let preprocessor_config_clone = preprocessor_config.clone();
let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel(); let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
senders.push(tokenizer_sender); senders.push(tokenizer_sender);
// Spawn worker // Spawn worker
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver) tokenizer_worker(
tokenizer_clone,
config_clone,
preprocessor_config_clone,
tokenizer_receiver,
)
}); });
} }
@ -420,13 +430,20 @@ async fn round_robin_task(
fn tokenizer_worker( fn tokenizer_worker(
tokenizer: Tokenizer, tokenizer: Tokenizer,
config: Option<Config>, config: Option<Config>,
preprocessor_config: Option<HubPreprocessorConfig>,
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) { ) {
// Loop over requests // Loop over requests
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
parent_span.in_scope(|| { parent_span.in_scope(|| {
response_tx response_tx
.send(prepare_input(inputs, truncate, &tokenizer, &config)) .send(prepare_input(
inputs,
truncate,
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
))
.unwrap_or(()) .unwrap_or(())
}) })
} }
@ -506,16 +523,59 @@ fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), Validatio
} }
} }
fn image_tokens(
config: &Config,
preprocessor_config: Option<&HubPreprocessorConfig>,
height: usize,
width: usize,
) -> String {
use Config::*;
use HubPreprocessorConfig::*;
match config {
Idefics => "<image>".to_string(),
Idefics2(config) => {
let slots = config.get_number_of_features(height, width);
const FAKE: &str = "<fake_token_around_image>";
const IMAGE: &str = "<image>";
const FAKE_LEN: usize = FAKE.len();
const IMAGE_LEN: usize = IMAGE.len();
let mut tokens = String::with_capacity(2 * FAKE_LEN + slots * IMAGE_LEN);
tokens.push_str(FAKE);
tokens.extend(iter::repeat(IMAGE).take(slots));
tokens.push_str(FAKE);
if matches!(
preprocessor_config,
Some(Idefics2Processor(Idefics2Preprocessor {
do_image_splitting: true,
..
}))
) {
tokens = tokens.repeat(5);
}
tokens
}
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
_ => unimplemented!("Images tokens are not supported for this model configuration"),
}
}
/// Get input length and optionally truncate it /// Get input length and optionally truncate it
fn prepare_input( fn prepare_input(
inputs: String, inputs: String,
_truncate: Option<usize>, _truncate: Option<usize>,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
config: &Option<Config>, config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>,
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> { ) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
Some(Config::LlavaNext(config)) => { Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;
@ -527,82 +587,8 @@ fn prepare_input(
tokenizer_query.push_str(&inputs[start..chunk_start]); tokenizer_query.push_str(&inputs[start..chunk_start]);
} }
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
tokenizer_query.push_str(&"<image>".repeat(slots));
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
(tokenizer_query, input_chunks)
}
Some(Config::Paligemma(config)) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
tokenizer_query.push_str(&"<image>".repeat(slots));
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
(tokenizer_query, input_chunks)
}
Some(Config::Idefics2(config)) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = config.get_number_of_features(height, width);
tokenizer_query.push_str("<fake_token_around_image>");
tokenizer_query.push_str(&"<image>".repeat(slots));
tokenizer_query.push_str("<fake_token_around_image>");
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
tokenizer_query.push_str(&inputs[start..]);
}
(tokenizer_query, input_chunks)
}
Some(Config::Idefics) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, _height, _width) =
fetch_image(&inputs[chunk_start..chunk_end])?;
let slots = 1;
tokenizer_query.push_str(&"<image>".repeat(slots));
input_chunks.push(Chunk::Image(Image { data, mimetype }).into()); input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end; start = chunk_end;
} }
if start != inputs.len() { if start != inputs.len() {
@ -766,6 +752,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -800,6 +787,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -833,6 +821,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -871,6 +860,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,
@ -938,6 +928,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
config, config,
None,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,
@ -1023,6 +1014,7 @@ mod tests {
workers, workers,
tokenizer, tokenizer,
Some(config), Some(config),
None,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens, max_top_n_tokens,

View File

@ -39,7 +39,9 @@ class PaliGemmaBatch(VlmCausalLMBatch):
# TODO do_convert_RGB should be on by default ? # TODO do_convert_RGB should be on by default ?
image = image.convert("RGB") image = image.convert("RGB")
image_input = processor.image_processor(image, return_tensors="pt") image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id) full_text += image_text_replacement(
processor, image_input, config, image_id
)
image_inputs.append(image_input) image_inputs.append(image_input)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")

View File

@ -39,15 +39,14 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
def image_text_replacement(image_input, config, image_id) -> str: def image_text_replacement(processor, image_input, config, image_id) -> str:
if config.model_type == "idefics2": if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented. # TODO technically depends on image splitting which is not implemented.
num_features = 320 image_seq_len = 64
return ( image_str = f"<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>"
"<fake_token_around_image>" if processor.image_processor.do_image_splitting:
+ "<image>" * num_features image_str *= 5
+ "<fake_token_around_image>" return image_str
)
elif config.model_type == "llava_next": elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id] height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config) num_features = get_number_of_features(height, width, config)
@ -168,7 +167,9 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
if chunk_type == "text": if chunk_type == "text":
full_text += chunk.text full_text += chunk.text
elif chunk_type == "image": elif chunk_type == "image":
full_text += image_text_replacement(image_inputs, config, image_id) full_text += image_text_replacement(
processor, image_inputs, config, image_id
)
image_id += 1 image_id += 1
batch_inputs.append(full_text) batch_inputs.append(full_text)