mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-28 13:32:10 +00:00
feat(backend): expose tokenizer to the GenerationContext to decode token
This commit is contained in:
parent
1473259f84
commit
1149186794
@ -13,6 +13,7 @@ use text_generation_router::validation::{
|
|||||||
};
|
};
|
||||||
use text_generation_router::{FinishReason, Token};
|
use text_generation_router::{FinishReason, Token};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
@ -54,6 +55,7 @@ pub(crate) struct GenerationContext {
|
|||||||
pub(crate) struct InferContext {
|
pub(crate) struct InferContext {
|
||||||
pub(crate) start: Instant,
|
pub(crate) start: Instant,
|
||||||
pub(crate) stream: UnboundedSender<InferResult>,
|
pub(crate) stream: UnboundedSender<InferResult>,
|
||||||
|
pub(crate) tokenizer: Tokenizer,
|
||||||
pub(crate) generation: GenerationContext,
|
pub(crate) generation: GenerationContext,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -72,7 +74,10 @@ pub struct LlamaCppBackend {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl LlamaCppBackend {
|
impl LlamaCppBackend {
|
||||||
pub fn new<P: AsRef<Path> + Send>(model_path: P) -> Result<Self, LlamaCppBackendError> {
|
pub fn new<P: AsRef<Path> + Send>(
|
||||||
|
model_path: P,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
) -> Result<Self, LlamaCppBackendError> {
|
||||||
let path = Arc::new(model_path.as_ref());
|
let path = Arc::new(model_path.as_ref());
|
||||||
if !path.exists() {
|
if !path.exists() {
|
||||||
return Err(LlamaCppBackendError::ModelFileDoesntExist(
|
return Err(LlamaCppBackendError::ModelFileDoesntExist(
|
||||||
@ -93,7 +98,7 @@ impl LlamaCppBackend {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let (submitter, receiver) = channel();
|
let (submitter, receiver) = channel();
|
||||||
let handle = unsafe { spawn(|| scheduler_loop(backend, receiver)) };
|
let handle = unsafe { spawn(|| scheduler_loop(backend, tokenizer, receiver)) };
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
backlog: submitter,
|
backlog: submitter,
|
||||||
scheduler_handle: handle,
|
scheduler_handle: handle,
|
||||||
@ -110,19 +115,25 @@ fn llama_generate_callback(
|
|||||||
) -> bool {
|
) -> bool {
|
||||||
info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})");
|
info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})");
|
||||||
|
|
||||||
// Decode token
|
|
||||||
let token = Token {
|
|
||||||
id: new_token_id,
|
|
||||||
text: "".to_string(),
|
|
||||||
logprob: new_token_logit,
|
|
||||||
special: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
let ctx = unsafe { &mut *ctx };
|
let ctx = unsafe { &mut *ctx };
|
||||||
|
|
||||||
// Append the new token to the generated ones
|
// Append the new token to the generated ones
|
||||||
ctx.generation.generated_tokens.push(new_token_id);
|
ctx.generation.generated_tokens.push(new_token_id);
|
||||||
|
|
||||||
|
// Decode token
|
||||||
|
let token = match ctx.tokenizer.decode(&[new_token_id], false) {
|
||||||
|
Ok(text) => {
|
||||||
|
let special = ctx.tokenizer.get_added_vocabulary().is_special_token(&text);
|
||||||
|
Token {
|
||||||
|
id: new_token_id,
|
||||||
|
text,
|
||||||
|
logprob: new_token_logit,
|
||||||
|
special,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => panic!("Failed to decode token"),
|
||||||
|
};
|
||||||
|
|
||||||
// Create the streamed response
|
// Create the streamed response
|
||||||
let response = match is_final {
|
let response = match is_final {
|
||||||
false => InferStreamResponse::Intermediate {
|
false => InferStreamResponse::Intermediate {
|
||||||
@ -131,10 +142,11 @@ fn llama_generate_callback(
|
|||||||
},
|
},
|
||||||
true => {
|
true => {
|
||||||
// Decode the whole text
|
// Decode the whole text
|
||||||
let text = String::new();
|
match ctx
|
||||||
|
.tokenizer
|
||||||
// Stream end response
|
.decode(&ctx.generation.generated_tokens, false)
|
||||||
InferStreamResponse::End {
|
{
|
||||||
|
Ok(text) => InferStreamResponse::End {
|
||||||
token,
|
token,
|
||||||
top_tokens: vec![],
|
top_tokens: vec![],
|
||||||
generated_text: GeneratedText {
|
generated_text: GeneratedText {
|
||||||
@ -145,7 +157,11 @@ fn llama_generate_callback(
|
|||||||
},
|
},
|
||||||
start: ctx.start,
|
start: ctx.start,
|
||||||
queued: ctx.start,
|
queued: ctx.start,
|
||||||
|
},
|
||||||
|
Err(_) => panic!("Failed to decode token"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stream end response
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -162,6 +178,7 @@ fn llama_generate_callback(
|
|||||||
|
|
||||||
unsafe fn scheduler_loop(
|
unsafe fn scheduler_loop(
|
||||||
mut backend: UniquePtr<LlamaCppBackendImpl>,
|
mut backend: UniquePtr<LlamaCppBackendImpl>,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
|
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||||
) {
|
) {
|
||||||
// This loop will mostly decode single token at every step, so no need to rely on parallelism
|
// This loop will mostly decode single token at every step, so no need to rely on parallelism
|
||||||
@ -170,6 +187,7 @@ unsafe fn scheduler_loop(
|
|||||||
loop {
|
loop {
|
||||||
if let Ok((generation, stream)) = backlog.recv() {
|
if let Ok((generation, stream)) = backlog.recv() {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
let tokenizer = tokenizer.clone();
|
||||||
let generation_params = generation.generation_params; // copy
|
let generation_params = generation.generation_params; // copy
|
||||||
let sampling_params = generation.sampling_params; // copy
|
let sampling_params = generation.sampling_params; // copy
|
||||||
let input_tokens = Arc::clone(&generation.input_tokens);
|
let input_tokens = Arc::clone(&generation.input_tokens);
|
||||||
@ -179,6 +197,7 @@ unsafe fn scheduler_loop(
|
|||||||
let ctx = Box::new(InferContext {
|
let ctx = Box::new(InferContext {
|
||||||
start,
|
start,
|
||||||
stream,
|
stream,
|
||||||
|
tokenizer,
|
||||||
generation,
|
generation,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackend
|
|||||||
use text_generation_router::server::ApiDoc;
|
use text_generation_router::server::ApiDoc;
|
||||||
use text_generation_router::{server, usage_stats};
|
use text_generation_router::{server, usage_stats};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use tokenizers::FromPretrainedParameters;
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -36,9 +37,9 @@ struct Args {
|
|||||||
port: u16,
|
port: u16,
|
||||||
#[clap(long, env, help = "Path to GGUF model file(s) to load")]
|
#[clap(long, env, help = "Path to GGUF model file(s) to load")]
|
||||||
gguf_path: PathBuf,
|
gguf_path: PathBuf,
|
||||||
#[clap(long, env, default_value = "1", help = "Number of model instance(s)")]
|
// #[clap(long, env, default_value = "1", help = "Number of model instance(s)")]
|
||||||
num_model_instance: u16,
|
// num_model_instance: u16,
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
#[clap(long, env, required = true)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
tokenizer_config_path: Option<String>,
|
tokenizer_config_path: Option<String>,
|
||||||
@ -94,7 +95,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
hostname,
|
hostname,
|
||||||
port,
|
port,
|
||||||
gguf_path,
|
gguf_path,
|
||||||
num_model_instance,
|
// num_model_instance,
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
@ -153,7 +154,17 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let backend = LlamaCppBackend::new(gguf_path)?;
|
let auth_token = std::env::var("HF_TOKEN")
|
||||||
|
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||||
|
.ok();
|
||||||
|
let options = FromPretrainedParameters {
|
||||||
|
revision: revision.clone().unwrap_or("main".to_string()),
|
||||||
|
user_agent: Default::default(),
|
||||||
|
auth_token,
|
||||||
|
};
|
||||||
|
let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
|
||||||
|
.expect("Failed to retrieve tokenizer");
|
||||||
|
let backend = LlamaCppBackend::new(gguf_path, tokenizer)?;
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
|
Loading…
Reference in New Issue
Block a user