From 8ee9307618b1240e0fa73845079216b5aaea6e4a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 25 Sep 2023 10:07:45 +0200 Subject: [PATCH] Finishing nits + integration test --- integration-tests/models/test_flash_awq.py | 61 +++++++++++++++++++ .../utils/awq/quantize/qmodule.py | 29 ++++----- .../text_generation_server/utils/weights.py | 36 +++++------ 3 files changed, 87 insertions(+), 39 deletions(-) create mode 100644 integration-tests/models/test_flash_awq.py diff --git a/integration-tests/models/test_flash_awq.py b/integration-tests/models/test_flash_awq.py new file mode 100644 index 00000000..f25f7f4e --- /dev/null +++ b/integration-tests/models/test_flash_awq.py @@ -0,0 +1,61 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_gptq_handle(launcher): + with launcher("abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", num_shard=2, quantize="awq") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_gptq(flash_llama_gptq_handle): + await flash_llama_gptq_handle.health(300) + return flash_llama_gptq_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): + response = await flash_llama_gptq.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): + response = await flash_llama_gptq.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq_load( + flash_llama_gptq, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_gptq, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/text_generation_server/utils/awq/quantize/qmodule.py b/server/text_generation_server/utils/awq/quantize/qmodule.py index fb1adf5c..c658e17f 100644 --- a/server/text_generation_server/utils/awq/quantize/qmodule.py +++ b/server/text_generation_server/utils/awq/quantize/qmodule.py @@ -6,14 +6,14 @@ import torch.nn as nn import awq_inference_engine # with CUDA kernels -class ScaledActivation(nn.Module): - def __init__(self, module, scales): - super().__init__() - self.act = module - self.scales = nn.Parameter(scales.data) - - def forward(self, x): - return self.act(x) / self.scales.view(1, 1, -1).to(x.device) +# class ScaledActivation(nn.Module): +# def __init__(self, module, scales): +# super().__init__() +# self.act = module +# self.scales = nn.Parameter(scales.data) +# +# def forward(self, x): +# return self.act(x) / self.scales.view(1, 1, -1).to(x.device) class WQLinear(nn.Module): @@ -32,11 +32,11 @@ class WQLinear(nn.Module): assert self.in_features % self.group_size == 0 assert self.out_features % (32 // self.w_bit) == 0 - self.register_buffer('qweight', qweight) - self.register_buffer('qzeros', qzeros) - self.register_buffer('scales', scales) + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales if bias: - self.register_buffer('bias', bias) + self.bias = bias else: self.bias = None @@ -46,8 +46,3 @@ class WQLinear(nn.Module): out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) - - def extra_repr(self) -> str: - return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format( - self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size - ) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index fdeabbe6..7b492ee7 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -139,21 +139,16 @@ class Weights: try: qweight = self._get_qweight(f"{prefix}.qweight") except RuntimeError: - if quantize == "gptq": - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) - else: - raise RuntimeError( - "Cannot load `awq` weight, make sure the model is already quantized" - ) + raise RuntimeError( + f"Cannot load `{quantize}` weight, make sure the model is already quantized." + ) qzeros = self._get_qweight(f"{prefix}.qzeros") scales = self._get_qweight(f"{prefix}.scales") scales = scales.to(dtype=self.dtype) - try: + if quantize == "gptq": g_idx = self.get_tensor(f"{prefix}.g_idx") - except RuntimeError: + else: g_idx = None bits, groupsize = self._get_gptq_params() @@ -185,14 +180,9 @@ class Weights: [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 ) except RuntimeError: - if quantize == "gptq": - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) - else: - raise RuntimeError( - "Cannot load `awq` weight, make sure the model is already quantized" - ) + raise RuntimeError( + f"Cannot load `{quantize}` weight, make sure the model is already quantized" + ) qzeros = torch.cat( [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 @@ -201,12 +191,12 @@ class Weights: [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 ) - try: + if quantize == "gptq": w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) g_idx = w[0] - except RuntimeError: + else: g_idx = None bits, groupsize = self._get_gptq_params() @@ -233,7 +223,7 @@ class Weights: return tensor def get_multi_weights_row(self, prefix: str, quantize: str): - if quantize in "gptq": + if quantize == "gptq": use_exllama = True bits, groupsize = self._get_gptq_params() @@ -311,8 +301,10 @@ class Weights: qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") + g_idx = None + use_exllama = False - weight = (qweight, qzeros, scales, None, bits, groupsize, None) + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight