mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
Fixing CI. (#2846)
This commit is contained in:
parent
6f0b8c947d
commit
11ab329883
@ -44,7 +44,8 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut config = cmake::Config::new(".");
|
let mut config = cmake::Config::new(".");
|
||||||
config.uses_cxx11()
|
config
|
||||||
|
.uses_cxx11()
|
||||||
.generator("Ninja")
|
.generator("Ninja")
|
||||||
.profile(match is_debug {
|
.profile(match is_debug {
|
||||||
true => "Debug",
|
true => "Debug",
|
||||||
|
@ -133,13 +133,16 @@ fn executor_status_looper(
|
|||||||
Ok(decoded_token) => {
|
Ok(decoded_token) => {
|
||||||
post_process_decoded_token(&tokenizer, ctx, decoded_token)
|
post_process_decoded_token(&tokenizer, ctx, decoded_token)
|
||||||
}
|
}
|
||||||
Err(err) => Err(err)
|
Err(err) => Err(err),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Attempt to send back the response to the client
|
// Attempt to send back the response to the client
|
||||||
if let Err(_) = ctx.streamer.send(response) {
|
if let Err(_) = ctx.streamer.send(response) {
|
||||||
// Client has dropped, remove from tracked requests
|
// Client has dropped, remove from tracked requests
|
||||||
debug!("Client dropped - removing request {} from tracked requests", step.request_id);
|
debug!(
|
||||||
|
"Client dropped - removing request {} from tracked requests",
|
||||||
|
step.request_id
|
||||||
|
);
|
||||||
backend.as_mut().cancel(step.request_id);
|
backend.as_mut().cancel(step.request_id);
|
||||||
let _ = in_flights.remove(&step.request_id);
|
let _ = in_flights.remove(&step.request_id);
|
||||||
}
|
}
|
||||||
@ -160,11 +163,14 @@ fn executor_status_looper(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn post_process_decoded_token(tokenizer: &Tokenizer, ctx: &mut GenerationContext, decoded_token: DecodedToken) -> InferResult<InferStreamResponse> {
|
fn post_process_decoded_token(
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
ctx: &mut GenerationContext,
|
||||||
|
decoded_token: DecodedToken,
|
||||||
|
) -> InferResult<InferStreamResponse> {
|
||||||
match tokenizer.decode(&[decoded_token.id], false) {
|
match tokenizer.decode(&[decoded_token.id], false) {
|
||||||
Ok(text) => {
|
Ok(text) => {
|
||||||
let is_special =
|
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
|
||||||
tokenizer.get_added_vocabulary().is_special_token(&text);
|
|
||||||
let token = Token {
|
let token = Token {
|
||||||
id: decoded_token.id,
|
id: decoded_token.id,
|
||||||
text,
|
text,
|
||||||
@ -248,7 +254,6 @@ unsafe impl Send for TensorRtLlmBackendImpl {}
|
|||||||
|
|
||||||
pub struct TensorRtLlmBackendV2(UnboundedSender<GenerationContext>);
|
pub struct TensorRtLlmBackendV2(UnboundedSender<GenerationContext>);
|
||||||
|
|
||||||
|
|
||||||
impl TensorRtLlmBackendV2 {
|
impl TensorRtLlmBackendV2 {
|
||||||
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
|
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
@ -268,12 +273,7 @@ impl TensorRtLlmBackendV2 {
|
|||||||
|
|
||||||
// Executor looper is responsible for scheduling and pulling requests state at regular interval
|
// Executor looper is responsible for scheduling and pulling requests state at regular interval
|
||||||
spawn_blocking(move || {
|
spawn_blocking(move || {
|
||||||
executor_status_looper(
|
executor_status_looper(max_inflight_requests, tokenizer, backend, executor_receiver)
|
||||||
max_inflight_requests,
|
|
||||||
tokenizer,
|
|
||||||
backend,
|
|
||||||
executor_receiver,
|
|
||||||
)
|
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(TensorRtLlmBackendV2(executor_sender))
|
Ok(TensorRtLlmBackendV2(executor_sender))
|
||||||
|
@ -7,9 +7,11 @@ use tracing::info;
|
|||||||
|
|
||||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||||
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
|
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
|
||||||
|
use text_generation_router::server::{
|
||||||
|
get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer,
|
||||||
|
};
|
||||||
use text_generation_router::usage_stats::UsageStatsLevel;
|
use text_generation_router::usage_stats::UsageStatsLevel;
|
||||||
use text_generation_router::{server, HubTokenizerConfig, Tokenizer};
|
use text_generation_router::{server, HubTokenizerConfig, Tokenizer};
|
||||||
use text_generation_router::server::{get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer};
|
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -129,14 +131,14 @@ async fn get_tokenizer(
|
|||||||
_tokenizer_config_filename,
|
_tokenizer_config_filename,
|
||||||
_preprocessor_config_filename,
|
_preprocessor_config_filename,
|
||||||
_processor_config_filename,
|
_processor_config_filename,
|
||||||
_model_info
|
_model_info,
|
||||||
) = match api {
|
) = match api {
|
||||||
Type::None => (
|
Type::None => (
|
||||||
Some(local_path.join("config.json")),
|
Some(local_path.join("config.json")),
|
||||||
Some(local_path.join("tokenizer_config.json")),
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
Some(local_path.join("preprocessor_config.json")),
|
Some(local_path.join("preprocessor_config.json")),
|
||||||
Some(local_path.join("processor_config.json")),
|
Some(local_path.join("processor_config.json")),
|
||||||
None
|
None,
|
||||||
),
|
),
|
||||||
Type::Api(api) => {
|
Type::Api(api) => {
|
||||||
let api_repo = api.repo(Repo::with_revision(
|
let api_repo = api.repo(Repo::with_revision(
|
||||||
@ -145,7 +147,6 @@ async fn get_tokenizer(
|
|||||||
revision.unwrap_or_else(|| "main").to_string(),
|
revision.unwrap_or_else(|| "main").to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok();
|
let config_filename = api_repo.get("config.json").await.ok();
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||||
@ -176,7 +177,7 @@ async fn get_tokenizer(
|
|||||||
repo.get("tokenizer_config.json"),
|
repo.get("tokenizer_config.json"),
|
||||||
repo.get("preprocessor_config.json"),
|
repo.get("preprocessor_config.json"),
|
||||||
repo.get("processor_config.json"),
|
repo.get("processor_config.json"),
|
||||||
None
|
None,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -297,10 +298,11 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
revision.as_deref(),
|
revision.as_deref(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to retrieve tokenizer implementation") {
|
.expect("Failed to retrieve tokenizer implementation")
|
||||||
Tokenizer::Python { .. } => {
|
{
|
||||||
Err(TensorRtLlmBackendError::Tokenizer("Failed to retrieve Rust based tokenizer".to_string()))
|
Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer(
|
||||||
}
|
"Failed to retrieve Rust based tokenizer".to_string(),
|
||||||
|
)),
|
||||||
Tokenizer::Rust(tokenizer) => {
|
Tokenizer::Rust(tokenizer) => {
|
||||||
info!("Successfully retrieved tokenizer {}", &tokenizer_name);
|
info!("Successfully retrieved tokenizer {}", &tokenizer_name);
|
||||||
let backend = TensorRtLlmBackendV2::new(
|
let backend = TensorRtLlmBackendV2::new(
|
||||||
@ -337,9 +339,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
payload_limit,
|
payload_limit,
|
||||||
).await?;
|
)
|
||||||
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "3.0.1-dev0"
|
"version": "3.0.2-dev0"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/": {
|
"/": {
|
||||||
|
Loading…
Reference in New Issue
Block a user