downloading videos

This commit is contained in:
Miquel Farre 2024-11-14 11:36:11 +00:00 committed by drbh
parent c7c2fdae8c
commit b9c8152ac6

View File

@ -1,7 +1,9 @@
import torch
import requests
from PIL import Image
from io import BytesIO
from contextlib import contextmanager
from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict
@ -218,8 +220,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
)
image_id += 1
elif chunk_type == "video" and config.model_type == "qwen2_vl":
# Based on Qwen2VL's video token format
full_text += f"<video>{chunk.video}</video>"
# Download and process video in a temporary context
with cls.temp_video_download(chunk.video) as local_path:
# Now the video is available at local_path for processing
full_text += f"<video>{local_path}</video>"
full_text = image_text_replacement_fixup(config, full_text)
@ -272,6 +276,26 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_grid_thw = None
return batch
@staticmethod
@contextmanager
def temp_video_download(url: str) -> str:
"""Downloads video to a temporary file and cleans it up after use."""
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(url)[1], delete=False) as tmp_file:
try:
# Download video
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
if chunk:
tmp_file.write(chunk)
tmp_file.flush()
yield tmp_file.name
finally:
# Clean up temp file
try:
os.unlink(tmp_file.name)
except OSError:
pass
class VlmCausalLM(FlashCausalLM):
def __init__(