mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Simpler exllama
This commit is contained in:
parent
6bf7090ecd
commit
5ca0508d02
3
Makefile
3
Makefile
@ -56,6 +56,3 @@ run-bloom:
|
|||||||
|
|
||||||
run-bloom-quantize:
|
run-bloom-quantize:
|
||||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080
|
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080
|
||||||
|
|
||||||
clean:
|
|
||||||
rm -rf target aml
|
|
||||||
|
@ -230,19 +230,16 @@ def launcher(event_loop):
|
|||||||
shard_uds_path,
|
shard_uds_path,
|
||||||
]
|
]
|
||||||
|
|
||||||
env = os.environ
|
|
||||||
|
|
||||||
if num_shard is not None:
|
if num_shard is not None:
|
||||||
args.extend(["--num-shard", str(num_shard)])
|
args.extend(["--num-shard", str(num_shard)])
|
||||||
if quantize is not None:
|
if quantize is not None:
|
||||||
args.append("--quantize")
|
args.append("--quantize")
|
||||||
args.append(quantize)
|
args.append(quantize)
|
||||||
if quantize == "gptq":
|
|
||||||
env["GPTQ_GROUPSIZE"] = "128"
|
|
||||||
env["GPTQ_BITS"] = "4"
|
|
||||||
if trust_remote_code:
|
if trust_remote_code:
|
||||||
args.append("--trust-remote-code")
|
args.append("--trust-remote-code")
|
||||||
|
|
||||||
|
env = os.environ
|
||||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||||
|
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
|
@ -1,102 +1,103 @@
|
|||||||
{
|
{
|
||||||
"generated_text": ", and I am going to visit the Louvre",
|
|
||||||
"details": {
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
"generated_tokens": 10,
|
"generated_tokens": 10,
|
||||||
"seed": null,
|
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 2,
|
"id": 1,
|
||||||
"text": "</s>",
|
"logprob": null,
|
||||||
"logprob": null
|
"text": "<s>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 20628,
|
"id": 20628,
|
||||||
"text": "Today",
|
"logprob": -10.328125,
|
||||||
"logprob": -11.2265625
|
"text": "Today"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 306,
|
"id": 306,
|
||||||
"text": "I",
|
"logprob": -2.390625,
|
||||||
"logprob": -4.1757812
|
"text": "I"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 626,
|
"id": 626,
|
||||||
"text": "am",
|
"logprob": -1.8857422,
|
||||||
"logprob": -1.9746094
|
"text": "am"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 297,
|
"id": 297,
|
||||||
"text": "in",
|
"logprob": -4.4765625,
|
||||||
"logprob": -5.4648438
|
"text": "in"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3444,
|
"id": 3444,
|
||||||
"text": "France",
|
"logprob": -7.0703125,
|
||||||
"logprob": -9.03125
|
"text": "France"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 29892,
|
"id": 29892,
|
||||||
"text": ",",
|
"logprob": -1.2910156,
|
||||||
"logprob": -0.31298828,
|
"special": false,
|
||||||
"special": false
|
"text": ","
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 322,
|
"id": 297,
|
||||||
"text": " and",
|
"logprob": -1.9394531,
|
||||||
"logprob": -1.4345703,
|
"special": false,
|
||||||
"special": false
|
"text": " in"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 306,
|
|
||||||
"text": " I",
|
|
||||||
"logprob": -0.32080078,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 626,
|
|
||||||
"text": " am",
|
|
||||||
"logprob": -1.3798828,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2675,
|
|
||||||
"text": " going",
|
|
||||||
"logprob": -1.2304688,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 304,
|
|
||||||
"text": " to",
|
|
||||||
"logprob": -0.0014791489,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 6493,
|
|
||||||
"text": " visit",
|
|
||||||
"logprob": -1.1503906,
|
|
||||||
"special": false
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 278,
|
"id": 278,
|
||||||
"text": " the",
|
"logprob": -0.7597656,
|
||||||
"logprob": -0.41259766,
|
"special": false,
|
||||||
"special": false
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4562,
|
"id": 7062,
|
||||||
"text": " Lou",
|
"logprob": -2.9121094,
|
||||||
"logprob": -1.8134766,
|
"special": false,
|
||||||
"special": false
|
"text": " south"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 12675,
|
"id": 310,
|
||||||
"text": "vre",
|
"logprob": -1.0302734,
|
||||||
"logprob": -0.000767231,
|
"special": false,
|
||||||
"special": false
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 278,
|
||||||
|
"logprob": -0.58203125,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4234,
|
||||||
|
"logprob": -0.2944336,
|
||||||
|
"special": false,
|
||||||
|
"text": " country"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29892,
|
||||||
|
"logprob": -0.7011719,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 297,
|
||||||
|
"logprob": -1.1054688,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 278,
|
||||||
|
"logprob": -0.52490234,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
|
"generated_text": ", in the south of the country, in the"
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_llama_gptq_handle(launcher):
|
def flash_llama_gptq_handle(launcher):
|
||||||
with launcher("TheBloke/WizardLM-7B-uncensored-GPTQ", num_shard=2, quantize="gptq") as handle:
|
with launcher("huggingface/llama-7b-gptq", num_shard=4, quantize="gptq") as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import pytest
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_starcoder_gptq_handle(launcher):
|
def flash_starcoder_gptq_handle(launcher):
|
||||||
with launcher("Narsil/starcoder-gptq", num_shard=2, quantize="gptq") as handle:
|
with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
@ -46,4 +46,4 @@ async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, re
|
|||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <cuda_runtime.h>
|
// #include <cuda_runtime.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
@ -14,7 +14,7 @@ setup(
|
|||||||
sources=["custom_kernels/fused_attention_cuda.cu"],
|
sources=["custom_kernels/fused_attention_cuda.cu"],
|
||||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||||
),
|
),
|
||||||
CppExtension(
|
CUDAExtension(
|
||||||
name="custom_kernels.exllama",
|
name="custom_kernels.exllama",
|
||||||
sources=[
|
sources=[
|
||||||
"custom_kernels/exllama/exllama_ext.cpp",
|
"custom_kernels/exllama/exllama_ext.cpp",
|
||||||
|
@ -500,7 +500,6 @@ class CausalLM(Model):
|
|||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=model.config,
|
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -378,6 +378,7 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
class FlashLlamaModel(torch.nn.Module):
|
class FlashLlamaModel(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
self.tp_rank = process_group.rank()
|
self.tp_rank = process_group.rank()
|
||||||
@ -448,7 +449,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.model = FlashLlamaModel(config, weights)
|
self.model = FlashLlamaModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = TensorParallelHead.load(
|
||||||
config,
|
config,
|
||||||
|
@ -20,6 +20,7 @@ from text_generation_server.utils.layers import (
|
|||||||
)
|
)
|
||||||
from safetensors import SafetensorError
|
from safetensors import SafetensorError
|
||||||
|
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||||
):
|
):
|
||||||
@ -71,12 +72,19 @@ def _load_multi_mqa_gptq(
|
|||||||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||||
bits, groupsize = weights.get_gptq_qparams()
|
try:
|
||||||
|
bits = weights.get_tensor("gptq_bits").item()
|
||||||
|
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
|
except SafetensorError as e:
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
qweight = qweight.to(weights.device)
|
bits = int(os.getenv("GPTQ_BITS"))
|
||||||
qzeros = qzeros.to(weights.device)
|
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||||
scales = scales.to(weights.device)
|
except Exception:
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
raise e
|
||||||
|
|
||||||
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||||
@ -90,8 +98,6 @@ def _load_multi_mqa_gptq(
|
|||||||
kv_tensor = slice_[-2 * head_size :]
|
kv_tensor = slice_[-2 * head_size :]
|
||||||
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||||
|
|
||||||
bias = bias.to(weights.device)
|
|
||||||
|
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Gptq loading with santacoder is not implemented")
|
raise NotImplementedError("Gptq loading with santacoder is not implemented")
|
||||||
@ -355,7 +361,7 @@ class Block(nn.Module):
|
|||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.attn(
|
hidden_states = self.attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
|
@ -6,8 +6,9 @@ import torch.distributed
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from loguru import logger
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -20,7 +21,6 @@ from text_generation_server.models.types import (
|
|||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
@ -684,7 +684,6 @@ class FlashCausalLM(Model):
|
|||||||
self,
|
self,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
config: PretrainedConfig,
|
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
@ -700,7 +699,6 @@ class FlashCausalLM(Model):
|
|||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=config,
|
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -68,7 +68,6 @@ class FlashLlama(FlashCausalLM):
|
|||||||
super(FlashLlama, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=config,
|
|
||||||
num_layers=len(model.model.layers),
|
num_layers=len(model.model.layers),
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
num_kv_heads=model.model.num_key_value_heads,
|
||||||
head_size=model.model.head_size,
|
head_size=model.model.head_size,
|
||||||
|
@ -59,7 +59,6 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
super(FlashNeoXSharded, self).__init__(
|
super(FlashNeoXSharded, self).__init__(
|
||||||
model=model.to(device),
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=config,
|
|
||||||
num_layers=len(model.gpt_neox.layers),
|
num_layers=len(model.gpt_neox.layers),
|
||||||
num_kv_heads=model.gpt_neox.num_heads,
|
num_kv_heads=model.gpt_neox.num_heads,
|
||||||
head_size=model.gpt_neox.head_size,
|
head_size=model.gpt_neox.head_size,
|
||||||
|
@ -65,7 +65,6 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
super(FlashRWSharded, self).__init__(
|
super(FlashRWSharded, self).__init__(
|
||||||
model=model.to(device),
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=config,
|
|
||||||
num_layers=len(model.transformer.h),
|
num_layers=len(model.transformer.h),
|
||||||
num_kv_heads=model.transformer.cache_size,
|
num_kv_heads=model.transformer.cache_size,
|
||||||
head_size=model.transformer.head_size,
|
head_size=model.transformer.head_size,
|
||||||
|
@ -66,7 +66,6 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
super(FlashSantacoderSharded, self).__init__(
|
super(FlashSantacoderSharded, self).__init__(
|
||||||
model=model.to(device),
|
model=model.to(device),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=config,
|
|
||||||
num_layers=len(model.transformer.h),
|
num_layers=len(model.transformer.h),
|
||||||
num_kv_heads=1,
|
num_kv_heads=1,
|
||||||
head_size=model.transformer.head_size,
|
head_size=model.transformer.head_size,
|
||||||
|
@ -198,7 +198,6 @@ class GalacticaSharded(CausalLM):
|
|||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=config,
|
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -63,7 +63,6 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=config,
|
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -3,27 +3,19 @@ import torch
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple, Optional, TypeVar, Type
|
from typing import List, Tuple, Optional, TypeVar, Type
|
||||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch, GeneratedText
|
from text_generation_server.models.types import Batch, GeneratedText
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear
|
|
||||||
from custom_kernels.exllama import prepare_buffers, set_tuning_params
|
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
|
||||||
TensorParallelRowLinear,
|
|
||||||
TensorParallelColumnLinear
|
|
||||||
)
|
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
|
|
||||||
class Model(ABC):
|
class Model(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
config: PretrainedConfig,
|
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
@ -46,47 +38,6 @@ class Model(ABC):
|
|||||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||||
is not None
|
is not None
|
||||||
)
|
)
|
||||||
self.config = config
|
|
||||||
|
|
||||||
if config.quantize == "gptq":
|
|
||||||
# Buffers need to be persistent to avoid any bug.
|
|
||||||
self.buffers = {}
|
|
||||||
use_exllama_act_order = False
|
|
||||||
max_dq_buffer_size = 1
|
|
||||||
max_inner_outer_dim = 1
|
|
||||||
for name, submodule in model.named_modules():
|
|
||||||
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
|
|
||||||
|
|
||||||
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
|
|
||||||
|
|
||||||
if submodule.linear.act_order:
|
|
||||||
max_inner_outer_dim = max(max_inner_outer_dim, submodule.linear.height, submodule.linear.width)
|
|
||||||
|
|
||||||
use_exllama_act_order = True
|
|
||||||
|
|
||||||
if use_exllama_act_order:
|
|
||||||
# TODO: this should be set to rust side `max_total_tokens`, but TGI
|
|
||||||
# does not offer an API to expose this variable to python, as this variable
|
|
||||||
# is handled by the client but it appears the model is initialized by the server.
|
|
||||||
# An alternative could be to initialize the buffers during warmup.
|
|
||||||
max_total_tokens = 2048
|
|
||||||
else:
|
|
||||||
max_total_tokens = 1
|
|
||||||
|
|
||||||
# This temp_state buffer is required to reorder X in the act-order case.
|
|
||||||
self.buffers["temp_state"] = torch.zeros((max_total_tokens, max_inner_outer_dim), dtype=torch.float16, device=device)
|
|
||||||
|
|
||||||
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
|
||||||
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
|
|
||||||
|
|
||||||
prepare_buffers(device, self.buffers["temp_state"], self.buffers["temp_dq"])
|
|
||||||
|
|
||||||
matmul_recons_thd = 8
|
|
||||||
matmul_fused_remap = False
|
|
||||||
matmul_no_half2 = False
|
|
||||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
self.check_initialized()
|
self.check_initialized()
|
||||||
|
|
||||||
|
@ -86,7 +86,6 @@ class MPTSharded(CausalLM):
|
|||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=config,
|
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -61,7 +61,6 @@ class OPTSharded(CausalLM):
|
|||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=config,
|
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -58,7 +58,6 @@ class RW(CausalLM):
|
|||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=model.config,
|
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -63,7 +63,6 @@ class SantaCoder(CausalLM):
|
|||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=model.config,
|
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -542,7 +542,6 @@ class Seq2SeqLM(Model):
|
|||||||
super(Seq2SeqLM, self).__init__(
|
super(Seq2SeqLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=model.config,
|
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -73,7 +73,6 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
super(Seq2SeqLM, self).__init__(
|
super(Seq2SeqLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
config=config,
|
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -140,6 +140,13 @@ def serve(
|
|||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
from text_generation_server.utils.gptq.exllama import create_buffers
|
||||||
|
create_buffers()
|
||||||
|
logger.info("Created exllama GPTQ buffers !")
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
server = aio.server(
|
server = aio.server(
|
||||||
interceptors=[
|
interceptors=[
|
||||||
ExceptionInterceptor(),
|
ExceptionInterceptor(),
|
||||||
|
89
server/text_generation_server/utils/gptq/exllama.py
Normal file
89
server/text_generation_server/utils/gptq/exllama.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import torch
|
||||||
|
from custom_kernels.exllama import make_q4, q4_matmul, set_tuning_params, prepare_buffers
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||||
|
|
||||||
|
def ext_q4_matmul(x, q4, q4_width):
|
||||||
|
"""Matrix multiplication, returns x @ q4"""
|
||||||
|
outshape = x.shape[:-1] + (q4_width,)
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
|
||||||
|
|
||||||
|
q4_matmul(x, q4, output)
|
||||||
|
|
||||||
|
return output.view(outshape)
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
RANK = os.getenv("RANK", "0")
|
||||||
|
DEVICE = torch.device(f"cuda:{RANK}")
|
||||||
|
MAX_TOTAL_TOKENS = 1
|
||||||
|
MAX_INNER_OUTER_DIM = 0
|
||||||
|
MAX_DQ_BUFFER_SIZE = 0
|
||||||
|
|
||||||
|
|
||||||
|
def create_buffers():
|
||||||
|
temp_state = torch.zeros((MAX_TOTAL_TOKENS, MAX_INNER_OUTER_DIM), dtype=torch.float16, device=DEVICE)
|
||||||
|
temp_dq = torch.zeros((1, MAX_DQ_BUFFER_SIZE), dtype=torch.float16, device=DEVICE)
|
||||||
|
logger.info(f"Creating buffers {temp_state.shape} - {temp_dq.shape} - {DEVICE}")
|
||||||
|
|
||||||
|
prepare_buffers(DEVICE, temp_state, temp_dq)
|
||||||
|
|
||||||
|
matmul_recons_thd = 8
|
||||||
|
matmul_fused_remap = False
|
||||||
|
matmul_no_half2 = False
|
||||||
|
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||||
|
|
||||||
|
class Ex4bitLinear:
|
||||||
|
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||||
|
def __init__(self, qweight, qzeros, scales, bias, bits):
|
||||||
|
assert bits == 4, "We cannot run exllama GPTQ kernels if bits != 4"
|
||||||
|
|
||||||
|
global MAX_INNER_OUTER_DIM, MAX_DQ_BUFFER_SIZE
|
||||||
|
dq = qweight.numel() * 8
|
||||||
|
if dq > MAX_DQ_BUFFER_SIZE:
|
||||||
|
MAX_DQ_BUFFER_SIZE = dq
|
||||||
|
|
||||||
|
width = qweight.shape[1]
|
||||||
|
if width > MAX_INNER_OUTER_DIM:
|
||||||
|
MAX_INNER_OUTER_DIM = width
|
||||||
|
height = qweight.shape[0] * 8
|
||||||
|
if height > MAX_INNER_OUTER_DIM:
|
||||||
|
MAX_INNER_OUTER_DIM = height
|
||||||
|
|
||||||
|
# prepare_buffers(DEVICE, TEMP_STATE, TEMP_DQ)
|
||||||
|
|
||||||
|
|
||||||
|
self.q4 = make_q4(
|
||||||
|
qweight,
|
||||||
|
qzeros,
|
||||||
|
scales,
|
||||||
|
# Never send g_idx, it MUST be like act_order=False, the exllama kernel does not expect it
|
||||||
|
torch.zeros((0, 0), device=torch.device("meta")),
|
||||||
|
DEVICE.index
|
||||||
|
)
|
||||||
|
self.bias = bias if bias is not None else None
|
||||||
|
self.width = width
|
||||||
|
|
||||||
|
# # Infer groupsize from height of qzeros
|
||||||
|
# self.groupsize = None
|
||||||
|
# if self.qzeros.shape[0] > 1:
|
||||||
|
# self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
||||||
|
|
||||||
|
# if self.groupsize is not None:
|
||||||
|
# assert groupsize == self.groupsize
|
||||||
|
|
||||||
|
# # Handle act-order matrix
|
||||||
|
# if self.g_idx is not None:
|
||||||
|
# if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
|
||||||
|
# self.act_order = True
|
||||||
|
# else:
|
||||||
|
# self.act_order = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = ext_q4_matmul(x, self.q4, self.width)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
out.add_(self.bias)
|
||||||
|
return out
|
@ -8,11 +8,6 @@ import torch
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
try:
|
|
||||||
from custom_kernels.exllama import make_q4, q4_matmul
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"The CUDA kernels custom_kernels.exllama not installed, got the error: {e}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@ -368,76 +363,3 @@ class QuantLinear(nn.Module):
|
|||||||
out = out + self.bias if self.bias is not None else out
|
out = out + self.bias if self.bias is not None else out
|
||||||
return out.reshape(out_shape)
|
return out.reshape(out_shape)
|
||||||
|
|
||||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
|
||||||
none_tensor = torch.empty((1, 1), device = "meta")
|
|
||||||
|
|
||||||
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
|
|
||||||
"""Construct Q4Matrix, return handle"""
|
|
||||||
return make_q4(qweight,
|
|
||||||
qzeros,
|
|
||||||
scales,
|
|
||||||
g_idx if g_idx is not None else none_tensor,
|
|
||||||
device)
|
|
||||||
|
|
||||||
def ext_q4_matmul(x, q4, q4_width):
|
|
||||||
"""Matrix multiplication, returns x @ q4"""
|
|
||||||
outshape = x.shape[:-1] + (q4_width,)
|
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
|
|
||||||
|
|
||||||
q4_matmul(x, q4, output)
|
|
||||||
|
|
||||||
return output.view(outshape)
|
|
||||||
|
|
||||||
|
|
||||||
class Ex4bitLinear:
|
|
||||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
|
||||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
|
||||||
assert bits == 4
|
|
||||||
|
|
||||||
self.device = qweight.device
|
|
||||||
self.qweight = qweight
|
|
||||||
self.qzeros = qzeros
|
|
||||||
self.scales = scales
|
|
||||||
self.g_idx = g_idx.cpu() if g_idx is not None else None
|
|
||||||
self.bias = bias if bias is not None else None
|
|
||||||
|
|
||||||
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
|
|
||||||
self.empty_g_idx = True
|
|
||||||
self.g_idx = None
|
|
||||||
|
|
||||||
assert self.device.type == "cuda"
|
|
||||||
assert self.device.index is not None
|
|
||||||
|
|
||||||
self.q4 = ext_make_q4(
|
|
||||||
self.qweight,
|
|
||||||
self.qzeros,
|
|
||||||
self.scales,
|
|
||||||
self.g_idx,
|
|
||||||
self.device.index
|
|
||||||
)
|
|
||||||
|
|
||||||
self.height = qweight.shape[0] * 8
|
|
||||||
self.width = qweight.shape[1]
|
|
||||||
|
|
||||||
# Infer groupsize from height of qzeros
|
|
||||||
self.groupsize = None
|
|
||||||
if self.qzeros.shape[0] > 1:
|
|
||||||
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
|
||||||
|
|
||||||
if self.groupsize is not None:
|
|
||||||
assert groupsize == self.groupsize
|
|
||||||
|
|
||||||
# Handle act-order matrix
|
|
||||||
if self.g_idx is not None:
|
|
||||||
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
|
|
||||||
self.act_order = True
|
|
||||||
else:
|
|
||||||
self.act_order = False
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = ext_q4_matmul(x, self.q4, self.width)
|
|
||||||
|
|
||||||
if self.bias is not None:
|
|
||||||
out.add_(self.bias)
|
|
||||||
return out
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
@ -15,7 +16,14 @@ except ImportError:
|
|||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear, Ex4bitLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
|
|
||||||
|
HAS_EXLLAMA = True
|
||||||
|
try:
|
||||||
|
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
|
||||||
|
except ImportError:
|
||||||
|
logger.error(f"The CUDA kernels custom_kernels.exllama not installed using base triton kernel")
|
||||||
|
HAS_EXLLAMA = False
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -145,13 +153,15 @@ def get_linear(weight, bias, quantize):
|
|||||||
linear.bias = nn.Parameter(bias)
|
linear.bias = nn.Parameter(bias)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
try:
|
try:
|
||||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel = weight
|
qweight, qzeros, scales, g_idx, bits, groupsize, can_exllama = weight
|
||||||
except Exception:
|
except Exception:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_triton_kernel or bits != 4:
|
if can_exllama and HAS_EXLLAMA:
|
||||||
|
linear = Ex4bitLinear(qweight, qzeros, scales, bias, bits)
|
||||||
|
else:
|
||||||
linear = QuantLinear(
|
linear = QuantLinear(
|
||||||
qweight,
|
qweight,
|
||||||
qzeros,
|
qzeros,
|
||||||
@ -161,8 +171,6 @@ def get_linear(weight, bias, quantize):
|
|||||||
bits,
|
bits,
|
||||||
groupsize,
|
groupsize,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||||
return linear
|
return linear
|
||||||
|
@ -3,7 +3,6 @@ from typing import List, Dict, Optional, Tuple
|
|||||||
from safetensors import safe_open, SafetensorError
|
from safetensors import safe_open, SafetensorError
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -126,9 +125,15 @@ class Weights:
|
|||||||
for w2 in w[1:]:
|
for w2 in w[1:]:
|
||||||
torch.testing.assert_close(w2, w[0])
|
torch.testing.assert_close(w2, w[0])
|
||||||
g_idx = w[0]
|
g_idx = w[0]
|
||||||
|
can_exllama = True
|
||||||
|
bits, groupsize = self._get_gptq_qparams()
|
||||||
|
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
|
||||||
|
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||||
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
|
can_exllama = False
|
||||||
|
|
||||||
bits, groupsize = self.get_gptq_qparams()
|
bits, groupsize = self._get_gptq_qparams()
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, can_exllama)
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
weight = torch.cat(w, dim=dim)
|
weight = torch.cat(w, dim=dim)
|
||||||
@ -136,52 +141,32 @@ class Weights:
|
|||||||
|
|
||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
use_triton_kernel = False
|
|
||||||
if self.process_group.size() > 1:
|
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
|
||||||
_, groupsize = self.get_gptq_qparams()
|
|
||||||
|
|
||||||
if g_idx is not None:
|
|
||||||
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
|
|
||||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
|
||||||
# it would require to reorder input activations that are split unto several GPUs
|
|
||||||
use_triton_kernel = True
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||||
|
|
||||||
|
bits, groupsize = self._get_gptq_qparams()
|
||||||
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
|
|
||||||
bits, groupsize = self.get_gptq_qparams()
|
can_exllama = True
|
||||||
|
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
|
||||||
|
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||||
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
|
can_exllama = False
|
||||||
|
|
||||||
if use_triton_kernel:
|
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||||
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
# Thus, each rank needs the full qzeros/scales.
|
||||||
# Thus, each rank needs the full qzeros/scales.
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
||||||
else:
|
|
||||||
if groupsize >= 16:
|
|
||||||
# Exllama reorders the weights in advance and the activations on the fly, thus
|
|
||||||
# the scales and zero-points do not need to be reordered.
|
|
||||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
|
||||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
|
||||||
else:
|
|
||||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
|
||||||
|
|
||||||
# For tp > 1, at this point we know we do not use act-order
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, can_exllama)
|
||||||
if self.process_group.size() == 1:
|
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
|
||||||
else:
|
|
||||||
g_idx = None
|
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
|
|
||||||
else:
|
else:
|
||||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_gptq_qparams(self) -> Tuple[int, int]:
|
def _get_gptq_qparams(self) -> Tuple[int, int]:
|
||||||
try:
|
try:
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
@ -194,4 +179,4 @@ class Weights:
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return bits, groupsize
|
return bits, groupsize
|
||||||
|
Loading…
Reference in New Issue
Block a user