mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fixed
This commit is contained in:
parent
2e7f6e8012
commit
07a3050b20
@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
|
|||||||
sentencepiece = "^0.1.97"
|
sentencepiece = "^0.1.97"
|
||||||
tokenizers = "^0.15.0"
|
tokenizers = "^0.15.0"
|
||||||
huggingface-hub = "^0.19.3"
|
huggingface-hub = "^0.19.3"
|
||||||
transformers = { git = "https://github.com/huggingface/transformers", rev = "517a3e670d8fc11374895e870dd0dd041467c7fe" }
|
transformers = "^4.38"
|
||||||
einops = "^0.6.1"
|
einops = "^0.6.1"
|
||||||
texttable = { version = "^1.6.7", optional = true }
|
texttable = { version = "^1.6.7", optional = true }
|
||||||
datasets = { version = "^2.14.0", optional = true }
|
datasets = { version = "^2.14.0", optional = true }
|
||||||
|
@ -43,6 +43,43 @@ else:
|
|||||||
dropout_layer_norm = None
|
dropout_layer_norm = None
|
||||||
|
|
||||||
|
|
||||||
|
class CohereRotary(PositionRotaryEmbedding):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
):
|
||||||
|
# Such controlflows may add some overhead.
|
||||||
|
if IS_CUDA_SYSTEM:
|
||||||
|
import rotary_emb
|
||||||
|
|
||||||
|
q1 = query[..., ::2]
|
||||||
|
q2 = query[..., 1::2]
|
||||||
|
|
||||||
|
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||||
|
|
||||||
|
k1 = key[..., ::2]
|
||||||
|
k2 = key[..., 1::2]
|
||||||
|
|
||||||
|
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
from vllm import pos_encoding_ops
|
||||||
|
|
||||||
|
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||||
|
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||||
|
|
||||||
|
head_size = query.shape[-1]
|
||||||
|
|
||||||
|
# Inplace operation, updating query and key.
|
||||||
|
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CohereLayerNorm(nn.Module):
|
class CohereLayerNorm(nn.Module):
|
||||||
def __init__(self, prefix, weights, eps):
|
def __init__(self, prefix, weights, eps):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -146,93 +183,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CohereRotaryEmbedding(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
max_position_embeddings=2048,
|
|
||||||
base=10000,
|
|
||||||
device=None,
|
|
||||||
scaling_factor=1.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.scaling_factor = scaling_factor
|
|
||||||
self.dim = dim
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.base = base
|
|
||||||
inv_freq = 1.0 / (
|
|
||||||
self.base
|
|
||||||
** (
|
|
||||||
torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
|
|
||||||
/ self.dim
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, device_type, position_ids):
|
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
||||||
inv_freq_expanded = (
|
|
||||||
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
|
||||||
)
|
|
||||||
position_ids_expanded = position_ids[None, None, :].float()
|
|
||||||
|
|
||||||
# Force float32 since bfloat16 loses precision on long contexts
|
|
||||||
# See https://github.com/huggingface/transformers/pull/29285
|
|
||||||
device_type = (
|
|
||||||
device_type
|
|
||||||
if isinstance(device_type, str) and device_type != "mps"
|
|
||||||
else "cpu"
|
|
||||||
)
|
|
||||||
with torch.autocast(device_type=device_type, enabled=False):
|
|
||||||
freqs = (
|
|
||||||
inv_freq_expanded.float().to(position_ids.device)
|
|
||||||
@ position_ids_expanded.float()
|
|
||||||
).transpose(1, 2)
|
|
||||||
emb = torch.repeat_interleave(freqs, 2, dim=-1)
|
|
||||||
cos = emb.cos()
|
|
||||||
sin = emb.sin()
|
|
||||||
return cos[0], sin[0]
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
# Split and rotate
|
|
||||||
x1 = x[..., ::2]
|
|
||||||
x2 = x[..., 1::2]
|
|
||||||
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
|
|
||||||
return rot_x
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
||||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q (`torch.Tensor`): The query tensor.
|
|
||||||
k (`torch.Tensor`): The key tensor.
|
|
||||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
||||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
||||||
position_ids (`torch.Tensor`, *optional*):
|
|
||||||
Deprecated and unused.
|
|
||||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
||||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
||||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
||||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
||||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
||||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
||||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
||||||
Returns:
|
|
||||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
||||||
"""
|
|
||||||
dtype = q.dtype
|
|
||||||
q = q.float()
|
|
||||||
k = k.float()
|
|
||||||
cos = cos.unsqueeze(unsqueeze_dim)
|
|
||||||
sin = sin.unsqueeze(unsqueeze_dim)
|
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashCohereAttention(torch.nn.Module):
|
class FlashCohereAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -245,7 +195,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = CohereRotary.static(
|
||||||
config=config,
|
config=config,
|
||||||
dim=self.head_size,
|
dim=self.head_size,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
@ -325,9 +275,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_key_value_heads, self.head_size)
|
value = value.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
# self.rotary_emb(query, key, true_cos.reshape(*cos.shape), true_sin.reshape(*sin.shape))
|
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
@ -487,11 +435,6 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
self.head_size = self.layers[0].self_attn.head_size
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
self.num_heads = self.layers[0].self_attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
self.rotary_true = CohereRotaryEmbedding(
|
|
||||||
self.head_size,
|
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
|
||||||
base=config.rope_theta,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -508,10 +451,9 @@ class FlashCohereModel(torch.nn.Module):
|
|||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
# cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
# position_ids, max_s, hidden_states.dtype
|
position_ids, max_s, hidden_states.dtype
|
||||||
# )
|
)
|
||||||
cos, sin = self.rotary_true(hidden_states.device.type, position_ids)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
@ -19,7 +19,6 @@ from accelerate import init_empty_weights
|
|||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||||
from text_generation_server.utils.log import log_once
|
|
||||||
|
|
||||||
HAS_AWQ = True
|
HAS_AWQ = True
|
||||||
try:
|
try:
|
||||||
@ -35,12 +34,6 @@ except Exception:
|
|||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
|
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
|
||||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||||
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
|
||||||
# V2 = False
|
|
||||||
# log_once(
|
|
||||||
# logger.warning,
|
|
||||||
# "Disabling exllama v2 and using v1 instead because there are issues when sharding",
|
|
||||||
# )
|
|
||||||
|
|
||||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
|
Loading…
Reference in New Issue
Block a user