mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-28 13:32:10 +00:00
feat(backend): bind incoming request to the server
This commit is contained in:
parent
d4aee42fd8
commit
612f2f939f
@ -2,18 +2,54 @@ use crate::ffi::{
|
|||||||
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams,
|
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams,
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::{Exception, UniquePtr};
|
use cxx::UniquePtr;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::mpsc::{channel, Receiver, SendError, Sender};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::thread::spawn;
|
use std::thread::{spawn, JoinHandle};
|
||||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::{
|
||||||
|
ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
|
||||||
|
};
|
||||||
|
use text_generation_router::Token;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||||
|
use tokio::sync::TryAcquireError;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::info;
|
use tracing::{error, info};
|
||||||
|
|
||||||
unsafe impl Send for LlamaCppBackendImpl {}
|
unsafe impl Send for LlamaCppBackendImpl {}
|
||||||
|
|
||||||
|
impl From<&ValidParameters> for SamplingParams {
|
||||||
|
fn from(v: &ValidParameters) -> Self {
|
||||||
|
Self {
|
||||||
|
top_k: v.top_k,
|
||||||
|
top_p: v.top_p,
|
||||||
|
frequency_penalty: v.frequency_penalty,
|
||||||
|
repetition_penalty: v.repetition_penalty,
|
||||||
|
seed: v.seed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&ValidStoppingParameters> for GenerationParams {
|
||||||
|
fn from(v: &ValidStoppingParameters) -> Self {
|
||||||
|
Self {
|
||||||
|
max_new_tokens: v.max_new_tokens,
|
||||||
|
ignore_eos_token: v.ignore_eos_token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(debug_assertions, derive(Debug))]
|
||||||
|
struct InferContext {
|
||||||
|
pub(crate) stream: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||||
|
pub(crate) input_tokens: Arc<Vec<u32>>,
|
||||||
|
pub(crate) generated_tokens: Vec<u32>,
|
||||||
|
pub(crate) generation_params: GenerationParams,
|
||||||
|
pub(crate) sampling_params: SamplingParams,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum LlamaCppBackendError {
|
pub enum LlamaCppBackendError {
|
||||||
#[error("Provided GGUF model path {0} doesn't exist")]
|
#[error("Provided GGUF model path {0} doesn't exist")]
|
||||||
@ -23,7 +59,10 @@ pub enum LlamaCppBackendError {
|
|||||||
ModelInitializationFailed(PathBuf, String),
|
ModelInitializationFailed(PathBuf, String),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LlamaCppBackend {}
|
pub struct LlamaCppBackend {
|
||||||
|
backlog: Sender<InferContext>,
|
||||||
|
scheduler_handle: JoinHandle<()>,
|
||||||
|
}
|
||||||
|
|
||||||
impl LlamaCppBackend {
|
impl LlamaCppBackend {
|
||||||
pub fn new<P: AsRef<Path> + Send>(model_path: P) -> Result<Self, LlamaCppBackendError> {
|
pub fn new<P: AsRef<Path> + Send>(model_path: P) -> Result<Self, LlamaCppBackendError> {
|
||||||
@ -34,7 +73,7 @@ impl LlamaCppBackend {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut backend = create_single_worker_backend(path.to_str().unwrap()).map_err(|err| {
|
let backend = create_single_worker_backend(path.to_str().unwrap()).map_err(|err| {
|
||||||
LlamaCppBackendError::ModelInitializationFailed(
|
LlamaCppBackendError::ModelInitializationFailed(
|
||||||
path.to_path_buf(),
|
path.to_path_buf(),
|
||||||
err.what().to_string(),
|
err.what().to_string(),
|
||||||
@ -46,43 +85,100 @@ impl LlamaCppBackend {
|
|||||||
path.display()
|
path.display()
|
||||||
);
|
);
|
||||||
|
|
||||||
let j = spawn(|| scheduler_loop(backend));
|
let (submitter, receiver) = channel();
|
||||||
j.join().ok();
|
let handle = spawn(|| scheduler_loop(backend, receiver));
|
||||||
Ok(Self {})
|
Ok(Self {
|
||||||
|
backlog: submitter,
|
||||||
|
scheduler_handle: handle,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scheduler_loop(mut backend: UniquePtr<LlamaCppBackendImpl>) {
|
fn scheduler_loop(
|
||||||
println!("Scheduler loop");
|
mut backend: UniquePtr<LlamaCppBackendImpl>,
|
||||||
let tokens = [128000u32, 5159, 836, 374, 23809];
|
mut backlog: Receiver<InferContext>,
|
||||||
let mut generated = vec![0u32; 16];
|
) {
|
||||||
let generation_params = GenerationParams {
|
loop {
|
||||||
max_new_tokens: generated.len() as u32,
|
println!("Looping");
|
||||||
};
|
if let Ok(mut ctx) = backlog.recv() {
|
||||||
let sampling_params = SamplingParams::default();
|
println!("{ctx:?}, {}", &ctx.generated_tokens.capacity());
|
||||||
|
|
||||||
match backend.pin_mut().generate(
|
match backend.pin_mut().generate(
|
||||||
&tokens,
|
&ctx.input_tokens,
|
||||||
&mut generated,
|
&mut ctx.generated_tokens,
|
||||||
&generation_params,
|
&ctx.generation_params,
|
||||||
&sampling_params,
|
&ctx.sampling_params,
|
||||||
|new_token_id: u32, is_eos: bool| println!("Generated {new_token_id} (is_eos: {is_eos})"),
|
|new_token_id: u32, new_token_logit: f32, is_eos: bool| {
|
||||||
|
let response = InferStreamResponse::Intermediate {
|
||||||
|
token: Token {
|
||||||
|
id: new_token_id,
|
||||||
|
text: "".to_string(),
|
||||||
|
logprob: new_token_logit,
|
||||||
|
special: false,
|
||||||
|
},
|
||||||
|
top_tokens: vec![],
|
||||||
|
};
|
||||||
|
println!("Generated token: {response:?}");
|
||||||
|
// let _ = tokio::spawn(async {
|
||||||
|
// match ctx.stream.send(Ok(response)) {
|
||||||
|
// Ok(_) => {}
|
||||||
|
// Err(ref err) => {
|
||||||
|
// error!(
|
||||||
|
// "Failed to send back token to the client: {}",
|
||||||
|
// err.to_string()
|
||||||
|
// );
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// });
|
||||||
|
},
|
||||||
) {
|
) {
|
||||||
Ok(n_tokens) => {
|
Ok(n_tokens) => {
|
||||||
generated.truncate(n_tokens);
|
unsafe {
|
||||||
println!("Generated {} tokens -> {:?}", n_tokens, generated);
|
ctx.generated_tokens.set_len(n_tokens);
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"Generated {} tokens -> {:?}",
|
||||||
|
n_tokens, &ctx.generated_tokens
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Err(err) => println!("Error: {}", err),
|
Err(err) => println!("Error: {}", err),
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
info!("IPC channel is closed, exiting the scheduler loop");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Backend for LlamaCppBackend {
|
impl Backend for LlamaCppBackend {
|
||||||
fn schedule(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
_request: ValidGenerateRequest,
|
request: ValidGenerateRequest,
|
||||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
Err(InferError::GenerationError("Not implemented yet".into()))
|
if let Some(input_ids) = request.input_ids {
|
||||||
|
let (sx, rx) = unbounded_channel();
|
||||||
|
let sampling_params = SamplingParams::from(&request.parameters);
|
||||||
|
let generation_params = GenerationParams::from(&request.stopping_parameters);
|
||||||
|
|
||||||
|
let ctx = InferContext {
|
||||||
|
stream: sx,
|
||||||
|
input_tokens: Arc::clone(&input_ids),
|
||||||
|
generated_tokens: Vec::with_capacity(generation_params.max_new_tokens as usize),
|
||||||
|
generation_params,
|
||||||
|
sampling_params,
|
||||||
|
};
|
||||||
|
|
||||||
|
match self.backlog.send(ctx) {
|
||||||
|
Ok(_) => Ok(UnboundedReceiverStream::new(rx)),
|
||||||
|
Err(_) => Err(InferError::GenerationError(
|
||||||
|
"Failed to sent the request".to_string(),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(InferError::GenerationError(
|
||||||
|
"Unsupported modalities".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _: bool) -> bool {
|
async fn health(&self, _: bool) -> bool {
|
||||||
|
@ -16,11 +16,13 @@ impl Default for SamplingParams {
|
|||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")]
|
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
struct GenerationParams {
|
struct GenerationParams {
|
||||||
max_new_tokens: u32,
|
max_new_tokens: u32,
|
||||||
ignore_eos_token: bool,
|
ignore_eos_token: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
struct SamplingParams {
|
struct SamplingParams {
|
||||||
top_k: u32,
|
top_k: u32,
|
||||||
top_p: f32,
|
top_p: f32,
|
||||||
|
Loading…
Reference in New Issue
Block a user