mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 21:12:07 +00:00
chore(rebase): fix invalid references
This commit is contained in:
parent
f5b9ee368a
commit
d73401ac73
@ -5,14 +5,13 @@ use std::path::Path;
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
use cxx::UniquePtr;
|
||||||
use hashbrown::HashMap;
|
use hashbrown::HashMap;
|
||||||
use log::warn;
|
use tokenizers::Tokenizer;
|
||||||
use tokenizers::{Encoding, Tokenizer};
|
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||||
use tokio::sync::TryAcquireError;
|
use tokio::sync::TryAcquireError;
|
||||||
use tokio::task::{spawn_blocking, JoinHandle};
|
use tokio::task::{spawn_blocking, JoinHandle};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error, info};
|
use tracing::{debug, error, warn};
|
||||||
|
|
||||||
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
|
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
@ -285,7 +284,6 @@ fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
|
|||||||
unsafe impl Send for TensorRtLlmBackendImpl {}
|
unsafe impl Send for TensorRtLlmBackendImpl {}
|
||||||
|
|
||||||
pub struct TensorRtLlmBackendV2 {
|
pub struct TensorRtLlmBackendV2 {
|
||||||
tokenizer: Tokenizer,
|
|
||||||
executor_looper: JoinHandle<()>,
|
executor_looper: JoinHandle<()>,
|
||||||
post_processor_looper: JoinHandle<()>,
|
post_processor_looper: JoinHandle<()>,
|
||||||
executor: UnboundedSender<GenerationContext>,
|
executor: UnboundedSender<GenerationContext>,
|
||||||
@ -320,10 +318,9 @@ impl TensorRtLlmBackendV2 {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
|
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
|
||||||
let tokenizer_ = tokenizer.clone();
|
|
||||||
let post_processor_looper = spawn_blocking(move || {
|
let post_processor_looper = spawn_blocking(move || {
|
||||||
post_processor_looper(
|
post_processor_looper(
|
||||||
tokenizer_,
|
tokenizer,
|
||||||
512,
|
512,
|
||||||
max_inflight_requests,
|
max_inflight_requests,
|
||||||
post_processor_receiver,
|
post_processor_receiver,
|
||||||
@ -331,7 +328,6 @@ impl TensorRtLlmBackendV2 {
|
|||||||
});
|
});
|
||||||
|
|
||||||
Ok(TensorRtLlmBackendV2 {
|
Ok(TensorRtLlmBackendV2 {
|
||||||
tokenizer,
|
|
||||||
executor_looper,
|
executor_looper,
|
||||||
post_processor_looper,
|
post_processor_looper,
|
||||||
executor: executor_sender,
|
executor: executor_sender,
|
||||||
@ -358,7 +354,7 @@ impl TensorRtLlmBackendV2 {
|
|||||||
"TensorRT-LLM backend don't support multi-chunk".into(),
|
"TensorRT-LLM backend don't support multi-chunk".into(),
|
||||||
)),
|
)),
|
||||||
1 => match request.inputs.first().expect("Single item-chunk") {
|
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||||
Chunk::Text(text) => Ok(()),
|
Chunk::Text(_) => Ok(()),
|
||||||
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
|
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ use tracing::info;
|
|||||||
|
|
||||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||||
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
|
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
|
||||||
use text_generation_router::server::{create_post_processor, get_base_tokenizer};
|
use text_generation_router::server::get_base_tokenizer;
|
||||||
use text_generation_router::usage_stats::UsageStatsLevel;
|
use text_generation_router::usage_stats::UsageStatsLevel;
|
||||||
use text_generation_router::{server, HubTokenizerConfig};
|
use text_generation_router::{server, HubTokenizerConfig};
|
||||||
|
|
||||||
@ -125,10 +125,10 @@ async fn get_tokenizer(
|
|||||||
// Load tokenizer and model info
|
// Load tokenizer and model info
|
||||||
let (
|
let (
|
||||||
tokenizer_filename,
|
tokenizer_filename,
|
||||||
config_filename,
|
_config_filename,
|
||||||
tokenizer_config_filename,
|
tokenizer_config_filename,
|
||||||
preprocessor_config_filename,
|
_preprocessor_config_filename,
|
||||||
processor_config_filename,
|
_processor_config_filename,
|
||||||
) = match api {
|
) = match api {
|
||||||
Type::None => (
|
Type::None => (
|
||||||
Some(local_path.join("tokenizer.json")),
|
Some(local_path.join("tokenizer.json")),
|
||||||
@ -184,25 +184,8 @@ async fn get_tokenizer(
|
|||||||
} else {
|
} else {
|
||||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||||
};
|
};
|
||||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
|
||||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
|
||||||
HubTokenizerConfig::default()
|
|
||||||
});
|
|
||||||
|
|
||||||
tokenizer_filename.and_then(|filename| {
|
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
|
||||||
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
|
||||||
if let Some(tokenizer) = &mut tokenizer {
|
|
||||||
if let Some(class) = &tokenizer_config.tokenizer_class {
|
|
||||||
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
|
|
||||||
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
|
||||||
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
|
||||||
tokenizer.with_post_processor(post_processor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tokenizer
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
Loading…
Reference in New Issue
Block a user