mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Further speedup.
Fused KV (QKV cannot be done for cross attention because shapes are different) in attention Fused Gate + up. ~at 40ms
This commit is contained in:
parent
fc02d99e57
commit
8030d66f0d
@ -287,27 +287,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# this was adapted from LlamaRMSNorm
|
||||
# class IdeficsRMSNorm(nn.Module):
|
||||
# def __init__(self, prefix, weights, eps=1e-6):
|
||||
# """
|
||||
# IdeficsRMSNorm is equivalent to T5LayerNorm
|
||||
# """
|
||||
# super().__init__()
|
||||
#
|
||||
# weight = weights.get_tensor(f"{prefix}.weight")
|
||||
# self.weight = nn.Parameter(weight)
|
||||
# self.variance_epsilon = eps
|
||||
#
|
||||
# def forward(self, hidden_states):
|
||||
# variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
# hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
#
|
||||
# # convert into half-precision if necessary
|
||||
# if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
# hidden_states = hidden_states.to(self.weight.dtype)
|
||||
#
|
||||
# return self.weight * hidden_states
|
||||
class IdeficsRMSNorm(nn.Module):
|
||||
def __init__(self, prefix, weights, eps=1e-6):
|
||||
"""
|
||||
@ -371,56 +350,6 @@ class IdeficsRMSNorm(nn.Module):
|
||||
return normed_hidden_states
|
||||
|
||||
|
||||
# this was adapted from LlamaRotaryEmbedding
|
||||
class IdeficsEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
||||
# gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
||||
# cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
# sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
# q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
# k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
# return q_embed, k_embed
|
||||
|
||||
|
||||
# this was adapted from LlamaMLP
|
||||
class IdeficsMLP(nn.Module):
|
||||
def __init__(
|
||||
@ -430,19 +359,23 @@ class IdeficsMLP(nn.Module):
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.gate_proj", weights=weights, bias=False,
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.down_proj", weights=weights, bias=False,
|
||||
)
|
||||
self.up_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.up_proj", weights=weights, bias=False,
|
||||
)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
shape = gate_up_states.shape
|
||||
gate_up_states = gate_up_states.view(*shape[:-1], 2, shape[-1] // 2)
|
||||
return self.down_proj(self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1])
|
||||
|
||||
|
||||
# this was adapted from LlamaAttention
|
||||
@ -496,14 +429,12 @@ class IdeficsAttention(nn.Module):
|
||||
config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
|
||||
)
|
||||
else:
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
|
||||
)
|
||||
self.k_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
|
||||
)
|
||||
self.v_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
|
||||
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
|
||||
@ -538,10 +469,21 @@ class IdeficsAttention(nn.Module):
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if is_cross_attention:
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
|
||||
if not is_cross_attention:
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# . transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
_, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
|
||||
key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = (
|
||||
self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
)
|
||||
else:
|
||||
qkv = self.qkv(hidden_states)
|
||||
query_states, key_states, value_states = qkv.split(self.num_heads * self.head_dim, dim=2)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim)# . transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
|
||||
kv_seq_len = q_len
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
@ -559,13 +501,6 @@ class IdeficsAttention(nn.Module):
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
else:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
_, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
|
||||
key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = (
|
||||
self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
@ -676,14 +611,14 @@ class IdeficsDecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
# hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
# hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
@ -773,7 +708,7 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
|
||||
attention_mask=image_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
|
||||
# hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
|
||||
# when there are no images the model is used in pure language mode
|
||||
gate = 0 if no_images else 1
|
||||
hidden_states = residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states
|
||||
@ -782,7 +717,7 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
|
||||
# hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
|
||||
hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
@ -16,7 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
|
||||
PROFILE = False
|
||||
PROFILE = True
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||
|
Loading…
Reference in New Issue
Block a user