mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Adding new env variables for TPU backends.
This commit is contained in:
parent
06c3d4b1ec
commit
5bc3d65dd3
@ -448,6 +448,8 @@ fn shard_manager(
|
||||
cuda_memory_fraction: f32,
|
||||
rope_scaling: Option<RopeScaling>,
|
||||
rope_factor: Option<f32>,
|
||||
max_total_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
otlp_endpoint: Option<String>,
|
||||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
@ -512,6 +514,7 @@ fn shard_manager(
|
||||
(Some(scaling), Some(factor)) => Some((scaling, factor)),
|
||||
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
|
||||
};
|
||||
|
||||
// OpenTelemetry
|
||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||
shard_args.push("--otlp-endpoint".to_string());
|
||||
@ -564,6 +567,14 @@ fn shard_manager(
|
||||
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
||||
}
|
||||
|
||||
envs.push((
|
||||
"MAX_TOTAL_TOKENS".into(),
|
||||
max_total_tokens.to_string().into(),
|
||||
));
|
||||
if let Some(max_batch_size) = max_batch_size {
|
||||
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
|
||||
}
|
||||
|
||||
// If huggingface_hub_cache is some, pass it to the shard
|
||||
// Useful when running inside a docker container
|
||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||
@ -967,6 +978,7 @@ fn spawn_shards(
|
||||
num_shard: usize,
|
||||
args: &Args,
|
||||
cuda_graphs: Vec<usize>,
|
||||
max_total_tokens: usize,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
shutdown_sender: mpsc::Sender<()>,
|
||||
@ -998,6 +1010,7 @@ fn spawn_shards(
|
||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||
let rope_scaling = args.rope_scaling;
|
||||
let rope_factor = args.rope_factor;
|
||||
let max_batch_size = args.max_batch_size;
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
@ -1020,6 +1033,8 @@ fn spawn_shards(
|
||||
cuda_memory_fraction,
|
||||
rope_scaling,
|
||||
rope_factor,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
otlp_endpoint,
|
||||
status_sender,
|
||||
shutdown,
|
||||
@ -1473,6 +1488,7 @@ fn main() -> Result<(), LauncherError> {
|
||||
num_shard,
|
||||
&args,
|
||||
cuda_graphs,
|
||||
max_total_tokens,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
shutdown_sender,
|
||||
|
Loading…
Reference in New Issue
Block a user