mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
rebase
This commit is contained in:
parent
438883cb10
commit
ef51a1e0b7
@ -349,8 +349,8 @@ fn main() -> ExitCode {
|
|||||||
Err(TryRecvError::Empty) => {
|
Err(TryRecvError::Empty) => {
|
||||||
sleep(Duration::from_millis(100));
|
sleep(Duration::from_millis(100));
|
||||||
}
|
}
|
||||||
Ok(ShardStatus::Failed((rank, err))) => {
|
Ok(ShardStatus::Failed(rank)) => {
|
||||||
tracing::error!("Shard {} failed to start:\n{}", rank, err);
|
tracing::error!("Shard {} failed to start.", rank);
|
||||||
shutdown_shards(shutdown, &shutdown_receiver);
|
shutdown_shards(shutdown, &shutdown_receiver);
|
||||||
return ExitCode::FAILURE;
|
return ExitCode::FAILURE;
|
||||||
}
|
}
|
||||||
@ -457,8 +457,8 @@ fn main() -> ExitCode {
|
|||||||
let mut exit_code = ExitCode::SUCCESS;
|
let mut exit_code = ExitCode::SUCCESS;
|
||||||
|
|
||||||
while running.load(Ordering::SeqCst) {
|
while running.load(Ordering::SeqCst) {
|
||||||
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {
|
||||||
tracing::error!("Shard {rank} failed:\n{err}");
|
tracing::error!("Shard {rank} failed.");
|
||||||
exit_code = ExitCode::FAILURE;
|
exit_code = ExitCode::FAILURE;
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
@ -488,7 +488,7 @@ fn main() -> ExitCode {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum ShardStatus {
|
enum ShardStatus {
|
||||||
Ready,
|
Ready,
|
||||||
Failed((usize, String)),
|
Failed(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
@ -627,9 +627,7 @@ fn shard_manager(
|
|||||||
tracing::error!("Please install it with `make install-server`")
|
tracing::error!("Please install it with `make install-server`")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
status_sender
|
status_sender.send(ShardStatus::Failed(rank)).unwrap();
|
||||||
.send(ShardStatus::Failed((rank, err.to_string())))
|
|
||||||
.unwrap();
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -658,11 +656,7 @@ fn shard_manager(
|
|||||||
loop {
|
loop {
|
||||||
// Process exited
|
// Process exited
|
||||||
if p.poll().is_some() {
|
if p.poll().is_some() {
|
||||||
let mut err = String::new();
|
status_sender.send(ShardStatus::Failed(rank)).unwrap();
|
||||||
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
|
|
||||||
status_sender
|
|
||||||
.send(ShardStatus::Failed((rank, err)))
|
|
||||||
.unwrap();
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from text_generation.models import CausalLM
|
|||||||
from text_generation.utils import (
|
from text_generation.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
download_weights,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
@ -69,14 +68,8 @@ class OPTSharded(OPT):
|
|||||||
)
|
)
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
# Only download weights for small models
|
|
||||||
if self.master:
|
|
||||||
download_weights(model_id, revision=revision, extension=".safetensors")
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
if not filenames:
|
|
||||||
raise ValueError("No safetensors weights found")
|
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
@ -115,11 +108,9 @@ class OPTSharded(OPT):
|
|||||||
if name == "lm_head.weight":
|
if name == "lm_head.weight":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
full_name = f"model.{name}"
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
|
|
||||||
module_name, param_name = full_name.rsplit(".", 1)
|
|
||||||
module = model.get_submodule(module_name)
|
module = model.get_submodule(module_name)
|
||||||
current_tensor = parameters[full_name]
|
current_tensor = parameters[name]
|
||||||
|
|
||||||
slice_ = f.get_slice(name)
|
slice_ = f.get_slice(name)
|
||||||
|
|
||||||
@ -212,7 +203,7 @@ class OPTSharded(OPT):
|
|||||||
tensor = tensor.to(device)
|
tensor = tensor.to(device)
|
||||||
|
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
if full_name == "model.decoder.embed_tokens.weight":
|
if name == "decoder.embed_tokens.weight":
|
||||||
model.lm_head._parameters["weight"] = tensor
|
model.lm_head._parameters["weight"] = tensor
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
Reference in New Issue
Block a user