Further speedup.

Fused KV (QKV cannot be done for cross attention because shapes are
different) in attention
Fused Gate + up.

~at 40ms
This commit is contained in:
Nicolas Patry 2023-08-15 12:47:02 +00:00
parent fc02d99e57
commit 8030d66f0d
2 changed files with 37 additions and 102 deletions

View File

@ -287,27 +287,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# this was adapted from LlamaRMSNorm
# class IdeficsRMSNorm(nn.Module):
# def __init__(self, prefix, weights, eps=1e-6):
# """
# IdeficsRMSNorm is equivalent to T5LayerNorm
# """
# super().__init__()
#
# weight = weights.get_tensor(f"{prefix}.weight")
# self.weight = nn.Parameter(weight)
# self.variance_epsilon = eps
#
# def forward(self, hidden_states):
# variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
# hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
#
# # convert into half-precision if necessary
# if self.weight.dtype in [torch.float16, torch.bfloat16]:
# hidden_states = hidden_states.to(self.weight.dtype)
#
# return self.weight * hidden_states
class IdeficsRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
@ -371,56 +350,6 @@ class IdeficsRMSNorm(nn.Module):
return normed_hidden_states
# this was adapted from LlamaRotaryEmbedding
class IdeficsEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
# gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
# cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
# sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
# q_embed = (q * cos) + (rotate_half(q) * sin)
# k_embed = (k * cos) + (rotate_half(k) * sin)
# return q_embed, k_embed
# this was adapted from LlamaMLP
class IdeficsMLP(nn.Module):
def __init__(
@ -430,19 +359,23 @@ class IdeficsMLP(nn.Module):
weights,
):
super().__init__()
self.gate_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.gate_proj", weights=weights, bias=False,
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.down_proj", weights=weights, bias=False,
)
self.up_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.up_proj", weights=weights, bias=False,
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
shape = gate_up_states.shape
gate_up_states = gate_up_states.view(*shape[:-1], 2, shape[-1] // 2)
return self.down_proj(self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1])
# this was adapted from LlamaAttention
@ -496,14 +429,12 @@ class IdeficsAttention(nn.Module):
config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
)
else:
self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
)
self.k_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
)
self.v_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
self.qkv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
@ -538,10 +469,21 @@ class IdeficsAttention(nn.Module):
bsz, q_len, _ = hidden_states.size()
if is_cross_attention:
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
if not is_cross_attention:
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# . transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
query_states = query_states.transpose(1, 2)
_, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = (
self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
)
else:
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(self.num_heads * self.head_dim, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim)# . transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
kv_seq_len = q_len
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
@ -559,13 +501,6 @@ class IdeficsAttention(nn.Module):
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
else:
query_states = query_states.transpose(1, 2)
_, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = (
self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
@ -676,14 +611,14 @@ class IdeficsDecoderLayer(nn.Module):
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
@ -773,7 +708,7 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
attention_mask=image_attention_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
# hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
# when there are no images the model is used in pure language mode
gate = 0 if no_images else 1
hidden_states = residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states
@ -782,7 +717,7 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
# hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states
outputs = (hidden_states,)

View File

@ -16,7 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
PROFILE = False
PROFILE = True
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):