diff --git a/Dockerfile b/Dockerfile index f7236e72..72005333 100644 --- a/Dockerfile +++ b/Dockerfile @@ -47,7 +47,7 @@ RUN cargo build --profile release-opt --frozen FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install # NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 -ARG PYTORCH_VERSION=2.4.0 +ARG PYTORCH_VERSION=2.5.1 ARG PYTHON_VERSION=3.11 # Keep in sync with `server/pyproject.toml @@ -235,8 +235,8 @@ RUN cd server && \ make gen-server && \ python -c "from text_generation_server.pb import generate_pb2" && \ pip install -U pip uv && \ - uv pip install -e ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \ - uv pip install nvidia-nccl-cu12==2.22.3 + uv pip install -e ".[attention, bnb, accelerate, compressed-tensors, marlin, moe, quantize, peft, outlines]" --no-cache-dir # && \ + # uv pip install nvidia-nccl-cu12==2.22.3 ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2 # Required to find libpython within the rust binaries diff --git a/server/Makefile-flashinfer b/server/Makefile-flashinfer index d5f684ba..200e9738 100644 --- a/server/Makefile-flashinfer +++ b/server/Makefile-flashinfer @@ -2,4 +2,4 @@ install-flashinfer: # We need fsspec as an additional dependency, but # `pip install flashinfer` cannot resolve it. pip install fsspec - pip install flashinfer==0.2.0.post1 -i https://flashinfer.ai/whl/cu124/torch2.4 + pip install flashinfer-python==0.2.0.post1 diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3b437ce0..36c837cc 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -79,7 +79,7 @@ __all__ = [ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." -FLASH_ATTENTION = True +FLASH_ATTENTION = False try: from text_generation_server.models.flash_causal_lm import FlashCausalLM @@ -931,10 +931,10 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) - elif sharded: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}") - ) + # elif sharded: + # raise NotImplementedError( + # FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}") + # ) else: return transformers_causal_lm_class.fallback( model_id, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d097c54f..1073f4f9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1571,7 +1571,7 @@ class FlashCausalLM(Model): real_free_memory = get_free_memory(self.device, MEMORY_FRACTION) log_master( logger.debug, - f"Free memory {free_memory/1e9:.2f}GB , (real: {real_free_memory/1e9:.2f}GB", + f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB", ) _, _batch, _ = self.generate_token(batch) diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 647fabc2..45d51d82 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -3,7 +3,7 @@ from typing import List, Optional import torch from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +from transformers import AutoTokenizer, AutoModelForCausalLM import transformers.modeling_utils from text_generation_server.models.flash_causal_lm import FlashCausalLM @@ -36,9 +36,11 @@ def tgi_flash_attention_forward( softcap: Optional[float] = None, **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): - kv_cache = kv_cache[module.layer_idx] + import ipdb + + ipdb.set_trace() query_states = query_states.transpose(1, 2).squeeze(dim=0) key_states = key_states.transpose(1, 2).squeeze(dim=0) value_states = value_states.transpose(1, 2).squeeze(dim=0) @@ -95,7 +97,6 @@ class TransformersFlashCausalLM(FlashCausalLM): default_dtype=torch.float16, trust_remote_code: bool = False, tokenizer_class=AutoTokenizer, - config_class=AutoConfig, kv_cache_dtype: Optional[torch.dtype] = None, ): self.quantize = quantize @@ -105,17 +106,17 @@ class TransformersFlashCausalLM(FlashCausalLM): raise RuntimeError("Speculator decoding is not enabled for AutoModel") if torch.cuda.is_available(): - device = torch.device("cuda:0") - dtype = torch.float16 if dtype is None else dtype + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype elif hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device("xpu") - dtype = torch.float16 if dtype is None else dtype + dtype = default_dtype if dtype is None else dtype else: raise ValueError( "Flash `Transformers` modeling backend is not available on cpu." ) - tokenizer = AutoTokenizer.from_pretrained( + tokenizer = tokenizer_class.from_pretrained( model_id, revision=revision, padding_side="left", @@ -126,10 +127,10 @@ class TransformersFlashCausalLM(FlashCausalLM): model_id, revision=revision, torch_dtype=dtype, - device_map="auto", - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - attn_implementation="tgi", + # load_in_8bit=quantize == "bitsandbytes", + # trust_remote_code=trust_remote_code, + # attn_implementation="tgi", + device_map=device if world_size == 1 else None, tp_plan="auto" if world_size > 1 else None, )