mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
black
This commit is contained in:
parent
0812e3bdc9
commit
265c76d328
2
.github/workflows/build.yaml
vendored
2
.github/workflows/build.yaml
vendored
@ -425,4 +425,4 @@ jobs:
|
|||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||||
pytest -s -vv integration-tests
|
pytest -s -vv integration-tests
|
||||||
|
@ -35,4 +35,4 @@ By default, as its performances have experimentally been better, Triton implemen
|
|||||||
|
|
||||||
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
||||||
* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
|
* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
|
||||||
* Kernel for sliding window attention (Mistral)
|
* Kernel for sliding window attention (Mistral)
|
||||||
|
@ -11,7 +11,7 @@ __device__ __forceinline__ __half __compat_hrcp(__half x) {
|
|||||||
|
|
||||||
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
||||||
return _Float16_2{
|
return _Float16_2{
|
||||||
_Float16_2{static_cast<_Float16>(1.0f),
|
_Float16_2{static_cast<_Float16>(1.0f),
|
||||||
static_cast<_Float16>(1.0f)} / x.data};
|
static_cast<_Float16>(1.0f)} / x.data};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ if SYSTEM == "rocm":
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
class FastLinear(torch.nn.Module):
|
class FastLinear(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -47,7 +47,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -197,7 +197,9 @@ class LlamaMLP(nn.Module):
|
|||||||
else lambda x: torch.nn.functional.gelu(
|
else lambda x: torch.nn.functional.gelu(
|
||||||
x,
|
x,
|
||||||
approximate=(
|
approximate=(
|
||||||
"tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
"tanh"
|
||||||
|
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -229,7 +231,11 @@ class LlamaMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
if (
|
||||||
|
SYSTEM == "rocm"
|
||||||
|
and self.hidden_act == "silu"
|
||||||
|
and hidden_states.shape[0] == 1
|
||||||
|
):
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
hidden_states.shape[0],
|
hidden_states.shape[0],
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
|
@ -266,7 +266,9 @@ class MistralMLP(nn.Module):
|
|||||||
else lambda x: torch.nn.functional.gelu(
|
else lambda x: torch.nn.functional.gelu(
|
||||||
x,
|
x,
|
||||||
approximate=(
|
approximate=(
|
||||||
"tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
"tanh"
|
||||||
|
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -289,7 +291,11 @@ class MistralMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
if (
|
||||||
|
SYSTEM == "rocm"
|
||||||
|
and self.hidden_act == "silu"
|
||||||
|
and hidden_states.shape[0] == 1
|
||||||
|
):
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
hidden_states.shape[0],
|
hidden_states.shape[0],
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
|
@ -64,6 +64,7 @@ elif SYSTEM == "rocm":
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupported system {SYSTEM}")
|
raise RuntimeError(f"Unsupported system {SYSTEM}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
||||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||||
|
@ -831,28 +831,40 @@ class FlashCausalLM(Model):
|
|||||||
torch.cuda.tunable.tuning_enable(True)
|
torch.cuda.tunable.tuning_enable(True)
|
||||||
|
|
||||||
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False):
|
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS", False):
|
||||||
tuning_sequences = [int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")]
|
tuning_sequences = [
|
||||||
|
int(val)
|
||||||
|
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
tuning_sequences = [1, 2, 4, 8, 16, 32]
|
tuning_sequences = [1, 2, 4, 8, 16, 32]
|
||||||
|
|
||||||
tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv")
|
|
||||||
|
|
||||||
logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.")
|
tunableop_filepath = os.path.join(
|
||||||
|
"/data",
|
||||||
|
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}."
|
||||||
|
)
|
||||||
|
|
||||||
if os.path.isfile(tunableop_filepath):
|
if os.path.isfile(tunableop_filepath):
|
||||||
logger.info(f"The file {tunableop_filepath} already exists and will be reused.")
|
logger.info(
|
||||||
|
f"The file {tunableop_filepath} already exists and will be reused."
|
||||||
|
)
|
||||||
torch.cuda.tunable.read_file(tunableop_filepath)
|
torch.cuda.tunable.read_file(tunableop_filepath)
|
||||||
|
|
||||||
os.makedirs("/data", exist_ok=True)
|
os.makedirs("/data", exist_ok=True)
|
||||||
|
|
||||||
for seqlen in tuning_sequences:
|
for seqlen in tuning_sequences:
|
||||||
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
||||||
self.tunableop_warmup(seqlen)
|
self.tunableop_warmup(seqlen)
|
||||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||||
torch.cuda.tunable.tuning_enable(False)
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
else:
|
else:
|
||||||
logger.info("PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.")
|
logger.info(
|
||||||
|
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
|
||||||
|
)
|
||||||
|
|
||||||
if CUDA_GRAPHS:
|
if CUDA_GRAPHS:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||||
@ -877,7 +889,9 @@ class FlashCausalLM(Model):
|
|||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=torch.tensor([0, seqlen], device=self.device, dtype=torch.int32),
|
cu_seqlen_prefill=torch.tensor(
|
||||||
|
[0, seqlen], device=self.device, dtype=torch.int32
|
||||||
|
),
|
||||||
kv_cache=get_cache_manager().kv_cache,
|
kv_cache=get_cache_manager().kv_cache,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
input_lengths=None,
|
input_lengths=None,
|
||||||
|
@ -29,7 +29,7 @@ class FlashCohere(FlashCausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -31,7 +31,7 @@ class FlashDbrx(FlashCausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -29,7 +29,7 @@ class FlashGemma(FlashCausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -32,7 +32,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -319,7 +319,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
tokenizer_class=AutoTokenizer,
|
tokenizer_class=AutoTokenizer,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -30,7 +30,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -30,7 +30,7 @@ class FlashPhi(FlashCausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -35,7 +35,7 @@ class FlashQwen2(BaseFlashMistral):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -31,7 +31,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -34,7 +34,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -34,7 +34,7 @@ class FlashStarcoder2(BaseFlashMistral):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -172,7 +172,7 @@ class GalacticaSharded(CausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -29,7 +29,7 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -36,7 +36,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -413,7 +413,7 @@ class Mamba(Model):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, _rank, world_size = initialize_torch_distributed()
|
self.process_group, _rank, world_size = initialize_torch_distributed()
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")
|
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")
|
||||||
|
@ -48,7 +48,7 @@ class MPTSharded(CausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -27,7 +27,7 @@ class OPTSharded(CausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -27,7 +27,7 @@ class Phi(CausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, _rank, _world_size = initialize_torch_distributed()
|
self.process_group, _rank, _world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
@ -24,7 +24,7 @@ class SantaCoder(CausalLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
@ -30,7 +30,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
@ -187,6 +187,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
total_ns=time.time_ns() - start,
|
total_ns=time.time_ns() - start,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
|
@ -64,12 +64,17 @@ if SYSTEM in {"cuda", "rocm"}:
|
|||||||
is_sm94 = major == 9 and minor == 4
|
is_sm94 = major == 9 and minor == 4
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
if os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true" or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1":
|
if (
|
||||||
|
os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
|
||||||
|
or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
|
||||||
|
):
|
||||||
ROCM_USE_FLASH_ATTN_V2_TRITON = True
|
ROCM_USE_FLASH_ATTN_V2_TRITON = True
|
||||||
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
|
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
|
||||||
else:
|
else:
|
||||||
ROCM_USE_FLASH_ATTN_V2_CK = True
|
ROCM_USE_FLASH_ATTN_V2_CK = True
|
||||||
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
logger.info(
|
||||||
|
"ROCm: using Flash Attention 2 Composable Kernel implementation."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
@ -158,6 +163,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@ -192,8 +198,9 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
|||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@ -217,6 +224,7 @@ elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
|||||||
softmax_scale,
|
softmax_scale,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN:
|
elif HAS_FLASH_ATTN:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
|
@ -46,16 +46,16 @@ def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||||
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
|
rng_offsets = dropout_offsets(
|
||||||
stride).to(tl.uint32)
|
philox_seed, philox_offset, dropout_p, m, n, stride
|
||||||
|
).to(tl.uint32)
|
||||||
# TODO: use tl.randint for better performance
|
# TODO: use tl.randint for better performance
|
||||||
return tl.rand(philox_seed, rng_offsets)
|
return tl.rand(philox_seed, rng_offsets)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||||
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
|
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
|
||||||
stride)
|
|
||||||
rng_keep = rng_output > dropout_p
|
rng_keep = rng_output > dropout_p
|
||||||
return rng_keep
|
return rng_keep
|
||||||
|
|
||||||
@ -65,9 +65,9 @@ def load_fn(block_ptr, first, second, pad):
|
|||||||
if first and second:
|
if first and second:
|
||||||
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
||||||
elif first:
|
elif first:
|
||||||
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
|
tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)
|
||||||
elif second:
|
elif second:
|
||||||
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
|
tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)
|
||||||
else:
|
else:
|
||||||
tensor = tl.load(block_ptr)
|
tensor = tl.load(block_ptr)
|
||||||
return tensor
|
return tensor
|
||||||
@ -133,9 +133,7 @@ def _attn_fwd_inner(
|
|||||||
# if not is_modulo_mn. last step might get wasted but that is okay.
|
# if not is_modulo_mn. last step might get wasted but that is okay.
|
||||||
# check if this masking works for that case.
|
# check if this masking works for that case.
|
||||||
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
||||||
boundary_m = tl.full([BLOCK_M],
|
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
|
||||||
actual_seqlen_k,
|
|
||||||
dtype=tl.int32)
|
|
||||||
size_n = start_n + OFFS_N[None, :]
|
size_n = start_n + OFFS_N[None, :]
|
||||||
mask = size_n < boundary_m[:, None]
|
mask = size_n < boundary_m[:, None]
|
||||||
qk = tl.where(mask, qk, float("-inf"))
|
qk = tl.where(mask, qk, float("-inf"))
|
||||||
@ -146,8 +144,9 @@ def _attn_fwd_inner(
|
|||||||
# -- compute qk ----
|
# -- compute qk ----
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
if bias_ptr is not None:
|
if bias_ptr is not None:
|
||||||
bias = load_fn(bias_ptr, False, MASK_STEPS
|
bias = load_fn(
|
||||||
and (n_extra_tokens != 0), "zero")
|
bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero"
|
||||||
|
)
|
||||||
# While bias is added after multiplying qk with sm_scale, our
|
# While bias is added after multiplying qk with sm_scale, our
|
||||||
# optimization to use 2^x instead of e^x results in an additional
|
# optimization to use 2^x instead of e^x results in an additional
|
||||||
# scale factor of log2(e) which we must also multiply the bias with.
|
# scale factor of log2(e) which we must also multiply the bias with.
|
||||||
@ -159,9 +158,12 @@ def _attn_fwd_inner(
|
|||||||
# CAVEAT: Must update l_ij before applying dropout
|
# CAVEAT: Must update l_ij before applying dropout
|
||||||
l_ij = tl.sum(p, 1)
|
l_ij = tl.sum(p, 1)
|
||||||
if ENABLE_DROPOUT:
|
if ENABLE_DROPOUT:
|
||||||
philox_offset = (batch_philox_offset +
|
philox_offset = (
|
||||||
start_m * BLOCK_M * actual_seqlen_k + start_n -
|
batch_philox_offset
|
||||||
BLOCK_N)
|
+ start_m * BLOCK_M * actual_seqlen_k
|
||||||
|
+ start_n
|
||||||
|
- BLOCK_N
|
||||||
|
)
|
||||||
keep = dropout_mask(
|
keep = dropout_mask(
|
||||||
philox_seed,
|
philox_seed,
|
||||||
philox_offset,
|
philox_offset,
|
||||||
@ -173,8 +175,7 @@ def _attn_fwd_inner(
|
|||||||
if RETURN_ENCODED_SOFTMAX:
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
tl.store(
|
tl.store(
|
||||||
encoded_softmax_block_ptr,
|
encoded_softmax_block_ptr,
|
||||||
tl.where(keep, p,
|
tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),
|
||||||
-p).to(encoded_softmax_block_ptr.type.element_ty),
|
|
||||||
)
|
)
|
||||||
p = tl.where(keep, p, 0.0)
|
p = tl.where(keep, p, 0.0)
|
||||||
elif RETURN_ENCODED_SOFTMAX:
|
elif RETURN_ENCODED_SOFTMAX:
|
||||||
@ -202,8 +203,9 @@ def _attn_fwd_inner(
|
|||||||
if bias_ptr is not None:
|
if bias_ptr is not None:
|
||||||
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
||||||
if RETURN_ENCODED_SOFTMAX:
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
encoded_softmax_block_ptr = tl.advance(
|
||||||
(0, BLOCK_N))
|
encoded_softmax_block_ptr, (0, BLOCK_N)
|
||||||
|
)
|
||||||
return acc, l_i, m_i
|
return acc, l_i, m_i
|
||||||
|
|
||||||
|
|
||||||
@ -341,7 +343,7 @@ def attn_fwd(
|
|||||||
philox_offset_base,
|
philox_offset_base,
|
||||||
encoded_softmax,
|
encoded_softmax,
|
||||||
HQ: tl.constexpr,
|
HQ: tl.constexpr,
|
||||||
HK:tl.constexpr,
|
HK: tl.constexpr,
|
||||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||||
MAX_SEQLENS_Q: tl.constexpr,
|
MAX_SEQLENS_Q: tl.constexpr,
|
||||||
MAX_SEQLENS_K: tl.constexpr,
|
MAX_SEQLENS_K: tl.constexpr,
|
||||||
@ -392,15 +394,17 @@ def attn_fwd(
|
|||||||
# This captures the decrease in n_blocks if we have a rectangular attn
|
# This captures the decrease in n_blocks if we have a rectangular attn
|
||||||
# matrix
|
# matrix
|
||||||
n_blocks_seqlen = cdiv_fn(
|
n_blocks_seqlen = cdiv_fn(
|
||||||
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
|
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N
|
||||||
|
)
|
||||||
# This is what adjusts the block_max for the current WG, only
|
# This is what adjusts the block_max for the current WG, only
|
||||||
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
||||||
n_blocks = min(n_blocks, n_blocks_seqlen)
|
n_blocks = min(n_blocks, n_blocks_seqlen)
|
||||||
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
||||||
# part of the blocks that are all 0. We exit early.
|
# part of the blocks that are all 0. We exit early.
|
||||||
if n_blocks <= 0:
|
if n_blocks <= 0:
|
||||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
o_offset = (
|
||||||
off_h_q * stride_oh)
|
off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
||||||
|
)
|
||||||
O_block_ptr = tl.make_block_ptr(
|
O_block_ptr = tl.make_block_ptr(
|
||||||
base=Out + o_offset,
|
base=Out + o_offset,
|
||||||
shape=(seqlen_q, BLOCK_DMODEL),
|
shape=(seqlen_q, BLOCK_DMODEL),
|
||||||
@ -436,11 +440,10 @@ def attn_fwd(
|
|||||||
n_extra_tokens = BLOCK_N - seqlen_k
|
n_extra_tokens = BLOCK_N - seqlen_k
|
||||||
elif seqlen_k % BLOCK_N:
|
elif seqlen_k % BLOCK_N:
|
||||||
n_extra_tokens = seqlen_k % BLOCK_N
|
n_extra_tokens = seqlen_k % BLOCK_N
|
||||||
PADDED_HEAD:tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL)
|
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||||
|
|
||||||
# Compute pointers for all the tensors used in this kernel.
|
# Compute pointers for all the tensors used in this kernel.
|
||||||
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
|
q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
|
||||||
cu_seqlens_q_start * stride_qm)
|
|
||||||
Q_block_ptr = tl.make_block_ptr(
|
Q_block_ptr = tl.make_block_ptr(
|
||||||
base=Q + q_offset,
|
base=Q + q_offset,
|
||||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||||
@ -449,8 +452,7 @@ def attn_fwd(
|
|||||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||||
order=(1, 0),
|
order=(1, 0),
|
||||||
)
|
)
|
||||||
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
|
k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
|
||||||
cu_seqlens_k_start * stride_kn)
|
|
||||||
K_block_ptr = tl.make_block_ptr(
|
K_block_ptr = tl.make_block_ptr(
|
||||||
base=K + k_offset,
|
base=K + k_offset,
|
||||||
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
||||||
@ -459,8 +461,7 @@ def attn_fwd(
|
|||||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||||
order=(0, 1),
|
order=(0, 1),
|
||||||
)
|
)
|
||||||
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
|
v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
|
||||||
cu_seqlens_k_start * stride_vk)
|
|
||||||
V_block_ptr = tl.make_block_ptr(
|
V_block_ptr = tl.make_block_ptr(
|
||||||
base=V + v_offset,
|
base=V + v_offset,
|
||||||
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
||||||
@ -481,9 +482,9 @@ def attn_fwd(
|
|||||||
else:
|
else:
|
||||||
bias_ptr = None
|
bias_ptr = None
|
||||||
if ENABLE_DROPOUT:
|
if ENABLE_DROPOUT:
|
||||||
batch_philox_offset = philox_offset_base \
|
batch_philox_offset = (
|
||||||
+ (off_z * HQ + off_h_q) \
|
philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
|
||||||
* seqlen_q * seqlen_k
|
)
|
||||||
else:
|
else:
|
||||||
batch_philox_offset = 0
|
batch_philox_offset = 0
|
||||||
# We can ask to return the dropout mask without actually doing any dropout.
|
# We can ask to return the dropout mask without actually doing any dropout.
|
||||||
@ -578,8 +579,9 @@ def attn_fwd(
|
|||||||
if bias_ptr is not None:
|
if bias_ptr is not None:
|
||||||
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
||||||
if RETURN_ENCODED_SOFTMAX:
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
encoded_softmax_block_ptr = tl.advance(
|
||||||
(0, n_full_blocks))
|
encoded_softmax_block_ptr, (0, n_full_blocks)
|
||||||
|
)
|
||||||
acc, l_i, m_i = _attn_fwd_inner(
|
acc, l_i, m_i = _attn_fwd_inner(
|
||||||
acc,
|
acc,
|
||||||
l_i,
|
l_i,
|
||||||
@ -626,12 +628,11 @@ def attn_fwd(
|
|||||||
acc = acc.to(Out.type.element_ty)
|
acc = acc.to(Out.type.element_ty)
|
||||||
if IS_CAUSAL: # noqa: SIM102
|
if IS_CAUSAL: # noqa: SIM102
|
||||||
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
||||||
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
|
out_mask_boundary = tl.full(
|
||||||
causal_start_idx,
|
(BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32
|
||||||
dtype=tl.int32)
|
)
|
||||||
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
||||||
out_ptrs_mask = (mask_m_offsets[:, None] >=
|
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
|
||||||
out_mask_boundary[None, :])
|
|
||||||
z = 0.0
|
z = 0.0
|
||||||
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||||
# write back LSE
|
# write back LSE
|
||||||
@ -649,8 +650,7 @@ def attn_fwd(
|
|||||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||||
|
|
||||||
# write back O
|
# write back O
|
||||||
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
||||||
off_h_q * stride_oh)
|
|
||||||
O_block_ptr = tl.make_block_ptr(
|
O_block_ptr = tl.make_block_ptr(
|
||||||
base=Out + o_offset,
|
base=Out + o_offset,
|
||||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||||
@ -813,4 +813,4 @@ class _attention(torch.autograd.Function):
|
|||||||
return o, encoded_softmax
|
return o, encoded_softmax
|
||||||
|
|
||||||
|
|
||||||
triton_attention = _attention.apply
|
triton_attention = _attention.apply
|
||||||
|
@ -10,7 +10,9 @@ else:
|
|||||||
from vllm._C import cache_ops
|
from vllm._C import cache_ops
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}")
|
raise ImportError(
|
||||||
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
def reshape_and_cache(
|
||||||
|
Loading…
Reference in New Issue
Block a user