Tp ready.

This commit is contained in:
Nicolas Patry 2023-08-14 17:05:21 +00:00
parent eaf9448b48
commit 21c15d576d
3 changed files with 10 additions and 5 deletions

View File

@ -410,6 +410,7 @@ class IdeficsAttention(nn.Module):
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads //= weights.process_group.size()
if self.is_cross_attention:
# kv_input_dim = (
@ -440,7 +441,7 @@ class IdeficsAttention(nn.Module):
# self.rotary_emb = PositionRotaryEmbedding.load(
# prefix=f"{prefix}.rotary_emb", weights=weights
# )
self.rotary_emb = IdeficsEmbedding(self.head_dim, device="cuda:0") #TO Verify, i did not replace by since it looks like it is specfic to `PositionRotaryEmbedding` and flash
self.rotary_emb = IdeficsEmbedding(self.head_dim, device=weights.device) #TO Verify, i did not replace by since it looks like it is specfic to `PositionRotaryEmbedding` and flash
self.qk_layer_norms = qk_layer_norms
if self.qk_layer_norms:
@ -525,7 +526,7 @@ class IdeficsAttention(nn.Module):
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)

View File

@ -159,6 +159,7 @@ class IdeficsPerceiverAttention(nn.Module):
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.n_heads //= weights.process_group.size()
# Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
self.q_proj = TensorParallelColumnLinear.load(

View File

@ -132,6 +132,9 @@ class IdeficsVisionAttention(nn.Module):
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.k_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
@ -158,7 +161,7 @@ class IdeficsVisionAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scale
@ -221,7 +224,7 @@ class IdeficsVisionAttention(nn.Module):
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)