mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
rebase
This commit is contained in:
parent
438883cb10
commit
ef51a1e0b7
@ -349,8 +349,8 @@ fn main() -> ExitCode {
|
||||
Err(TryRecvError::Empty) => {
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
Ok(ShardStatus::Failed((rank, err))) => {
|
||||
tracing::error!("Shard {} failed to start:\n{}", rank, err);
|
||||
Ok(ShardStatus::Failed(rank)) => {
|
||||
tracing::error!("Shard {} failed to start.", rank);
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
@ -457,8 +457,8 @@ fn main() -> ExitCode {
|
||||
let mut exit_code = ExitCode::SUCCESS;
|
||||
|
||||
while running.load(Ordering::SeqCst) {
|
||||
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
||||
tracing::error!("Shard {rank} failed:\n{err}");
|
||||
if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
|
||||
tracing::error!("Shard {rank} failed.");
|
||||
exit_code = ExitCode::FAILURE;
|
||||
break;
|
||||
};
|
||||
@ -488,7 +488,7 @@ fn main() -> ExitCode {
|
||||
#[derive(Debug)]
|
||||
enum ShardStatus {
|
||||
Ready,
|
||||
Failed((usize, String)),
|
||||
Failed(usize),
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@ -627,9 +627,7 @@ fn shard_manager(
|
||||
tracing::error!("Please install it with `make install-server`")
|
||||
}
|
||||
}
|
||||
status_sender
|
||||
.send(ShardStatus::Failed((rank, err.to_string())))
|
||||
.unwrap();
|
||||
status_sender.send(ShardStatus::Failed(rank)).unwrap();
|
||||
return;
|
||||
}
|
||||
};
|
||||
@ -658,11 +656,7 @@ fn shard_manager(
|
||||
loop {
|
||||
// Process exited
|
||||
if p.poll().is_some() {
|
||||
let mut err = String::new();
|
||||
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
|
||||
status_sender
|
||||
.send(ShardStatus::Failed((rank, err)))
|
||||
.unwrap();
|
||||
status_sender.send(ShardStatus::Failed(rank)).unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,6 @@ from text_generation.models import CausalLM
|
||||
from text_generation.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
download_weights,
|
||||
)
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
@ -33,7 +32,7 @@ except Exception as e:
|
||||
|
||||
class OPT(CausalLM):
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""Overwrite forward to ignore position_ids"""
|
||||
|
||||
@ -49,7 +48,7 @@ class OPT(CausalLM):
|
||||
|
||||
class OPTSharded(OPT):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
):
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
@ -69,14 +68,8 @@ class OPTSharded(OPT):
|
||||
)
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
# Only download weights for small models
|
||||
if self.master:
|
||||
download_weights(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
if not filenames:
|
||||
raise ValueError("No safetensors weights found")
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
@ -99,27 +92,25 @@ class OPTSharded(OPT):
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
if name == "lm_head.weight":
|
||||
continue
|
||||
|
||||
full_name = f"model.{name}"
|
||||
|
||||
module_name, param_name = full_name.rsplit(".", 1)
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
current_tensor = parameters[full_name]
|
||||
current_tensor = parameters[name]
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
@ -166,9 +157,9 @@ class OPTSharded(OPT):
|
||||
)
|
||||
|
||||
if (
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
):
|
||||
tensor = Int8Params(
|
||||
tensor,
|
||||
@ -212,11 +203,11 @@ class OPTSharded(OPT):
|
||||
tensor = tensor.to(device)
|
||||
|
||||
module._parameters[param_name] = tensor
|
||||
if full_name == "model.decoder.embed_tokens.weight":
|
||||
if name == "decoder.embed_tokens.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user