mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Thanks cargo fmt
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
This commit is contained in:
parent
fb81c0d1c4
commit
2b0d99c1cf
@ -1,5 +1,4 @@
|
|||||||
|
use bindgen::callbacks::{ItemInfo, ParseCallbacks};
|
||||||
use bindgen::callbacks::{ParseCallbacks, ItemInfo};
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
@ -7,21 +7,21 @@ mod llamacpp {
|
|||||||
}
|
}
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::ffi::CString;
|
use std::ffi::CString;
|
||||||
|
use std::mem::replace;
|
||||||
|
use std::str::FromStr;
|
||||||
use std::sync::{mpsc, Once};
|
use std::sync::{mpsc, Once};
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::{ValidGenerateRequest};
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
use text_generation_router::{FinishReason, Token};
|
use text_generation_router::{FinishReason, Token};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||||
use tokio::sync::{watch, oneshot};
|
use tokio::sync::{oneshot, watch};
|
||||||
use tokio::task::{spawn, spawn_blocking};
|
use tokio::task::{spawn, spawn_blocking};
|
||||||
use tokio::time::{Duration, Instant, timeout};
|
use tokio::time::{timeout, Duration, Instant};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, info, warn, error, trace};
|
use tracing::instrument;
|
||||||
use tracing::{instrument};
|
use tracing::{debug, error, info, trace, warn};
|
||||||
use std::str::FromStr;
|
|
||||||
use std::mem::replace;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub enum LlamacppSplitMode {
|
pub enum LlamacppSplitMode {
|
||||||
@ -39,7 +39,7 @@ impl FromStr for LlamacppSplitMode {
|
|||||||
_ => match s.parse::<usize>() {
|
_ => match s.parse::<usize>() {
|
||||||
Ok(n) => Ok(LlamacppSplitMode::GPU(n)),
|
Ok(n) => Ok(LlamacppSplitMode::GPU(n)),
|
||||||
Err(_) => Err("Choose a GPU number or `layer` or `row`".to_string()),
|
Err(_) => Err("Choose a GPU number or `layer` or `row`".to_string()),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -241,7 +241,7 @@ impl Llamacpp {
|
|||||||
llamacpp::model_load_from_file(gguf.as_ptr(), params)
|
llamacpp::model_load_from_file(gguf.as_ptr(), params)
|
||||||
};
|
};
|
||||||
if model.is_null() {
|
if model.is_null() {
|
||||||
return Err(BackendError::Llamacpp("Failed to load model".to_string()))
|
return Err(BackendError::Llamacpp("Failed to load model".to_string()));
|
||||||
}
|
}
|
||||||
let ctx = unsafe {
|
let ctx = unsafe {
|
||||||
let mut params = llamacpp::context_default_params();
|
let mut params = llamacpp::context_default_params();
|
||||||
@ -260,7 +260,7 @@ impl Llamacpp {
|
|||||||
llamacpp::init_from_model(model, params)
|
llamacpp::init_from_model(model, params)
|
||||||
};
|
};
|
||||||
if ctx.is_null() {
|
if ctx.is_null() {
|
||||||
return Err(BackendError::Llamacpp("Failed to init context".to_string()))
|
return Err(BackendError::Llamacpp("Failed to init context".to_string()));
|
||||||
}
|
}
|
||||||
let vocab = unsafe {
|
let vocab = unsafe {
|
||||||
llamacpp::model_get_vocab(model)
|
llamacpp::model_get_vocab(model)
|
||||||
@ -444,8 +444,11 @@ impl LlamacppBackend {
|
|||||||
pub fn new(
|
pub fn new(
|
||||||
conf: LlamacppConfig,
|
conf: LlamacppConfig,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
) -> (Self, oneshot::Receiver<Result<(),BackendError>>, watch::Sender<bool>) {
|
) -> (
|
||||||
|
Self,
|
||||||
|
oneshot::Receiver<Result<(), BackendError>>,
|
||||||
|
watch::Sender<bool>,
|
||||||
|
) {
|
||||||
// Setup llama & export logs, once and for all
|
// Setup llama & export logs, once and for all
|
||||||
INIT.call_once(|| unsafe {
|
INIT.call_once(|| unsafe {
|
||||||
llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
|
llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
|
||||||
@ -489,7 +492,7 @@ impl LlamacppBackend {
|
|||||||
if requests.len() == conf.max_batch_size {
|
if requests.len() == conf.max_batch_size {
|
||||||
flush(&mut requests, &mut n_tokens);
|
flush(&mut requests, &mut n_tokens);
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
Ok(None) => break, // closed
|
Ok(None) => break, // closed
|
||||||
Err(_) => flush(&mut requests, &mut n_tokens), // timeout
|
Err(_) => flush(&mut requests, &mut n_tokens), // timeout
|
||||||
}
|
}
|
||||||
@ -498,8 +501,14 @@ impl LlamacppBackend {
|
|||||||
|
|
||||||
spawn_blocking(move || {
|
spawn_blocking(move || {
|
||||||
let mut llamacpp = match Llamacpp::new(conf) {
|
let mut llamacpp = match Llamacpp::new(conf) {
|
||||||
Ok(v) => { let _ = ok_tx.send(Ok(())); v },
|
Ok(v) => {
|
||||||
Err(e) => { let _ = ok_tx.send(Err(e)); return; },
|
let _ = ok_tx.send(Ok(()));
|
||||||
|
v
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = ok_tx.send(Err(e));
|
||||||
|
return;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let vocab = tokenizer.get_added_vocabulary();
|
let vocab = tokenizer.get_added_vocabulary();
|
||||||
|
|
||||||
@ -522,7 +531,7 @@ impl LlamacppBackend {
|
|||||||
_ => {
|
_ => {
|
||||||
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
|
||||||
continue;
|
continue;
|
||||||
},
|
}
|
||||||
};
|
};
|
||||||
let last_pos = request.input_ids.len() - 1;
|
let last_pos = request.input_ids.len() - 1;
|
||||||
|
|
||||||
@ -570,7 +579,7 @@ impl LlamacppBackend {
|
|||||||
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
|
let _ = requests[seq.id].tx.send(Err(InferError::IncompleteGeneration));
|
||||||
seq.running = false;
|
seq.running = false;
|
||||||
continue;
|
continue;
|
||||||
},
|
}
|
||||||
};
|
};
|
||||||
let special = vocab.is_special_token(&piece);
|
let special = vocab.is_special_token(&piece);
|
||||||
|
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
mod backend;
|
mod backend;
|
||||||
|
|
||||||
use backend::{LlamacppNuma, LlamacppGGMLType, LlamacppSplitMode, LlamacppConfig, LlamacppBackend, BackendError};
|
use backend::{
|
||||||
use clap::{Parser};
|
BackendError, LlamacppBackend, LlamacppConfig, LlamacppGGMLType, LlamacppNuma,
|
||||||
|
LlamacppSplitMode,
|
||||||
|
};
|
||||||
|
use clap::Parser;
|
||||||
use text_generation_router::{logging, server, usage_stats};
|
use text_generation_router::{logging, server, usage_stats};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::{Tokenizer, FromPretrainedParameters};
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||||
use tokio::sync::oneshot::error::RecvError;
|
use tokio::sync::oneshot::error::RecvError;
|
||||||
use tracing::{warn, error};
|
use tracing::{error, warn};
|
||||||
|
|
||||||
/// Backend Configuration
|
/// Backend Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -161,11 +164,7 @@ struct Args {
|
|||||||
async fn main() -> Result<(), RouterError> {
|
async fn main() -> Result<(), RouterError> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
logging::init_logging(
|
logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output);
|
||||||
args.otlp_endpoint,
|
|
||||||
args.otlp_service_name,
|
|
||||||
args.json_output
|
|
||||||
);
|
|
||||||
|
|
||||||
let n_threads = match args.n_threads {
|
let n_threads = match args.n_threads {
|
||||||
Some(0) | None => num_cpus::get(),
|
Some(0) | None => num_cpus::get(),
|
||||||
@ -218,10 +217,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
token,
|
token,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
Tokenizer::from_pretrained(
|
Tokenizer::from_pretrained(args.model_id.clone(), Some(params))?
|
||||||
args.model_id.clone(),
|
|
||||||
Some(params)
|
|
||||||
)?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let (backend, ok, shutdown) = LlamacppBackend::new(
|
let (backend, ok, shutdown) = LlamacppBackend::new(
|
||||||
|
Loading…
Reference in New Issue
Block a user