mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
revert silu
This commit is contained in:
parent
2cca808c31
commit
ca071bdd1d
@ -398,31 +398,10 @@ class LlamaMLP(nn.Module):
|
||||
return self.down_proj(out, adapter_data)
|
||||
else:
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
# x = gate_up_states.view(-1, 1,self.intermediate_size)
|
||||
# from loguru import logger
|
||||
# logger.info(f"gate_up_states: {gate_up_states.shape}")
|
||||
# x = self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
||||
# logger.info(f"x: {x.shape}")
|
||||
|
||||
# return self.down_proj(
|
||||
# x, adapter_data
|
||||
# )
|
||||
|
||||
# gate_up_states: torch.Size([4096, 2, 14336])
|
||||
# x: torch.Size([4096, 14336])
|
||||
|
||||
# gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
# x = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
# # x = gate_up_states[:, 0] * self.act(gate_up_states[:, 1])
|
||||
|
||||
output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,)
|
||||
|
||||
out = torch.empty(
|
||||
output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||
)
|
||||
ops.silu_and_mul(out, gate_up_states)
|
||||
|
||||
return self.down_proj(out, adapter_data)
|
||||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user