remove a tad of cpu bottleneck

This commit is contained in:
OlivierDehaene 2023-12-11 10:32:13 +01:00
parent af1989459c
commit e69eed8ea3
6 changed files with 104 additions and 85 deletions

View File

@ -188,10 +188,7 @@ def download_weights(
# Try to see if there are local pytorch weights
try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
try:
local_pt_files = utils.weight_files(model_id, revision, extension=".bin")
except (FileNotFoundError, utils.EntryNotFoundError):
local_pt_files = utils.weight_files(model_id, revision, extension=".pt")
local_pt_files = utils.weight_files(model_id, revision, ".bin")
# No local pytorch weights
except utils.LocalEntryNotFoundError:
@ -202,10 +199,7 @@ def download_weights(
)
# Try to see if there are pytorch weights on the hub
try:
pt_filenames = utils.weight_hub_files(model_id, revision, extension=".bin")
except utils.EntryNotFoundError:
pt_filenames = utils.weight_hub_files(model_id, revision, extension=".pt")
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
# Download pytorch weights
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)

View File

@ -282,7 +282,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "mistral":
if model_type in ["mistral", "mixtral"]:
if MISTRAL:
return FlashMistral(
model_id,
@ -292,7 +292,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mistral model requires flash attention v2")
raise NotImplementedError("Mistral models requires flash attention v2")
if model_type == "opt":
return OPTSharded(

View File

@ -86,7 +86,7 @@ class MixtralConfig(PretrainedConfig):
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = 4
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
@ -123,7 +123,7 @@ def load_attention(config, prefix, weights):
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.wq", f"{prefix}.wk", f"{prefix}.wv"],
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
@ -135,7 +135,7 @@ def _load_gqa(config, prefix: str, weights):
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.wq", f"{prefix}.wk", f"{prefix}.wv"],
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
@ -156,34 +156,35 @@ def _load_gqa(config, prefix: str, weights):
)
def _load_experts(config, prefix, weights):
def _load_experts(config, prefix, mat, weights):
if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.")
slice_ = weights._get_slice(prefix)
assert mat in ["w1", "w2", "w3"]
world_size = weights.process_group.size()
rank = weights.process_group.rank()
if world_size == 1:
tensor = slice_[:].to(dtype=weights.dtype).to(device=weights.device)
else:
assert (
config.intermediate_size % world_size == 0
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
assert slice_.get_shape()[0] == config.num_local_experts * config.intermediate_size
block_size = config.intermediate_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
expert_slices = []
tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size),
dtype=weights.dtype,
device=weights.device)
for i in range(config.num_local_experts):
expert_start = i * config.intermediate_size
expert_slices.append(slice_[start + expert_start:stop + expert_start])
tensor = torch.cat(expert_slices, dim=0).to(dtype=weights.dtype).to(device=weights.device)
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
if mat == "w2":
expert_slice = slice_[:, start:stop].t().contiguous()
else:
expert_slice = slice_[start:stop]
tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device)
return tensor
@ -223,9 +224,9 @@ class MixtralAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights)
self.wo = TensorParallelRowLinear.load(
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.wo",
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
@ -299,7 +300,25 @@ class MixtralAttention(torch.nn.Module):
max_s,
)
return self.wo(attn_output.view(-1, self.num_heads * self.head_size))
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@torch.jit.script
def select_experts(gate_logits: torch.Tensor, top_k: int):
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
# weights, selected_experts: (sequence_length, top-k)
weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
weights = weights.view(-1)
selected_experts = selected_experts.view(-1)
return selected_experts, weights
@torch.jit.script
def round_up(x: torch.Tensor, value: int):
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
class BlockSparseMoE(nn.Module):
@ -339,9 +358,12 @@ class BlockSparseMoE(nn.Module):
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
self.w1 = _load_experts(config, f"{prefix}.w1", weights)
self.w2 = _load_experts(config, f"{prefix}.w2", weights)
self.w3 = _load_experts(config, f"{prefix}.w3", weights)
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).t()
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).t()
self.offsets = None
self.offsets_block_rows = 0
self.process_group = weights.process_group
@ -361,13 +383,18 @@ class BlockSparseMoE(nn.Module):
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_dim // self.blocking
offsets = torch.arange(
if self.offsets is None or block_rows > self.offsets_block_rows:
self.offsets = torch.arange(
0,
block_rows * blocks_per_row + 1,
blocks_per_row,
dtype=torch.int32,
device=x.device,
)
self.offsets_block_rows = block_rows
offsets = self.offsets
else:
offsets = self.offsets[:block_rows]
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
@ -375,7 +402,6 @@ class BlockSparseMoE(nn.Module):
column_indices = ops.topology(padded_bins, self.blocking, block_rows,
blocks_per_row)
# TODO(tgale): This is unused. Remove the need for this in stk.
# For now, use meta init to save the device memory.
data = torch.empty(
column_indices.numel(),
@ -400,7 +426,7 @@ class BlockSparseMoE(nn.Module):
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
selected_experts = selected_experts.int()
# selected_experts = selected_experts.int()
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
# and indices == how to sort tokens?
@ -418,7 +444,7 @@ class BlockSparseMoE(nn.Module):
# position of each bin.
# List of size num_experts
padded_tokens_per_expert = ops.round_up(tokens_per_expert,
padded_tokens_per_expert = round_up(tokens_per_expert,
self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...]
@ -446,13 +472,7 @@ class BlockSparseMoE(nn.Module):
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
# weights, selected_experts: (sequence_length, top-k)
weights, selected_experts = torch.topk(all_probs, self.top_k, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
weights = weights.flatten().to(x.dtype)
selected_experts = selected_experts.flatten()
selected_experts, weights = select_experts(gate_logits, self.top_k)
(
indices,
@ -465,7 +485,7 @@ class BlockSparseMoE(nn.Module):
# Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins,
self.top_k, x.shape[0] + self.num_experts * self.blocking)
self.top_k)
# Create the sparse matrix topology
with torch.no_grad():
@ -476,8 +496,8 @@ class BlockSparseMoE(nn.Module):
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix(
topo.size(),
self.act(stk.ops.sdd(x, self.w1.t(), topo).data) *
stk.ops.sdd(x, self.w3.t(), topo).data,
self.act(stk.ops.sdd(x, self.w1, topo).data) *
stk.ops.sdd(x, self.w3, topo).data,
topo.row_indices,
topo.column_indices,
topo.offsets,
@ -512,18 +532,18 @@ class BlockSparseMoE(nn.Module):
class MixtralLayer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"layers.{layer_id}"
prefix = f"model.layers.{layer_id}"
self.attention = MixtralAttention(
prefix=f"{prefix}.attention", config=config, weights=weights
self.self_attn = MixtralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
self.attention_norm = FastRMSNorm.load(
prefix=f"{prefix}.attention_norm", weights=weights, eps=config.rms_norm_eps
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.ffn_norm = FastRMSNorm.load(
prefix=f"{prefix}.ffn_norm",
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
@ -542,10 +562,10 @@ class MixtralLayer(nn.Module):
max_s,
prefill_cache_indices,
):
normed_hidden_states, res = self.attention_norm(hidden_states, residual)
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.attention(
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
@ -559,21 +579,21 @@ class MixtralLayer(nn.Module):
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.ffn_norm(
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.block_sparse_moe(normed_attn_res_output)
block_sparse_moe_output = self.block_sparse_moe(normed_attn_res_output)
return mlp_output, attn_res
return block_sparse_moe_output, attn_res
class MixtralModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.tok_embeddings = TensorParallelEmbedding(
prefix="tok_embeddings", weights=weights
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
@ -587,12 +607,12 @@ class MixtralModel(torch.nn.Module):
]
)
self.norm = FastRMSNorm.load(
prefix="norm", weights=weights, eps=config.rms_norm_eps
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads
self.num_key_value_heads = self.layers[0].attention.num_key_value_heads
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
@ -606,11 +626,11 @@ class MixtralModel(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.tok_embeddings(input_ids)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
@ -642,7 +662,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
self.model = MixtralModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,
prefix="output",
prefix="lm_head",
weights=weights,
)
self.max_past = config.sliding_window

View File

@ -91,8 +91,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
# from torch.profiler import profile, ProfilerActivity
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch)
# if self.model.rank == 0:
# prefill_prof.export_chrome_trace("prefill.json")
return generate_pb2.PrefillResponse(
generations=[generation.to_pb() for generation in generations],
@ -118,8 +122,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
else:
batch = batches[0]
# from torch.profiler import profile, ProfilerActivity
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch)
# if self.model.rank == 0:
# prefill_prof.export_chrome_trace("decode.json")
return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations],

View File

@ -98,10 +98,7 @@ def weight_files(
if extension != ".safetensors":
raise e
# Try to see if there are pytorch weights
try:
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
except EntryNotFoundError:
pt_filenames = weight_hub_files(model_id, revision, extension=".pt")
# Change pytorch extension to safetensors extension
# It is possible that we have safetensors weights locally even though they are not on the
# hub if we converted weights locally without pushing them

View File

@ -23,7 +23,7 @@ class Weights:
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
logger.warning(
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
@ -116,7 +116,7 @@ class Weights:
size = slice_.get_shape()[dim]
assert (
size % world_size == 0
), f"The chosen size {size} is not compatible with sharding on {world_size} shards"
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def _get_qweight(self, name: str):