mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Tmp branch to test transformers backend with 2.5.1 and TP>1
This commit is contained in:
parent
6d335ca7ce
commit
859d2f0464
@ -47,7 +47,7 @@ RUN cargo build --profile release-opt --frozen
|
|||||||
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
|
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
|
||||||
|
|
||||||
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
|
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
|
||||||
ARG PYTORCH_VERSION=2.4.0
|
ARG PYTORCH_VERSION=2.5.1
|
||||||
|
|
||||||
ARG PYTHON_VERSION=3.11
|
ARG PYTHON_VERSION=3.11
|
||||||
# Keep in sync with `server/pyproject.toml
|
# Keep in sync with `server/pyproject.toml
|
||||||
@ -235,8 +235,8 @@ RUN cd server && \
|
|||||||
make gen-server && \
|
make gen-server && \
|
||||||
python -c "from text_generation_server.pb import generate_pb2" && \
|
python -c "from text_generation_server.pb import generate_pb2" && \
|
||||||
pip install -U pip uv && \
|
pip install -U pip uv && \
|
||||||
uv pip install -e ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \
|
uv pip install -e ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir # && \
|
||||||
uv pip install nvidia-nccl-cu12==2.22.3
|
# uv pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||||
# Required to find libpython within the rust binaries
|
# Required to find libpython within the rust binaries
|
||||||
|
@ -2,4 +2,4 @@ install-flashinfer:
|
|||||||
# We need fsspec as an additional dependency, but
|
# We need fsspec as an additional dependency, but
|
||||||
# `pip install flashinfer` cannot resolve it.
|
# `pip install flashinfer` cannot resolve it.
|
||||||
pip install fsspec
|
pip install fsspec
|
||||||
pip install flashinfer==0.2.0.post1 -i https://flashinfer.ai/whl/cu124/torch2.4
|
pip install flashinfer-python==0.2.0.post1
|
||||||
|
@ -79,7 +79,7 @@ __all__ = [
|
|||||||
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
|
|
||||||
FLASH_ATTENTION = True
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
@ -931,10 +931,10 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
elif sharded:
|
# elif sharded:
|
||||||
raise NotImplementedError(
|
# raise NotImplementedError(
|
||||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
# FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
|
||||||
)
|
# )
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return transformers_causal_lm_class.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -3,7 +3,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
|
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
@ -36,9 +36,11 @@ def tgi_flash_attention_forward(
|
|||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||||
):
|
):
|
||||||
|
|
||||||
kv_cache = kv_cache[module.layer_idx]
|
kv_cache = kv_cache[module.layer_idx]
|
||||||
|
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
query_states = query_states.transpose(1, 2).squeeze(dim=0)
|
||||||
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
key_states = key_states.transpose(1, 2).squeeze(dim=0)
|
||||||
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
value_states = value_states.transpose(1, 2).squeeze(dim=0)
|
||||||
@ -95,7 +97,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
default_dtype=torch.float16,
|
default_dtype=torch.float16,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
tokenizer_class=AutoTokenizer,
|
tokenizer_class=AutoTokenizer,
|
||||||
config_class=AutoConfig,
|
|
||||||
kv_cache_dtype: Optional[torch.dtype] = None,
|
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
@ -105,17 +106,17 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda:0")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
device = torch.device("xpu")
|
device = torch.device("xpu")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = default_dtype if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Flash `Transformers` modeling backend is not available on cpu."
|
"Flash `Transformers` modeling backend is not available on cpu."
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = tokenizer_class.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
@ -126,10 +127,10 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto",
|
# load_in_8bit=quantize == "bitsandbytes",
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
# trust_remote_code=trust_remote_code,
|
||||||
trust_remote_code=trust_remote_code,
|
# attn_implementation="tgi",
|
||||||
attn_implementation="tgi",
|
device_map=device if world_size == 1 else None,
|
||||||
tp_plan="auto" if world_size > 1 else None,
|
tp_plan="auto" if world_size > 1 else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user