diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index caf11fb0..54bbecf7 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -287,27 +287,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) -# this was adapted from LlamaRMSNorm -# class IdeficsRMSNorm(nn.Module): -# def __init__(self, prefix, weights, eps=1e-6): -# """ -# IdeficsRMSNorm is equivalent to T5LayerNorm -# """ -# super().__init__() -# -# weight = weights.get_tensor(f"{prefix}.weight") -# self.weight = nn.Parameter(weight) -# self.variance_epsilon = eps -# -# def forward(self, hidden_states): -# variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) -# hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) -# -# # convert into half-precision if necessary -# if self.weight.dtype in [torch.float16, torch.bfloat16]: -# hidden_states = hidden_states.to(self.weight.dtype) -# -# return self.weight * hidden_states class IdeficsRMSNorm(nn.Module): def __init__(self, prefix, weights, eps=1e-6): """ @@ -371,56 +350,6 @@ class IdeficsRMSNorm(nn.Module): return normed_hidden_states -# this was adapted from LlamaRotaryEmbedding -class IdeficsEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# def apply_rotary_pos_emb(q, k, cos, sin, position_ids): -# gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] -# gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) -# cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) -# sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) -# q_embed = (q * cos) + (rotate_half(q) * sin) -# k_embed = (k * cos) + (rotate_half(k) * sin) -# return q_embed, k_embed - - # this was adapted from LlamaMLP class IdeficsMLP(nn.Module): def __init__( @@ -430,19 +359,23 @@ class IdeficsMLP(nn.Module): weights, ): super().__init__() - self.gate_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.gate_proj", weights=weights, bias=False, + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, ) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, ) - self.up_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.up_proj", weights=weights, bias=False, - ) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + shape = gate_up_states.shape + gate_up_states = gate_up_states.view(*shape[:-1], 2, shape[-1] // 2) + return self.down_proj(self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1]) # this was adapted from LlamaAttention @@ -496,14 +429,12 @@ class IdeficsAttention(nn.Module): config, prefix=f"{prefix}.v_proj", weights=weights, bias=False ) else: - self.q_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.q_proj", weights=weights, bias=False - ) - self.k_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.k_proj", weights=weights, bias=False - ) - self.v_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.v_proj", weights=weights, bias=False + self.qkv = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, ) self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False @@ -538,10 +469,21 @@ class IdeficsAttention(nn.Module): bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2) - if not is_cross_attention: - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# . transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2) + if is_cross_attention: + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2) + query_states = query_states.transpose(1, 2) + _, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len` + key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = ( + self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) + ) + else: + qkv = self.qkv(hidden_states) + query_states, key_states, value_states = qkv.split(self.num_heads * self.head_dim, dim=2) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim)# . transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2) kv_seq_len = q_len if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] @@ -559,13 +501,6 @@ class IdeficsAttention(nn.Module): query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - else: - query_states = query_states.transpose(1, 2) - _, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len` - key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = ( - self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) - ) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -676,14 +611,14 @@ class IdeficsDecoderLayer(nn.Module): output_attentions=output_attentions, use_cache=use_cache, ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -773,7 +708,7 @@ class IdeficsGatedCrossAttentionLayer(nn.Module): attention_mask=image_attention_mask, output_attentions=output_attentions, ) - hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) + # hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) # when there are no images the model is used in pure language mode gate = 0 if no_images else 1 hidden_states = residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states @@ -782,7 +717,7 @@ class IdeficsGatedCrossAttentionLayer(nn.Module): residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) + # hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states outputs = (hidden_states,) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 9507d331..5a307269 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -16,7 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch -PROFILE = False +PROFILE = True class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__(self, model: Model, cache: Cache, server_urls: List[str]):