Fixing CI. (#2846)

This commit is contained in:
Nicolas Patry 2024-12-16 10:58:15 +01:00 committed by GitHub
parent 6f0b8c947d
commit 11ab329883
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 42 additions and 39 deletions

View File

@ -44,7 +44,8 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
} }
let mut config = cmake::Config::new("."); let mut config = cmake::Config::new(".");
config.uses_cxx11() config
.uses_cxx11()
.generator("Ninja") .generator("Ninja")
.profile(match is_debug { .profile(match is_debug {
true => "Debug", true => "Debug",
@ -57,12 +58,12 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list) .define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path); .define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path);
// Allow to override which Python to use ... // Allow to override which Python to use ...
if let Some(python3) = option_env!("Python3_EXECUTABLE") { if let Some(python3) = option_env!("Python3_EXECUTABLE") {
config.define("Python3_EXECUTABLE", python3); config.define("Python3_EXECUTABLE", python3);
} }
config.build(); config.build();
// Additional transitive CMake dependencies // Additional transitive CMake dependencies
let deps_folder = out_dir.join("build").join("_deps"); let deps_folder = out_dir.join("build").join("_deps");

View File

@ -133,13 +133,16 @@ fn executor_status_looper(
Ok(decoded_token) => { Ok(decoded_token) => {
post_process_decoded_token(&tokenizer, ctx, decoded_token) post_process_decoded_token(&tokenizer, ctx, decoded_token)
} }
Err(err) => Err(err) Err(err) => Err(err),
}; };
// Attempt to send back the response to the client // Attempt to send back the response to the client
if let Err(_) = ctx.streamer.send(response) { if let Err(_) = ctx.streamer.send(response) {
// Client has dropped, remove from tracked requests // Client has dropped, remove from tracked requests
debug!("Client dropped - removing request {} from tracked requests", step.request_id); debug!(
"Client dropped - removing request {} from tracked requests",
step.request_id
);
backend.as_mut().cancel(step.request_id); backend.as_mut().cancel(step.request_id);
let _ = in_flights.remove(&step.request_id); let _ = in_flights.remove(&step.request_id);
} }
@ -160,11 +163,14 @@ fn executor_status_looper(
} }
} }
fn post_process_decoded_token(tokenizer: &Tokenizer, ctx: &mut GenerationContext, decoded_token: DecodedToken) -> InferResult<InferStreamResponse> { fn post_process_decoded_token(
tokenizer: &Tokenizer,
ctx: &mut GenerationContext,
decoded_token: DecodedToken,
) -> InferResult<InferStreamResponse> {
match tokenizer.decode(&[decoded_token.id], false) { match tokenizer.decode(&[decoded_token.id], false) {
Ok(text) => { Ok(text) => {
let is_special = let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token { let token = Token {
id: decoded_token.id, id: decoded_token.id,
text, text,
@ -186,7 +192,7 @@ fn post_process_decoded_token(tokenizer: &Tokenizer, ctx: &mut GenerationContext
let generated_text = GeneratedText { let generated_text = GeneratedText {
text: text.unwrap(), text: text.unwrap(),
generated_tokens: ctx.tokens.len() as u32, generated_tokens: ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason
seed: None, seed: None,
}; };
@ -248,7 +254,6 @@ unsafe impl Send for TensorRtLlmBackendImpl {}
pub struct TensorRtLlmBackendV2(UnboundedSender<GenerationContext>); pub struct TensorRtLlmBackendV2(UnboundedSender<GenerationContext>);
impl TensorRtLlmBackendV2 { impl TensorRtLlmBackendV2 {
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>( pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
tokenizer: Tokenizer, tokenizer: Tokenizer,
@ -268,12 +273,7 @@ impl TensorRtLlmBackendV2 {
// Executor looper is responsible for scheduling and pulling requests state at regular interval // Executor looper is responsible for scheduling and pulling requests state at regular interval
spawn_blocking(move || { spawn_blocking(move || {
executor_status_looper( executor_status_looper(max_inflight_requests, tokenizer, backend, executor_receiver)
max_inflight_requests,
tokenizer,
backend,
executor_receiver,
)
}); });
Ok(TensorRtLlmBackendV2(executor_sender)) Ok(TensorRtLlmBackendV2(executor_sender))

View File

@ -7,9 +7,11 @@ use tracing::info;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackendV2; use text_generation_backends_trtllm::TensorRtLlmBackendV2;
use text_generation_router::server::{
get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer,
};
use text_generation_router::usage_stats::UsageStatsLevel; use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig, Tokenizer}; use text_generation_router::{server, HubTokenizerConfig, Tokenizer};
use text_generation_router::server::{get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer};
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -129,14 +131,14 @@ async fn get_tokenizer(
_tokenizer_config_filename, _tokenizer_config_filename,
_preprocessor_config_filename, _preprocessor_config_filename,
_processor_config_filename, _processor_config_filename,
_model_info _model_info,
) = match api { ) = match api {
Type::None => ( Type::None => (
Some(local_path.join("config.json")), Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")), Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")), Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")), Some(local_path.join("processor_config.json")),
None None,
), ),
Type::Api(api) => { Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision( let api_repo = api.repo(Repo::with_revision(
@ -145,7 +147,6 @@ async fn get_tokenizer(
revision.unwrap_or_else(|| "main").to_string(), revision.unwrap_or_else(|| "main").to_string(),
)); ));
let config_filename = api_repo.get("config.json").await.ok(); let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
@ -176,7 +177,7 @@ async fn get_tokenizer(
repo.get("tokenizer_config.json"), repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"), repo.get("preprocessor_config.json"),
repo.get("processor_config.json"), repo.get("processor_config.json"),
None None,
) )
} }
}; };
@ -200,14 +201,14 @@ async fn get_tokenizer(
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), false)?; py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), false)?;
Ok(()) Ok(())
}) })
.inspect_err(|err| { .inspect_err(|err| {
tracing::error!("Failed to import python tokenizer {err}"); tracing::error!("Failed to import python tokenizer {err}");
}) })
.or_else(|err| { .or_else(|err| {
let out = legacy_tokenizer_handle(config_filename.as_ref()); let out = legacy_tokenizer_handle(config_filename.as_ref());
out.ok_or(err) out.ok_or(err)
}) })
.expect("We cannot load a tokenizer"); .expect("We cannot load a tokenizer");
let filename = "out/tokenizer.json"; let filename = "out/tokenizer.json";
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) { if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
Tokenizer::Rust(tok) Tokenizer::Rust(tok)
@ -297,10 +298,11 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
revision.as_deref(), revision.as_deref(),
) )
.await .await
.expect("Failed to retrieve tokenizer implementation") { .expect("Failed to retrieve tokenizer implementation")
Tokenizer::Python { .. } => { {
Err(TensorRtLlmBackendError::Tokenizer("Failed to retrieve Rust based tokenizer".to_string())) Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer(
} "Failed to retrieve Rust based tokenizer".to_string(),
)),
Tokenizer::Rust(tokenizer) => { Tokenizer::Rust(tokenizer) => {
info!("Successfully retrieved tokenizer {}", &tokenizer_name); info!("Successfully retrieved tokenizer {}", &tokenizer_name);
let backend = TensorRtLlmBackendV2::new( let backend = TensorRtLlmBackendV2::new(
@ -337,9 +339,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
max_client_batch_size, max_client_batch_size,
usage_stats, usage_stats,
payload_limit, payload_limit,
).await?; )
.await?;
Ok(()) Ok(())
} }
} }
} }

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "3.0.1-dev0" "version": "3.0.2-dev0"
}, },
"paths": { "paths": {
"/": { "/": {