mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
Add save
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
ccddbba752
commit
dafc597a8b
@ -71,6 +71,11 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE
|
||||
_CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B"
|
||||
_CONFIG_FOR_DOC = "Llama4Config"
|
||||
|
||||
def torch_save(tensor, name):
|
||||
# Only save on the main process (rank 0) when using distributed training
|
||||
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
||||
torch.save(tensor, name)
|
||||
|
||||
|
||||
class Llama4TextExperts(nn.Module):
|
||||
def __init__(self, prefix, config: Llama4TextConfig, weights):
|
||||
@ -233,10 +238,11 @@ class Llama4TextMoe(nn.Module):
|
||||
|
||||
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
seq_len, hidden_dim = hidden_states.shape
|
||||
#seq_len, hidden_dim = hidden_states.shape
|
||||
print(f"hidden_states.shape: {hidden_states.shape}")
|
||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
tokens_per_expert = hidden_states.shape[0]
|
||||
router_logits = self.router(hidden_states)
|
||||
tokens_per_expert = seq_len
|
||||
|
||||
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
|
||||
router_scores = (
|
||||
@ -536,18 +542,18 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
else:
|
||||
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx)
|
||||
|
||||
self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights)
|
||||
self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights)
|
||||
# self.input_layernorm = FastRMSNorm.load(
|
||||
# prefix=f"{prefix}.input_layernorm",
|
||||
# weights=weights,
|
||||
# eps=config.rms_norm_eps,
|
||||
# )
|
||||
# self.post_attention_layernorm = FastRMSNorm.load(
|
||||
# prefix=f"{prefix}.post_attention_layernorm",
|
||||
# weights=weights,
|
||||
# eps=config.rms_norm_eps,
|
||||
# )
|
||||
#self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights)
|
||||
#self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights)
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.post_attention_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
@ -566,7 +572,7 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states, _ = self.input_layernorm(hidden_states)
|
||||
|
||||
attention_states = self.self_attn(
|
||||
hidden_states,
|
||||
@ -584,7 +590,7 @@ class Llama4TextDecoderLayer(nn.Module):
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.feed_forward(hidden_states, adapter_data)
|
||||
hidden_states = residual + hidden_states.view(residual.shape)
|
||||
#outputs = (hidden_states,)
|
||||
@ -642,9 +648,9 @@ class Llama4TextModel(nn.Module):
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.run_index = 0
|
||||
|
||||
|
||||
self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
|
||||
#self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
@ -660,6 +666,8 @@ class Llama4TextModel(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.input.hidden_states.pt")
|
||||
log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}")
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
@ -679,7 +687,12 @@ class Llama4TextModel(nn.Module):
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.layers.hidden_states.pt")
|
||||
log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}")
|
||||
hidden_states, _ = self.norm(hidden_states)
|
||||
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.norm.hidden_states.pt")
|
||||
log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}")
|
||||
self.run_index += 1
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
@ -1520,6 +1520,10 @@ class FlashCausalLM(Model):
|
||||
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
|
||||
logger.info("skip warmup hpu graph, not recommmended")
|
||||
del _batch, batch
|
||||
print(f"max_input_tokens: {max_input_tokens}")
|
||||
print(f"max_total_tokens: {max_total_tokens}")
|
||||
print(f"num_blocks: {num_blocks}")
|
||||
print(f"BLOCK_SIZE: {BLOCK_SIZE}")
|
||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||
|
||||
self.warmup_hpu_graph(batch)
|
||||
|
Loading…
Reference in New Issue
Block a user