mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
fix: resolve rebase issues and add test
This commit is contained in:
parent
71ed75a21b
commit
e2b75a572f
@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "",
|
||||||
|
"role": "assistant"
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1733450914,
|
||||||
|
"id": "",
|
||||||
|
"model": "Qwen/Qwen2-VL-7B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.4.2-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
}
|
84
integration-tests/models/test_flash_qwen2_vl_video.py
Normal file
84
integration-tests/models/test_flash_qwen2_vl_video.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def qwen2_vl_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"Qwen/Qwen2-VL-7B-Instruct",
|
||||||
|
max_input_length=10_000,
|
||||||
|
max_batch_prefill_tokens=10_000,
|
||||||
|
max_total_tokens=10_001,
|
||||||
|
cuda_graphs=[0],
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def qwen2_vl(qwen2_vl_handle):
|
||||||
|
await qwen2_vl_handle.health(300)
|
||||||
|
return qwen2_vl_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qwen2_vl_simpl(qwen2_vl, response_snapshot):
|
||||||
|
responses = requests.post(
|
||||||
|
f"{qwen2_vl.base_url}/v1/chat/completions",
|
||||||
|
headers=qwen2_vl.headers,
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "video_url",
|
||||||
|
"video_url": {
|
||||||
|
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/360/Big_Buck_Bunny_360_10s_1MB.mp4"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe this video.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"seed": 42,
|
||||||
|
"max_tokens": 100,
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# iterate over the response in chunks
|
||||||
|
count = 0
|
||||||
|
full_text = ""
|
||||||
|
last_response = None
|
||||||
|
for chunk in responses.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
count += 1
|
||||||
|
# remove the "data: " prefix, trailing newline, and split the chunk into individual lines
|
||||||
|
lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n")
|
||||||
|
for line in lines:
|
||||||
|
if line == "[DONE]":
|
||||||
|
break
|
||||||
|
print("=", line)
|
||||||
|
try:
|
||||||
|
response = json.loads(line)
|
||||||
|
# print(response)
|
||||||
|
last_response = response
|
||||||
|
full_text += response["choices"][0]["delta"]["content"]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
# assert count == 27
|
||||||
|
# assert response.usage == {
|
||||||
|
# "completion_tokens": 10,
|
||||||
|
# "prompt_tokens": 50,
|
||||||
|
# "total_tokens": 60,
|
||||||
|
# }
|
||||||
|
# assert (
|
||||||
|
# response.choices[0].message.content
|
||||||
|
# == "In a bustling city, a chicken named Cluck"
|
||||||
|
# )
|
||||||
|
assert last_response == response_snapshot
|
@ -21,6 +21,11 @@ use tokio::sync::mpsc;
|
|||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tracing::{instrument, Span};
|
use tracing::{instrument, Span};
|
||||||
use {once_cell::sync::Lazy, regex::Regex};
|
use {once_cell::sync::Lazy, regex::Regex};
|
||||||
|
// video processing
|
||||||
|
use ffmpeg_next::format::Pixel;
|
||||||
|
use ffmpeg_next::media::Type;
|
||||||
|
use ffmpeg_next::software::scaling::{context::Context, flag::Flags};
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
static DEFAULT_GENERATION_LENGTH: u32 = 1024;
|
static DEFAULT_GENERATION_LENGTH: u32 = 1024;
|
||||||
|
|
||||||
@ -536,7 +541,11 @@ fn format_to_mimetype(format: ImageFormat) -> String {
|
|||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fetch_video(input: &str) -> Result<(Vec<u8>, String, usize, usize, f32), ValidationError> {
|
pub fn fetch_video(
|
||||||
|
input: &str,
|
||||||
|
target_width: u32,
|
||||||
|
target_height: u32,
|
||||||
|
) -> Result<ProcessedVideo, ValidationError> {
|
||||||
let (data, mimetype) =
|
let (data, mimetype) =
|
||||||
if input.starts_with("<video>(http://") || input.starts_with("<video>(https://") {
|
if input.starts_with("<video>(http://") || input.starts_with("<video>(https://") {
|
||||||
let url = &input["<video>(".len()..input.len() - 1];
|
let url = &input["<video>(".len()..input.len() - 1];
|
||||||
@ -544,7 +553,7 @@ pub fn fetch_video(input: &str) -> Result<(Vec<u8>, String, usize, usize, f32),
|
|||||||
(data, "video/mp4".to_string())
|
(data, "video/mp4".to_string())
|
||||||
} else if input.starts_with("<video>(data:") {
|
} else if input.starts_with("<video>(data:") {
|
||||||
let content = &input["<video>(data:".len()..input.len() - 1];
|
let content = &input["<video>(data:".len()..input.len() - 1];
|
||||||
let tokens: Vec<_> = content.split(';').collect();
|
let tokens: Vec<&str> = content.split(';').collect();
|
||||||
if tokens.len() != 2 {
|
if tokens.len() != 2 {
|
||||||
return Err(ValidationError::InvalidVideoContent(content.to_string()));
|
return Err(ValidationError::InvalidVideoContent(content.to_string()));
|
||||||
}
|
}
|
||||||
@ -559,45 +568,101 @@ pub fn fetch_video(input: &str) -> Result<(Vec<u8>, String, usize, usize, f32),
|
|||||||
return Err(ValidationError::InvalidVideoContent(input.to_string()));
|
return Err(ValidationError::InvalidVideoContent(input.to_string()));
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut cursor = Cursor::new(&data);
|
// init ffmpeg
|
||||||
let context = mp4parse::read_mp4(&mut cursor).map_err(|_| ValidationError::MP4Error)?;
|
ffmpeg_next::init().map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
|
||||||
let video_track = context
|
// create temporary file for ffmpeg input
|
||||||
.tracks
|
let mut temp_file = tempfile::NamedTempFile::new().map_err(ValidationError::IoError)?;
|
||||||
.iter()
|
temp_file
|
||||||
.find(|track| track.track_type == mp4parse::TrackType::Video)
|
.write_all(&data)
|
||||||
.ok_or(ValidationError::NoVideoStream)?;
|
.map_err(ValidationError::IoError)?;
|
||||||
|
|
||||||
let video_info = video_track
|
let mut ictx =
|
||||||
.tkhd
|
ffmpeg_next::format::input(&temp_file.path()).map_err(|_| ValidationError::FFmpegError)?;
|
||||||
.as_ref()
|
|
||||||
.ok_or(ValidationError::NoVideoStream)?;
|
|
||||||
let width = (video_info.width >> 16) as usize;
|
|
||||||
let height = (video_info.height >> 16) as usize;
|
|
||||||
|
|
||||||
// timescale units per second
|
let input = ictx
|
||||||
let timescale = video_track.timescale.map(|t| t.0 as f32).unwrap_or(600.0);
|
.streams()
|
||||||
|
.best(Type::Video)
|
||||||
|
.ok_or(ValidationError::FFmpegError)?;
|
||||||
|
let video_stream_index = input.index();
|
||||||
|
|
||||||
// TODO: revisit if we need duration in seconds
|
let context_decoder = ffmpeg_next::codec::context::Context::from_parameters(input.parameters())
|
||||||
let _duration = video_track
|
.map_err(|_| ValidationError::FFmpegError)?;
|
||||||
.duration
|
let mut decoder = context_decoder
|
||||||
.map(|d| d.0 as f32 / timescale)
|
.decoder()
|
||||||
.unwrap_or(0.0);
|
.video()
|
||||||
|
.map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
|
||||||
let time_to_sample = video_track
|
let width = target_width;
|
||||||
.stts
|
let height = target_height;
|
||||||
.as_ref()
|
|
||||||
.ok_or(ValidationError::NoVideoStream)?;
|
|
||||||
|
|
||||||
let num_samples = time_to_sample
|
let mut scaler = Context::get(
|
||||||
.samples
|
decoder.format(),
|
||||||
.iter()
|
decoder.width(), // original width
|
||||||
.map(|entry| entry.sample_count)
|
decoder.height(),
|
||||||
.sum::<u32>();
|
Pixel::RGB24,
|
||||||
|
width, // target width
|
||||||
|
height,
|
||||||
|
Flags::BILINEAR,
|
||||||
|
)
|
||||||
|
.map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
|
||||||
let total_frames = num_samples as f32;
|
let mut frame_index = 0;
|
||||||
|
let mut captured_frame_index = 0;
|
||||||
|
let mut frames = vec![];
|
||||||
|
|
||||||
Ok((data, mimetype, height, width, total_frames))
|
let mut receive_and_process_decoded_frames = |decoder: &mut ffmpeg_next::decoder::Video,
|
||||||
|
raw_fps: f32|
|
||||||
|
-> Result<(), ffmpeg_next::Error> {
|
||||||
|
let mut decoded = ffmpeg_next::util::frame::video::Video::empty();
|
||||||
|
let fps = raw_fps.floor();
|
||||||
|
while decoder.receive_frame(&mut decoded).is_ok() {
|
||||||
|
let mut rgb_frame = ffmpeg_next::util::frame::video::Video::empty();
|
||||||
|
scaler.run(&decoded, &mut rgb_frame)?;
|
||||||
|
if frame_index as f32 % fps == 0.0 {
|
||||||
|
captured_frame_index += 1;
|
||||||
|
// Create new buffer without padding
|
||||||
|
let mut frame_data =
|
||||||
|
Vec::with_capacity((rgb_frame.width() * rgb_frame.height() * 3) as usize);
|
||||||
|
let src_data = rgb_frame.data(0);
|
||||||
|
let row_size = rgb_frame.width() as usize * 3;
|
||||||
|
|
||||||
|
// Copy each row without padding
|
||||||
|
for y in 0..rgb_frame.height() as usize {
|
||||||
|
let start = y * rgb_frame.stride(0) as usize;
|
||||||
|
let end = start + row_size;
|
||||||
|
frame_data.extend_from_slice(&src_data[start..end]);
|
||||||
|
}
|
||||||
|
frames.push(frame_data);
|
||||||
|
}
|
||||||
|
frame_index += 1;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
};
|
||||||
|
|
||||||
|
for (stream, packet) in ictx.packets() {
|
||||||
|
// Floor the fps to get a whole number
|
||||||
|
let fps = (stream.rate().numerator() as f32 / stream.rate().denominator() as f32).floor();
|
||||||
|
|
||||||
|
if stream.index() == video_stream_index {
|
||||||
|
decoder
|
||||||
|
.send_packet(&packet)
|
||||||
|
.map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
receive_and_process_decoded_frames(&mut decoder, fps)
|
||||||
|
.map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
decoder
|
||||||
|
.send_eof()
|
||||||
|
.map_err(|_| ValidationError::FFmpegError)?;
|
||||||
|
|
||||||
|
Ok(ProcessedVideo {
|
||||||
|
mimetype,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
frames,
|
||||||
|
sampled_frames: captured_frame_index,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
||||||
@ -688,22 +753,18 @@ fn image_tokens(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn video_tokens(config: &Config, height: usize, width: usize, total_frames: f32) -> String {
|
fn video_tokens(config: &Config, height: u32, width: u32, sampled_frames: f32) -> String {
|
||||||
use Config::*;
|
use Config::*;
|
||||||
|
|
||||||
match config {
|
match config {
|
||||||
// TOOD: improve to use the config to better estimate the number of tokens
|
// TOOD: improve to use the config to better estimate the number of tokens
|
||||||
Qwen2Vl(_config) => {
|
Qwen2Vl(_config) => {
|
||||||
let video_fps = 30_f32;
|
let min_frames = 2_f32;
|
||||||
let fps = 30_f32;
|
let max_frames = 256_f32;
|
||||||
let min_frames = 16_f32;
|
|
||||||
let max_frames = 64_f32;
|
|
||||||
// make sure the frames are within the range and are even
|
// make sure the frames are within the range and are even
|
||||||
let nframes = (total_frames / video_fps * fps)
|
let nframes = (sampled_frames).max(min_frames).min(max_frames);
|
||||||
.max(min_frames)
|
|
||||||
.min(max_frames);
|
|
||||||
let nframes = (nframes / 2.0).round() as usize * 2;
|
let nframes = (nframes / 2.0).round() as usize * 2;
|
||||||
let num_tokens = nframes * height * width / 1541;
|
let num_tokens = nframes * height as usize * width as usize / 1541;
|
||||||
format!(
|
format!(
|
||||||
"<|vision_start|>{:?}<|vision_end|>",
|
"<|vision_start|>{:?}<|vision_end|>",
|
||||||
"<|video_pad|>".repeat(num_tokens)
|
"<|video_pad|>".repeat(num_tokens)
|
||||||
@ -754,37 +815,66 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (data, mimetype, height, width, total_frames) =
|
let processed_video = match config {
|
||||||
fetch_video(&inputs[chunk_start..chunk_end])?;
|
Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) => {
|
||||||
input_chunks.push(Chunk::Video(Video { data, mimetype }));
|
let default_target_width = 224;
|
||||||
let video_tokens = video_tokens(config, height, width, total_frames);
|
let default_target_height = 224;
|
||||||
|
fetch_video(
|
||||||
|
&inputs[chunk_start..chunk_end],
|
||||||
|
default_target_width,
|
||||||
|
default_target_height,
|
||||||
|
)?
|
||||||
|
}
|
||||||
|
Qwen2Vl(_) => {
|
||||||
|
let target_width = 360;
|
||||||
|
let target_height = 420;
|
||||||
|
fetch_video(&inputs[chunk_start..chunk_end], target_width, target_height)?
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
unreachable!("Video tokens are not supported for this model configuration")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
input_chunks.push(Chunk::Video(Video {
|
||||||
|
data: processed_video.frames.iter().flatten().cloned().collect(),
|
||||||
|
mimetype: processed_video.mimetype.clone(),
|
||||||
|
width: processed_video.width,
|
||||||
|
height: processed_video.height,
|
||||||
|
num_frames: processed_video.frames.len() as u32,
|
||||||
|
}));
|
||||||
|
let video_tokens = video_tokens(
|
||||||
|
config,
|
||||||
|
processed_video.height,
|
||||||
|
processed_video.width,
|
||||||
|
processed_video.sampled_frames as f32,
|
||||||
|
);
|
||||||
tokenizer_query.push_str(&video_tokens);
|
tokenizer_query.push_str(&video_tokens);
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
|
|
||||||
// clip remaining inputs and process images
|
// handle image content after video content
|
||||||
let remaining_input = &inputs[start..];
|
for chunk in RE.find_iter(&inputs) {
|
||||||
for chunk in RE.find_iter(remaining_input) {
|
let chunk_start = chunk.start();
|
||||||
let chunk_start = chunk.start() + start;
|
let chunk_end = chunk.end();
|
||||||
let chunk_end = chunk.end() + start;
|
|
||||||
if chunk_start != start {
|
if chunk_start != start {
|
||||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||||
input_chunks.push(Chunk::Image(Image { data, mimetype }));
|
input_chunks.push(Chunk::Image(Image {
|
||||||
|
data,
|
||||||
|
mimetype: mimetype.clone(),
|
||||||
|
}));
|
||||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add any remaining text
|
|
||||||
if start != inputs.len() {
|
if start != inputs.len() {
|
||||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()));
|
input_chunks.push(Chunk::Text(inputs[start..].to_string()));
|
||||||
tokenizer_query.push_str(&inputs[start..]);
|
tokenizer_query.push_str(&inputs[start..]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply any necessary token fixups
|
|
||||||
tokenizer_query = image_tokens_fixup(config, tokenizer_query);
|
tokenizer_query = image_tokens_fixup(config, tokenizer_query);
|
||||||
|
|
||||||
(tokenizer_query, input_chunks)
|
(tokenizer_query, input_chunks)
|
||||||
}
|
}
|
||||||
_ => (inputs.clone(), vec![Chunk::Text(inputs)]),
|
_ => (inputs.clone(), vec![Chunk::Text(inputs)]),
|
||||||
@ -797,7 +887,6 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
|
|
||||||
Ok((encoding, input_chunks))
|
Ok((encoding, input_chunks))
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizerRequest = (
|
type TokenizerRequest = (
|
||||||
(String, bool, Option<usize>),
|
(String, bool, Option<usize>),
|
||||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
||||||
@ -810,16 +899,21 @@ pub struct Image {
|
|||||||
pub mimetype: String,
|
pub mimetype: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct ProcessedVideo {
|
||||||
|
mimetype: String,
|
||||||
|
height: u32,
|
||||||
|
width: u32,
|
||||||
|
frames: Vec<Vec<u8>>, // RGB frames
|
||||||
|
sampled_frames: usize,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
pub struct Video {
|
pub struct Video {
|
||||||
pub data: Vec<u8>,
|
pub data: Vec<u8>,
|
||||||
pub mimetype: String,
|
pub mimetype: String,
|
||||||
}
|
pub width: u32,
|
||||||
|
pub height: u32,
|
||||||
impl Video {
|
pub num_frames: u32,
|
||||||
pub fn as_bytes(&self) -> Vec<u8> {
|
|
||||||
self.frames.iter().flatten().cloned().collect()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
@ -845,9 +939,19 @@ impl ChunksToString for Vec<Chunk> {
|
|||||||
let encoded = STANDARD.encode(data);
|
let encoded = STANDARD.encode(data);
|
||||||
output.push_str(&format!("", mimetype, encoded))
|
output.push_str(&format!("", mimetype, encoded))
|
||||||
}
|
}
|
||||||
Chunk::Video(Video { data, mimetype }) => {
|
Chunk::Video(Video {
|
||||||
|
data,
|
||||||
|
mimetype,
|
||||||
|
width,
|
||||||
|
height: _,
|
||||||
|
num_frames: _,
|
||||||
|
}) => {
|
||||||
|
// TODO: revisit if we should limit video support to v3 - to avoid sending very large base64 strings
|
||||||
let encoded = STANDARD.encode(data);
|
let encoded = STANDARD.encode(data);
|
||||||
output.push_str(&format!("<video>(data:{};base64,{})", mimetype, encoded))
|
output.push_str(&format!(
|
||||||
|
r#"<video width="{}"><source src="data:{};base64,{}" type="{}"></video>"#,
|
||||||
|
width, mimetype, encoded, mimetype
|
||||||
|
));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
output
|
output
|
||||||
@ -983,6 +1087,10 @@ pub enum ValidationError {
|
|||||||
MP4Error,
|
MP4Error,
|
||||||
#[error("no video stream found")]
|
#[error("no video stream found")]
|
||||||
NoVideoStream,
|
NoVideoStream,
|
||||||
|
#[error("io error: {0}")]
|
||||||
|
IoError(#[from] std::io::Error),
|
||||||
|
#[error("ffmpeg error")]
|
||||||
|
FFmpegError,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
Loading…
Reference in New Issue
Block a user