mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Small polish.
This commit is contained in:
parent
0860394489
commit
8cf7c89910
@ -448,7 +448,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.model = FlashLlamaModel(config, weights)
|
self.model = FlashLlamaModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = TensorParallelHead.load(
|
||||||
config,
|
config,
|
||||||
|
@ -73,7 +73,9 @@ def _load_multi_mqa_gptq(
|
|||||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||||
bits, groupsize = weights._get_gptq_qparams()
|
bits, groupsize = weights._get_gptq_qparams()
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
use_exllama = HAS_EXLLAMA
|
||||||
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||||
@ -350,7 +352,6 @@ class Block(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.attn(
|
hidden_states = self.attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
|
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -20,7 +20,6 @@ from text_generation_server.models.types import (
|
|||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
|
@ -59,7 +59,6 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
super(FlashNeoXSharded, self).__init__(
|
super(FlashNeoXSharded, self).__init__(
|
||||||
model=model.to(device),
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=config,
|
|
||||||
num_layers=len(model.gpt_neox.layers),
|
num_layers=len(model.gpt_neox.layers),
|
||||||
num_kv_heads=model.gpt_neox.num_heads,
|
num_kv_heads=model.gpt_neox.num_heads,
|
||||||
head_size=model.gpt_neox.head_size,
|
head_size=model.gpt_neox.head_size,
|
||||||
|
@ -143,7 +143,10 @@ def serve(
|
|||||||
|
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
try:
|
try:
|
||||||
from text_generation_server.utils.gptq.quant_linear import create_exllama_buffers
|
# When using GPTQ, Exllama kernels need some global kernels
|
||||||
|
# For which we have the finale shapes only after the model has loaded
|
||||||
|
# This will allocate those buffers.
|
||||||
|
from text_generation_server.utils.gptq.exllama import create_exllama_buffers
|
||||||
create_exllama_buffers()
|
create_exllama_buffers()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
@ -4,10 +4,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@ -256,7 +252,6 @@ class QuantLinear(nn.Module):
|
|||||||
self.register_buffer("qzeros", qzeros)
|
self.register_buffer("qzeros", qzeros)
|
||||||
self.register_buffer("scales", scales)
|
self.register_buffer("scales", scales)
|
||||||
self.register_buffer("g_idx", g_idx)
|
self.register_buffer("g_idx", g_idx)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
self.register_buffer("bias", bias)
|
self.register_buffer("bias", bias)
|
||||||
else:
|
else:
|
||||||
@ -362,4 +357,3 @@ class QuantLinear(nn.Module):
|
|||||||
)
|
)
|
||||||
out = out + self.bias if self.bias is not None else out
|
out = out + self.bias if self.bias is not None else out
|
||||||
return out.reshape(out_shape)
|
return out.reshape(out_shape)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user