mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-27 10:20:17 +00:00
feat(backend): wrap Arc tokenizer to avoid duplicating
This commit is contained in:
parent
57b215467b
commit
6f059c4b5d
@ -53,10 +53,10 @@ pub(crate) struct GenerationContext {
|
|||||||
pub(crate) sampling_params: SamplingParams,
|
pub(crate) sampling_params: SamplingParams,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct InferContext {
|
pub(crate) struct InferContext<'a> {
|
||||||
pub(crate) start: Instant,
|
pub(crate) start: Instant,
|
||||||
pub(crate) stream: UnboundedSender<InferResult>,
|
pub(crate) stream: UnboundedSender<InferResult>,
|
||||||
pub(crate) tokenizer: Tokenizer,
|
pub(crate) tokenizer: &'a Tokenizer,
|
||||||
pub(crate) generation: GenerationContext,
|
pub(crate) generation: GenerationContext,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,11 +69,6 @@ pub enum LlamaCppBackendError {
|
|||||||
ModelInitializationFailed(PathBuf, String),
|
ModelInitializationFailed(PathBuf, String),
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub struct LlamaCppBackend {
|
|
||||||
// backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
|
|
||||||
// _scheduler_handle: JoinHandle<()>,
|
|
||||||
// }
|
|
||||||
|
|
||||||
struct LlamaCppWorker {
|
struct LlamaCppWorker {
|
||||||
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
|
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||||
handle: JoinHandle<()>,
|
handle: JoinHandle<()>,
|
||||||
@ -95,7 +90,7 @@ impl LlamaCppBackend {
|
|||||||
|
|
||||||
pub fn new<P: AsRef<Path>>(
|
pub fn new<P: AsRef<Path>>(
|
||||||
model_path: P,
|
model_path: P,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Arc<Tokenizer>,
|
||||||
num_cores_per_instance: u16,
|
num_cores_per_instance: u16,
|
||||||
) -> Result<Self, LlamaCppBackendError> {
|
) -> Result<Self, LlamaCppBackendError> {
|
||||||
let shared_path = Arc::new(model_path);
|
let shared_path = Arc::new(model_path);
|
||||||
@ -110,7 +105,7 @@ impl LlamaCppBackend {
|
|||||||
0 => {
|
0 => {
|
||||||
let worker = Self::allocate_worker(path)?;
|
let worker = Self::allocate_worker(path)?;
|
||||||
let (sender, receiver) = channel();
|
let (sender, receiver) = channel();
|
||||||
let handle = spawn(|| scheduler_loop(worker, tokenizer, receiver));
|
let handle = spawn(move || scheduler_loop(worker, tokenizer, receiver));
|
||||||
LlamaCppBackend::Single(LlamaCppWorker { sender, handle })
|
LlamaCppBackend::Single(LlamaCppWorker { sender, handle })
|
||||||
}
|
}
|
||||||
_ => panic!("No supported yet"),
|
_ => panic!("No supported yet"),
|
||||||
@ -186,7 +181,7 @@ fn llama_generate_callback(
|
|||||||
|
|
||||||
fn scheduler_loop(
|
fn scheduler_loop(
|
||||||
mut backend: UniquePtr<LlamaCppWorkerFrontend>,
|
mut backend: UniquePtr<LlamaCppWorkerFrontend>,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Arc<Tokenizer>,
|
||||||
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
|
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||||
) {
|
) {
|
||||||
// This loop will mostly decode single token at every step, so no need to rely on parallelism
|
// This loop will mostly decode single token at every step, so no need to rely on parallelism
|
||||||
@ -195,17 +190,15 @@ fn scheduler_loop(
|
|||||||
loop {
|
loop {
|
||||||
if let Ok((generation, stream)) = backlog.recv() {
|
if let Ok((generation, stream)) = backlog.recv() {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let tokenizer = tokenizer.clone();
|
|
||||||
let generation_params = generation.generation_params; // copy
|
let generation_params = generation.generation_params; // copy
|
||||||
let sampling_params = generation.sampling_params; // copy
|
let sampling_params = generation.sampling_params; // copy
|
||||||
let input_tokens = Arc::clone(&generation.input_tokens);
|
let input_tokens = Arc::clone(&generation.input_tokens);
|
||||||
|
|
||||||
// Creating the whole InferContext and pushing it to the heap
|
// Creating the whole InferContext and pushing it to the heap
|
||||||
{
|
|
||||||
let ctx = Box::new(InferContext {
|
let ctx = Box::new(InferContext {
|
||||||
start,
|
start,
|
||||||
stream,
|
stream,
|
||||||
tokenizer,
|
tokenizer: &tokenizer,
|
||||||
generation,
|
generation,
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -226,7 +219,6 @@ fn scheduler_loop(
|
|||||||
// Make sure we re-keep track of the OpaqueStream box
|
// Make sure we re-keep track of the OpaqueStream box
|
||||||
let _ = Box::from_raw(boxed_ctx);
|
let _ = Box::from_raw(boxed_ctx);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
info!("IPC channel is closed, exiting the scheduler loop");
|
info!("IPC channel is closed, exiting the scheduler loop");
|
||||||
break;
|
break;
|
||||||
|
@ -33,7 +33,7 @@ mod ffi {
|
|||||||
}
|
}
|
||||||
|
|
||||||
extern "Rust" {
|
extern "Rust" {
|
||||||
type InferContext;
|
type InferContext<'a>;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "C++" {
|
unsafe extern "C++" {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError};
|
use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError};
|
||||||
use text_generation_router::server::ApiDoc;
|
use text_generation_router::server::ApiDoc;
|
||||||
use text_generation_router::{server, usage_stats};
|
use text_generation_router::{server, usage_stats};
|
||||||
@ -162,8 +163,10 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
user_agent: Default::default(),
|
user_agent: Default::default(),
|
||||||
auth_token,
|
auth_token,
|
||||||
};
|
};
|
||||||
let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
|
let tokenizer = Arc::new(
|
||||||
.expect("Failed to retrieve tokenizer");
|
tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
|
||||||
|
.expect("Failed to retrieve tokenizer"),
|
||||||
|
);
|
||||||
let backend = LlamaCppBackend::new(gguf_path, tokenizer, num_cores_per_instance.unwrap_or(0))?;
|
let backend = LlamaCppBackend::new(gguf_path, tokenizer, num_cores_per_instance.unwrap_or(0))?;
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
|
Loading…
Reference in New Issue
Block a user