diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json index c8481eb2..26224118 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json @@ -11,57 +11,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -8.8515625, "text": "ometric" }, { "id": 81, - "logprob": -0.2578125, + "logprob": -0.21875, "text": "_" }, { "id": 6009, - "logprob": -2.1835938, + "logprob": -1.2773438, "text": "mean" }, { "id": 26, - "logprob": -0.3005371, + "logprob": -0.25195312, "text": "(" }, { "id": 62, - "logprob": -5.625, + "logprob": -4.8203125, "text": "L" }, { "id": 44, - "logprob": -3.0644531, + "logprob": -3.7734375, "text": ":" }, { "id": 1682, - "logprob": -0.6845703, + "logprob": -0.8310547, "text": " List" }, { "id": 77, - "logprob": -0.3869629, + "logprob": -0.22766113, "text": "[" }, { "id": 1808, - "logprob": -0.94628906, + "logprob": -0.46240234, "text": "float" }, { "id": 10794, - "logprob": -2.5371094, + "logprob": -3.0234375, "text": "]):" } ], @@ -69,7 +69,7 @@ "tokens": [ { "id": 284, - "logprob": -1.171875, + "logprob": -0.04626465, "special": false, "text": "\n " }, diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json index 06f19bca..015912f8 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json @@ -11,57 +11,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -8.859375, "text": "ometric" }, { "id": 81, - "logprob": -0.25634766, + "logprob": -0.21984863, "text": "_" }, { "id": 6009, - "logprob": -2.1835938, + "logprob": -1.2861328, "text": "mean" }, { "id": 26, - "logprob": -0.29956055, + "logprob": -0.25219727, "text": "(" }, { "id": 62, - "logprob": -5.625, + "logprob": -4.8007812, "text": "L" }, { "id": 44, - "logprob": -3.09375, + "logprob": -3.7949219, "text": ":" }, { "id": 1682, - "logprob": -0.67578125, + "logprob": -0.8046875, "text": " List" }, { "id": 77, - "logprob": -0.38256836, + "logprob": -0.22424316, "text": "[" }, { "id": 1808, - "logprob": -0.9458008, + "logprob": -0.46191406, "text": "float" }, { "id": 10794, - "logprob": -2.5371094, + "logprob": -3.0253906, "text": "]):" } ], @@ -69,7 +69,7 @@ "tokens": [ { "id": 284, - "logprob": -0.05831909, + "logprob": 0.0, "special": false, "text": "\n " }, diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json index 70a73423..d9072c52 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json @@ -12,57 +12,57 @@ }, { "id": 3226, - "logprob": -9.015625, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -8.8515625, "text": "ometric" }, { "id": 81, - "logprob": -0.25585938, + "logprob": -0.22033691, "text": "_" }, { "id": 6009, - "logprob": -2.1894531, + "logprob": -1.2939453, "text": "mean" }, { "id": 26, - "logprob": -0.29882812, + "logprob": -0.25268555, "text": "(" }, { "id": 62, - "logprob": -5.6210938, + "logprob": -4.796875, "text": "L" }, { "id": 44, - "logprob": -3.078125, + "logprob": -3.796875, "text": ":" }, { "id": 1682, - "logprob": -0.6699219, + "logprob": -0.8066406, "text": " List" }, { "id": 77, - "logprob": -0.38232422, + "logprob": -0.22644043, "text": "[" }, { "id": 1808, - "logprob": -0.9379883, + "logprob": -0.46166992, "text": "float" }, { "id": 10794, - "logprob": -2.5371094, + "logprob": -3.0253906, "text": "]):" } ], @@ -70,7 +70,7 @@ "tokens": [ { "id": 284, - "logprob": -1.1826172, + "logprob": -0.046844482, "special": false, "text": "\n " }, @@ -98,57 +98,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.9375, "text": " ge" }, { "id": 21017, - "logprob": -9.09375, + "logprob": -8.8515625, "text": "ometric" }, { "id": 81, - "logprob": -0.25610352, + "logprob": -0.21826172, "text": "_" }, { "id": 6009, - "logprob": -2.1933594, + "logprob": -1.2871094, "text": "mean" }, { "id": 26, - "logprob": -0.29907227, + "logprob": -0.25390625, "text": "(" }, { "id": 62, - "logprob": -5.640625, + "logprob": -4.8085938, "text": "L" }, { "id": 44, - "logprob": -3.09375, + "logprob": -3.7890625, "text": ":" }, { "id": 1682, - "logprob": -0.67626953, + "logprob": -0.8076172, "text": " List" }, { "id": 77, - "logprob": -0.39038086, + "logprob": -0.22302246, "text": "[" }, { "id": 1808, - "logprob": -0.94384766, + "logprob": -0.46435547, "text": "float" }, { "id": 10794, - "logprob": -2.5507812, + "logprob": -3.0234375, "text": "]):" } ], @@ -156,7 +156,7 @@ "tokens": [ { "id": 284, - "logprob": -1.1865234, + "logprob": -0.046722412, "special": false, "text": "\n " }, @@ -184,57 +184,57 @@ }, { "id": 3226, - "logprob": -9.015625, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -8.8515625, "text": "ometric" }, { "id": 81, - "logprob": -0.25561523, + "logprob": -0.21813965, "text": "_" }, { "id": 6009, - "logprob": -2.1933594, + "logprob": -1.2744141, "text": "mean" }, { "id": 26, - "logprob": -0.296875, + "logprob": -0.2512207, "text": "(" }, { "id": 62, - "logprob": -5.6367188, + "logprob": -4.8046875, "text": "L" }, { "id": 44, - "logprob": -3.0800781, + "logprob": -3.7851562, "text": ":" }, { "id": 1682, - "logprob": -0.6875, + "logprob": -0.81396484, "text": " List" }, { "id": 77, - "logprob": -0.3840332, + "logprob": -0.22570801, "text": "[" }, { "id": 1808, - "logprob": -0.93847656, + "logprob": -0.46044922, "text": "float" }, { "id": 10794, - "logprob": -2.5371094, + "logprob": -3.0234375, "text": "]):" } ], @@ -242,7 +242,7 @@ "tokens": [ { "id": 284, - "logprob": -1.1777344, + "logprob": -0.04650879, "special": false, "text": "\n " }, @@ -270,57 +270,57 @@ }, { "id": 3226, - "logprob": -9.015625, + "logprob": -8.9453125, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -8.8515625, "text": "ometric" }, { "id": 81, - "logprob": -0.25610352, + "logprob": -0.21960449, "text": "_" }, { "id": 6009, - "logprob": -2.1933594, + "logprob": -1.2890625, "text": "mean" }, { "id": 26, - "logprob": -0.3010254, + "logprob": -0.25073242, "text": "(" }, { "id": 62, - "logprob": -5.6484375, + "logprob": -4.8085938, "text": "L" }, { "id": 44, - "logprob": -3.0820312, + "logprob": -3.8046875, "text": ":" }, { "id": 1682, - "logprob": -0.6801758, + "logprob": -0.8071289, "text": " List" }, { "id": 77, - "logprob": -0.39257812, + "logprob": -0.22570801, "text": "[" }, { "id": 1808, - "logprob": -0.92626953, + "logprob": -0.46118164, "text": "float" }, { "id": 10794, - "logprob": -2.5234375, + "logprob": -3.0097656, "text": "]):" } ], @@ -328,7 +328,7 @@ "tokens": [ { "id": 284, - "logprob": -1.171875, + "logprob": -0.046539307, "special": false, "text": "\n " }, diff --git a/server/text_generation_server/layers/awq/quantize/__init__.py b/server/text_generation_server/layers/awq/quantize/__init__.py new file mode 100644 index 00000000..3e72881b --- /dev/null +++ b/server/text_generation_server/layers/awq/quantize/__init__.py @@ -0,0 +1,8 @@ +from text_generation_server.utils.import_utils import SYSTEM + +if SYSTEM == "ipex": + from .ipex import WQLinear +elif SYSTEM == "cuda": + from .cuda import WQLinear + +__all__ = ["WQLinear"] diff --git a/server/text_generation_server/layers/awq/quantize/qmodule.py b/server/text_generation_server/layers/awq/quantize/cuda.py similarity index 100% rename from server/text_generation_server/layers/awq/quantize/qmodule.py rename to server/text_generation_server/layers/awq/quantize/cuda.py diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 020467f2..2049f777 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -8,6 +8,11 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader +if SYSTEM == "ipex": + from .ipex import QuantLinear +elif SYSTEM == "cuda": + from .cuda import QuantLinear + @dataclass class GPTQWeight(Weight): @@ -36,12 +41,7 @@ class GPTQWeight(Weight): "to use Exllama/GPTQ kernels for AWQ inference." ) try: - if SYSTEM == "ipex": - from text_generation_server.layers.awq.quantize.ipex import WQLinear - else: - from text_generation_server.layers.awq.quantize.qmodule import ( - WQLinear, - ) + from text_generation_server.layers.awq.quantize import WQLinear return WQLinear( w_bit=self.bits, @@ -65,10 +65,7 @@ class GPTQWeight(Weight): return ExllamaQuantLinear(self, bias) else: - if SYSTEM == "ipex": - from text_generation_server.layers.gptq.ipex import QuantLinear - else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear + from text_generation_server.layers.gptq import QuantLinear return QuantLinear( self.qweight, diff --git a/server/text_generation_server/layers/gptq/quant_linear.py b/server/text_generation_server/layers/gptq/cuda.py similarity index 100% rename from server/text_generation_server/layers/gptq/quant_linear.py rename to server/text_generation_server/layers/gptq/cuda.py diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index d87df5f2..6261792d 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -13,11 +13,7 @@ from accelerate import init_empty_weights from text_generation_server.utils import initialize_torch_distributed, Weights from text_generation_server.utils.hub import weight_files from text_generation_server.utils.import_utils import SYSTEM - -if SYSTEM == "ipex": - from text_generation_server.layers.gptq.ipex import QuantLinear -else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear +from text_generation_server.layers.gptq import QuantLinear from loguru import logger from typing import Optional from text_generation_server.layers.gptq.utils import torch_snr_error