Remove lots of dead code, move triton to hard requirement

- Added option to upload to hub directly after quantizing.
This commit is contained in:
Nicolas Patry 2023-06-14 14:55:45 +02:00
parent 5de6863756
commit 732da6942b
8 changed files with 1354 additions and 1519 deletions

View File

@ -166,7 +166,6 @@ FROM base as sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh RUN chmod +x entrypoint.sh
RUN pip install triton
ENTRYPOINT ["./entrypoint.sh"] ENTRYPOINT ["./entrypoint.sh"]

2031
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -27,6 +27,7 @@ sentencepiece = "^0.1.97"
tokenizers = "0.13.3" tokenizers = "0.13.3"
huggingface-hub = "^0.14.1" huggingface-hub = "^0.14.1"
transformers = "^4.29.2" transformers = "^4.29.2"
triton = "^2.0.0"
[tool.poetry.extras] [tool.poetry.extras]
accelerate = ["accelerate"] accelerate = ["accelerate"]

View File

@ -1,21 +1,27 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0" backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0"
certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0" certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0"
charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0" charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0" click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows" cmake==3.26.4 ; python_version >= "3.9" and python_version < "4.0"
deprecated==1.2.13 ; python_version >= "3.9" and python_version < "4.0" colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows")
filelock==3.12.0 ; python_version >= "3.9" and python_version < "4.0" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
fsspec==2023.5.0 ; python_version >= "3.9" and python_version < "4.0" filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
googleapis-common-protos==1.59.0 ; python_version >= "3.9" and python_version < "4.0" fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0" grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio-reflection==1.55.0 ; python_version >= "3.9" and python_version < "4.0" grpcio-reflection==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio-status==1.55.0 ; python_version >= "3.9" and python_version < "4.0" grpcio-status==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio==1.55.0 ; python_version >= "3.9" and python_version < "4.0" grpcio==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0" hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0" huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
idna==3.4 ; python_version >= "3.9" and python_version < "4" idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0"
lit==16.0.5.post0 ; python_version >= "3.9" and python_version < "4.0"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0" loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "4.0"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
networkx==3.1 ; python_version >= "3.9" and python_version < "4.0"
numpy==1.24.3 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
@ -26,17 +32,21 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0" packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
protobuf==4.23.1 ; python_version >= "3.9" and python_version < "4.0" protobuf==4.23.2 ; python_version >= "3.9" and python_version < "4.0"
pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0" pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0" requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0" safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0" setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0"
sympy==1.12 ; python_version >= "3.9" and python_version < "4.0"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0" torch==2.0.1 ; python_version >= "3.9" and python_version < "4.0"
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0" tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.30.2 ; python_version >= "3.9" and python_version < "4.0"
triton==2.0.0.post1 ; python_version >= "3.9" and python_version < "4.0"
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0" typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
typing-extensions==4.6.0 ; python_version >= "3.9" and python_version < "4.0" typing-extensions==4.6.3 ; python_version >= "3.9" and python_version < "4.0"
urllib3==2.0.2 ; python_version >= "3.9" and python_version < "4.0" urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0" wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"

View File

