This commit is contained in:
OlivierDehaene 2024-04-10 16:47:41 +02:00
parent 2e7f6e8012
commit 07a3050b20
3 changed files with 43 additions and 108 deletions

View File

@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
tokenizers = "^0.15.0"
huggingface-hub = "^0.19.3"
transformers = { git = "https://github.com/huggingface/transformers", rev = "517a3e670d8fc11374895e870dd0dd041467c7fe" }
transformers = "^4.38"
einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }

View File

@ -43,6 +43,43 @@ else:
dropout_layer_norm = None
class CohereRotary(PositionRotaryEmbedding):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
# Such controlflows may add some overhead.
if IS_CUDA_SYSTEM:
import rotary_emb
q1 = query[..., ::2]
q2 = query[..., 1::2]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., ::2]
k2 = key[..., 1::2]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif IS_ROCM_SYSTEM:
from vllm import pos_encoding_ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query.shape[-1]
# Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
class CohereLayerNorm(nn.Module):
def __init__(self, prefix, weights, eps):
super().__init__()
@ -146,93 +183,6 @@ def _load_gqa(config, prefix: str, weights):
)
class CohereRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (
self.base
** (
torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
/ self.dim
)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, device_type, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[None, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
freqs = (
inv_freq_expanded.float().to(position_ids.device)
@ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos[0], sin[0]
def rotate_half(x):
# Split and rotate
x1 = x[..., ::2]
x2 = x[..., 1::2]
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
return rot_x
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
dtype = q.dtype
q = q.float()
k = k.float()
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
class FlashCohereAttention(torch.nn.Module):
def __init__(
self,
@ -245,7 +195,7 @@ class FlashCohereAttention(torch.nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
self.rotary_emb = CohereRotary.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
@ -325,9 +275,7 @@ class FlashCohereAttention(torch.nn.Module):
key = key.view(-1, self.num_key_value_heads, self.head_size)
value = value.view(-1, self.num_key_value_heads, self.head_size)
query, key = apply_rotary_pos_emb(query, key, cos, sin)
# self.rotary_emb(query, key, true_cos.reshape(*cos.shape), true_sin.reshape(*sin.shape))
self.rotary_emb(query, key, cos, sin)
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
@ -487,11 +435,6 @@ class FlashCohereModel(torch.nn.Module):
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
self.rotary_true = CohereRotaryEmbedding(
self.head_size,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
)
def forward(
self,
@ -508,10 +451,9 @@ class FlashCohereModel(torch.nn.Module):
# Get rotary cos and sin for this forward
# Avoid to index in each layer
# cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
# position_ids, max_s, hidden_states.dtype
# )
cos, sin = self.rotary_true(hidden_states.device.type, position_ids)
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):

View File

@ -19,7 +19,6 @@ from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from text_generation_server.utils.log import log_once
HAS_AWQ = True
try:
@ -35,12 +34,6 @@ except Exception:
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
# V2 = False
# log_once(
# logger.warning,
# "Disabling exllama v2 and using v1 instead because there are issues when sharding",
# )
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False