tp optims

This commit is contained in:
HuaYZhao 2023-11-22 17:22:47 +08:00
parent b226e469c9
commit dad29f7299
5 changed files with 69 additions and 22 deletions

View File

@ -524,22 +524,29 @@ fn shard_manager(
if let Some(watermark_delta) = watermark_delta {
envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
}
// for mpi
let n_devices = match num_cuda_devices() {
Some(value) => value.to_string(),
None => String::from("0"),
};
// Start process
tracing::info!("Starting shard");
let mut p = match Command::new("text-generation-server")
.args(shard_args)
.envs(envs)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.process_group(0)
.spawn()
tracing::info!("run with mpi and use cuda device num: {}", n_devices);
// tracing::info!("{:?}", shard_args);
tracing::info!("{:?}", envs);
let mut p = match Command::new("mpirun")
.args(&["-n", &n_devices, "--allow-run-as-root", "text-generation-server"])
.args(shard_args)
.envs(envs)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.process_group(0)
.spawn()
{
Ok(p) => p,
Err(err) => {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
tracing::error!("start mpi failed! ");
}
{
tracing::error!("{}", err);

View File

@ -402,9 +402,19 @@ class FlashLlamaModel(torch.nn.Module):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
import os
if int(os.getenv("USE_TP_EMBEDDING", "1")) == 1:
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
else:
from torch.nn import functional as F
from loguru import logger
embeddings = weights.get_tensor(f"model.embed_tokens.weight")
self.embed_tokens = nn.Embedding.from_pretrained(F.pad(embeddings, (0, 0, 0, 1)),
padding_idx=config.pad_token_id)
logger.info("Disabled embedding tensor parallel! ")
self.layers = nn.ModuleList(
[
FlashLlamaLayer(

View File

@ -16,9 +16,12 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
import os
tracer = trace.get_tracer(__name__)
USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1
if USE_CUSTOM_NCCL:
from text_generation_server.utils.my_dist import initialize_mpi_distributed
class FlashLlama(FlashCausalLM):
def __init__(
@ -29,7 +32,10 @@ class FlashLlama(FlashCausalLM):
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if USE_CUSTOM_NCCL:
self.process_group, rank, world_size, COMM = initialize_mpi_distributed()
else:
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
@ -37,6 +43,7 @@ class FlashLlama(FlashCausalLM):
raise NotImplementedError("FlashLlama is only available on GPU")
try:
raise
tokenizer = LlamaTokenizer.from_pretrained(
model_id,
revision=revision,
@ -58,7 +65,10 @@ class FlashLlama(FlashCausalLM):
)
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
if USE_CUSTOM_NCCL:
COMM.barrier()
else:
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
@ -67,7 +77,10 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
if USE_CUSTOM_NCCL:
COMM.barrier()
else:
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model=model,
tokenizer=tokenizer,

View File

@ -144,6 +144,7 @@ def serve(
dtype: Optional[str] = None,
trust_remote_code: bool = False,
):
logger.info(os.environ)
unix_socket_template = "unix://{}-{}"
if sharded:
server_urls = [
@ -155,6 +156,13 @@ def serve(
local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url]
if int(os.environ.get("USE_CUSTOM_NCCL", 0)):
server_urls = [
unix_socket_template.format(uds_path, rank)
for rank in range(int(os.environ["OMPI_COMM_WORLD_SIZE"]))
]
local_url = server_urls[int(os.environ["OMPI_COMM_WORLD_RANK"])]
try:
model = get_model(
model_id, revision, sharded, quantize, dtype, trust_remote_code

View File

@ -1,6 +1,5 @@
import os
import torch
import torch.distributed
from torch import nn
from torch.nn import functional as F
@ -52,7 +51,9 @@ try:
HAS_EETQ = True
except ImportError:
pass
import my_custom_comm
USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1
USE_LM_HEAD_PARALLEL = int(os.getenv("USE_LM_HEAD_PARALLEL", "1"))
# Monkey patching
@classmethod
@ -358,6 +359,7 @@ class TensorParallelHead(SuperLayer):
def load(config, prefix: str, weights):
if weights.process_group.size() > 1:
try:
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
@ -365,6 +367,7 @@ class TensorParallelHead(SuperLayer):
# just load the entire thing.
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
logger.info("Disabled lm head parallel! ")
else:
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
@ -469,7 +472,10 @@ class TensorParallelRowLinear(SuperLayer):
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super().forward(input)
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
if USE_CUSTOM_NCCL:
my_custom_comm.custom_allreduce(out, self.process_group.tp_comm)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out
@ -504,7 +510,10 @@ class TensorParallelEmbedding(nn.Module):
)
out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
if USE_CUSTOM_NCCL:
my_custom_comm.custom_allreduce(out, self.process_group.tp_comm)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out