From ca071bdd1d34cd7fab8e9430b2173cecb7e4664e Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 10 Dec 2024 10:41:40 +0000 Subject: [PATCH] revert silu --- .../custom_modeling/flash_llama_modeling.py | 27 +++---------------- 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index b093bcd7..10309006 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -398,31 +398,10 @@ class LlamaMLP(nn.Module): return self.down_proj(out, adapter_data) else: gate_up_states = self.gate_up_proj(hidden_states, adapter_data) - # x = gate_up_states.view(-1, 1,self.intermediate_size) - # from loguru import logger - # logger.info(f"gate_up_states: {gate_up_states.shape}") - # x = self.act(gate_up_states[:, 0]) * gate_up_states[:, 1] - # logger.info(f"x: {x.shape}") - - # return self.down_proj( - # x, adapter_data - # ) - - # gate_up_states: torch.Size([4096, 2, 14336]) - # x: torch.Size([4096, 14336]) - - # gate_up_states = self.gate_up_proj(hidden_states, adapter_data) - # x = gate_up_states.view(-1, 2, self.intermediate_size) - # # x = gate_up_states[:, 0] * self.act(gate_up_states[:, 1]) - - output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,) - - out = torch.empty( - output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data ) - ops.silu_and_mul(out, gate_up_states) - - return self.down_proj(out, adapter_data) class FlashLlamaLayer(nn.Module):