remove a tad of cpu bottleneck

This commit is contained in:
OlivierDehaene 2023-12-11 10:32:13 +01:00
parent af1989459c
commit e69eed8ea3
6 changed files with 104 additions and 85 deletions

View File

@ -188,10 +188,7 @@ def download_weights(
# Try to see if there are local pytorch weights # Try to see if there are local pytorch weights
try: try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
try: local_pt_files = utils.weight_files(model_id, revision, ".bin")
local_pt_files = utils.weight_files(model_id, revision, extension=".bin")
except (FileNotFoundError, utils.EntryNotFoundError):
local_pt_files = utils.weight_files(model_id, revision, extension=".pt")
# No local pytorch weights # No local pytorch weights
except utils.LocalEntryNotFoundError: except utils.LocalEntryNotFoundError:
@ -202,10 +199,7 @@ def download_weights(
) )
# Try to see if there are pytorch weights on the hub # Try to see if there are pytorch weights on the hub
try: pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
pt_filenames = utils.weight_hub_files(model_id, revision, extension=".bin")
except utils.EntryNotFoundError:
pt_filenames = utils.weight_hub_files(model_id, revision, extension=".pt")
# Download pytorch weights # Download pytorch weights
local_pt_files = utils.download_weights(pt_filenames, model_id, revision) local_pt_files = utils.download_weights(pt_filenames, model_id, revision)

View File

