mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
revert silu
This commit is contained in:
parent
2cca808c31
commit
ca071bdd1d
@ -398,31 +398,10 @@ class LlamaMLP(nn.Module):
|
|||||||
return self.down_proj(out, adapter_data)
|
return self.down_proj(out, adapter_data)
|
||||||
else:
|
else:
|
||||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
# x = gate_up_states.view(-1, 1,self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
# from loguru import logger
|
return self.down_proj(
|
||||||
# logger.info(f"gate_up_states: {gate_up_states.shape}")
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
# x = self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
|
||||||
# logger.info(f"x: {x.shape}")
|
|
||||||
|
|
||||||
# return self.down_proj(
|
|
||||||
# x, adapter_data
|
|
||||||
# )
|
|
||||||
|
|
||||||
# gate_up_states: torch.Size([4096, 2, 14336])
|
|
||||||
# x: torch.Size([4096, 14336])
|
|
||||||
|
|
||||||
# gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
|
||||||
# x = gate_up_states.view(-1, 2, self.intermediate_size)
|
|
||||||
# # x = gate_up_states[:, 0] * self.act(gate_up_states[:, 1])
|
|
||||||
|
|
||||||
output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,)
|
|
||||||
|
|
||||||
out = torch.empty(
|
|
||||||
output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device
|
|
||||||
)
|
)
|
||||||
ops.silu_and_mul(out, gate_up_states)
|
|
||||||
|
|
||||||
return self.down_proj(out, adapter_data)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaLayer(nn.Module):
|
class FlashLlamaLayer(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user