From b7299e1b7fac91904cdfb761c8aee7b8caa1396a Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 18 Dec 2023 16:07:05 +0100 Subject: [PATCH] fix: fix gpt-q with groupsize = -1 (#1358) --- proto/generate.proto | 3 +++ router/client/src/client.rs | 14 ++++++++++++-- server/text_generation_server/server.py | 7 ++++++- .../text_generation_server/utils/gptq/exllama.py | 11 ++--------- .../text_generation_server/utils/gptq/exllamav2.py | 13 +------------ server/text_generation_server/utils/weights.py | 12 ++++++------ 6 files changed, 30 insertions(+), 30 deletions(-) diff --git a/proto/generate.proto b/proto/generate.proto index fc4617f9..ceb421c4 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -213,6 +213,9 @@ message DecodeResponse { message WarmupRequest { /// Batch to warmup on repeated Batch batches = 1; + uint32 max_input_length = 2; + uint32 max_prefill_tokens = 3; + uint32 max_total_tokens = 4; } /// Empty response diff --git a/router/client/src/client.rs b/router/client/src/client.rs index f0ecb05a..ba7b7565 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -167,7 +167,12 @@ impl Client { ); num_batches ]; - let request = tonic::Request::new(WarmupRequest { batches }).inject_context(); + let request = tonic::Request::new(WarmupRequest { + batches, + max_input_length, + max_prefill_tokens, + max_total_tokens, + }).inject_context(); let _response = self.stub.warmup(request).await?.into_inner(); } @@ -188,7 +193,12 @@ impl Client { ); num_batches ]; - let request = tonic::Request::new(WarmupRequest { batches }).inject_context(); + let request = tonic::Request::new(WarmupRequest { + batches, + max_input_length, + max_prefill_tokens, + max_total_tokens, + }).inject_context(); let _response = self.stub.warmup(request).await?.into_inner(); } Ok(None) // No support for maximum total tokens diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 83a65251..4a07733a 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -21,7 +21,12 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): - def __init__(self, model: Model, cache: Cache, server_urls: List[str]): + def __init__( + self, + model: Model, + cache: Cache, + server_urls: List[str], + ): self.cache = cache self.model = model self.server_urls = server_urls diff --git a/server/text_generation_server/utils/gptq/exllama.py b/server/text_generation_server/utils/gptq/exllama.py index 7353afb5..32f817db 100644 --- a/server/text_generation_server/utils/gptq/exllama.py +++ b/server/text_generation_server/utils/gptq/exllama.py @@ -37,19 +37,12 @@ def set_device(device): DEVICE = device -def create_exllama_buffers(): +def create_exllama_buffers(max_total_tokens: int): global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ assert DEVICE is not None, "call set_device first" - if ACT_ORDER: - # TODO: this should be set to rust side `max_total_tokens`, but TGI - # does not offer an API to expose this variable to python, as this variable - # is handled by the client but it appears the model is initialized by the server. - # An alternative could be to initialize the buffers during warmup. - # Dummy - max_total_tokens = 2048 - else: + if not ACT_ORDER: max_total_tokens = 1 # This temp_state buffer is required to reorder X in the act-order case. diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/utils/gptq/exllamav2.py index f820f0d9..dd41b269 100644 --- a/server/text_generation_server/utils/gptq/exllamav2.py +++ b/server/text_generation_server/utils/gptq/exllamav2.py @@ -101,7 +101,7 @@ def set_device(device): DEVICE = device -def create_exllama_buffers(): +def create_exllama_buffers(max_total_tokens: int): global FIXED_BYTES, LAYERS, DEVICE temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES) @@ -138,17 +138,6 @@ class QuantLinear(nn.Module): self.bias = bias if bias is not None else None self.group_size = groupsize - infeatures = self.infeatures - outfeatures = self.outfeatures - assert qweight.shape == (infeatures // 32 * self.bits, outfeatures) - assert infeatures % self.group_size == 0 - assert qzeros.shape == ( - infeatures // self.group_size, - outfeatures // 32 * self.bits, - ) - assert scales.shape == (infeatures // self.group_size, outfeatures) - assert g_idx.shape == (infeatures,), f"{g_idx.shape}, {infeatures}" - global FIXED_BYTES, LAYERS FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) LAYERS.append(self) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 67fda511..a2cca2ea 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -281,17 +281,17 @@ class Weights: else: logger.info(f"Using exllama kernels v{HAS_EXLLAMA}") - if use_exllama: + if use_exllama and groupsize != -1: qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0) - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - g_idx = g_idx - g_idx[0] else: - # The triton kernel reorders the scales/zero points instead of the weight/activation. - # Thus, each rank needs the full qzeros/scales. qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + + if use_exllama: + g_idx = g_idx - g_idx[0] weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) elif quantize == "awq":