Compare commits

..

2 Commits
v3.3.6 ... main

Author SHA1 Message Date
drbh
24ee40d143
feat: support max_image_fetch_size to limit (#3339)
* feat: support max_image_fetch_size to limit

* fix: update model path for test

* fix: adjust model repo id for test again

* fix: apply clippy lints

* fix: clippy fix

* fix: avoid torch build isolation in docker

* fix: bump repo id in flash llama tests

* fix: temporarily avoid problematic repos in tests
2025-11-18 12:29:21 -05:00
Funtowicz Morgan
85790a19a7
misc(gha): expose action cache url and runtime as secrets (#2964)
* misc(gha): expose action cache url and runtime as secrets

* (CI): Move S3 Auth to OIDC

* Fix Typo

* change bucket name

* fix aws auth creds

* misc(gha): fix invalid syntax for secrets

* WIP: Add AWS session token

* Increase session time

* Remove actions_cache_url mount from Dockerfile

Removed an unused mount for actions_cache_url in the Dockerfile.

* WIP

---------

Co-authored-by: Guillaume LEGENDRE <glegendre01@gmail.com>
2025-11-17 10:50:10 +01:00
14 changed files with 93 additions and 16 deletions

View File

@ -175,6 +175,14 @@ jobs:
registry: docker.io
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: configure aws credentials
id: aws-creds
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502
with:
role-to-assume: ${{ secrets.AWS_ROLE_GITHUB_BUILDX_CACHE }}
role-duration-seconds: 18000
aws-region: us-east-1
output-credentials: true
# If pull request
- name: Extract metadata (tags, labels) for Docker
if: ${{ github.event_name == 'pull_request' }}
@ -204,6 +212,8 @@ jobs:
- name: Build and push Docker image
id: build-and-push
uses: docker/build-push-action@v4
env:
DOCKER_BUILD_SUMMARY: false
with:
context: .
file: ${{ env.DOCKERFILE }}
@ -215,13 +225,14 @@ jobs:
PLATFORM=${{ env.PLATFORM }}
build_type=${{ env.BUILD_TYPE }}
sccache_gha_enabled=on
secrets: |
actions_results_url=${{ env.ACTIONS_RESULTS_URL }}
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
target: ${{ env.TARGET }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=max,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
cache-from: type=s3,region=us-east-1,bucket=${{ vars.AWS_S3BUCKET_GITHUB_BUILDX_CACHE }},name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ steps.aws-creds.outputs.aws-access-key-id }},secret_access_key=${{ steps.aws-creds.outputs.aws-secret-access-key }},session_token=${{ steps.aws-creds.outputs.aws-session-token }},mode=max
cache-to: type=s3,region=us-east-1,bucket=${{ vars.AWS_S3BUCKET_GITHUB_BUILDX_CACHE }},name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ steps.aws-creds.outputs.aws-access-key-id }},secret_access_key=${{ steps.aws-creds.outputs.aws-secret-access-key }},session_token=${{ steps.aws-creds.outputs.aws-session-token }},mode=max
- name: Final
id: final
run: |

View File

@ -65,8 +65,6 @@ WORKDIR /usr/src/text-generation-inference
ARG cuda_arch_list
ARG build_type
ARG sccache_gha_enabled
ARG actions_results_url
ARG actions_runtime_token
# Install Rust
ENV PATH="/root/.cargo/bin:$PATH"
@ -84,8 +82,6 @@ ENV CUDA_ARCH_LIST=${cuda_arch_list}
# SCCACHE Specifics args - before finding a better, more generic, way...
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
ENV ACTIONS_RESULTS_URL=${actions_results_url}
ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
@ -99,8 +95,8 @@ COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
ENV RUSTC_WRAPPER=sccache
ENV CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX
RUN export CC=gcc-14 \
export CXX=g++-14 \
RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
export CMAKE_C_COMPILER_LAUNCHER=sccache && \
export CMAKE_CXX_COMPILER_LAUNCHER=sccache && \
export CMAKE_CUDA_COMPILER_LAUNCHER=sccache && \

