From 5ca0508d02b1996d290016a3bec5694c7c3da59b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 20 Jul 2023 15:36:53 +0000 Subject: [PATCH] Simpler exllama --- Makefile | 3 - integration-tests/conftest.py | 5 +- .../test_flash_llama_gptq.json | 129 +++++++++--------- .../models/test_flash_llama_gptq.py | 3 +- .../models/test_flash_starcoder_gptq.py | 4 +- .../custom_kernels/exllama/exllama_ext.cpp | 2 +- server/custom_kernels/setup.py | 2 +- .../models/causal_lm.py | 1 - .../custom_modeling/flash_llama_modeling.py | 2 +- .../flash_santacoder_modeling.py | 22 +-- .../models/flash_causal_lm.py | 6 +- .../models/flash_llama.py | 1 - .../models/flash_neox.py | 1 - .../text_generation_server/models/flash_rw.py | 1 - .../models/flash_santacoder.py | 1 - .../models/galactica.py | 1 - .../text_generation_server/models/gpt_neox.py | 1 - server/text_generation_server/models/model.py | 53 +------ server/text_generation_server/models/mpt.py | 1 - server/text_generation_server/models/opt.py | 1 - server/text_generation_server/models/rw.py | 1 - .../models/santacoder.py | 1 - .../models/seq2seq_lm.py | 1 - server/text_generation_server/models/t5.py | 1 - server/text_generation_server/server.py | 7 + .../utils/gptq/exllama.py | 89 ++++++++++++ .../utils/gptq/quant_linear.py | 78 ----------- server/text_generation_server/utils/layers.py | 18 ++- .../text_generation_server/utils/weights.py | 63 ++++----- 29 files changed, 223 insertions(+), 276 deletions(-) create mode 100644 server/text_generation_server/utils/gptq/exllama.py diff --git a/Makefile b/Makefile index 81b312d1..3c2f2b9d 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,3 @@ run-bloom: run-bloom-quantize: text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080 - -clean: - rm -rf target aml diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index e6b577d4..85dbeeda 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -230,19 +230,16 @@ def launcher(event_loop): shard_uds_path, ] - env = os.environ if num_shard is not None: args.extend(["--num-shard", str(num_shard)]) if quantize is not None: args.append("--quantize") args.append(quantize) - if quantize == "gptq": - env["GPTQ_GROUPSIZE"] = "128" - env["GPTQ_BITS"] = "4" if trust_remote_code: args.append("--trust-remote-code") + env = os.environ env["LOG_LEVEL"] = "info,text_generation_router=debug" if not use_flash_attention: diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json index 6bd6c729..db3ad58f 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json @@ -1,102 +1,103 @@ { - "generated_text": ", and I am going to visit the Louvre", "details": { + "best_of_sequences": null, "finish_reason": "length", "generated_tokens": 10, - "seed": null, "prefill": [ { - "id": 2, - "text": "", - "logprob": null + "id": 1, + "logprob": null, + "text": "" }, { "id": 20628, - "text": "Today", - "logprob": -11.2265625 + "logprob": -10.328125, + "text": "Today" }, { "id": 306, - "text": "I", - "logprob": -4.1757812 + "logprob": -2.390625, + "text": "I" }, { "id": 626, - "text": "am", - "logprob": -1.9746094 + "logprob": -1.8857422, + "text": "am" }, { "id": 297, - "text": "in", - "logprob": -5.4648438 + "logprob": -4.4765625, + "text": "in" }, { "id": 3444, - "text": "France", - "logprob": -9.03125 + "logprob": -7.0703125, + "text": "France" } ], + "seed": null, "tokens": [ { "id": 29892, - "text": ",", - "logprob": -0.31298828, - "special": false + "logprob": -1.2910156, + "special": false, + "text": "," }, { - "id": 322, - "text": " and", - "logprob": -1.4345703, - "special": false - }, - { - "id": 306, - "text": " I", - "logprob": -0.32080078, - "special": false - }, - { - "id": 626, - "text": " am", - "logprob": -1.3798828, - "special": false - }, - { - "id": 2675, - "text": " going", - "logprob": -1.2304688, - "special": false - }, - { - "id": 304, - "text": " to", - "logprob": -0.0014791489, - "special": false - }, - { - "id": 6493, - "text": " visit", - "logprob": -1.1503906, - "special": false + "id": 297, + "logprob": -1.9394531, + "special": false, + "text": " in" }, { "id": 278, - "text": " the", - "logprob": -0.41259766, - "special": false + "logprob": -0.7597656, + "special": false, + "text": " the" }, { - "id": 4562, - "text": " Lou", - "logprob": -1.8134766, - "special": false + "id": 7062, + "logprob": -2.9121094, + "special": false, + "text": " south" }, { - "id": 12675, - "text": "vre", - "logprob": -0.000767231, - "special": false + "id": 310, + "logprob": -1.0302734, + "special": false, + "text": " of" + }, + { + "id": 278, + "logprob": -0.58203125, + "special": false, + "text": " the" + }, + { + "id": 4234, + "logprob": -0.2944336, + "special": false, + "text": " country" + }, + { + "id": 29892, + "logprob": -0.7011719, + "special": false, + "text": "," + }, + { + "id": 297, + "logprob": -1.1054688, + "special": false, + "text": " in" + }, + { + "id": 278, + "logprob": -0.52490234, + "special": false, + "text": " the" } ] - } + }, + "generated_text": ", in the south of the country, in the" } diff --git a/integration-tests/models/test_flash_llama_gptq.py b/integration-tests/models/test_flash_llama_gptq.py index 2a00a2b1..73ba2785 100644 --- a/integration-tests/models/test_flash_llama_gptq.py +++ b/integration-tests/models/test_flash_llama_gptq.py @@ -1,10 +1,9 @@ - import pytest @pytest.fixture(scope="module") def flash_llama_gptq_handle(launcher): - with launcher("TheBloke/WizardLM-7B-uncensored-GPTQ", num_shard=2, quantize="gptq") as handle: + with launcher("huggingface/llama-7b-gptq", num_shard=4, quantize="gptq") as handle: yield handle diff --git a/integration-tests/models/test_flash_starcoder_gptq.py b/integration-tests/models/test_flash_starcoder_gptq.py index b6bed6a6..dd98d660 100644 --- a/integration-tests/models/test_flash_starcoder_gptq.py +++ b/integration-tests/models/test_flash_starcoder_gptq.py @@ -3,7 +3,7 @@ import pytest @pytest.fixture(scope="module") def flash_starcoder_gptq_handle(launcher): - with launcher("Narsil/starcoder-gptq", num_shard=2, quantize="gptq") as handle: + with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle: yield handle @@ -46,4 +46,4 @@ async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, re assert len(responses) == 4 assert all([r.generated_text == responses[0].generated_text for r in responses]) - assert responses == response_snapshot \ No newline at end of file + assert responses == response_snapshot diff --git a/server/custom_kernels/custom_kernels/exllama/exllama_ext.cpp b/server/custom_kernels/custom_kernels/exllama/exllama_ext.cpp index b786988b..4e43d605 100644 --- a/server/custom_kernels/custom_kernels/exllama/exllama_ext.cpp +++ b/server/custom_kernels/custom_kernels/exllama/exllama_ext.cpp @@ -3,7 +3,7 @@ #include #include #include -#include +// #include #include #include #include diff --git a/server/custom_kernels/setup.py b/server/custom_kernels/setup.py index 2af50d94..d881697f 100644 --- a/server/custom_kernels/setup.py +++ b/server/custom_kernels/setup.py @@ -14,7 +14,7 @@ setup( sources=["custom_kernels/fused_attention_cuda.cu"], extra_compile_args=["-arch=compute_80", "-std=c++17"], ), - CppExtension( + CUDAExtension( name="custom_kernels.exllama", sources=[ "custom_kernels/exllama/exllama_ext.cpp", diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 3bad2daf..cbdf4808 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -500,7 +500,6 @@ class CausalLM(Model): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, - config=model.config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 039fe2bf..b2bde282 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -378,6 +378,7 @@ class FlashLlamaLayer(nn.Module): class FlashLlamaModel(torch.nn.Module): def __init__(self, config, weights): super().__init__() + self.config = config process_group = weights.process_group self.tp_rank = process_group.rank() @@ -448,7 +449,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, config, weights): super().__init__() - self.config = config self.model = FlashLlamaModel(config, weights) self.lm_head = TensorParallelHead.load( config, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 4dd76360..6f5c60fc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -20,6 +20,7 @@ from text_generation_server.utils.layers import ( ) from safetensors import SafetensorError + def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): @@ -71,12 +72,19 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") - bits, groupsize = weights.get_gptq_qparams() + try: + bits = weights.get_tensor("gptq_bits").item() + groupsize = weights.get_tensor("gptq_groupsize").item() + except SafetensorError as e: + try: + import os - qweight = qweight.to(weights.device) - qzeros = qzeros.to(weights.device) - scales = scales.to(weights.device) - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + except Exception: + raise e + + weight = (qweight, qzeros, scales, g_idx, bits, groupsize) if bias: slice_ = weights._get_slice(f"{prefix}.c_attn.bias") @@ -90,8 +98,6 @@ def _load_multi_mqa_gptq( kv_tensor = slice_[-2 * head_size :] bias = torch.cat([q_tensor, kv_tensor], dim=0) - bias = bias.to(weights.device) - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) else: raise NotImplementedError("Gptq loading with santacoder is not implemented") @@ -355,7 +361,7 @@ class Block(nn.Module): max_s, ): hidden_states, residual = self.ln_1(hidden_states, residual) - + hidden_states = self.attn( hidden_states, cu_seqlen_prefill, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 56c21463..517fba68 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -6,8 +6,9 @@ import torch.distributed import numpy as np from dataclasses import dataclass +from loguru import logger from opentelemetry import trace -from transformers import PreTrainedTokenizerBase, PretrainedConfig +from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Dict from text_generation_server.models import Model @@ -20,7 +21,6 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser - tracer = trace.get_tracer(__name__) BLOCK_SIZE = 16 @@ -684,7 +684,6 @@ class FlashCausalLM(Model): self, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, - config: PretrainedConfig, num_layers: int, num_kv_heads: int, head_size: int, @@ -700,7 +699,6 @@ class FlashCausalLM(Model): super(FlashCausalLM, self).__init__( model=model, tokenizer=tokenizer, - config=config, requires_padding=False, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 77450cbb..b699799e 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -68,7 +68,6 @@ class FlashLlama(FlashCausalLM): super(FlashLlama, self).__init__( model=model, tokenizer=tokenizer, - config=config, num_layers=len(model.model.layers), num_kv_heads=model.model.num_key_value_heads, head_size=model.model.head_size, diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index ce868904..61004d8e 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -59,7 +59,6 @@ class FlashNeoXSharded(FlashCausalLM): super(FlashNeoXSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, - config=config, num_layers=len(model.gpt_neox.layers), num_kv_heads=model.gpt_neox.num_heads, head_size=model.gpt_neox.head_size, diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 2006993f..55d555fc 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -65,7 +65,6 @@ class FlashRWSharded(FlashCausalLM): super(FlashRWSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, - config=config, num_layers=len(model.transformer.h), num_kv_heads=model.transformer.cache_size, head_size=model.transformer.head_size, diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index ecbd6b31..415ec2df 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -66,7 +66,6 @@ class FlashSantacoderSharded(FlashCausalLM): super(FlashSantacoderSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, - config=config, num_layers=len(model.transformer.h), num_kv_heads=1, head_size=model.transformer.head_size, diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index d6885d5f..01e58bad 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -198,7 +198,6 @@ class GalacticaSharded(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, - config=config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 2b55a70b..91877fa0 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -63,7 +63,6 @@ class GPTNeoxSharded(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, - config=config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 8ceac511..3827197f 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -3,27 +3,19 @@ import torch from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type -from transformers import PreTrainedTokenizerBase, PretrainedConfig +from transformers import PreTrainedTokenizerBase from text_generation_server.models.types import Batch, GeneratedText from text_generation_server.pb.generate_pb2 import InfoResponse -from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear -from custom_kernels.exllama import prepare_buffers, set_tuning_params - -from text_generation_server.utils.layers import ( - TensorParallelRowLinear, - TensorParallelColumnLinear -) - B = TypeVar("B", bound=Batch) + class Model(ABC): def __init__( self, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, - config: PretrainedConfig, requires_padding: bool, dtype: torch.dtype, device: torch.device, @@ -46,47 +38,6 @@ class Model(ABC): inspect.signature(model.forward).parameters.get("position_ids", None) is not None ) - self.config = config - - if config.quantize == "gptq": - # Buffers need to be persistent to avoid any bug. - self.buffers = {} - use_exllama_act_order = False - max_dq_buffer_size = 1 - max_inner_outer_dim = 1 - for name, submodule in model.named_modules(): - if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear): - - max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8) - - if submodule.linear.act_order: - max_inner_outer_dim = max(max_inner_outer_dim, submodule.linear.height, submodule.linear.width) - - use_exllama_act_order = True - - if use_exllama_act_order: - # TODO: this should be set to rust side `max_total_tokens`, but TGI - # does not offer an API to expose this variable to python, as this variable - # is handled by the client but it appears the model is initialized by the server. - # An alternative could be to initialize the buffers during warmup. - max_total_tokens = 2048 - else: - max_total_tokens = 1 - - # This temp_state buffer is required to reorder X in the act-order case. - self.buffers["temp_state"] = torch.zeros((max_total_tokens, max_inner_outer_dim), dtype=torch.float16, device=device) - - # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) - - prepare_buffers(device, self.buffers["temp_state"], self.buffers["temp_dq"]) - - matmul_recons_thd = 8 - matmul_fused_remap = False - matmul_no_half2 = False - set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) - - torch.cuda.empty_cache() self.check_initialized() diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 8d7d68ab..a4fe5105 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -86,7 +86,6 @@ class MPTSharded(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, - config=config, requires_padding=False, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 8e64b1ab..d407b44a 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -61,7 +61,6 @@ class OPTSharded(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, - config=config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 796a9590..92bb135b 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -58,7 +58,6 @@ class RW(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, - config=model.config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index f8f72e9c..a2b38737 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -63,7 +63,6 @@ class SantaCoder(CausalLM): super(CausalLM, self).__init__( model=model, tokenizer=tokenizer, - config=model.config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index f6e744f5..9e5c21d1 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -542,7 +542,6 @@ class Seq2SeqLM(Model): super(Seq2SeqLM, self).__init__( model=model, tokenizer=tokenizer, - config=model.config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 890141e2..133aafd8 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -73,7 +73,6 @@ class T5Sharded(Seq2SeqLM): super(Seq2SeqLM, self).__init__( model=model, tokenizer=tokenizer, - config=config, requires_padding=True, dtype=dtype, device=device, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e0efbcf5..6f48d2de 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -140,6 +140,13 @@ def serve( logger.exception("Error when initializing model") raise + try: + from text_generation_server.utils.gptq.exllama import create_buffers + create_buffers() + logger.info("Created exllama GPTQ buffers !") + except ImportError: + pass + server = aio.server( interceptors=[ ExceptionInterceptor(), diff --git a/server/text_generation_server/utils/gptq/exllama.py b/server/text_generation_server/utils/gptq/exllama.py new file mode 100644 index 00000000..ad80d976 --- /dev/null +++ b/server/text_generation_server/utils/gptq/exllama.py @@ -0,0 +1,89 @@ +import torch +from custom_kernels.exllama import make_q4, q4_matmul, set_tuning_params, prepare_buffers +from loguru import logger + +# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension + +def ext_q4_matmul(x, q4, q4_width): + """Matrix multiplication, returns x @ q4""" + outshape = x.shape[:-1] + (q4_width,) + x = x.view(-1, x.shape[-1]) + output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device) + + q4_matmul(x, q4, output) + + return output.view(outshape) + + +import os +RANK = os.getenv("RANK", "0") +DEVICE = torch.device(f"cuda:{RANK}") +MAX_TOTAL_TOKENS = 1 +MAX_INNER_OUTER_DIM = 0 +MAX_DQ_BUFFER_SIZE = 0 + + +def create_buffers(): + temp_state = torch.zeros((MAX_TOTAL_TOKENS, MAX_INNER_OUTER_DIM), dtype=torch.float16, device=DEVICE) + temp_dq = torch.zeros((1, MAX_DQ_BUFFER_SIZE), dtype=torch.float16, device=DEVICE) + logger.info(f"Creating buffers {temp_state.shape} - {temp_dq.shape} - {DEVICE}") + + prepare_buffers(DEVICE, temp_state, temp_dq) + + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + +class Ex4bitLinear: + """Linear layer implementation with per-group 4-bit quantization of the weights""" + def __init__(self, qweight, qzeros, scales, bias, bits): + assert bits == 4, "We cannot run exllama GPTQ kernels if bits != 4" + + global MAX_INNER_OUTER_DIM, MAX_DQ_BUFFER_SIZE + dq = qweight.numel() * 8 + if dq > MAX_DQ_BUFFER_SIZE: + MAX_DQ_BUFFER_SIZE = dq + + width = qweight.shape[1] + if width > MAX_INNER_OUTER_DIM: + MAX_INNER_OUTER_DIM = width + height = qweight.shape[0] * 8 + if height > MAX_INNER_OUTER_DIM: + MAX_INNER_OUTER_DIM = height + + # prepare_buffers(DEVICE, TEMP_STATE, TEMP_DQ) + + + self.q4 = make_q4( + qweight, + qzeros, + scales, + # Never send g_idx, it MUST be like act_order=False, the exllama kernel does not expect it + torch.zeros((0, 0), device=torch.device("meta")), + DEVICE.index + ) + self.bias = bias if bias is not None else None + self.width = width + + # # Infer groupsize from height of qzeros + # self.groupsize = None + # if self.qzeros.shape[0] > 1: + # self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) + + # if self.groupsize is not None: + # assert groupsize == self.groupsize + + # # Handle act-order matrix + # if self.g_idx is not None: + # if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?") + # self.act_order = True + # else: + # self.act_order = False + + def forward(self, x): + out = ext_q4_matmul(x, self.q4, self.width) + + if self.bias is not None: + out.add_(self.bias) + return out diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index 1b807427..4d7814af 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -8,11 +8,6 @@ import torch from loguru import logger -try: - from custom_kernels.exllama import make_q4, q4_matmul -except Exception as e: - logger.error(f"The CUDA kernels custom_kernels.exllama not installed, got the error: {e}") - try: import triton import triton.language as tl @@ -368,76 +363,3 @@ class QuantLinear(nn.Module): out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) -# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device = "meta") - -def ext_make_q4(qweight, qzeros, scales, g_idx, device): - """Construct Q4Matrix, return handle""" - return make_q4(qweight, - qzeros, - scales, - g_idx if g_idx is not None else none_tensor, - device) - -def ext_q4_matmul(x, q4, q4_width): - """Matrix multiplication, returns x @ q4""" - outshape = x.shape[:-1] + (q4_width,) - x = x.view(-1, x.shape[-1]) - output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device) - - q4_matmul(x, q4, output) - - return output.view(outshape) - - -class Ex4bitLinear: - """Linear layer implementation with per-group 4-bit quantization of the weights""" - def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): - assert bits == 4 - - self.device = qweight.device - self.qweight = qweight - self.qzeros = qzeros - self.scales = scales - self.g_idx = g_idx.cpu() if g_idx is not None else None - self.bias = bias if bias is not None else None - - if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))): - self.empty_g_idx = True - self.g_idx = None - - assert self.device.type == "cuda" - assert self.device.index is not None - - self.q4 = ext_make_q4( - self.qweight, - self.qzeros, - self.scales, - self.g_idx, - self.device.index - ) - - self.height = qweight.shape[0] * 8 - self.width = qweight.shape[1] - - # Infer groupsize from height of qzeros - self.groupsize = None - if self.qzeros.shape[0] > 1: - self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) - - if self.groupsize is not None: - assert groupsize == self.groupsize - - # Handle act-order matrix - if self.g_idx is not None: - if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?") - self.act_order = True - else: - self.act_order = False - - def forward(self, x): - out = ext_q4_matmul(x, self.q4, self.width) - - if self.bias is not None: - out.add_(self.bias) - return out diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index db392c4a..79c10084 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,5 +1,6 @@ import torch import torch.distributed +from loguru import logger from torch import nn from torch.nn import functional as F @@ -15,7 +16,14 @@ except ImportError: from accelerate import init_empty_weights -from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear +from text_generation_server.utils.gptq.quant_linear import QuantLinear + +HAS_EXLLAMA = True +try: + from text_generation_server.utils.gptq.exllama import Ex4bitLinear +except ImportError: + logger.error(f"The CUDA kernels custom_kernels.exllama not installed using base triton kernel") + HAS_EXLLAMA = False from typing import Optional @@ -145,13 +153,15 @@ def get_linear(weight, bias, quantize): linear.bias = nn.Parameter(bias) elif quantize == "gptq": try: - qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel = weight + qweight, qzeros, scales, g_idx, bits, groupsize, can_exllama = weight except Exception: raise NotImplementedError( f"The passed weight is not `gptq` compatible, loader needs to be updated." ) - if use_triton_kernel or bits != 4: + if can_exllama and HAS_EXLLAMA: + linear = Ex4bitLinear(qweight, qzeros, scales, bias, bits) + else: linear = QuantLinear( qweight, qzeros, @@ -161,8 +171,6 @@ def get_linear(weight, bias, quantize): bits, groupsize, ) - else: - linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 5bbf04d0..8bf74a40 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -3,7 +3,6 @@ from typing import List, Dict, Optional, Tuple from safetensors import safe_open, SafetensorError import torch - class Weights: def __init__( self, @@ -126,9 +125,15 @@ class Weights: for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) g_idx = w[0] + can_exllama = True + bits, groupsize = self._get_gptq_qparams() + if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all(): + # Exllama implementation does not support row tensor parallelism with act-order, as + # it would require to reorder input activations that are split unto several GPUs + can_exllama = False - bits, groupsize = self.get_gptq_qparams() - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + bits, groupsize = self._get_gptq_qparams() + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, can_exllama) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) @@ -136,52 +141,32 @@ class Weights: def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": - use_triton_kernel = False - if self.process_group.size() > 1: - g_idx = self.get_tensor(f"{prefix}.g_idx") - _, groupsize = self.get_gptq_qparams() - - if g_idx is not None: - if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all(): - # Exllama implementation does not support row tensor parallelism with act-order, as - # it would require to reorder input activations that are split unto several GPUs - use_triton_kernel = True - try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + + bits, groupsize = self._get_gptq_qparams() + g_idx = self.get_tensor(f"{prefix}.g_idx") - bits, groupsize = self.get_gptq_qparams() + can_exllama = True + if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all(): + # Exllama implementation does not support row tensor parallelism with act-order, as + # it would require to reorder input activations that are split unto several GPUs + can_exllama = False - if use_triton_kernel: - # The triton kernel reorders the scales/zero points instead of the weight/activation. - # Thus, each rank needs the full qzeros/scales. - qzeros = self.get_tensor(f"{prefix}.qzeros") - scales = self.get_tensor(f"{prefix}.scales") - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - else: - if groupsize >= 16: - # Exllama reorders the weights in advance and the activations on the fly, thus - # the scales and zero-points do not need to be reordered. - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) - else: - qzeros = self.get_tensor(f"{prefix}.qzeros") - scales = self.get_tensor(f"{prefix}.scales") + # The triton kernel reorders the scales/zero points instead of the weight/activation. + # Thus, each rank needs the full qzeros/scales. + qzeros = self.get_tensor(f"{prefix}.qzeros") + scales = self.get_tensor(f"{prefix}.scales") + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - # For tp > 1, at this point we know we do not use act-order - if self.process_group.size() == 1: - g_idx = self.get_tensor(f"{prefix}.g_idx") - else: - g_idx = None - - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel) + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, can_exllama) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight - def get_gptq_qparams(self) -> Tuple[int, int]: + def _get_gptq_qparams(self) -> Tuple[int, int]: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() @@ -194,4 +179,4 @@ class Weights: except Exception: raise e - return bits, groupsize \ No newline at end of file + return bits, groupsize