mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Adding new env variables for TPU backends.
This commit is contained in:
parent
06c3d4b1ec
commit
5bc3d65dd3
@ -448,6 +448,8 @@ fn shard_manager(
|
|||||||
cuda_memory_fraction: f32,
|
cuda_memory_fraction: f32,
|
||||||
rope_scaling: Option<RopeScaling>,
|
rope_scaling: Option<RopeScaling>,
|
||||||
rope_factor: Option<f32>,
|
rope_factor: Option<f32>,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
@ -512,6 +514,7 @@ fn shard_manager(
|
|||||||
(Some(scaling), Some(factor)) => Some((scaling, factor)),
|
(Some(scaling), Some(factor)) => Some((scaling, factor)),
|
||||||
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
|
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
|
||||||
};
|
};
|
||||||
|
|
||||||
// OpenTelemetry
|
// OpenTelemetry
|
||||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||||
shard_args.push("--otlp-endpoint".to_string());
|
shard_args.push("--otlp-endpoint".to_string());
|
||||||
@ -564,6 +567,14 @@ fn shard_manager(
|
|||||||
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
envs.push((
|
||||||
|
"MAX_TOTAL_TOKENS".into(),
|
||||||
|
max_total_tokens.to_string().into(),
|
||||||
|
));
|
||||||
|
if let Some(max_batch_size) = max_batch_size {
|
||||||
|
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
|
||||||
|
}
|
||||||
|
|
||||||
// If huggingface_hub_cache is some, pass it to the shard
|
// If huggingface_hub_cache is some, pass it to the shard
|
||||||
// Useful when running inside a docker container
|
// Useful when running inside a docker container
|
||||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||||
@ -967,6 +978,7 @@ fn spawn_shards(
|
|||||||
num_shard: usize,
|
num_shard: usize,
|
||||||
args: &Args,
|
args: &Args,
|
||||||
cuda_graphs: Vec<usize>,
|
cuda_graphs: Vec<usize>,
|
||||||
|
max_total_tokens: usize,
|
||||||
shutdown: Arc<AtomicBool>,
|
shutdown: Arc<AtomicBool>,
|
||||||
shutdown_receiver: &mpsc::Receiver<()>,
|
shutdown_receiver: &mpsc::Receiver<()>,
|
||||||
shutdown_sender: mpsc::Sender<()>,
|
shutdown_sender: mpsc::Sender<()>,
|
||||||
@ -998,6 +1010,7 @@ fn spawn_shards(
|
|||||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||||
let rope_scaling = args.rope_scaling;
|
let rope_scaling = args.rope_scaling;
|
||||||
let rope_factor = args.rope_factor;
|
let rope_factor = args.rope_factor;
|
||||||
|
let max_batch_size = args.max_batch_size;
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_id,
|
model_id,
|
||||||
@ -1020,6 +1033,8 @@ fn spawn_shards(
|
|||||||
cuda_memory_fraction,
|
cuda_memory_fraction,
|
||||||
rope_scaling,
|
rope_scaling,
|
||||||
rope_factor,
|
rope_factor,
|
||||||
|
max_total_tokens,
|
||||||
|
max_batch_size,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
@ -1473,6 +1488,7 @@ fn main() -> Result<(), LauncherError> {
|
|||||||
num_shard,
|
num_shard,
|
||||||
&args,
|
&args,
|
||||||
cuda_graphs,
|
cuda_graphs,
|
||||||
|
max_total_tokens,
|
||||||
shutdown.clone(),
|
shutdown.clone(),
|
||||||
&shutdown_receiver,
|
&shutdown_receiver,
|
||||||
shutdown_sender,
|
shutdown_sender,
|
||||||
|
Loading…
Reference in New Issue
Block a user