From 1149186794a919d58cf5d43c7a497d81555f20c5 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 4 Nov 2024 23:01:57 +0100 Subject: [PATCH] feat(backend): expose tokenizer to the GenerationContext to decode token --- backends/llamacpp/src/backend.rs | 65 +++++++++++++++++++++----------- backends/llamacpp/src/main.rs | 21 ++++++++--- 2 files changed, 58 insertions(+), 28 deletions(-) diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 531a07dc..08fac675 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -13,6 +13,7 @@ use text_generation_router::validation::{ }; use text_generation_router::{FinishReason, Token}; use thiserror::Error; +use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -54,6 +55,7 @@ pub(crate) struct GenerationContext { pub(crate) struct InferContext { pub(crate) start: Instant, pub(crate) stream: UnboundedSender, + pub(crate) tokenizer: Tokenizer, pub(crate) generation: GenerationContext, } @@ -72,7 +74,10 @@ pub struct LlamaCppBackend { } impl LlamaCppBackend { - pub fn new + Send>(model_path: P) -> Result { + pub fn new + Send>( + model_path: P, + tokenizer: Tokenizer, + ) -> Result { let path = Arc::new(model_path.as_ref()); if !path.exists() { return Err(LlamaCppBackendError::ModelFileDoesntExist( @@ -93,7 +98,7 @@ impl LlamaCppBackend { ); let (submitter, receiver) = channel(); - let handle = unsafe { spawn(|| scheduler_loop(backend, receiver)) }; + let handle = unsafe { spawn(|| scheduler_loop(backend, tokenizer, receiver)) }; Ok(Self { backlog: submitter, scheduler_handle: handle, @@ -110,19 +115,25 @@ fn llama_generate_callback( ) -> bool { info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})"); - // Decode token - let token = Token { - id: new_token_id, - text: "".to_string(), - logprob: new_token_logit, - special: false, - }; - let ctx = unsafe { &mut *ctx }; // Append the new token to the generated ones ctx.generation.generated_tokens.push(new_token_id); + // Decode token + let token = match ctx.tokenizer.decode(&[new_token_id], false) { + Ok(text) => { + let special = ctx.tokenizer.get_added_vocabulary().is_special_token(&text); + Token { + id: new_token_id, + text, + logprob: new_token_logit, + special, + } + } + Err(_) => panic!("Failed to decode token"), + }; + // Create the streamed response let response = match is_final { false => InferStreamResponse::Intermediate { @@ -131,21 +142,26 @@ fn llama_generate_callback( }, true => { // Decode the whole text - let text = String::new(); + match ctx + .tokenizer + .decode(&ctx.generation.generated_tokens, false) + { + Ok(text) => InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text: GeneratedText { + text, + generated_tokens: n_generated_tokens as u32, + finish_reason: FinishReason::Length, + seed: Some(ctx.generation.sampling_params.seed), + }, + start: ctx.start, + queued: ctx.start, + }, + Err(_) => panic!("Failed to decode token"), + } // Stream end response - InferStreamResponse::End { - token, - top_tokens: vec![], - generated_text: GeneratedText { - text, - generated_tokens: n_generated_tokens as u32, - finish_reason: FinishReason::Length, - seed: Some(ctx.generation.sampling_params.seed), - }, - start: ctx.start, - queued: ctx.start, - } } }; @@ -162,6 +178,7 @@ fn llama_generate_callback( unsafe fn scheduler_loop( mut backend: UniquePtr, + tokenizer: Tokenizer, backlog: Receiver<(GenerationContext, UnboundedSender)>, ) { // This loop will mostly decode single token at every step, so no need to rely on parallelism @@ -170,6 +187,7 @@ unsafe fn scheduler_loop( loop { if let Ok((generation, stream)) = backlog.recv() { let start = Instant::now(); + let tokenizer = tokenizer.clone(); let generation_params = generation.generation_params; // copy let sampling_params = generation.sampling_params; // copy let input_tokens = Arc::clone(&generation.input_tokens); @@ -179,6 +197,7 @@ unsafe fn scheduler_loop( let ctx = Box::new(InferContext { start, stream, + tokenizer, generation, }); diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index f128a6a3..c5d735ab 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -4,6 +4,7 @@ use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackend use text_generation_router::server::ApiDoc; use text_generation_router::{server, usage_stats}; use thiserror::Error; +use tokenizers::FromPretrainedParameters; /// App Configuration #[derive(Parser, Debug)] @@ -36,9 +37,9 @@ struct Args { port: u16, #[clap(long, env, help = "Path to GGUF model file(s) to load")] gguf_path: PathBuf, - #[clap(long, env, default_value = "1", help = "Number of model instance(s)")] - num_model_instance: u16, - #[clap(default_value = "bigscience/bloom", long, env)] + // #[clap(long, env, default_value = "1", help = "Number of model instance(s)")] + // num_model_instance: u16, + #[clap(long, env, required = true)] tokenizer_name: String, #[clap(long, env)] tokenizer_config_path: Option, @@ -94,7 +95,7 @@ async fn main() -> Result<(), RouterError> { hostname, port, gguf_path, - num_model_instance, + // num_model_instance, tokenizer_name, tokenizer_config_path, revision, @@ -153,7 +154,17 @@ async fn main() -> Result<(), RouterError> { } } - let backend = LlamaCppBackend::new(gguf_path)?; + let auth_token = std::env::var("HF_TOKEN") + .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")) + .ok(); + let options = FromPretrainedParameters { + revision: revision.clone().unwrap_or("main".to_string()), + user_agent: Default::default(), + auth_token, + }; + let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options)) + .expect("Failed to retrieve tokenizer"); + let backend = LlamaCppBackend::new(gguf_path, tokenizer)?; // Run server server::run(