feat: support max_image_fetch_size to limit (#3339)

* feat: support max_image_fetch_size to limit

* fix: update model path for test

* fix: adjust model repo id for test again

* fix: apply clippy lints

* fix: clippy fix

* fix: avoid torch build isolation in docker

* fix: bump repo id in flash llama tests

* fix: temporarily avoid problematic repos in tests
This commit is contained in:
drbh 2025-11-18 12:29:21 -05:00 committed by GitHub
parent 85790a19a7
commit 24ee40d143
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 78 additions and 8 deletions

View File

@ -157,6 +157,10 @@ struct Args {
/// Maximum payload size in bytes.
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
/// Maximum image fetch size in bytes.
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}
#[tokio::main]
@ -320,6 +324,7 @@ async fn main() -> Result<(), RouterError> {
args.max_client_batch_size,
args.usage_stats,
args.payload_limit,
args.max_image_fetch_size,
args.prometheus_port,
)
.await?;

View File

@ -67,6 +67,8 @@ struct Args {
usage_stats: UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}
async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
@ -244,6 +246,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
executor_worker,
usage_stats,
payload_limit,
max_image_fetch_size,
} = args;
// Launch Tokio runtime
@ -325,6 +328,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
prometheus_port,
)
.await?;

View File

@ -74,6 +74,8 @@ struct Args {
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}
#[derive(Debug, Subcommand)]
@ -120,6 +122,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
} = args;
if let Some(Commands::PrintSchema) = command {
@ -201,6 +204,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
prometheus_port,
)
.await?;

View File

@ -74,6 +74,8 @@ struct Args {
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}
#[derive(Debug, Subcommand)]
@ -120,6 +122,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
} = args;
if let Some(Commands::PrintSchema) = command {
@ -217,6 +220,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
prometheus_port,
)
.await?;

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module")
def flash_llama_handle(launcher):
with launcher("huggingface/llama-7b", num_shard=2) as handle:
with launcher("huggyllama/llama-7b", num_shard=2) as handle:
yield handle

View File

@ -13,6 +13,7 @@ async def flash_llama_fp8(flash_llama_fp8_handle):
return flash_llama_fp8_handle.client
@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@ -26,6 +27,7 @@ async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot):
assert response == response_snapshot
@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@ -49,6 +51,7 @@ async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot):
assert response == response_snapshot
@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private

View File

@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin24_handle):
return flash_llama_marlin24_handle.client
@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@ -27,6 +28,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
assert response == response_snapshot
@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@ -50,6 +52,7 @@ async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snap
assert response == response_snapshot
@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private

View File

