Small polish.

This commit is contained in:
Nicolas Patry 2023-07-20 17:44:37 +00:00
parent 0860394489
commit 8cf7c89910
6 changed files with 8 additions and 13 deletions

View File

@ -448,7 +448,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
self.config = config
self.model = FlashLlamaModel(config, weights) self.model = FlashLlamaModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = TensorParallelHead.load(
config, config,

View File

@ -73,7 +73,9 @@ def _load_multi_mqa_gptq(
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
bits, groupsize = weights._get_gptq_qparams() bits, groupsize = weights._get_gptq_qparams()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = HAS_EXLLAMA
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
if bias: if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias") slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
@ -350,7 +352,6 @@ class Block(nn.Module):
max_s, max_s,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn( hidden_states = self.attn(
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,

View File

@ -7,7 +7,7 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase, PretrainedConfig from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
@ -20,7 +20,6 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
BLOCK_SIZE = 16 BLOCK_SIZE = 16

View File

@ -59,7 +59,6 @@ class FlashNeoXSharded(FlashCausalLM):
super(FlashNeoXSharded, self).__init__( super(FlashNeoXSharded, self).__init__(
model=model.to(device), model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
config=config,
num_layers=len(model.gpt_neox.layers), num_layers=len(model.gpt_neox.layers),
num_kv_heads=model.gpt_neox.num_heads, num_kv_heads=model.gpt_neox.num_heads,
head_size=model.gpt_neox.head_size, head_size=model.gpt_neox.head_size,

View File

@ -143,7 +143,10 @@ def serve(
if quantize == "gptq": if quantize == "gptq":
try: try:
from text_generation_server.utils.gptq.quant_linear import create_exllama_buffers # When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded
# This will allocate those buffers.
from text_generation_server.utils.gptq.exllama import create_exllama_buffers
create_exllama_buffers() create_exllama_buffers()
except ImportError: except ImportError:
pass pass

View File

@ -4,10 +4,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
import torch
from loguru import logger
try: try:
import triton import triton
import triton.language as tl import triton.language as tl
@ -256,7 +252,6 @@ class QuantLinear(nn.Module):
self.register_buffer("qzeros", qzeros) self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales) self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx) self.register_buffer("g_idx", g_idx)
if bias is not None: if bias is not None:
self.register_buffer("bias", bias) self.register_buffer("bias", bias)
else: else:
@ -362,4 +357,3 @@ class QuantLinear(nn.Module):
) )
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)