Small polish.

This commit is contained in:
Nicolas Patry 2023-07-20 17:44:37 +00:00
parent 0860394489
commit 8cf7c89910
6 changed files with 8 additions and 13 deletions

View File

@ -448,7 +448,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.config = config
self.model = FlashLlamaModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,

View File

@ -73,7 +73,9 @@ def _load_multi_mqa_gptq(
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
bits, groupsize = weights._get_gptq_qparams()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = HAS_EXLLAMA
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
@ -350,7 +352,6 @@ class Block(nn.Module):
max_s,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
hidden_states,
cu_seqlen_prefill,

View File

@ -7,7 +7,7 @@ import numpy as np
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model
@ -20,7 +20,6 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
tracer = trace.get_tracer(__name__)
BLOCK_SIZE = 16

View File

@ -59,7 +59,6 @@ class FlashNeoXSharded(FlashCausalLM):
super(FlashNeoXSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
config=config,
num_layers=len(model.gpt_neox.layers),
num_kv_heads=model.gpt_neox.num_heads,
head_size=model.gpt_neox.head_size,

View File

@ -143,7 +143,10 @@ def serve(
if quantize == "gptq":
try:
from text_generation_server.utils.gptq.quant_linear import create_exllama_buffers
# When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded
# This will allocate those buffers.
from text_generation_server.utils.gptq.exllama import create_exllama_buffers
create_exllama_buffers()
except ImportError:
pass

View File

@ -4,10 +4,6 @@ import torch
import torch.nn as nn
from torch.cuda.amp import custom_bwd, custom_fwd
import torch
from loguru import logger
try:
import triton
import triton.language as tl
@ -256,7 +252,6 @@ class QuantLinear(nn.Module):
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
@ -362,4 +357,3 @@ class QuantLinear(nn.Module):
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)