mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Tp ready.
This commit is contained in:
parent
eaf9448b48
commit
21c15d576d
@ -410,6 +410,7 @@ class IdeficsAttention(nn.Module):
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads //= weights.process_group.size()
|
||||
|
||||
if self.is_cross_attention:
|
||||
# kv_input_dim = (
|
||||
@ -440,7 +441,7 @@ class IdeficsAttention(nn.Module):
|
||||
# self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
# prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
# )
|
||||
self.rotary_emb = IdeficsEmbedding(self.head_dim, device="cuda:0") #TO Verify, i did not replace by since it looks like it is specfic to `PositionRotaryEmbedding` and flash
|
||||
self.rotary_emb = IdeficsEmbedding(self.head_dim, device=weights.device) #TO Verify, i did not replace by since it looks like it is specfic to `PositionRotaryEmbedding` and flash
|
||||
|
||||
self.qk_layer_norms = qk_layer_norms
|
||||
if self.qk_layer_norms:
|
||||
@ -525,7 +526,7 @@ class IdeficsAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
@ -1315,4 +1316,4 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
return reordered_past
|
||||
|
@ -159,6 +159,7 @@ class IdeficsPerceiverAttention(nn.Module):
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.n_heads //= weights.process_group.size()
|
||||
|
||||
# Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
|
@ -132,6 +132,9 @@ class IdeficsVisionAttention(nn.Module):
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||
|
||||
|
||||
self.k_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
|
||||
@ -158,7 +161,7 @@ class IdeficsVisionAttention(nn.Module):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
@ -221,7 +224,7 @@ class IdeficsVisionAttention(nn.Module):
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user