mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
Fix experts issue
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
2dadceaf07
commit
a3967a57bc
File diff suppressed because it is too large
Load Diff
@ -61,6 +61,11 @@ from text_generation_server.utils.weights import (
|
||||
)
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
|
||||
def torch_save(tensor, name):
|
||||
# Only save on the main process (rank 0) when using distributed training
|
||||
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
||||
torch.save(tensor, name)
|
||||
|
||||
|
||||
def load_attention(config, prefix: str, weights, layer_id):
|
||||
# Only defined in granite.
|
||||
@ -377,7 +382,7 @@ class LlamaMLP(nn.Module):
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, index, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.index = index
|
||||
with no_fp8(weights):
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
index=index,
|
||||
@ -438,6 +443,7 @@ class FlashLlamaLayer(nn.Module):
|
||||
seqlen,
|
||||
adapter_data,
|
||||
cross_attention_states,
|
||||
run_index,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
@ -454,6 +460,10 @@ class FlashLlamaLayer(nn.Module):
|
||||
adapter_data,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
|
||||
if run_index != -1:
|
||||
torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_states.pt")
|
||||
|
||||
if self.residual_multiplier is not None:
|
||||
attn_output *= self.residual_multiplier
|
||||
|
||||
@ -462,6 +472,10 @@ class FlashLlamaLayer(nn.Module):
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||
if run_index != -1:
|
||||
torch_save(mlp_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.mlp.pt")
|
||||
|
||||
|
||||
if self.residual_multiplier is not None:
|
||||
mlp_output *= self.residual_multiplier
|
||||
|
||||
@ -471,7 +485,7 @@ class FlashLlamaLayer(nn.Module):
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.run_index = -1
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
@ -568,11 +582,12 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
seqlen,
|
||||
adapter_data,
|
||||
cross_attention_states,
|
||||
self.run_index,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
self.run_index += 1
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user