From 4baa6ff59f45cda2bbdc30289cfc4357a0b8b426 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 14 Aug 2024 11:58:08 +0200 Subject: [PATCH] Upgrading exl2. (#2415) * Upgrading exl2. * Fixing the other pathways. * Fix idefics. --- .gitignore | 2 +- flake.nix | 1 + server/Makefile | 1 + server/Makefile-exllamav2 | 12 ++++++++++++ server/text_generation_server/models/causal_lm.py | 1 + .../text_generation_server/models/flash_causal_lm.py | 1 + server/text_generation_server/models/idefics.py | 1 + .../models/idefics_causal_lm.py | 1 + server/text_generation_server/models/seq2seq_lm.py | 1 + server/text_generation_server/server.py | 6 +++--- 10 files changed, 23 insertions(+), 4 deletions(-) create mode 100644 server/Makefile-exllamav2 diff --git a/.gitignore b/.gitignore index bd9d9125..f79d8faa 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,7 @@ backends/client/src/v3/pb # ROCm auto-generated files *.hip -server/exllamav2_kernels/exllamav2_kernels/hip/ +server/exllamav2 server/exllama_kernels/exllama_kernels/hip/ server/exllama_kernels/exllama_kernels/hip_func/ *_hip.cuh diff --git a/flake.nix b/flake.nix index e1f44212..229184d2 100644 --- a/flake.nix +++ b/flake.nix @@ -93,6 +93,7 @@ causal-conv1d click einops + exllamav2 fbgemm-gpu flashinfer flash-attn diff --git a/server/Makefile b/server/Makefile index 209fc44e..51ea8b32 100644 --- a/server/Makefile +++ b/server/Makefile @@ -6,6 +6,7 @@ include Makefile-eetq include Makefile-selective-scan include Makefile-lorax-punica include Makefile-fbgemm +include Makefile-exllamav2 unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-exllamav2 b/server/Makefile-exllamav2 new file mode 100644 index 00000000..0d4cc385 --- /dev/null +++ b/server/Makefile-exllamav2 @@ -0,0 +1,12 @@ +exllamav2_commit := v0.1.8 + +build-exllamav2: + git clone https://github.com/turboderp/exllamav2.git exllamav2 && \ + cd exllamav2 && git fetch && git checkout $(exllamav2_commit) && \ + git submodule update --init --recursive && \ + pip install -r requirements.txt && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py build + +install-exllamav2: build-exllamav2 + cd exllamav2/ && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py install diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 212ab7a9..ba168b13 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -511,6 +511,7 @@ class CausalLM(Model): config_class=AutoConfig, batch_class=CausalLMBatch, ): + self.quantize = quantize self.batch_class = batch_class self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 42d93a12..5e2fd20a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -872,6 +872,7 @@ class FlashCausalLM(Model): head_size: Optional[int] = None, skip_special_tokens: bool = True, ): + self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index 29929b98..9058cb96 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -33,6 +33,7 @@ class IDEFICSSharded(IdeficsCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 8a80ed68..c5480952 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -580,6 +580,7 @@ class IdeficsCausalLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.quantize = quantize from text_generation_server.models.custom_modeling.idefics_modeling import ( IdeficsForVisionText2Text, ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 79c001b0..3c92128a 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -553,6 +553,7 @@ class Seq2SeqLM(Model): tokenizer_class=AutoTokenizer, aliases=None, ): + self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index b92ab572..22871ec5 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -50,12 +50,12 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): self, model: Model, cache: Cache, - quantize: Optional[str], server_urls: List[str], ): self.cache = cache self.model = model - self.quantize = quantize + # Quantize is resolved during model loading + self.quantize = model.quantize self.server_urls = server_urls # For some reason, inference_mode does not work well with GLOO which we use on CPU if model.device.type == "cuda": @@ -255,7 +255,7 @@ def serve( ], ) generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( - TextGenerationService(model, Cache(), quantize, server_urls), server + TextGenerationService(model, Cache(), server_urls), server ) SERVICE_NAMES = ( generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,