Simplify image token lookup

This commit is contained in:
Daniël de Kok 2024-09-03 11:46:23 +00:00
parent bac2cf7655
commit 8c74ee4498
2 changed files with 8 additions and 10 deletions

View File

@ -7,7 +7,6 @@ pub struct LlavaNext {
pub(crate) text_config: TextConfig, pub(crate) text_config: TextConfig,
pub(crate) vision_config: VisionConfig, pub(crate) vision_config: VisionConfig,
pub(crate) image_grid_pinpoints: Vec<(usize, usize)>, pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,
pub(crate) image_token_index: u32,
} }
fn get_anyres_image_grid_shape( fn get_anyres_image_grid_shape(
@ -113,9 +112,7 @@ pub struct ClipVisionModel {
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct Idefics2 { pub struct Idefics2 {}
pub(crate) image_token_id: u32,
}
impl Idefics2 { impl Idefics2 {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
@ -132,7 +129,6 @@ pub struct PaliTextConfig {
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub struct Paligemma { pub struct Paligemma {
pub(crate) image_token_index: u32,
pub(crate) text_config: PaliTextConfig, pub(crate) text_config: PaliTextConfig,
} }

View File

@ -1,5 +1,5 @@
/// Payload validation logic /// Payload validation logic
use crate::config::{Config, Idefics2}; use crate::config::Config;
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{ use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
@ -605,10 +605,12 @@ fn image_tokens(
fn image_id(config: &Config, tokenizer: &Tokenizer) -> u32 { fn image_id(config: &Config, tokenizer: &Tokenizer) -> u32 {
use Config::*; use Config::*;
match config { match config {
Idefics => tokenizer.token_to_id("<image>").unwrap(), // The configuration key for the image token id does not seem to
Idefics2(idefics) => idefics.image_token_id, // be standardized, but the image tag is. So let's use that to get
LlavaNext(llava) => llava.image_token_index, // the image token id.
Paligemma(paligemma) => paligemma.image_token_index, Idefics | Idefics2(_) | LlavaNext(_) | Paligemma(_) => {
tokenizer.token_to_id("<image>").unwrap()
}
_ => unimplemented!("Images tokens are not supported for this model configuration"), _ => unimplemented!("Images tokens are not supported for this model configuration"),
} }
} }