chore(rebase): fix invalid references

This commit is contained in:
Morgan Funtowicz 2024-10-21 21:44:28 +02:00
parent f5b9ee368a
commit d73401ac73
2 changed files with 9 additions and 30 deletions

View File

@ -5,14 +5,13 @@ use std::path::Path;
use async_trait::async_trait;
use cxx::UniquePtr;
use hashbrown::HashMap;
use log::warn;
use tokenizers::{Encoding, Tokenizer};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError;
use tokio::task::{spawn_blocking, JoinHandle};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info};
use tracing::{debug, error, warn};
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
@ -285,7 +284,6 @@ fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
unsafe impl Send for TensorRtLlmBackendImpl {}
pub struct TensorRtLlmBackendV2 {
tokenizer: Tokenizer,
executor_looper: JoinHandle<()>,
post_processor_looper: JoinHandle<()>,
executor: UnboundedSender<GenerationContext>,
@ -320,10 +318,9 @@ impl TensorRtLlmBackendV2 {
});
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
let tokenizer_ = tokenizer.clone();
let post_processor_looper = spawn_blocking(move || {
post_processor_looper(
tokenizer_,
tokenizer,
512,
max_inflight_requests,
post_processor_receiver,
@ -331,7 +328,6 @@ impl TensorRtLlmBackendV2 {
});
Ok(TensorRtLlmBackendV2 {
tokenizer,
executor_looper,
post_processor_looper,
executor: executor_sender,
@ -358,7 +354,7 @@ impl TensorRtLlmBackendV2 {
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(()),
Chunk::Text(_) => Ok(()),
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
},
}

View File

@ -8,7 +8,7 @@ use tracing::info;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
use text_generation_router::server::{create_post_processor, get_base_tokenizer};
use text_generation_router::server::get_base_tokenizer;
use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig};
@ -125,10 +125,10 @@ async fn get_tokenizer(
// Load tokenizer and model info
let (
tokenizer_filename,
config_filename,
_config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
_preprocessor_config_filename,
_processor_config_filename,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
@ -184,25 +184,8 @@ async fn get_tokenizer(
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
})
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
}
#[tokio::main]