mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
feat: support video input chunks and enable qwen2 vl to process video
This commit is contained in:
parent
3c07391e8e
commit
b2c557594f
@ -9,7 +9,7 @@ use thiserror::Error;
|
|||||||
use tonic::transport;
|
use tonic::transport;
|
||||||
use tonic::Status;
|
use tonic::Status;
|
||||||
|
|
||||||
pub use v3::{Chunk, Image, Input, InputChunk};
|
pub use v3::{Chunk, Image, Input, InputChunk, Video};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Health {
|
pub trait Health {
|
||||||
@ -79,8 +79,9 @@ impl ChunksToString for Vec<InputChunk> {
|
|||||||
let encoded = STANDARD.encode(data);
|
let encoded = STANDARD.encode(data);
|
||||||
output.push_str(&format!("", mimetype, encoded))
|
output.push_str(&format!("", mimetype, encoded))
|
||||||
}
|
}
|
||||||
Some(Chunk::Video(url)) => {
|
Some(Chunk::Video(Video { data, mimetype })) => {
|
||||||
output.push_str(&format!("<video>({})", url))
|
let encoded = STANDARD.encode(data);
|
||||||
|
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
|
||||||
}
|
}
|
||||||
// We don't create empty chunks, so this should be unreachable.
|
// We don't create empty chunks, so this should be unreachable.
|
||||||
None => unreachable!("Chunks should never be empty"),
|
None => unreachable!("Chunks should never be empty"),
|
||||||
|
@ -8,6 +8,6 @@ pub use client::Client;
|
|||||||
pub use pb::generate::v3::{
|
pub use pb::generate::v3::{
|
||||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||||
StoppingCriteriaParameters, Tokens,
|
StoppingCriteriaParameters, Tokens, Video,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
|
@ -15,7 +15,7 @@ pub use grpc_client::Client;
|
|||||||
pub use pb::generate::v3::{
|
pub use pb::generate::v3::{
|
||||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||||
StoppingCriteriaParameters,
|
StoppingCriteriaParameters, Video,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
|
|
||||||
|
@ -439,7 +439,10 @@ impl State {
|
|||||||
data: image.data,
|
data: image.data,
|
||||||
mimetype: image.mimetype,
|
mimetype: image.mimetype,
|
||||||
}),
|
}),
|
||||||
Chunk::Video(url) => client::Chunk::Video(url),
|
Chunk::Video(video) => client::Chunk::Video(client::Video {
|
||||||
|
data: video.data,
|
||||||
|
mimetype: video.mimetype,
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
|
@ -1922,6 +1922,24 @@
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"video_url",
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"video_url"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"video_url": {
|
||||||
|
"$ref": "#/components/schemas/Url"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"discriminator": {
|
"discriminator": {
|
||||||
|
@ -63,7 +63,6 @@ Options:
|
|||||||
|
|
||||||
Possible values:
|
Possible values:
|
||||||
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
|
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
|
||||||
- compressed-tensors: Compressed tensors, which can be a mixture of different quantization methods
|
|
||||||
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||||
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
|
- exl2: Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)
|
||||||
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||||
|
@ -1230,7 +1230,9 @@ impl From<Message> for TextMessage {
|
|||||||
.map(|chunk| match chunk {
|
.map(|chunk| match chunk {
|
||||||
MessageChunk::Text { text } => text,
|
MessageChunk::Text { text } => text,
|
||||||
MessageChunk::ImageUrl { image_url } => format!("", image_url.url),
|
MessageChunk::ImageUrl { image_url } => format!("", image_url.url),
|
||||||
MessageChunk::VideoUrl { video_url } => format!("", video_url.url),
|
MessageChunk::VideoUrl { video_url } => {
|
||||||
|
format!("<video>({})", video_url.url)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(""),
|
.join(""),
|
||||||
|
@ -536,6 +536,70 @@ fn format_to_mimetype(format: ImageFormat) -> String {
|
|||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn fetch_video(input: &str) -> Result<(Vec<u8>, String, usize, usize, f32), ValidationError> {
|
||||||
|
let (data, mimetype) =
|
||||||
|
if input.starts_with("<video>(http://") || input.starts_with("<video>(https://") {
|
||||||
|
let url = &input["<video>(".len()..input.len() - 1];
|
||||||
|
let data = reqwest::blocking::get(url)?.bytes()?.to_vec();
|
||||||
|
(data, "video/mp4".to_string())
|
||||||
|
} else if input.starts_with("<video>(data:") {
|
||||||
|
let content = &input["<video>(data:".len()..input.len() - 1];
|
||||||
|
let tokens: Vec<_> = content.split(';').collect();
|
||||||
|
if tokens.len() != 2 {
|
||||||
|
return Err(ValidationError::InvalidVideoContent(content.to_string()));
|
||||||
|
}
|
||||||
|
let mimetype = tokens[0];
|
||||||
|
let content = tokens[1];
|
||||||
|
if !content.starts_with("base64,") {
|
||||||
|
return Err(ValidationError::InvalidVideoContent(content.to_string()));
|
||||||
|
}
|
||||||
|
let data = STANDARD.decode(&content["base64,".len()..])?;
|
||||||
|
(data, mimetype.to_string())
|
||||||
|
} else {
|
||||||
|
return Err(ValidationError::InvalidVideoContent(input.to_string()));
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut cursor = Cursor::new(&data);
|
||||||
|
let context = mp4parse::read_mp4(&mut cursor).map_err(|_| ValidationError::MP4Error)?;
|
||||||
|
|
||||||
|
let video_track = context
|
||||||
|
.tracks
|
||||||
|
.iter()
|
||||||
|
.find(|track| track.track_type == mp4parse::TrackType::Video)
|
||||||
|
.ok_or(ValidationError::NoVideoStream)?;
|
||||||
|
|
||||||
|
let video_info = video_track
|
||||||
|
.tkhd
|
||||||
|
.as_ref()
|
||||||
|
.ok_or(ValidationError::NoVideoStream)?;
|
||||||
|
let width = (video_info.width >> 16) as usize;
|
||||||
|
let height = (video_info.height >> 16) as usize;
|
||||||
|
|
||||||
|
// timescale units per second
|
||||||
|
let timescale = video_track.timescale.map(|t| t.0 as f32).unwrap_or(600.0);
|
||||||
|
|
||||||
|
// TODO: revisit if we need duration in seconds
|
||||||
|
let _duration = video_track
|
||||||
|
.duration
|
||||||
|
.map(|d| d.0 as f32 / timescale)
|
||||||
|
.unwrap_or(0.0);
|
||||||
|
|
||||||
|
let time_to_sample = video_track
|
||||||
|
.stts
|
||||||
|
.as_ref()
|
||||||
|
.ok_or(ValidationError::NoVideoStream)?;
|
||||||
|
|
||||||
|
let num_samples = time_to_sample
|
||||||
|
.samples
|
||||||
|
.iter()
|
||||||
|
.map(|entry| entry.sample_count)
|
||||||
|
.sum::<u32>();
|
||||||
|
|
||||||
|
let total_frames = num_samples as f32;
|
||||||
|
|
||||||
|
Ok((data, mimetype, height, width, total_frames))
|
||||||
|
}
|
||||||
|
|
||||||
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
||||||
if input.starts_with(" || input.starts_with(" {
|
if input.starts_with(" || input.starts_with(" {
|
||||||
let url = &input["..input.len() - 1];
|
let url = &input["..input.len() - 1];
|
||||||
@ -624,6 +688,31 @@ fn image_tokens(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn video_tokens(config: &Config, height: usize, width: usize, total_frames: f32) -> String {
|
||||||
|
use Config::*;
|
||||||
|
|
||||||
|
match config {
|
||||||
|
// TOOD: improve to use the config to better estimate the number of tokens
|
||||||
|
Qwen2Vl(_config) => {
|
||||||
|
let video_fps = 30_f32;
|
||||||
|
let fps = 30_f32;
|
||||||
|
let min_frames = 16_f32;
|
||||||
|
let max_frames = 64_f32;
|
||||||
|
// make sure the frames are within the range and are even
|
||||||
|
let nframes = (total_frames / video_fps * fps)
|
||||||
|
.max(min_frames)
|
||||||
|
.min(max_frames);
|
||||||
|
let nframes = (nframes / 2.0).round() as usize * 2;
|
||||||
|
let num_tokens = nframes * height * width / 1541;
|
||||||
|
format!(
|
||||||
|
"<|vision_start|>{:?}<|vision_end|>",
|
||||||
|
"<|video_pad|>".repeat(num_tokens)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => unimplemented!("Video tokens are not supported for this model configuration"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn image_tokens_fixup(config: &Config, text: String) -> String {
|
fn image_tokens_fixup(config: &Config, text: String) -> String {
|
||||||
match config {
|
match config {
|
||||||
Config::Idefics2(_) => {
|
Config::Idefics2(_) => {
|
||||||
@ -646,7 +735,8 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
use Config::*;
|
use Config::*;
|
||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
// Add video regex
|
// Add video regex
|
||||||
static VIDEO_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
|
static VIDEO_RE: Lazy<Regex> =
|
||||||
|
Lazy::new(|| Regex::new(r"<video>\((https?://[^\)]+)\)").unwrap());
|
||||||
|
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
Some(
|
Some(
|
||||||
@ -656,7 +746,7 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
|
|
||||||
// Process videos first
|
// handle video content first
|
||||||
for chunk in VIDEO_RE.find_iter(&inputs) {
|
for chunk in VIDEO_RE.find_iter(&inputs) {
|
||||||
let chunk_start = chunk.start();
|
let chunk_start = chunk.start();
|
||||||
let chunk_end = chunk.end();
|
let chunk_end = chunk.end();
|
||||||
@ -664,14 +754,15 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let video_url = &inputs[chunk_start + 8..chunk_end - 1]; // Remove <video>( and )
|
let (data, mimetype, height, width, total_frames) =
|
||||||
input_chunks.push(Chunk::Video(video_url.to_string()));
|
fetch_video(&inputs[chunk_start..chunk_end])?;
|
||||||
// For videos, we use the default size as height/width don't matter for the initial processing
|
input_chunks.push(Chunk::Video(Video { data, mimetype }));
|
||||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, 1, 1));
|
let video_tokens = video_tokens(config, height, width, total_frames);
|
||||||
|
tokenizer_query.push_str(&video_tokens);
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process remaining content for images
|
// clip remaining inputs and process images
|
||||||
let remaining_input = &inputs[start..];
|
let remaining_input = &inputs[start..];
|
||||||
for chunk in RE.find_iter(remaining_input) {
|
for chunk in RE.find_iter(remaining_input) {
|
||||||
let chunk_start = chunk.start() + start;
|
let chunk_start = chunk.start() + start;
|
||||||
@ -719,11 +810,17 @@ pub struct Image {
|
|||||||
pub mimetype: String,
|
pub mimetype: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
|
pub struct Video {
|
||||||
|
pub data: Vec<u8>,
|
||||||
|
pub mimetype: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
pub enum Chunk {
|
pub enum Chunk {
|
||||||
Text(String),
|
Text(String),
|
||||||
Image(Image),
|
Image(Image),
|
||||||
Video(String),
|
Video(Video),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert input chunks to a stringly-typed input for backwards
|
/// Convert input chunks to a stringly-typed input for backwards
|
||||||
@ -742,8 +839,9 @@ impl ChunksToString for Vec<Chunk> {
|
|||||||
let encoded = STANDARD.encode(data);
|
let encoded = STANDARD.encode(data);
|
||||||
output.push_str(&format!("", mimetype, encoded))
|
output.push_str(&format!("", mimetype, encoded))
|
||||||
}
|
}
|
||||||
Chunk::Video(url) => {
|
Chunk::Video(Video { data, mimetype }) => {
|
||||||
output.push_str(&format!("<video>({})", url))
|
let encoded = STANDARD.encode(data);
|
||||||
|
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
output
|
output
|
||||||
@ -875,6 +973,10 @@ pub enum ValidationError {
|
|||||||
UnsupportedModality(&'static str),
|
UnsupportedModality(&'static str),
|
||||||
#[error("invalid video content: {0}")]
|
#[error("invalid video content: {0}")]
|
||||||
InvalidVideoContent(String),
|
InvalidVideoContent(String),
|
||||||
|
#[error("could not parse MP4 file")]
|
||||||
|
MP4Error,
|
||||||
|
#[error("no video stream found")]
|
||||||
|
NoVideoStream,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -14,20 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch Qwen2 VL model."""
|
"""PyTorch Qwen2 VL model."""
|
||||||
|
|
||||||
__all__ = ['Qwen2VLForConditionalGeneration', 'process_qwen_video']
|
|
||||||
|
|
||||||
from typing import Dict, Optional, Tuple, List
|
from typing import Dict, Optional, Tuple, List
|
||||||
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
import requests
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from contextlib import contextmanager
|
|
||||||
from qwen_vl_utils import process_vision_info
|
|
||||||
|
|
||||||
|
|
||||||
if SYSTEM == "ipex":
|
if SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
@ -445,38 +437,48 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
vision_tokens = input_ids[vision_start_indices + 1]
|
vision_tokens = input_ids[vision_start_indices + 1]
|
||||||
|
|
||||||
# Count both image and video tokens
|
# only copy the sum of the image and video tokens GPU<->CPU
|
||||||
image_count = (vision_tokens == self.image_token_id).sum().item()
|
image_count = (vision_tokens == self.image_token_id).sum().item()
|
||||||
video_count = (vision_tokens == self.video_token_id).sum().item()
|
video_count = (vision_tokens == self.video_token_id).sum().item()
|
||||||
|
|
||||||
current_pos = 0
|
current_pos = 0
|
||||||
for _ in range(image_count + video_count):
|
for _ in range(image_count + video_count):
|
||||||
# Find next vision token position (either image or video)
|
# copy the value position of the next image or video token from GPU<->CPU
|
||||||
next_vision_pos = (
|
next_vision_pos = (
|
||||||
((input_ids[current_pos:] == self.image_token_id) |
|
(
|
||||||
(input_ids[current_pos:] == self.video_token_id))
|
(input_ids[current_pos:] == self.image_token_id)
|
||||||
|
| (input_ids[current_pos:] == self.video_token_id)
|
||||||
|
)
|
||||||
.nonzero()[0]
|
.nonzero()[0]
|
||||||
.item()
|
.item()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine if current token is video or image
|
# TODO: revisit above to get all next_image_pos in one go to avoid copying in the loop
|
||||||
is_video = input_ids[current_pos + next_vision_pos] == self.video_token_id
|
is_video = (
|
||||||
grid_thw = video_grid_thw[vision_index] if is_video else image_grid_thw[vision_index]
|
input_ids[current_pos + next_vision_pos] == self.video_token_id
|
||||||
|
)
|
||||||
|
grid_thw = (
|
||||||
|
video_grid_thw[vision_index]
|
||||||
|
if is_video
|
||||||
|
else image_grid_thw[vision_index]
|
||||||
|
)
|
||||||
|
|
||||||
time_steps, height, width = grid_thw.clone()
|
time_steps, height, width = grid_thw.clone()
|
||||||
height //= self.spatial_merge_size
|
height //= self.spatial_merge_size
|
||||||
width //= self.spatial_merge_size
|
width //= self.spatial_merge_size
|
||||||
|
|
||||||
# Calculate lengths and indices same as before
|
# calculate the length of the text and image tokens
|
||||||
text_length = next_vision_pos - current_pos
|
text_length = next_vision_pos - current_pos
|
||||||
start_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
start_idx = (
|
||||||
|
llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||||
|
)
|
||||||
|
|
||||||
# Text position ids
|
# text position ids
|
||||||
text_pos_ids = torch.arange(text_length, device=d)
|
text_pos_ids = torch.arange(text_length, device=d)
|
||||||
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
|
text_pos_ids = text_pos_ids.view(1, -1).expand(3, -1) + start_idx
|
||||||
llm_pos_ids_list.append(text_pos_ids)
|
llm_pos_ids_list.append(text_pos_ids)
|
||||||
|
|
||||||
# Vision position ids
|
# vision position ids
|
||||||
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
|
t_indices = torch.arange(time_steps, device=d).repeat_interleave(
|
||||||
height * width
|
height * width
|
||||||
)
|
)
|
||||||
@ -485,7 +487,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
.repeat_interleave(width)
|
.repeat_interleave(width)
|
||||||
.repeat(time_steps)
|
.repeat(time_steps)
|
||||||
)
|
)
|
||||||
w_indices = torch.arange(width, device=d).repeat(height * time_steps)
|
w_indices = torch.arange(width, device=d).repeat(
|
||||||
|
height * time_steps
|
||||||
|
)
|
||||||
|
|
||||||
vision_pos_ids = (
|
vision_pos_ids = (
|
||||||
torch.stack([t_indices, h_indices, w_indices])
|
torch.stack([t_indices, h_indices, w_indices])
|
||||||
@ -499,7 +503,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
# Handle remaining text if any
|
# Handle remaining text if any
|
||||||
if current_pos < batch_input_ids.size(1):
|
if current_pos < batch_input_ids.size(1):
|
||||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
st_idx = (
|
||||||
|
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
)
|
||||||
text_len = batch_input_ids.size(1) - current_pos
|
text_len = batch_input_ids.size(1) - current_pos
|
||||||
llm_pos_ids_list.append(
|
llm_pos_ids_list.append(
|
||||||
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
|
torch.arange(text_len, device=d).view(1, -1).expand(3, -1) + st_idx
|
||||||
@ -528,6 +534,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor],
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
video_pixel_values: torch.FloatTensor = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
pixel_attention_mask=None,
|
pixel_attention_mask=None,
|
||||||
@ -538,15 +545,26 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
):
|
):
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if video_pixel_values is not None and len(video_pixel_values) > 0:
|
||||||
|
vision_embeds = self.visual(
|
||||||
|
video_pixel_values, grid_thw=video_grid_thw
|
||||||
|
).squeeze(0)
|
||||||
|
vision_token_mask = input_ids == self.video_token_id
|
||||||
|
inputs_embeds[vision_token_mask] = vision_embeds
|
||||||
|
|
||||||
# apply the visual model to the pixel values if they are provided
|
# apply the visual model to the pixel values if they are provided
|
||||||
if pixel_values is not None and len(pixel_values) > 0:
|
if pixel_values is not None and len(pixel_values) > 0:
|
||||||
vision_embeds = self.visual(
|
vision_embeds = self.visual(
|
||||||
pixel_values,
|
pixel_values,
|
||||||
grid_thw=torch.cat([image_grid_thw, video_grid_thw]) if video_grid_thw is not None else image_grid_thw
|
grid_thw=(
|
||||||
|
torch.cat([image_grid_thw, video_grid_thw])
|
||||||
|
if video_grid_thw is not None
|
||||||
|
else image_grid_thw
|
||||||
|
),
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
|
|
||||||
# Apply embeddings to both image and video tokens
|
# Apply embeddings to image tokens
|
||||||
vision_token_mask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id)
|
vision_token_mask = input_ids == self.image_token_id
|
||||||
inputs_embeds[vision_token_mask] = vision_embeds
|
inputs_embeds[vision_token_mask] = vision_embeds
|
||||||
|
|
||||||
hidden_states = self.text_model(
|
hidden_states = self.text_model(
|
||||||
|
@ -18,6 +18,8 @@ from text_generation_server.utils.log import log_master
|
|||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||||
|
from torchvision import io
|
||||||
|
import math
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -77,6 +79,20 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
|
|
||||||
|
def video_text_replacement(processor, video_input, config) -> str:
|
||||||
|
if config.model_type == "qwen2_vl":
|
||||||
|
# num_pads = video_input['pixel_values'].size(0)
|
||||||
|
# num_pads = 1206
|
||||||
|
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
|
# num_pads = 9556 + 10
|
||||||
|
num_pads = video_input.pixel_values.shape[0] // 4
|
||||||
|
padding = "<|video_pad|>" * num_pads
|
||||||
|
return f"<|vision_start|>{padding}<|vision_end|>"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||||
|
|
||||||
|
|
||||||
def image_text_replacement_fixup(config, text: str) -> str:
|
def image_text_replacement_fixup(config, text: str) -> str:
|
||||||
if config.model_type == "idefics2":
|
if config.model_type == "idefics2":
|
||||||
return text.replace(
|
return text.replace(
|
||||||
@ -139,29 +155,59 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||||||
return unpadded_features + newline_features + base_features
|
return unpadded_features + newline_features + base_features
|
||||||
|
|
||||||
|
|
||||||
|
# copied from: https://github.com/QwenLM/Qwen2-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||||
|
def smart_nframes(
|
||||||
|
fps: int,
|
||||||
|
nframes: int,
|
||||||
|
min_frames: int,
|
||||||
|
max_frames: int,
|
||||||
|
total_frames: int,
|
||||||
|
video_fps: int | float,
|
||||||
|
) -> int:
|
||||||
|
if nframes:
|
||||||
|
nframes = round(nframes / 2) * 2
|
||||||
|
else:
|
||||||
|
min_frames = math.ceil(min_frames / 2) * 2
|
||||||
|
max_frames = math.floor(max_frames / 2) * 2
|
||||||
|
nframes = total_frames / video_fps * fps
|
||||||
|
nframes = min(max(nframes, min_frames), max_frames)
|
||||||
|
nframes = round(nframes / 2) * 2
|
||||||
|
if not (2 <= nframes and nframes <= total_frames):
|
||||||
|
raise ValueError(
|
||||||
|
f"nframes should in interval [{2}, {total_frames}], but got {nframes}."
|
||||||
|
)
|
||||||
|
return nframes
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLMBatch(FlashCausalLMBatch):
|
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
|
video_pixel_values: Optional[List[torch.Tensor]]
|
||||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
image_sizes: Optional[List[Tuple[int, int]]]
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
image_grid_thw: Optional[torch.Tensor]
|
image_grid_thw: Optional[torch.Tensor]
|
||||||
|
video_grid_thw: Optional[torch.Tensor]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches):
|
def concatenate(cls, batches):
|
||||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
batch.video_pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
batch.video_grid_thw = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]):
|
def filter(self, request_ids: List[int]):
|
||||||
batch = super().filter(request_ids)
|
batch = super().filter(request_ids)
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
batch.video_pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
batch.video_grid_thw = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -172,6 +218,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
# can make the image splits the same size. And we need the final
|
# can make the image splits the same size. And we need the final
|
||||||
# sizes to insert correct number of image tokens.
|
# sizes to insert correct number of image tokens.
|
||||||
images = []
|
images = []
|
||||||
|
videos = []
|
||||||
for r in requests:
|
for r in requests:
|
||||||
for chunk in r.input_chunks.chunks:
|
for chunk in r.input_chunks.chunks:
|
||||||
chunk_type = chunk.WhichOneof("chunk")
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
@ -193,7 +240,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
images.append([image])
|
images.append([image])
|
||||||
elif chunk_type == "video":
|
elif chunk_type == "video":
|
||||||
if config.model_type == "qwen2_vl":
|
if config.model_type == "qwen2_vl":
|
||||||
pass
|
videos.append(chunk.video)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
@ -202,6 +249,41 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
|
|
||||||
|
video_inputs = None
|
||||||
|
if videos:
|
||||||
|
try:
|
||||||
|
tensor_videos = []
|
||||||
|
video = videos[0]
|
||||||
|
video_buffer = BytesIO(video.data)
|
||||||
|
video, _audio, info = io.read_video(
|
||||||
|
video_buffer,
|
||||||
|
start_pts=0.0,
|
||||||
|
end_pts=None,
|
||||||
|
pts_unit="sec",
|
||||||
|
output_format="TCHW",
|
||||||
|
)
|
||||||
|
total_frames, video_fps = video.size(0), info["video_fps"]
|
||||||
|
nframes = smart_nframes(
|
||||||
|
fps=30,
|
||||||
|
nframes=None,
|
||||||
|
min_frames=16,
|
||||||
|
max_frames=64,
|
||||||
|
total_frames=total_frames,
|
||||||
|
video_fps=video_fps,
|
||||||
|
)
|
||||||
|
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
|
||||||
|
video = video[idx]
|
||||||
|
tensor_videos.append(video)
|
||||||
|
video_inputs = processor.image_processor(
|
||||||
|
tensor_videos, return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to process video: {e}")
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
video_inputs = None
|
||||||
|
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
image_id = 0
|
image_id = 0
|
||||||
@ -216,10 +298,8 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
processor, image_inputs, config, image_id
|
processor, image_inputs, config, image_id
|
||||||
)
|
)
|
||||||
image_id += 1
|
image_id += 1
|
||||||
elif chunk_type == "video" and config.model_type == "qwen2_vl":
|
elif chunk_type == "video":
|
||||||
from text_generation_server.models.custom_modeling.qwen2_vl import process_qwen_video
|
full_text += video_text_replacement(processor, video_inputs, config)
|
||||||
text, _ = process_qwen_video(chunk.video)
|
|
||||||
full_text += text
|
|
||||||
|
|
||||||
full_text = image_text_replacement_fixup(config, full_text)
|
full_text = image_text_replacement_fixup(config, full_text)
|
||||||
batch_inputs.append(full_text)
|
batch_inputs.append(full_text)
|
||||||
@ -232,7 +312,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
add_special_tokens=not config.model_type == "paligemma",
|
add_special_tokens=not config.model_type == "paligemma",
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs, video_inputs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb_processor(
|
def from_pb_processor(
|
||||||
@ -244,10 +324,23 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "VlmCausalLMBatch":
|
) -> "VlmCausalLMBatch":
|
||||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
batch_tokenized_inputs, image_inputs, video_inputs = cls.batch_tokenized_inputs(
|
||||||
pb.requests, tokenizer, processor, config
|
pb.requests, tokenizer, processor, config
|
||||||
)
|
)
|
||||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
|
if video_inputs is not None:
|
||||||
|
if "pixel_values" in video_inputs:
|
||||||
|
batch.video_pixel_values = video_inputs["pixel_values"].to(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
if "image_grid_thw" in video_inputs:
|
||||||
|
batch.video_grid_thw = video_inputs["image_grid_thw"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.video_grid_thw = None
|
||||||
|
else:
|
||||||
|
batch.video_pixel_values = None
|
||||||
|
batch.video_grid_thw = None
|
||||||
|
|
||||||
if image_inputs is not None:
|
if image_inputs is not None:
|
||||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||||
if "pixel_attention_mask" in image_inputs:
|
if "pixel_attention_mask" in image_inputs:
|
||||||
@ -264,6 +357,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
|
||||||
else:
|
else:
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
if "video_grid_thw" in image_inputs:
|
||||||
|
batch.video_grid_thw = image_inputs["video_grid_thw"].to(device=device)
|
||||||
|
else:
|
||||||
|
batch.video_grid_thw = None
|
||||||
else:
|
else:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
@ -271,6 +368,7 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLM(FlashCausalLM):
|
class VlmCausalLM(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -373,7 +471,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
if self.model.config.model_type == "qwen2_vl":
|
if self.model.config.model_type == "qwen2_vl":
|
||||||
if position_ids.dim() == 1 and batch.prefilling:
|
if position_ids.dim() == 1 and batch.prefilling:
|
||||||
position_ids = self.model.get_position_ids(
|
position_ids = self.model.get_position_ids(
|
||||||
input_ids, batch.image_grid_thw
|
input_ids, batch.image_grid_thw, batch.video_grid_thw
|
||||||
)
|
)
|
||||||
batch.position_ids = position_ids
|
batch.position_ids = position_ids
|
||||||
|
|
||||||
@ -426,20 +524,26 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
pixel_values=batch.pixel_values,
|
pixel_values=batch.pixel_values,
|
||||||
|
video_pixel_values=batch.video_pixel_values,
|
||||||
pixel_attention_mask=batch.pixel_attention_mask,
|
pixel_attention_mask=batch.pixel_attention_mask,
|
||||||
image_sizes=batch.image_sizes,
|
image_sizes=batch.image_sizes,
|
||||||
image_grid_thw=batch.image_grid_thw,
|
image_grid_thw=batch.image_grid_thw,
|
||||||
|
video_grid_thw=batch.video_grid_thw,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
batch.pixel_values = None
|
batch.pixel_values = None
|
||||||
|
if batch.video_pixel_values is not None:
|
||||||
|
batch.video_pixel_values = None
|
||||||
if batch.pixel_attention_mask is not None:
|
if batch.pixel_attention_mask is not None:
|
||||||
batch.pixel_attention_mask = None
|
batch.pixel_attention_mask = None
|
||||||
if batch.image_sizes is not None:
|
if batch.image_sizes is not None:
|
||||||
batch.image_sizes = None
|
batch.image_sizes = None
|
||||||
if batch.image_grid_thw is not None:
|
if batch.image_grid_thw is not None:
|
||||||
batch.image_grid_thw = None
|
batch.image_grid_thw = None
|
||||||
|
if batch.video_grid_thw is not None:
|
||||||
|
batch.video_grid_thw = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
Loading…
Reference in New Issue
Block a user