View File

@ -157,6 +157,10 @@ struct Args {
/// Maximum payload size in bytes.
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
/// Maximum image fetch size in bytes.
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}
#[tokio::main]
@ -320,6 +324,7 @@ async fn main() -> Result<(), RouterError> {
args.max_client_batch_size,
args.usage_stats,
args.payload_limit,
args.max_image_fetch_size,
args.prometheus_port,
)
.await?;

View File

@ -67,6 +67,8 @@ struct Args {
usage_stats: UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}
async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
@ -244,6 +246,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
executor_worker,
usage_stats,
payload_limit,
max_image_fetch_size,
} = args;
// Launch Tokio runtime
@ -325,6 +328,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
prometheus_port,
)
.await?;

View File

@ -74,6 +74,8 @@ struct Args {
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}
#[derive(Debug, Subcommand)]
@ -120,6 +122,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
} = args;
if let Some(Commands::PrintSchema) = command {
@ -201,6 +204,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
prometheus_port,
)
.await?;

View File

@ -74,6 +74,8 @@ struct Args {
usage_stats: usage_stats::UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}
#[derive(Debug, Subcommand)]
@ -120,6 +122,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
} = args;
if let Some(Commands::PrintSchema) = command {
@ -217,6 +220,7 @@ async fn main() -> Result<(), RouterError> {
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
prometheus_port,
)
.await?;

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module")
def flash_llama_handle(launcher):
with launcher("huggingface/llama-7b", num_shard=2) as handle:
with launcher("huggyllama/llama-7b", num_shard=2) as handle:
yield handle

View File

@ -13,6 +13,7 @@ async def flash_llama_fp8(flash_llama_fp8_handle):
return flash_llama_fp8_handle.client
@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@ -26,6 +27,7 @@ async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot):
assert response == response_snapshot
@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@ -49,6 +51,7 @@ async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot):
assert response == response_snapshot
@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private

View File

@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin24_handle):
return flash_llama_marlin24_handle.client
@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@ -27,6 +28,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
assert response == response_snapshot
@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
@ -50,6 +52,7 @@ async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snap
assert response == response_snapshot
@pytest.mark.skip(reason="Issue with the model access")
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private

View File

