From 8cf7c899109dec3177cc68503295789ef0c441a7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 20 Jul 2023 17:44:37 +0000 Subject: [PATCH] Small polish. --- .../models/custom_modeling/flash_llama_modeling.py | 1 - .../models/custom_modeling/flash_santacoder_modeling.py | 5 +++-- server/text_generation_server/models/flash_causal_lm.py | 3 +-- server/text_generation_server/models/flash_neox.py | 1 - server/text_generation_server/server.py | 5 ++++- server/text_generation_server/utils/gptq/quant_linear.py | 6 ------ 6 files changed, 8 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 039fe2bf..3bd3dca3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -448,7 +448,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, config, weights): super().__init__() - self.config = config self.model = FlashLlamaModel(config, weights) self.lm_head = TensorParallelHead.load( config, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 5e0580b2..4603d577 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -73,7 +73,9 @@ def _load_multi_mqa_gptq( g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") bits, groupsize = weights._get_gptq_qparams() - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) + from text_generation_server.utils.layers import HAS_EXLLAMA + use_exllama = HAS_EXLLAMA + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) if bias: slice_ = weights._get_slice(f"{prefix}.c_attn.bias") @@ -350,7 +352,6 @@ class Block(nn.Module): max_s, ): hidden_states, residual = self.ln_1(hidden_states, residual) - hidden_states = self.attn( hidden_states, cu_seqlen_prefill, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 3b862737..1753d4c0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -7,7 +7,7 @@ import numpy as np from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase, PretrainedConfig +from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Dict from text_generation_server.models import Model @@ -20,7 +20,6 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser - tracer = trace.get_tracer(__name__) BLOCK_SIZE = 16 diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index ce868904..61004d8e 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -59,7 +59,6 @@ class FlashNeoXSharded(FlashCausalLM): super(FlashNeoXSharded, self).__init__( model=model.to(device), tokenizer=tokenizer, - config=config, num_layers=len(model.gpt_neox.layers), num_kv_heads=model.gpt_neox.num_heads, head_size=model.gpt_neox.head_size, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index b022892d..b279426b 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -143,7 +143,10 @@ def serve( if quantize == "gptq": try: - from text_generation_server.utils.gptq.quant_linear import create_exllama_buffers + # When using GPTQ, Exllama kernels need some global kernels + # For which we have the finale shapes only after the model has loaded + # This will allocate those buffers. + from text_generation_server.utils.gptq.exllama import create_exllama_buffers create_exllama_buffers() except ImportError: pass diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index 4d7814af..54fa2014 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -4,10 +4,6 @@ import torch import torch.nn as nn from torch.cuda.amp import custom_bwd, custom_fwd -import torch - -from loguru import logger - try: import triton import triton.language as tl @@ -256,7 +252,6 @@ class QuantLinear(nn.Module): self.register_buffer("qzeros", qzeros) self.register_buffer("scales", scales) self.register_buffer("g_idx", g_idx) - if bias is not None: self.register_buffer("bias", bias) else: @@ -362,4 +357,3 @@ class QuantLinear(nn.Module): ) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) -