Moving to SYSTEM enum.

This commit is contained in:
Nicolas Patry 2024-05-06 19:16:04 +02:00
parent c84718c8b6
commit 34b9289c3d
18 changed files with 286 additions and 190 deletions

View File

@ -1,7 +1,7 @@
import os import os
import torch import torch
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
IS_ROCM_SYSTEM, SYSTEM,
) )
try: try:
@ -10,7 +10,7 @@ except Exception:
major = 1 major = 1
HAS_EXLLAMA = False HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if os.getenv("DISABLE_EXLLAMA") == "True": if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False HAS_EXLLAMA = False

View File

@ -2,9 +2,7 @@ import torch
from torch import nn from torch import nn
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM, SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
) )
@ -35,19 +33,60 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
torch.nn.LayerNorm.load = load_layer_norm torch.nn.LayerNorm.load = load_layer_norm
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
if IS_CUDA_SYSTEM: if SYSTEM == "cuda":
import dropout_layer_norm import dropout_layer_norm
elif IS_ROCM_SYSTEM:
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
return super(FastLayerNorm, self).forward(hidden_states), residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return normed_hidden_states, residual
elif SYSTEM == "rocm":
from vllm import layernorm_ops from vllm import layernorm_ops
elif IS_XPU_SYSTEM:
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
return super(FastLayerNorm, self).forward(hidden_states), residual
elif SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
else:
dropout_layer_norm = None
class FastLayerNorm(nn.LayerNorm):
class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None):
def forward(self, hidden_states, residual=None):
if IS_XPU_SYSTEM:
res_out = hidden_states res_out = hidden_states
out = ipex.llm.functional.add_layer_norm( out = ipex.llm.functional.add_layer_norm(
residual, hidden_states, self.weight, self.bias, self.eps, True residual, hidden_states, self.weight, self.bias, self.eps, True
@ -55,7 +94,19 @@ class FastLayerNorm(nn.LayerNorm):
if residual is not None: if residual is not None:
res_out = residual res_out = residual
return out, res_out return out, res_out
elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if SYSTEM == "xpu":
res_out = hidden_states
out = ipex.llm.functional.add_layer_norm(
residual, hidden_states, self.weight, self.bias, self.eps, True
)
if residual is not None:
res_out = residual
return out, res_out
elif hidden_states.shape[-1] > 8192 or SYSTEM == "rocm":
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
@ -102,7 +153,7 @@ class FastRMSNorm(nn.Module):
return cls(weight, eps) return cls(weight, eps)
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if IS_XPU_SYSTEM: if SYSTEM == "xpu":
residual_out = hidden_states residual_out = hidden_states
out = ipex.llm.functional.add_rms_norm( out = ipex.llm.functional.add_rms_norm(
residual, residual,
@ -131,7 +182,7 @@ class FastRMSNorm(nn.Module):
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM: elif SYSTEM == "cuda":
# faster post attention rms norm # faster post attention rms norm
( (
normed_hidden_states, normed_hidden_states,
@ -158,7 +209,7 @@ class FastRMSNorm(nn.Module):
res = hidden_states res = hidden_states
return normed_hidden_states, res return normed_hidden_states, res
elif IS_ROCM_SYSTEM: elif SYSTEM == "rocm":
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual

View File

@ -1,4 +1,5 @@
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM
class FastLinear(torch.nn.Module): class FastLinear(torch.nn.Module):
@ -126,7 +127,7 @@ def get_linear(weight, bias, quantize):
raise NotImplementedError( raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated." f"The passed weight is not `awq` compatible, loader needs to be updated."
) )
if IS_ROCM_SYSTEM: if SYSTEM == "rocm":
raise NotImplementedError( raise NotImplementedError(
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
"to use Exllama/GPTQ kernels for AWQ inference." "to use Exllama/GPTQ kernels for AWQ inference."

View File

@ -1,15 +1,12 @@
import torch import torch
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import SYSTEM
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
)
if IS_CUDA_SYSTEM: if SYSTEM == "cuda":
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb import rotary_emb
elif IS_ROCM_SYSTEM: elif SYSTEM == "rocm":
from vllm import pos_encoding_ops from vllm import pos_encoding_ops
@ -50,7 +47,7 @@ class PositionRotaryEmbedding(nn.Module):
sin: torch.Tensor, sin: torch.Tensor,
): ):
# Such controlflows may add some overhead. # Such controlflows may add some overhead.
if IS_CUDA_SYSTEM: if SYSTEM == "cuda":
rotary_dim = cos.shape[-1] rotary_dim = cos.shape[-1]
q1 = query[..., :rotary_dim] q1 = query[..., :rotary_dim]
q2 = query[..., rotary_dim : 2 * rotary_dim] q2 = query[..., rotary_dim : 2 * rotary_dim]
@ -61,7 +58,7 @@ class PositionRotaryEmbedding(nn.Module):
k2 = key[..., rotary_dim : 2 * rotary_dim] k2 = key[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif IS_ROCM_SYSTEM: elif SYSTEM == "rocm":
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
@ -69,7 +66,7 @@ class PositionRotaryEmbedding(nn.Module):
# Inplace operation, updating query and key. # Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif IS_XPU_SYSTEM: elif SYSTEM == "xpu":
ipex.llm.functional.rotary_embedding( ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True query, key, sin, cos, query.size(-1), True
) )
@ -223,7 +220,7 @@ class PositionRotaryEmbedding(nn.Module):
""" """
Return cos and sin for the asked position ids Return cos and sin for the asked position ids
""" """
if IS_ROCM_SYSTEM: if SYSTEM == "rocm":
# For RoCm, we always use float cos/sin to avoid a cast. # For RoCm, we always use float cos/sin to avoid a cast.
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26 # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal. # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.

View File

@ -2,7 +2,7 @@ import math
import torch import torch
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
BLOCK_SIZE: int = 16 BLOCK_SIZE: int = 16
# Will be set in warmup # Will be set in warmup
@ -25,7 +25,7 @@ class CacheManager:
self.repeat_slots = repeat_slots self.repeat_slots = repeat_slots
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
if IS_XPU_SYSTEM: if SYSTEM == "xpu":
x = 1 x = 1
else: else:
x = self.block_size // element_size x = self.block_size // element_size

View File

@ -26,18 +26,22 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding,
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
)
from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
if IS_CUDA_SYSTEM: if SYSTEM == "cuda":
import dropout_layer_norm import dropout_layer_norm
else: else:
dropout_layer_norm = None dropout_layer_norm = None
@ -52,7 +56,7 @@ class CohereRotary(PositionRotaryEmbedding):
sin: torch.Tensor, sin: torch.Tensor,
): ):
# Such controlflows may add some overhead. # Such controlflows may add some overhead.
if IS_CUDA_SYSTEM: if SYSTEM == "cuda":
import rotary_emb import rotary_emb
q1 = query[..., ::2] q1 = query[..., ::2]
@ -64,7 +68,7 @@ class CohereRotary(PositionRotaryEmbedding):
k2 = key[..., 1::2] k2 = key[..., 1::2]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif IS_ROCM_SYSTEM: elif SYSTEM == "rocm":
from vllm import pos_encoding_ops from vllm import pos_encoding_ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
@ -90,7 +94,7 @@ class CohereLayerNorm(nn.Module):
self.eps = eps self.eps = eps
def forward(self, hidden_states): def forward(self, hidden_states):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: if hidden_states.shape[-1] > 8192 or SYSTEM == "rocm":
hidden_states = hidden_states.reshape( hidden_states = hidden_states.reshape(
-1, self.weight.shape[0], self.weight.shape[1] -1, self.weight.shape[0], self.weight.shape[1]
) )

View File

@ -21,21 +21,26 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from loguru import logger from loguru import logger
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if not IS_XPU_SYSTEM: if SYSTEM != "xpu":
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import ( from text_generation_server.layers import (
FastLinear, FastLinear,
FastLayerNorm,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding,
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once

View File

@ -24,9 +24,9 @@ import torch.distributed
import numpy as np import numpy as np
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if not IS_XPU_SYSTEM: if SYSTEM != "xpu":
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
@ -36,14 +36,18 @@ from loguru import logger
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import ( from text_generation_server.layers import (
FastLinear, FastLinear,
FastRMSNorm,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding,
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
class MixtralConfig(PretrainedConfig): class MixtralConfig(PretrainedConfig):

View File

@ -55,12 +55,14 @@ from text_generation_server.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
FastLinear, FastLinear,
) )
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if IS_CUDA_SYSTEM: if SYSTEM == "cuda":
import dropout_layer_norm import dropout_layer_norm
elif IS_ROCM_SYSTEM: elif SYSTEM == "rocm":
from vllm import layernorm_ops from vllm import layernorm_ops
else:
raise RuntimeError(f"Unsupported system {SYSTEM}")
@dataclass @dataclass
@ -373,7 +375,7 @@ class IdeficsRMSNorm(nn.Module):
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states return self.weight * hidden_states
elif IS_CUDA_SYSTEM: elif SYSTEM == "cuda":
# faster post attention rms norm # faster post attention rms norm
unwrap = False unwrap = False
if len(hidden_states.shape) > 2: if len(hidden_states.shape) > 2:
@ -405,7 +407,7 @@ class IdeficsRMSNorm(nn.Module):
normed_hidden_states = normed_hidden_states.view(*shape) normed_hidden_states = normed_hidden_states.view(*shape)
return normed_hidden_states return normed_hidden_states
elif IS_ROCM_SYSTEM: elif SYSTEM == "rocm":
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual

View File

@ -12,10 +12,6 @@ from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
Tokens, Tokens,
@ -32,13 +28,14 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM, empty_cache,
IS_ROCM_SYSTEM, synchronize,
IS_XPU_SYSTEM, get_free_memory,
) )
tracer = trace.get_tracer(__name__)
@dataclass @dataclass
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
@ -757,10 +754,8 @@ class FlashCausalLM(Model):
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive # The warmup batch is the biggest batch we could ever receive
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: empty_cache()
torch.cuda.empty_cache()
elif IS_XPU_SYSTEM:
torch.xpu.empty_cache()
try: try:
cache_manager = set_cache_manager( cache_manager = set_cache_manager(
batch.blocks, batch.blocks,
@ -780,10 +775,7 @@ class FlashCausalLM(Model):
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
) from e ) from e
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: synchronize(self.device)
torch.cuda.synchronize(self.device)
elif IS_XPU_SYSTEM:
torch.xpu.synchronize(self.device)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory # Calculate the number of blocks that can be allocated with the free memory
@ -791,20 +783,7 @@ class FlashCausalLM(Model):
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: free_memory = get_free_memory(self.device, MEMORY_FRACTION)
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = torch.cuda.get_device_properties(
self.device
).total_memory
free_memory = max(
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
)
elif IS_XPU_SYSTEM:
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
free_memory = int(total_gpu_memory * 0.5)
else:
raise NotImplementedError("FlashModel is only available on GPU")
num_blocks = ( num_blocks = (
# Leave 5% for some wiggle room # Leave 5% for some wiggle room

View File

@ -18,7 +18,7 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
class FlashLlama(FlashCausalLM): class FlashLlama(FlashCausalLM):
@ -35,7 +35,7 @@ class FlashLlama(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM: elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:

View File

@ -33,7 +33,7 @@ tracer = trace.get_tracer(__name__)
# Will be set in init # Will be set in init
SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
@ -322,7 +322,7 @@ class BaseFlashMistral(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM: elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:

View File

@ -14,7 +14,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -33,7 +33,7 @@ class FlashNeoXSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM: elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:

View File

@ -15,7 +15,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -34,7 +34,7 @@ class FlashRWSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM: elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:

View File

@ -18,7 +18,7 @@ from text_generation_server.utils import (
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -37,7 +37,7 @@ class FlashSantacoderSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM: elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:

View File

@ -2,13 +2,8 @@ import os
import torch import torch
from loguru import logger from loguru import logger
import math
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import SYSTEM
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
@ -16,83 +11,22 @@ HAS_FLASH_ATTN = True
HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False HAS_FLASH_ATTN_V2_ROCM = False
if IS_XPU_SYSTEM: if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: def attention(
if not torch.cuda.is_available(): q,
raise ImportError("CUDA is not available") k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
try:
try:
import flash_attn_2_cuda
except ImportError:
architecture_suffix = ""
if IS_CUDA_SYSTEM:
architecture_suffix = "-cuda"
elif IS_ROCM_SYSTEM:
architecture_suffix = "-rocm"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
if not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif IS_ROCM_SYSTEM:
for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if IS_XPU_SYSTEM:
if window_size_left != -1: if window_size_left != -1:
raise ValueError( raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
@ -114,7 +48,77 @@ def attention(
None, None,
) )
if HAS_FLASH_ATTN_V2_CUDA:
if SYSTEM in {"cuda", "rocm"}:
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
try:
try:
import flash_attn_2_cuda
except ImportError:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
if not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif SYSTEM == "rocm":
for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True
if HAS_FLASH_ATTN_V2_CUDA:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
k, k,
@ -136,7 +140,21 @@ def attention(
False, False,
None, None,
) )
elif HAS_FLASH_ATTN_V2_ROCM:
elif HAS_FLASH_ATTN_V2_ROCM:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1: if window_size_left != -1:
raise ValueError( raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
@ -159,7 +177,19 @@ def attention(
False, False,
None, None,
) )
elif HAS_FLASH_ATTN:
elif HAS_FLASH_ATTN:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1: if window_size_left != -1:
raise NotImplementedError( raise NotImplementedError(
"window_size_left is only available with flash attn v2" "window_size_left is only available with flash attn v2"
@ -209,4 +239,5 @@ def attention(
None, None,
) )
else:
raise NotImplementedError("flash attention is not installed") raise NotImplementedError("flash attention is not installed")

View File

@ -10,6 +10,32 @@ def is_xpu_available():
return hasattr(torch, "xpu") and torch.xpu.is_available() return hasattr(torch, "xpu") and torch.xpu.is_available()
IS_ROCM_SYSTEM = torch.version.hip is not None def get_cuda_free_memory(device, memory_fraction):
IS_CUDA_SYSTEM = torch.version.cuda is not None total_free_memory, _ = torch.cuda.mem_get_info(device)
IS_XPU_SYSTEM = is_xpu_available() total_gpu_memory = torch.cuda.get_device_properties(device).total_memory
free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory)
return free_memory
def get_xpu_free_memory(device):
total_gpu_memory = torch.xpu.get_device_properties(device).total_memory
free_memory = int(total_gpu_memory * 0.5)
return free_memory
SYSTEM = None
if torch.version.hip is not None:
SYSTEM = "rocm"
empty_cache = torch.cuda.empty_cache
synchronize = torch.cuda.synchronize
get_free_memory = get_cuda_free_memory
elif torch.version.cuda is not None:
SYSTEM = "cuda"
empty_cache = torch.cuda.empty_cache
synchronize = torch.cuda.synchronize
get_free_memory = get_cuda_free_memory
elif is_xpu_available():
SYSTEM = "xpu"
empty_cache = torch.xpu.empty_cache
synchronize = torch.xpu.synchronize
get_free_memory = get_xpu_free_memory

View File

@ -1,13 +1,9 @@
import torch import torch
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import SYSTEM
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
if IS_XPU_SYSTEM: if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
@ -18,17 +14,17 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
): ):
if IS_CUDA_SYSTEM: if SYSTEM == "cuda":
from vllm._C import cache_ops from vllm._C import cache_ops
cache_ops.reshape_and_cache( cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0 key, value, key_cache, value_cache, slots, "auto", 1.0
) )
elif IS_ROCM_SYSTEM: elif SYSTEM == "rocm":
from vllm import cache_ops from vllm import cache_ops
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
elif IS_XPU_SYSTEM: elif SYSTEM == "xpu":
ipex.llm.modules.PagedAttention.reshape_and_cache( ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots key, value, key_cache, value_cache, slots
) )
@ -68,7 +64,7 @@ def attention(
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
if IS_XPU_SYSTEM: if SYSTEM == "xpu":
query = query.contiguous() query = query.contiguous()
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,
@ -91,7 +87,7 @@ def attention(
# to parallelize. # to parallelize.
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1: if use_v1:
if IS_CUDA_SYSTEM: if SYSTEM == "cuda":
from vllm._C import ops from vllm._C import ops
ops.paged_attention_v1( ops.paged_attention_v1(
@ -109,7 +105,7 @@ def attention(
"auto", "auto",
1.0, 1.0,
) )
elif IS_ROCM_SYSTEM: elif SYSTEM == "rocm":
from vllm import attention_ops from vllm import attention_ops
attention_ops.paged_attention_v1( attention_ops.paged_attention_v1(
@ -143,7 +139,7 @@ def attention(
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if IS_CUDA_SYSTEM: if SYSTEM == "cuda":
from vllm._C import ops from vllm._C import ops
ops.paged_attention_v2( ops.paged_attention_v2(
@ -164,7 +160,7 @@ def attention(
"auto", "auto",
1.0, 1.0,
) )
elif IS_ROCM_SYSTEM: elif SYSTEM == "rocm":
from vllm import attention_ops from vllm import attention_ops
attention_ops.paged_attention_v2( attention_ops.paged_attention_v2(