From 2b0d99c1cf7b7af3cd2590387a2aa11b6e42bb44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Thu, 6 Feb 2025 10:08:18 +0000 Subject: [PATCH] Thanks cargo fmt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- backends/llamacpp/build.rs | 3 +- backends/llamacpp/src/backend.rs | 75 ++++++++++++++++++-------------- backends/llamacpp/src/main.rs | 22 ++++------ 3 files changed, 52 insertions(+), 48 deletions(-) diff --git a/backends/llamacpp/build.rs b/backends/llamacpp/build.rs index 1b1c3718..aa2a0d87 100644 --- a/backends/llamacpp/build.rs +++ b/backends/llamacpp/build.rs @@ -1,5 +1,4 @@ - -use bindgen::callbacks::{ParseCallbacks, ItemInfo}; +use bindgen::callbacks::{ItemInfo, ParseCallbacks}; use std::collections::HashMap; use std::env; use std::path::PathBuf; diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 5d5eab43..81f7b9f4 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -7,21 +7,21 @@ mod llamacpp { } use async_trait::async_trait; use std::ffi::CString; +use std::mem::replace; +use std::str::FromStr; use std::sync::{mpsc, Once}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; -use text_generation_router::validation::{ValidGenerateRequest}; +use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::{FinishReason, Token}; use thiserror::Error; use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; -use tokio::sync::{watch, oneshot}; +use tokio::sync::{oneshot, watch}; use tokio::task::{spawn, spawn_blocking}; -use tokio::time::{Duration, Instant, timeout}; +use tokio::time::{timeout, Duration, Instant}; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{debug, info, warn, error, trace}; -use tracing::{instrument}; -use std::str::FromStr; -use std::mem::replace; +use tracing::instrument; +use tracing::{debug, error, info, trace, warn}; #[derive(Debug, Clone, Copy)] pub enum LlamacppSplitMode { @@ -39,7 +39,7 @@ impl FromStr for LlamacppSplitMode { _ => match s.parse::() { Ok(n) => Ok(LlamacppSplitMode::GPU(n)), Err(_) => Err("Choose a GPU number or `layer` or `row`".to_string()), - } + }, } } } @@ -175,23 +175,23 @@ impl LlamacppRequest { fn new( from: &ValidGenerateRequest, tx: UnboundedSender>, - ) -> Option{ + ) -> Option { from.input_ids.as_ref().map(|input_ids| LlamacppRequest { - input_ids: input_ids.iter().map(|&x| x as i32).collect(), - top_k: from.parameters.top_k as _, - top_p: from.parameters.top_p as _, - typical_p: from.parameters.typical_p as _, - min_keep: 0, // disabled - temp: from.parameters.temperature as _, - seed: from.parameters.seed as _, - penalty_last_n: 64, // 0 = disabled, -1 = context size - penalty_repeat: from.parameters.repetition_penalty as _, - penalty_freq: from.parameters.frequency_penalty as _, - penalty_present: 0.0, // disabled - max_new_tokens: from.stopping_parameters.max_new_tokens as _, - tx, - time: Instant::now(), - }) + input_ids: input_ids.iter().map(|&x| x as i32).collect(), + top_k: from.parameters.top_k as _, + top_p: from.parameters.top_p as _, + typical_p: from.parameters.typical_p as _, + min_keep: 0, // disabled + temp: from.parameters.temperature as _, + seed: from.parameters.seed as _, + penalty_last_n: 64, // 0 = disabled, -1 = context size + penalty_repeat: from.parameters.repetition_penalty as _, + penalty_freq: from.parameters.frequency_penalty as _, + penalty_present: 0.0, // disabled + max_new_tokens: from.stopping_parameters.max_new_tokens as _, + tx, + time: Instant::now(), + }) } } @@ -241,7 +241,7 @@ impl Llamacpp { llamacpp::model_load_from_file(gguf.as_ptr(), params) }; if model.is_null() { - return Err(BackendError::Llamacpp("Failed to load model".to_string())) + return Err(BackendError::Llamacpp("Failed to load model".to_string())); } let ctx = unsafe { let mut params = llamacpp::context_default_params(); @@ -260,7 +260,7 @@ impl Llamacpp { llamacpp::init_from_model(model, params) }; if ctx.is_null() { - return Err(BackendError::Llamacpp("Failed to init context".to_string())) + return Err(BackendError::Llamacpp("Failed to init context".to_string())); } let vocab = unsafe { llamacpp::model_get_vocab(model) @@ -444,8 +444,11 @@ impl LlamacppBackend { pub fn new( conf: LlamacppConfig, tokenizer: Tokenizer, - ) -> (Self, oneshot::Receiver>, watch::Sender) { - + ) -> ( + Self, + oneshot::Receiver>, + watch::Sender, + ) { // Setup llama & export logs, once and for all INIT.call_once(|| unsafe { llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut()); @@ -489,7 +492,7 @@ impl LlamacppBackend { if requests.len() == conf.max_batch_size { flush(&mut requests, &mut n_tokens); } - }, + } Ok(None) => break, // closed Err(_) => flush(&mut requests, &mut n_tokens), // timeout } @@ -498,8 +501,14 @@ impl LlamacppBackend { spawn_blocking(move || { let mut llamacpp = match Llamacpp::new(conf) { - Ok(v) => { let _ = ok_tx.send(Ok(())); v }, - Err(e) => { let _ = ok_tx.send(Err(e)); return; }, + Ok(v) => { + let _ = ok_tx.send(Ok(())); + v + } + Err(e) => { + let _ = ok_tx.send(Err(e)); + return; + } }; let vocab = tokenizer.get_added_vocabulary(); @@ -522,7 +531,7 @@ impl LlamacppBackend { _ => { let _ = request.tx.send(Err(InferError::IncompleteGeneration)); continue; - }, + } }; let last_pos = request.input_ids.len() - 1; @@ -570,7 +579,7 @@ impl LlamacppBackend { let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration)); seq.running = false; continue; - }, + } }; let special = vocab.is_special_token(&piece); diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 762764a7..1919580d 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -1,12 +1,15 @@ mod backend; -use backend::{LlamacppNuma, LlamacppGGMLType, LlamacppSplitMode, LlamacppConfig, LlamacppBackend, BackendError}; -use clap::{Parser}; +use backend::{ + BackendError, LlamacppBackend, LlamacppConfig, LlamacppGGMLType, LlamacppNuma, + LlamacppSplitMode, +}; +use clap::Parser; use text_generation_router::{logging, server, usage_stats}; use thiserror::Error; -use tokenizers::{Tokenizer, FromPretrainedParameters}; +use tokenizers::{FromPretrainedParameters, Tokenizer}; use tokio::sync::oneshot::error::RecvError; -use tracing::{warn, error}; +use tracing::{error, warn}; /// Backend Configuration #[derive(Parser, Debug)] @@ -161,11 +164,7 @@ struct Args { async fn main() -> Result<(), RouterError> { let args = Args::parse(); - logging::init_logging( - args.otlp_endpoint, - args.otlp_service_name, - args.json_output - ); + logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output); let n_threads = match args.n_threads { Some(0) | None => num_cpus::get(), @@ -218,10 +217,7 @@ async fn main() -> Result<(), RouterError> { token, ..Default::default() }; - Tokenizer::from_pretrained( - args.model_id.clone(), - Some(params) - )? + Tokenizer::from_pretrained(args.model_id.clone(), Some(params))? }; let (backend, ok, shutdown) = LlamacppBackend::new(