mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-11-18 23:15:59 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
24ee40d143 | ||
|
|
85790a19a7 |
15
.github/workflows/build.yaml
vendored
15
.github/workflows/build.yaml
vendored
@ -175,6 +175,14 @@ jobs:
|
|||||||
registry: docker.io
|
registry: docker.io
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
|
- name: configure aws credentials
|
||||||
|
id: aws-creds
|
||||||
|
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502
|
||||||
|
with:
|
||||||
|
role-to-assume: ${{ secrets.AWS_ROLE_GITHUB_BUILDX_CACHE }}
|
||||||
|
role-duration-seconds: 18000
|
||||||
|
aws-region: us-east-1
|
||||||
|
output-credentials: true
|
||||||
# If pull request
|
# If pull request
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
- name: Extract metadata (tags, labels) for Docker
|
||||||
if: ${{ github.event_name == 'pull_request' }}
|
if: ${{ github.event_name == 'pull_request' }}
|
||||||
@ -204,6 +212,8 @@ jobs:
|
|||||||
- name: Build and push Docker image
|
- name: Build and push Docker image
|
||||||
id: build-and-push
|
id: build-and-push
|
||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v4
|
||||||
|
env:
|
||||||
|
DOCKER_BUILD_SUMMARY: false
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ${{ env.DOCKERFILE }}
|
file: ${{ env.DOCKERFILE }}
|
||||||
@ -215,13 +225,14 @@ jobs:
|
|||||||
PLATFORM=${{ env.PLATFORM }}
|
PLATFORM=${{ env.PLATFORM }}
|
||||||
build_type=${{ env.BUILD_TYPE }}
|
build_type=${{ env.BUILD_TYPE }}
|
||||||
sccache_gha_enabled=on
|
sccache_gha_enabled=on
|
||||||
|
secrets: |
|
||||||
actions_results_url=${{ env.ACTIONS_RESULTS_URL }}
|
actions_results_url=${{ env.ACTIONS_RESULTS_URL }}
|
||||||
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
|
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
|
||||||
target: ${{ env.TARGET }}
|
target: ${{ env.TARGET }}
|
||||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||||
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=max,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
cache-from: type=s3,region=us-east-1,bucket=${{ vars.AWS_S3BUCKET_GITHUB_BUILDX_CACHE }},name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ steps.aws-creds.outputs.aws-access-key-id }},secret_access_key=${{ steps.aws-creds.outputs.aws-secret-access-key }},session_token=${{ steps.aws-creds.outputs.aws-session-token }},mode=max
|
||||||
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL_EXTENSION }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
cache-to: type=s3,region=us-east-1,bucket=${{ vars.AWS_S3BUCKET_GITHUB_BUILDX_CACHE }},name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ steps.aws-creds.outputs.aws-access-key-id }},secret_access_key=${{ steps.aws-creds.outputs.aws-secret-access-key }},session_token=${{ steps.aws-creds.outputs.aws-session-token }},mode=max
|
||||||
- name: Final
|
- name: Final
|
||||||
id: final
|
id: final
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@ -65,8 +65,6 @@ WORKDIR /usr/src/text-generation-inference
|
|||||||
ARG cuda_arch_list
|
ARG cuda_arch_list
|
||||||
ARG build_type
|
ARG build_type
|
||||||
ARG sccache_gha_enabled
|
ARG sccache_gha_enabled
|
||||||
ARG actions_results_url
|
|
||||||
ARG actions_runtime_token
|
|
||||||
|
|
||||||
# Install Rust
|
# Install Rust
|
||||||
ENV PATH="/root/.cargo/bin:$PATH"
|
ENV PATH="/root/.cargo/bin:$PATH"
|
||||||
@ -84,8 +82,6 @@ ENV CUDA_ARCH_LIST=${cuda_arch_list}
|
|||||||
|
|
||||||
# SCCACHE Specifics args - before finding a better, more generic, way...
|
# SCCACHE Specifics args - before finding a better, more generic, way...
|
||||||
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
|
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
|
||||||
ENV ACTIONS_RESULTS_URL=${actions_results_url}
|
|
||||||
ENV ACTIONS_RUNTIME_TOKEN=${actions_runtime_token}
|
|
||||||
|
|
||||||
COPY Cargo.lock Cargo.lock
|
COPY Cargo.lock Cargo.lock
|
||||||
COPY Cargo.toml Cargo.toml
|
COPY Cargo.toml Cargo.toml
|
||||||
@ -99,8 +95,8 @@ COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
|||||||
|
|
||||||
ENV RUSTC_WRAPPER=sccache
|
ENV RUSTC_WRAPPER=sccache
|
||||||
ENV CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX
|
ENV CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX
|
||||||
RUN export CC=gcc-14 \
|
RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
|
||||||
export CXX=g++-14 \
|
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
|
||||||
export CMAKE_C_COMPILER_LAUNCHER=sccache && \
|
export CMAKE_C_COMPILER_LAUNCHER=sccache && \
|
||||||
export CMAKE_CXX_COMPILER_LAUNCHER=sccache && \
|
export CMAKE_CXX_COMPILER_LAUNCHER=sccache && \
|
||||||
export CMAKE_CUDA_COMPILER_LAUNCHER=sccache && \
|
export CMAKE_CUDA_COMPILER_LAUNCHER=sccache && \
|
||||||
|
|||||||
@ -157,6 +157,10 @@ struct Args {
|
|||||||
/// Maximum payload size in bytes.
|
/// Maximum payload size in bytes.
|
||||||
#[clap(default_value = "2000000", long, env)]
|
#[clap(default_value = "2000000", long, env)]
|
||||||
payload_limit: usize,
|
payload_limit: usize,
|
||||||
|
|
||||||
|
/// Maximum image fetch size in bytes.
|
||||||
|
#[clap(default_value = "1073741824", long, env)]
|
||||||
|
max_image_fetch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -320,6 +324,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
args.max_client_batch_size,
|
args.max_client_batch_size,
|
||||||
args.usage_stats,
|
args.usage_stats,
|
||||||
args.payload_limit,
|
args.payload_limit,
|
||||||
|
args.max_image_fetch_size,
|
||||||
args.prometheus_port,
|
args.prometheus_port,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|||||||
@ -67,6 +67,8 @@ struct Args {
|
|||||||
usage_stats: UsageStatsLevel,
|
usage_stats: UsageStatsLevel,
|
||||||
#[clap(default_value = "2000000", long, env)]
|
#[clap(default_value = "2000000", long, env)]
|
||||||
payload_limit: usize,
|
payload_limit: usize,
|
||||||
|
#[clap(default_value = "1073741824", long, env)]
|
||||||
|
max_image_fetch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
|
async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
|
||||||
@ -244,6 +246,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
executor_worker,
|
executor_worker,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
payload_limit,
|
payload_limit,
|
||||||
|
max_image_fetch_size,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
@ -325,6 +328,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
payload_limit,
|
payload_limit,
|
||||||
|
max_image_fetch_size,
|
||||||
prometheus_port,
|
prometheus_port,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|||||||
@ -74,6 +74,8 @@ struct Args {
|
|||||||
usage_stats: usage_stats::UsageStatsLevel,
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
#[clap(default_value = "2000000", long, env)]
|
#[clap(default_value = "2000000", long, env)]
|
||||||
payload_limit: usize,
|
payload_limit: usize,
|
||||||
|
#[clap(default_value = "1073741824", long, env)]
|
||||||
|
max_image_fetch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Subcommand)]
|
#[derive(Debug, Subcommand)]
|
||||||
@ -120,6 +122,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
payload_limit,
|
payload_limit,
|
||||||
|
max_image_fetch_size,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
if let Some(Commands::PrintSchema) = command {
|
if let Some(Commands::PrintSchema) = command {
|
||||||
@ -201,6 +204,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
payload_limit,
|
payload_limit,
|
||||||
|
max_image_fetch_size,
|
||||||
prometheus_port,
|
prometheus_port,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|||||||
@ -74,6 +74,8 @@ struct Args {
|
|||||||
usage_stats: usage_stats::UsageStatsLevel,
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
#[clap(default_value = "2000000", long, env)]
|
#[clap(default_value = "2000000", long, env)]
|
||||||
payload_limit: usize,
|
payload_limit: usize,
|
||||||
|
#[clap(default_value = "1073741824", long, env)]
|
||||||
|
max_image_fetch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Subcommand)]
|
#[derive(Debug, Subcommand)]
|
||||||
@ -120,6 +122,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
payload_limit,
|
payload_limit,
|
||||||
|
max_image_fetch_size,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
if let Some(Commands::PrintSchema) = command {
|
if let Some(Commands::PrintSchema) = command {
|
||||||
@ -217,6 +220,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
payload_limit,
|
payload_limit,
|
||||||
|
max_image_fetch_size,
|
||||||
prometheus_port,
|
prometheus_port,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import pytest
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_llama_handle(launcher):
|
def flash_llama_handle(launcher):
|
||||||
with launcher("huggingface/llama-7b", num_shard=2) as handle:
|
with launcher("huggyllama/llama-7b", num_shard=2) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@ async def flash_llama_fp8(flash_llama_fp8_handle):
|
|||||||
return flash_llama_fp8_handle.client
|
return flash_llama_fp8_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Issue with the model access")
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@ -26,6 +27,7 @@ async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Issue with the model access")
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@ -49,6 +51,7 @@ async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Issue with the model access")
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
|
|||||||
@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin24_handle):
|
|||||||
return flash_llama_marlin24_handle.client
|
return flash_llama_marlin24_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Issue with the model access")
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@ -27,6 +28,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Issue with the model access")
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
@ -50,6 +52,7 @@ async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snap
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Issue with the model access")
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
|
|||||||
@ -673,7 +673,7 @@ mod tests {
|
|||||||
let (name, arguments) = get_tool_call_content(&events[0]);
|
let (name, arguments) = get_tool_call_content(&events[0]);
|
||||||
if let Some(name) = name {
|
if let Some(name) = name {
|
||||||
assert_eq!(name, "get_current_weather");
|
assert_eq!(name, "get_current_weather");
|
||||||
output_name.push_str(&name);
|
output_name.push_str(name);
|
||||||
}
|
}
|
||||||
output.push_str(arguments);
|
output.push_str(arguments);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -1523,6 +1523,7 @@ pub async fn run(
|
|||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: usage_stats::UsageStatsLevel,
|
usage_stats_level: usage_stats::UsageStatsLevel,
|
||||||
payload_limit: usize,
|
payload_limit: usize,
|
||||||
|
max_image_fetch_size: usize,
|
||||||
prometheus_port: u16,
|
prometheus_port: u16,
|
||||||
) -> Result<(), WebServerError> {
|
) -> Result<(), WebServerError> {
|
||||||
// CORS allowed origins
|
// CORS allowed origins
|
||||||
@ -1827,6 +1828,7 @@ pub async fn run(
|
|||||||
compat_return_full_text,
|
compat_return_full_text,
|
||||||
allow_origin,
|
allow_origin,
|
||||||
payload_limit,
|
payload_limit,
|
||||||
|
max_image_fetch_size,
|
||||||
prometheus_port,
|
prometheus_port,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
@ -1889,6 +1891,7 @@ async fn start(
|
|||||||
compat_return_full_text: bool,
|
compat_return_full_text: bool,
|
||||||
allow_origin: Option<AllowOrigin>,
|
allow_origin: Option<AllowOrigin>,
|
||||||
payload_limit: usize,
|
payload_limit: usize,
|
||||||
|
max_image_fetch_size: usize,
|
||||||
prometheus_port: u16,
|
prometheus_port: u16,
|
||||||
) -> Result<(), WebServerError> {
|
) -> Result<(), WebServerError> {
|
||||||
// Determine the server port based on the feature and environment variable.
|
// Determine the server port based on the feature and environment variable.
|
||||||
@ -1920,6 +1923,7 @@ async fn start(
|
|||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
|
max_image_fetch_size,
|
||||||
);
|
);
|
||||||
|
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
|
|||||||
@ -12,7 +12,7 @@ use rand::{thread_rng, Rng};
|
|||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::io::Cursor;
|
use std::io::{Cursor, Read};
|
||||||
use std::iter;
|
use std::iter;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
@ -51,6 +51,7 @@ impl Validation {
|
|||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
|
max_image_fetch_size: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let workers = if let Tokenizer::Python { .. } = &tokenizer {
|
let workers = if let Tokenizer::Python { .. } = &tokenizer {
|
||||||
1
|
1
|
||||||
@ -78,6 +79,7 @@ impl Validation {
|
|||||||
config_clone,
|
config_clone,
|
||||||
preprocessor_config_clone,
|
preprocessor_config_clone,
|
||||||
tokenizer_receiver,
|
tokenizer_receiver,
|
||||||
|
max_image_fetch_size,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -480,6 +482,7 @@ fn tokenizer_worker(
|
|||||||
config: Option<Config>,
|
config: Option<Config>,
|
||||||
preprocessor_config: Option<HubPreprocessorConfig>,
|
preprocessor_config: Option<HubPreprocessorConfig>,
|
||||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||||
|
max_image_fetch_size: usize,
|
||||||
) {
|
) {
|
||||||
match tokenizer {
|
match tokenizer {
|
||||||
Tokenizer::Python {
|
Tokenizer::Python {
|
||||||
@ -503,6 +506,7 @@ fn tokenizer_worker(
|
|||||||
&tokenizer,
|
&tokenizer,
|
||||||
config.as_ref(),
|
config.as_ref(),
|
||||||
preprocessor_config.as_ref(),
|
preprocessor_config.as_ref(),
|
||||||
|
max_image_fetch_size,
|
||||||
))
|
))
|
||||||
.unwrap_or(())
|
.unwrap_or(())
|
||||||
})
|
})
|
||||||
@ -524,6 +528,7 @@ fn tokenizer_worker(
|
|||||||
&tokenizer,
|
&tokenizer,
|
||||||
config.as_ref(),
|
config.as_ref(),
|
||||||
preprocessor_config.as_ref(),
|
preprocessor_config.as_ref(),
|
||||||
|
max_image_fetch_size,
|
||||||
))
|
))
|
||||||
.unwrap_or(())
|
.unwrap_or(())
|
||||||
})
|
})
|
||||||
@ -562,10 +567,35 @@ fn format_to_mimetype(format: ImageFormat) -> String {
|
|||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
fn fetch_image(
|
||||||
|
input: &str,
|
||||||
|
max_image_fetch_size: usize,
|
||||||
|
) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
|
||||||
if input.starts_with(" || input.starts_with(" {
|
if input.starts_with(" || input.starts_with(" {
|
||||||
let url = &input["..input.len() - 1];
|
let url = &input["..input.len() - 1];
|
||||||
let data = reqwest::blocking::get(url)?.bytes()?;
|
let response = reqwest::blocking::get(url)?;
|
||||||
|
|
||||||
|
// Check Content-Length header if present
|
||||||
|
if let Some(content_length) = response.content_length() {
|
||||||
|
if content_length as usize > max_image_fetch_size {
|
||||||
|
return Err(ValidationError::ImageTooLarge(
|
||||||
|
content_length as usize,
|
||||||
|
max_image_fetch_size,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the body with size limit to prevent unbounded memory allocation
|
||||||
|
let mut data = Vec::new();
|
||||||
|
let mut limited_reader = response.take((max_image_fetch_size + 1) as u64);
|
||||||
|
limited_reader.read_to_end(&mut data)?;
|
||||||
|
|
||||||
|
if data.len() > max_image_fetch_size {
|
||||||
|
return Err(ValidationError::ImageTooLarge(
|
||||||
|
data.len(),
|
||||||
|
max_image_fetch_size,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let format = image::guess_format(&data)?;
|
let format = image::guess_format(&data)?;
|
||||||
// TODO Remove this clone
|
// TODO Remove this clone
|
||||||
@ -787,6 +817,7 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
tokenizer: &T,
|
tokenizer: &T,
|
||||||
config: Option<&Config>,
|
config: Option<&Config>,
|
||||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||||
|
max_image_fetch_size: usize,
|
||||||
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
||||||
use Config::*;
|
use Config::*;
|
||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
@ -805,7 +836,8 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
||||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||||
}
|
}
|
||||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
let (data, mimetype, height, width) =
|
||||||
|
fetch_image(&inputs[chunk_start..chunk_end], max_image_fetch_size)?;
|
||||||
input_chunks.push(Chunk::Image(Image { data, mimetype }));
|
input_chunks.push(Chunk::Image(Image { data, mimetype }));
|
||||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
||||||
start = chunk_end;
|
start = chunk_end;
|
||||||
@ -990,6 +1022,10 @@ pub enum ValidationError {
|
|||||||
InvalidImageContent(String),
|
InvalidImageContent(String),
|
||||||
#[error("Could not fetch image: {0}")]
|
#[error("Could not fetch image: {0}")]
|
||||||
FailedFetchImage(#[from] reqwest::Error),
|
FailedFetchImage(#[from] reqwest::Error),
|
||||||
|
#[error("Image size {0} bytes exceeds maximum allowed size of {1} bytes")]
|
||||||
|
ImageTooLarge(usize, usize),
|
||||||
|
#[error("Failed to read image data: {0}")]
|
||||||
|
ImageReadError(#[from] std::io::Error),
|
||||||
#[error("{0} modality is not supported")]
|
#[error("{0} modality is not supported")]
|
||||||
UnsupportedModality(&'static str),
|
UnsupportedModality(&'static str),
|
||||||
}
|
}
|
||||||
@ -1023,6 +1059,7 @@ mod tests {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
|
1024 * 1024 * 1024, // 1GB
|
||||||
);
|
);
|
||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
@ -1058,6 +1095,7 @@ mod tests {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
|
1024 * 1024 * 1024, // 1GB
|
||||||
);
|
);
|
||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
@ -1092,6 +1130,7 @@ mod tests {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
|
1024 * 1024 * 1024, // 1GB
|
||||||
);
|
);
|
||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
@ -1132,6 +1171,7 @@ mod tests {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
|
1024 * 1024 * 1024, // 1GB
|
||||||
);
|
);
|
||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
@ -1203,6 +1243,7 @@ mod tests {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
|
1024 * 1024 * 1024, // 1GB
|
||||||
);
|
);
|
||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
@ -1293,6 +1334,7 @@ mod tests {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
|
1024 * 1024 * 1024, // 1GB
|
||||||
);
|
);
|
||||||
|
|
||||||
let chunks = match validation
|
let chunks = match validation
|
||||||
@ -1349,6 +1391,7 @@ mod tests {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
|
1024 * 1024 * 1024, // 1GB
|
||||||
);
|
);
|
||||||
|
|
||||||
let (encoding, chunks) = match validation
|
let (encoding, chunks) = match validation
|
||||||
|
|||||||
@ -3,7 +3,7 @@ flash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd
|
|||||||
|
|
||||||
build-flash-attention-v2-cuda:
|
build-flash-attention-v2-cuda:
|
||||||
pip install -U packaging wheel
|
pip install -U packaging wheel
|
||||||
pip install flash-attn==$(flash_att_v2_commit_cuda)
|
pip install --no-build-isolation flash-attn==$(flash_att_v2_commit_cuda)
|
||||||
|
|
||||||
install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
|
||||||
echo "Flash v2 installed"
|
echo "Flash v2 installed"
|
||||||
|
|||||||
@ -14,7 +14,7 @@ def get_test_model():
|
|||||||
def generate_token(self, batch):
|
def generate_token(self, batch):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
|
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
|
|
||||||
model = TestModel(
|
model = TestModel(
|
||||||
"test_model_id",
|
"test_model_id",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user