Thanks cargo fmt

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-02-06 10:08:18 +00:00
parent fb81c0d1c4
commit 2b0d99c1cf
No known key found for this signature in database
3 changed files with 52 additions and 48 deletions

View File

@ -1,5 +1,4 @@
use bindgen::callbacks::{ItemInfo, ParseCallbacks};
use bindgen::callbacks::{ParseCallbacks, ItemInfo};
use std::collections::HashMap; use std::collections::HashMap;
use std::env; use std::env;
use std::path::PathBuf; use std::path::PathBuf;

View File

@ -7,21 +7,21 @@ mod llamacpp {
} }
use async_trait::async_trait; use async_trait::async_trait;
use std::ffi::CString; use std::ffi::CString;
use std::mem::replace;
use std::str::FromStr;
use std::sync::{mpsc, Once}; use std::sync::{mpsc, Once};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::{ValidGenerateRequest}; use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, Token}; use text_generation_router::{FinishReason, Token};
use thiserror::Error; use thiserror::Error;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::{watch, oneshot}; use tokio::sync::{oneshot, watch};
use tokio::task::{spawn, spawn_blocking}; use tokio::task::{spawn, spawn_blocking};
use tokio::time::{Duration, Instant, timeout}; use tokio::time::{timeout, Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, info, warn, error, trace}; use tracing::instrument;
use tracing::{instrument}; use tracing::{debug, error, info, trace, warn};
use std::str::FromStr;
use std::mem::replace;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub enum LlamacppSplitMode { pub enum LlamacppSplitMode {
@ -39,7 +39,7 @@ impl FromStr for LlamacppSplitMode {
_ => match s.parse::<usize>() { _ => match s.parse::<usize>() {
Ok(n) => Ok(LlamacppSplitMode::GPU(n)), Ok(n) => Ok(LlamacppSplitMode::GPU(n)),
Err(_) => Err("Choose a GPU number or `layer` or `row`".to_string()), Err(_) => Err("Choose a GPU number or `layer` or `row`".to_string()),
} },
} }
} }
} }
@ -241,7 +241,7 @@ impl Llamacpp {
llamacpp::model_load_from_file(gguf.as_ptr(), params) llamacpp::model_load_from_file(gguf.as_ptr(), params)
}; };
if model.is_null() { if model.is_null() {
return Err(BackendError::Llamacpp("Failed to load model".to_string())) return Err(BackendError::Llamacpp("Failed to load model".to_string()));
} }
let ctx = unsafe { let ctx = unsafe {
let mut params = llamacpp::context_default_params(); let mut params = llamacpp::context_default_params();
@ -260,7 +260,7 @@ impl Llamacpp {
llamacpp::init_from_model(model, params) llamacpp::init_from_model(model, params)
}; };
if ctx.is_null() { if ctx.is_null() {
return Err(BackendError::Llamacpp("Failed to init context".to_string())) return Err(BackendError::Llamacpp("Failed to init context".to_string()));
} }
let vocab = unsafe { let vocab = unsafe {
llamacpp::model_get_vocab(model) llamacpp::model_get_vocab(model)
@ -444,8 +444,11 @@ impl LlamacppBackend {
pub fn new( pub fn new(
conf: LlamacppConfig, conf: LlamacppConfig,
tokenizer: Tokenizer, tokenizer: Tokenizer,
) -> (Self, oneshot::Receiver<Result<(),BackendError>>, watch::Sender<bool>) { ) -> (
Self,
oneshot::Receiver<Result<(), BackendError>>,
watch::Sender<bool>,
) {
// Setup llama & export logs, once and for all // Setup llama & export logs, once and for all
INIT.call_once(|| unsafe { INIT.call_once(|| unsafe {
llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut()); llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
@ -489,7 +492,7 @@ impl LlamacppBackend {
if requests.len() == conf.max_batch_size { if requests.len() == conf.max_batch_size {
flush(&mut requests, &mut n_tokens); flush(&mut requests, &mut n_tokens);
} }
}, }
Ok(None) => break, // closed Ok(None) => break, // closed
Err(_) => flush(&mut requests, &mut n_tokens), // timeout Err(_) => flush(&mut requests, &mut n_tokens), // timeout
} }
@ -498,8 +501,14 @@ impl LlamacppBackend {
spawn_blocking(move || { spawn_blocking(move || {
let mut llamacpp = match Llamacpp::new(conf) { let mut llamacpp = match Llamacpp::new(conf) {
Ok(v) => { let _ = ok_tx.send(Ok(())); v }, Ok(v) => {
Err(e) => { let _ = ok_tx.send(Err(e)); return; }, let _ = ok_tx.send(Ok(()));
v
}
Err(e) => {
let _ = ok_tx.send(Err(e));
return;
}
}; };
let vocab = tokenizer.get_added_vocabulary(); let vocab = tokenizer.get_added_vocabulary();
@ -522,7 +531,7 @@ impl LlamacppBackend {
_ => { _ => {
let _ = request.tx.send(Err(InferError::IncompleteGeneration)); let _ = request.tx.send(Err(InferError::IncompleteGeneration));
continue; continue;
}, }
}; };
let last_pos = request.input_ids.len() - 1; let last_pos = request.input_ids.len() - 1;
@ -570,7 +579,7 @@ impl LlamacppBackend {
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration)); let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
seq.running = false; seq.running = false;
continue; continue;
}, }
}; };
let special = vocab.is_special_token(&piece); let special = vocab.is_special_token(&piece);

View File

@ -1,12 +1,15 @@
mod backend; mod backend;
use backend::{LlamacppNuma, LlamacppGGMLType, LlamacppSplitMode, LlamacppConfig, LlamacppBackend, BackendError}; use backend::{
use clap::{Parser}; BackendError, LlamacppBackend, LlamacppConfig, LlamacppGGMLType, LlamacppNuma,
LlamacppSplitMode,
};
use clap::Parser;
use text_generation_router::{logging, server, usage_stats}; use text_generation_router::{logging, server, usage_stats};
use thiserror::Error; use thiserror::Error;
use tokenizers::{Tokenizer, FromPretrainedParameters}; use tokenizers::{FromPretrainedParameters, Tokenizer};
use tokio::sync::oneshot::error::RecvError; use tokio::sync::oneshot::error::RecvError;
use tracing::{warn, error}; use tracing::{error, warn};
/// Backend Configuration /// Backend Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -161,11 +164,7 @@ struct Args {
async fn main() -> Result<(), RouterError> { async fn main() -> Result<(), RouterError> {
let args = Args::parse(); let args = Args::parse();
logging::init_logging( logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output);
args.otlp_endpoint,
args.otlp_service_name,
args.json_output
);
let n_threads = match args.n_threads { let n_threads = match args.n_threads {
Some(0) | None => num_cpus::get(), Some(0) | None => num_cpus::get(),
@ -218,10 +217,7 @@ async fn main() -> Result<(), RouterError> {
token, token,
..Default::default() ..Default::default()
}; };
Tokenizer::from_pretrained( Tokenizer::from_pretrained(args.model_id.clone(), Some(params))?
args.model_id.clone(),
Some(params)
)?
}; };
let (backend, ok, shutdown) = LlamacppBackend::new( let (backend, ok, shutdown) = LlamacppBackend::new(