mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
feat(backend): avoid dropping the boxed stream at the end of the callback
This commit is contained in:
parent
612f2f939f
commit
b50dcddbb8
@ -21,6 +21,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
|
|
||||||
|
|
||||||
#include "backends/llamacpp/src/lib.rs.h"
|
#include "backends/llamacpp/src/lib.rs.h"
|
||||||
|
#include "rust/cxx.h"
|
||||||
|
|
||||||
|
|
||||||
namespace huggingface::tgi::backends::llamacpp {
|
namespace huggingface::tgi::backends::llamacpp {
|
||||||
@ -61,17 +62,22 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
|
|
||||||
explicit llama_cpp_backend_impl_t(multi_worker_backend_t &&backend) : mInner_(std::move(backend)) {}
|
explicit llama_cpp_backend_impl_t(multi_worker_backend_t &&backend) : mInner_(std::move(backend)) {}
|
||||||
|
|
||||||
size_t generate(
|
size_t stream(
|
||||||
rust::Slice<const uint32_t> input_tokens,
|
rust::Slice<const uint32_t> input_tokens,
|
||||||
rust::Slice <uint32_t> generated_tokens,
|
rust::Slice <uint32_t> generated_tokens,
|
||||||
const generation_params_t &generation_params,
|
const generation_params_t generation_params,
|
||||||
const sampling_params_t &sampling_params,
|
const sampling_params_t &sampling_params,
|
||||||
rust::Fn<void(uint32_t, float_t, bool)> callback
|
OpaqueStream *stream,
|
||||||
|
rust::Fn<void(OpaqueStream *, uint32_t, float_t, bool)> callback
|
||||||
) {
|
) {
|
||||||
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
|
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
|
||||||
static auto inner_fw = [=, &generation_params, &sampling_params]<has_emplace_generate T>(T &&backend)
|
static auto inner_fw = [=, &sampling_params, &stream, &callback]<has_emplace_generate T>(T &&backend)
|
||||||
-> std::expected<size_t, backend_error_t> {
|
-> std::expected<size_t, backend_error_t> {
|
||||||
|
|
||||||
|
auto context_forwarding_callback = [=, &stream](uint32_t new_token_id, float_t logits, bool is_eos){
|
||||||
|
callback(stream, new_token_id, logits, is_eos);
|
||||||
|
};
|
||||||
|
|
||||||
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
|
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
|
||||||
auto input_tokens_v =
|
auto input_tokens_v =
|
||||||
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
|
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
|
||||||
@ -79,7 +85,12 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
std::span(reinterpret_cast<llama_token *>(generated_tokens.data()), generated_tokens.size());
|
std::span(reinterpret_cast<llama_token *>(generated_tokens.data()), generated_tokens.size());
|
||||||
|
|
||||||
return backend.generate(
|
return backend.generate(
|
||||||
input_tokens_v, generated_tokens_v, generation_params, sampling_params, callback);
|
input_tokens_v,
|
||||||
|
generated_tokens_v,
|
||||||
|
generation_params,
|
||||||
|
sampling_params,
|
||||||
|
context_forwarding_callback
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
if (const auto result = std::visit(inner_fw, mInner_); result.has_value()) {
|
if (const auto result = std::visit(inner_fw, mInner_); result.has_value()) {
|
||||||
|
@ -1,23 +1,27 @@
|
|||||||
use crate::ffi::{
|
use crate::ffi::{
|
||||||
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams,
|
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams,
|
||||||
};
|
};
|
||||||
|
use crate::OpaqueStream;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
use cxx::UniquePtr;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::mpsc::{channel, Receiver, SendError, Sender};
|
use std::sync::mpsc::{channel, Receiver, SendError, Sender};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::thread::{spawn, JoinHandle};
|
use std::thread::{spawn, JoinHandle};
|
||||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::{
|
use text_generation_router::validation::{
|
||||||
ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
|
ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
|
||||||
};
|
};
|
||||||
use text_generation_router::Token;
|
use text_generation_router::{FinishReason, Token};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||||
use tokio::sync::TryAcquireError;
|
use tokio::sync::TryAcquireError;
|
||||||
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
|
|
||||||
|
type BoxedOpaqueStream = Box<OpaqueStream>;
|
||||||
|
|
||||||
unsafe impl Send for LlamaCppBackendImpl {}
|
unsafe impl Send for LlamaCppBackendImpl {}
|
||||||
|
|
||||||
impl From<&ValidParameters> for SamplingParams {
|
impl From<&ValidParameters> for SamplingParams {
|
||||||
@ -86,7 +90,7 @@ impl LlamaCppBackend {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let (submitter, receiver) = channel();
|
let (submitter, receiver) = channel();
|
||||||
let handle = spawn(|| scheduler_loop(backend, receiver));
|
let handle = unsafe { spawn(|| scheduler_loop(backend, receiver)) };
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
backlog: submitter,
|
backlog: submitter,
|
||||||
scheduler_handle: handle,
|
scheduler_handle: handle,
|
||||||
@ -94,47 +98,59 @@ impl LlamaCppBackend {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scheduler_loop(
|
fn llama_generate_callback(
|
||||||
|
channel: *mut OpaqueStream,
|
||||||
|
new_token_id: u32,
|
||||||
|
new_token_logit: f32,
|
||||||
|
is_eos: bool,
|
||||||
|
) {
|
||||||
|
let response = InferStreamResponse::Intermediate {
|
||||||
|
token: Token {
|
||||||
|
id: new_token_id,
|
||||||
|
text: "".to_string(),
|
||||||
|
logprob: new_token_logit,
|
||||||
|
special: false,
|
||||||
|
},
|
||||||
|
top_tokens: vec![],
|
||||||
|
};
|
||||||
|
println!("Generated token: {new_token_id} -> logits={new_token_logit}, is_eos={is_eos}");
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
if let Err(ref err) = (*channel).0.send(Ok(response)) {
|
||||||
|
error!(
|
||||||
|
"Failed to send back token to the client: {}",
|
||||||
|
err.to_string()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn scheduler_loop(
|
||||||
mut backend: UniquePtr<LlamaCppBackendImpl>,
|
mut backend: UniquePtr<LlamaCppBackendImpl>,
|
||||||
mut backlog: Receiver<InferContext>,
|
mut backlog: Receiver<InferContext>,
|
||||||
) {
|
) {
|
||||||
loop {
|
loop {
|
||||||
println!("Looping");
|
|
||||||
if let Ok(mut ctx) = backlog.recv() {
|
if let Ok(mut ctx) = backlog.recv() {
|
||||||
println!("{ctx:?}, {}", &ctx.generated_tokens.capacity());
|
let stream = BoxedOpaqueStream::new(OpaqueStream(ctx.stream));
|
||||||
match backend.pin_mut().generate(
|
let stream_ptr = Box::into_raw(stream);
|
||||||
|
let result = backend.pin_mut().stream(
|
||||||
&ctx.input_tokens,
|
&ctx.input_tokens,
|
||||||
&mut ctx.generated_tokens,
|
&mut ctx.generated_tokens,
|
||||||
&ctx.generation_params,
|
ctx.generation_params,
|
||||||
&ctx.sampling_params,
|
&ctx.sampling_params,
|
||||||
|new_token_id: u32, new_token_logit: f32, is_eos: bool| {
|
stream_ptr,
|
||||||
let response = InferStreamResponse::Intermediate {
|
llama_generate_callback,
|
||||||
token: Token {
|
);
|
||||||
id: new_token_id,
|
|
||||||
text: "".to_string(),
|
// Make sure we re-keep track of the OpaqueStream box
|
||||||
logprob: new_token_logit,
|
let _ = Box::from_raw(stream_ptr);
|
||||||
special: false,
|
|
||||||
},
|
match result {
|
||||||
top_tokens: vec![],
|
|
||||||
};
|
|
||||||
println!("Generated token: {response:?}");
|
|
||||||
// let _ = tokio::spawn(async {
|
|
||||||
// match ctx.stream.send(Ok(response)) {
|
|
||||||
// Ok(_) => {}
|
|
||||||
// Err(ref err) => {
|
|
||||||
// error!(
|
|
||||||
// "Failed to send back token to the client: {}",
|
|
||||||
// err.to_string()
|
|
||||||
// );
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
},
|
|
||||||
) {
|
|
||||||
Ok(n_tokens) => {
|
Ok(n_tokens) => {
|
||||||
unsafe {
|
unsafe {
|
||||||
ctx.generated_tokens.set_len(n_tokens);
|
ctx.generated_tokens.set_len(n_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"Generated {} tokens -> {:?}",
|
"Generated {} tokens -> {:?}",
|
||||||
n_tokens, &ctx.generated_tokens
|
n_tokens, &ctx.generated_tokens
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
use crate::ffi::SamplingParams;
|
use crate::ffi::SamplingParams;
|
||||||
|
use text_generation_router::infer::{InferError, InferStreamResponse};
|
||||||
|
use tokio::sync::mpsc::UnboundedSender;
|
||||||
|
|
||||||
pub mod backend;
|
pub mod backend;
|
||||||
|
|
||||||
@ -14,6 +16,8 @@ impl Default for SamplingParams {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct OpaqueStream(UnboundedSender<Result<InferStreamResponse, InferError>>);
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")]
|
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug, Copy, Clone)]
|
||||||
@ -31,6 +35,10 @@ mod ffi {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extern "Rust" {
|
||||||
|
type OpaqueStream;
|
||||||
|
}
|
||||||
|
|
||||||
unsafe extern "C++" {
|
unsafe extern "C++" {
|
||||||
include!("backends/llamacpp/csrc/ffi.hpp");
|
include!("backends/llamacpp/csrc/ffi.hpp");
|
||||||
|
|
||||||
@ -47,13 +55,22 @@ mod ffi {
|
|||||||
#[rust_name = "create_single_worker_backend"]
|
#[rust_name = "create_single_worker_backend"]
|
||||||
fn create_single_worker_backend(modelPath: &str) -> Result<UniquePtr<LlamaCppBackendImpl>>;
|
fn create_single_worker_backend(modelPath: &str) -> Result<UniquePtr<LlamaCppBackendImpl>>;
|
||||||
|
|
||||||
fn generate(
|
// fn generate(
|
||||||
|
// self: Pin<&mut LlamaCppBackendImpl>,
|
||||||
|
// tokens: &[u32],
|
||||||
|
// generated: &mut [u32],
|
||||||
|
// generation_params: GenerationParams,
|
||||||
|
// sampling_params: &SamplingParams,
|
||||||
|
// ) -> Result<usize>;
|
||||||
|
|
||||||
|
unsafe fn stream(
|
||||||
self: Pin<&mut LlamaCppBackendImpl>,
|
self: Pin<&mut LlamaCppBackendImpl>,
|
||||||
tokens: &[u32],
|
tokens: &[u32],
|
||||||
generated: &mut [u32],
|
generated: &mut [u32],
|
||||||
generation_params: &GenerationParams,
|
generation_params: GenerationParams,
|
||||||
sampling_params: &SamplingParams,
|
sampling_params: &SamplingParams,
|
||||||
callback: fn(u32, f32, bool),
|
stream: *mut OpaqueStream,
|
||||||
|
callback: unsafe fn(*mut OpaqueStream, u32, f32, bool),
|
||||||
) -> Result<usize>;
|
) -> Result<usize>;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user