mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
style
This commit is contained in:
parent
47447ef017
commit
0ad78d20a5
@ -92,7 +92,7 @@ RUN chmod +x ~/mambaforge.sh && \
|
||||
# Install flash-attention, torch dependencies
|
||||
RUN pip install numpy einops ninja --no-cache-dir
|
||||
|
||||
RUN conda install intel::mkl-static intel::mkl-include
|
||||
RUN conda install mkl-static mkl-include
|
||||
RUN pip uninstall -y triton && \
|
||||
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
||||
cd triton/python && \
|
||||
|
@ -1,5 +1,5 @@
|
||||
flash_att_v2_commit_cuda := v2.6.1
|
||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||
flash_att_v2_commit_rocm := d83c4129a92e4258081f92dfafd34345b3b06130
|
||||
|
||||
build-flash-attention-v2-cuda:
|
||||
pip install -U packaging wheel
|
||||
|
@ -1,5 +1,5 @@
|
||||
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
||||
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
||||
commit_rocm := c06ccbf90a213688a2c6a85d2e7af3da7bc4b41b
|
||||
build-vllm-cuda:
|
||||
if [ ! -d 'vllm' ]; then \
|
||||
pip install -U ninja packaging --no-cache-dir && \
|
||||
@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
|
||||
build-vllm-rocm:
|
||||
if [ ! -d 'vllm' ]; then \
|
||||
pip install -U ninja packaging --no-cache-dir && \
|
||||
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
|
||||
git clone https://github.com/mht-sharma/vllm.git vllm; \
|
||||
fi
|
||||
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||
|
@ -14,7 +14,7 @@ use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true",
|
||||
ENGINE = "triton" if use_triton else "ck"
|
||||
|
||||
try:
|
||||
from vllm._C import cache_ops
|
||||
import vllm._custom_ops as ops
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
@ -33,9 +33,7 @@ def reshape_and_cache(
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
@ -78,7 +76,7 @@ def paged_attention(
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
from vllm._C import ops
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
if use_v1:
|
||||
@ -180,6 +178,7 @@ if ENGINE == "ck":
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap=0.0,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
@ -194,12 +193,19 @@ if ENGINE == "ck":
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
window_size_left,
|
||||
0,
|
||||
softcap,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
@ -313,11 +313,15 @@ class LlamaMLP(nn.Module):
|
||||
# TODO: This is a hotfix to be removed & properly refactored.
|
||||
self.quantize = config.quantize
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.shape[0] == 1
|
||||
and self.hidden_size
|
||||
!= 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.
|
||||
and not self.quantize
|
||||
):
|
||||
out = torch.empty(
|
||||
|
Loading…
Reference in New Issue
Block a user