mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-01 15:02:09 +00:00
feat: bundle launcher and refactor cli wrappers
This commit is contained in:
parent
af2b2e8388
commit
30f4deba77
@ -1260,8 +1260,64 @@ fn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::R
|
|||||||
Ok(exit_status)
|
Ok(exit_status)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn internal_main_args() -> Result<(), LauncherError> {
|
||||||
|
let args: Vec<String> = std::env::args()
|
||||||
|
// skips the first arg if it's python
|
||||||
|
.skip_while(|a| a.contains("python"))
|
||||||
|
.collect();
|
||||||
|
let args = Args::parse_from(args);
|
||||||
|
|
||||||
|
internal_main(
|
||||||
|
args.model_id,
|
||||||
|
args.revision,
|
||||||
|
args.validation_workers,
|
||||||
|
args.sharded,
|
||||||
|
args.num_shard,
|
||||||
|
args.quantize,
|
||||||
|
args.speculate,
|
||||||
|
args.dtype,
|
||||||
|
args.trust_remote_code,
|
||||||
|
args.max_concurrent_requests,
|
||||||
|
args.max_best_of,
|
||||||
|
args.max_stop_sequences,
|
||||||
|
args.max_top_n_tokens,
|
||||||
|
args.max_input_tokens,
|
||||||
|
args.max_input_length,
|
||||||
|
args.max_total_tokens,
|
||||||
|
args.waiting_served_ratio,
|
||||||
|
args.max_batch_prefill_tokens,
|
||||||
|
args.max_batch_total_tokens,
|
||||||
|
args.max_waiting_tokens,
|
||||||
|
args.max_batch_size,
|
||||||
|
args.cuda_graphs,
|
||||||
|
args.hostname,
|
||||||
|
args.port,
|
||||||
|
args.shard_uds_path,
|
||||||
|
args.master_addr,
|
||||||
|
args.master_port,
|
||||||
|
args.huggingface_hub_cache,
|
||||||
|
args.weights_cache_override,
|
||||||
|
args.disable_custom_kernels,
|
||||||
|
args.cuda_memory_fraction,
|
||||||
|
args.rope_scaling,
|
||||||
|
args.rope_factor,
|
||||||
|
args.json_output,
|
||||||
|
args.otlp_endpoint,
|
||||||
|
args.cors_allow_origin,
|
||||||
|
args.watermark_gamma,
|
||||||
|
args.watermark_delta,
|
||||||
|
args.ngrok,
|
||||||
|
args.ngrok_authtoken,
|
||||||
|
args.ngrok_edge,
|
||||||
|
args.tokenizer_config_path,
|
||||||
|
args.disable_grammar_support,
|
||||||
|
args.env,
|
||||||
|
args.max_client_batch_size,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn launcher_main(
|
pub fn internal_main(
|
||||||
model_id: String,
|
model_id: String,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
@ -1639,388 +1695,3 @@ pub fn launcher_main(
|
|||||||
|
|
||||||
exit_code
|
exit_code
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn launcher_main_without_server(
|
|
||||||
model_id: String,
|
|
||||||
revision: Option<String>,
|
|
||||||
validation_workers: usize,
|
|
||||||
sharded: Option<bool>,
|
|
||||||
num_shard: Option<usize>,
|
|
||||||
quantize: Option<Quantization>,
|
|
||||||
speculate: Option<usize>,
|
|
||||||
dtype: Option<Dtype>,
|
|
||||||
trust_remote_code: bool,
|
|
||||||
max_concurrent_requests: usize,
|
|
||||||
max_best_of: usize,
|
|
||||||
max_stop_sequences: usize,
|
|
||||||
max_top_n_tokens: u32,
|
|
||||||
max_input_tokens: Option<usize>,
|
|
||||||
max_input_length: Option<usize>,
|
|
||||||
max_total_tokens: Option<usize>,
|
|
||||||
waiting_served_ratio: f32,
|
|
||||||
max_batch_prefill_tokens: Option<u32>,
|
|
||||||
max_batch_total_tokens: Option<u32>,
|
|
||||||
max_waiting_tokens: usize,
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
cuda_graphs: Option<Vec<usize>>,
|
|
||||||
hostname: String,
|
|
||||||
port: u16,
|
|
||||||
shard_uds_path: String,
|
|
||||||
master_addr: String,
|
|
||||||
master_port: usize,
|
|
||||||
huggingface_hub_cache: Option<String>,
|
|
||||||
weights_cache_override: Option<String>,
|
|
||||||
disable_custom_kernels: bool,
|
|
||||||
cuda_memory_fraction: f32,
|
|
||||||
rope_scaling: Option<RopeScaling>,
|
|
||||||
rope_factor: Option<f32>,
|
|
||||||
json_output: bool,
|
|
||||||
otlp_endpoint: Option<String>,
|
|
||||||
cors_allow_origin: Vec<String>,
|
|
||||||
watermark_gamma: Option<f32>,
|
|
||||||
watermark_delta: Option<f32>,
|
|
||||||
ngrok: bool,
|
|
||||||
ngrok_authtoken: Option<String>,
|
|
||||||
ngrok_edge: Option<String>,
|
|
||||||
tokenizer_config_path: Option<String>,
|
|
||||||
disable_grammar_support: bool,
|
|
||||||
env: bool,
|
|
||||||
max_client_batch_size: usize,
|
|
||||||
webserver_callback: Box<dyn FnOnce() -> Result<(), LauncherError>>,
|
|
||||||
) -> Result<(), LauncherError> {
|
|
||||||
let args = Args {
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
validation_workers,
|
|
||||||
sharded,
|
|
||||||
num_shard,
|
|
||||||
quantize,
|
|
||||||
speculate,
|
|
||||||
dtype,
|
|
||||||
trust_remote_code,
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_input_length,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
cuda_graphs,
|
|
||||||
hostname,
|
|
||||||
port,
|
|
||||||
shard_uds_path,
|
|
||||||
master_addr,
|
|
||||||
master_port,
|
|
||||||
huggingface_hub_cache,
|
|
||||||
weights_cache_override,
|
|
||||||
disable_custom_kernels,
|
|
||||||
cuda_memory_fraction,
|
|
||||||
rope_scaling,
|
|
||||||
rope_factor,
|
|
||||||
json_output,
|
|
||||||
otlp_endpoint,
|
|
||||||
cors_allow_origin,
|
|
||||||
watermark_gamma,
|
|
||||||
watermark_delta,
|
|
||||||
ngrok,
|
|
||||||
ngrok_authtoken,
|
|
||||||
ngrok_edge,
|
|
||||||
tokenizer_config_path,
|
|
||||||
disable_grammar_support,
|
|
||||||
env,
|
|
||||||
max_client_batch_size,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Filter events with LOG_LEVEL
|
|
||||||
let env_filter =
|
|
||||||
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
|
|
||||||
|
|
||||||
if args.json_output {
|
|
||||||
tracing_subscriber::fmt()
|
|
||||||
.with_env_filter(env_filter)
|
|
||||||
.json()
|
|
||||||
.init();
|
|
||||||
} else {
|
|
||||||
tracing_subscriber::fmt()
|
|
||||||
.with_env_filter(env_filter)
|
|
||||||
.compact()
|
|
||||||
.init();
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.env {
|
|
||||||
let env_runtime = env_runtime::Env::new();
|
|
||||||
tracing::info!("{}", env_runtime);
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::info!("{:#?}", args);
|
|
||||||
|
|
||||||
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
|
|
||||||
let model_id = args.model_id.clone();
|
|
||||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
|
||||||
let filename = if !path.exists() {
|
|
||||||
// Assume it's a hub id
|
|
||||||
let api = Api::new()?;
|
|
||||||
let repo = if let Some(ref revision) = args.revision {
|
|
||||||
api.repo(Repo::with_revision(
|
|
||||||
model_id,
|
|
||||||
RepoType::Model,
|
|
||||||
revision.to_string(),
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
api.model(model_id)
|
|
||||||
};
|
|
||||||
repo.get("config.json")?
|
|
||||||
} else {
|
|
||||||
path.push("config.json");
|
|
||||||
path
|
|
||||||
};
|
|
||||||
|
|
||||||
let content = std::fs::read_to_string(filename)?;
|
|
||||||
let config: Config = serde_json::from_str(&content)?;
|
|
||||||
|
|
||||||
// Quantization usually means you're even more RAM constrained.
|
|
||||||
let max_default = 4096;
|
|
||||||
|
|
||||||
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) {
|
|
||||||
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
|
|
||||||
if max_position_embeddings > max_default {
|
|
||||||
let max = max_position_embeddings;
|
|
||||||
if args.max_input_tokens.is_none()
|
|
||||||
&& args.max_total_tokens.is_none()
|
|
||||||
&& args.max_batch_prefill_tokens.is_none()
|
|
||||||
{
|
|
||||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
|
||||||
}
|
|
||||||
max_default
|
|
||||||
} else {
|
|
||||||
max_position_embeddings
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
return Err(Box::new(LauncherError::ArgumentValidation(
|
|
||||||
"no max defined".to_string(),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(max_position_embeddings)
|
|
||||||
};
|
|
||||||
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
|
||||||
|
|
||||||
let max_input_tokens = {
|
|
||||||
match (args.max_input_tokens, args.max_input_length) {
|
|
||||||
(Some(max_input_tokens), Some(max_input_length)) => {
|
|
||||||
return Err(LauncherError::ArgumentValidation(
|
|
||||||
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
|
|
||||||
(None, None) => {
|
|
||||||
let value = max_position_embeddings - 1;
|
|
||||||
tracing::info!("Default `max_input_tokens` to {value}");
|
|
||||||
value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let max_total_tokens = {
|
|
||||||
match args.max_total_tokens {
|
|
||||||
Some(max_total_tokens) => max_total_tokens,
|
|
||||||
None => {
|
|
||||||
let value = max_position_embeddings;
|
|
||||||
tracing::info!("Default `max_total_tokens` to {value}");
|
|
||||||
value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let max_batch_prefill_tokens = {
|
|
||||||
match args.max_batch_prefill_tokens {
|
|
||||||
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
|
||||||
None => {
|
|
||||||
let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
|
|
||||||
max_batch_size * max_input_tokens
|
|
||||||
} else {
|
|
||||||
// Adding some edge in order to account for potential block_size alignement
|
|
||||||
// issue.
|
|
||||||
max_input_tokens + 50
|
|
||||||
} as u32;
|
|
||||||
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
|
||||||
value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Validate args
|
|
||||||
if max_input_tokens >= max_total_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(
|
|
||||||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
|
||||||
max_batch_prefill_tokens, max_input_tokens
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
|
||||||
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
|
||||||
#[allow(deprecated)]
|
|
||||||
(
|
|
||||||
None,
|
|
||||||
Some(
|
|
||||||
Quantization::Bitsandbytes
|
|
||||||
| Quantization::BitsandbytesNF4
|
|
||||||
| Quantization::BitsandbytesFP4,
|
|
||||||
),
|
|
||||||
) => {
|
|
||||||
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
|
||||||
vec![]
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
|
|
||||||
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
|
|
||||||
cuda_graphs
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if args.validation_workers == 0 {
|
|
||||||
return Err(LauncherError::ArgumentValidation(
|
|
||||||
"`validation_workers` must be > 0".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
if args.trust_remote_code {
|
|
||||||
tracing::warn!(
|
|
||||||
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
|
|
||||||
args.model_id
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
|
||||||
if num_shard > 1 {
|
|
||||||
tracing::info!("Sharding model on {num_shard} processes");
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
|
||||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
|
||||||
max_batch_prefill_tokens, max_batch_total_tokens
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
|
||||||
max_total_tokens, max_batch_total_tokens
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.ngrok {
|
|
||||||
if args.ngrok_authtoken.is_none() {
|
|
||||||
return Err(LauncherError::ArgumentValidation(
|
|
||||||
"`ngrok-authtoken` must be set when using ngrok tunneling".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.ngrok_edge.is_none() {
|
|
||||||
return Err(LauncherError::ArgumentValidation(
|
|
||||||
"`ngrok-edge` must be set when using ngrok tunneling".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signal handler
|
|
||||||
let running = Arc::new(AtomicBool::new(true));
|
|
||||||
let r = running.clone();
|
|
||||||
ctrlc::set_handler(move || {
|
|
||||||
r.store(false, Ordering::SeqCst);
|
|
||||||
})
|
|
||||||
.expect("Error setting Ctrl-C handler");
|
|
||||||
|
|
||||||
// Download and convert model weights
|
|
||||||
download_convert_model(&args, running.clone())?;
|
|
||||||
|
|
||||||
if !running.load(Ordering::SeqCst) {
|
|
||||||
// Launcher was asked to stop
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shared shutdown bool
|
|
||||||
let shutdown = Arc::new(AtomicBool::new(false));
|
|
||||||
// Shared shutdown channel
|
|
||||||
// When shutting down, the main thread will wait for all senders to be dropped
|
|
||||||
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
|
|
||||||
|
|
||||||
// Shared channel to track shard status
|
|
||||||
let (status_sender, status_receiver) = mpsc::channel();
|
|
||||||
|
|
||||||
spawn_shards(
|
|
||||||
num_shard,
|
|
||||||
&args,
|
|
||||||
cuda_graphs,
|
|
||||||
max_total_tokens,
|
|
||||||
shutdown.clone(),
|
|
||||||
&shutdown_receiver,
|
|
||||||
shutdown_sender,
|
|
||||||
&status_receiver,
|
|
||||||
status_sender,
|
|
||||||
running.clone(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// We might have received a termination signal
|
|
||||||
if !running.load(Ordering::SeqCst) {
|
|
||||||
shutdown_shards(shutdown, &shutdown_receiver);
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
// let mut webserver = spawn_webserver(
|
|
||||||
// num_shard,
|
|
||||||
// args,
|
|
||||||
// max_input_tokens,
|
|
||||||
// max_total_tokens,
|
|
||||||
// max_batch_prefill_tokens,
|
|
||||||
// shutdown.clone(),
|
|
||||||
// &shutdown_receiver,
|
|
||||||
// )
|
|
||||||
// .map_err(|err| {
|
|
||||||
// shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
|
||||||
// err
|
|
||||||
// })?;
|
|
||||||
|
|
||||||
webserver_callback()?;
|
|
||||||
|
|
||||||
println!("Webserver started");
|
|
||||||
|
|
||||||
// Default exit code
|
|
||||||
let mut exit_code = Ok(());
|
|
||||||
|
|
||||||
while running.load(Ordering::SeqCst) {
|
|
||||||
if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
|
|
||||||
tracing::error!("Shard {rank} crashed");
|
|
||||||
exit_code = Err(LauncherError::ShardFailed);
|
|
||||||
break;
|
|
||||||
};
|
|
||||||
|
|
||||||
// match webserver.try_wait().unwrap() {
|
|
||||||
// Some(_) => {
|
|
||||||
// tracing::error!("Webserver Crashed");
|
|
||||||
// shutdown_shards(shutdown, &shutdown_receiver);
|
|
||||||
// return Err(LauncherError::WebserverFailed);
|
|
||||||
// }
|
|
||||||
// None => {
|
|
||||||
// sleep(Duration::from_millis(100));
|
|
||||||
// }
|
|
||||||
// };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Graceful termination
|
|
||||||
// terminate("webserver", webserver, Duration::from_secs(90)).unwrap();
|
|
||||||
shutdown_shards(shutdown, &shutdown_receiver);
|
|
||||||
|
|
||||||
exit_code
|
|
||||||
}
|
|
||||||
|
@ -1,53 +1,5 @@
|
|||||||
use clap::Parser;
|
use text_generation_launcher::{internal_main_args, LauncherError};
|
||||||
use text_generation_launcher::{launcher_main, Args, LauncherError};
|
|
||||||
|
|
||||||
fn main() -> Result<(), LauncherError> {
|
fn main() -> Result<(), LauncherError> {
|
||||||
let args = Args::parse();
|
internal_main_args()
|
||||||
launcher_main(
|
|
||||||
args.model_id,
|
|
||||||
args.revision,
|
|
||||||
args.validation_workers,
|
|
||||||
args.sharded,
|
|
||||||
args.num_shard,
|
|
||||||
args.quantize,
|
|
||||||
args.speculate,
|
|
||||||
args.dtype,
|
|
||||||
args.trust_remote_code,
|
|
||||||
args.max_concurrent_requests,
|
|
||||||
args.max_best_of,
|
|
||||||
args.max_stop_sequences,
|
|
||||||
args.max_top_n_tokens,
|
|
||||||
args.max_input_tokens,
|
|
||||||
args.max_input_length,
|
|
||||||
args.max_total_tokens,
|
|
||||||
args.waiting_served_ratio,
|
|
||||||
args.max_batch_prefill_tokens,
|
|
||||||
args.max_batch_total_tokens,
|
|
||||||
args.max_waiting_tokens,
|
|
||||||
args.max_batch_size,
|
|
||||||
args.cuda_graphs,
|
|
||||||
args.hostname,
|
|
||||||
args.port,
|
|
||||||
args.shard_uds_path,
|
|
||||||
args.master_addr,
|
|
||||||
args.master_port,
|
|
||||||
args.huggingface_hub_cache,
|
|
||||||
args.weights_cache_override,
|
|
||||||
args.disable_custom_kernels,
|
|
||||||
args.cuda_memory_fraction,
|
|
||||||
args.rope_scaling,
|
|
||||||
args.rope_factor,
|
|
||||||
args.json_output,
|
|
||||||
args.otlp_endpoint,
|
|
||||||
args.cors_allow_origin,
|
|
||||||
args.watermark_gamma,
|
|
||||||
args.watermark_delta,
|
|
||||||
args.ngrok,
|
|
||||||
args.ngrok_authtoken,
|
|
||||||
args.ngrok_edge,
|
|
||||||
args.tokenizer_config_path,
|
|
||||||
args.disable_grammar_support,
|
|
||||||
args.env,
|
|
||||||
args.max_client_batch_size,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ pub mod server;
|
|||||||
mod validation;
|
mod validation;
|
||||||
|
|
||||||
use axum::http::HeaderValue;
|
use axum::http::HeaderValue;
|
||||||
|
use clap::Parser;
|
||||||
use config::Config;
|
use config::Config;
|
||||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||||
use hf_hub::{Cache, Repo, RepoType};
|
use hf_hub::{Cache, Repo, RepoType};
|
||||||
@ -175,6 +176,108 @@ pub enum RouterError {
|
|||||||
Axum(#[from] axum::BoxError),
|
Axum(#[from] axum::BoxError),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// App Configuration
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[clap(author, version, about, long_about = None)]
|
||||||
|
pub struct Args {
|
||||||
|
#[clap(default_value = "128", long, env)]
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
max_best_of: usize,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
#[clap(default_value = "1024", long, env)]
|
||||||
|
max_input_tokens: usize,
|
||||||
|
#[clap(default_value = "2048", long, env)]
|
||||||
|
max_total_tokens: usize,
|
||||||
|
#[clap(default_value = "1.2", long, env)]
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
#[clap(default_value = "4096", long, env)]
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
#[clap(default_value = "20", long, env)]
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
|
hostname: String,
|
||||||
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
|
port: u16,
|
||||||
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
|
master_shard_uds_path: String,
|
||||||
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
|
tokenizer_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
revision: Option<String>,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
validation_workers: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
json_output: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
otlp_endpoint: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_authtoken: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_edge: Option<String>,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
messages_api_enabled: bool,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
disable_grammar_support: bool,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_client_batch_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn internal_main_args() -> Result<(), RouterError> {
|
||||||
|
let args: Vec<String> = std::env::args()
|
||||||
|
// skips the first arg if it's python
|
||||||
|
.skip_while(|a| a.contains("python"))
|
||||||
|
.collect();
|
||||||
|
let args = Args::parse_from(args);
|
||||||
|
|
||||||
|
println!("{:?}", args);
|
||||||
|
let out = internal_main(
|
||||||
|
args.max_concurrent_requests,
|
||||||
|
args.max_best_of,
|
||||||
|
args.max_stop_sequences,
|
||||||
|
args.max_top_n_tokens,
|
||||||
|
args.max_input_tokens,
|
||||||
|
args.max_total_tokens,
|
||||||
|
args.waiting_served_ratio,
|
||||||
|
args.max_batch_prefill_tokens,
|
||||||
|
args.max_batch_total_tokens,
|
||||||
|
args.max_waiting_tokens,
|
||||||
|
args.max_batch_size,
|
||||||
|
args.hostname,
|
||||||
|
args.port,
|
||||||
|
args.master_shard_uds_path,
|
||||||
|
args.tokenizer_name,
|
||||||
|
args.tokenizer_config_path,
|
||||||
|
args.revision,
|
||||||
|
args.validation_workers,
|
||||||
|
args.json_output,
|
||||||
|
args.otlp_endpoint,
|
||||||
|
args.cors_allow_origin,
|
||||||
|
args.ngrok,
|
||||||
|
args.ngrok_authtoken,
|
||||||
|
args.ngrok_edge,
|
||||||
|
args.messages_api_enabled,
|
||||||
|
args.disable_grammar_support,
|
||||||
|
args.max_client_batch_size,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
println!("[internal_main_args] {:?}", out);
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn internal_main(
|
pub async fn internal_main(
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
|
@ -1,100 +1,7 @@
|
|||||||
use clap::Parser;
|
use text_generation_router::{internal_main_args, RouterError};
|
||||||
use text_generation_router::{internal_main, RouterError};
|
|
||||||
|
|
||||||
/// App Configuration
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[clap(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
#[clap(default_value = "128", long, env)]
|
|
||||||
max_concurrent_requests: usize,
|
|
||||||
#[clap(default_value = "2", long, env)]
|
|
||||||
max_best_of: usize,
|
|
||||||
#[clap(default_value = "4", long, env)]
|
|
||||||
max_stop_sequences: usize,
|
|
||||||
#[clap(default_value = "5", long, env)]
|
|
||||||
max_top_n_tokens: u32,
|
|
||||||
#[clap(default_value = "1024", long, env)]
|
|
||||||
max_input_tokens: usize,
|
|
||||||
#[clap(default_value = "2048", long, env)]
|
|
||||||
max_total_tokens: usize,
|
|
||||||
#[clap(default_value = "1.2", long, env)]
|
|
||||||
waiting_served_ratio: f32,
|
|
||||||
#[clap(default_value = "4096", long, env)]
|
|
||||||
max_batch_prefill_tokens: u32,
|
|
||||||
#[clap(long, env)]
|
|
||||||
max_batch_total_tokens: Option<u32>,
|
|
||||||
#[clap(default_value = "20", long, env)]
|
|
||||||
max_waiting_tokens: usize,
|
|
||||||
#[clap(long, env)]
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
#[clap(default_value = "0.0.0.0", long, env)]
|
|
||||||
hostname: String,
|
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
|
||||||
port: u16,
|
|
||||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
|
||||||
master_shard_uds_path: String,
|
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
|
||||||
tokenizer_name: String,
|
|
||||||
#[clap(long, env)]
|
|
||||||
tokenizer_config_path: Option<String>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
revision: Option<String>,
|
|
||||||
#[clap(default_value = "2", long, env)]
|
|
||||||
validation_workers: usize,
|
|
||||||
#[clap(long, env)]
|
|
||||||
json_output: bool,
|
|
||||||
#[clap(long, env)]
|
|
||||||
otlp_endpoint: Option<String>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
cors_allow_origin: Option<Vec<String>>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok: bool,
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok_authtoken: Option<String>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok_edge: Option<String>,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
messages_api_enabled: bool,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
disable_grammar_support: bool,
|
|
||||||
#[clap(default_value = "4", long, env)]
|
|
||||||
max_client_batch_size: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), RouterError> {
|
async fn main() -> Result<(), RouterError> {
|
||||||
// Get args
|
internal_main_args().await?;
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
internal_main(
|
|
||||||
args.max_concurrent_requests,
|
|
||||||
args.max_best_of,
|
|
||||||
args.max_stop_sequences,
|
|
||||||
args.max_top_n_tokens,
|
|
||||||
args.max_input_tokens,
|
|
||||||
args.max_total_tokens,
|
|
||||||
args.waiting_served_ratio,
|
|
||||||
args.max_batch_prefill_tokens,
|
|
||||||
args.max_batch_total_tokens,
|
|
||||||
args.max_waiting_tokens,
|
|
||||||
args.max_batch_size,
|
|
||||||
args.hostname,
|
|
||||||
args.port,
|
|
||||||
args.master_shard_uds_path,
|
|
||||||
args.tokenizer_name,
|
|
||||||
args.tokenizer_config_path,
|
|
||||||
args.revision,
|
|
||||||
args.validation_workers,
|
|
||||||
args.json_output,
|
|
||||||
args.otlp_endpoint,
|
|
||||||
args.cors_allow_origin,
|
|
||||||
args.ngrok,
|
|
||||||
args.ngrok_authtoken,
|
|
||||||
args.ngrok_edge,
|
|
||||||
args.messages_api_enabled,
|
|
||||||
args.disable_grammar_support,
|
|
||||||
args.max_client_batch_size,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -13,3 +13,5 @@ library-install:
|
|||||||
pip install -e .
|
pip install -e .
|
||||||
|
|
||||||
install: build comment-gitignore library-install remove-comment-gitignore
|
install: build comment-gitignore library-install remove-comment-gitignore
|
||||||
|
|
||||||
|
quick-install: build library-install
|
||||||
|
@ -31,3 +31,5 @@ python-packages = ["tgi", "text_generation_server"]
|
|||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
text-generation-server = "tgi:text_generation_server_cli_main"
|
text-generation-server = "tgi:text_generation_server_cli_main"
|
||||||
|
text-generation-router = "tgi:text_generation_router_cli_main"
|
||||||
|
text-generation-launcher = "tgi:text_generation_launcher_cli_main"
|
||||||
|
316
tgi/src/lib.rs
316
tgi/src/lib.rs
@ -1,6 +1,8 @@
|
|||||||
use pyo3::{prelude::*, wrap_pyfunction};
|
use pyo3::{prelude::*, wrap_pyfunction};
|
||||||
use text_generation_launcher::{launcher_main, launcher_main_without_server};
|
use std::thread;
|
||||||
use text_generation_router::internal_main;
|
use text_generation_launcher::{internal_main, internal_main_args as internal_main_args_launcher};
|
||||||
|
use text_generation_router::internal_main_args as internal_main_args_router;
|
||||||
|
use tokio::runtime::Runtime;
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
@ -100,7 +102,7 @@ fn rust_launcher(
|
|||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
) -> PyResult<&PyAny> {
|
) -> PyResult<&PyAny> {
|
||||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||||
launcher_main(
|
internal_main(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
@ -153,251 +155,6 @@ fn rust_launcher(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
#[pyfunction]
|
|
||||||
#[pyo3(signature = (
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
validation_workers,
|
|
||||||
sharded,
|
|
||||||
num_shard,
|
|
||||||
_quantize,
|
|
||||||
speculate,
|
|
||||||
_dtype,
|
|
||||||
trust_remote_code,
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_input_length,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
cuda_graphs,
|
|
||||||
hostname,
|
|
||||||
port,
|
|
||||||
shard_uds_path,
|
|
||||||
master_addr,
|
|
||||||
master_port,
|
|
||||||
huggingface_hub_cache,
|
|
||||||
weights_cache_override,
|
|
||||||
disable_custom_kernels,
|
|
||||||
cuda_memory_fraction,
|
|
||||||
_rope_scaling,
|
|
||||||
rope_factor,
|
|
||||||
json_output,
|
|
||||||
otlp_endpoint,
|
|
||||||
cors_allow_origin,
|
|
||||||
watermark_gamma,
|
|
||||||
watermark_delta,
|
|
||||||
ngrok,
|
|
||||||
ngrok_authtoken,
|
|
||||||
ngrok_edge,
|
|
||||||
tokenizer_config_path,
|
|
||||||
disable_grammar_support,
|
|
||||||
env,
|
|
||||||
max_client_batch_size,
|
|
||||||
))]
|
|
||||||
fn fully_packaged(
|
|
||||||
py: Python<'_>,
|
|
||||||
model_id: String,
|
|
||||||
revision: Option<String>,
|
|
||||||
validation_workers: usize,
|
|
||||||
sharded: Option<bool>,
|
|
||||||
num_shard: Option<usize>,
|
|
||||||
_quantize: Option<String>, // Option<Quantization>,
|
|
||||||
speculate: Option<usize>,
|
|
||||||
_dtype: Option<String>, // Option<Dtype>,
|
|
||||||
trust_remote_code: bool,
|
|
||||||
max_concurrent_requests: usize,
|
|
||||||
max_best_of: usize,
|
|
||||||
max_stop_sequences: usize,
|
|
||||||
max_top_n_tokens: u32,
|
|
||||||
max_input_tokens: Option<usize>,
|
|
||||||
max_input_length: Option<usize>,
|
|
||||||
max_total_tokens: Option<usize>,
|
|
||||||
waiting_served_ratio: f32,
|
|
||||||
max_batch_prefill_tokens: Option<u32>,
|
|
||||||
max_batch_total_tokens: Option<u32>,
|
|
||||||
max_waiting_tokens: usize,
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
cuda_graphs: Option<Vec<usize>>,
|
|
||||||
hostname: String,
|
|
||||||
port: u16,
|
|
||||||
shard_uds_path: String,
|
|
||||||
master_addr: String,
|
|
||||||
master_port: usize,
|
|
||||||
huggingface_hub_cache: Option<String>,
|
|
||||||
weights_cache_override: Option<String>,
|
|
||||||
disable_custom_kernels: bool,
|
|
||||||
cuda_memory_fraction: f32,
|
|
||||||
_rope_scaling: Option<f32>, // Option<RopeScaling>,
|
|
||||||
rope_factor: Option<f32>,
|
|
||||||
json_output: bool,
|
|
||||||
otlp_endpoint: Option<String>,
|
|
||||||
cors_allow_origin: Vec<String>,
|
|
||||||
watermark_gamma: Option<f32>,
|
|
||||||
watermark_delta: Option<f32>,
|
|
||||||
ngrok: bool,
|
|
||||||
ngrok_authtoken: Option<String>,
|
|
||||||
ngrok_edge: Option<String>,
|
|
||||||
tokenizer_config_path: Option<String>,
|
|
||||||
disable_grammar_support: bool,
|
|
||||||
env: bool,
|
|
||||||
max_client_batch_size: usize,
|
|
||||||
) -> PyResult<&PyAny> {
|
|
||||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
|
||||||
use std::thread;
|
|
||||||
use tokio::runtime::Runtime;
|
|
||||||
|
|
||||||
let model_id_clone = model_id.clone();
|
|
||||||
let max_concurrent_requests_clone = max_concurrent_requests;
|
|
||||||
let max_best_of_clone = max_best_of;
|
|
||||||
let max_stop_sequences_clone = max_stop_sequences;
|
|
||||||
let max_top_n_tokens_clone = max_top_n_tokens;
|
|
||||||
let max_input_tokens_clone = max_input_tokens.unwrap_or(1024);
|
|
||||||
let max_total_tokens_clone = max_total_tokens.unwrap_or(2048);
|
|
||||||
let waiting_served_ratio_clone = waiting_served_ratio;
|
|
||||||
|
|
||||||
let max_batch_prefill_tokens_clone = max_batch_prefill_tokens.unwrap_or(4096);
|
|
||||||
let max_batch_total_tokens_clone = max_batch_total_tokens;
|
|
||||||
let max_waiting_tokens_clone = max_waiting_tokens;
|
|
||||||
let max_batch_size_clone = max_batch_size;
|
|
||||||
let hostname_clone = hostname.clone();
|
|
||||||
let port_clone = port;
|
|
||||||
|
|
||||||
// TODO: fix this
|
|
||||||
let _shard_uds_path_clone = shard_uds_path.clone();
|
|
||||||
|
|
||||||
let tokenizer_config_path = tokenizer_config_path.clone();
|
|
||||||
let revision = revision.clone();
|
|
||||||
let validation_workers = validation_workers;
|
|
||||||
let json_output = json_output;
|
|
||||||
|
|
||||||
let otlp_endpoint = otlp_endpoint.clone();
|
|
||||||
let cors_allow_origin = cors_allow_origin.clone();
|
|
||||||
let ngrok = ngrok;
|
|
||||||
let ngrok_authtoken = ngrok_authtoken.clone();
|
|
||||||
let ngrok_edge = ngrok_edge.clone();
|
|
||||||
let messages_api_enabled = true;
|
|
||||||
let disable_grammar_support = disable_grammar_support;
|
|
||||||
let max_client_batch_size = max_client_batch_size;
|
|
||||||
|
|
||||||
let ngrok_clone = ngrok;
|
|
||||||
let ngrok_authtoken_clone = ngrok_authtoken.clone();
|
|
||||||
let ngrok_edge_clone = ngrok_edge.clone();
|
|
||||||
let messages_api_enabled_clone = messages_api_enabled;
|
|
||||||
let disable_grammar_support_clone = disable_grammar_support;
|
|
||||||
let max_client_batch_size_clone = max_client_batch_size;
|
|
||||||
|
|
||||||
let tokenizer_config_path_clone = tokenizer_config_path.clone();
|
|
||||||
let revision_clone = revision.clone();
|
|
||||||
let validation_workers_clone = validation_workers;
|
|
||||||
let json_output_clone = json_output;
|
|
||||||
let otlp_endpoint_clone = otlp_endpoint.clone();
|
|
||||||
|
|
||||||
let webserver_callback = move || {
|
|
||||||
let handle = thread::spawn(move || {
|
|
||||||
let rt = Runtime::new().unwrap();
|
|
||||||
rt.block_on(async {
|
|
||||||
internal_main(
|
|
||||||
max_concurrent_requests_clone,
|
|
||||||
max_best_of_clone,
|
|
||||||
max_stop_sequences_clone,
|
|
||||||
max_top_n_tokens_clone,
|
|
||||||
max_input_tokens_clone,
|
|
||||||
max_total_tokens_clone,
|
|
||||||
waiting_served_ratio_clone,
|
|
||||||
max_batch_prefill_tokens_clone,
|
|
||||||
max_batch_total_tokens_clone,
|
|
||||||
max_waiting_tokens_clone,
|
|
||||||
max_batch_size_clone,
|
|
||||||
hostname_clone,
|
|
||||||
port_clone,
|
|
||||||
"/tmp/text-generation-server-0".to_string(),
|
|
||||||
model_id_clone,
|
|
||||||
tokenizer_config_path_clone,
|
|
||||||
revision_clone,
|
|
||||||
validation_workers_clone,
|
|
||||||
json_output_clone,
|
|
||||||
otlp_endpoint_clone,
|
|
||||||
None,
|
|
||||||
ngrok_clone,
|
|
||||||
ngrok_authtoken_clone,
|
|
||||||
ngrok_edge_clone,
|
|
||||||
messages_api_enabled_clone,
|
|
||||||
disable_grammar_support_clone,
|
|
||||||
max_client_batch_size_clone,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
})
|
|
||||||
});
|
|
||||||
match handle.join() {
|
|
||||||
Ok(_) => println!("Server exited successfully"),
|
|
||||||
Err(e) => println!("Server exited with error: {:?}", e),
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
};
|
|
||||||
|
|
||||||
// parse the arguments and run the main function
|
|
||||||
launcher_main_without_server(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
validation_workers,
|
|
||||||
sharded,
|
|
||||||
num_shard,
|
|
||||||
None,
|
|
||||||
speculate,
|
|
||||||
None,
|
|
||||||
trust_remote_code,
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_input_length,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
cuda_graphs,
|
|
||||||
hostname,
|
|
||||||
port,
|
|
||||||
shard_uds_path,
|
|
||||||
master_addr,
|
|
||||||
master_port,
|
|
||||||
huggingface_hub_cache,
|
|
||||||
weights_cache_override,
|
|
||||||
disable_custom_kernels,
|
|
||||||
cuda_memory_fraction,
|
|
||||||
None,
|
|
||||||
rope_factor,
|
|
||||||
json_output,
|
|
||||||
otlp_endpoint,
|
|
||||||
cors_allow_origin,
|
|
||||||
watermark_gamma,
|
|
||||||
watermark_delta,
|
|
||||||
ngrok,
|
|
||||||
ngrok_authtoken,
|
|
||||||
ngrok_edge,
|
|
||||||
tokenizer_config_path,
|
|
||||||
disable_grammar_support,
|
|
||||||
env,
|
|
||||||
max_client_batch_size,
|
|
||||||
Box::new(webserver_callback),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Ok(Python::with_gil(|py| py.None()))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Asynchronous sleep function.
|
/// Asynchronous sleep function.
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
fn rust_sleep(py: Python<'_>) -> PyResult<&PyAny> {
|
fn rust_sleep(py: Python<'_>) -> PyResult<&PyAny> {
|
||||||
@ -407,49 +164,38 @@ fn rust_sleep(py: Python<'_>) -> PyResult<&PyAny> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: remove hardcoding
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
fn rust_server(py: Python<'_>) -> PyResult<&PyAny> {
|
fn rust_router(_py: Python<'_>) -> PyResult<String> {
|
||||||
pyo3_asyncio::tokio::future_into_py(py, async {
|
let handle = thread::spawn(move || {
|
||||||
let _ = internal_main(
|
let rt = Runtime::new().unwrap();
|
||||||
128, // max_concurrent_requests: usize,
|
rt.block_on(async { internal_main_args_router().await })
|
||||||
2, // max_best_of: usize,
|
});
|
||||||
4, // max_stop_sequences: usize,
|
match handle.join() {
|
||||||
5, // max_top_n_tokens: u32,
|
Ok(thread_output) => match thread_output {
|
||||||
1024, // max_input_tokens: usize,
|
Ok(_) => println!("Inner server exited successfully"),
|
||||||
2048, // max_total_tokens: usize,
|
Err(e) => println!("Inner server exited with error: {:?}", e),
|
||||||
1.2, // waiting_served_ratio: f32,
|
},
|
||||||
4096, // max_batch_prefill_tokens: u32,
|
Err(e) => {
|
||||||
None, // max_batch_total_tokens: Option<u32>,
|
println!("Server exited with error: {:?}", e);
|
||||||
20, // max_waiting_tokens: usize,
|
}
|
||||||
None, // max_batch_size: Option<usize>,
|
}
|
||||||
"0.0.0.0".to_string(), // hostname: String,
|
Ok("Completed".to_string())
|
||||||
3000, // port: u16,
|
}
|
||||||
"/tmp/text-generation-server-0".to_string(), // master_shard_uds_path: String,
|
|
||||||
"llava-hf/llava-v1.6-mistral-7b-hf".to_string(), // tokenizer_name: String,
|
#[pyfunction]
|
||||||
None, // tokenizer_config_path: Option<String>,
|
fn rust_launcher_cli(_py: Python<'_>) -> PyResult<String> {
|
||||||
None, // revision: Option<String>,
|
match internal_main_args_launcher() {
|
||||||
2, // validation_workers: usize,
|
Ok(_) => println!("Server exited successfully"),
|
||||||
false, // json_output: bool,
|
Err(e) => println!("Server exited with error: {:?}", e),
|
||||||
None, // otlp_endpoint: Option<String>,
|
}
|
||||||
None, // cors_allow_origin: Option<Vec<String>>,
|
Ok("Completed".to_string())
|
||||||
false, // ngrok: bool,
|
|
||||||
None, // ngrok_authtoken: Option<String>,
|
|
||||||
None, // ngrok_edge: Option<String>,
|
|
||||||
false, // messages_api_enabled: bool,
|
|
||||||
false, // disable_grammar_support: bool,
|
|
||||||
4, // max_client_batch_size: usize,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
Ok(Python::with_gil(|py| py.None()))
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn tgi(_py: Python, m: &PyModule) -> PyResult<()> {
|
fn tgi(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_function(wrap_pyfunction!(rust_sleep, m)?)?;
|
m.add_function(wrap_pyfunction!(rust_sleep, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(rust_server, m)?)?;
|
m.add_function(wrap_pyfunction!(rust_router, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(rust_launcher, m)?)?;
|
m.add_function(wrap_pyfunction!(rust_launcher, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(fully_packaged, m)?)?;
|
m.add_function(wrap_pyfunction!(rust_launcher_cli, m)?)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
from .tgi import *
|
from .tgi import *
|
||||||
import threading
|
import threading
|
||||||
from tgi import rust_launcher, rust_sleep, fully_packaged
|
from tgi import rust_router, rust_launcher, rust_launcher_cli
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
import sys
|
|
||||||
from text_generation_server.cli import app
|
from text_generation_server.cli import app
|
||||||
|
|
||||||
# add the rust_launcher coroutine to the __all__ list
|
# add the rust_launcher coroutine to the __all__ list
|
||||||
@ -17,6 +16,14 @@ def text_generation_server_cli_main():
|
|||||||
app()
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
def text_generation_router_cli_main():
|
||||||
|
rust_router()
|
||||||
|
|
||||||
|
|
||||||
|
def text_generation_launcher_cli_main():
|
||||||
|
rust_launcher_cli()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Args:
|
class Args:
|
||||||
model_id = "google/gemma-2b-it"
|
model_id = "google/gemma-2b-it"
|
||||||
@ -81,7 +88,7 @@ class TGI(object):
|
|||||||
print(args)
|
print(args)
|
||||||
args = Args(**args)
|
args = Args(**args)
|
||||||
try:
|
try:
|
||||||
await fully_packaged(
|
await rust_launcher(
|
||||||
args.model_id,
|
args.model_id,
|
||||||
args.revision,
|
args.revision,
|
||||||
args.validation_workers,
|
args.validation_workers,
|
||||||
|
Loading…
Reference in New Issue
Block a user