From 09ac4fb6eba6f29bf1a82962958221ff34531a8d Mon Sep 17 00:00:00 2001 From: David Holtz Date: Thu, 24 Oct 2024 19:57:47 +0000 Subject: [PATCH] feat: fix token padding, enable warmup and process basic request --- router/src/config.rs | 29 +++++++++++++++++++ router/src/validation.rs | 8 ++++- .../custom_modeling/flash_qwen2_modeling.py | 4 +++ .../models/custom_modeling/qwen2_vl.py | 18 ++++++++---- .../models/vlm_causal_lm.py | 13 +++++---- 5 files changed, 60 insertions(+), 12 deletions(-) diff --git a/router/src/config.rs b/router/src/config.rs index ce066ad0..7fc27f96 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -138,10 +138,39 @@ impl Paligemma { } } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Qwen2VlVisionConfig { + pub(crate) depth: usize, + pub(crate) embed_dim: usize, + pub(crate) mlp_ratio: usize, + pub(crate) num_heads: usize, + pub(crate) in_chans: usize, + pub(crate) hidden_size: usize, + pub(crate) patch_size: usize, + pub(crate) spatial_merge_size: usize, + pub(crate) spatial_patch_size: usize, + pub(crate) temporal_patch_size: usize, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct Qwen2Vl { + pub(crate) vision_config: Qwen2VlVisionConfig, +} + +impl Qwen2Vl { + pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize { + // TODO: calculate number of features + 6000 / 4 + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(tag = "model_type")] #[serde(rename_all = "snake_case")] pub enum Config { + Qwen2Vl(Qwen2Vl), LlavaNext(LlavaNext), ClipVisionModel(ClipVisionModel), Mistral, diff --git a/router/src/validation.rs b/router/src/validation.rs index 8159ede4..5b2a153c 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -594,6 +594,10 @@ fn image_tokens( } Paligemma(config) => "".repeat(config.get_number_of_features(height, width)), LlavaNext(config) => "".repeat(config.get_number_of_features(height, width)), + Qwen2Vl(config) => format!( + "<|vision_start|>{:?}<|vision_end|>", + "<|image_pad|>".repeat(config.get_number_of_features(height, width)) + ), _ => unimplemented!("Images tokens are not supported for this model configuration"), } } @@ -620,7 +624,9 @@ fn prepare_input( use Config::*; static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); let (tokenizer_query, input_chunks) = match config { - Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => { + Some( + config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)), + ) => { let mut input_chunks = Vec::new(); let mut tokenizer_query = String::with_capacity(inputs.len()); let mut start = 0; diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index e9be22b1..7ae43256 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -348,6 +348,10 @@ class Qwen2Model(torch.nn.Module): # Get rotary cos and sin for this forward # Avoid to index in each layer # TODO: fix how N-D position_ids are handled + + if position_ids.dim() == 2: + position_ids = position_ids.unsqueeze(0) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin_hack( position_ids, true_max_s, hidden_states.dtype ) diff --git a/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 28217d7d..ac66695a 100644 --- a/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -34,6 +34,7 @@ from text_generation_server.layers.layernorm import ( ) from text_generation_server.layers import ( TensorParallelColumnLinear, + TensorParallelRowLinear, FastLinear, ) from text_generation_server.layers.attention import ( @@ -352,6 +353,7 @@ class Qwen2VisionModel(nn.Module): class Qwen2VLForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() + self.config = config config.vision_config.quantize = None config.vision_config.speculator = config.speculator self.hidden_size = config.hidden_size @@ -364,6 +366,10 @@ class Qwen2VLForConditionalGeneration(nn.Module): prefix="visual", config=config.vision_config, weights=weights ) self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) + self.lm_head = FastLinear.load( + prefix="lm_head", weights=weights, config=config, bias=False + ) + self.device = weights.device def forward( self, @@ -386,10 +392,10 @@ class Qwen2VLForConditionalGeneration(nn.Module): cross_attention_states: Optional[torch.Tensor] = None, image_indices=None, ): - - # make an attention_mask that is the same size as the input_ids - attention_mask = torch.ones_like(input_ids, dtype=torch.bool) - + # make an attention_mask that is (batch_size, sequence_length) + attention_mask = torch.ones_like( + input_ids, dtype=torch.bool, device=input_ids.device + ) inputs_embeds = self.text_model.embed_tokens(input_ids) # apply the visual model to the pixel values if they are provided @@ -525,7 +531,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): mrope_position_deltas, device=input_ids.device ).unsqueeze(1) - # TODO: adjust model to accept 2D position_ids outputs = self.text_model( input_ids=input_ids, position_ids=position_ids, @@ -541,4 +546,5 @@ class Qwen2VLForConditionalGeneration(nn.Module): attention_mask=attention_mask, ) - return outputs, None + logits = self.lm_head(outputs) + return logits, None diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 1b8e7f88..7625c305 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -68,7 +68,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str elif config.model_type == "paligemma": return "" * config.text_config.num_image_tokens elif config.model_type == "qwen2_vl": - return "" + num_pads = image_input.pixel_values.shape[0] // 4 + padding = "<|image_pad|>" * num_pads + return f"<|vision_start|>{padding}<|vision_end|>" else: raise RuntimeError(f"Unknown config {config.model_type} for multimodal") @@ -183,10 +185,11 @@ class VlmCausalLMBatch(FlashCausalLMBatch): raise RuntimeError(f"Invalid chunk type {chunk_type}") if images: - # TODO: REMOVE (this is for debugging purposes) - images = images[0][0].resize( - (images[0][0].width * 2, images[0][0].height * 2) - ) + if images[0][0].width <= 20: + # TODO: provide a better way to handle the issue of the prefill image being too small + images = images[0][0].resize( + (images[0][0].width * 2, images[0][0].height * 2) + ) image_inputs = processor.image_processor(images, return_tensors="pt") else: image_inputs = None