mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Fix f180
This commit is contained in:
parent
211b54ac41
commit
7b88baddf7
@ -241,19 +241,21 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
num_heads = config.n_head
|
||||
num_heads_kv = config.n_head_kv
|
||||
# num_heads_kv = config.n_head_kv
|
||||
num_groups = config.n_head_kv
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
self.num_groups = num_groups
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
self.num_groups = num_heads // (num_heads_kv * 2)
|
||||
# self.num_groups = num_heads // (num_heads_kv * 2)
|
||||
self.num_heads = num_heads // self.num_groups
|
||||
self.num_heads_kv = num_heads_kv // self.num_groups
|
||||
# self.num_heads_kv = num_heads_kv // self.num_groups
|
||||
process_group = weights.process_group
|
||||
|
||||
if process_group.size() > self.num_groups:
|
||||
@ -264,6 +266,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
raise NotImplementedError(
|
||||
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
|
||||
)
|
||||
|
||||
self.num_groups = self.num_groups // process_group.size()
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear.load(
|
||||
|
Loading…
Reference in New Issue
Block a user