diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 9ee61ce6e..ea72dcaee 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -157,6 +157,10 @@ struct Args { /// Maximum payload size in bytes. #[clap(default_value = "2000000", long, env)] payload_limit: usize, + + /// Maximum image fetch size in bytes. + #[clap(default_value = "1073741824", long, env)] + max_image_fetch_size: usize, } #[tokio::main] @@ -320,6 +324,7 @@ async fn main() -> Result<(), RouterError> { args.max_client_batch_size, args.usage_stats, args.payload_limit, + args.max_image_fetch_size, args.prometheus_port, ) .await?; diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 543f8e6e3..26d1dbf8c 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -67,6 +67,8 @@ struct Args { usage_stats: UsageStatsLevel, #[clap(default_value = "2000000", long, env)] payload_limit: usize, + #[clap(default_value = "1073741824", long, env)] + max_image_fetch_size: usize, } async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option { @@ -244,6 +246,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { executor_worker, usage_stats, payload_limit, + max_image_fetch_size, } = args; // Launch Tokio runtime @@ -325,6 +328,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { max_client_batch_size, usage_stats, payload_limit, + max_image_fetch_size, prometheus_port, ) .await?; diff --git a/backends/v2/src/main.rs b/backends/v2/src/main.rs index 60b5d52bb..575e62a02 100644 --- a/backends/v2/src/main.rs +++ b/backends/v2/src/main.rs @@ -74,6 +74,8 @@ struct Args { usage_stats: usage_stats::UsageStatsLevel, #[clap(default_value = "2000000", long, env)] payload_limit: usize, + #[clap(default_value = "1073741824", long, env)] + max_image_fetch_size: usize, } #[derive(Debug, Subcommand)] @@ -120,6 +122,7 @@ async fn main() -> Result<(), RouterError> { max_client_batch_size, usage_stats, payload_limit, + max_image_fetch_size, } = args; if let Some(Commands::PrintSchema) = command { @@ -201,6 +204,7 @@ async fn main() -> Result<(), RouterError> { max_client_batch_size, usage_stats, payload_limit, + max_image_fetch_size, prometheus_port, ) .await?; diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 44e63853e..a71ec3d09 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -74,6 +74,8 @@ struct Args { usage_stats: usage_stats::UsageStatsLevel, #[clap(default_value = "2000000", long, env)] payload_limit: usize, + #[clap(default_value = "1073741824", long, env)] + max_image_fetch_size: usize, } #[derive(Debug, Subcommand)] @@ -120,6 +122,7 @@ async fn main() -> Result<(), RouterError> { max_client_batch_size, usage_stats, payload_limit, + max_image_fetch_size, } = args; if let Some(Commands::PrintSchema) = command { @@ -217,6 +220,7 @@ async fn main() -> Result<(), RouterError> { max_client_batch_size, usage_stats, payload_limit, + max_image_fetch_size, prometheus_port, ) .await?; diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py index bf49dc0b4..f24215a08 100644 --- a/integration-tests/models/test_flash_llama.py +++ b/integration-tests/models/test_flash_llama.py @@ -3,7 +3,7 @@ import pytest @pytest.fixture(scope="module") def flash_llama_handle(launcher): - with launcher("huggingface/llama-7b", num_shard=2) as handle: + with launcher("huggyllama/llama-7b", num_shard=2) as handle: yield handle diff --git a/integration-tests/models/test_flash_llama_fp8.py b/integration-tests/models/test_flash_llama_fp8.py index 1980846d3..1d065f5d6 100644 --- a/integration-tests/models/test_flash_llama_fp8.py +++ b/integration-tests/models/test_flash_llama_fp8.py @@ -13,6 +13,7 @@ async def flash_llama_fp8(flash_llama_fp8_handle): return flash_llama_fp8_handle.client +@pytest.mark.skip(reason="Issue with the model access") @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private @@ -26,6 +27,7 @@ async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot): assert response == response_snapshot +@pytest.mark.skip(reason="Issue with the model access") @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private @@ -49,6 +51,7 @@ async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot): assert response == response_snapshot +@pytest.mark.skip(reason="Issue with the model access") @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private diff --git a/integration-tests/models/test_flash_llama_marlin_24.py b/integration-tests/models/test_flash_llama_marlin_24.py index 3eb94f02e..bd364ecf9 100644 --- a/integration-tests/models/test_flash_llama_marlin_24.py +++ b/integration-tests/models/test_flash_llama_marlin_24.py @@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin24_handle): return flash_llama_marlin24_handle.client +@pytest.mark.skip(reason="Issue with the model access") @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private @@ -27,6 +28,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): assert response == response_snapshot +@pytest.mark.skip(reason="Issue with the model access") @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private @@ -50,6 +52,7 @@ async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snap assert response == response_snapshot +@pytest.mark.skip(reason="Issue with the model access") @pytest.mark.release @pytest.mark.asyncio @pytest.mark.private diff --git a/router/src/chat.rs b/router/src/chat.rs index d5824fea0..824b775b3 100644 --- a/router/src/chat.rs +++ b/router/src/chat.rs @@ -673,7 +673,7 @@ mod tests { let (name, arguments) = get_tool_call_content(&events[0]); if let Some(name) = name { assert_eq!(name, "get_current_weather"); - output_name.push_str(&name); + output_name.push_str(name); } output.push_str(arguments); } else { diff --git a/router/src/server.rs b/router/src/server.rs index 97a0cea25..7f0bf74eb 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1523,6 +1523,7 @@ pub async fn run( max_client_batch_size: usize, usage_stats_level: usage_stats::UsageStatsLevel, payload_limit: usize, + max_image_fetch_size: usize, prometheus_port: u16, ) -> Result<(), WebServerError> { // CORS allowed origins @@ -1827,6 +1828,7 @@ pub async fn run( compat_return_full_text, allow_origin, payload_limit, + max_image_fetch_size, prometheus_port, ) .await; @@ -1889,6 +1891,7 @@ async fn start( compat_return_full_text: bool, allow_origin: Option, payload_limit: usize, + max_image_fetch_size: usize, prometheus_port: u16, ) -> Result<(), WebServerError> { // Determine the server port based on the feature and environment variable. @@ -1920,6 +1923,7 @@ async fn start( max_input_tokens, max_total_tokens, disable_grammar_support, + max_image_fetch_size, ); let infer = Infer::new( diff --git a/router/src/validation.rs b/router/src/validation.rs index 7717f373e..b32f5f8b5 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -12,7 +12,7 @@ use rand::{thread_rng, Rng}; use serde_json::Value; /// Payload validation logic use std::cmp::min; -use std::io::Cursor; +use std::io::{Cursor, Read}; use std::iter; use std::sync::Arc; use thiserror::Error; @@ -51,6 +51,7 @@ impl Validation { max_input_length: usize, max_total_tokens: usize, disable_grammar_support: bool, + max_image_fetch_size: usize, ) -> Self { let workers = if let Tokenizer::Python { .. } = &tokenizer { 1 @@ -78,6 +79,7 @@ impl Validation { config_clone, preprocessor_config_clone, tokenizer_receiver, + max_image_fetch_size, ) }); } @@ -480,6 +482,7 @@ fn tokenizer_worker( config: Option, preprocessor_config: Option, mut receiver: mpsc::UnboundedReceiver, + max_image_fetch_size: usize, ) { match tokenizer { Tokenizer::Python { @@ -503,6 +506,7 @@ fn tokenizer_worker( &tokenizer, config.as_ref(), preprocessor_config.as_ref(), + max_image_fetch_size, )) .unwrap_or(()) }) @@ -524,6 +528,7 @@ fn tokenizer_worker( &tokenizer, config.as_ref(), preprocessor_config.as_ref(), + max_image_fetch_size, )) .unwrap_or(()) }) @@ -562,10 +567,35 @@ fn format_to_mimetype(format: ImageFormat) -> String { .to_string() } -fn fetch_image(input: &str) -> Result<(Vec, String, usize, usize), ValidationError> { +fn fetch_image( + input: &str, + max_image_fetch_size: usize, +) -> Result<(Vec, String, usize, usize), ValidationError> { if input.starts_with("![](http://") || input.starts_with("![](https://") { let url = &input["![](".len()..input.len() - 1]; - let data = reqwest::blocking::get(url)?.bytes()?; + let response = reqwest::blocking::get(url)?; + + // Check Content-Length header if present + if let Some(content_length) = response.content_length() { + if content_length as usize > max_image_fetch_size { + return Err(ValidationError::ImageTooLarge( + content_length as usize, + max_image_fetch_size, + )); + } + } + + // Read the body with size limit to prevent unbounded memory allocation + let mut data = Vec::new(); + let mut limited_reader = response.take((max_image_fetch_size + 1) as u64); + limited_reader.read_to_end(&mut data)?; + + if data.len() > max_image_fetch_size { + return Err(ValidationError::ImageTooLarge( + data.len(), + max_image_fetch_size, + )); + } let format = image::guess_format(&data)?; // TODO Remove this clone @@ -787,6 +817,7 @@ fn prepare_input( tokenizer: &T, config: Option<&Config>, preprocessor_config: Option<&HubPreprocessorConfig>, + max_image_fetch_size: usize, ) -> Result<(tokenizers::Encoding, Vec), ValidationError> { use Config::*; static RE: Lazy = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); @@ -805,7 +836,8 @@ fn prepare_input( input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string())); tokenizer_query.push_str(&inputs[start..chunk_start]); } - let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?; + let (data, mimetype, height, width) = + fetch_image(&inputs[chunk_start..chunk_end], max_image_fetch_size)?; input_chunks.push(Chunk::Image(Image { data, mimetype })); tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width)); start = chunk_end; @@ -990,6 +1022,10 @@ pub enum ValidationError { InvalidImageContent(String), #[error("Could not fetch image: {0}")] FailedFetchImage(#[from] reqwest::Error), + #[error("Image size {0} bytes exceeds maximum allowed size of {1} bytes")] + ImageTooLarge(usize, usize), + #[error("Failed to read image data: {0}")] + ImageReadError(#[from] std::io::Error), #[error("{0} modality is not supported")] UnsupportedModality(&'static str), } @@ -1023,6 +1059,7 @@ mod tests { max_input_length, max_total_tokens, disable_grammar_support, + 1024 * 1024 * 1024, // 1GB ); let max_new_tokens = 10; @@ -1058,6 +1095,7 @@ mod tests { max_input_length, max_total_tokens, disable_grammar_support, + 1024 * 1024 * 1024, // 1GB ); let max_new_tokens = 10; @@ -1092,6 +1130,7 @@ mod tests { max_input_length, max_total_tokens, disable_grammar_support, + 1024 * 1024 * 1024, // 1GB ); match validation .validate(GenerateRequest { @@ -1132,6 +1171,7 @@ mod tests { max_input_length, max_total_tokens, disable_grammar_support, + 1024 * 1024 * 1024, // 1GB ); match validation .validate(GenerateRequest { @@ -1203,6 +1243,7 @@ mod tests { max_input_length, max_total_tokens, disable_grammar_support, + 1024 * 1024 * 1024, // 1GB ); match validation .validate(GenerateRequest { @@ -1293,6 +1334,7 @@ mod tests { max_input_length, max_total_tokens, disable_grammar_support, + 1024 * 1024 * 1024, // 1GB ); let chunks = match validation @@ -1349,6 +1391,7 @@ mod tests { max_input_length, max_total_tokens, disable_grammar_support, + 1024 * 1024 * 1024, // 1GB ); let (encoding, chunks) = match validation diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index 9a946d97f..2097af3fd 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -3,7 +3,7 @@ flash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd build-flash-attention-v2-cuda: pip install -U packaging wheel - pip install flash-attn==$(flash_att_v2_commit_cuda) + pip install --no-build-isolation flash-attn==$(flash_att_v2_commit_cuda) install-flash-attention-v2-cuda: build-flash-attention-v2-cuda echo "Flash v2 installed" diff --git a/server/tests/models/test_model.py b/server/tests/models/test_model.py index 8441e8c6e..d048b9cc1 100644 --- a/server/tests/models/test_model.py +++ b/server/tests/models/test_model.py @@ -14,7 +14,7 @@ def get_test_model(): def generate_token(self, batch): raise NotImplementedError - tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b") + tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") model = TestModel( "test_model_id",