mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
remove a tad of cpu bottleneck
This commit is contained in:
parent
af1989459c
commit
e69eed8ea3
@ -188,10 +188,7 @@ def download_weights(
|
||||
# Try to see if there are local pytorch weights
|
||||
try:
|
||||
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||
try:
|
||||
local_pt_files = utils.weight_files(model_id, revision, extension=".bin")
|
||||
except (FileNotFoundError, utils.EntryNotFoundError):
|
||||
local_pt_files = utils.weight_files(model_id, revision, extension=".pt")
|
||||
local_pt_files = utils.weight_files(model_id, revision, ".bin")
|
||||
|
||||
# No local pytorch weights
|
||||
except utils.LocalEntryNotFoundError:
|
||||
@ -202,10 +199,7 @@ def download_weights(
|
||||
)
|
||||
|
||||
# Try to see if there are pytorch weights on the hub
|
||||
try:
|
||||
pt_filenames = utils.weight_hub_files(model_id, revision, extension=".bin")
|
||||
except utils.EntryNotFoundError:
|
||||
pt_filenames = utils.weight_hub_files(model_id, revision, extension=".pt")
|
||||
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
|
||||
# Download pytorch weights
|
||||
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
||||
|
||||
|
@ -282,7 +282,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "mistral":
|
||||
if model_type in ["mistral", "mixtral"]:
|
||||
if MISTRAL:
|
||||
return FlashMistral(
|
||||
model_id,
|
||||
@ -292,7 +292,7 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
raise NotImplementedError("Mistral model requires flash attention v2")
|
||||
raise NotImplementedError("Mistral models requires flash attention v2")
|
||||
|
||||
if model_type == "opt":
|
||||
return OPTSharded(
|
||||
|
@ -86,7 +86,7 @@ class MixtralConfig(PretrainedConfig):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = 4
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
@ -123,7 +123,7 @@ def load_attention(config, prefix, weights):
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.wq", f"{prefix}.wk", f"{prefix}.wv"],
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
@ -135,7 +135,7 @@ def _load_gqa(config, prefix: str, weights):
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.wq", f"{prefix}.wk", f"{prefix}.wv"],
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
@ -156,34 +156,35 @@ def _load_gqa(config, prefix: str, weights):
|
||||
)
|
||||
|
||||
|
||||
def _load_experts(config, prefix, weights):
|
||||
def _load_experts(config, prefix, mat, weights):
|
||||
if config.quantize is not None:
|
||||
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
||||
|
||||
slice_ = weights._get_slice(prefix)
|
||||
assert mat in ["w1", "w2", "w3"]
|
||||
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
if world_size == 1:
|
||||
tensor = slice_[:].to(dtype=weights.dtype).to(device=weights.device)
|
||||
else:
|
||||
assert (
|
||||
config.intermediate_size % world_size == 0
|
||||
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
|
||||
assert slice_.get_shape()[0] == config.num_local_experts * config.intermediate_size
|
||||
assert (
|
||||
config.intermediate_size % world_size == 0
|
||||
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
|
||||
|
||||
block_size = config.intermediate_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
block_size = config.intermediate_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
expert_slices = []
|
||||
for i in range(config.num_local_experts):
|
||||
expert_start = i * config.intermediate_size
|
||||
tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size),
|
||||
dtype=weights.dtype,
|
||||
device=weights.device)
|
||||
|
||||
expert_slices.append(slice_[start + expert_start:stop + expert_start])
|
||||
|
||||
tensor = torch.cat(expert_slices, dim=0).to(dtype=weights.dtype).to(device=weights.device)
|
||||
for i in range(config.num_local_experts):
|
||||
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
||||
|
||||
if mat == "w2":
|
||||
expert_slice = slice_[:, start:stop].t().contiguous()
|
||||
else:
|
||||
expert_slice = slice_[start:stop]
|
||||
tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device)
|
||||
return tensor
|
||||
|
||||
|
||||
@ -223,9 +224,9 @@ class MixtralAttention(torch.nn.Module):
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.wo = TensorParallelRowLinear.load(
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.wo",
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
@ -299,7 +300,25 @@ class MixtralAttention(torch.nn.Module):
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.wo(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def select_experts(gate_logits: torch.Tensor, top_k: int):
|
||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||
# weights, selected_experts: (sequence_length, top-k)
|
||||
weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
|
||||
weights /= weights.sum(dim=-1, keepdim=True)
|
||||
weights = weights.view(-1)
|
||||
selected_experts = selected_experts.view(-1)
|
||||
|
||||
return selected_experts, weights
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def round_up(x: torch.Tensor, value: int):
|
||||
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
|
||||
|
||||
|
||||
class BlockSparseMoE(nn.Module):
|
||||
@ -339,9 +358,12 @@ class BlockSparseMoE(nn.Module):
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||
self.w1 = _load_experts(config, f"{prefix}.w1", weights)
|
||||
self.w2 = _load_experts(config, f"{prefix}.w2", weights)
|
||||
self.w3 = _load_experts(config, f"{prefix}.w3", weights)
|
||||
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).t()
|
||||
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).t()
|
||||
|
||||
self.offsets = None
|
||||
self.offsets_block_rows = 0
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
@ -361,13 +383,18 @@ class BlockSparseMoE(nn.Module):
|
||||
# dimensionality of a single expert.
|
||||
block_rows = padded_tokens // self.blocking
|
||||
blocks_per_row = self.ffn_dim // self.blocking
|
||||
offsets = torch.arange(
|
||||
0,
|
||||
block_rows * blocks_per_row + 1,
|
||||
blocks_per_row,
|
||||
dtype=torch.int32,
|
||||
device=x.device,
|
||||
)
|
||||
if self.offsets is None or block_rows > self.offsets_block_rows:
|
||||
self.offsets = torch.arange(
|
||||
0,
|
||||
block_rows * blocks_per_row + 1,
|
||||
blocks_per_row,
|
||||
dtype=torch.int32,
|
||||
device=x.device,
|
||||
)
|
||||
self.offsets_block_rows = block_rows
|
||||
offsets = self.offsets
|
||||
else:
|
||||
offsets = self.offsets[:block_rows]
|
||||
|
||||
# Indices for the sparse matrix. The indices for
|
||||
# the intermediate matrix are dynamic depending
|
||||
@ -375,7 +402,6 @@ class BlockSparseMoE(nn.Module):
|
||||
column_indices = ops.topology(padded_bins, self.blocking, block_rows,
|
||||
blocks_per_row)
|
||||
|
||||
# TODO(tgale): This is unused. Remove the need for this in stk.
|
||||
# For now, use meta init to save the device memory.
|
||||
data = torch.empty(
|
||||
column_indices.numel(),
|
||||
@ -400,7 +426,7 @@ class BlockSparseMoE(nn.Module):
|
||||
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
|
||||
# Sort the expert ids to produce the scatter/gather
|
||||
# indices for the permutation.
|
||||
selected_experts = selected_experts.int()
|
||||
# selected_experts = selected_experts.int()
|
||||
|
||||
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
|
||||
# and indices == how to sort tokens?
|
||||
@ -418,8 +444,8 @@ class BlockSparseMoE(nn.Module):
|
||||
# position of each bin.
|
||||
|
||||
# List of size num_experts
|
||||
padded_tokens_per_expert = ops.round_up(tokens_per_expert,
|
||||
self.blocking)
|
||||
padded_tokens_per_expert = round_up(tokens_per_expert,
|
||||
self.blocking)
|
||||
# padded_tokens_per_expert => [128, O, 128, ...]
|
||||
|
||||
# Cumulative selected experts per token
|
||||
@ -446,13 +472,7 @@ class BlockSparseMoE(nn.Module):
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||
# weights, selected_experts: (sequence_length, top-k)
|
||||
weights, selected_experts = torch.topk(all_probs, self.top_k, dim=-1)
|
||||
weights /= weights.sum(dim=-1, keepdim=True)
|
||||
weights = weights.flatten().to(x.dtype)
|
||||
selected_experts = selected_experts.flatten()
|
||||
selected_experts, weights = select_experts(gate_logits, self.top_k)
|
||||
|
||||
(
|
||||
indices,
|
||||
@ -465,7 +485,7 @@ class BlockSparseMoE(nn.Module):
|
||||
# Permute tokens and pad to prepare expert computation
|
||||
# (top_k * sequence_length + padding, model_dim)
|
||||
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins,
|
||||
self.top_k, x.shape[0] + self.num_experts * self.blocking)
|
||||
self.top_k)
|
||||
|
||||
# Create the sparse matrix topology
|
||||
with torch.no_grad():
|
||||
@ -476,8 +496,8 @@ class BlockSparseMoE(nn.Module):
|
||||
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
||||
x = stk.Matrix(
|
||||
topo.size(),
|
||||
self.act(stk.ops.sdd(x, self.w1.t(), topo).data) *
|
||||
stk.ops.sdd(x, self.w3.t(), topo).data,
|
||||
self.act(stk.ops.sdd(x, self.w1, topo).data) *
|
||||
stk.ops.sdd(x, self.w3, topo).data,
|
||||
topo.row_indices,
|
||||
topo.column_indices,
|
||||
topo.offsets,
|
||||
@ -512,18 +532,18 @@ class BlockSparseMoE(nn.Module):
|
||||
class MixtralLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"layers.{layer_id}"
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
|
||||
self.attention = MixtralAttention(
|
||||
prefix=f"{prefix}.attention", config=config, weights=weights
|
||||
self.self_attn = MixtralAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
|
||||
|
||||
self.attention_norm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.attention_norm", weights=weights, eps=config.rms_norm_eps
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.ffn_norm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.ffn_norm",
|
||||
self.post_attention_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
@ -542,10 +562,10 @@ class MixtralLayer(nn.Module):
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
normed_hidden_states, res = self.attention_norm(hidden_states, residual)
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.attention(
|
||||
attn_output = self.self_attn(
|
||||
normed_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
@ -559,21 +579,21 @@ class MixtralLayer(nn.Module):
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.ffn_norm(
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.block_sparse_moe(normed_attn_res_output)
|
||||
block_sparse_moe_output = self.block_sparse_moe(normed_attn_res_output)
|
||||
|
||||
return mlp_output, attn_res
|
||||
return block_sparse_moe_output, attn_res
|
||||
|
||||
|
||||
class MixtralModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.tok_embeddings = TensorParallelEmbedding(
|
||||
prefix="tok_embeddings", weights=weights
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
@ -587,12 +607,12 @@ class MixtralModel(torch.nn.Module):
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="norm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.head_size = self.layers[0].attention.head_size
|
||||
self.num_heads = self.layers[0].attention.num_heads
|
||||
self.num_key_value_heads = self.layers[0].attention.num_key_value_heads
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -606,11 +626,11 @@ class MixtralModel(torch.nn.Module):
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.tok_embeddings(input_ids)
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
@ -642,7 +662,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
self.model = MixtralModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config,
|
||||
prefix="output",
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
|
@ -91,8 +91,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
# from torch.profiler import profile, ProfilerActivity
|
||||
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
# if self.model.rank == 0:
|
||||
# prefill_prof.export_chrome_trace("prefill.json")
|
||||
|
||||
return generate_pb2.PrefillResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
@ -118,8 +122,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
# from torch.profiler import profile, ProfilerActivity
|
||||
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
self.cache.set(next_batch)
|
||||
# if self.model.rank == 0:
|
||||
# prefill_prof.export_chrome_trace("decode.json")
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
|
@ -98,10 +98,7 @@ def weight_files(
|
||||
if extension != ".safetensors":
|
||||
raise e
|
||||
# Try to see if there are pytorch weights
|
||||
try:
|
||||
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
|
||||
except EntryNotFoundError:
|
||||
pt_filenames = weight_hub_files(model_id, revision, extension=".pt")
|
||||
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
|
||||
# Change pytorch extension to safetensors extension
|
||||
# It is possible that we have safetensors weights locally even though they are not on the
|
||||
# hub if we converted weights locally without pushing them
|
||||
|
@ -23,7 +23,7 @@ class Weights:
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
logger.warning(
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
@ -116,7 +116,7 @@ class Weights:
|
||||
size = slice_.get_shape()[dim]
|
||||
assert (
|
||||
size % world_size == 0
|
||||
), f"The chosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
return self.get_partial_sharded(tensor_name, dim)
|
||||
|
||||
def _get_qweight(self, name: str):
|
||||
|
Loading…
Reference in New Issue
Block a user