fix: resolve rebase issues and add test

This commit is contained in:
drbh 2024-12-12 18:31:33 +00:00
parent 71ed75a21b
commit e2b75a572f
3 changed files with 275 additions and 64 deletions

View File

@ -0,0 +1,19 @@
{
"choices": [
{
"delta": {
"content": "",
"role": "assistant"
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1733450914,
"id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.4.2-dev0-native",
"usage": null
}

View File

@ -0,0 +1,84 @@
import pytest
import json
import requests
@pytest.fixture(scope="module")
def qwen2_vl_handle(launcher):
with launcher(
"Qwen/Qwen2-VL-7B-Instruct",
max_input_length=10_000,
max_batch_prefill_tokens=10_000,
max_total_tokens=10_001,
cuda_graphs=[0],
) as handle:
yield handle
@pytest.fixture(scope="module")
async def qwen2_vl(qwen2_vl_handle):
await qwen2_vl_handle.health(300)
return qwen2_vl_handle.client
@pytest.mark.asyncio
async def test_qwen2_vl_simpl(qwen2_vl, response_snapshot):
responses = requests.post(
f"{qwen2_vl.base_url}/v1/chat/completions",
headers=qwen2_vl.headers,
json={
"model": "tgi",
"messages": [
{
"role": "user",
"content": [
{
"type": "video_url",
"video_url": {
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/360/Big_Buck_Bunny_360_10s_1MB.mp4"
},
},
{
"type": "text",
"text": "Describe this video.",
},
],
},
],
"seed": 42,
"max_tokens": 100,
"stream": True,
},
)
# iterate over the response in chunks
count = 0
full_text = ""
last_response = None
for chunk in responses.iter_content(chunk_size=1024):
if chunk:
count += 1
# remove the "data: " prefix, trailing newline, and split the chunk into individual lines
lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n")
for line in lines:
if line == "[DONE]":
break
print("=", line)
try:
response = json.loads(line)
# print(response)
last_response = response
full_text += response["choices"][0]["delta"]["content"]
except json.JSONDecodeError:
pass
# assert count == 27
# assert response.usage == {
# "completion_tokens": 10,
# "prompt_tokens": 50,
# "total_tokens": 60,
# }
# assert (
# response.choices[0].message.content
# == "In a bustling city, a chicken named Cluck"
# )
assert last_response == response_snapshot

View File

@ -21,6 +21,11 @@ use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tracing::{instrument, Span};
use {once_cell::sync::Lazy, regex::Regex};
// video processing
use ffmpeg_next::format::Pixel;
use ffmpeg_next::media::Type;
use ffmpeg_next::software::scaling::{context::Context, flag::Flags};
use std::io::Write;
static DEFAULT_GENERATION_LENGTH: u32 = 1024;
@ -536,7 +541,11 @@ fn format_to_mimetype(format: ImageFormat) -> String {
.to_string()
}
pub fn fetch_video(input: &str) -> Result<(Vec<u8>, String, usize, usize, f32), ValidationError> {
pub fn fetch_video(
input: &str,
target_width: u32,
target_height: u32,
) -> Result<ProcessedVideo, ValidationError> {
let (data, mimetype) =
if input.starts_with("<video>(http://") || input.starts_with("<video>(https://") {
let url = &input["<video>(".len()..input.len() - 1];
@ -544,7 +553,7 @@ pub fn fetch_video(input: &str) -> Result<(Vec<u8>, String, usize, usize, f32),
(data, "video/mp4".to_string())
} else if input.starts_with("<video>(data:") {
let content = &input["<video>(data:".len()..input.len() - 1];
let tokens: Vec<_> = content.split(';').collect();
let tokens: Vec<&str> = content.split(';').collect();
if tokens.len() != 2 {
return Err(ValidationError::InvalidVideoContent(content.to_string()));
}
@ -559,45 +568,101 @@ pub fn fetch_video(input: &str) -> Result<(Vec<u8>, String, usize, usize, f32),
return Err(ValidationError::InvalidVideoContent(input.to_string()));
};
let mut cursor = Cursor::new(&data);
let context = mp4parse::read_mp4(&mut cursor).map_err(|_| ValidationError::MP4Error)?;
// init ffmpeg
ffmpeg_next::init().map_err(|_| ValidationError::FFmpegError)?;
let video_track = context
.tracks
.iter()
.find(|track| track.track_type == mp4parse::TrackType::Video)
.ok_or(ValidationError::NoVideoStream)?;
// create temporary file for ffmpeg input
let mut temp_file = tempfile::NamedTempFile::new().map_err(ValidationError::IoError)?;
temp_file
.write_all(&data)
.map_err(ValidationError::IoError)?;
let video_info = video_track
.tkhd
.as_ref()
.ok_or(ValidationError::NoVideoStream)?;
let width = (video_info.width >> 16) as usize;
let height = (video_info.height >> 16) as usize;
let mut ictx =
ffmpeg_next::format::input(&temp_file.path()).map_err(|_| ValidationError::FFmpegError)?;
// timescale units per second
let timescale = video_track.timescale.map(|t| t.0 as f32).unwrap_or(600.0);
let input = ictx
.streams()
.best(Type::Video)
.ok_or(ValidationError::FFmpegError)?;
let video_stream_index = input.index();
// TODO: revisit if we need duration in seconds
let _duration = video_track
.duration
.map(|d| d.0 as f32 / timescale)
.unwrap_or(0.0);
let context_decoder = ffmpeg_next::codec::context::Context::from_parameters(input.parameters())
.map_err(|_| ValidationError::FFmpegError)?;
let mut decoder = context_decoder
.decoder()
.video()
.map_err(|_| ValidationError::FFmpegError)?;
let time_to_sample = video_track
.stts
.as_ref()
.ok_or(ValidationError::NoVideoStream)?;
let width = target_width;
let height = target_height;
let num_samples = time_to_sample
.samples
.iter()
.map(|entry| entry.sample_count)
.sum::<u32>();
let mut scaler = Context::get(
decoder.format(),
decoder.width(), // original width
decoder.height(),
Pixel::RGB24,
width, // target width
height,
Flags::BILINEAR,
)
.map_err(|_| ValidationError::FFmpegError)?;
let total_frames = num_samples as f32;
let mut frame_index = 0;
let mut captured_frame_index = 0;
let mut frames = vec![];
Ok((data, mimetype, height, width, total_frames))
let mut receive_and_process_decoded_frames = |decoder: &mut ffmpeg_next::decoder::Video,
raw_fps: f32|
-> Result<(), ffmpeg_next::Error> {
let mut decoded = ffmpeg_next::util::frame::video::Video::empty();
let fps = raw_fps.floor();
while decoder.receive_frame(&mut decoded).is_ok() {
let mut rgb_frame = ffmpeg_next::util::frame::video::Video::empty();
scaler.run(&decoded, &mut rgb_frame)?;
if frame_index as f32 % fps == 0.0 {
captured_frame_index += 1;
// Create new buffer without padding
let mut frame_data =
Vec::with_capacity((rgb_frame.width() * rgb_frame.height() * 3) as usize);
let src_data = rgb_frame.data(0);
let row_size = rgb_frame.width() as usize * 3;
// Copy each row without padding
for y in 0..rgb_frame.height() as usize {
let start = y * rgb_frame.stride(0) as usize;
let end = start + row_size;
frame_data.extend_from_slice(&src_data[start..end]);
}
frames.push(frame_data);
}
frame_index += 1;
}
Ok(())
};
for (stream, packet) in ictx.packets() {
// Floor the fps to get a whole number
let fps = (stream.rate().numerator() as f32 / stream.rate().denominator() as f32).floor();
if stream.index() == video_stream_index {
decoder
.send_packet(&packet)
.map_err(|_| ValidationError::FFmpegError)?;
receive_and_process_decoded_frames(&mut decoder, fps)
.map_err(|_| ValidationError::FFmpegError)?;
}
}
decoder
.send_eof()
.map_err(|_| ValidationError::FFmpegError)?;
Ok(ProcessedVideo {
mimetype,
height,
width,
frames,
sampled_frames: captured_frame_index,
})
}
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
@ -688,22 +753,18 @@ fn image_tokens(
}
}
fn video_tokens(config: &Config, height: usize, width: usize, total_frames: f32) -> String {
fn video_tokens(config: &Config, height: u32, width: u32, sampled_frames: f32) -> String {
use Config::*;
match config {
// TOOD: improve to use the config to better estimate the number of tokens
Qwen2Vl(_config) => {
let video_fps = 30_f32;
let fps = 30_f32;
let min_frames = 16_f32;
let max_frames = 64_f32;
let min_frames = 2_f32;
let max_frames = 256_f32;
// make sure the frames are within the range and are even
let nframes = (total_frames / video_fps * fps)
.max(min_frames)
.min(max_frames);
let nframes = (sampled_frames).max(min_frames).min(max_frames);
let nframes = (nframes / 2.0).round() as usize * 2;
let num_tokens = nframes * height * width / 1541;
let num_tokens = nframes * height as usize * width as usize / 1541;
format!(
"<|vision_start|>{:?}<|vision_end|>",
"<|video_pad|>".repeat(num_tokens)
@ -754,37 +815,66 @@ fn prepare_input<T: TokenizerTrait>(
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width, total_frames) =
fetch_video(&inputs[chunk_start..chunk_end])?;
input_chunks.push(Chunk::Video(Video { data, mimetype }));
let video_tokens = video_tokens(config, height, width, total_frames);
let processed_video = match config {
Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) => {
let default_target_width = 224;
let default_target_height = 224;
fetch_video(
&inputs[chunk_start..chunk_end],
default_target_width,
default_target_height,
)?
}
Qwen2Vl(_) => {
let target_width = 360;
let target_height = 420;
fetch_video(&inputs[chunk_start..chunk_end], target_width, target_height)?
}
_ => {
unreachable!("Video tokens are not supported for this model configuration")
}
};
input_chunks.push(Chunk::Video(Video {
data: processed_video.frames.iter().flatten().cloned().collect(),
mimetype: processed_video.mimetype.clone(),
width: processed_video.width,
height: processed_video.height,
num_frames: processed_video.frames.len() as u32,
}));
let video_tokens = video_tokens(
config,
processed_video.height,
processed_video.width,
processed_video.sampled_frames as f32,
);
tokenizer_query.push_str(&video_tokens);
start = chunk_end;
}
// clip remaining inputs and process images
let remaining_input = &inputs[start..];
for chunk in RE.find_iter(remaining_input) {
let chunk_start = chunk.start() + start;
let chunk_end = chunk.end() + start;
// handle image content after video content
for chunk in RE.find_iter(&inputs) {
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
input_chunks.push(Chunk::Image(Image { data, mimetype }));
input_chunks.push(Chunk::Image(Image {
data,
mimetype: mimetype.clone(),
}));
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end;
}
// Add any remaining text
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()));
tokenizer_query.push_str(&inputs[start..]);
}
// Apply any necessary token fixups
tokenizer_query = image_tokens_fixup(config, tokenizer_query);
(tokenizer_query, input_chunks)
}
_ => (inputs.clone(), vec![Chunk::Text(inputs)]),
@ -797,7 +887,6 @@ fn prepare_input<T: TokenizerTrait>(
Ok((encoding, input_chunks))
}
type TokenizerRequest = (
(String, bool, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
@ -810,16 +899,21 @@ pub struct Image {
pub mimetype: String,
}
pub struct ProcessedVideo {
mimetype: String,
height: u32,
width: u32,
frames: Vec<Vec<u8>>, // RGB frames
sampled_frames: usize,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Video {
pub data: Vec<u8>,
pub mimetype: String,
}
impl Video {
pub fn as_bytes(&self) -> Vec<u8> {
self.frames.iter().flatten().cloned().collect()
}
pub width: u32,
pub height: u32,
pub num_frames: u32,
}
#[derive(Debug, Clone, Eq, PartialEq)]
@ -845,9 +939,19 @@ impl ChunksToString for Vec<Chunk> {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
Chunk::Video(Video { data, mimetype }) => {
Chunk::Video(Video {
data,
mimetype,
width,
height: _,
num_frames: _,
}) => {
// TODO: revisit if we should limit video support to v3 - to avoid sending very large base64 strings
let encoded = STANDARD.encode(data);
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
output.push_str(&format!(
r#"<video width="{}"><source src="data:{};base64,{}" type="{}"></video>"#,
width, mimetype, encoded, mimetype
));
}
});
output
@ -983,6 +1087,10 @@ pub enum ValidationError {
MP4Error,
#[error("no video stream found")]
NoVideoStream,
#[error("io error: {0}")]
IoError(#[from] std::io::Error),
#[error("ffmpeg error")]
FFmpegError,
}
#[cfg(test)]