diff --git a/launcher/src/main.rs b/launcher/src/main.rs index b4fc86b7..147e4c38 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -524,22 +524,29 @@ fn shard_manager( if let Some(watermark_delta) = watermark_delta { envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into())) } - + // for mpi + let n_devices = match num_cuda_devices() { + Some(value) => value.to_string(), + None => String::from("0"), + }; // Start process tracing::info!("Starting shard"); - let mut p = match Command::new("text-generation-server") - .args(shard_args) - .envs(envs) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .process_group(0) - .spawn() + tracing::info!("run with mpi and use cuda device num: {}", n_devices); + // tracing::info!("{:?}", shard_args); + tracing::info!("{:?}", envs); + let mut p = match Command::new("mpirun") + .args(&["-n", &n_devices, "--allow-run-as-root", "text-generation-server"]) + .args(shard_args) + .envs(envs) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .process_group(0) + .spawn() { Ok(p) => p, Err(err) => { if err.kind() == io::ErrorKind::NotFound { - tracing::error!("text-generation-server not found in PATH"); - tracing::error!("Please install it with `make install-server`") + tracing::error!("start mpi failed! "); } { tracing::error!("{}", err); diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 69608e1c..c3cca1ac 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -402,9 +402,19 @@ class FlashLlamaModel(torch.nn.Module): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights - ) + + import os + if int(os.getenv("USE_TP_EMBEDDING", "1")) == 1: + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + else: + from torch.nn import functional as F + from loguru import logger + embeddings = weights.get_tensor(f"model.embed_tokens.weight") + self.embed_tokens = nn.Embedding.from_pretrained(F.pad(embeddings, (0, 0, 0, 1)), + padding_idx=config.pad_token_id) + logger.info("Disabled embedding tensor parallel! ") self.layers = nn.ModuleList( [ FlashLlamaLayer( diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index d2ed0b15..094e0726 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -16,9 +16,12 @@ from text_generation_server.utils import ( weight_files, Weights, ) +import os tracer = trace.get_tracer(__name__) - +USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1 +if USE_CUSTOM_NCCL: + from text_generation_server.utils.my_dist import initialize_mpi_distributed class FlashLlama(FlashCausalLM): def __init__( @@ -29,7 +32,10 @@ class FlashLlama(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - self.process_group, rank, world_size = initialize_torch_distributed() + if USE_CUSTOM_NCCL: + self.process_group, rank, world_size, COMM = initialize_mpi_distributed() + else: + self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype @@ -37,6 +43,7 @@ class FlashLlama(FlashCausalLM): raise NotImplementedError("FlashLlama is only available on GPU") try: + raise tokenizer = LlamaTokenizer.from_pretrained( model_id, revision=revision, @@ -58,7 +65,10 @@ class FlashLlama(FlashCausalLM): ) config.quantize = quantize - torch.distributed.barrier(group=self.process_group) + if USE_CUSTOM_NCCL: + COMM.barrier() + else: + torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) @@ -67,7 +77,10 @@ class FlashLlama(FlashCausalLM): model = FlashLlamaForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) + if USE_CUSTOM_NCCL: + COMM.barrier() + else: + torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 75d2b159..5848d3fa 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -144,6 +144,7 @@ def serve( dtype: Optional[str] = None, trust_remote_code: bool = False, ): + logger.info(os.environ) unix_socket_template = "unix://{}-{}" if sharded: server_urls = [ @@ -154,6 +155,13 @@ def serve( else: local_url = unix_socket_template.format(uds_path, 0) server_urls = [local_url] + + if int(os.environ.get("USE_CUSTOM_NCCL", 0)): + server_urls = [ + unix_socket_template.format(uds_path, rank) + for rank in range(int(os.environ["OMPI_COMM_WORLD_SIZE"])) + ] + local_url = server_urls[int(os.environ["OMPI_COMM_WORLD_RANK"])] try: model = get_model( @@ -206,4 +214,4 @@ def serve( asyncio.run( serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) - ) + ) \ No newline at end of file diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 7bb95dd2..d0667bda 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,6 +1,5 @@ import os import torch -import torch.distributed from torch import nn from torch.nn import functional as F @@ -52,7 +51,9 @@ try: HAS_EETQ = True except ImportError: pass - +import my_custom_comm +USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1 +USE_LM_HEAD_PARALLEL = int(os.getenv("USE_LM_HEAD_PARALLEL", "1")) # Monkey patching @classmethod @@ -358,6 +359,7 @@ class TensorParallelHead(SuperLayer): def load(config, prefix: str, weights): if weights.process_group.size() > 1: try: + assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1 weight = weights.get_sharded(f"{prefix}.weight", dim=0) should_gather = True except AssertionError: @@ -365,6 +367,7 @@ class TensorParallelHead(SuperLayer): # just load the entire thing. weight = weights.get_tensor(f"{prefix}.weight") should_gather = False + logger.info("Disabled lm head parallel! ") else: weight = weights.get_tensor(f"{prefix}.weight") should_gather = False @@ -469,7 +472,10 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1: - torch.distributed.all_reduce(out, group=self.process_group) + if USE_CUSTOM_NCCL: + my_custom_comm.custom_allreduce(out, self.process_group.tp_comm) + else: + torch.distributed.all_reduce(out, group=self.process_group) return out @@ -504,7 +510,10 @@ class TensorParallelEmbedding(nn.Module): ) out = torch.nn.functional.embedding(input, self.weight) if self.reduce and self.process_group.size() > 1: - torch.distributed.all_reduce(out, group=self.process_group) + if USE_CUSTOM_NCCL: + my_custom_comm.custom_allreduce(out, self.process_group.tp_comm) + else: + torch.distributed.all_reduce(out, group=self.process_group) return out