mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
tp optims
This commit is contained in:
parent
b226e469c9
commit
dad29f7299
@ -524,22 +524,29 @@ fn shard_manager(
|
||||
if let Some(watermark_delta) = watermark_delta {
|
||||
envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
||||
}
|
||||
|
||||
// for mpi
|
||||
let n_devices = match num_cuda_devices() {
|
||||
Some(value) => value.to_string(),
|
||||
None => String::from("0"),
|
||||
};
|
||||
// Start process
|
||||
tracing::info!("Starting shard");
|
||||
let mut p = match Command::new("text-generation-server")
|
||||
.args(shard_args)
|
||||
.envs(envs)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.process_group(0)
|
||||
.spawn()
|
||||
tracing::info!("run with mpi and use cuda device num: {}", n_devices);
|
||||
// tracing::info!("{:?}", shard_args);
|
||||
tracing::info!("{:?}", envs);
|
||||
let mut p = match Command::new("mpirun")
|
||||
.args(&["-n", &n_devices, "--allow-run-as-root", "text-generation-server"])
|
||||
.args(shard_args)
|
||||
.envs(envs)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.process_group(0)
|
||||
.spawn()
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("text-generation-server not found in PATH");
|
||||
tracing::error!("Please install it with `make install-server`")
|
||||
tracing::error!("start mpi failed! ");
|
||||
}
|
||||
{
|
||||
tracing::error!("{}", err);
|
||||
|
@ -402,9 +402,19 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
|
||||
import os
|
||||
if int(os.getenv("USE_TP_EMBEDDING", "1")) == 1:
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
else:
|
||||
from torch.nn import functional as F
|
||||
from loguru import logger
|
||||
embeddings = weights.get_tensor(f"model.embed_tokens.weight")
|
||||
self.embed_tokens = nn.Embedding.from_pretrained(F.pad(embeddings, (0, 0, 0, 1)),
|
||||
padding_idx=config.pad_token_id)
|
||||
logger.info("Disabled embedding tensor parallel! ")
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashLlamaLayer(
|
||||
|
@ -16,9 +16,12 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
import os
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1
|
||||
if USE_CUSTOM_NCCL:
|
||||
from text_generation_server.utils.my_dist import initialize_mpi_distributed
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
def __init__(
|
||||
@ -29,7 +32,10 @@ class FlashLlama(FlashCausalLM):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if USE_CUSTOM_NCCL:
|
||||
self.process_group, rank, world_size, COMM = initialize_mpi_distributed()
|
||||
else:
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
@ -37,6 +43,7 @@ class FlashLlama(FlashCausalLM):
|
||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||
|
||||
try:
|
||||
raise
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
@ -58,7 +65,10 @@ class FlashLlama(FlashCausalLM):
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
if USE_CUSTOM_NCCL:
|
||||
COMM.barrier()
|
||||
else:
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
@ -67,7 +77,10 @@ class FlashLlama(FlashCausalLM):
|
||||
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
if USE_CUSTOM_NCCL:
|
||||
COMM.barrier()
|
||||
else:
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -144,6 +144,7 @@ def serve(
|
||||
dtype: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
logger.info(os.environ)
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
if sharded:
|
||||
server_urls = [
|
||||
@ -154,6 +155,13 @@ def serve(
|
||||
else:
|
||||
local_url = unix_socket_template.format(uds_path, 0)
|
||||
server_urls = [local_url]
|
||||
|
||||
if int(os.environ.get("USE_CUSTOM_NCCL", 0)):
|
||||
server_urls = [
|
||||
unix_socket_template.format(uds_path, rank)
|
||||
for rank in range(int(os.environ["OMPI_COMM_WORLD_SIZE"]))
|
||||
]
|
||||
local_url = server_urls[int(os.environ["OMPI_COMM_WORLD_RANK"])]
|
||||
|
||||
try:
|
||||
model = get_model(
|
||||
@ -206,4 +214,4 @@ def serve(
|
||||
|
||||
asyncio.run(
|
||||
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
|
||||
)
|
||||
)
|
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
@ -52,7 +51,9 @@ try:
|
||||
HAS_EETQ = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import my_custom_comm
|
||||
USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1
|
||||
USE_LM_HEAD_PARALLEL = int(os.getenv("USE_LM_HEAD_PARALLEL", "1"))
|
||||
|
||||
# Monkey patching
|
||||
@classmethod
|
||||
@ -358,6 +359,7 @@ class TensorParallelHead(SuperLayer):
|
||||
def load(config, prefix: str, weights):
|
||||
if weights.process_group.size() > 1:
|
||||
try:
|
||||
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
should_gather = True
|
||||
except AssertionError:
|
||||
@ -365,6 +367,7 @@ class TensorParallelHead(SuperLayer):
|
||||
# just load the entire thing.
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
logger.info("Disabled lm head parallel! ")
|
||||
else:
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
@ -469,7 +472,10 @@ class TensorParallelRowLinear(SuperLayer):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
out = super().forward(input)
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
if USE_CUSTOM_NCCL:
|
||||
my_custom_comm.custom_allreduce(out, self.process_group.tp_comm)
|
||||
else:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
||||
|
||||
@ -504,7 +510,10 @@ class TensorParallelEmbedding(nn.Module):
|
||||
)
|
||||
out = torch.nn.functional.embedding(input, self.weight)
|
||||
if self.reduce and self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
if USE_CUSTOM_NCCL:
|
||||
my_custom_comm.custom_allreduce(out, self.process_group.tp_comm)
|
||||
else:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user