mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Merge remote-tracking branch 'origin/main' into feat/rw
This commit is contained in:
commit
3e517bfc9d
@ -455,11 +455,12 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn num_cuda_devices() -> Option<usize> {
|
fn num_cuda_devices() -> Option<usize> {
|
||||||
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
|
let devices = match env::var("CUDA_VISIBLE_DEVICES") {
|
||||||
let n_devices = cuda_visible_devices.split(',').count();
|
Ok(devices) => devices,
|
||||||
return Some(n_devices);
|
Err(_) => env::var("NVIDIA_VISIBLE_DEVICES").ok()?,
|
||||||
}
|
};
|
||||||
None
|
let n_devices = devices.split(',').count();
|
||||||
|
Some(n_devices)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@ -509,9 +510,9 @@ fn find_num_shards(sharded: Option<bool>, num_shard: Option<usize>) -> usize {
|
|||||||
let num_shard = match (sharded, num_shard) {
|
let num_shard = match (sharded, num_shard) {
|
||||||
(Some(true), None) => {
|
(Some(true), None) => {
|
||||||
// try to default to the number of available GPUs
|
// try to default to the number of available GPUs
|
||||||
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
|
tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES");
|
||||||
let n_devices =
|
let n_devices = num_cuda_devices()
|
||||||
num_cuda_devices().expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
|
.expect("--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES are not set");
|
||||||
if n_devices <= 1 {
|
if n_devices <= 1 {
|
||||||
panic!("`sharded` is true but only found {n_devices} CUDA devices");
|
panic!("`sharded` is true but only found {n_devices} CUDA devices");
|
||||||
}
|
}
|
||||||
|
@ -245,14 +245,12 @@ class BLOOMSharded(BLOOM):
|
|||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
"`gptq` is not implemented for now"
|
elif quantize is None:
|
||||||
)
|
tensor = tensor.to(device)
|
||||||
elif quantize is None:
|
else:
|
||||||
tensor = tensor.to(device)
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
|
||||||
|
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
if name == "word_embeddings.weight":
|
if name == "word_embeddings.weight":
|
||||||
|
@ -364,14 +364,12 @@ class GalacticaSharded(Galactica):
|
|||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
"`gptq` is not implemented for now"
|
elif quantize is None:
|
||||||
)
|
tensor = tensor.to(device)
|
||||||
elif quantize is None:
|
else:
|
||||||
tensor = tensor.to(device)
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
|
||||||
|
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
if name == "model.decoder.embed_tokens.weight":
|
if name == "model.decoder.embed_tokens.weight":
|
||||||
|
@ -210,14 +210,12 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
"`gptq` is not implemented for now"
|
elif quantize is None:
|
||||||
)
|
tensor = tensor.to(device)
|
||||||
elif quantize is None:
|
else:
|
||||||
tensor = tensor.to(device)
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
|
||||||
|
|
||||||
if current_parameter_tensor is not None:
|
if current_parameter_tensor is not None:
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
|
@ -223,14 +223,12 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
|
|
||||||
elif quantize == "gptq" and not module_name.endswith("wo"):
|
elif quantize == "gptq" and not module_name.endswith("wo"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
"`gptq` is not implemented for now"
|
elif quantize is None or module_name.endswith("wo"):
|
||||||
)
|
tensor = tensor.to(device)
|
||||||
elif quantize is None:
|
else:
|
||||||
tensor = tensor.to(device)
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
|
||||||
|
|
||||||
if current_parameter_tensor is not None:
|
if current_parameter_tensor is not None:
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user