mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Add save
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
ccddbba752
commit
dafc597a8b
@ -71,6 +71,11 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE
|
|||||||
_CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B"
|
_CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B"
|
||||||
_CONFIG_FOR_DOC = "Llama4Config"
|
_CONFIG_FOR_DOC = "Llama4Config"
|
||||||
|
|
||||||
|
def torch_save(tensor, name):
|
||||||
|
# Only save on the main process (rank 0) when using distributed training
|
||||||
|
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
||||||
|
torch.save(tensor, name)
|
||||||
|
|
||||||
|
|
||||||
class Llama4TextExperts(nn.Module):
|
class Llama4TextExperts(nn.Module):
|
||||||
def __init__(self, prefix, config: Llama4TextConfig, weights):
|
def __init__(self, prefix, config: Llama4TextConfig, weights):
|
||||||
@ -233,10 +238,11 @@ class Llama4TextMoe(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, adapter_data):
|
def forward(self, hidden_states, adapter_data):
|
||||||
seq_len, hidden_dim = hidden_states.shape
|
#seq_len, hidden_dim = hidden_states.shape
|
||||||
|
print(f"hidden_states.shape: {hidden_states.shape}")
|
||||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||||
|
tokens_per_expert = hidden_states.shape[0]
|
||||||
router_logits = self.router(hidden_states)
|
router_logits = self.router(hidden_states)
|
||||||
tokens_per_expert = seq_len
|
|
||||||
|
|
||||||
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
|
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
|
||||||
router_scores = (
|
router_scores = (
|
||||||
@ -536,18 +542,18 @@ class Llama4TextDecoderLayer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx)
|
self.feed_forward = LlamaMLP(f"{prefix}.feed_forward", config, weights, layer_idx)
|
||||||
|
|
||||||
self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights)
|
#self.input_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.input_layernorm", config=config, weights=weights)
|
||||||
self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights)
|
#self.post_attention_layernorm = Llama4TextRMSNorm(prefix=f"{prefix}.post_attention_layernorm", config=config, weights=weights)
|
||||||
# self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
# prefix=f"{prefix}.input_layernorm",
|
prefix=f"{prefix}.input_layernorm",
|
||||||
# weights=weights,
|
weights=weights,
|
||||||
# eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
# )
|
)
|
||||||
# self.post_attention_layernorm = FastRMSNorm.load(
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
# prefix=f"{prefix}.post_attention_layernorm",
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
# weights=weights,
|
weights=weights,
|
||||||
# eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
# )
|
)
|
||||||
|
|
||||||
|
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
@ -566,7 +572,7 @@ class Llama4TextDecoderLayer(nn.Module):
|
|||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states, _ = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
attention_states = self.self_attn(
|
attention_states = self.self_attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -584,7 +590,7 @@ class Llama4TextDecoderLayer(nn.Module):
|
|||||||
# Fully Connected
|
# Fully Connected
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||||
hidden_states = self.feed_forward(hidden_states, adapter_data)
|
hidden_states = self.feed_forward(hidden_states, adapter_data)
|
||||||
hidden_states = residual + hidden_states.view(residual.shape)
|
hidden_states = residual + hidden_states.view(residual.shape)
|
||||||
#outputs = (hidden_states,)
|
#outputs = (hidden_states,)
|
||||||
@ -642,9 +648,9 @@ class Llama4TextModel(nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
|
self.run_index = 0
|
||||||
|
|
||||||
|
#self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
|
||||||
self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -660,6 +666,8 @@ class Llama4TextModel(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.input.hidden_states.pt")
|
||||||
|
log_master(logger.debug, f"inputs_embeds.shape={inputs_embeds.shape}")
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||||
@ -679,7 +687,12 @@ class Llama4TextModel(nn.Module):
|
|||||||
hpu_attention_meta=hpu_attention_meta,
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.layers.hidden_states.pt")
|
||||||
|
log_master(logger.debug, f"hidden_states.shape={hidden_states.shape}")
|
||||||
|
hidden_states, _ = self.norm(hidden_states)
|
||||||
|
torch_save(hidden_states, f"tgi.{self.run_index}.Llama4TextModel.norm.hidden_states.pt")
|
||||||
|
log_master(logger.debug, f"normalized hidden_states.shape={hidden_states.shape}")
|
||||||
|
self.run_index += 1
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@ -1520,6 +1520,10 @@ class FlashCausalLM(Model):
|
|||||||
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
|
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
|
||||||
logger.info("skip warmup hpu graph, not recommmended")
|
logger.info("skip warmup hpu graph, not recommmended")
|
||||||
del _batch, batch
|
del _batch, batch
|
||||||
|
print(f"max_input_tokens: {max_input_tokens}")
|
||||||
|
print(f"max_total_tokens: {max_total_tokens}")
|
||||||
|
print(f"num_blocks: {num_blocks}")
|
||||||
|
print(f"BLOCK_SIZE: {BLOCK_SIZE}")
|
||||||
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
|
||||||
|
|
||||||
self.warmup_hpu_graph(batch)
|
self.warmup_hpu_graph(batch)
|
||||||
|
Loading…
Reference in New Issue
Block a user