diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a598a8bb..2b152be6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -349,8 +349,8 @@ fn main() -> ExitCode { Err(TryRecvError::Empty) => { sleep(Duration::from_millis(100)); } - Ok(ShardStatus::Failed((rank, err))) => { - tracing::error!("Shard {} failed to start:\n{}", rank, err); + Ok(ShardStatus::Failed(rank)) => { + tracing::error!("Shard {} failed to start.", rank); shutdown_shards(shutdown, &shutdown_receiver); return ExitCode::FAILURE; } @@ -457,8 +457,8 @@ fn main() -> ExitCode { let mut exit_code = ExitCode::SUCCESS; while running.load(Ordering::SeqCst) { - if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { - tracing::error!("Shard {rank} failed:\n{err}"); + if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() { + tracing::error!("Shard {rank} failed."); exit_code = ExitCode::FAILURE; break; }; @@ -488,7 +488,7 @@ fn main() -> ExitCode { #[derive(Debug)] enum ShardStatus { Ready, - Failed((usize, String)), + Failed(usize), } #[allow(clippy::too_many_arguments)] @@ -627,9 +627,7 @@ fn shard_manager( tracing::error!("Please install it with `make install-server`") } } - status_sender - .send(ShardStatus::Failed((rank, err.to_string()))) - .unwrap(); + status_sender.send(ShardStatus::Failed(rank)).unwrap(); return; } }; @@ -658,11 +656,7 @@ fn shard_manager( loop { // Process exited if p.poll().is_some() { - let mut err = String::new(); - p.stderr.take().unwrap().read_to_string(&mut err).unwrap(); - status_sender - .send(ShardStatus::Failed((rank, err))) - .unwrap(); + status_sender.send(ShardStatus::Failed(rank)).unwrap(); return; } diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 6f437957..c05e9a1d 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -20,7 +20,6 @@ from text_generation.models import CausalLM from text_generation.utils import ( initialize_torch_distributed, weight_files, - download_weights, ) HAS_BITS_AND_BYTES = True @@ -33,7 +32,7 @@ except Exception as e: class OPT(CausalLM): def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: """Overwrite forward to ignore position_ids""" @@ -49,7 +48,7 @@ class OPT(CausalLM): class OPTSharded(OPT): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 @@ -69,14 +68,8 @@ class OPTSharded(OPT): ) tokenizer.pad_token_id = config.pad_token_id - # Only download weights for small models - if self.master: - download_weights(model_id, revision=revision, extension=".safetensors") - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - if not filenames: - raise ValueError("No safetensors weights found") with init_empty_weights(): model = AutoModelForCausalLM.from_config(config) @@ -99,27 +92,25 @@ class OPTSharded(OPT): @staticmethod def load_weights( - model, - filenames: List[str], - quantize: bool, - device: torch.device, - rank: int, - world_size: int, + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, ): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if not quantize else "cpu" ) as f: for name in f.keys(): if name == "lm_head.weight": continue - full_name = f"model.{name}" - - module_name, param_name = full_name.rsplit(".", 1) + module_name, param_name = name.rsplit(".", 1) module = model.get_submodule(module_name) - current_tensor = parameters[full_name] + current_tensor = parameters[name] slice_ = f.get_slice(name) @@ -166,9 +157,9 @@ class OPTSharded(OPT): ) if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" ): tensor = Int8Params( tensor, @@ -212,11 +203,11 @@ class OPTSharded(OPT): tensor = tensor.to(device) module._parameters[param_name] = tensor - if full_name == "model.decoder.embed_tokens.weight": + if name == "decoder.embed_tokens.weight": model.lm_head._parameters["weight"] = tensor def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): outputs = self.model.forward( input_ids=input_ids,