mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-03 07:52:06 +00:00
feat(backend): refactor the callback to handle intermediate and end inference message
This commit is contained in:
parent
11c593dc69
commit
5b7a951389
@ -114,9 +114,6 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
auto is_eog = llama_token_is_eog(mModel_.get(), new_token_id);
|
auto is_eog = llama_token_is_eog(mModel_.get(), new_token_id);
|
||||||
auto new_token_logits = 0.0f; // TODO: return logit
|
auto new_token_logits = 0.0f; // TODO: return logit
|
||||||
|
|
||||||
// Push the token to the generated vector on Rust side
|
|
||||||
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
|
|
||||||
|
|
||||||
// Handle termination cases
|
// Handle termination cases
|
||||||
const auto has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1;
|
const auto has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1;
|
||||||
const auto has_reach_eog = !generation_context.generation_params.ignore_eos_token & is_eog;
|
const auto has_reach_eog = !generation_context.generation_params.ignore_eos_token & is_eog;
|
||||||
@ -150,10 +147,15 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
) {
|
) {
|
||||||
// TODO: Should we provide a way to change this value?
|
// TODO: Should we provide a way to change this value?
|
||||||
auto generated = std::vector<llama_token>(2 << 8);
|
auto generated = std::vector<llama_token>(2 << 8);
|
||||||
|
auto inner_callback = [&](uint32_t new_token_id, float_t new_token_logit, bool is_eos,
|
||||||
|
size_t num_generated_tokens) {
|
||||||
|
generated.emplace_back(new_token_id);
|
||||||
|
|
||||||
auto nTokensGenerated = generate(tokens, generated, generation_params, sampling_params, callback);
|
if (callback.has_value())
|
||||||
if (nTokensGenerated.has_value())
|
(*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens);
|
||||||
generated.resize(*nTokensGenerated);
|
};
|
||||||
|
|
||||||
|
auto nTokensGenerated = stream(tokens, generation_params, sampling_params, inner_callback);
|
||||||
return generated;
|
return generated;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,25 +170,24 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL);
|
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::expected<std::size_t, backend_error_t>
|
std::expected<size_t, backend_error_t>
|
||||||
single_worker_backend_t::generate(
|
single_worker_backend_t::stream(
|
||||||
std::span<const llama_token> tokens,
|
std::span<const llama_token> tokens,
|
||||||
std::span<llama_token> out,
|
|
||||||
const generation_params_t &generation_params,
|
const generation_params_t &generation_params,
|
||||||
const sampling_params_t &sampling_params,
|
const sampling_params_t &sampling_params,
|
||||||
const std::optional<llama_decode_callback> &callback
|
const llama_decode_callback &callback
|
||||||
) {
|
) {
|
||||||
return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens, out}, callback);
|
return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens}, callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::expected<size_t, backend_error_t>
|
std::expected<size_t, backend_error_t>
|
||||||
multi_worker_backend_t::generate(
|
multi_worker_backend_t::stream(
|
||||||
std::span<const llama_token>,
|
std::span<const llama_token> tokens,
|
||||||
std::span<llama_token>,
|
|
||||||
const generation_params_t &generation_params,
|
const generation_params_t &generation_params,
|
||||||
const sampling_params_t &sampling_params,
|
const sampling_params_t &sampling_params,
|
||||||
const std::optional<llama_decode_callback> &callback) {
|
const llama_decode_callback &callback
|
||||||
SPDLOG_ERROR("Not implemented yet");
|
) {
|
||||||
return 0uz;
|
SPDLOG_WARN("Not implemented for multi_worker_t");
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -69,7 +69,6 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
generation_params_t generation_params;
|
generation_params_t generation_params;
|
||||||
sampling_params_t sampling_params;
|
sampling_params_t sampling_params;
|
||||||
std::span<const llama_token> input_tokens;
|
std::span<const llama_token> input_tokens;
|
||||||
std::span<llama_token> generated_tokens;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -125,25 +124,9 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param tokens
|
* @param tokens
|
||||||
* @params out
|
* @param generation_params
|
||||||
* @param params
|
* @param sampling_params
|
||||||
* @param maxNewTokens
|
* @param callback
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
|
|
||||||
virtual std::expected<size_t, backend_error_t> generate(
|
|
||||||
std::span<const llama_token> input_tokens,
|
|
||||||
std::span<llama_token> generated_tokens,
|
|
||||||
const generation_params_t &generation_params,
|
|
||||||
const sampling_params_t &sampling_params,
|
|
||||||
const std::optional<llama_decode_callback> &callback
|
|
||||||
) = 0;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param tokens
|
|
||||||
* @param params
|
|
||||||
* @param maxNewTokens
|
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
|
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
|
||||||
@ -153,6 +136,22 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
const sampling_params_t &sampling_params,
|
const sampling_params_t &sampling_params,
|
||||||
const std::optional<llama_decode_callback> &callback = std::nullopt
|
const std::optional<llama_decode_callback> &callback = std::nullopt
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param tokens
|
||||||
|
* @param generation_params
|
||||||
|
* @param sampling_params
|
||||||
|
* @params callback
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
|
||||||
|
virtual std::expected<size_t, backend_error_t> stream(
|
||||||
|
std::span<const llama_token> tokens,
|
||||||
|
const generation_params_t &generation_params,
|
||||||
|
const sampling_params_t &sampling_params,
|
||||||
|
const llama_decode_callback &callback
|
||||||
|
) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -174,16 +173,11 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
public:
|
public:
|
||||||
explicit single_worker_backend_t(llama_model *pModel, const std::optional<llama_context_params> &);
|
explicit single_worker_backend_t(llama_model *pModel, const std::optional<llama_context_params> &);
|
||||||
|
|
||||||
using backend_base_t::generate;
|
std::expected<size_t, backend_error_t> stream(
|
||||||
|
|
||||||
std::expected<size_t, backend_error_t>
|
|
||||||
generate(
|
|
||||||
std::span<const llama_token> tokens,
|
std::span<const llama_token> tokens,
|
||||||
std::span<llama_token> out,
|
|
||||||
const generation_params_t &generation_params,
|
const generation_params_t &generation_params,
|
||||||
const sampling_params_t &sampling_params,
|
const sampling_params_t &sampling_params,
|
||||||
const std::optional<llama_decode_callback> &callback
|
const llama_decode_callback &callback) override;
|
||||||
) override;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class multi_worker_backend_t : backend_base_t {
|
class multi_worker_backend_t : backend_base_t {
|
||||||
@ -191,13 +185,11 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
llama_context_ptr mContext_;
|
llama_context_ptr mContext_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
std::expected<size_t, backend_error_t> generate(
|
std::expected<size_t, backend_error_t> stream(
|
||||||
std::span<const llama_token>,
|
std::span<const llama_token> tokens,
|
||||||
std::span<llama_token>,
|
|
||||||
const generation_params_t &generation_params,
|
const generation_params_t &generation_params,
|
||||||
const sampling_params_t &sampling_params,
|
const sampling_params_t &sampling_params,
|
||||||
const std::optional<llama_decode_callback> &callback
|
const llama_decode_callback &callback) override;
|
||||||
) override;
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,23 +28,20 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
|
|
||||||
// Concept identifying types which have a .generate() -> size_t method to do in-place generation
|
// Concept identifying types which have a .generate() -> size_t method to do in-place generation
|
||||||
template<typename T>
|
template<typename T>
|
||||||
concept has_emplace_generate = requires(
|
concept has_stream_method = requires(
|
||||||
T t,
|
T t,
|
||||||
std::span<const llama_token> input_tokens,
|
std::span<const llama_token> input_tokens,
|
||||||
std::span<llama_token> generated_tokens,
|
|
||||||
const generation_params_t &generation_params,
|
const generation_params_t &generation_params,
|
||||||
const sampling_params_t &sampling_params,
|
const sampling_params_t &sampling_params,
|
||||||
llama_decode_callback callback
|
llama_decode_callback callback
|
||||||
) {
|
) {
|
||||||
{
|
{
|
||||||
t.generate(input_tokens, generated_tokens, generation_params, sampling_params, callback)
|
t.stream(input_tokens, generation_params, sampling_params, callback)
|
||||||
} -> std::same_as<std::expected<size_t, backend_error_t>>;
|
} -> std::same_as<std::expected<size_t, backend_error_t>>;
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(has_emplace_generate<single_worker_backend_t>,
|
static_assert(has_stream_method<single_worker_backend_t>, "single_worker_backend_t doesn't meet concept has_stream_method");
|
||||||
"single_worker_backend_t doesn't meet concept is_generate_emplace_capable");
|
static_assert(has_stream_method<multi_worker_backend_t>, "multi_worker_backend_t doesn't meet concept has_stream_method");
|
||||||
static_assert(has_emplace_generate<multi_worker_backend_t>,
|
|
||||||
"multi_worker_backend_t doesn't meet concept is_generate_emplace_capable");
|
|
||||||
|
|
||||||
class llama_cpp_backend_exception_t : std::exception {
|
class llama_cpp_backend_exception_t : std::exception {
|
||||||
|
|
||||||
@ -64,29 +61,25 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
|
|
||||||
size_t stream(
|
size_t stream(
|
||||||
rust::Slice<const uint32_t> input_tokens,
|
rust::Slice<const uint32_t> input_tokens,
|
||||||
rust::Slice <uint32_t> generated_tokens,
|
|
||||||
const generation_params_t generation_params,
|
const generation_params_t generation_params,
|
||||||
const sampling_params_t &sampling_params,
|
const sampling_params_t &sampling_params,
|
||||||
OpaqueStream *stream,
|
InferContext *ctx,
|
||||||
rust::Fn<void(OpaqueStream *, uint32_t, float_t, bool, size_t)> callback
|
rust::Fn<void(InferContext *, uint32_t, float_t, bool, size_t)> callback
|
||||||
) {
|
) {
|
||||||
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
|
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
|
||||||
auto inner_fw = [=, &sampling_params, &stream, &callback]<has_emplace_generate T>(T &&backend)
|
auto inner_fw = [=, &sampling_params, &ctx, &callback]<has_stream_method T>(T &&backend)
|
||||||
-> std::expected<size_t, backend_error_t> {
|
-> std::expected<size_t, backend_error_t> {
|
||||||
|
|
||||||
auto context_forwarding_callback = [=, &stream](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens){
|
auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens){
|
||||||
callback(stream, new_token_id, logits, is_eos, n_generated_tokens);
|
callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
|
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
|
||||||
auto input_tokens_v =
|
auto input_tokens_v =
|
||||||
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
|
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
|
||||||
auto generated_tokens_v =
|
|
||||||
std::span(reinterpret_cast<llama_token *>(generated_tokens.data()), generated_tokens.size());
|
|
||||||
|
|
||||||
return backend.generate(
|
return backend.stream(
|
||||||
input_tokens_v,
|
input_tokens_v,
|
||||||
generated_tokens_v,
|
|
||||||
generation_params,
|
generation_params,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
context_forwarding_callback
|
context_forwarding_callback
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
use crate::ffi::{
|
use crate::ffi::{
|
||||||
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams,
|
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams,
|
||||||
};
|
};
|
||||||
use crate::OpaqueStream;
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
use cxx::UniquePtr;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
@ -14,12 +13,13 @@ use text_generation_router::validation::{
|
|||||||
};
|
};
|
||||||
use text_generation_router::{FinishReason, Token};
|
use text_generation_router::{FinishReason, Token};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error, info};
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
type BoxedOpaqueStream = Box<OpaqueStream>;
|
type InferResult = Result<InferStreamResponse, InferError>;
|
||||||
|
|
||||||
unsafe impl Send for LlamaCppBackendImpl {}
|
unsafe impl Send for LlamaCppBackendImpl {}
|
||||||
|
|
||||||
@ -45,14 +45,19 @@ impl From<&ValidStoppingParameters> for GenerationParams {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(debug_assertions, derive(Debug))]
|
#[cfg_attr(debug_assertions, derive(Debug))]
|
||||||
struct InferContext {
|
struct GenerationContext {
|
||||||
pub(crate) stream: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
|
||||||
pub(crate) input_tokens: Arc<Vec<u32>>,
|
pub(crate) input_tokens: Arc<Vec<u32>>,
|
||||||
pub(crate) generated_tokens: Vec<u32>,
|
pub(crate) generated_tokens: Vec<u32>,
|
||||||
pub(crate) generation_params: GenerationParams,
|
pub(crate) generation_params: GenerationParams,
|
||||||
pub(crate) sampling_params: SamplingParams,
|
pub(crate) sampling_params: SamplingParams,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) struct InferContext {
|
||||||
|
pub(crate) start: Instant,
|
||||||
|
pub(crate) stream: UnboundedSender<InferResult>,
|
||||||
|
pub(crate) generation: GenerationContext,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum LlamaCppBackendError {
|
pub enum LlamaCppBackendError {
|
||||||
#[error("Provided GGUF model path {0} doesn't exist")]
|
#[error("Provided GGUF model path {0} doesn't exist")]
|
||||||
@ -63,7 +68,7 @@ pub enum LlamaCppBackendError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct LlamaCppBackend {
|
pub struct LlamaCppBackend {
|
||||||
backlog: Sender<InferContext>,
|
backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||||
scheduler_handle: JoinHandle<()>,
|
scheduler_handle: JoinHandle<()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,81 +103,96 @@ impl LlamaCppBackend {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn llama_generate_callback(
|
fn llama_generate_callback(
|
||||||
channel: *mut OpaqueStream,
|
ctx: *mut InferContext,
|
||||||
new_token_id: u32,
|
new_token_id: u32,
|
||||||
new_token_logit: f32,
|
new_token_logit: f32,
|
||||||
is_eos: bool,
|
is_final: bool,
|
||||||
n_generated_tokens: usize,
|
n_generated_tokens: usize,
|
||||||
) {
|
) {
|
||||||
let response = InferStreamResponse::Intermediate {
|
info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})");
|
||||||
token: Token {
|
|
||||||
id: new_token_id,
|
|
||||||
text: "".to_string(),
|
|
||||||
logprob: new_token_logit,
|
|
||||||
special: false,
|
|
||||||
},
|
|
||||||
top_tokens: vec![],
|
|
||||||
};
|
|
||||||
info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_eos={is_eos} ({n_generated_tokens})");
|
|
||||||
|
|
||||||
unsafe {
|
// Decode token
|
||||||
if let Err(ref err) = (*channel).0.send(Ok(response)) {
|
let token = Token {
|
||||||
error!(
|
id: new_token_id,
|
||||||
"Failed to send back token to the client: {}",
|
text: "".to_string(),
|
||||||
err.to_string()
|
logprob: new_token_logit,
|
||||||
);
|
special: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let ctx = unsafe { &mut *ctx };
|
||||||
|
|
||||||
|
// Append the new token to the generated ones
|
||||||
|
ctx.generation.generated_tokens.push(new_token_id);
|
||||||
|
|
||||||
|
// Create the streamed response
|
||||||
|
let response = match is_final {
|
||||||
|
false => InferStreamResponse::Intermediate {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
},
|
||||||
|
true => {
|
||||||
|
// Decode the whole text
|
||||||
|
let text = String::new();
|
||||||
|
|
||||||
|
// Stream end response
|
||||||
|
InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
generated_text: GeneratedText {
|
||||||
|
text,
|
||||||
|
generated_tokens: n_generated_tokens as u32,
|
||||||
|
finish_reason: FinishReason::Length,
|
||||||
|
seed: Some(ctx.generation.sampling_params.seed),
|
||||||
|
},
|
||||||
|
start: ctx.start,
|
||||||
|
queued: ctx.start,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Send back to the client
|
||||||
|
if let Err(ref err) = ctx.stream.send(Ok(response)) {
|
||||||
|
error!("Failed to send back the response to the client, cancelling request");
|
||||||
|
// TODO: cancel the request
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn scheduler_loop(
|
unsafe fn scheduler_loop(
|
||||||
mut backend: UniquePtr<LlamaCppBackendImpl>,
|
mut backend: UniquePtr<LlamaCppBackendImpl>,
|
||||||
backlog: Receiver<InferContext>,
|
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||||
) {
|
) {
|
||||||
|
// This loop will mostly decode single token at every step, so no need to rely on parallelism
|
||||||
|
tokenizers::utils::parallelism::set_parallelism(false);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
if let Ok(mut ctx) = backlog.recv() {
|
if let Ok((generation, stream)) = backlog.recv() {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let stream = BoxedOpaqueStream::new(OpaqueStream(ctx.stream));
|
let generation_params = generation.generation_params; // copy
|
||||||
let stream_ptr = Box::into_raw(stream);
|
let sampling_params = generation.sampling_params; // copy
|
||||||
let result = backend.pin_mut().stream(
|
let input_tokens = Arc::clone(&generation.input_tokens);
|
||||||
&ctx.input_tokens,
|
|
||||||
&mut ctx.generated_tokens,
|
|
||||||
ctx.generation_params,
|
|
||||||
&ctx.sampling_params,
|
|
||||||
stream_ptr,
|
|
||||||
llama_generate_callback,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Make sure we re-keep track of the OpaqueStream box
|
// Creating the whole InferContext and pushing it to the heap
|
||||||
let stream = Box::from_raw(stream_ptr);
|
{
|
||||||
|
let ctx = Box::new(InferContext {
|
||||||
|
start,
|
||||||
|
stream,
|
||||||
|
generation,
|
||||||
|
});
|
||||||
|
|
||||||
match result {
|
let boxed_ctx = Box::into_raw(ctx);
|
||||||
Ok(n_tokens) => {
|
|
||||||
unsafe {
|
|
||||||
ctx.generated_tokens.set_len(n_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
let _ = stream.0.send(Ok(InferStreamResponse::End {
|
if let Err(e) = backend.pin_mut().stream(
|
||||||
token: Token {
|
&input_tokens,
|
||||||
id: ctx.generated_tokens[n_tokens - 1],
|
generation_params,
|
||||||
text: "".to_string(),
|
&sampling_params,
|
||||||
logprob: 0.0,
|
boxed_ctx,
|
||||||
special: false,
|
llama_generate_callback,
|
||||||
},
|
) {
|
||||||
top_tokens: vec![],
|
error!("Error while decoding tokens... {}", e.what());
|
||||||
generated_text: GeneratedText {
|
|
||||||
text: "".to_string(),
|
|
||||||
generated_tokens: n_tokens as u32,
|
|
||||||
finish_reason: FinishReason::Length,
|
|
||||||
seed: Some(ctx.sampling_params.seed),
|
|
||||||
},
|
|
||||||
start,
|
|
||||||
queued: start,
|
|
||||||
}));
|
|
||||||
|
|
||||||
debug!("Generated {n_tokens} tokens -> {:?}", ctx.generated_tokens);
|
|
||||||
}
|
}
|
||||||
Err(err) => println!("Error: {err}"),
|
|
||||||
|
// Make sure we re-keep track of the OpaqueStream box
|
||||||
|
let _ = Box::from_raw(boxed_ctx);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
info!("IPC channel is closed, exiting the scheduler loop");
|
info!("IPC channel is closed, exiting the scheduler loop");
|
||||||
@ -186,21 +206,20 @@ impl Backend for LlamaCppBackend {
|
|||||||
fn schedule(
|
fn schedule(
|
||||||
&self,
|
&self,
|
||||||
request: ValidGenerateRequest,
|
request: ValidGenerateRequest,
|
||||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
) -> Result<UnboundedReceiverStream<InferResult>, InferError> {
|
||||||
if let Some(input_ids) = request.input_ids {
|
if let Some(input_ids) = request.input_ids {
|
||||||
let (sx, rx) = unbounded_channel();
|
let (sx, rx) = unbounded_channel();
|
||||||
let sampling_params = SamplingParams::from(&request.parameters);
|
let sampling_params = SamplingParams::from(&request.parameters);
|
||||||
let generation_params = GenerationParams::from(&request.stopping_parameters);
|
let generation_params = GenerationParams::from(&request.stopping_parameters);
|
||||||
|
|
||||||
let ctx = InferContext {
|
let ctx = GenerationContext {
|
||||||
stream: sx,
|
|
||||||
input_tokens: Arc::clone(&input_ids),
|
input_tokens: Arc::clone(&input_ids),
|
||||||
generated_tokens: Vec::with_capacity(generation_params.max_new_tokens as usize),
|
generated_tokens: Vec::with_capacity(generation_params.max_new_tokens as usize),
|
||||||
generation_params,
|
generation_params,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
};
|
};
|
||||||
|
|
||||||
match self.backlog.send(ctx) {
|
match self.backlog.send((ctx, sx)) {
|
||||||
Ok(_) => Ok(UnboundedReceiverStream::new(rx)),
|
Ok(_) => Ok(UnboundedReceiverStream::new(rx)),
|
||||||
Err(_) => Err(InferError::GenerationError(
|
Err(_) => Err(InferError::GenerationError(
|
||||||
"Failed to sent the request".to_string(),
|
"Failed to sent the request".to_string(),
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
|
use crate::backend::InferContext;
|
||||||
use crate::ffi::SamplingParams;
|
use crate::ffi::SamplingParams;
|
||||||
use text_generation_router::infer::{InferError, InferStreamResponse};
|
|
||||||
use tokio::sync::mpsc::UnboundedSender;
|
|
||||||
|
|
||||||
pub mod backend;
|
pub mod backend;
|
||||||
|
|
||||||
@ -16,8 +15,6 @@ impl Default for SamplingParams {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct OpaqueStream(UnboundedSender<Result<InferStreamResponse, InferError>>);
|
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")]
|
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug, Copy, Clone)]
|
||||||
@ -36,7 +33,7 @@ mod ffi {
|
|||||||
}
|
}
|
||||||
|
|
||||||
extern "Rust" {
|
extern "Rust" {
|
||||||
type OpaqueStream;
|
type InferContext;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "C++" {
|
unsafe extern "C++" {
|
||||||
@ -66,11 +63,10 @@ mod ffi {
|
|||||||
unsafe fn stream(
|
unsafe fn stream(
|
||||||
self: Pin<&mut LlamaCppBackendImpl>,
|
self: Pin<&mut LlamaCppBackendImpl>,
|
||||||
tokens: &[u32],
|
tokens: &[u32],
|
||||||
generated: &mut [u32],
|
|
||||||
generation_params: GenerationParams,
|
generation_params: GenerationParams,
|
||||||
sampling_params: &SamplingParams,
|
sampling_params: &SamplingParams,
|
||||||
stream: *mut OpaqueStream,
|
stream: *mut InferContext,
|
||||||
callback: unsafe fn(*mut OpaqueStream, u32, f32, bool, usize),
|
callback: unsafe fn(*mut InferContext, u32, f32, bool, usize),
|
||||||
) -> Result<usize>;
|
) -> Result<usize>;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user