feat(backend): expose tokenizer to the GenerationContext to decode token

This commit is contained in:
Morgan Funtowicz 2024-11-04 23:01:57 +01:00
parent 1473259f84
commit 1149186794
2 changed files with 58 additions and 28 deletions

View File

@ -13,6 +13,7 @@ use text_generation_router::validation::{
};
use text_generation_router::{FinishReason, Token};
use thiserror::Error;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
@ -54,6 +55,7 @@ pub(crate) struct GenerationContext {
pub(crate) struct InferContext {
pub(crate) start: Instant,
pub(crate) stream: UnboundedSender<InferResult>,
pub(crate) tokenizer: Tokenizer,
pub(crate) generation: GenerationContext,
}
@ -72,7 +74,10 @@ pub struct LlamaCppBackend {
}
impl LlamaCppBackend {
pub fn new<P: AsRef<Path> + Send>(model_path: P) -> Result<Self, LlamaCppBackendError> {
pub fn new<P: AsRef<Path> + Send>(
model_path: P,
tokenizer: Tokenizer,
) -> Result<Self, LlamaCppBackendError> {
let path = Arc::new(model_path.as_ref());
if !path.exists() {
return Err(LlamaCppBackendError::ModelFileDoesntExist(
@ -93,7 +98,7 @@ impl LlamaCppBackend {
);
let (submitter, receiver) = channel();
let handle = unsafe { spawn(|| scheduler_loop(backend, receiver)) };
let handle = unsafe { spawn(|| scheduler_loop(backend, tokenizer, receiver)) };
Ok(Self {
backlog: submitter,
scheduler_handle: handle,
@ -110,19 +115,25 @@ fn llama_generate_callback(
) -> bool {
info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})");
// Decode token
let token = Token {
id: new_token_id,
text: "".to_string(),
logprob: new_token_logit,
special: false,
};
let ctx = unsafe { &mut *ctx };
// Append the new token to the generated ones
ctx.generation.generated_tokens.push(new_token_id);
// Decode token
let token = match ctx.tokenizer.decode(&[new_token_id], false) {
Ok(text) => {
let special = ctx.tokenizer.get_added_vocabulary().is_special_token(&text);
Token {
id: new_token_id,
text,
logprob: new_token_logit,
special,
}
}
Err(_) => panic!("Failed to decode token"),
};
// Create the streamed response
let response = match is_final {
false => InferStreamResponse::Intermediate {
@ -131,21 +142,26 @@ fn llama_generate_callback(
},
true => {
// Decode the whole text
let text = String::new();
match ctx
.tokenizer
.decode(&ctx.generation.generated_tokens, false)
{
Ok(text) => InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text,
generated_tokens: n_generated_tokens as u32,
finish_reason: FinishReason::Length,
seed: Some(ctx.generation.sampling_params.seed),
},
start: ctx.start,
queued: ctx.start,
},
Err(_) => panic!("Failed to decode token"),
}
// Stream end response
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text,
generated_tokens: n_generated_tokens as u32,
finish_reason: FinishReason::Length,
seed: Some(ctx.generation.sampling_params.seed),
},
start: ctx.start,
queued: ctx.start,
}
}
};
@ -162,6 +178,7 @@ fn llama_generate_callback(
unsafe fn scheduler_loop(
mut backend: UniquePtr<LlamaCppBackendImpl>,
tokenizer: Tokenizer,
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
) {
// This loop will mostly decode single token at every step, so no need to rely on parallelism
@ -170,6 +187,7 @@ unsafe fn scheduler_loop(
loop {
if let Ok((generation, stream)) = backlog.recv() {
let start = Instant::now();
let tokenizer = tokenizer.clone();
let generation_params = generation.generation_params; // copy
let sampling_params = generation.sampling_params; // copy
let input_tokens = Arc::clone(&generation.input_tokens);
@ -179,6 +197,7 @@ unsafe fn scheduler_loop(
let ctx = Box::new(InferContext {
start,
stream,
tokenizer,
generation,
});

View File

@ -4,6 +4,7 @@ use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackend
use text_generation_router::server::ApiDoc;
use text_generation_router::{server, usage_stats};
use thiserror::Error;
use tokenizers::FromPretrainedParameters;
/// App Configuration
#[derive(Parser, Debug)]
@ -36,9 +37,9 @@ struct Args {
port: u16,
#[clap(long, env, help = "Path to GGUF model file(s) to load")]
gguf_path: PathBuf,
#[clap(long, env, default_value = "1", help = "Number of model instance(s)")]
num_model_instance: u16,
#[clap(default_value = "bigscience/bloom", long, env)]
// #[clap(long, env, default_value = "1", help = "Number of model instance(s)")]
// num_model_instance: u16,
#[clap(long, env, required = true)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
@ -94,7 +95,7 @@ async fn main() -> Result<(), RouterError> {
hostname,
port,
gguf_path,
num_model_instance,
// num_model_instance,
tokenizer_name,
tokenizer_config_path,
revision,
@ -153,7 +154,17 @@ async fn main() -> Result<(), RouterError> {
}
}
let backend = LlamaCppBackend::new(gguf_path)?;
let auth_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
let options = FromPretrainedParameters {
revision: revision.clone().unwrap_or("main".to_string()),
user_agent: Default::default(),
auth_token,
};
let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
.expect("Failed to retrieve tokenizer");
let backend = LlamaCppBackend::new(gguf_path, tokenizer)?;
// Run server
server::run(