diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq.json new file mode 100644 index 00000000..760ebf94 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.640625, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.34375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4296875, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4453125, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.8632812, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1328125, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.76660156, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3837891, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -1.9746094, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4189453, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.34375, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.8852539, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json new file mode 100644 index 00000000..7a168b2e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.65625, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.3671875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 604, + "logprob": -0.36938477, + "special": false, + "text": " for" + }, + { + "id": 235248, + "logprob": -1.8046875, + "special": false, + "text": " " + }, + { + "id": 235274, + "logprob": -0.46240234, + "special": false, + "text": "1" + }, + { + "id": 235284, + "logprob": -1.7460938, + "special": false, + "text": "2" + }, + { + "id": 235265, + "logprob": -1.9443359, + "special": false, + "text": "." + }, + { + "id": 235284, + "logprob": -1.4550781, + "special": false, + "text": "2" + }, + { + "id": 235308, + "logprob": -1.0205078, + "special": false, + "text": "5" + }, + { + "id": 235290, + "logprob": -1.0283203, + "special": false, + "text": "-" + }, + { + "id": 235274, + "logprob": -1.2783203, + "special": false, + "text": "1" + }, + { + "id": 235284, + "logprob": 0.0, + "special": false, + "text": "2" + } + ], + "top_tokens": null + }, + "generated_text": "Test request for 12.25-12" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_load.json new file mode 100644 index 00000000..bcb9b378 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.6484375, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.359375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4277344, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4394531, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.8613281, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1523438, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.76220703, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3642578, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -2.0175781, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4238281, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.328125, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.8881836, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.6484375, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.34375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4238281, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4453125, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.859375, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1445312, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.7631836, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3642578, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -1.9960938, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4179688, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.3359375, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.8847656, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.640625, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.3671875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4257812, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4453125, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.8789062, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1367188, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.76171875, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3515625, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -1.9873047, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4169922, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.3320312, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.8930664, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -9.6484375, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.359375, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 604, + "logprob": -2.4179688, + "special": false, + "text": " for" + }, + { + "id": 573, + "logprob": -2.4492188, + "special": false, + "text": " the" + }, + { + "id": 2412, + "logprob": -2.8574219, + "special": false, + "text": " following" + }, + { + "id": 235292, + "logprob": -2.1445312, + "special": false, + "text": ":" + }, + { + "id": 109, + "logprob": -0.7519531, + "special": false, + "text": "\n\n" + }, + { + "id": 235287, + "logprob": -1.3623047, + "special": false, + "text": "*" + }, + { + "id": 235248, + "logprob": -1.9707031, + "special": false, + "text": " " + }, + { + "id": 199, + "logprob": -1.4267578, + "special": false, + "text": "" + }, + { + "id": 1232, + "logprob": -4.3359375, + "special": false, + "text": "Name" + }, + { + "id": 208, + "logprob": -0.88427734, + "special": false, + "text": "" + } + ], + "top_tokens": null + }, + "generated_text": " for the following:\n\n* Name" + } +] diff --git a/integration-tests/models/test_flash_gemma_gptq.py b/integration-tests/models/test_flash_gemma_gptq.py new file mode 100644 index 00000000..7ed339f4 --- /dev/null +++ b/integration-tests/models/test_flash_gemma_gptq.py @@ -0,0 +1,62 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_gemma_gptq_handle(launcher): + with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_gemma_gptq(flash_gemma_gptq_handle): + await flash_gemma_gptq_handle.health(300) + return flash_gemma_gptq_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_gptq(flash_gemma_gptq, response_snapshot): + response = await flash_gemma_gptq.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot): + response = await flash_gemma_gptq.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_gptq_load( + flash_gemma_gptq, generate_load, response_snapshot +): + responses = await generate_load( + flash_gemma_gptq, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d4a325a9..92a20639 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -263,9 +263,13 @@ def get_model( trust_remote_code: bool, ) -> Model: if dtype is None: - # Keep it as default for now and let - # every model resolve their own default dtype. - dtype = None + if quantize in ["awq", "gptq"]: + # These quantizers only work with float16 params. + dtype = torch.float16 + else: + # Keep it as default for now and let + # every model resolve their own default dtype. + dtype = None elif dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16": diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index d2f6d9af..5b66823f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -78,7 +78,7 @@ def _load_multi_mqa_gptq( quant_method, ) = weights._get_gptq_params() if quant_method == "gptq": - g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") + g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx", to_dtype=False) g_idx = g_idx.to(device=weights.device) elif quant_method == "awq": g_idx = None diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 6af7d3fb..d11493f2 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -71,19 +71,19 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device=True): + def get_tensor( + self, tensor_name: str, to_device: bool = True, to_dtype: bool = True + ): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) - # Special case for gptq which shouldn't convert - # u4 which are disguised as int32 - if tensor.dtype not in [torch.int32, torch.int64]: + if to_dtype: tensor = tensor.to(dtype=self.dtype) if to_device: tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int): + def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype: bool = True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -101,14 +101,12 @@ class Weights: tensor = slice_[:, start:stop] else: raise NotImplementedError("Let's make that generic when needed") - # Special case for gptq which shouldn't convert - # u4 which are disguised as int32 - if tensor.dtype != torch.int32: + if to_dtype: tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_sharded(self, tensor_name: str, dim: int, to_dtype: bool = True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -117,7 +115,7 @@ class Weights: assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim) + return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype) def _get_qweight(self, name: str): slice_ = self._get_slice(name) @@ -163,10 +161,9 @@ class Weights: qzeros = self._get_qweight(f"{prefix}.qzeros") scales = self._get_qweight(f"{prefix}.scales") - scales = scales.to(dtype=self.dtype) if quantize == "gptq" and quant_method == "gptq": - g_idx = self.get_tensor(f"{prefix}.g_idx") + g_idx = self.get_tensor(f"{prefix}.g_idx", to_dtype=False) elif quantize == "gptq" and quant_method == "awq": log_once( logger.info, "Converting AWQ model to Exllama/GPTQ packing format." @@ -211,7 +208,11 @@ class Weights: if quantize in ["gptq", "awq"]: try: qweight = torch.cat( - [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + [ + self.get_sharded(f"{p}.qweight", dim=1, to_dtype=False) + for p in prefixes + ], + dim=1, ) except RuntimeError: raise RuntimeError( @@ -219,10 +220,18 @@ class Weights: ) qzeros = torch.cat( - [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + [ + self.get_sharded(f"{p}.qzeros", dim=1, to_dtype=False) + for p in prefixes + ], + dim=1, ) scales = torch.cat( - [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + [ + self.get_sharded(f"{p}.scales", dim=1, to_dtype=False) + for p in prefixes + ], + dim=1, ) bits, groupsize, desc_act, quant_method = self._get_gptq_params() @@ -234,7 +243,7 @@ class Weights: ) if quantize == "gptq" and quant_method == "gptq": - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] + w = [self.get_tensor(f"{p}.g_idx", to_dtype=False) for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) g_idx = w[0] @@ -265,22 +274,6 @@ class Weights: weight = torch.cat(w, dim=dim) return weight - def get_tensor_shard(self, var, dim): - world_size = self.process_group.size() - rank = self.process_group.rank() - block_size = var.size()[dim] // world_size - start = rank * block_size - stop = (rank + 1) * block_size - if dim == 0: - tensor = var[start:stop] - elif dim == 1: - tensor = var[:, start:stop] - else: - raise NotImplementedError("Let's make that generic when needed") - tensor = tensor.to(dtype=self.dtype) - tensor = tensor.to(device=self.device) - return tensor - def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": use_exllama = True @@ -294,14 +287,14 @@ class Weights: use_exllama = False try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + qweight = self.get_sharded(f"{prefix}.qweight", dim=0, to_dtype=False) except RuntimeError: raise RuntimeError( "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) if quant_method == "gptq": - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0, to_dtype=False) elif quant_method == "awq": g_idx = None @@ -335,11 +328,11 @@ class Weights: log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") if use_exllama and groupsize != -1: - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) + qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0, to_dtype=False) + scales = self.get_sharded(f"{prefix}.scales", dim=0, to_dtype=False) else: - qzeros = self.get_tensor(f"{prefix}.qzeros") - scales = self.get_tensor(f"{prefix}.scales") + qzeros = self.get_tensor(f"{prefix}.qzeros", to_dtype=False) + scales = self.get_tensor(f"{prefix}.scales", to_dtype=False) if use_exllama and g_idx is not None: g_idx = g_idx - g_idx[0] @@ -368,14 +361,14 @@ class Weights: bits, groupsize, _, _ = self._get_gptq_params() try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + qweight = self.get_sharded(f"{prefix}.qweight", dim=0, to_dtype=False) except RuntimeError: raise RuntimeError( "Cannot load `awq` weight, make sure the model is already quantized" ) - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) + qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0, to_dtype=False) + scales = self.get_sharded(f"{prefix}.scales", dim=0, to_dtype=False) g_idx = None use_exllama = False @@ -386,8 +379,8 @@ class Weights: def _get_gptq_params(self) -> Tuple[int, int, int, str]: try: - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() + bits = self.get_tensor("gptq_bits", to_dtype=False).item() + groupsize = self.get_tensor("gptq_groupsize", to_dtype=False).item() desc_act = False quant_method = "gptq" except (SafetensorError, RuntimeError) as e: