diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 42c079d5..cbed601b 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -410,6 +410,7 @@ class IdeficsAttention(nn.Module): f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"and `num_shards`: {weights.process_group.size()}" ) + self.num_heads //= weights.process_group.size() if self.is_cross_attention: # kv_input_dim = ( @@ -440,7 +441,7 @@ class IdeficsAttention(nn.Module): # self.rotary_emb = PositionRotaryEmbedding.load( # prefix=f"{prefix}.rotary_emb", weights=weights # ) - self.rotary_emb = IdeficsEmbedding(self.head_dim, device="cuda:0") #TO Verify, i did not replace by since it looks like it is specfic to `PositionRotaryEmbedding` and flash + self.rotary_emb = IdeficsEmbedding(self.head_dim, device=weights.device) #TO Verify, i did not replace by since it looks like it is specfic to `PositionRotaryEmbedding` and flash self.qk_layer_norms = qk_layer_norms if self.qk_layer_norms: @@ -525,7 +526,7 @@ class IdeficsAttention(nn.Module): ) attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) @@ -1315,4 +1316,4 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): reordered_past = () for layer_past in past: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past \ No newline at end of file + return reordered_past diff --git a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py index c946ee7b..c0e5b400 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py +++ b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py @@ -159,6 +159,7 @@ class IdeficsPerceiverAttention(nn.Module): f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} " f"and `num_shards`: {weights.process_group.size()}" ) + self.n_heads //= weights.process_group.size() # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers). self.q_proj = TensorParallelColumnLinear.load( diff --git a/server/text_generation_server/models/custom_modeling/idefics_vision.py b/server/text_generation_server/models/custom_modeling/idefics_vision.py index e42f5c1f..6caf2918 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_vision.py +++ b/server/text_generation_server/models/custom_modeling/idefics_vision.py @@ -132,6 +132,9 @@ class IdeficsVisionAttention(nn.Module): f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"and `num_shards`: {weights.process_group.size()}" ) + self.num_heads = self.num_heads // weights.process_group.size() + self.embed_dim = self.embed_dim // weights.process_group.size() + self.k_proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.k_proj", weights=weights, bias=True @@ -158,7 +161,7 @@ class IdeficsVisionAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - bsz, tgt_len, embed_dim = hidden_states.size() + bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scale @@ -221,7 +224,7 @@ class IdeficsVisionAttention(nn.Module): attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output)