@ -282,7 +282,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "mistral": if model_type in ["mistral", "mixtral"]:
if MISTRAL: if MISTRAL:
return FlashMistral( return FlashMistral(
model_id, model_id,
@ -292,7 +292,7 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
raise NotImplementedError("Mistral model requires flash attention v2") raise NotImplementedError("Mistral models requires flash attention v2")
if model_type == "opt": if model_type == "opt":
return OPTSharded( return OPTSharded(

View File

@ -86,7 +86,7 @@ class MixtralConfig(PretrainedConfig):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.num_hidden_layers = 4 self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
@ -123,7 +123,7 @@ def load_attention(config, prefix, weights):
else: else:
return TensorParallelColumnLinear.load_multi( return TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.wq", f"{prefix}.wk", f"{prefix}.wv"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0, dim=0,
weights=weights, weights=weights,
bias=False, bias=False,
@ -135,7 +135,7 @@ def _load_gqa(config, prefix: str, weights):
assert config.num_attention_heads % weights.process_group.size() == 0 assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.wq", f"{prefix}.wk", f"{prefix}.wv"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize, quantize=config.quantize,
dim=0, dim=0,
) )
@ -156,34 +156,35 @@ def _load_gqa(config, prefix: str, weights):
) )
def _load_experts(config, prefix, weights): def _load_experts(config, prefix, mat, weights):
if config.quantize is not None: if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.") raise NotImplementedError("Mixtral does not support weight quantization yet.")
slice_ = weights._get_slice(prefix) assert mat in ["w1", "w2", "w3"]
world_size = weights.process_group.size() world_size = weights.process_group.size()
rank = weights.process_group.rank() rank = weights.process_group.rank()
if world_size == 1:
tensor = slice_[:].to(dtype=weights.dtype).to(device=weights.device)
else:
assert ( assert (
config.intermediate_size % world_size == 0 config.intermediate_size % world_size == 0
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
assert slice_.get_shape()[0] == config.num_local_experts * config.intermediate_size
block_size = config.intermediate_size // world_size block_size = config.intermediate_size // world_size
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
expert_slices = [] tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size),
dtype=weights.dtype,
device=weights.device)
for i in range(config.num_local_experts): for i in range(config.num_local_experts):
expert_start = i * config.intermediate_size slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
expert_slices.append(slice_[start + expert_start:stop + expert_start])
tensor = torch.cat(expert_slices, dim=0).to(dtype=weights.dtype).to(device=weights.device)
if mat == "w2":
expert_slice = slice_[:, start:stop].t().contiguous()
else:
expert_slice = slice_[start:stop]
tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device)
return tensor return tensor
@ -223,9 +224,9 @@ class MixtralAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.wo = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.wo", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=False, bias=False,
) )
@ -299,7 +300,25 @@ class MixtralAttention(torch.nn.Module):
max_s, max_s,
) )
return self.wo(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@torch.jit.script
def select_experts(gate_logits: torch.Tensor, top_k: int):
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
# weights, selected_experts: (sequence_length, top-k)
weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
weights = weights.view(-1)
selected_experts = selected_experts.view(-1)
return selected_experts, weights
@torch.jit.script
def round_up(x: torch.Tensor, value: int):
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
class BlockSparseMoE(nn.Module): class BlockSparseMoE(nn.Module):
@ -339,9 +358,12 @@ class BlockSparseMoE(nn.Module):
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
self.w1 = _load_experts(config, f"{prefix}.w1", weights) self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).t()
self.w2 = _load_experts(config, f"{prefix}.w2", weights) self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
self.w3 = _load_experts(config, f"{prefix}.w3", weights) self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).t()
self.offsets = None
self.offsets_block_rows = 0
self.process_group = weights.process_group self.process_group = weights.process_group
@ -361,13 +383,18 @@ class BlockSparseMoE(nn.Module):
# dimensionality of a single expert. # dimensionality of a single expert.
block_rows = padded_tokens // self.blocking block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_dim // self.blocking blocks_per_row = self.ffn_dim // self.blocking
offsets = torch.arange( if self.offsets is None or block_rows > self.offsets_block_rows:
self.offsets = torch.arange(
0, 0,
block_rows * blocks_per_row + 1, block_rows * blocks_per_row + 1,
blocks_per_row, blocks_per_row,
dtype=torch.int32, dtype=torch.int32,
device=x.device, device=x.device,
) )
self.offsets_block_rows = block_rows
offsets = self.offsets
else:
offsets = self.offsets[:block_rows]
# Indices for the sparse matrix. The indices for # Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending # the intermediate matrix are dynamic depending
@ -375,7 +402,6 @@ class BlockSparseMoE(nn.Module):
column_indices = ops.topology(padded_bins, self.blocking, block_rows, column_indices = ops.topology(padded_bins, self.blocking, block_rows,
blocks_per_row) blocks_per_row)
# TODO(tgale): This is unused. Remove the need for this in stk.
# For now, use meta init to save the device memory. # For now, use meta init to save the device memory.
data = torch.empty( data = torch.empty(
column_indices.numel(), column_indices.numel(),
@ -400,7 +426,7 @@ class BlockSparseMoE(nn.Module):
def indices_and_padded_bins(self, selected_experts: torch.Tensor): def indices_and_padded_bins(self, selected_experts: torch.Tensor):
# Sort the expert ids to produce the scatter/gather # Sort the expert ids to produce the scatter/gather
# indices for the permutation. # indices for the permutation.
selected_experts = selected_experts.int() # selected_experts = selected_experts.int()
# returns bin_ids == num of experts for this sequence ? == unique selected experts? # returns bin_ids == num of experts for this sequence ? == unique selected experts?
# and indices == how to sort tokens? # and indices == how to sort tokens?
@ -418,7 +444,7 @@ class BlockSparseMoE(nn.Module):
# position of each bin. # position of each bin.
# List of size num_experts # List of size num_experts
padded_tokens_per_expert = ops.round_up(tokens_per_expert, padded_tokens_per_expert = round_up(tokens_per_expert,
self.blocking) self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...] # padded_tokens_per_expert => [128, O, 128, ...]
@ -446,13 +472,7 @@ class BlockSparseMoE(nn.Module):
# gate_logits: (sequence_length, n_experts) # gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x) gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax selected_experts, weights = select_experts(gate_logits, self.top_k)
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
# weights, selected_experts: (sequence_length, top-k)
weights, selected_experts = torch.topk(all_probs, self.top_k, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
weights = weights.flatten().to(x.dtype)
selected_experts = selected_experts.flatten()
( (
indices, indices,
@ -465,7 +485,7 @@ class BlockSparseMoE(nn.Module):
# Permute tokens and pad to prepare expert computation # Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim) # (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins,
self.top_k, x.shape[0] + self.num_experts * self.blocking) self.top_k)
# Create the sparse matrix topology # Create the sparse matrix topology
with torch.no_grad(): with torch.no_grad():
@ -476,8 +496,8 @@ class BlockSparseMoE(nn.Module):
# (top_k * sequence_length + padding, ffn_dim * n_experts) # (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix( x = stk.Matrix(
topo.size(), topo.size(),
self.act(stk.ops.sdd(x, self.w1.t(), topo).data) * self.act(stk.ops.sdd(x, self.w1, topo).data) *
stk.ops.sdd(x, self.w3.t(), topo).data, stk.ops.sdd(x, self.w3, topo).data,
topo.row_indices, topo.row_indices,
topo.column_indices, topo.column_indices,
topo.offsets, topo.offsets,
@ -512,18 +532,18 @@ class BlockSparseMoE(nn.Module):
class MixtralLayer(nn.Module): class MixtralLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"layers.{layer_id}" prefix = f"model.layers.{layer_id}"
self.attention = MixtralAttention( self.self_attn = MixtralAttention(
prefix=f"{prefix}.attention", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights) self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
self.attention_norm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.attention_norm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
) )
self.ffn_norm = FastRMSNorm.load( self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.ffn_norm", prefix=f"{prefix}.post_attention_layernorm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )
@ -542,10 +562,10 @@ class MixtralLayer(nn.Module):
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
normed_hidden_states, res = self.attention_norm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention # Self Attention
attn_output = self.attention( attn_output = self.self_attn(
normed_hidden_states, normed_hidden_states,
cos, cos,
sin, sin,
@ -559,21 +579,21 @@ class MixtralLayer(nn.Module):
) )
# faster post attention rms norm # faster post attention rms norm
normed_attn_res_output, attn_res = self.ffn_norm( normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res attn_output, res
) )
mlp_output = self.block_sparse_moe(normed_attn_res_output) block_sparse_moe_output = self.block_sparse_moe(normed_attn_res_output)
return mlp_output, attn_res return block_sparse_moe_output, attn_res
class MixtralModel(torch.nn.Module): class MixtralModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
self.tok_embeddings = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="tok_embeddings", weights=weights prefix="model.embed_tokens", weights=weights
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
@ -587,12 +607,12 @@ class MixtralModel(torch.nn.Module):
] ]
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="norm", weights=weights, eps=config.rms_norm_eps prefix="model.norm", weights=weights, eps=config.rms_norm_eps
) )
self.head_size = self.layers[0].attention.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].attention.num_heads self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].attention.num_key_value_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward( def forward(
self, self,
@ -606,11 +626,11 @@ class MixtralModel(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.tok_embeddings(input_ids) hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype position_ids, max_s, hidden_states.dtype
) )
@ -642,7 +662,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
self.model = MixtralModel(config, weights) self.model = MixtralModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = TensorParallelHead.load(
config, config,
prefix="output", prefix="lm_head",
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window

View File

@ -91,8 +91,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
# from torch.profiler import profile, ProfilerActivity
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
generations, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
# if self.model.rank == 0:
# prefill_prof.export_chrome_trace("prefill.json")
return generate_pb2.PrefillResponse( return generate_pb2.PrefillResponse(
generations=[generation.to_pb() for generation in generations], generations=[generation.to_pb() for generation in generations],
@ -118,8 +122,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
else: else:
batch = batches[0] batch = batches[0]
# from torch.profiler import profile, ProfilerActivity
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof:
generations, next_batch = self.model.generate_token(batch) generations, next_batch = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
# if self.model.rank == 0:
# prefill_prof.export_chrome_trace("decode.json")
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations], generations=[generation.to_pb() for generation in generations],

View File

@ -98,10 +98,7 @@ def weight_files(
if extension != ".safetensors": if extension != ".safetensors":
raise e raise e
# Try to see if there are pytorch weights # Try to see if there are pytorch weights
try:
pt_filenames = weight_hub_files(model_id, revision, extension=".bin") pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
except EntryNotFoundError:
pt_filenames = weight_hub_files(model_id, revision, extension=".pt")
# Change pytorch extension to safetensors extension # Change pytorch extension to safetensors extension
# It is possible that we have safetensors weights locally even though they are not on the # It is possible that we have safetensors weights locally even though they are not on the
# hub if we converted weights locally without pushing them # hub if we converted weights locally without pushing them

View File

@ -23,7 +23,7 @@ class Weights:
with safe_open(filename, framework="pytorch") as f: with safe_open(filename, framework="pytorch") as f:
for k in f.keys(): for k in f.keys():
if k in routing: if k in routing:
logger.warning( raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}" f"Key {k} was found in multiple files: {filename} and {routing[k]}"
) )
routing[k] = filename routing[k] = filename
@ -116,7 +116,7 @@ class Weights:
size = slice_.get_shape()[dim] size = slice_.get_shape()[dim]
assert ( assert (
size % world_size == 0 size % world_size == 0
), f"The chosen size {size} is not compatible with sharding on {world_size} shards" ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim) return self.get_partial_sharded(tensor_name, dim)
def _get_qweight(self, name: str): def _get_qweight(self, name: str):