This commit is contained in:
OlivierDehaene 2023-02-28 11:48:15 +01:00
parent 438883cb10
commit ef51a1e0b7
2 changed files with 23 additions and 38 deletions

View File

@ -349,8 +349,8 @@ fn main() -> ExitCode {
Err(TryRecvError::Empty) => {
sleep(Duration::from_millis(100));
}
Ok(ShardStatus::Failed((rank, err))) => {
tracing::error!("Shard {} failed to start:\n{}", rank, err);
Ok(ShardStatus::Failed(rank)) => {
tracing::error!("Shard {} failed to start.", rank);
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
@ -457,8 +457,8 @@ fn main() -> ExitCode {
let mut exit_code = ExitCode::SUCCESS;
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {rank} failed:\n{err}");
if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
tracing::error!("Shard {rank} failed.");
exit_code = ExitCode::FAILURE;
break;
};
@ -488,7 +488,7 @@ fn main() -> ExitCode {
#[derive(Debug)]
enum ShardStatus {
Ready,
Failed((usize, String)),
Failed(usize),
}
#[allow(clippy::too_many_arguments)]
@ -627,9 +627,7 @@ fn shard_manager(
tracing::error!("Please install it with `make install-server`")
}
}
status_sender
.send(ShardStatus::Failed((rank, err.to_string())))
.unwrap();
status_sender.send(ShardStatus::Failed(rank)).unwrap();
return;
}
};
@ -658,11 +656,7 @@ fn shard_manager(
loop {
// Process exited
if p.poll().is_some() {
let mut err = String::new();
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
status_sender
.send(ShardStatus::Failed((rank, err)))
.unwrap();
status_sender.send(ShardStatus::Failed(rank)).unwrap();
return;
}

View File

@ -20,7 +20,6 @@ from text_generation.models import CausalLM
from text_generation.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
)
HAS_BITS_AND_BYTES = True
@ -33,7 +32,7 @@ except Exception as e:
class OPT(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
@ -49,7 +48,7 @@ class OPT(CausalLM):
class OPTSharded(OPT):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
@ -69,14 +68,8 @@ class OPTSharded(OPT):
)
tokenizer.pad_token_id = config.pad_token_id
# Only download weights for small models
if self.master:
download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
if not filenames:
raise ValueError("No safetensors weights found")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
@ -99,27 +92,25 @@ class OPTSharded(OPT):
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
if name == "lm_head.weight":
continue
full_name = f"model.{name}"
module_name, param_name = full_name.rsplit(".", 1)
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[full_name]
current_tensor = parameters[name]
slice_ = f.get_slice(name)
@ -166,9 +157,9 @@ class OPTSharded(OPT):
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
@ -212,11 +203,11 @@ class OPTSharded(OPT):
tensor = tensor.to(device)
module._parameters[param_name] = tensor
if full_name == "model.decoder.embed_tokens.weight":
if name == "decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
input_ids=input_ids,