mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
rebase
This commit is contained in:
parent
438883cb10
commit
ef51a1e0b7
@ -349,8 +349,8 @@ fn main() -> ExitCode {
|
|||||||
Err(TryRecvError::Empty) => {
|
Err(TryRecvError::Empty) => {
|
||||||
sleep(Duration::from_millis(100));
|
sleep(Duration::from_millis(100));
|
||||||
}
|
}
|
||||||
Ok(ShardStatus::Failed((rank, err))) => {
|
Ok(ShardStatus::Failed(rank)) => {
|
||||||
tracing::error!("Shard {} failed to start:\n{}", rank, err);
|
tracing::error!("Shard {} failed to start.", rank);
|
||||||
shutdown_shards(shutdown, &shutdown_receiver);
|
shutdown_shards(shutdown, &shutdown_receiver);
|
||||||
return ExitCode::FAILURE;
|
return ExitCode::FAILURE;
|
||||||
}
|
}
|
||||||
@ -457,8 +457,8 @@ fn main() -> ExitCode {
|
|||||||
let mut exit_code = ExitCode::SUCCESS;
|
let mut exit_code = ExitCode::SUCCESS;
|
||||||
|
|
||||||
while running.load(Ordering::SeqCst) {
|
while running.load(Ordering::SeqCst) {
|
||||||
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
|
||||||
tracing::error!("Shard {rank} failed:\n{err}");
|
tracing::error!("Shard {rank} failed.");
|
||||||
exit_code = ExitCode::FAILURE;
|
exit_code = ExitCode::FAILURE;
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
@ -488,7 +488,7 @@ fn main() -> ExitCode {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum ShardStatus {
|
enum ShardStatus {
|
||||||
Ready,
|
Ready,
|
||||||
Failed((usize, String)),
|
Failed(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
@ -627,9 +627,7 @@ fn shard_manager(
|
|||||||
tracing::error!("Please install it with `make install-server`")
|
tracing::error!("Please install it with `make install-server`")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
status_sender
|
status_sender.send(ShardStatus::Failed(rank)).unwrap();
|
||||||
.send(ShardStatus::Failed((rank, err.to_string())))
|
|
||||||
.unwrap();
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -658,11 +656,7 @@ fn shard_manager(
|
|||||||
loop {
|
loop {
|
||||||
// Process exited
|
// Process exited
|
||||||
if p.poll().is_some() {
|
if p.poll().is_some() {
|
||||||
let mut err = String::new();
|
status_sender.send(ShardStatus::Failed(rank)).unwrap();
|
||||||
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
|
|
||||||
status_sender
|
|
||||||
.send(ShardStatus::Failed((rank, err)))
|
|
||||||
.unwrap();
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from text_generation.models import CausalLM
|
|||||||
from text_generation.utils import (
|
from text_generation.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
download_weights,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
@ -33,7 +32,7 @@ except Exception as e:
|
|||||||
|
|
||||||
class OPT(CausalLM):
|
class OPT(CausalLM):
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
"""Overwrite forward to ignore position_ids"""
|
"""Overwrite forward to ignore position_ids"""
|
||||||
|
|
||||||
@ -49,7 +48,7 @@ class OPT(CausalLM):
|
|||||||
|
|
||||||
class OPTSharded(OPT):
|
class OPTSharded(OPT):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
@ -69,14 +68,8 @@ class OPTSharded(OPT):
|
|||||||
)
|
)
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
# Only download weights for small models
|
|
||||||
if self.master:
|
|
||||||
download_weights(model_id, revision=revision, extension=".safetensors")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
if not filenames:
|
|
||||||
raise ValueError("No safetensors weights found")
|
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
@ -99,27 +92,25 @@ class OPTSharded(OPT):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[str],
|
filenames: List[str],
|
||||||
quantize: bool,
|
quantize: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
):
|
):
|
||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
if name == "lm_head.weight":
|
if name == "lm_head.weight":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
full_name = f"model.{name}"
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
|
|
||||||
module_name, param_name = full_name.rsplit(".", 1)
|
|
||||||
module = model.get_submodule(module_name)
|
module = model.get_submodule(module_name)
|
||||||
current_tensor = parameters[full_name]
|
current_tensor = parameters[name]
|
||||||
|
|
||||||
slice_ = f.get_slice(name)
|
slice_ = f.get_slice(name)
|
||||||
|
|
||||||
@ -166,9 +157,9 @@ class OPTSharded(OPT):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
type(module)
|
type(module)
|
||||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||||
and param_name == "weight"
|
and param_name == "weight"
|
||||||
):
|
):
|
||||||
tensor = Int8Params(
|
tensor = Int8Params(
|
||||||
tensor,
|
tensor,
|
||||||
@ -212,11 +203,11 @@ class OPTSharded(OPT):
|
|||||||
tensor = tensor.to(device)
|
tensor = tensor.to(device)
|
||||||
|
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
if full_name == "model.decoder.embed_tokens.weight":
|
if name == "decoder.embed_tokens.weight":
|
||||||
model.lm_head._parameters["weight"] = tensor
|
model.lm_head._parameters["weight"] = tensor
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
):
|
):
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
Loading…
Reference in New Issue
Block a user