mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Fix experts issue
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
2dadceaf07
commit
a3967a57bc
File diff suppressed because it is too large
Load Diff
@ -61,6 +61,11 @@ from text_generation_server.utils.weights import (
|
|||||||
)
|
)
|
||||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
|
||||||
|
def torch_save(tensor, name):
|
||||||
|
# Only save on the main process (rank 0) when using distributed training
|
||||||
|
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
||||||
|
torch.save(tensor, name)
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix: str, weights, layer_id):
|
def load_attention(config, prefix: str, weights, layer_id):
|
||||||
# Only defined in granite.
|
# Only defined in granite.
|
||||||
@ -377,7 +382,7 @@ class LlamaMLP(nn.Module):
|
|||||||
class FlashLlamaLayer(nn.Module):
|
class FlashLlamaLayer(nn.Module):
|
||||||
def __init__(self, index, prefix, config, weights):
|
def __init__(self, index, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.index = index
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.self_attn = FlashLlamaAttention(
|
self.self_attn = FlashLlamaAttention(
|
||||||
index=index,
|
index=index,
|
||||||
@ -438,6 +443,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
|
run_index,
|
||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
@ -454,6 +460,10 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
adapter_data,
|
adapter_data,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if run_index != -1:
|
||||||
|
torch_save(attn_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.attention.attention_states.pt")
|
||||||
|
|
||||||
if self.residual_multiplier is not None:
|
if self.residual_multiplier is not None:
|
||||||
attn_output *= self.residual_multiplier
|
attn_output *= self.residual_multiplier
|
||||||
|
|
||||||
@ -462,6 +472,10 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||||
|
if run_index != -1:
|
||||||
|
torch_save(mlp_output, f"trans.{run_index}.Llama4TextDecoderLayer.{self.index}.mlp.pt")
|
||||||
|
|
||||||
|
|
||||||
if self.residual_multiplier is not None:
|
if self.residual_multiplier is not None:
|
||||||
mlp_output *= self.residual_multiplier
|
mlp_output *= self.residual_multiplier
|
||||||
|
|
||||||
@ -471,7 +485,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
class FlashLlamaModel(torch.nn.Module):
|
class FlashLlamaModel(torch.nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.run_index = -1
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
@ -568,11 +582,12 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
seqlen,
|
seqlen,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
|
self.run_index,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
self.run_index += 1
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user