text-generation-inference/backends/llamacpp/src/backend.rs

71 lines
2.0 KiB
Rust
Raw Normal View History

2024-10-24 07:56:40 +00:00
use crate::ffi::{create_llamacpp_backend, LlamaCppBackendImpl};
use async_trait::async_trait;
2024-10-24 07:56:40 +00:00
use cxx::UniquePtr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
2024-10-04 08:42:31 +00:00
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use thiserror::Error;
use tokio::task::spawn_blocking;
2024-10-04 08:42:31 +00:00
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::info;
2024-10-04 08:42:31 +00:00
unsafe impl Send for LlamaCppBackendImpl {}
#[derive(Debug, Error)]
pub enum LlamaCppBackendError {
#[error("Provided GGUF model path {0} doesn't exist")]
ModelFileDoesntExist(String),
#[error("Failed to initialize model from GGUF file {0}: {1}")]
ModelInitializationFailed(PathBuf, String),
2024-10-24 07:56:40 +00:00
}
pub struct LlamaCppBackend {}
impl LlamaCppBackend {
pub fn new<P: AsRef<Path> + Send>(
model_path: P,
n_threads: u16,
) -> Result<Self, LlamaCppBackendError> {
let path = Arc::new(model_path.as_ref());
if !path.exists() {
return Err(LlamaCppBackendError::ModelFileDoesntExist(
path.display().to_string(),
));
}
let mut backend =
create_llamacpp_backend(path.to_str().unwrap(), n_threads).map_err(|err| {
LlamaCppBackendError::ModelInitializationFailed(
path.to_path_buf(),
err.what().to_string(),
)
})?;
info!(
"Successfully initialized llama.cpp backend from {}",
path.display()
);
spawn_blocking(move || scheduler_loop(backend));
Ok(Self {})
2024-10-24 07:56:40 +00:00
}
}
2024-10-04 08:42:31 +00:00
async fn scheduler_loop(mut backend: UniquePtr<LlamaCppBackendImpl>) {}
#[async_trait]
impl Backend for LlamaCppBackend {
2024-10-04 08:42:31 +00:00
fn schedule(
&self,
_request: ValidGenerateRequest,
2024-10-04 08:42:31 +00:00
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
Err(InferError::GenerationError("Not implemented yet".into()))
}
async fn health(&self, _: bool) -> bool {
true
2024-10-04 08:42:31 +00:00
}
}