@ -150,6 +150,7 @@ def download_weights(
# Convert pytorch weights to safetensors # Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files) utils.convert_files(local_pt_files, local_st_files)
@app.command() @app.command()
def quantize( def quantize(
model_id: str, model_id: str,
@ -158,25 +159,24 @@ def quantize(
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
upload_to_model_id: Optional[str] = None,
): ):
extension: str = ".safetensors", download_weights(
# Remove default handler model_id=model_id,
logger.remove() revision=revision,
logger.add( logger_level=logger_level,
sys.stdout, json_output=json_output,
format="{message}",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
) )
download_weights(model_id=model_id, revision=revision, logger_level=logger_level, json_output=json_output)
from text_generation_server.utils.gptq.quantize import quantize from text_generation_server.utils.gptq.quantize import quantize
quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir, trust_remote_code=trust_remote_code)
quantize(
model_id=model_id,
bits=4,
groupsize=128,
output_dir=output_dir,
trust_remote_code=trust_remote_code,
upload_to_model_id=upload_to_model_id,
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -12,66 +12,121 @@ try:
# code based https://github.com/fpgaminer/GPTQ-triton # code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune( @custom_autotune.autotune(
configs=[ configs=[
triton.Config({ triton.Config(
'BLOCK_SIZE_M': 64, {
'BLOCK_SIZE_N': 256, "BLOCK_SIZE_M": 64,
'BLOCK_SIZE_K': 32, "BLOCK_SIZE_N": 256,
'GROUP_SIZE_M': 8 "BLOCK_SIZE_K": 32,
}, num_stages=4, num_warps=4), "GROUP_SIZE_M": 8,
triton.Config({ },
'BLOCK_SIZE_M': 128, num_stages=4,
'BLOCK_SIZE_N': 128, num_warps=4,
'BLOCK_SIZE_K': 32, ),
'GROUP_SIZE_M': 8 triton.Config(
}, num_stages=4, num_warps=4), {
triton.Config({ "BLOCK_SIZE_M": 128,
'BLOCK_SIZE_M': 64, "BLOCK_SIZE_N": 128,
'BLOCK_SIZE_N': 128, "BLOCK_SIZE_K": 32,
'BLOCK_SIZE_K': 32, "GROUP_SIZE_M": 8,
'GROUP_SIZE_M': 8 },
}, num_stages=4, num_warps=4), num_stages=4,
triton.Config({ num_warps=4,
'BLOCK_SIZE_M': 128, ),
'BLOCK_SIZE_N': 32, triton.Config(
'BLOCK_SIZE_K': 32, {
'GROUP_SIZE_M': 8 "BLOCK_SIZE_M": 64,
}, num_stages=4, num_warps=4), "BLOCK_SIZE_N": 128,
triton.Config({ "BLOCK_SIZE_K": 32,
'BLOCK_SIZE_M': 64, "GROUP_SIZE_M": 8,
'BLOCK_SIZE_N': 64, },
'BLOCK_SIZE_K': 32, num_stages=4,
'GROUP_SIZE_M': 8 num_warps=4,
}, num_stages=4, num_warps=4), ),
triton.Config({ triton.Config(
'BLOCK_SIZE_M': 64, {
'BLOCK_SIZE_N': 128, "BLOCK_SIZE_M": 128,
'BLOCK_SIZE_K': 32, "BLOCK_SIZE_N": 32,
'GROUP_SIZE_M': 8 "BLOCK_SIZE_K": 32,
}, num_stages=2, num_warps=8), "GROUP_SIZE_M": 8,
triton.Config({ },
'BLOCK_SIZE_M': 64, num_stages=4,
'BLOCK_SIZE_N': 64, num_warps=4,
'BLOCK_SIZE_K': 64, ),
'GROUP_SIZE_M': 8 triton.Config(
}, num_stages=3, num_warps=8), {
triton.Config({ "BLOCK_SIZE_M": 64,
'BLOCK_SIZE_M': 32, "BLOCK_SIZE_N": 64,
'BLOCK_SIZE_N': 32, "BLOCK_SIZE_K": 32,
'BLOCK_SIZE_K': 128, "GROUP_SIZE_M": 8,
'GROUP_SIZE_M': 8 },
}, num_stages=2, num_warps=4), num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=4,
),
], ],
key=['M', 'N', 'K'], key=["M", "N", "K"],
nearest_power_of_two=True, nearest_power_of_two=True,
prune_configs_by={ prune_configs_by={
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, "early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
'perf_model': None, "perf_model": None,
'top_k': None, "top_k": None,
}, },
) )
@triton.jit @triton.jit
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, def matmul_248_kernel(
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
g_ptr,
M,
N,
K,
bits,
maxq,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
""" """
Compute the matrix multiplication C = A x B. Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16 A is of shape (M, K) float16
@ -79,7 +134,7 @@ try:
C is of shape (M, N) float16 C is of shape (M, N) float16
scales is of shape (G, N) float16 scales is of shape (G, N) float16
zeros is of shape (G, N) float16 zeros is of shape (G, N) float16
g_ptr is of shape (K) int32 g_ptr is of shape (K) int32
""" """
infearure_per_bits = 32 // bits infearure_per_bits = 32 // bits
@ -97,10 +152,15 @@ try:
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a_ptrs = a_ptr + (
a_mask = (offs_am[:, None] < M) offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times # b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) b_ptrs = b_ptr + (
(offs_k[:, None] // infearure_per_bits) * stride_bk
+ offs_bn[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B # shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :] scales_ptrs = scales_ptr + offs_bn[None, :]
@ -114,13 +174,17 @@ try:
g_idx = tl.load(g_ptrs) g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) scales = tl.load(
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) scales_ptrs + g_idx[:, None] * stride_scales
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(
zeros_ptrs + g_idx[:, None] * stride_zeros
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) zeros = zeros + 1
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values # Now we need to unpack b (which is N-bit values) into 32-bit values
@ -136,170 +200,50 @@ try:
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
@custom_autotune.autotune(configs=[
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 256,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
}, num_stages=3, num_warps=8),
triton.Config({
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}, num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
nearest_power_of_two=True)
@triton.jit
def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,
stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, N) float16
B is of shape (K//8, N) int32
C is of shape (M, K) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_k
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_k = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk
g_idx = tl.load(g_ptrs)
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
shifter = (offs_bk % infearure_per_bits) * bits
zeros_shifter = (offs_n % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
for n in range(0, num_pid_n):
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
b = tl.trans(b)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N
scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask)
except: except:
print('triton not installed.') print("triton not installed.")
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device): with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) output = torch.empty(
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), )
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) grid = lambda META: (
return output triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): matmul_248_kernel[grid](
with torch.cuda.device(input.device): input,
output_dim = (qweight.shape[0] * 32) // bits qweight,
output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16) output,
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) scales,
transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), qzeros,
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) g_idx,
input.shape[0],
qweight.shape[1],
input.shape[1],
bits,
maxq,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
)
return output return output
class QuantLinearFunction(torch.autograd.Function): class QuantLinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
return output return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
return grad_input, None, None, None, None, None, None
class QuantLinear(nn.Module): class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
@ -326,18 +270,23 @@ class QuantLinear(nn.Module):
if bits not in [2, 4, 8]: if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.") raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32) qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 32 * bits), dtype=torch.int32) qzeros = torch.zeros(
scales = torch.zeros((math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16) (math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
g_idx = torch.tensor([i // groupsize for i in range(infeatures)], dtype=torch.int32) dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias: if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16) bias = torch.zeros((outfeatures), dtype=torch.float16)
else: else:
bias = None bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize) return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None): def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
@ -350,11 +299,18 @@ class QuantLinear(nn.Module):
intweight = [] intweight = []
for idx in range(self.infeatures): for idx in range(self.infeatures):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None]) intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1) intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous() intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32) intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0 i = 0
row = 0 row = 0
while row < qweight.shape[0]: while row < qweight.shape[0]:
@ -371,7 +327,9 @@ class QuantLinear(nn.Module):
zeros -= 1 zeros -= 1
zeros = zeros.numpy().astype(np.uint32) zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0 i = 0
col = 0 col = 0
while col < qzeros.shape[1]: while col < qzeros.shape[1]:
@ -386,9 +344,16 @@ class QuantLinear(nn.Module):
qzeros = qzeros.astype(np.int32) qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros) self.qzeros = torch.from_numpy(qzeros)
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures, ) out_shape = x.shape[:-1] + (self.outfeatures,)
out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq) out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.bits,
self.maxq,
)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)

