Merge remote-tracking branch 'origin/main' into feat/rw

This commit is contained in:
OlivierDehaene 2023-05-30 15:24:44 +02:00
commit 3e517bfc9d
5 changed files with 33 additions and 40 deletions

View File

@ -455,11 +455,12 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
} }
fn num_cuda_devices() -> Option<usize> { fn num_cuda_devices() -> Option<usize> {
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") { let devices = match env::var("CUDA_VISIBLE_DEVICES") {
let n_devices = cuda_visible_devices.split(',').count(); Ok(devices) => devices,
return Some(n_devices); Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?,
} };
None let n_devices = devices.split(',').count();
Some(n_devices)
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -509,9 +510,9 @@ fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
let num_shard = match (sharded, num_shard) { let num_shard = match (sharded, num_shard) {
(Some(true), None) => { (Some(true), None) => {
// try to default to the number of available GPUs // try to default to the number of available GPUs
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES"); tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES");
let n_devices = let n_devices = num_cuda_devices()
num_cuda_devices().expect("--num-shard and CUDA_VISIBLE_DEVICES are not set"); .expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set");
if n_devices <= 1 { if n_devices <= 1 {
panic!("`sharded` is true but only found {n_devices} CUDA devices"); panic!("`sharded` is true but only found {n_devices} CUDA devices");
} }

View File

@ -245,14 +245,12 @@ class BLOOMSharded(BLOOM):
return linear return linear
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq": elif quantize == "gptq":
raise NotImplementedError( raise NotImplementedError("`gptq` is not implemented for now")
"`gptq` is not implemented for now" elif quantize is None:
) tensor = tensor.to(device)
elif quantize is None: else:
tensor = tensor.to(device) raise ValueError(f"Unexpected quantize `{quantize}`")
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor module._parameters[param_name] = tensor
if name == "word_embeddings.weight": if name == "word_embeddings.weight":

View File

@ -364,14 +364,12 @@ class GalacticaSharded(Galactica):
return linear return linear
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq": elif quantize == "gptq":
raise NotImplementedError( raise NotImplementedError("`gptq` is not implemented for now")
"`gptq` is not implemented for now" elif quantize is None:
) tensor = tensor.to(device)
elif quantize is None: else:
tensor = tensor.to(device) raise ValueError(f"Unexpected quantize `{quantize}`")
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor module._parameters[param_name] = tensor
if name == "model.decoder.embed_tokens.weight": if name == "model.decoder.embed_tokens.weight":

View File

@ -210,14 +210,12 @@ class GPTNeoxSharded(CausalLM):
return linear return linear
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq": elif quantize == "gptq":
raise NotImplementedError( raise NotImplementedError("`gptq` is not implemented for now")
"`gptq` is not implemented for now" elif quantize is None:
) tensor = tensor.to(device)
elif quantize is None: else:
tensor = tensor.to(device) raise ValueError(f"Unexpected quantize `{quantize}`")
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None: if current_parameter_tensor is not None:
module._parameters[param_name] = tensor module._parameters[param_name] = tensor

View File

@ -223,14 +223,12 @@ class T5Sharded(Seq2SeqLM):
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq" and not module_name.endswith("wo"): elif quantize == "gptq" and not module_name.endswith("wo"):
raise NotImplementedError( raise NotImplementedError("`gptq` is not implemented for now")
"`gptq` is not implemented for now" elif quantize is None or module_name.endswith("wo"):
) tensor = tensor.to(device)
elif quantize is None: else:
tensor = tensor.to(device) raise ValueError(f"Unexpected quantize `{quantize}`")
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None: if current_parameter_tensor is not None:
module._parameters[param_name] = tensor module._parameters[param_name] = tensor