mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
move controlflow in forward
This commit is contained in:
parent
80ce8910f1
commit
8617d4795a
@ -161,7 +161,7 @@ class LlamaRMSNorm(nn.Module):
|
|||||||
)
|
)
|
||||||
return out, residual
|
return out, residual
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("system not supported")
|
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
|
@ -166,6 +166,8 @@ class MistralRMSNorm(nn.Module):
|
|||||||
self.variance_epsilon,
|
self.variance_epsilon,
|
||||||
)
|
)
|
||||||
return out, residual
|
return out, residual
|
||||||
|
else:
|
||||||
|
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
|
@ -556,32 +556,6 @@ try:
|
|||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
|
|
||||||
def rope_forward_cuda(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
|
||||||
rotary_dim = cos.shape[-1]
|
|
||||||
x1 = x[..., :rotary_dim]
|
|
||||||
x2 = x[..., rotary_dim : 2 * rotary_dim]
|
|
||||||
|
|
||||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
|
||||||
return x
|
|
||||||
elif IS_ROCM_SYSTEM:
|
|
||||||
# For RoCm, we fall back on a manual implementation given that Flash Attention's ROPE kernel can not be compiled for RoCm.
|
|
||||||
# We could use VLLM ROPE kernel here (compatible with RoCm), but the API is different and would require position_ids: https://github.com/vllm-project/vllm/blob/1a2bbc930135cd3b94fbff2aafbdf5c568acc8bd/csrc/pos_encoding.cpp#L3
|
|
||||||
def rope_forward_rocm(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
|
||||||
rotary_dim = cos.shape[-1]
|
|
||||||
|
|
||||||
dtype = x.dtype
|
|
||||||
x_upcast = x.to(torch.float32)
|
|
||||||
cos = cos.to(torch.float32)
|
|
||||||
sin = sin.to(torch.float32)
|
|
||||||
|
|
||||||
x1 = x_upcast[..., :rotary_dim]
|
|
||||||
x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]
|
|
||||||
|
|
||||||
# Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well.
|
|
||||||
x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
|
|
||||||
x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _create_inv_freq(dim, base, device):
|
def _create_inv_freq(dim, base, device):
|
||||||
inv_freq = 1.0 / (
|
inv_freq = 1.0 / (
|
||||||
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||||
@ -609,6 +583,36 @@ try:
|
|||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
self.dynamic_args = None
|
self.dynamic_args = None
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||||
|
# Such controlflows may add some overhead.
|
||||||
|
if IS_CUDA_SYSTEM:
|
||||||
|
rotary_dim = cos.shape[-1]
|
||||||
|
x1 = x[..., :rotary_dim]
|
||||||
|
x2 = x[..., rotary_dim : 2 * rotary_dim]
|
||||||
|
|
||||||
|
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||||
|
return x
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
# For RoCm, we fall back on a manual implementation given that Flash Attention's ROPE kernel can not be compiled for RoCm.
|
||||||
|
# We could use VLLM ROPE kernel here (compatible with RoCm), but the API is different and would require position_ids: https://github.com/vllm-project/vllm/blob/1a2bbc930135cd3b94fbff2aafbdf5c568acc8bd/csrc/pos_encoding.cpp#L3
|
||||||
|
def rope_forward_rocm(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||||
|
rotary_dim = cos.shape[-1]
|
||||||
|
|
||||||
|
dtype = x.dtype
|
||||||
|
x_upcast = x.to(torch.float32)
|
||||||
|
cos = cos.to(torch.float32)
|
||||||
|
sin = sin.to(torch.float32)
|
||||||
|
|
||||||
|
x1 = x_upcast[..., :rotary_dim]
|
||||||
|
x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]
|
||||||
|
|
||||||
|
# Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well.
|
||||||
|
x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
|
||||||
|
x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def static(cls, config, dim, base, device):
|
def static(cls, config, dim, base, device):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
@ -718,11 +722,6 @@ try:
|
|||||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||||
|
|
||||||
if IS_CUDA_SYSTEM:
|
|
||||||
PositionRotaryEmbedding.forward = rope_forward_cuda
|
|
||||||
elif IS_ROCM_SYSTEM:
|
|
||||||
PositionRotaryEmbedding.forward = rope_forward_rocm
|
|
||||||
|
|
||||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
Loading…
Reference in New Issue
Block a user