Thanks cargo fmt

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
Adrien Gallouët 2025-02-06 10:08:18 +00:00
parent fb81c0d1c4
commit 2b0d99c1cf
No known key found for this signature in database
3 changed files with 52 additions and 48 deletions

View File

@ -1,5 +1,4 @@
use bindgen::callbacks::{ParseCallbacks, ItemInfo};
use bindgen::callbacks::{ItemInfo, ParseCallbacks};
use std::collections::HashMap;
use std::env;
use std::path::PathBuf;

View File

@ -7,21 +7,21 @@ mod llamacpp {
}
use async_trait::async_trait;
use std::ffi::CString;
use std::mem::replace;
use std::str::FromStr;
use std::sync::{mpsc, Once};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::{ValidGenerateRequest};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, Token};
use thiserror::Error;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::{watch, oneshot};
use tokio::sync::{oneshot, watch};
use tokio::task::{spawn, spawn_blocking};
use tokio::time::{Duration, Instant, timeout};
use tokio::time::{timeout, Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, info, warn, error, trace};
use tracing::{instrument};
use std::str::FromStr;
use std::mem::replace;
use tracing::instrument;
use tracing::{debug, error, info, trace, warn};
#[derive(Debug, Clone, Copy)]
pub enum LlamacppSplitMode {
@ -39,7 +39,7 @@ impl FromStr for LlamacppSplitMode {
_ => match s.parse::<usize>() {
Ok(n) => Ok(LlamacppSplitMode::GPU(n)),
Err(_) => Err("Choose a GPU number or `layer` or `row`".to_string()),
}
},
}
}
}
@ -175,23 +175,23 @@ impl LlamacppRequest {
fn new(
from: &ValidGenerateRequest,
tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
) -> Option<Self>{
) -> Option<Self> {
from.input_ids.as_ref().map(|input_ids| LlamacppRequest {
input_ids: input_ids.iter().map(|&x| x as i32).collect(),
top_k: from.parameters.top_k as _,
top_p: from.parameters.top_p as _,
typical_p: from.parameters.typical_p as _,
min_keep: 0, // disabled
temp: from.parameters.temperature as _,
seed: from.parameters.seed as _,
penalty_last_n: 64, // 0 = disabled, -1 = context size
penalty_repeat: from.parameters.repetition_penalty as _,
penalty_freq: from.parameters.frequency_penalty as _,
penalty_present: 0.0, // disabled
max_new_tokens: from.stopping_parameters.max_new_tokens as _,
tx,
time: Instant::now(),
})
input_ids: input_ids.iter().map(|&x| x as i32).collect(),
top_k: from.parameters.top_k as _,
top_p: from.parameters.top_p as _,
typical_p: from.parameters.typical_p as _,
min_keep: 0, // disabled
temp: from.parameters.temperature as _,
seed: from.parameters.seed as _,
penalty_last_n: 64, // 0 = disabled, -1 = context size
penalty_repeat: from.parameters.repetition_penalty as _,
penalty_freq: from.parameters.frequency_penalty as _,
penalty_present: 0.0, // disabled
max_new_tokens: from.stopping_parameters.max_new_tokens as _,
tx,
time: Instant::now(),
})
}
}
@ -241,7 +241,7 @@ impl Llamacpp {
llamacpp::model_load_from_file(gguf.as_ptr(), params)
};
if model.is_null() {
return Err(BackendError::Llamacpp("Failed to load model".to_string()))
return Err(BackendError::Llamacpp("Failed to load model".to_string()));
}
let ctx = unsafe {
let mut params = llamacpp::context_default_params();
@ -260,7 +260,7 @@ impl Llamacpp {
llamacpp::init_from_model(model, params)
};
if ctx.is_null() {
return Err(BackendError::Llamacpp("Failed to init context".to_string()))
return Err(BackendError::Llamacpp("Failed to init context".to_string()));
}
let vocab = unsafe {
llamacpp::model_get_vocab(model)
@ -444,8 +444,11 @@ impl LlamacppBackend {
pub fn new(
conf: LlamacppConfig,
tokenizer: Tokenizer,
) -> (Self, oneshot::Receiver<Result<(),BackendError>>, watch::Sender<bool>) {
) -> (
Self,
oneshot::Receiver<Result<(), BackendError>>,
watch::Sender<bool>,
) {
// Setup llama & export logs, once and for all
INIT.call_once(|| unsafe {
llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
@ -489,7 +492,7 @@ impl LlamacppBackend {
if requests.len() == conf.max_batch_size {
flush(&mut requests, &mut n_tokens);
}
},
}
Ok(None) => break, // closed
Err(_) => flush(&mut requests, &mut n_tokens), // timeout
}
@ -498,8 +501,14 @@ impl LlamacppBackend {
spawn_blocking(move || {
let mut llamacpp = match Llamacpp::new(conf) {
Ok(v) => { let _ = ok_tx.send(Ok(())); v },
Err(e) => { let _ = ok_tx.send(Err(e)); return; },
Ok(v) => {
let _ = ok_tx.send(Ok(()));
v
}
Err(e) => {
let _ = ok_tx.send(Err(e));
return;
}
};
let vocab = tokenizer.get_added_vocabulary();
@ -522,7 +531,7 @@ impl LlamacppBackend {
_ => {
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
continue;
},
}
};
let last_pos = request.input_ids.len() - 1;
@ -570,7 +579,7 @@ impl LlamacppBackend {
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
seq.running = false;
continue;
},
}
};
let special = vocab.is_special_token(&piece);

View File

@ -1,12 +1,15 @@
mod backend;
use backend::{LlamacppNuma, LlamacppGGMLType, LlamacppSplitMode, LlamacppConfig, LlamacppBackend, BackendError};
use clap::{Parser};
use backend::{
BackendError, LlamacppBackend, LlamacppConfig, LlamacppGGMLType, LlamacppNuma,
LlamacppSplitMode,
};
use clap::Parser;
use text_generation_router::{logging, server, usage_stats};
use thiserror::Error;
use tokenizers::{Tokenizer, FromPretrainedParameters};
use tokenizers::{FromPretrainedParameters, Tokenizer};
use tokio::sync::oneshot::error::RecvError;
use tracing::{warn, error};
use tracing::{error, warn};
/// Backend Configuration
#[derive(Parser, Debug)]
@ -161,11 +164,7 @@ struct Args {
async fn main() -> Result<(), RouterError> {
let args = Args::parse();
logging::init_logging(
args.otlp_endpoint,
args.otlp_service_name,
args.json_output
);
logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output);
let n_threads = match args.n_threads {
Some(0) | None => num_cpus::get(),
@ -218,10 +217,7 @@ async fn main() -> Result<(), RouterError> {
token,
..Default::default()
};
Tokenizer::from_pretrained(
args.model_id.clone(),
Some(params)
)?
Tokenizer::from_pretrained(args.model_id.clone(), Some(params))?
};
let (backend, ok, shutdown) = LlamacppBackend::new(