mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Moving to SYSTEM
enum.
This commit is contained in:
parent
c84718c8b6
commit
34b9289c3d
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_ROCM_SYSTEM,
|
||||
SYSTEM,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -10,7 +10,7 @@ except Exception:
|
||||
major = 1
|
||||
|
||||
HAS_EXLLAMA = False
|
||||
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
|
||||
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
|
||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA = False
|
||||
|
@ -2,9 +2,7 @@ import torch
|
||||
from torch import nn
|
||||
from accelerate import init_empty_weights
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
SYSTEM,
|
||||
)
|
||||
|
||||
|
||||
@ -35,19 +33,60 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
||||
torch.nn.LayerNorm.load = load_layer_norm
|
||||
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
if SYSTEM == "cuda":
|
||||
import dropout_layer_norm
|
||||
elif IS_ROCM_SYSTEM:
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||
else:
|
||||
(
|
||||
normed_hidden_states,
|
||||
residual,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
return normed_hidden_states, residual
|
||||
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm import layernorm_ops
|
||||
elif IS_XPU_SYSTEM:
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||
|
||||
elif SYSTEM == "xpu":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
else:
|
||||
dropout_layer_norm = None
|
||||
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if IS_XPU_SYSTEM:
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
res_out = hidden_states
|
||||
out = ipex.llm.functional.add_layer_norm(
|
||||
residual, hidden_states, self.weight, self.bias, self.eps, True
|
||||
@ -55,7 +94,19 @@ class FastLayerNorm(nn.LayerNorm):
|
||||
if residual is not None:
|
||||
res_out = residual
|
||||
return out, res_out
|
||||
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if SYSTEM == "xpu":
|
||||
res_out = hidden_states
|
||||
out = ipex.llm.functional.add_layer_norm(
|
||||
residual, hidden_states, self.weight, self.bias, self.eps, True
|
||||
)
|
||||
if residual is not None:
|
||||
res_out = residual
|
||||
return out, res_out
|
||||
elif hidden_states.shape[-1] > 8192 or SYSTEM == "rocm":
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
@ -102,7 +153,7 @@ class FastRMSNorm(nn.Module):
|
||||
return cls(weight, eps)
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if IS_XPU_SYSTEM:
|
||||
if SYSTEM == "xpu":
|
||||
residual_out = hidden_states
|
||||
out = ipex.llm.functional.add_rms_norm(
|
||||
residual,
|
||||
@ -131,7 +182,7 @@ class FastRMSNorm(nn.Module):
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states, residual
|
||||
elif IS_CUDA_SYSTEM:
|
||||
elif SYSTEM == "cuda":
|
||||
# faster post attention rms norm
|
||||
(
|
||||
normed_hidden_states,
|
||||
@ -158,7 +209,7 @@ class FastRMSNorm(nn.Module):
|
||||
res = hidden_states
|
||||
|
||||
return normed_hidden_states, res
|
||||
elif IS_ROCM_SYSTEM:
|
||||
elif SYSTEM == "rocm":
|
||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
|
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
class FastLinear(torch.nn.Module):
|
||||
@ -126,7 +127,7 @@ def get_linear(weight, bias, quantize):
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||
)
|
||||
if IS_ROCM_SYSTEM:
|
||||
if SYSTEM == "rocm":
|
||||
raise NotImplementedError(
|
||||
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
|
||||
"to use Exllama/GPTQ kernels for AWQ inference."
|
||||
|
@ -1,15 +1,12 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
if SYSTEM == "cuda":
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
elif IS_ROCM_SYSTEM:
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm import pos_encoding_ops
|
||||
|
||||
|
||||
@ -50,7 +47,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
# Such controlflows may add some overhead.
|
||||
if IS_CUDA_SYSTEM:
|
||||
if SYSTEM == "cuda":
|
||||
rotary_dim = cos.shape[-1]
|
||||
q1 = query[..., :rotary_dim]
|
||||
q2 = query[..., rotary_dim : 2 * rotary_dim]
|
||||
@ -61,7 +58,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
k2 = key[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
elif SYSTEM == "rocm":
|
||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||
|
||||
@ -69,7 +66,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
|
||||
# Inplace operation, updating query and key.
|
||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||
elif IS_XPU_SYSTEM:
|
||||
elif SYSTEM == "xpu":
|
||||
ipex.llm.functional.rotary_embedding(
|
||||
query, key, sin, cos, query.size(-1), True
|
||||
)
|
||||
@ -223,7 +220,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
"""
|
||||
Return cos and sin for the asked position ids
|
||||
"""
|
||||
if IS_ROCM_SYSTEM:
|
||||
if SYSTEM == "rocm":
|
||||
# For RoCm, we always use float cos/sin to avoid a cast.
|
||||
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
|
||||
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
|
||||
|
@ -2,7 +2,7 @@ import math
|
||||
import torch
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
BLOCK_SIZE: int = 16
|
||||
# Will be set in warmup
|
||||
@ -25,7 +25,7 @@ class CacheManager:
|
||||
self.repeat_slots = repeat_slots
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
if IS_XPU_SYSTEM:
|
||||
if SYSTEM == "xpu":
|
||||
x = 1
|
||||
else:
|
||||
x = self.block_size // element_size
|
||||
|
@ -26,18 +26,22 @@ from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
if SYSTEM == "cuda":
|
||||
import dropout_layer_norm
|
||||
else:
|
||||
dropout_layer_norm = None
|
||||
@ -52,7 +56,7 @@ class CohereRotary(PositionRotaryEmbedding):
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
# Such controlflows may add some overhead.
|
||||
if IS_CUDA_SYSTEM:
|
||||
if SYSTEM == "cuda":
|
||||
import rotary_emb
|
||||
|
||||
q1 = query[..., ::2]
|
||||
@ -64,7 +68,7 @@ class CohereRotary(PositionRotaryEmbedding):
|
||||
k2 = key[..., 1::2]
|
||||
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm import pos_encoding_ops
|
||||
|
||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||
@ -90,7 +94,7 @@ class CohereLayerNorm(nn.Module):
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||
if hidden_states.shape[-1] > 8192 or SYSTEM == "rocm":
|
||||
hidden_states = hidden_states.reshape(
|
||||
-1, self.weight.shape[0], self.weight.shape[1]
|
||||
)
|
||||
|
@ -21,21 +21,26 @@ from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if not IS_XPU_SYSTEM:
|
||||
if SYSTEM != "xpu":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
FastLayerNorm,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
|
@ -24,9 +24,9 @@ import torch.distributed
|
||||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if not IS_XPU_SYSTEM:
|
||||
if SYSTEM != "xpu":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
@ -36,14 +36,18 @@ from loguru import logger
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
FastRMSNorm,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class MixtralConfig(PretrainedConfig):
|
||||
|
@ -55,12 +55,14 @@ from text_generation_server.layers import (
|
||||
PositionRotaryEmbedding,
|
||||
FastLinear,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
if SYSTEM == "cuda":
|
||||
import dropout_layer_norm
|
||||
elif IS_ROCM_SYSTEM:
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm import layernorm_ops
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported system {SYSTEM}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -373,7 +375,7 @@ class IdeficsRMSNorm(nn.Module):
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
elif IS_CUDA_SYSTEM:
|
||||
elif SYSTEM == "cuda":
|
||||
# faster post attention rms norm
|
||||
unwrap = False
|
||||
if len(hidden_states.shape) > 2:
|
||||
@ -405,7 +407,7 @@ class IdeficsRMSNorm(nn.Module):
|
||||
normed_hidden_states = normed_hidden_states.view(*shape)
|
||||
|
||||
return normed_hidden_states
|
||||
elif IS_ROCM_SYSTEM:
|
||||
elif SYSTEM == "rocm":
|
||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
|
@ -12,10 +12,6 @@ from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
Tokens,
|
||||
@ -32,13 +28,14 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
empty_cache,
|
||||
synchronize,
|
||||
get_free_memory,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashCausalLMBatch(Batch):
|
||||
@ -757,10 +754,8 @@ class FlashCausalLM(Model):
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
# The warmup batch is the biggest batch we could ever receive
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
torch.cuda.empty_cache()
|
||||
elif IS_XPU_SYSTEM:
|
||||
torch.xpu.empty_cache()
|
||||
empty_cache()
|
||||
|
||||
try:
|
||||
cache_manager = set_cache_manager(
|
||||
batch.blocks,
|
||||
@ -780,10 +775,7 @@ class FlashCausalLM(Model):
|
||||
f"You need to decrease `--max-batch-prefill-tokens`"
|
||||
) from e
|
||||
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
torch.cuda.synchronize(self.device)
|
||||
elif IS_XPU_SYSTEM:
|
||||
torch.xpu.synchronize(self.device)
|
||||
synchronize(self.device)
|
||||
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the free memory
|
||||
@ -791,20 +783,7 @@ class FlashCausalLM(Model):
|
||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||
total_gpu_memory = torch.cuda.get_device_properties(
|
||||
self.device
|
||||
).total_memory
|
||||
|
||||
free_memory = max(
|
||||
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
|
||||
)
|
||||
elif IS_XPU_SYSTEM:
|
||||
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
|
||||
free_memory = int(total_gpu_memory * 0.5)
|
||||
else:
|
||||
raise NotImplementedError("FlashModel is only available on GPU")
|
||||
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||
|
||||
num_blocks = (
|
||||
# Leave 5% for some wiggle room
|
||||
|
@ -18,7 +18,7 @@ from text_generation_server.utils import (
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
@ -35,7 +35,7 @@ class FlashLlama(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_XPU_SYSTEM:
|
||||
elif SYSTEM == "xpu":
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
|
@ -33,7 +33,7 @@ tracer = trace.get_tracer(__name__)
|
||||
# Will be set in init
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
|
||||
@ -322,7 +322,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_XPU_SYSTEM:
|
||||
elif SYSTEM == "xpu":
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
|
@ -14,7 +14,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -33,7 +33,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_XPU_SYSTEM:
|
||||
elif SYSTEM == "xpu":
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
|
@ -15,7 +15,7 @@ from text_generation_server.utils import (
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -34,7 +34,7 @@ class FlashRWSharded(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_XPU_SYSTEM:
|
||||
elif SYSTEM == "xpu":
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
|
@ -18,7 +18,7 @@ from text_generation_server.utils import (
|
||||
Weights,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -37,7 +37,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif IS_XPU_SYSTEM:
|
||||
elif SYSTEM == "xpu":
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
|
@ -2,13 +2,8 @@ import os
|
||||
import torch
|
||||
|
||||
from loguru import logger
|
||||
import math
|
||||
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
@ -16,83 +11,22 @@ HAS_FLASH_ATTN = True
|
||||
HAS_FLASH_ATTN_V2_CUDA = False
|
||||
HAS_FLASH_ATTN_V2_ROCM = False
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
if SYSTEM == "xpu":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
if not torch.cuda.is_available():
|
||||
raise ImportError("CUDA is not available")
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
is_sm8x = major == 8 and minor >= 0
|
||||
is_sm90 = major == 9 and minor == 0
|
||||
|
||||
HAS_FLASH_ATTN = False
|
||||
HAS_FLASH_ATTN_V2_CUDA = False
|
||||
HAS_FLASH_ATTN_V2_ROCM = False
|
||||
try:
|
||||
try:
|
||||
import flash_attn_2_cuda
|
||||
except ImportError:
|
||||
architecture_suffix = ""
|
||||
if IS_CUDA_SYSTEM:
|
||||
architecture_suffix = "-cuda"
|
||||
elif IS_ROCM_SYSTEM:
|
||||
architecture_suffix = "-rocm"
|
||||
raise ImportError(
|
||||
"Flash Attention V2 is not installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||
)
|
||||
if not (is_sm8x or is_sm90):
|
||||
raise ImportError(
|
||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||
"Flash Attention V2"
|
||||
)
|
||||
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
|
||||
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
|
||||
except ImportError as e:
|
||||
try:
|
||||
import flash_attn_cuda
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Flash Attention is not installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
) from e
|
||||
|
||||
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
|
||||
raise ImportError(
|
||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||
) from e
|
||||
elif IS_ROCM_SYSTEM:
|
||||
for idx in range(torch.cuda.device_count()):
|
||||
if "MI210" not in torch.cuda.get_device_name(
|
||||
idx
|
||||
) and "MI250" not in torch.cuda.get_device_name(idx):
|
||||
raise ImportError(
|
||||
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||
)
|
||||
|
||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||
HAS_FLASH_ATTN = True
|
||||
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
if window_size_left != -1:
|
||||
raise ValueError(
|
||||
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||
@ -114,7 +48,77 @@ def attention(
|
||||
None,
|
||||
)
|
||||
|
||||
if HAS_FLASH_ATTN_V2_CUDA:
|
||||
|
||||
if SYSTEM in {"cuda", "rocm"}:
|
||||
if not torch.cuda.is_available():
|
||||
raise ImportError("CUDA is not available")
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
is_sm8x = major == 8 and minor >= 0
|
||||
is_sm90 = major == 9 and minor == 0
|
||||
|
||||
HAS_FLASH_ATTN = False
|
||||
HAS_FLASH_ATTN_V2_CUDA = False
|
||||
HAS_FLASH_ATTN_V2_ROCM = False
|
||||
try:
|
||||
try:
|
||||
import flash_attn_2_cuda
|
||||
except ImportError:
|
||||
architecture_suffix = f"-{SYSTEM}"
|
||||
raise ImportError(
|
||||
"Flash Attention V2 is not installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||
)
|
||||
if not (is_sm8x or is_sm90):
|
||||
raise ImportError(
|
||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||
"Flash Attention V2"
|
||||
)
|
||||
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
|
||||
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
|
||||
except ImportError as e:
|
||||
try:
|
||||
import flash_attn_cuda
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Flash Attention is not installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
) from e
|
||||
|
||||
if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
|
||||
raise ImportError(
|
||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||
) from e
|
||||
elif SYSTEM == "rocm":
|
||||
for idx in range(torch.cuda.device_count()):
|
||||
if "MI210" not in torch.cuda.get_device_name(
|
||||
idx
|
||||
) and "MI250" not in torch.cuda.get_device_name(idx):
|
||||
raise ImportError(
|
||||
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||
)
|
||||
|
||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||
HAS_FLASH_ATTN = True
|
||||
|
||||
|
||||
if HAS_FLASH_ATTN_V2_CUDA:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
@ -136,7 +140,21 @@ def attention(
|
||||
False,
|
||||
None,
|
||||
)
|
||||
elif HAS_FLASH_ATTN_V2_ROCM:
|
||||
|
||||
elif HAS_FLASH_ATTN_V2_ROCM:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
if window_size_left != -1:
|
||||
raise ValueError(
|
||||
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||
@ -159,7 +177,19 @@ def attention(
|
||||
False,
|
||||
None,
|
||||
)
|
||||
elif HAS_FLASH_ATTN:
|
||||
|
||||
elif HAS_FLASH_ATTN:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
raise NotImplementedError(
|
||||
"window_size_left is only available with flash attn v2"
|
||||
@ -209,4 +239,5 @@ def attention(
|
||||
None,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("flash attention is not installed")
|
||||
|
@ -10,6 +10,32 @@ def is_xpu_available():
|
||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||
|
||||
|
||||
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
||||
IS_XPU_SYSTEM = is_xpu_available()
|
||||
def get_cuda_free_memory(device, memory_fraction):
|
||||
total_free_memory, _ = torch.cuda.mem_get_info(device)
|
||||
total_gpu_memory = torch.cuda.get_device_properties(device).total_memory
|
||||
free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory)
|
||||
return free_memory
|
||||
|
||||
|
||||
def get_xpu_free_memory(device):
|
||||
total_gpu_memory = torch.xpu.get_device_properties(device).total_memory
|
||||
free_memory = int(total_gpu_memory * 0.5)
|
||||
return free_memory
|
||||
|
||||
|
||||
SYSTEM = None
|
||||
if torch.version.hip is not None:
|
||||
SYSTEM = "rocm"
|
||||
empty_cache = torch.cuda.empty_cache
|
||||
synchronize = torch.cuda.synchronize
|
||||
get_free_memory = get_cuda_free_memory
|
||||
elif torch.version.cuda is not None:
|
||||
SYSTEM = "cuda"
|
||||
empty_cache = torch.cuda.empty_cache
|
||||
synchronize = torch.cuda.synchronize
|
||||
get_free_memory = get_cuda_free_memory
|
||||
elif is_xpu_available():
|
||||
SYSTEM = "xpu"
|
||||
empty_cache = torch.xpu.empty_cache
|
||||
synchronize = torch.xpu.synchronize
|
||||
get_free_memory = get_xpu_free_memory
|
||||
|
@ -1,13 +1,9 @@
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
if SYSTEM == "xpu":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
|
||||
@ -18,17 +14,17 @@ def reshape_and_cache(
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
if IS_CUDA_SYSTEM:
|
||||
if SYSTEM == "cuda":
|
||||
from vllm._C import cache_ops
|
||||
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm import cache_ops
|
||||
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
elif IS_XPU_SYSTEM:
|
||||
elif SYSTEM == "xpu":
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots
|
||||
)
|
||||
@ -68,7 +64,7 @@ def attention(
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
if IS_XPU_SYSTEM:
|
||||
if SYSTEM == "xpu":
|
||||
query = query.contiguous()
|
||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
@ -91,7 +87,7 @@ def attention(
|
||||
# to parallelize.
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
if use_v1:
|
||||
if IS_CUDA_SYSTEM:
|
||||
if SYSTEM == "cuda":
|
||||
from vllm._C import ops
|
||||
|
||||
ops.paged_attention_v1(
|
||||
@ -109,7 +105,7 @@ def attention(
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm import attention_ops
|
||||
|
||||
attention_ops.paged_attention_v1(
|
||||
@ -143,7 +139,7 @@ def attention(
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
if SYSTEM == "cuda":
|
||||
from vllm._C import ops
|
||||
|
||||
ops.paged_attention_v2(
|
||||
@ -164,7 +160,7 @@ def attention(
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm import attention_ops
|
||||
|
||||
attention_ops.paged_attention_v2(
|
||||
|
Loading…
Reference in New Issue
Block a user