mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
faster
This commit is contained in:
parent
56296cc43c
commit
cfc89bb396
@ -244,7 +244,9 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), reduce=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CohereMLP(nn.Module):
|
class CohereMLP(nn.Module):
|
||||||
@ -282,7 +284,9 @@ class CohereMLP(nn.Module):
|
|||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashCohereLayer(nn.Module):
|
class FlashCohereLayer(nn.Module):
|
||||||
@ -299,6 +303,7 @@ class FlashCohereLayer(nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_eps,
|
eps=config.layer_norm_eps,
|
||||||
)
|
)
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -331,6 +336,9 @@ class FlashCohereLayer(nn.Module):
|
|||||||
mlp_output = self.mlp(normed_hidden_states)
|
mlp_output = self.mlp(normed_hidden_states)
|
||||||
output = attn_output + mlp_output
|
output = attn_output + mlp_output
|
||||||
|
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(output, group=self.process_group)
|
||||||
|
|
||||||
return output, res
|
return output, res
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user