mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 12:02:08 +00:00
feat(backend): expose tokenizer to the GenerationContext to decode token
This commit is contained in:
parent
1473259f84
commit
1149186794
@ -13,6 +13,7 @@ use text_generation_router::validation::{
|
||||
};
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
use thiserror::Error;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
@ -54,6 +55,7 @@ pub(crate) struct GenerationContext {
|
||||
pub(crate) struct InferContext {
|
||||
pub(crate) start: Instant,
|
||||
pub(crate) stream: UnboundedSender<InferResult>,
|
||||
pub(crate) tokenizer: Tokenizer,
|
||||
pub(crate) generation: GenerationContext,
|
||||
}
|
||||
|
||||
@ -72,7 +74,10 @@ pub struct LlamaCppBackend {
|
||||
}
|
||||
|
||||
impl LlamaCppBackend {
|
||||
pub fn new<P: AsRef<Path> + Send>(model_path: P) -> Result<Self, LlamaCppBackendError> {
|
||||
pub fn new<P: AsRef<Path> + Send>(
|
||||
model_path: P,
|
||||
tokenizer: Tokenizer,
|
||||
) -> Result<Self, LlamaCppBackendError> {
|
||||
let path = Arc::new(model_path.as_ref());
|
||||
if !path.exists() {
|
||||
return Err(LlamaCppBackendError::ModelFileDoesntExist(
|
||||
@ -93,7 +98,7 @@ impl LlamaCppBackend {
|
||||
);
|
||||
|
||||
let (submitter, receiver) = channel();
|
||||
let handle = unsafe { spawn(|| scheduler_loop(backend, receiver)) };
|
||||
let handle = unsafe { spawn(|| scheduler_loop(backend, tokenizer, receiver)) };
|
||||
Ok(Self {
|
||||
backlog: submitter,
|
||||
scheduler_handle: handle,
|
||||
@ -110,19 +115,25 @@ fn llama_generate_callback(
|
||||
) -> bool {
|
||||
info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})");
|
||||
|
||||
// Decode token
|
||||
let token = Token {
|
||||
id: new_token_id,
|
||||
text: "".to_string(),
|
||||
logprob: new_token_logit,
|
||||
special: false,
|
||||
};
|
||||
|
||||
let ctx = unsafe { &mut *ctx };
|
||||
|
||||
// Append the new token to the generated ones
|
||||
ctx.generation.generated_tokens.push(new_token_id);
|
||||
|
||||
// Decode token
|
||||
let token = match ctx.tokenizer.decode(&[new_token_id], false) {
|
||||
Ok(text) => {
|
||||
let special = ctx.tokenizer.get_added_vocabulary().is_special_token(&text);
|
||||
Token {
|
||||
id: new_token_id,
|
||||
text,
|
||||
logprob: new_token_logit,
|
||||
special,
|
||||
}
|
||||
}
|
||||
Err(_) => panic!("Failed to decode token"),
|
||||
};
|
||||
|
||||
// Create the streamed response
|
||||
let response = match is_final {
|
||||
false => InferStreamResponse::Intermediate {
|
||||
@ -131,21 +142,26 @@ fn llama_generate_callback(
|
||||
},
|
||||
true => {
|
||||
// Decode the whole text
|
||||
let text = String::new();
|
||||
match ctx
|
||||
.tokenizer
|
||||
.decode(&ctx.generation.generated_tokens, false)
|
||||
{
|
||||
Ok(text) => InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text,
|
||||
generated_tokens: n_generated_tokens as u32,
|
||||
finish_reason: FinishReason::Length,
|
||||
seed: Some(ctx.generation.sampling_params.seed),
|
||||
},
|
||||
start: ctx.start,
|
||||
queued: ctx.start,
|
||||
},
|
||||
Err(_) => panic!("Failed to decode token"),
|
||||
}
|
||||
|
||||
// Stream end response
|
||||
InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text,
|
||||
generated_tokens: n_generated_tokens as u32,
|
||||
finish_reason: FinishReason::Length,
|
||||
seed: Some(ctx.generation.sampling_params.seed),
|
||||
},
|
||||
start: ctx.start,
|
||||
queued: ctx.start,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -162,6 +178,7 @@ fn llama_generate_callback(
|
||||
|
||||
unsafe fn scheduler_loop(
|
||||
mut backend: UniquePtr<LlamaCppBackendImpl>,
|
||||
tokenizer: Tokenizer,
|
||||
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||
) {
|
||||
// This loop will mostly decode single token at every step, so no need to rely on parallelism
|
||||
@ -170,6 +187,7 @@ unsafe fn scheduler_loop(
|
||||
loop {
|
||||
if let Ok((generation, stream)) = backlog.recv() {
|
||||
let start = Instant::now();
|
||||
let tokenizer = tokenizer.clone();
|
||||
let generation_params = generation.generation_params; // copy
|
||||
let sampling_params = generation.sampling_params; // copy
|
||||
let input_tokens = Arc::clone(&generation.input_tokens);
|
||||
@ -179,6 +197,7 @@ unsafe fn scheduler_loop(
|
||||
let ctx = Box::new(InferContext {
|
||||
start,
|
||||
stream,
|
||||
tokenizer,
|
||||
generation,
|
||||
});
|
||||
|
||||
|
@ -4,6 +4,7 @@ use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackend
|
||||
use text_generation_router::server::ApiDoc;
|
||||
use text_generation_router::{server, usage_stats};
|
||||
use thiserror::Error;
|
||||
use tokenizers::FromPretrainedParameters;
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
@ -36,9 +37,9 @@ struct Args {
|
||||
port: u16,
|
||||
#[clap(long, env, help = "Path to GGUF model file(s) to load")]
|
||||
gguf_path: PathBuf,
|
||||
#[clap(long, env, default_value = "1", help = "Number of model instance(s)")]
|
||||
num_model_instance: u16,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
// #[clap(long, env, default_value = "1", help = "Number of model instance(s)")]
|
||||
// num_model_instance: u16,
|
||||
#[clap(long, env, required = true)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
@ -94,7 +95,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
hostname,
|
||||
port,
|
||||
gguf_path,
|
||||
num_model_instance,
|
||||
// num_model_instance,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
@ -153,7 +154,17 @@ async fn main() -> Result<(), RouterError> {
|
||||
}
|
||||
}
|
||||
|
||||
let backend = LlamaCppBackend::new(gguf_path)?;
|
||||
let auth_token = std::env::var("HF_TOKEN")
|
||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||
.ok();
|
||||
let options = FromPretrainedParameters {
|
||||
revision: revision.clone().unwrap_or("main".to_string()),
|
||||
user_agent: Default::default(),
|
||||
auth_token,
|
||||
};
|
||||
let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
|
||||
.expect("Failed to retrieve tokenizer");
|
||||
let backend = LlamaCppBackend::new(gguf_path, tokenizer)?;
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
|
Loading…
Reference in New Issue
Block a user