mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Fix f180
This commit is contained in:
parent
211b54ac41
commit
7b88baddf7
@ -241,19 +241,21 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
|
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
num_heads = config.n_head
|
num_heads = config.n_head
|
||||||
num_heads_kv = config.n_head_kv
|
# num_heads_kv = config.n_head_kv
|
||||||
|
num_groups = config.n_head_kv
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
|
self.num_groups = num_groups
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
||||||
)
|
)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
self.num_groups = num_heads // (num_heads_kv * 2)
|
# self.num_groups = num_heads // (num_heads_kv * 2)
|
||||||
self.num_heads = num_heads // self.num_groups
|
self.num_heads = num_heads // self.num_groups
|
||||||
self.num_heads_kv = num_heads_kv // self.num_groups
|
# self.num_heads_kv = num_heads_kv // self.num_groups
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
|
|
||||||
if process_group.size() > self.num_groups:
|
if process_group.size() > self.num_groups:
|
||||||
@ -264,6 +266,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
|
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_groups = self.num_groups // process_group.size()
|
self.num_groups = self.num_groups // process_group.size()
|
||||||
|
|
||||||
self.query_key_value = TensorParallelColumnLinear.load(
|
self.query_key_value = TensorParallelColumnLinear.load(
|
||||||
|
Loading…
Reference in New Issue
Block a user