revert silu

This commit is contained in:
Mohit Sharma 2024-12-10 10:41:40 +00:00
parent 2cca808c31
commit ca071bdd1d

View File

@ -398,31 +398,10 @@ class LlamaMLP(nn.Module):
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
# x = gate_up_states.view(-1, 1,self.intermediate_size)
# from loguru import logger
# logger.info(f"gate_up_states: {gate_up_states.shape}")
# x = self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
# logger.info(f"x: {x.shape}")
# return self.down_proj(
# x, adapter_data
# )
# gate_up_states: torch.Size([4096, 2, 14336])
# x: torch.Size([4096, 14336])
# gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
# x = gate_up_states.view(-1, 2, self.intermediate_size)
# # x = gate_up_states[:, 0] * self.act(gate_up_states[:, 1])
output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,)
out = torch.empty(
output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
ops.silu_and_mul(out, gate_up_states)
return self.down_proj(out, adapter_data)
class FlashLlamaLayer(nn.Module):