View File

@ -10,10 +10,12 @@ import os
from texttable import Texttable from texttable import Texttable
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import transformers import transformers
from huggingface_hub import HfApi
import numpy as np import numpy as np
import torch import torch
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
from loguru import logger from loguru import logger
from typings import Optional
DEV = torch.device("cuda:0") DEV = torch.device("cuda:0")
@ -613,13 +615,6 @@ def sequential(
layers = model.transformer.h layers = model.transformer.h
prefix = "transformer.h" prefix = "transformer.h"
# embeddings = model.get_input_embeddings()
# embeddings = embeddings.to(dev)
# model.set_input_embeddings(embeddings)
# model.model.embed_tokens = model.model.embed_tokens.to(dev)
# model.model.norm = model.model.norm.to(dev)
# layers[0] = layers[0].to(dev)
dtype = next(iter(model.parameters())).dtype dtype = next(iter(model.parameters())).dtype
inps = torch.zeros( inps = torch.zeros(
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
@ -728,136 +723,11 @@ def sequential(
print("+------------------+--------------+------------+-----------+-------+") print("+------------------+--------------+------------+-----------+-------+")
print("\n") print("\n")
# if args.observe:
# observer.print()
# conditions = gen_conditions(args.bits, args.groupsize)
# for item in observer.items():
# name = item[0]
# layerid = item[1]
# gptq = item[2]['gptq']
# error = item[2]['error']
# target = error / 2
# table = Texttable()
# table.header(['bits', 'groupsize', 'error'])
# table.set_cols_dtype(['i', 'i', 'f'])
# table.add_row([args.bits, args.groupsize, error])
# print('Optimizing {} {} ..'.format(name, layerid))
# for bits, groupsize in conditions:
# if error < target:
# # if error dropped 50%, skip
# break
# gptq.quantizer.configure(bits, perchannel=True, sym=args.sym, mse=False)
# scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name)
# table.add_row([bits, groupsize, error])
# quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), bits, groupsize)
# print(table.draw())
# print('\n')
# gptq.layer.to('cpu')
# gptq.free()
model.config.use_cache = use_cache model.config.use_cache = use_cache
return quantizers return quantizers
# @torch.no_grad()
# def llama_eval(model, testenc, dev):
# print('Evaluating ...')
#
# testenc = testenc.input_ids
# nsamples = testenc.numel() // model.seqlen
#
# use_cache = model.config.use_cache
# model.config.use_cache = False
# layers = model.model.layers
#
# model.model.embed_tokens = model.model.embed_tokens.to(dev)
# layers[0] = layers[0].to(dev)
#
# dtype = next(iter(model.parameters())).dtype
# inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
# cache = {'i': 0, 'attention_mask': None}
#
# class Catcher(nn.Module):
#
# def __init__(self, module):
# super().__init__()
# self.module = module
#
# def forward(self, inp, **kwargs):
# inps[cache['i']] = inp
# cache['i'] += 1
# cache['attention_mask'] = kwargs['attention_mask']
# cache['position_ids'] = kwargs['position_ids']
# raise ValueError
#
# layers[0] = Catcher(layers[0])
# for i in range(nsamples):
# batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
# try:
# model(batch)
# except ValueError:
# pass
# layers[0] = layers[0].module
#
# layers[0] = layers[0].cpu()
# model.model.embed_tokens = model.model.embed_tokens.cpu()
# torch.cuda.empty_cache()
#
# outs = torch.zeros_like(inps)
# attention_mask = cache['attention_mask']
# position_ids = cache['position_ids']
#
# for i in range(len(layers)):
# print(i)
# layer = layers[i].to(dev)
#
# if args.nearest:
# subset = find_layers(layer)
# for name in subset:
# quantizer = quant.Quantizer()
# quantizer.configure(args.bits, perchannel=True, sym=args.sym, mse=False)
# W = subset[name].weight.data
# quantizer.find_params(W, weight=True)
# subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
#
# for j in range(nsamples):
# outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
# layers[i] = layer.cpu()
# del layer
# torch.cuda.empty_cache()
# inps, outs = outs, inps
#
# if model.model.norm is not None:
# model.model.norm = model.model.norm.to(dev)
# model.lm_head = model.lm_head.to(dev)
#
# testenc = testenc.to(dev)
# nlls = []
# for i in range(nsamples):
# hidden_states = inps[i].unsqueeze(0)
# if model.model.norm is not None:
# hidden_states = model.model.norm(hidden_states)
# lm_logits = model.lm_head(hidden_states)
# shift_logits = lm_logits[:, :-1, :].contiguous()
# shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
# loss_fct = nn.CrossEntropyLoss()
# loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# neg_log_likelihood = loss.float() * model.seqlen
# nlls.append(neg_log_likelihood)
# ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
# print(ppl.item())
#
# model.config.use_cache = use_cache
def make_quant_linear(module, names, bits, groupsize, name=""): def make_quant_linear(module, names, bits, groupsize, name=""):
if isinstance(module, QuantLinear): if isinstance(module, QuantLinear):
return return
@ -898,170 +768,13 @@ def pack(model, quantizers, bits, groupsize):
return model return model
# def load_quant(model, checkpoint, bits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
# from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils
# config = LlamaConfig.from_pretrained(model)
#
# def noop(*args, **kwargs):
# pass
#
# torch.nn.init.kaiming_uniform_ = noop
# torch.nn.init.uniform_ = noop
# torch.nn.init.normal_ = noop
#
# torch.set_default_dtype(torch.half)
# modeling_utils._init_weights = False
# torch.set_default_dtype(torch.half)
# model = LlamaForCausalLM(config)
# torch.set_default_dtype(torch.float)
# if eval:
# model = model.eval()
# layers = find_layers(model)
# for name in ['lm_head']:
# if name in layers:
# del layers[name]
# quant.make_quant_linear(model, layers, bits, groupsize)
#
# del layers
#
# print('Loading model ...')
# if checkpoint.endswith('.safetensors'):
# from safetensors.torch import load_file as safe_load
# model.load_state_dict(safe_load(checkpoint))
# else:
# model.load_state_dict(torch.load(checkpoint))
#
# if eval:
# quant.make_quant_attn(model)
# quant.make_quant_norm(model)
# if fused_mlp:
# quant.make_fused_mlp(model)
#
# if warmup_autotune:
# quant.autotune_warmup_linear(model, transpose=not (eval))
# if eval and fused_mlp:
# quant.autotune_warmup_fused(model)
# model.seqlen = 2048
# print('Done.')
#
# return model
# def llama_multigpu(model, gpus, gpu_dist):
# model.model.embed_tokens = model.model.embed_tokens.to(gpus[0])
# if hasattr(model.model, 'norm') and model.model.norm:
# model.model.norm = model.model.norm.to(gpus[0])
# import copy
# model.lm_head = copy.deepcopy(model.lm_head).to(gpus[0])
#
# cache = {'mask': None, 'position_ids': None}
#
# class MoveModule(nn.Module):
#
# def __init__(self, module, invalidate_cache):
# super().__init__()
# self.module = module
# self.dev = next(iter(self.module.parameters())).device
# self.invalidate_cache=invalidate_cache
#
# def forward(self, *inp, **kwargs):
# inp = list(inp)
# if inp[0].device != self.dev:
# inp[0] = inp[0].to(self.dev)
#
# if cache['mask'] is None or cache['mask'].device != self.dev or self.invalidate_cache:
# cache['mask'] = kwargs['attention_mask'].to(self.dev)
# kwargs['attention_mask'] = cache['mask']
#
# if cache['position_ids'] is None or cache['position_ids'].device != self.dev or self.invalidate_cache:
# cache['position_ids'] = kwargs['position_ids'].to(self.dev)
# kwargs['position_ids'] = cache['position_ids']
#
# tmp = self.module(*inp, **kwargs)
# return tmp
#
# layers = model.model.layers
# from math import ceil
# if not gpu_dist:
# pergpu = ceil(len(layers) / len(gpus))
# for i in range(len(layers)):
# layers[i] = MoveModule(layers[i].to(0 if i == 0 or i == len(layers) -1 else gpus[(i-1) // pergpu]), i==0)
# else:
# assert gpu_dist[0] >= 2, "At least two layers must be on GPU 0."
# assigned_gpus = [0] * (gpu_dist[0]-1)
# for i in range(1, len(gpu_dist)):
# assigned_gpus = assigned_gpus + [i] * gpu_dist[i]
#
# remaining_assignments = len(layers)-len(assigned_gpus) - 1
# if remaining_assignments > 0:
# assigned_gpus = assigned_gpus + [-1] * remaining_assignments
#
# assigned_gpus = assigned_gpus + [0]
#
# for i in range(len(layers)):
# layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]), i==0)
#
# model.gpus = gpus
#
#
# def benchmark(model, input_ids, check=False):
# input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
# torch.cuda.synchronize()
#
# cache = {'past': None}
#
# def clear_past(i):
#
# def tmp(layer, inp, out):
# if cache['past']:
# cache['past'][i] = None
#
# return tmp
#
# for i, layer in enumerate(model.model.layers):
# layer.register_forward_hook(clear_past(i))
#
# print('Benchmarking ...')
#
# if check:
# loss = nn.CrossEntropyLoss()
# tot = 0.
#
# def sync():
# if hasattr(model, 'gpus'):
# for gpu in model.gpus:
# torch.cuda.synchronize(gpu)
# else:
# torch.cuda.synchronize()
#
# max_memory = 0
# with torch.no_grad():
# attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
# times = []
# for i in range(input_ids.numel()):
# tick = time.time()
# out = model(input_ids[:, i:i + 1], past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
# sync()
# times.append(time.time() - tick)
# print(i, times[-1])
# if hasattr(model, 'gpus'):
# mem_allocated = sum(torch.cuda.memory_allocated(gpu) for gpu in model.gpus) / 1024 / 1024
# else:
# mem_allocated = torch.cuda.memory_allocated() / 1024 / 1024
# max_memory = max(max_memory, mem_allocated)
# if check and i != input_ids.numel() - 1:
# tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
# cache['past'] = list(out.past_key_values)
# del out
# sync()
# print('Median:', np.median(times))
# if check:
# print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
# print('max memory(MiB):', max_memory)
def quantize( def quantize(
model_id: str, bits: int, groupsize: int, output_dir: str, trust_remote_code: bool model_id: str,
bits: int,
groupsize: int,
output_dir: str,
trust_remote_code: bool,
upload_to_model_id: Optional[str],
): ):
print("loading model") print("loading model")
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
@ -1085,48 +798,6 @@ def quantize(
quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize) quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize)
print(time.time() - tick) print(time.time() - tick)
# if args.benchmark:
# gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
# if len(gpus) > 1:
# llama_multigpu(model, gpus, gpu_dist)
# else:
# model = model.to(DEV)
# if args.benchmark:
# input_ids = next(iter(dataloader))[0][:, :args.benchmark]
# benchmark(model, input_ids, check=args.check)
# if args.eval:
# datasets = ['wikitext2', 'ptb', 'c4']
# if args.new_eval:
# datasets = ['wikitext2', 'ptb-new', 'c4-new']
# for dataset in datasets:
# dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
# print(dataset)
# llama_eval(model, testloader, DEV)
#
# if args.test_generation:
# gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
# if len(gpus) > 1:
# llama_multigpu(model, gpus, gpu_dist)
# else:
# model = model.to(DEV)
# from transformers import LlamaTokenizer, TextStreamer
# tokenizer = LlamaTokenizer.from_pretrained(args.model, use_fast=False)
# input_ids = tokenizer(["The capital of New Mexico is"], return_tensors="pt").input_ids.to(gpus[0])
# streamer = TextStreamer(tokenizer)
# with torch.no_grad():
# generated_ids = model.generate(input_ids, streamer=streamer)
#
# if args.quant_directory is not None:
# export_quant_table(quantizers, args.quant_directory)
# if not args.observe and args.save:
# llama_pack(model, quantizers, args.bits, args.groupsize)
# torch.save(model.state_dict(), args.save)
# if not args.observe and args.save_safetensors:
pack(model, quantizers, bits, groupsize) pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file from safetensors.torch import save_file
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
@ -1174,3 +845,11 @@ def quantize(
) )
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
logger.info("Saved tokenizer") logger.info("Saved tokenizer")
if upload_to_model_id:
api = HfApi()
api.upload_folder(
folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model"
)

View File

@ -134,13 +134,15 @@ def get_linear(weight, bias, quantize):
try: try:
qweight, qzeros, scales, g_idx, bits, groupsize = weight qweight, qzeros, scales, g_idx, bits, groupsize = weight
except Exception: except Exception:
raise NotImplementedError(f"The passed weight is not `gptq` compatible, loader needs to be updated.") raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
linear = QuantLinear( linear = QuantLinear(
qweight, qweight,
qzeros, qzeros,
scales, scales,
g_idx, g_idx,
bias, bias,
bits, bits,
groupsize, groupsize,
@ -221,7 +223,9 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod @classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize, dim=dim) weight = weights.get_multi_weights_col(
prefixes, quantize=config.quantize, dim=dim
)
if bias: if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]