@ -673,7 +673,7 @@ mod tests {
let (name, arguments) = get_tool_call_content(&events[0]);
if let Some(name) = name {
assert_eq!(name, "get_current_weather");
output_name.push_str(&name);
output_name.push_str(name);
}
output.push_str(arguments);
} else {

View File

@ -1523,6 +1523,7 @@ pub async fn run(
max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel,
payload_limit: usize,
max_image_fetch_size: usize,
prometheus_port: u16,
) -> Result<(), WebServerError> {
// CORS allowed origins
@ -1827,6 +1828,7 @@ pub async fn run(
compat_return_full_text,
allow_origin,
payload_limit,
max_image_fetch_size,
prometheus_port,
)
.await;
@ -1889,6 +1891,7 @@ async fn start(
compat_return_full_text: bool,
allow_origin: Option<AllowOrigin>,
payload_limit: usize,
max_image_fetch_size: usize,
prometheus_port: u16,
) -> Result<(), WebServerError> {
// Determine the server port based on the feature and environment variable.
@ -1920,6 +1923,7 @@ async fn start(
max_input_tokens,
max_total_tokens,
disable_grammar_support,
max_image_fetch_size,
);
let infer = Infer::new(

View File

@ -12,7 +12,7 @@ use rand::{thread_rng, Rng};
use serde_json::Value;
/// Payload validation logic
use std::cmp::min;
use std::io::Cursor;
use std::io::{Cursor, Read};
use std::iter;
use std::sync::Arc;
use thiserror::Error;
@ -51,6 +51,7 @@ impl Validation {
max_input_length: usize,
max_total_tokens: usize,
disable_grammar_support: bool,
max_image_fetch_size: usize,
) -> Self {
let workers = if let Tokenizer::Python { .. } = &tokenizer {
1
@ -78,6 +79,7 @@ impl Validation {
config_clone,
preprocessor_config_clone,
tokenizer_receiver,
max_image_fetch_size,
)
});
}
@ -480,6 +482,7 @@ fn tokenizer_worker(
config: Option<Config>,
preprocessor_config: Option<HubPreprocessorConfig>,
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
max_image_fetch_size: usize,
) {
match tokenizer {
Tokenizer::Python {
@ -503,6 +506,7 @@ fn tokenizer_worker(
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
max_image_fetch_size,
))
.unwrap_or(())
})
@ -524,6 +528,7 @@ fn tokenizer_worker(
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
max_image_fetch_size,
))
.unwrap_or(())
})
@ -562,10 +567,35 @@ fn format_to_mimetype(format: ImageFormat) -> String {
.to_string()
}
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
fn fetch_image(
input: &str,
max_image_fetch_size: usize,
) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
if input.starts_with("![](http://") || input.starts_with("![](https://") {
let url = &input["![](".len()..input.len() - 1];
let data = reqwest::blocking::get(url)?.bytes()?;
let response = reqwest::blocking::get(url)?;
// Check Content-Length header if present
if let Some(content_length) = response.content_length() {
if content_length as usize > max_image_fetch_size {
return Err(ValidationError::ImageTooLarge(
content_length as usize,
max_image_fetch_size,
));
}
}
// Read the body with size limit to prevent unbounded memory allocation
let mut data = Vec::new();
let mut limited_reader = response.take((max_image_fetch_size + 1) as u64);
limited_reader.read_to_end(&mut data)?;
if data.len() > max_image_fetch_size {
return Err(ValidationError::ImageTooLarge(
data.len(),
max_image_fetch_size,
));
}
let format = image::guess_format(&data)?;
// TODO Remove this clone
@ -787,6 +817,7 @@ fn prepare_input<T: TokenizerTrait>(
tokenizer: &T,
config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>,
max_image_fetch_size: usize,
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
@ -805,7 +836,8 @@ fn prepare_input<T: TokenizerTrait>(
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let (data, mimetype, height, width) =
fetch_image(&inputs[chunk_start..chunk_end], max_image_fetch_size)?;
input_chunks.push(Chunk::Image(Image { data, mimetype }));
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end;
@ -990,6 +1022,10 @@ pub enum ValidationError {
InvalidImageContent(String),
#[error("Could not fetch image: {0}")]
FailedFetchImage(#[from] reqwest::Error),
#[error("Image size {0} bytes exceeds maximum allowed size of {1} bytes")]
ImageTooLarge(usize, usize),
#[error("Failed to read image data: {0}")]
ImageReadError(#[from] std::io::Error),
#[error("{0} modality is not supported")]
UnsupportedModality(&'static str),
}
@ -1023,6 +1059,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
let max_new_tokens = 10;
@ -1058,6 +1095,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
let max_new_tokens = 10;
@ -1092,6 +1130,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
match validation
.validate(GenerateRequest {
@ -1132,6 +1171,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
match validation
.validate(GenerateRequest {
@ -1203,6 +1243,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
match validation
.validate(GenerateRequest {
@ -1293,6 +1334,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
let chunks = match validation
@ -1349,6 +1391,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
let (encoding, chunks) = match validation

View File

@ -3,7 +3,7 @@ flash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd
build-flash-attention-v2-cuda:
pip install -U packaging wheel
pip install flash-attn==$(flash_att_v2_commit_cuda)
pip install --no-build-isolation flash-attn==$(flash_att_v2_commit_cuda)
install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
echo "Flash v2 installed"

View File

@ -14,7 +14,7 @@ def get_test_model():
def generate_token(self, batch):
raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
model = TestModel(
"test_model_id",