This commit is contained in:
OlivierDehaene 2023-02-28 11:48:15 +01:00
parent 438883cb10
commit ef51a1e0b7
2 changed files with 23 additions and 38 deletions

View File

@ -349,8 +349,8 @@ fn main() -> ExitCode {
Err(TryRecvError::Empty) => { Err(TryRecvError::Empty) => {
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));
} }
Ok(ShardStatus::Failed((rank, err))) => { Ok(ShardStatus::Failed(rank)) => {
tracing::error!("Shard {} failed to start:\n{}", rank, err); tracing::error!("Shard {} failed to start.", rank);
shutdown_shards(shutdown, &shutdown_receiver); shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE; return ExitCode::FAILURE;
} }
@ -457,8 +457,8 @@ fn main() -> ExitCode {
let mut exit_code = ExitCode::SUCCESS; let mut exit_code = ExitCode::SUCCESS;
while running.load(Ordering::SeqCst) { while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
tracing::error!("Shard {rank} failed:\n{err}"); tracing::error!("Shard {rank} failed.");
exit_code = ExitCode::FAILURE; exit_code = ExitCode::FAILURE;
break; break;
}; };
@ -488,7 +488,7 @@ fn main() -> ExitCode {
#[derive(Debug)] #[derive(Debug)]
enum ShardStatus { enum ShardStatus {
Ready, Ready,
Failed((usize, String)), Failed(usize),
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
@ -627,9 +627,7 @@ fn shard_manager(
tracing::error!("Please install it with `make install-server`") tracing::error!("Please install it with `make install-server`")
} }
} }
status_sender status_sender.send(ShardStatus::Failed(rank)).unwrap();
.send(ShardStatus::Failed((rank, err.to_string())))
.unwrap();
return; return;
} }
}; };
@ -658,11 +656,7 @@ fn shard_manager(
loop { loop {
// Process exited // Process exited
if p.poll().is_some() { if p.poll().is_some() {
let mut err = String::new(); status_sender.send(ShardStatus::Failed(rank)).unwrap();
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
status_sender
.send(ShardStatus::Failed((rank, err)))
.unwrap();
return; return;
} }

View File

@ -20,7 +20,6 @@ from text_generation.models import CausalLM
from text_generation.utils import ( from text_generation.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights,
) )
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
@ -69,14 +68,8 @@ class OPTSharded(OPT):
) )
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
# Only download weights for small models
if self.master:
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config) model = AutoModelForCausalLM.from_config(config)
@ -115,11 +108,9 @@ class OPTSharded(OPT):
if name == "lm_head.weight": if name == "lm_head.weight":
continue continue
full_name = f"model.{name}" module_name, param_name = name.rsplit(".", 1)
module_name, param_name = full_name.rsplit(".", 1)
module = model.get_submodule(module_name) module = model.get_submodule(module_name)
current_tensor = parameters[full_name] current_tensor = parameters[name]
slice_ = f.get_slice(name) slice_ = f.get_slice(name)
@ -212,7 +203,7 @@ class OPTSharded(OPT):
tensor = tensor.to(device) tensor = tensor.to(device)
module._parameters[param_name] = tensor module._parameters[param_name] = tensor
if full_name == "model.decoder.embed_tokens.weight": if name == "decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor model.lm_head._parameters["weight"] = tensor
def forward( def forward(