mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
feat: fix token padding, enable warmup and process basic request
This commit is contained in:
parent
d96eef2a02
commit
09ac4fb6eb
@ -138,10 +138,39 @@ impl Paligemma {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct Qwen2VlVisionConfig {
|
||||
pub(crate) depth: usize,
|
||||
pub(crate) embed_dim: usize,
|
||||
pub(crate) mlp_ratio: usize,
|
||||
pub(crate) num_heads: usize,
|
||||
pub(crate) in_chans: usize,
|
||||
pub(crate) hidden_size: usize,
|
||||
pub(crate) patch_size: usize,
|
||||
pub(crate) spatial_merge_size: usize,
|
||||
pub(crate) spatial_patch_size: usize,
|
||||
pub(crate) temporal_patch_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct Qwen2Vl {
|
||||
pub(crate) vision_config: Qwen2VlVisionConfig,
|
||||
}
|
||||
|
||||
impl Qwen2Vl {
|
||||
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
||||
// TODO: calculate number of features
|
||||
6000 / 4
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Config {
|
||||
Qwen2Vl(Qwen2Vl),
|
||||
LlavaNext(LlavaNext),
|
||||
ClipVisionModel(ClipVisionModel),
|
||||
Mistral,
|
||||
|
@ -594,6 +594,10 @@ fn image_tokens(
|
||||
}
|
||||
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||
Qwen2Vl(config) => format!(
|
||||
"<|vision_start|>{:?}<|vision_end|>",
|
||||
"<|image_pad|>".repeat(config.get_number_of_features(height, width))
|
||||
),
|
||||
_ => unimplemented!("Images tokens are not supported for this model configuration"),
|
||||
}
|
||||
}
|
||||
@ -620,7 +624,9 @@ fn prepare_input<T: TokenizerTrait>(
|
||||
use Config::*;
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||
Some(
|
||||
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)),
|
||||
) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
|
@ -348,6 +348,10 @@ class Qwen2Model(torch.nn.Module):
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
# TODO: fix how N-D position_ids are handled
|
||||
|
||||
if position_ids.dim() == 2:
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin_hack(
|
||||
position_ids, true_max_s, hidden_states.dtype
|
||||
)
|
||||
|
@ -34,6 +34,7 @@ from text_generation_server.layers.layernorm import (
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
FastLinear,
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
@ -352,6 +353,7 @@ class Qwen2VisionModel(nn.Module):
|
||||
class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
config.vision_config.quantize = None
|
||||
config.vision_config.speculator = config.speculator
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -364,6 +366,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
prefix="visual", config=config.vision_config, weights=weights
|
||||
)
|
||||
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
|
||||
self.lm_head = FastLinear.load(
|
||||
prefix="lm_head", weights=weights, config=config, bias=False
|
||||
)
|
||||
self.device = weights.device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -386,10 +392,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
):
|
||||
|
||||
# make an attention_mask that is the same size as the input_ids
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
||||
|
||||
# make an attention_mask that is (batch_size, sequence_length)
|
||||
attention_mask = torch.ones_like(
|
||||
input_ids, dtype=torch.bool, device=input_ids.device
|
||||
)
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
@ -525,7 +531,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
mrope_position_deltas, device=input_ids.device
|
||||
).unsqueeze(1)
|
||||
|
||||
# TODO: adjust model to accept 2D position_ids
|
||||
outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -541,4 +546,5 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
return outputs, None
|
||||
logits = self.lm_head(outputs)
|
||||
return logits, None
|
||||
|
@ -68,7 +68,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
||||
elif config.model_type == "paligemma":
|
||||
return "<image>" * config.text_config.num_image_tokens
|
||||
elif config.model_type == "qwen2_vl":
|
||||
return "<image>"
|
||||
num_pads = image_input.pixel_values.shape[0] // 4
|
||||
padding = "<|image_pad|>" * num_pads
|
||||
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||
else:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
@ -183,10 +185,11 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
|
||||
if images:
|
||||
# TODO: REMOVE (this is for debugging purposes)
|
||||
images = images[0][0].resize(
|
||||
(images[0][0].width * 2, images[0][0].height * 2)
|
||||
)
|
||||
if images[0][0].width <= 20:
|
||||
# TODO: provide a better way to handle the issue of the prefill image being too small
|
||||
images = images[0][0].resize(
|
||||
(images[0][0].width * 2, images[0][0].height * 2)
|
||||
)
|
||||
image_inputs = processor.image_processor(images, return_tensors="pt")
|
||||
else:
|
||||
image_inputs = None
|
||||
|
Loading…
Reference in New Issue
Block a user