Tmp branch to test transformers backend with 2.5.1 and TP>1

This commit is contained in:
Nicolas Patry 2025-01-22 17:33:08 +01:00
parent 6d335ca7ce
commit 859d2f0464
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
5 changed files with 22 additions and 21 deletions

View File

@ -47,7 +47,7 @@ RUN cargo build --profile release-opt --frozen
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 # NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.4.0 ARG PYTORCH_VERSION=2.5.1
ARG PYTHON_VERSION=3.11 ARG PYTHON_VERSION=3.11
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
@ -235,8 +235,8 @@ RUN cd server && \
make gen-server && \ make gen-server && \
python -c "from text_generation_server.pb import generate_pb2" && \ python -c "from text_generation_server.pb import generate_pb2" && \
pip install -U pip uv && \ pip install -U pip uv && \
uv pip install -e ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \ uv pip install -e ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir # && \
uv pip install nvidia-nccl-cu12==2.22.3 # uv pip install nvidia-nccl-cu12==2.22.3
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
# Required to find libpython within the rust binaries # Required to find libpython within the rust binaries

View File

@ -2,4 +2,4 @@ install-flashinfer:
# We need fsspec as an additional dependency, but # We need fsspec as an additional dependency, but
# `pip install flashinfer` cannot resolve it. # `pip install flashinfer` cannot resolve it.
pip install fsspec pip install fsspec
pip install flashinfer==0.2.0.post1 -i https://flashinfer.ai/whl/cu124/torch2.4 pip install flashinfer-python==0.2.0.post1

View File

@ -79,7 +79,7 @@ __all__ = [
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = True FLASH_ATTENTION = False
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
@ -931,10 +931,10 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif sharded: # elif sharded:
raise NotImplementedError( # raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}") # FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
) # )
else: else:
return transformers_causal_lm_class.fallback( return transformers_causal_lm_class.fallback(
model_id, model_id,

View File

@ -1571,7 +1571,7 @@ class FlashCausalLM(Model):
real_free_memory = get_free_memory(self.device, MEMORY_FRACTION) real_free_memory = get_free_memory(self.device, MEMORY_FRACTION)
log_master( log_master(
logger.debug, logger.debug,
f"Free memory {free_memory/1e9:.2f}GB , (real: {real_free_memory/1e9:.2f}GB", f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB",
) )
_, _batch, _ = self.generate_token(batch) _, _batch, _ = self.generate_token(batch)

View File

@ -3,7 +3,7 @@ from typing import List, Optional
import torch import torch
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers.modeling_utils import transformers.modeling_utils
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
@ -36,9 +36,11 @@ def tgi_flash_attention_forward(
softcap: Optional[float] = None, softcap: Optional[float] = None,
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling **kwargs, # This is needed to "absorb" other args passed by Transformers modeling
): ):
kv_cache = kv_cache[module.layer_idx] kv_cache = kv_cache[module.layer_idx]
import ipdb
ipdb.set_trace()
query_states = query_states.transpose(1, 2).squeeze(dim=0) query_states = query_states.transpose(1, 2).squeeze(dim=0)
key_states = key_states.transpose(1, 2).squeeze(dim=0) key_states = key_states.transpose(1, 2).squeeze(dim=0)
value_states = value_states.transpose(1, 2).squeeze(dim=0) value_states = value_states.transpose(1, 2).squeeze(dim=0)
@ -95,7 +97,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
default_dtype=torch.float16, default_dtype=torch.float16,
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer, tokenizer_class=AutoTokenizer,
config_class=AutoConfig,
kv_cache_dtype: Optional[torch.dtype] = None, kv_cache_dtype: Optional[torch.dtype] = None,
): ):
self.quantize = quantize self.quantize = quantize
@ -105,17 +106,17 @@ class TransformersFlashCausalLM(FlashCausalLM):
raise RuntimeError("Speculator decoding is not enabled for AutoModel") raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda:0") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available(): elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu") device = torch.device("xpu")
dtype = torch.float16 if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
else: else:
raise ValueError( raise ValueError(
"Flash `Transformers` modeling backend is not available on cpu." "Flash `Transformers` modeling backend is not available on cpu."
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = tokenizer_class.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
padding_side="left", padding_side="left",
@ -126,10 +127,10 @@ class TransformersFlashCausalLM(FlashCausalLM):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto", # load_in_8bit=quantize == "bitsandbytes",
load_in_8bit=quantize == "bitsandbytes", # trust_remote_code=trust_remote_code,
trust_remote_code=trust_remote_code, # attn_implementation="tgi",
attn_implementation="tgi", device_map=device if world_size == 1 else None,
tp_plan="auto" if world_size > 1 else None, tp_plan="auto" if world_size > 1 else None,
) )