From e69eed8ea3c3508c739905c46dde27fc17c01cc5 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 11 Dec 2023 10:32:13 +0100 Subject: [PATCH] remove a tad of cpu bottleneck --- server/text_generation_server/cli.py | 10 +- .../text_generation_server/models/__init__.py | 4 +- .../custom_modeling/flash_mixtral_modeling.py | 158 ++++++++++-------- server/text_generation_server/server.py | 8 + server/text_generation_server/utils/hub.py | 5 +- .../text_generation_server/utils/weights.py | 4 +- 6 files changed, 104 insertions(+), 85 deletions(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 11ceecda..cb151173 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -188,10 +188,7 @@ def download_weights( # Try to see if there are local pytorch weights try: # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE - try: - local_pt_files = utils.weight_files(model_id, revision, extension=".bin") - except (FileNotFoundError, utils.EntryNotFoundError): - local_pt_files = utils.weight_files(model_id, revision, extension=".pt") + local_pt_files = utils.weight_files(model_id, revision, ".bin") # No local pytorch weights except utils.LocalEntryNotFoundError: @@ -202,10 +199,7 @@ def download_weights( ) # Try to see if there are pytorch weights on the hub - try: - pt_filenames = utils.weight_hub_files(model_id, revision, extension=".bin") - except utils.EntryNotFoundError: - pt_filenames = utils.weight_hub_files(model_id, revision, extension=".pt") + pt_filenames = utils.weight_hub_files(model_id, revision, ".bin") # Download pytorch weights local_pt_files = utils.download_weights(pt_filenames, model_id, revision) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 995222df..aae81be2 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -282,7 +282,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - if model_type == "mistral": + if model_type in ["mistral", "mixtral"]: if MISTRAL: return FlashMistral( model_id, @@ -292,7 +292,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - raise NotImplementedError("Mistral model requires flash attention v2") + raise NotImplementedError("Mistral models requires flash attention v2") if model_type == "opt": return OPTSharded( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 31709216..f8de6cf7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -86,7 +86,7 @@ class MixtralConfig(PretrainedConfig): self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.num_hidden_layers = 4 + self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.sliding_window = sliding_window @@ -123,7 +123,7 @@ def load_attention(config, prefix, weights): else: return TensorParallelColumnLinear.load_multi( config, - prefixes=[f"{prefix}.wq", f"{prefix}.wk", f"{prefix}.wv"], + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, bias=False, @@ -135,7 +135,7 @@ def _load_gqa(config, prefix: str, weights): assert config.num_attention_heads % weights.process_group.size() == 0 weight = weights.get_multi_weights_col( - prefixes=[f"{prefix}.wq", f"{prefix}.wk", f"{prefix}.wv"], + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], quantize=config.quantize, dim=0, ) @@ -156,34 +156,35 @@ def _load_gqa(config, prefix: str, weights): ) -def _load_experts(config, prefix, weights): +def _load_experts(config, prefix, mat, weights): if config.quantize is not None: raise NotImplementedError("Mixtral does not support weight quantization yet.") - slice_ = weights._get_slice(prefix) + assert mat in ["w1", "w2", "w3"] + world_size = weights.process_group.size() rank = weights.process_group.rank() - if world_size == 1: - tensor = slice_[:].to(dtype=weights.dtype).to(device=weights.device) - else: - assert ( - config.intermediate_size % world_size == 0 - ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" - assert slice_.get_shape()[0] == config.num_local_experts * config.intermediate_size + assert ( + config.intermediate_size % world_size == 0 + ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" - block_size = config.intermediate_size // world_size - start = rank * block_size - stop = (rank + 1) * block_size + block_size = config.intermediate_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size - expert_slices = [] - for i in range(config.num_local_experts): - expert_start = i * config.intermediate_size + tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size), + dtype=weights.dtype, + device=weights.device) - expert_slices.append(slice_[start + expert_start:stop + expert_start]) - - tensor = torch.cat(expert_slices, dim=0).to(dtype=weights.dtype).to(device=weights.device) + for i in range(config.num_local_experts): + slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") + if mat == "w2": + expert_slice = slice_[:, start:stop].t().contiguous() + else: + expert_slice = slice_[start:stop] + tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device) return tensor @@ -223,9 +224,9 @@ class MixtralAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) - self.wo = TensorParallelRowLinear.load( + self.o_proj = TensorParallelRowLinear.load( config, - prefix=f"{prefix}.wo", + prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) @@ -299,7 +300,25 @@ class MixtralAttention(torch.nn.Module): max_s, ) - return self.wo(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +@torch.jit.script +def select_experts(gate_logits: torch.Tensor, top_k: int): + # all_probs: (sequence_length, n_experts) and upcast for softmax + all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) + # weights, selected_experts: (sequence_length, top-k) + weights, selected_experts = torch.topk(all_probs, top_k, dim=-1) + weights /= weights.sum(dim=-1, keepdim=True) + weights = weights.view(-1) + selected_experts = selected_experts.view(-1) + + return selected_experts, weights + + +@torch.jit.script +def round_up(x: torch.Tensor, value: int): + return torch.div(x + (value - 1), value, rounding_mode="trunc") * value class BlockSparseMoE(nn.Module): @@ -339,9 +358,12 @@ class BlockSparseMoE(nn.Module): self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) - self.w1 = _load_experts(config, f"{prefix}.w1", weights) - self.w2 = _load_experts(config, f"{prefix}.w2", weights) - self.w3 = _load_experts(config, f"{prefix}.w3", weights) + self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).t() + self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights) + self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).t() + + self.offsets = None + self.offsets_block_rows = 0 self.process_group = weights.process_group @@ -361,13 +383,18 @@ class BlockSparseMoE(nn.Module): # dimensionality of a single expert. block_rows = padded_tokens // self.blocking blocks_per_row = self.ffn_dim // self.blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) + if self.offsets is None or block_rows > self.offsets_block_rows: + self.offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + self.offsets_block_rows = block_rows + offsets = self.offsets + else: + offsets = self.offsets[:block_rows] # Indices for the sparse matrix. The indices for # the intermediate matrix are dynamic depending @@ -375,7 +402,6 @@ class BlockSparseMoE(nn.Module): column_indices = ops.topology(padded_bins, self.blocking, block_rows, blocks_per_row) - # TODO(tgale): This is unused. Remove the need for this in stk. # For now, use meta init to save the device memory. data = torch.empty( column_indices.numel(), @@ -400,7 +426,7 @@ class BlockSparseMoE(nn.Module): def indices_and_padded_bins(self, selected_experts: torch.Tensor): # Sort the expert ids to produce the scatter/gather # indices for the permutation. - selected_experts = selected_experts.int() + # selected_experts = selected_experts.int() # returns bin_ids == num of experts for this sequence ? == unique selected experts? # and indices == how to sort tokens? @@ -418,8 +444,8 @@ class BlockSparseMoE(nn.Module): # position of each bin. # List of size num_experts - padded_tokens_per_expert = ops.round_up(tokens_per_expert, - self.blocking) + padded_tokens_per_expert = round_up(tokens_per_expert, + self.blocking) # padded_tokens_per_expert => [128, O, 128, ...] # Cumulative selected experts per token @@ -446,13 +472,7 @@ class BlockSparseMoE(nn.Module): # gate_logits: (sequence_length, n_experts) gate_logits = self.gate(x) - # all_probs: (sequence_length, n_experts) and upcast for softmax - all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) - # weights, selected_experts: (sequence_length, top-k) - weights, selected_experts = torch.topk(all_probs, self.top_k, dim=-1) - weights /= weights.sum(dim=-1, keepdim=True) - weights = weights.flatten().to(x.dtype) - selected_experts = selected_experts.flatten() + selected_experts, weights = select_experts(gate_logits, self.top_k) ( indices, @@ -465,7 +485,7 @@ class BlockSparseMoE(nn.Module): # Permute tokens and pad to prepare expert computation # (top_k * sequence_length + padding, model_dim) x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, - self.top_k, x.shape[0] + self.num_experts * self.blocking) + self.top_k) # Create the sparse matrix topology with torch.no_grad(): @@ -476,8 +496,8 @@ class BlockSparseMoE(nn.Module): # (top_k * sequence_length + padding, ffn_dim * n_experts) x = stk.Matrix( topo.size(), - self.act(stk.ops.sdd(x, self.w1.t(), topo).data) * - stk.ops.sdd(x, self.w3.t(), topo).data, + self.act(stk.ops.sdd(x, self.w1, topo).data) * + stk.ops.sdd(x, self.w3, topo).data, topo.row_indices, topo.column_indices, topo.offsets, @@ -512,18 +532,18 @@ class BlockSparseMoE(nn.Module): class MixtralLayer(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() - prefix = f"layers.{layer_id}" + prefix = f"model.layers.{layer_id}" - self.attention = MixtralAttention( - prefix=f"{prefix}.attention", config=config, weights=weights + self.self_attn = MixtralAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights ) self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights) - self.attention_norm = FastRMSNorm.load( - prefix=f"{prefix}.attention_norm", weights=weights, eps=config.rms_norm_eps + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) - self.ffn_norm = FastRMSNorm.load( - prefix=f"{prefix}.ffn_norm", + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps, ) @@ -542,10 +562,10 @@ class MixtralLayer(nn.Module): max_s, prefill_cache_indices, ): - normed_hidden_states, res = self.attention_norm(hidden_states, residual) + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) # Self Attention - attn_output = self.attention( + attn_output = self.self_attn( normed_hidden_states, cos, sin, @@ -559,21 +579,21 @@ class MixtralLayer(nn.Module): ) # faster post attention rms norm - normed_attn_res_output, attn_res = self.ffn_norm( + normed_attn_res_output, attn_res = self.post_attention_layernorm( attn_output, res ) - mlp_output = self.block_sparse_moe(normed_attn_res_output) + block_sparse_moe_output = self.block_sparse_moe(normed_attn_res_output) - return mlp_output, attn_res + return block_sparse_moe_output, attn_res class MixtralModel(torch.nn.Module): def __init__(self, config, weights): super().__init__() - self.tok_embeddings = TensorParallelEmbedding( - prefix="tok_embeddings", weights=weights + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights ) self.layers = nn.ModuleList( @@ -587,12 +607,12 @@ class MixtralModel(torch.nn.Module): ] ) self.norm = FastRMSNorm.load( - prefix="norm", weights=weights, eps=config.rms_norm_eps + prefix="model.norm", weights=weights, eps=config.rms_norm_eps ) - self.head_size = self.layers[0].attention.head_size - self.num_heads = self.layers[0].attention.num_heads - self.num_key_value_heads = self.layers[0].attention.num_key_value_heads + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads def forward( self, @@ -606,11 +626,11 @@ class MixtralModel(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: - hidden_states = self.tok_embeddings(input_ids) + hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( position_ids, max_s, hidden_states.dtype ) @@ -642,7 +662,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): self.model = MixtralModel(config, weights) self.lm_head = TensorParallelHead.load( config, - prefix="output", + prefix="lm_head", weights=weights, ) self.max_past = config.sliding_window diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index ebe066e3..b740976b 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -91,8 +91,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) + # from torch.profiler import profile, ProfilerActivity + # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof: generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) + # if self.model.rank == 0: + # prefill_prof.export_chrome_trace("prefill.json") return generate_pb2.PrefillResponse( generations=[generation.to_pb() for generation in generations], @@ -118,8 +122,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): else: batch = batches[0] + # from torch.profiler import profile, ProfilerActivity + # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prefill_prof: generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) + # if self.model.rank == 0: + # prefill_prof.export_chrome_trace("decode.json") return generate_pb2.DecodeResponse( generations=[generation.to_pb() for generation in generations], diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 05aaf277..23743c9b 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -98,10 +98,7 @@ def weight_files( if extension != ".safetensors": raise e # Try to see if there are pytorch weights - try: - pt_filenames = weight_hub_files(model_id, revision, extension=".bin") - except EntryNotFoundError: - pt_filenames = weight_hub_files(model_id, revision, extension=".pt") + pt_filenames = weight_hub_files(model_id, revision, extension=".bin") # Change pytorch extension to safetensors extension # It is possible that we have safetensors weights locally even though they are not on the # hub if we converted weights locally without pushing them diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 37c5f489..f3344988 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -23,7 +23,7 @@ class Weights: with safe_open(filename, framework="pytorch") as f: for k in f.keys(): if k in routing: - logger.warning( + raise RuntimeError( f"Key {k} was found in multiple files: {filename} and {routing[k]}" ) routing[k] = filename @@ -116,7 +116,7 @@ class Weights: size = slice_.get_shape()[dim] assert ( size % world_size == 0 - ), f"The chosen size {size} is not compatible with sharding on {world_size} shards" + ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) def _get_qweight(self, name: str):