mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
faster
This commit is contained in:
parent
56296cc43c
commit
cfc89bb396
@ -244,7 +244,9 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
return self.o_proj(
|
||||
attn_output.view(-1, self.num_heads * self.head_size), reduce=False
|
||||
)
|
||||
|
||||
|
||||
class CohereMLP(nn.Module):
|
||||
@ -282,7 +284,9 @@ class CohereMLP(nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=False
|
||||
)
|
||||
|
||||
|
||||
class FlashCohereLayer(nn.Module):
|
||||
@ -299,6 +303,7 @@ class FlashCohereLayer(nn.Module):
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.process_group = weights.process_group
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -331,6 +336,9 @@ class FlashCohereLayer(nn.Module):
|
||||
mlp_output = self.mlp(normed_hidden_states)
|
||||
output = attn_output + mlp_output
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(output, group=self.process_group)
|
||||
|
||||
return output, res
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user