fix: adjust video process, reduce to 1 fps and adjust tensor shape

This commit is contained in:
drbh 2024-11-25 16:40:32 -05:00
parent 36e095b38d
commit bc5e202d2c
5 changed files with 46 additions and 24 deletions

View File

@ -79,14 +79,19 @@ impl ChunksToString for Vec<InputChunk> {
let encoded = STANDARD.encode(data); let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded)) output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
} }
Some(Chunk::Video(video)) => { Some(Chunk::Video(Video {
let encoded = STANDARD.encode(&video.as_bytes()); data,
output.push_str(&format!("<video>(data:{};base64,{})", video.mimetype, encoded)) mimetype,
width,
frames: _,
})) => {
// TODO: revisit if we should limit video support to v3 - to avoid sending very large base64 strings
let encoded = STANDARD.encode(data);
output.push_str(&format!(
r#"<video width="{}"><source src="data:{};base64,{}" type="{}"></video>"#,
width, mimetype, encoded, mimetype
));
} }
// Some(Chunk::Video(Video { data, mimetype })) => {
// let encoded = STANDARD.encode(data);
// output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
// }
// We don't create empty chunks, so this should be unreachable. // We don't create empty chunks, so this should be unreachable.
None => unreachable!("Chunks should never be empty"), None => unreachable!("Chunks should never be empty"),
}); });

View File

@ -440,8 +440,10 @@ impl State {
mimetype: image.mimetype, mimetype: image.mimetype,
}), }),
Chunk::Video(video) => client::Chunk::Video(client::Video { Chunk::Video(video) => client::Chunk::Video(client::Video {
data: video.frames, data: video.data,
mimetype: video.mimetype, mimetype: video.mimetype,
width: video.width,
frames: video.num_frames,
}), }),
}), }),
}) })

View File

@ -65,11 +65,17 @@ message Image {
} }
message Video { message Video {
/// Binary video data. /// Binary video data (array of RGB data)
bytes data = 1; bytes data = 1;
/// Video MIME type. /// Video MIME type.
string mimetype = 2; string mimetype = 2;
/// Video width
uint32 width = 3;
/// Total number of frames
uint32 frames = 4;
} }
message InputChunk { message InputChunk {

View File

@ -18,7 +18,6 @@ from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.metadata_kernels import block_tables_to_ragged from text_generation_server.models.metadata_kernels import block_tables_to_ragged
from torchvision import io
import math import math
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -240,7 +239,27 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
images.append([image]) images.append([image])
elif chunk_type == "video": elif chunk_type == "video":
if config.model_type == "qwen2_vl": if config.model_type == "qwen2_vl":
videos.append(chunk.video) video_frame_buf = np.frombuffer(
chunk.video.data, dtype=np.uint8
)
num_bytes = len(video_frame_buf)
bytes_per_frame = num_bytes // chunk.video.frames
height = bytes_per_frame // 3 // chunk.video.width
# iterate over with a stride the size of a frame
frames = []
for i in range(chunk.video.frames):
frame = video_frame_buf[
i * bytes_per_frame : (i + 1) * bytes_per_frame
]
frame = frame.reshape(height, chunk.video.width, 3)
frames.append(frame)
video_frame_buf = np.stack(frames)
frame_nchw_tensor = torch.from_numpy(video_frame_buf).permute(
0, 3, 1, 2
)
videos.append(frame_nchw_tensor)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
@ -252,20 +271,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
video_inputs = None video_inputs = None
if videos: if videos:
try: try:
video = videos[0]
# Frames are already sampled and resized
frames = [
torch.from_numpy(np.frombuffer(frame, dtype=np.uint8).reshape(video.height, video.width, 3))
for frame in video.frames
]
video_tensor = torch.stack(frames).permute(0, 3, 1, 2) # NHWC -> NCHW
# Apply any additional preprocessing required by the model
tensor_videos = [video_tensor]
video_inputs = processor.image_processor( video_inputs = processor.image_processor(
tensor_videos, return_tensors="pt" videos,
return_tensors="pt",
) )
except Exception as e: except Exception as e:
print(f"Failed to process video: {e}") print(f"Failed to process video: {e}")
pass pass