tp optims

This commit is contained in:
HuaYZhao 2023-11-22 17:22:47 +08:00
parent b226e469c9
commit dad29f7299
5 changed files with 69 additions and 22 deletions

View File

@ -524,22 +524,29 @@ fn shard_manager(
if let Some(watermark_delta) = watermark_delta { if let Some(watermark_delta) = watermark_delta {
envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into())) envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
} }
// for mpi
let n_devices = match num_cuda_devices() {
Some(value) => value.to_string(),
None => String::from("0"),
};
// Start process // Start process
tracing::info!("Starting shard"); tracing::info!("Starting shard");
let mut p = match Command::new("text-generation-server") tracing::info!("run with mpi and use cuda device num: {}", n_devices);
.args(shard_args) // tracing::info!("{:?}", shard_args);
.envs(envs) tracing::info!("{:?}", envs);
.stdout(Stdio::piped()) let mut p = match Command::new("mpirun")
.stderr(Stdio::piped()) .args(&["-n", &n_devices, "--allow-run-as-root", "text-generation-server"])
.process_group(0) .args(shard_args)
.spawn() .envs(envs)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.process_group(0)
.spawn()
{ {
Ok(p) => p, Ok(p) => p,
Err(err) => { Err(err) => {
if err.kind() == io::ErrorKind::NotFound { if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-server not found in PATH"); tracing::error!("start mpi failed! ");
tracing::error!("Please install it with `make install-server`")
} }
{ {
tracing::error!("{}", err); tracing::error!("{}", err);

View File

@ -402,9 +402,19 @@ class FlashLlamaModel(torch.nn.Module):
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights import os
) if int(os.getenv("USE_TP_EMBEDDING", "1")) == 1:
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
else:
from torch.nn import functional as F
from loguru import logger
embeddings = weights.get_tensor(f"model.embed_tokens.weight")
self.embed_tokens = nn.Embedding.from_pretrained(F.pad(embeddings, (0, 0, 0, 1)),
padding_idx=config.pad_token_id)
logger.info("Disabled embedding tensor parallel! ")
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashLlamaLayer( FlashLlamaLayer(

View File

@ -16,9 +16,12 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
import os
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1
if USE_CUSTOM_NCCL:
from text_generation_server.utils.my_dist import initialize_mpi_distributed
class FlashLlama(FlashCausalLM): class FlashLlama(FlashCausalLM):
def __init__( def __init__(
@ -29,7 +32,10 @@ class FlashLlama(FlashCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() if USE_CUSTOM_NCCL:
self.process_group, rank, world_size, COMM = initialize_mpi_distributed()
else:
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
@ -37,6 +43,7 @@ class FlashLlama(FlashCausalLM):
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")
try: try:
raise
tokenizer = LlamaTokenizer.from_pretrained( tokenizer = LlamaTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
@ -58,7 +65,10 @@ class FlashLlama(FlashCausalLM):
) )
config.quantize = quantize config.quantize = quantize
torch.distributed.barrier(group=self.process_group) if USE_CUSTOM_NCCL:
COMM.barrier()
else:
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
@ -67,7 +77,10 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(config, weights) model = FlashLlamaForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) if USE_CUSTOM_NCCL:
COMM.barrier()
else:
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__( super(FlashLlama, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,

View File

@ -144,6 +144,7 @@ def serve(
dtype: Optional[str] = None, dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
logger.info(os.environ)
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
if sharded: if sharded:
server_urls = [ server_urls = [
@ -154,6 +155,13 @@ def serve(
else: else:
local_url = unix_socket_template.format(uds_path, 0) local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url] server_urls = [local_url]
if int(os.environ.get("USE_CUSTOM_NCCL", 0)):
server_urls = [
unix_socket_template.format(uds_path, rank)
for rank in range(int(os.environ["OMPI_COMM_WORLD_SIZE"]))
]
local_url = server_urls[int(os.environ["OMPI_COMM_WORLD_RANK"])]
try: try:
model = get_model( model = get_model(
@ -206,4 +214,4 @@ def serve(
asyncio.run( asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
) )

View File

@ -1,6 +1,5 @@
import os import os
import torch import torch
import torch.distributed
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
@ -52,7 +51,9 @@ try:
HAS_EETQ = True HAS_EETQ = True
except ImportError: except ImportError:
pass pass
import my_custom_comm
USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1
USE_LM_HEAD_PARALLEL = int(os.getenv("USE_LM_HEAD_PARALLEL", "1"))
# Monkey patching # Monkey patching
@classmethod @classmethod
@ -358,6 +359,7 @@ class TensorParallelHead(SuperLayer):
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
if weights.process_group.size() > 1: if weights.process_group.size() > 1:
try: try:
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True should_gather = True
except AssertionError: except AssertionError:
@ -365,6 +367,7 @@ class TensorParallelHead(SuperLayer):
# just load the entire thing. # just load the entire thing.
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False should_gather = False
logger.info("Disabled lm head parallel! ")
else: else:
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False should_gather = False
@ -469,7 +472,10 @@ class TensorParallelRowLinear(SuperLayer):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super().forward(input) out = super().forward(input)
if self.process_group.size() > 1: if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group) if USE_CUSTOM_NCCL:
my_custom_comm.custom_allreduce(out, self.process_group.tp_comm)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out return out
@ -504,7 +510,10 @@ class TensorParallelEmbedding(nn.Module):
) )
out = torch.nn.functional.embedding(input, self.weight) out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1: if self.reduce and self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group) if USE_CUSTOM_NCCL:
my_custom_comm.custom_allreduce(out, self.process_group.tp_comm)
else:
torch.distributed.all_reduce(out, group=self.process_group)
return out return out