mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
tp optims
This commit is contained in:
parent
b226e469c9
commit
dad29f7299
@ -524,10 +524,18 @@ fn shard_manager(
|
|||||||
if let Some(watermark_delta) = watermark_delta {
|
if let Some(watermark_delta) = watermark_delta {
|
||||||
envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
||||||
}
|
}
|
||||||
|
// for mpi
|
||||||
|
let n_devices = match num_cuda_devices() {
|
||||||
|
Some(value) => value.to_string(),
|
||||||
|
None => String::from("0"),
|
||||||
|
};
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting shard");
|
tracing::info!("Starting shard");
|
||||||
let mut p = match Command::new("text-generation-server")
|
tracing::info!("run with mpi and use cuda device num: {}", n_devices);
|
||||||
|
// tracing::info!("{:?}", shard_args);
|
||||||
|
tracing::info!("{:?}", envs);
|
||||||
|
let mut p = match Command::new("mpirun")
|
||||||
|
.args(&["-n", &n_devices, "--allow-run-as-root", "text-generation-server"])
|
||||||
.args(shard_args)
|
.args(shard_args)
|
||||||
.envs(envs)
|
.envs(envs)
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
@ -538,8 +546,7 @@ fn shard_manager(
|
|||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
if err.kind() == io::ErrorKind::NotFound {
|
if err.kind() == io::ErrorKind::NotFound {
|
||||||
tracing::error!("text-generation-server not found in PATH");
|
tracing::error!("start mpi failed! ");
|
||||||
tracing::error!("Please install it with `make install-server`")
|
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
tracing::error!("{}", err);
|
tracing::error!("{}", err);
|
||||||
|
@ -402,9 +402,19 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
|
|
||||||
|
import os
|
||||||
|
if int(os.getenv("USE_TP_EMBEDDING", "1")) == 1:
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="model.embed_tokens", weights=weights
|
prefix="model.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from loguru import logger
|
||||||
|
embeddings = weights.get_tensor(f"model.embed_tokens.weight")
|
||||||
|
self.embed_tokens = nn.Embedding.from_pretrained(F.pad(embeddings, (0, 0, 0, 1)),
|
||||||
|
padding_idx=config.pad_token_id)
|
||||||
|
logger.info("Disabled embedding tensor parallel! ")
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
|
@ -16,9 +16,12 @@ from text_generation_server.utils import (
|
|||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
import os
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1
|
||||||
|
if USE_CUSTOM_NCCL:
|
||||||
|
from text_generation_server.utils.my_dist import initialize_mpi_distributed
|
||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
class FlashLlama(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -29,6 +32,9 @@ class FlashLlama(FlashCausalLM):
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
if USE_CUSTOM_NCCL:
|
||||||
|
self.process_group, rank, world_size, COMM = initialize_mpi_distributed()
|
||||||
|
else:
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
@ -37,6 +43,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
raise
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(
|
tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
@ -58,6 +65,9 @@ class FlashLlama(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
|
||||||
|
if USE_CUSTOM_NCCL:
|
||||||
|
COMM.barrier()
|
||||||
|
else:
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
@ -67,6 +77,9 @@ class FlashLlama(FlashCausalLM):
|
|||||||
|
|
||||||
model = FlashLlamaForCausalLM(config, weights)
|
model = FlashLlamaForCausalLM(config, weights)
|
||||||
|
|
||||||
|
if USE_CUSTOM_NCCL:
|
||||||
|
COMM.barrier()
|
||||||
|
else:
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashLlama, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -144,6 +144,7 @@ def serve(
|
|||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
logger.info(os.environ)
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
if sharded:
|
if sharded:
|
||||||
server_urls = [
|
server_urls = [
|
||||||
@ -155,6 +156,13 @@ def serve(
|
|||||||
local_url = unix_socket_template.format(uds_path, 0)
|
local_url = unix_socket_template.format(uds_path, 0)
|
||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
|
if int(os.environ.get("USE_CUSTOM_NCCL", 0)):
|
||||||
|
server_urls = [
|
||||||
|
unix_socket_template.format(uds_path, rank)
|
||||||
|
for rank in range(int(os.environ["OMPI_COMM_WORLD_SIZE"]))
|
||||||
|
]
|
||||||
|
local_url = server_urls[int(os.environ["OMPI_COMM_WORLD_RANK"])]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = get_model(
|
model = get_model(
|
||||||
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
@ -52,7 +51,9 @@ try:
|
|||||||
HAS_EETQ = True
|
HAS_EETQ = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
import my_custom_comm
|
||||||
|
USE_CUSTOM_NCCL = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) > 1 and int(os.getenv("USE_CUSTOM_NCCL", "0")) == 1
|
||||||
|
USE_LM_HEAD_PARALLEL = int(os.getenv("USE_LM_HEAD_PARALLEL", "1"))
|
||||||
|
|
||||||
# Monkey patching
|
# Monkey patching
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -358,6 +359,7 @@ class TensorParallelHead(SuperLayer):
|
|||||||
def load(config, prefix: str, weights):
|
def load(config, prefix: str, weights):
|
||||||
if weights.process_group.size() > 1:
|
if weights.process_group.size() > 1:
|
||||||
try:
|
try:
|
||||||
|
assert USE_CUSTOM_NCCL == 0 and USE_LM_HEAD_PARALLEL == 1
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
should_gather = True
|
should_gather = True
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
@ -365,6 +367,7 @@ class TensorParallelHead(SuperLayer):
|
|||||||
# just load the entire thing.
|
# just load the entire thing.
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
should_gather = False
|
should_gather = False
|
||||||
|
logger.info("Disabled lm head parallel! ")
|
||||||
else:
|
else:
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
should_gather = False
|
should_gather = False
|
||||||
@ -469,6 +472,9 @@ class TensorParallelRowLinear(SuperLayer):
|
|||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
|
if USE_CUSTOM_NCCL:
|
||||||
|
my_custom_comm.custom_allreduce(out, self.process_group.tp_comm)
|
||||||
|
else:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -504,6 +510,9 @@ class TensorParallelEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
out = torch.nn.functional.embedding(input, self.weight)
|
out = torch.nn.functional.embedding(input, self.weight)
|
||||||
if self.reduce and self.process_group.size() > 1:
|
if self.reduce and self.process_group.size() > 1:
|
||||||
|
if USE_CUSTOM_NCCL:
|
||||||
|
my_custom_comm.custom_allreduce(out, self.process_group.tp_comm)
|
||||||
|
else:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user