diff --git a/launcher/src/main.rs b/launcher/src/main.rs index dc12c90f..0810d979 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -455,11 +455,12 @@ fn shutdown_shards(shutdown: Arc>, shutdown_receiver: &mpsc::Receive } fn num_cuda_devices() -> Option { - if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") { - let n_devices = cuda_visible_devices.split(',').count(); - return Some(n_devices); - } - None + let devices = match env::var("CUDA_VISIBLE_DEVICES") { + Ok(devices) => devices, + Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?, + }; + let n_devices = devices.split(',').count(); + Some(n_devices) } #[derive(Deserialize)] @@ -509,9 +510,9 @@ fn find_num_shards(sharded: Option, num_shard: Option) -> usize { let num_shard = match (sharded, num_shard) { (Some(true), None) => { // try to default to the number of available GPUs - tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES"); - let n_devices = - num_cuda_devices().expect("--num-shard and CUDA_VISIBLE_DEVICES are not set"); + tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES"); + let n_devices = num_cuda_devices() + .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set"); if n_devices <= 1 { panic!("`sharded` is true but only found {n_devices} CUDA devices"); } diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 088a1457..90db59fc 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -245,14 +245,12 @@ class BLOOMSharded(BLOOM): return linear module.linear = replace_linear(state) - elif quantize == "gptq": - raise NotImplementedError( - "`gptq` is not implemented for now" - ) - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: + tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") module._parameters[param_name] = tensor if name == "word_embeddings.weight": diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 0a3f341b..954421f0 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -364,14 +364,12 @@ class GalacticaSharded(Galactica): return linear module.linear = replace_linear(state) - elif quantize == "gptq": - raise NotImplementedError( - "`gptq` is not implemented for now" - ) - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: + tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") module._parameters[param_name] = tensor if name == "model.decoder.embed_tokens.weight": diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index e4a85082..c0e8170d 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -210,14 +210,12 @@ class GPTNeoxSharded(CausalLM): return linear module.linear = replace_linear(state) - elif quantize == "gptq": - raise NotImplementedError( - "`gptq` is not implemented for now" - ) - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: + tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") if current_parameter_tensor is not None: module._parameters[param_name] = tensor diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 17cc50e0..32dcb806 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -223,14 +223,12 @@ class T5Sharded(Seq2SeqLM): module.linear = replace_linear(state) - elif quantize == "gptq" and not module_name.endswith("wo"): - raise NotImplementedError( - "`gptq` is not implemented for now" - ) - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") + elif quantize == "gptq" and not module_name.endswith("wo"): + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None or module_name.endswith("wo"): + tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") if current_parameter_tensor is not None: module._parameters[param_name] = tensor