mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
style
This commit is contained in:
parent
47447ef017
commit
0ad78d20a5
@ -92,7 +92,7 @@ RUN chmod +x ~/mambaforge.sh && \
|
|||||||
# Install flash-attention, torch dependencies
|
# Install flash-attention, torch dependencies
|
||||||
RUN pip install numpy einops ninja --no-cache-dir
|
RUN pip install numpy einops ninja --no-cache-dir
|
||||||
|
|
||||||
RUN conda install intel::mkl-static intel::mkl-include
|
RUN conda install mkl-static mkl-include
|
||||||
RUN pip uninstall -y triton && \
|
RUN pip uninstall -y triton && \
|
||||||
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
||||||
cd triton/python && \
|
cd triton/python && \
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
flash_att_v2_commit_cuda := v2.6.1
|
flash_att_v2_commit_cuda := v2.6.1
|
||||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
flash_att_v2_commit_rocm := d83c4129a92e4258081f92dfafd34345b3b06130
|
||||||
|
|
||||||
build-flash-attention-v2-cuda:
|
build-flash-attention-v2-cuda:
|
||||||
pip install -U packaging wheel
|
pip install -U packaging wheel
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
||||||
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
commit_rocm := c06ccbf90a213688a2c6a85d2e7af3da7bc4b41b
|
||||||
build-vllm-cuda:
|
build-vllm-cuda:
|
||||||
if [ ! -d 'vllm' ]; then \
|
if [ ! -d 'vllm' ]; then \
|
||||||
pip install -U ninja packaging --no-cache-dir && \
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
|
|||||||
build-vllm-rocm:
|
build-vllm-rocm:
|
||||||
if [ ! -d 'vllm' ]; then \
|
if [ ! -d 'vllm' ]; then \
|
||||||
pip install -U ninja packaging --no-cache-dir && \
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
|
git clone https://github.com/mht-sharma/vllm.git vllm; \
|
||||||
fi
|
fi
|
||||||
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||||
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||||
|
@ -14,7 +14,7 @@ use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true",
|
|||||||
ENGINE = "triton" if use_triton else "ck"
|
ENGINE = "triton" if use_triton else "ck"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm._C import cache_ops
|
import vllm._custom_ops as ops
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
@ -33,9 +33,7 @@ def reshape_and_cache(
|
|||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
else:
|
else:
|
||||||
cache_ops.reshape_and_cache(
|
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
@ -78,7 +76,7 @@ def paged_attention(
|
|||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
from vllm._C import ops
|
import vllm._custom_ops as ops
|
||||||
|
|
||||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
if use_v1:
|
if use_v1:
|
||||||
@ -180,6 +178,7 @@ if ENGINE == "ck":
|
|||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
softcap=0.0,
|
||||||
):
|
):
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
@ -194,12 +193,19 @@ if ENGINE == "ck":
|
|||||||
out,
|
out,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
causal,
|
causal,
|
||||||
|
window_size_left,
|
||||||
|
0,
|
||||||
|
softcap,
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
@ -313,11 +313,15 @@ class LlamaMLP(nn.Module):
|
|||||||
# TODO: This is a hotfix to be removed & properly refactored.
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
self.quantize = config.quantize
|
self.quantize = config.quantize
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
def forward(self, hidden_states, adapter_data):
|
def forward(self, hidden_states, adapter_data):
|
||||||
if (
|
if (
|
||||||
SYSTEM == "rocm"
|
SYSTEM == "rocm"
|
||||||
and self.hidden_act == "silu"
|
and self.hidden_act == "silu"
|
||||||
and hidden_states.shape[0] == 1
|
and hidden_states.shape[0] == 1
|
||||||
|
and self.hidden_size
|
||||||
|
!= 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.
|
||||||
and not self.quantize
|
and not self.quantize
|
||||||
):
|
):
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
|
Loading…
Reference in New Issue
Block a user