feat: fix token padding, enable warmup and process basic request

This commit is contained in:
David Holtz 2024-10-24 19:57:47 +00:00 committed by drbh
parent d96eef2a02
commit 09ac4fb6eb
5 changed files with 60 additions and 12 deletions

View File

@ -138,10 +138,39 @@ impl Paligemma {
} }
} }
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Qwen2VlVisionConfig {
pub(crate) depth: usize,
pub(crate) embed_dim: usize,
pub(crate) mlp_ratio: usize,
pub(crate) num_heads: usize,
pub(crate) in_chans: usize,
pub(crate) hidden_size: usize,
pub(crate) patch_size: usize,
pub(crate) spatial_merge_size: usize,
pub(crate) spatial_patch_size: usize,
pub(crate) temporal_patch_size: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Qwen2Vl {
pub(crate) vision_config: Qwen2VlVisionConfig,
}
impl Qwen2Vl {
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
// TODO: calculate number of features
6000 / 4
}
}
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")] #[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum Config { pub enum Config {
Qwen2Vl(Qwen2Vl),
LlavaNext(LlavaNext), LlavaNext(LlavaNext),
ClipVisionModel(ClipVisionModel), ClipVisionModel(ClipVisionModel),
Mistral, Mistral,

View File

@ -594,6 +594,10 @@ fn image_tokens(
} }
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)), Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)), LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
Qwen2Vl(config) => format!(
"<|vision_start|>{:?}<|vision_end|>",
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
),
_ => unimplemented!("Images tokens are not supported for this model configuration"), _ => unimplemented!("Images tokens are not supported for this model configuration"),
} }
} }
@ -620,7 +624,9 @@ fn prepare_input<T: TokenizerTrait>(
use Config::*; use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => { Some(
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)),
) => {
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;

View File

@ -348,6 +348,10 @@ class Qwen2Model(torch.nn.Module):
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
# TODO: fix how N-D position_ids are handled # TODO: fix how N-D position_ids are handled
if position_ids.dim() == 2:
position_ids = position_ids.unsqueeze(0)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin_hack( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin_hack(
position_ids, true_max_s, hidden_states.dtype position_ids, true_max_s, hidden_states.dtype
) )

View File

@ -34,6 +34,7 @@ from text_generation_server.layers.layernorm import (
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear,
FastLinear, FastLinear,
) )
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -352,6 +353,7 @@ class Qwen2VisionModel(nn.Module):
class Qwen2VLForConditionalGeneration(nn.Module): class Qwen2VLForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.config = config
config.vision_config.quantize = None config.vision_config.quantize = None
config.vision_config.speculator = config.speculator config.vision_config.speculator = config.speculator
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -364,6 +366,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefix="visual", config=config.vision_config, weights=weights prefix="visual", config=config.vision_config, weights=weights
) )
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights) self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
self.lm_head = FastLinear.load(
prefix="lm_head", weights=weights, config=config, bias=False
)
self.device = weights.device
def forward( def forward(
self, self,
@ -386,10 +392,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
cross_attention_states: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None, image_indices=None,
): ):
# make an attention_mask that is (batch_size, sequence_length)
# make an attention_mask that is the same size as the input_ids attention_mask = torch.ones_like(
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) input_ids, dtype=torch.bool, device=input_ids.device
)
inputs_embeds = self.text_model.embed_tokens(input_ids) inputs_embeds = self.text_model.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided # apply the visual model to the pixel values if they are provided
@ -525,7 +531,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
mrope_position_deltas, device=input_ids.device mrope_position_deltas, device=input_ids.device
).unsqueeze(1) ).unsqueeze(1)
# TODO: adjust model to accept 2D position_ids
outputs = self.text_model( outputs = self.text_model(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -541,4 +546,5 @@ class Qwen2VLForConditionalGeneration(nn.Module):
attention_mask=attention_mask, attention_mask=attention_mask,
) )
return outputs, None logits = self.lm_head(outputs)
return logits, None

View File

@ -68,7 +68,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
elif config.model_type == "paligemma": elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens return "<image>" * config.text_config.num_image_tokens
elif config.model_type == "qwen2_vl": elif config.model_type == "qwen2_vl":
return "<image>" num_pads = image_input.pixel_values.shape[0] // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
else: else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal") raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
@ -183,10 +185,11 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
if images: if images:
# TODO: REMOVE (this is for debugging purposes) if images[0][0].width <= 20:
images = images[0][0].resize( # TODO: provide a better way to handle the issue of the prefill image being too small
(images[0][0].width * 2, images[0][0].height * 2) images = images[0][0].resize(
) (images[0][0].width * 2, images[0][0].height * 2)
)
image_inputs = processor.image_processor(images, return_tensors="pt") image_inputs = processor.image_processor(images, return_tensors="pt")
else: else:
image_inputs = None image_inputs = None