downloading videos

This commit is contained in:
Miquel Farre 2024-11-14 11:36:11 +00:00 committed by drbh
parent c7c2fdae8c
commit b9c8152ac6

View File

@ -1,7 +1,9 @@
import torch import torch
import requests
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
from contextlib import contextmanager
from opentelemetry import trace from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict from typing import Iterable, Optional, Tuple, List, Type, Dict
@ -218,8 +220,10 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
) )
image_id += 1 image_id += 1
elif chunk_type == "video" and config.model_type == "qwen2_vl": elif chunk_type == "video" and config.model_type == "qwen2_vl":
# Based on Qwen2VL's video token format # Download and process video in a temporary context
full_text += f"<video>{chunk.video}</video>" with cls.temp_video_download(chunk.video) as local_path:
# Now the video is available at local_path for processing
full_text += f"<video>{local_path}</video>"
full_text = image_text_replacement_fixup(config, full_text) full_text = image_text_replacement_fixup(config, full_text)
@ -271,7 +275,27 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
batch.image_sizes = None batch.image_sizes = None
batch.image_grid_thw = None batch.image_grid_thw = None
return batch return batch
@staticmethod
@contextmanager
def temp_video_download(url: str) -> str:
"""Downloads video to a temporary file and cleans it up after use."""
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(url)[1], delete=False) as tmp_file:
try:
# Download video
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
if chunk:
tmp_file.write(chunk)
tmp_file.flush()
yield tmp_file.name
finally:
# Clean up temp file
try:
os.unlink(tmp_file.name)
except OSError:
pass
class VlmCausalLM(FlashCausalLM): class VlmCausalLM(FlashCausalLM):
def __init__( def __init__(