From f4ce670eb03192d3c6af72e589d1209439ca565d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 14 Aug 2024 16:30:46 +0200 Subject: [PATCH] Fixing exl2 and other quanize tests again. --- server/Makefile-exllamav2 | 4 ++-- server/text_generation_server/models/causal_lm.py | 1 + server/text_generation_server/models/mamba.py | 1 + server/text_generation_server/models/seq2seq_lm.py | 1 + 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/server/Makefile-exllamav2 b/server/Makefile-exllamav2 index 0d4cc385..38abeffe 100644 --- a/server/Makefile-exllamav2 +++ b/server/Makefile-exllamav2 @@ -1,7 +1,7 @@ -exllamav2_commit := v0.1.8 +exllamav2_commit := 872386c89eaebe0bde5b245a890f1da9522768b3 build-exllamav2: - git clone https://github.com/turboderp/exllamav2.git exllamav2 && \ + git clone https://github.com/Narsil/exllamav2.git exllamav2 && \ cd exllamav2 && git fetch && git checkout $(exllamav2_commit) && \ git submodule update --init --recursive && \ pip install -r requirements.txt && \ diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ba168b13..28534d0f 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -652,6 +652,7 @@ class CausalLM(Model): dtype=dtype, device=device, ) + self.quantize = quantize return self @property diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 5d6ce3c7..f6dcde68 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -412,6 +412,7 @@ class Mamba(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.quantize = quantize self.process_group, _rank, world_size = initialize_torch_distributed() if world_size > 1: raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 3c92128a..04d4c28b 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -676,6 +676,7 @@ class Seq2SeqLM(Model): dtype=dtype, device=device, ) + self.quantize = quantize return self @property