diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index de53350c..f7bef8c3 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -110,7 +110,7 @@ class Llama4TextExperts(nn.Module): self.intermediate_size = config.intermediate_size // weights.process_group.size() self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(weights.get_sharded(f"{prefix}.gate_up_proj", dim=1), requires_grad=False) + self.gate_up_proj = nn.Parameter(weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2), requires_grad=False) synchronize(weights.device) real_free_memory = get_free_memory(weights.device, 1) log_master( @@ -119,7 +119,7 @@ class Llama4TextExperts(nn.Module): ) - self.down_proj = nn.Parameter(weights.get_sharded(f"{prefix}.down_proj", dim=0), requires_grad=False) + self.down_proj = nn.Parameter(weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False) synchronize(weights.device) real_free_memory = get_free_memory(weights.device, 1) log_master( @@ -144,23 +144,23 @@ class Llama4TextExperts(nn.Module): torch.Tensor """ gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2*self.expert_dim) - if run_index != -1: + if run_index == 0: torch_save(gate_up_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate_up_proj.pt") down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1) - if run_index != -1: + if run_index == 0: torch_save(down_proj, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.down_proj.pt") hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - if run_index != -1: + if run_index == 0: torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.hidden_states.pt") gate_up = torch.bmm(hidden_states, gate_up_proj) gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors - if run_index != -1: + if run_index == 0: torch_save(gate, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.gate.pt") torch_save(up, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.up.pt") @@ -169,7 +169,7 @@ class Llama4TextExperts(nn.Module): next_states = next_states.view(-1, self.hidden_size) - if run_index != -1: + if run_index == 0: torch_save(next_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.expert.next_states.pt") # Reduce sum @@ -303,7 +303,7 @@ class Llama4TextMoe(nn.Module): # assert isinstance(self.experts, MoELayer) - self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights) + self.experts = Llama4TextExperts(config=config, prefix=f"{prefix}.experts", weights=weights, layer_idx=layer_idx) synchronize(weights.device) real_free_memory = get_free_memory(weights.device, 1) log_master( @@ -319,7 +319,7 @@ class Llama4TextMoe(nn.Module): logger.debug, f"TextMode2 Free memory real: {real_free_memory / 1e9:.2f}GB" ) - self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights) + self.shared_expert = Llama4TextMLP(config=config, prefix=f"{prefix}.shared_expert", weights=weights, layer_idx=layer_idx) synchronize(weights.device) real_free_memory = get_free_memory(weights.device, 1) log_master( @@ -334,8 +334,8 @@ class Llama4TextMoe(nn.Module): hidden_states = hidden_states.view(-1, self.hidden_dim) tokens_per_expert = hidden_states.shape[0] router_logits = self.router(hidden_states) - if run_index != -1: - torch_save(router_logits, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_logits.pt") + #if run_index != -1: + # torch_save(router_logits, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_logits.pt") router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) router_scores = ( @@ -347,8 +347,8 @@ class Llama4TextMoe(nn.Module): torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1) ) router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) - if run_index != -1: - torch_save(router_scores, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.router_scores.pt") + #if run_index != -1: + # torch_save(router_scores, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.router_scores.pt") router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim) @@ -357,20 +357,20 @@ class Llama4TextMoe(nn.Module): dim=0, index=router_indices, ).to(hidden_states.device) - if run_index != -1: - torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.gather.pt") + #if run_index != -1: + # torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.gather.pt") # we gather inputs corresponding to each expert based on the router indices routed_in = routed_in * router_scores.reshape(-1, 1) - if run_index != -1: - torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_in.pt") + #if run_index != -1: + # torch_save(routed_in, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_in.pt") routed_out = self.experts(routed_in, run_index) - if run_index != -1: - torch_save(routed_out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_out.pt") + #if run_index != -1: + # torch_save(routed_out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.routed_out.pt") out = self.shared_expert(hidden_states, run_index, reduce=False) - if run_index != -1: - torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.out.pt") + #if run_index != -1: + # torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.out.pt") # now that we finished expert computation -> we scatter add because we gathered previously # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound # this scales a lot better if you do EP! @@ -381,8 +381,8 @@ class Llama4TextMoe(nn.Module): # if self.process_group.size() > 1: # torch.distributed.all_reduce(out, group=self.process_group) - if run_index != -1: - torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.add.out.pt") + #if run_index != -1: + # torch_save(out, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.moe.add.out.pt") return out @@ -689,10 +689,10 @@ class Llama4TextAttention(FlashLlamaAttention): # key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) # value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) - if run_index != -1: - torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.query_states.pt") - torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.key_states.pt") - torch_save(value_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.value_states.pt") + #if run_index != -1: + # torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.query_states.pt") + # torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.key_states.pt") + # torch_save(value_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.value_states.pt") if self.use_rope: # the 16E model skips rope for long context on certain layers #self.rotary_emb(query_states, torch.select(kv_states, dim=1, index=0), cos, sin) @@ -703,18 +703,18 @@ class Llama4TextAttention(FlashLlamaAttention): - if run_index != -1: - torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.query_states.pt") - torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.key_states.pt") + #if run_index != -1: + # torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.query_states.pt") + # torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.emb.key_states.pt") if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm query_states = self.qk_norm(query_states) key_states = self.qk_norm(key_states) - if run_index != -1: - torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.query_states.pt") - torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.key_states.pt") + #if run_index != -1: + # torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.query_states.pt") + # torch_save(key_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.qk_norm.key_states.pt") # query_states = query_states.view(-1, self.num_heads, self.head_size) @@ -742,9 +742,9 @@ class Llama4TextAttention(FlashLlamaAttention): #seq_len = input_shape / bs attn_scales = attn_scales.view(*input_shape, 1, 1) query_states = (query_states * attn_scales).to(query_states.dtype) - if run_index != -1: - torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attn_scales.query_states.pt") - torch_save(attention_mask, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_mask.pt") + #if run_index != -1: + # torch_save(query_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attn_scales.query_states.pt") + # torch_save(attention_mask, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_mask.pt") # Prefill @@ -806,8 +806,8 @@ class Llama4TextAttention(FlashLlamaAttention): ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() - if run_index != -1: - torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.reshape.attn_output.pt") + #if run_index != -1: + # torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.reshape.attn_output.pt") attn_output = self.o_proj(attn_output) return attn_output @@ -873,11 +873,11 @@ class Llama4TextDecoderLayer(nn.Module): run_index: int = 0, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states - if run_index != -1: - torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input.hidden_states.pt") + #if run_index != -1: + # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input.hidden_states.pt") hidden_states = self.input_layernorm(hidden_states) - if run_index != -1: - torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input_layernorm.hidden_states.pt") + #if run_index != -1: + # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.input_layernorm.hidden_states.pt") # use local attention mask for ROPE layers if self.use_chunked_attention and chunk_causal_mask is not None: @@ -896,24 +896,24 @@ class Llama4TextDecoderLayer(nn.Module): position_ids=position_ids, hpu_attention_meta=hpu_attention_meta, ) - if run_index != -1: - torch_save(attention_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.attention_states.pt") + #if run_index != -1: + # torch_save(attention_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.attention_states.pt") hidden_states = residual + attention_states - if run_index != -1: - torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.hidden_states.pt") + #if run_index != -1: + # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.attention.hidden_states.pt") # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - if run_index != -1: - torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.post_attention_layernorm.hidden_states.pt") + #if run_index != -1: + # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.post_attention_layernorm.hidden_states.pt") hidden_states = self.feed_forward(hidden_states, adapter_data, run_index) - if run_index != -1: - torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.feed_forward.hidden_states.pt") + #if run_index != -1: + # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.feed_forward.hidden_states.pt") hidden_states = residual + hidden_states.view(residual.shape) - if run_index != -1: - torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.output.hidden_states.pt") + #if run_index != -1: + # torch_save(hidden_states, f"trans.{run_index}.Llama4TextDecoderLayer.{self.layer_idx}.output.hidden_states.pt") #outputs = (hidden_states,) return hidden_states # if residual is None: @@ -988,8 +988,8 @@ class Llama4TextModel(nn.Module): ) -> torch.Tensor: hidden_states = inputs_embeds - if self.run_index != -1: - torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.input.hidden_states.pt") + #if self.run_index != -1: + # torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.input.hidden_states.pt") log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}") # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -1030,11 +1030,11 @@ class Llama4TextModel(nn.Module): run_index=self.run_index, ) - if self.run_index != -1: + if self.run_index == 0: torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.layers.hidden_states.pt") log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}") hidden_states, _ = self.norm(hidden_states) - if self.run_index != -1: + if self.run_index == 0: torch_save(hidden_states, f"trans.{self.run_index}.Llama4TextModel.norm.hidden_states.pt") log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}") self.run_index += 1 @@ -2014,4 +2014,4 @@ class Llama4ForConditionalGeneration(nn.Module): attention_mask ) - return logits, speculative_logits \ No newline at end of file + return logits, speculative_logits diff --git a/backends/gaudi/server/text_generation_server/utils/weights.py b/backends/gaudi/server/text_generation_server/utils/weights.py index acd598d7..a16b503a 100644 --- a/backends/gaudi/server/text_generation_server/utils/weights.py +++ b/backends/gaudi/server/text_generation_server/utils/weights.py @@ -303,7 +303,7 @@ class Weights: world_size = self.process_group.size() rank = self.process_group.rank() - tensors = [] + tensors_slices = [] block_offset = 0 for block_size in block_sizes: assert ( @@ -312,15 +312,19 @@ class Weights: shard_block_size = block_size // world_size start = rank * shard_block_size stop = (rank + 1) * shard_block_size - if dim == 0: - tensor = slice_[block_offset + start : block_offset + stop] - elif dim == 1: - tensor = slice_[:, block_offset + start : block_offset + stop] - else: - raise NotImplementedError("Currently only dim=0 or dim=1 is supported") - tensors.append(tensor) + tensors_slices += range(block_offset + start, block_offset + stop) block_offset += block_size - tensor = torch.cat(tensors, dim=dim) + + + if dim == 0: + tensor = slice_[tensors_slices, ...] + elif dim == 1 or dim == -2: + tensor = slice_[:, tensors_slices, ...] + elif dim == 2 or dim == -1: + tensor = slice_[..., tensors_slices] + else: + raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported") + tensor = tensor.to(device=self.device) # Avoid casting quantizer dtypes.