@ -673,7 +673,7 @@ mod tests {
let (name, arguments) = get_tool_call_content(&events[0]);
if let Some(name) = name {
assert_eq!(name, "get_current_weather");
output_name.push_str(&name);
output_name.push_str(name);
}
output.push_str(arguments);
} else {

View File

@ -1523,6 +1523,7 @@ pub async fn run(
max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel,
payload_limit: usize,
max_image_fetch_size: usize,
prometheus_port: u16,
) -> Result<(), WebServerError> {
// CORS allowed origins
@ -1827,6 +1828,7 @@ pub async fn run(
compat_return_full_text,
allow_origin,
payload_limit,
max_image_fetch_size,
prometheus_port,
)
.await;
@ -1889,6 +1891,7 @@ async fn start(
compat_return_full_text: bool,
allow_origin: Option<AllowOrigin>,
payload_limit: usize,
max_image_fetch_size: usize,
prometheus_port: u16,
) -> Result<(), WebServerError> {
// Determine the server port based on the feature and environment variable.
@ -1920,6 +1923,7 @@ async fn start(
max_input_tokens,
max_total_tokens,
disable_grammar_support,
max_image_fetch_size,
);
let infer = Infer::new(

View File

@ -12,7 +12,7 @@ use rand::{thread_rng, Rng};
use serde_json::Value;
/// Payload validation logic
use std::cmp::min;
use std::io::Cursor;
use std::io::{Cursor, Read};
use std::iter;
use std::sync::Arc;
use thiserror::Error;
@ -51,6 +51,7 @@ impl Validation {
max_input_length: usize,
max_total_tokens: usize,
disable_grammar_support: bool,
max_image_fetch_size: usize,
) -> Self {
let workers = if let Tokenizer::Python { .. } = &tokenizer {
1
@ -78,6 +79,7 @@ impl Validation {
config_clone,
preprocessor_config_clone,
tokenizer_receiver,
max_image_fetch_size,
)
});
}
@ -480,6 +482,7 @@ fn tokenizer_worker(
config: Option<Config>,
preprocessor_config: Option<HubPreprocessorConfig>,
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
max_image_fetch_size: usize,
) {
match tokenizer {
Tokenizer::Python {
@ -503,6 +506,7 @@ fn tokenizer_worker(
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
max_image_fetch_size,
))
.unwrap_or(())
})
@ -524,6 +528,7 @@ fn tokenizer_worker(
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
max_image_fetch_size,
))
.unwrap_or(())
})
@ -562,10 +567,35 @@ fn format_to_mimetype(format: ImageFormat) -> String {
.to_string()
}
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
fn fetch_image(
input: &str,
max_image_fetch_size: usize,
) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
if input.starts_with("![](http://") || input.starts_with("![](https://") {
let url = &input["![](".len()..input.len() - 1];
let data = reqwest::blocking::get(url)?.bytes()?;
let response = reqwest::blocking::get(url)?;
// Check Content-Length header if present
if let Some(content_length) = response.content_length() {
if content_length as usize > max_image_fetch_size {
return Err(ValidationError::ImageTooLarge(
content_length as usize,
max_image_fetch_size,
));
}
}
// Read the body with size limit to prevent unbounded memory allocation
let mut data = Vec::new();
let mut limited_reader = response.take((max_image_fetch_size + 1) as u64);
limited_reader.read_to_end(&mut data)?;
if data.len() > max_image_fetch_size {
return Err(ValidationError::ImageTooLarge(
data.len(),
max_image_fetch_size,
));
}
let format = image::guess_format(&data)?;
// TODO Remove this clone
@ -787,6 +817,7 @@ fn prepare_input<T: TokenizerTrait>(
tokenizer: &T,
config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>,
max_image_fetch_size: usize,
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
@ -805,7 +836,8 @@ fn prepare_input<T: TokenizerTrait>(
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
let (data, mimetype, height, width) =
fetch_image(&inputs[chunk_start..chunk_end], max_image_fetch_size)?;
input_chunks.push(Chunk::Image(Image { data, mimetype }));
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end;
@ -990,6 +1022,10 @@ pub enum ValidationError {
InvalidImageContent(String),
#[error("Could not fetch image: {0}")]
FailedFetchImage(#[from] reqwest::Error),
#[error("Image size {0} bytes exceeds maximum allowed size of {1} bytes")]
ImageTooLarge(usize, usize),
#[error("Failed to read image data: {0}")]
ImageReadError(#[from] std::io::Error),
#[error("{0} modality is not supported")]
UnsupportedModality(&'static str),
}
@ -1023,6 +1059,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
let max_new_tokens = 10;
@ -1058,6 +1095,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
let max_new_tokens = 10;
@ -1092,6 +1130,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
match validation
.validate(GenerateRequest {
@ -1132,6 +1171,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
match validation
.validate(GenerateRequest {
@ -1203,6 +1243,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
match validation
.validate(GenerateRequest {
@ -1293,6 +1334,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
let chunks = match validation
@ -1349,6 +1391,7 @@ mod tests {
max_input_length,
max_total_tokens,
disable_grammar_support,
1024 * 1024 * 1024, // 1GB
);
let (encoding, chunks) = match validation

View File

@ -3,7 +3,7 @@ flash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd
build-flash-attention-v2-cuda:
pip install -U packaging wheel
pip install flash-attn==$(flash_att_v2_commit_cuda)
pip install --no-build-isolation flash-attn==$(flash_att_v2_commit_cuda)
install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
echo "Flash v2 installed"

View File

@ -14,7 +14,7 @@ def get_test_model():
def generate_token(self, batch):
raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
model = TestModel(
"test_model_id",