mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Remove lots of dead code, move triton to hard requirement
- Added option to upload to hub directly after quantizing.
This commit is contained in:
parent
5de6863756
commit
732da6942b
@ -166,7 +166,6 @@ FROM base as sagemaker
|
||||
|
||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||
RUN chmod +x entrypoint.sh
|
||||
RUN pip install triton
|
||||
|
||||
ENTRYPOINT ["./entrypoint.sh"]
|
||||
|
||||
|
2031
server/poetry.lock
generated
2031
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -27,6 +27,7 @@ sentencepiece = "^0.1.97"
|
||||
tokenizers = "0.13.3"
|
||||
huggingface-hub = "^0.14.1"
|
||||
transformers = "^4.29.2"
|
||||
triton = "^2.0.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
accelerate = ["accelerate"]
|
||||
|
@ -1,21 +1,27 @@
|
||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0"
|
||||
charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows"
|
||||
deprecated==1.2.13 ; python_version >= "3.9" and python_version < "4.0"
|
||||
filelock==3.12.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
fsspec==2023.5.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
googleapis-common-protos==1.59.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
cmake==3.26.4 ; python_version >= "3.9" and python_version < "4.0"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
|
||||
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio-reflection==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio-status==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio-reflection==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio-status==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio==1.54.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
idna==3.4 ; python_version >= "3.9" and python_version < "4"
|
||||
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
|
||||
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
lit==16.0.5.post0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
networkx==3.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
numpy==1.24.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
@ -26,17 +32,21 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
protobuf==4.23.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
protobuf==4.23.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
|
||||
setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
sympy==1.12 ; python_version >= "3.9" and python_version < "4.0"
|
||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
torch==2.0.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
transformers==4.30.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
triton==2.0.0.post1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
typing-extensions==4.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
urllib3==2.0.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
typing-extensions==4.6.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
|
||||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
|
@ -150,6 +150,7 @@ def download_weights(
|
||||
# Convert pytorch weights to safetensors
|
||||
utils.convert_files(local_pt_files, local_st_files)
|
||||
|
||||
|
||||
@app.command()
|
||||
def quantize(
|
||||
model_id: str,
|
||||
@ -158,25 +159,24 @@ def quantize(
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
upload_to_model_id: Optional[str] = None,
|
||||
):
|
||||
extension: str = ".safetensors",
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format="{message}",
|
||||
filter="text_generation_server",
|
||||
level=logger_level,
|
||||
serialize=json_output,
|
||||
backtrace=True,
|
||||
diagnose=False,
|
||||
download_weights(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
logger_level=logger_level,
|
||||
json_output=json_output,
|
||||
)
|
||||
download_weights(model_id=model_id, revision=revision, logger_level=logger_level, json_output=json_output)
|
||||
from text_generation_server.utils.gptq.quantize import quantize
|
||||
quantize(model_id=model_id, bits=4, groupsize=128, output_dir=output_dir, trust_remote_code=trust_remote_code)
|
||||
|
||||
|
||||
|
||||
quantize(
|
||||
model_id=model_id,
|
||||
bits=4,
|
||||
groupsize=128,
|
||||
output_dir=output_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
upload_to_model_id=upload_to_model_id,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -12,66 +12,121 @@ try:
|
||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=3, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
key=["M", "N", "K"],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
|
||||
'perf_model': None,
|
||||
'top_k': None,
|
||||
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
|
||||
"perf_model": None,
|
||||
"top_k": None,
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
|
||||
def matmul_248_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
scales_ptr,
|
||||
zeros_ptr,
|
||||
g_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
maxq,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_scales,
|
||||
stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, K) float16
|
||||
@ -97,10 +152,15 @@ try:
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a_mask = offs_am[:, None] < M
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
b_ptrs = b_ptr + (
|
||||
(offs_k[:, None] // infearure_per_bits) * stride_bk
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_bn[None, :]
|
||||
@ -114,13 +174,17 @@ try:
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
scales = tl.load(
|
||||
scales_ptrs + g_idx[:, None] * stride_scales
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(
|
||||
zeros_ptrs + g_idx[:, None] * stride_zeros
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
zeros = zeros + 1
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
@ -136,170 +200,50 @@ try:
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
@custom_autotune.autotune(configs=[
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 256,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=4, num_warps=4),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=3, num_warps=8),
|
||||
triton.Config({
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
}, num_stages=2, num_warps=4),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
nearest_power_of_two=True)
|
||||
@triton.jit
|
||||
def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,
|
||||
stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, N) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, K) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_k
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_k = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
a_mask = (offs_am[:, None] < M)
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_bk
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
|
||||
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
|
||||
|
||||
shifter = (offs_bk % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_n % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
|
||||
for n in range(0, num_pid_n):
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
b = tl.trans(b)
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_N
|
||||
b_ptrs += BLOCK_SIZE_N
|
||||
scales_ptrs += BLOCK_SIZE_N
|
||||
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
|
||||
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
except:
|
||||
print('triton not installed.')
|
||||
print("triton not installed.")
|
||||
|
||||
|
||||
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )
|
||||
matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
|
||||
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
|
||||
return output
|
||||
|
||||
|
||||
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output_dim = (qweight.shape[0] * 32) // bits
|
||||
output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16)
|
||||
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )
|
||||
transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
|
||||
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
|
||||
output = torch.empty(
|
||||
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
|
||||
)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
matmul_248_kernel[grid](
|
||||
input,
|
||||
qweight,
|
||||
output,
|
||||
scales,
|
||||
qzeros,
|
||||
g_idx,
|
||||
input.shape[0],
|
||||
qweight.shape[1],
|
||||
input.shape[1],
|
||||
bits,
|
||||
maxq,
|
||||
input.stride(0),
|
||||
input.stride(1),
|
||||
qweight.stride(0),
|
||||
qweight.stride(1),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
scales.stride(0),
|
||||
qzeros.stride(0),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinearFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
|
||||
ctx.bits, ctx.maxq = bits, maxq
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
qweight, scales, qzeros, g_idx = ctx.saved_tensors
|
||||
bits, maxq = ctx.bits, ctx.maxq
|
||||
grad_input = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
return grad_input, None, None, None, None, None, None
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
@ -326,18 +270,23 @@ class QuantLinear(nn.Module):
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
|
||||
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
|
||||
qzeros = torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 32 * bits), dtype=torch.int32)
|
||||
scales = torch.zeros((math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16)
|
||||
g_idx = torch.tensor([i // groupsize for i in range(infeatures)], dtype=torch.int32)
|
||||
qzeros = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
scales = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
|
||||
)
|
||||
g_idx = torch.tensor(
|
||||
[i // groupsize for i in range(infeatures)], dtype=torch.int32
|
||||
)
|
||||
if bias:
|
||||
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
||||
else:
|
||||
bias = None
|
||||
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||
|
||||
@ -350,11 +299,18 @@ class QuantLinear(nn.Module):
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
|
||||
/ self.scales[self.g_idx[idx]]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
|
||||
qweight = np.zeros(
|
||||
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
row = 0
|
||||
while row < qweight.shape[0]:
|
||||
@ -371,7 +327,9 @@ class QuantLinear(nn.Module):
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
qzeros = np.zeros(
|
||||
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
@ -386,9 +344,16 @@ class QuantLinear(nn.Module):
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures, )
|
||||
out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq)
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
out = QuantLinearFunction.apply(
|
||||
x.reshape(-1, x.shape[-1]),
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
self.g_idx,
|
||||
self.bits,
|
||||
self.maxq,
|
||||
)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
||||
|
@ -10,10 +10,12 @@ import os
|
||||
from texttable import Texttable
|
||||
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
|
||||
import transformers
|
||||
from huggingface_hub import HfApi
|
||||
import numpy as np
|
||||
import torch
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
from loguru import logger
|
||||
from typings import Optional
|
||||
|
||||
DEV = torch.device("cuda:0")
|
||||
|
||||
@ -613,13 +615,6 @@ def sequential(
|
||||
layers = model.transformer.h
|
||||
prefix = "transformer.h"
|
||||
|
||||
# embeddings = model.get_input_embeddings()
|
||||
# embeddings = embeddings.to(dev)
|
||||
# model.set_input_embeddings(embeddings)
|
||||
# model.model.embed_tokens = model.model.embed_tokens.to(dev)
|
||||
# model.model.norm = model.model.norm.to(dev)
|
||||
# layers[0] = layers[0].to(dev)
|
||||
|
||||
dtype = next(iter(model.parameters())).dtype
|
||||
inps = torch.zeros(
|
||||
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
|
||||
@ -728,136 +723,11 @@ def sequential(
|
||||
print("+------------------+--------------+------------+-----------+-------+")
|
||||
print("\n")
|
||||
|
||||
# if args.observe:
|
||||
# observer.print()
|
||||
# conditions = gen_conditions(args.bits, args.groupsize)
|
||||
# for item in observer.items():
|
||||
# name = item[0]
|
||||
# layerid = item[1]
|
||||
# gptq = item[2]['gptq']
|
||||
# error = item[2]['error']
|
||||
# target = error / 2
|
||||
|
||||
# table = Texttable()
|
||||
# table.header(['bits', 'groupsize', 'error'])
|
||||
# table.set_cols_dtype(['i', 'i', 'f'])
|
||||
# table.add_row([args.bits, args.groupsize, error])
|
||||
|
||||
# print('Optimizing {} {} ..'.format(name, layerid))
|
||||
# for bits, groupsize in conditions:
|
||||
|
||||
# if error < target:
|
||||
# # if error dropped 50%, skip
|
||||
# break
|
||||
|
||||
# gptq.quantizer.configure(bits, perchannel=True, sym=args.sym, mse=False)
|
||||
|
||||
# scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name)
|
||||
|
||||
# table.add_row([bits, groupsize, error])
|
||||
# quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), bits, groupsize)
|
||||
|
||||
# print(table.draw())
|
||||
# print('\n')
|
||||
# gptq.layer.to('cpu')
|
||||
# gptq.free()
|
||||
|
||||
model.config.use_cache = use_cache
|
||||
|
||||
return quantizers
|
||||
|
||||
|
||||
# @torch.no_grad()
|
||||
# def llama_eval(model, testenc, dev):
|
||||
# print('Evaluating ...')
|
||||
#
|
||||
# testenc = testenc.input_ids
|
||||
# nsamples = testenc.numel() // model.seqlen
|
||||
#
|
||||
# use_cache = model.config.use_cache
|
||||
# model.config.use_cache = False
|
||||
# layers = model.model.layers
|
||||
#
|
||||
# model.model.embed_tokens = model.model.embed_tokens.to(dev)
|
||||
# layers[0] = layers[0].to(dev)
|
||||
#
|
||||
# dtype = next(iter(model.parameters())).dtype
|
||||
# inps = torch.zeros((nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev)
|
||||
# cache = {'i': 0, 'attention_mask': None}
|
||||
#
|
||||
# class Catcher(nn.Module):
|
||||
#
|
||||
# def __init__(self, module):
|
||||
# super().__init__()
|
||||
# self.module = module
|
||||
#
|
||||
# def forward(self, inp, **kwargs):
|
||||
# inps[cache['i']] = inp
|
||||
# cache['i'] += 1
|
||||
# cache['attention_mask'] = kwargs['attention_mask']
|
||||
# cache['position_ids'] = kwargs['position_ids']
|
||||
# raise ValueError
|
||||
#
|
||||
# layers[0] = Catcher(layers[0])
|
||||
# for i in range(nsamples):
|
||||
# batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
|
||||
# try:
|
||||
# model(batch)
|
||||
# except ValueError:
|
||||
# pass
|
||||
# layers[0] = layers[0].module
|
||||
#
|
||||
# layers[0] = layers[0].cpu()
|
||||
# model.model.embed_tokens = model.model.embed_tokens.cpu()
|
||||
# torch.cuda.empty_cache()
|
||||
#
|
||||
# outs = torch.zeros_like(inps)
|
||||
# attention_mask = cache['attention_mask']
|
||||
# position_ids = cache['position_ids']
|
||||
#
|
||||
# for i in range(len(layers)):
|
||||
# print(i)
|
||||
# layer = layers[i].to(dev)
|
||||
#
|
||||
# if args.nearest:
|
||||
# subset = find_layers(layer)
|
||||
# for name in subset:
|
||||
# quantizer = quant.Quantizer()
|
||||
# quantizer.configure(args.bits, perchannel=True, sym=args.sym, mse=False)
|
||||
# W = subset[name].weight.data
|
||||
# quantizer.find_params(W, weight=True)
|
||||
# subset[name].weight.data = quantizer.quantize(W).to(next(iter(layer.parameters())).dtype)
|
||||
#
|
||||
# for j in range(nsamples):
|
||||
# outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
|
||||
# layers[i] = layer.cpu()
|
||||
# del layer
|
||||
# torch.cuda.empty_cache()
|
||||
# inps, outs = outs, inps
|
||||
#
|
||||
# if model.model.norm is not None:
|
||||
# model.model.norm = model.model.norm.to(dev)
|
||||
# model.lm_head = model.lm_head.to(dev)
|
||||
#
|
||||
# testenc = testenc.to(dev)
|
||||
# nlls = []
|
||||
# for i in range(nsamples):
|
||||
# hidden_states = inps[i].unsqueeze(0)
|
||||
# if model.model.norm is not None:
|
||||
# hidden_states = model.model.norm(hidden_states)
|
||||
# lm_logits = model.lm_head(hidden_states)
|
||||
# shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
# shift_labels = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)][:, 1:]
|
||||
# loss_fct = nn.CrossEntropyLoss()
|
||||
# loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
# neg_log_likelihood = loss.float() * model.seqlen
|
||||
# nlls.append(neg_log_likelihood)
|
||||
# ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
|
||||
# print(ppl.item())
|
||||
#
|
||||
# model.config.use_cache = use_cache
|
||||
|
||||
|
||||
def make_quant_linear(module, names, bits, groupsize, name=""):
|
||||
if isinstance(module, QuantLinear):
|
||||
return
|
||||
@ -898,170 +768,13 @@ def pack(model, quantizers, bits, groupsize):
|
||||
return model
|
||||
|
||||
|
||||
# def load_quant(model, checkpoint, bits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
|
||||
# from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils
|
||||
# config = LlamaConfig.from_pretrained(model)
|
||||
#
|
||||
# def noop(*args, **kwargs):
|
||||
# pass
|
||||
#
|
||||
# torch.nn.init.kaiming_uniform_ = noop
|
||||
# torch.nn.init.uniform_ = noop
|
||||
# torch.nn.init.normal_ = noop
|
||||
#
|
||||
# torch.set_default_dtype(torch.half)
|
||||
# modeling_utils._init_weights = False
|
||||
# torch.set_default_dtype(torch.half)
|
||||
# model = LlamaForCausalLM(config)
|
||||
# torch.set_default_dtype(torch.float)
|
||||
# if eval:
|
||||
# model = model.eval()
|
||||
# layers = find_layers(model)
|
||||
# for name in ['lm_head']:
|
||||
# if name in layers:
|
||||
# del layers[name]
|
||||
# quant.make_quant_linear(model, layers, bits, groupsize)
|
||||
#
|
||||
# del layers
|
||||
#
|
||||
# print('Loading model ...')
|
||||
# if checkpoint.endswith('.safetensors'):
|
||||
# from safetensors.torch import load_file as safe_load
|
||||
# model.load_state_dict(safe_load(checkpoint))
|
||||
# else:
|
||||
# model.load_state_dict(torch.load(checkpoint))
|
||||
#
|
||||
# if eval:
|
||||
# quant.make_quant_attn(model)
|
||||
# quant.make_quant_norm(model)
|
||||
# if fused_mlp:
|
||||
# quant.make_fused_mlp(model)
|
||||
#
|
||||
# if warmup_autotune:
|
||||
# quant.autotune_warmup_linear(model, transpose=not (eval))
|
||||
# if eval and fused_mlp:
|
||||
# quant.autotune_warmup_fused(model)
|
||||
# model.seqlen = 2048
|
||||
# print('Done.')
|
||||
#
|
||||
# return model
|
||||
|
||||
|
||||
# def llama_multigpu(model, gpus, gpu_dist):
|
||||
# model.model.embed_tokens = model.model.embed_tokens.to(gpus[0])
|
||||
# if hasattr(model.model, 'norm') and model.model.norm:
|
||||
# model.model.norm = model.model.norm.to(gpus[0])
|
||||
# import copy
|
||||
# model.lm_head = copy.deepcopy(model.lm_head).to(gpus[0])
|
||||
#
|
||||
# cache = {'mask': None, 'position_ids': None}
|
||||
#
|
||||
# class MoveModule(nn.Module):
|
||||
#
|
||||
# def __init__(self, module, invalidate_cache):
|
||||
# super().__init__()
|
||||
# self.module = module
|
||||
# self.dev = next(iter(self.module.parameters())).device
|
||||
# self.invalidate_cache=invalidate_cache
|
||||
#
|
||||
# def forward(self, *inp, **kwargs):
|
||||
# inp = list(inp)
|
||||
# if inp[0].device != self.dev:
|
||||
# inp[0] = inp[0].to(self.dev)
|
||||
#
|
||||
# if cache['mask'] is None or cache['mask'].device != self.dev or self.invalidate_cache:
|
||||
# cache['mask'] = kwargs['attention_mask'].to(self.dev)
|
||||
# kwargs['attention_mask'] = cache['mask']
|
||||
#
|
||||
# if cache['position_ids'] is None or cache['position_ids'].device != self.dev or self.invalidate_cache:
|
||||
# cache['position_ids'] = kwargs['position_ids'].to(self.dev)
|
||||
# kwargs['position_ids'] = cache['position_ids']
|
||||
#
|
||||
# tmp = self.module(*inp, **kwargs)
|
||||
# return tmp
|
||||
#
|
||||
# layers = model.model.layers
|
||||
# from math import ceil
|
||||
# if not gpu_dist:
|
||||
# pergpu = ceil(len(layers) / len(gpus))
|
||||
# for i in range(len(layers)):
|
||||
# layers[i] = MoveModule(layers[i].to(0 if i == 0 or i == len(layers) -1 else gpus[(i-1) // pergpu]), i==0)
|
||||
# else:
|
||||
# assert gpu_dist[0] >= 2, "At least two layers must be on GPU 0."
|
||||
# assigned_gpus = [0] * (gpu_dist[0]-1)
|
||||
# for i in range(1, len(gpu_dist)):
|
||||
# assigned_gpus = assigned_gpus + [i] * gpu_dist[i]
|
||||
#
|
||||
# remaining_assignments = len(layers)-len(assigned_gpus) - 1
|
||||
# if remaining_assignments > 0:
|
||||
# assigned_gpus = assigned_gpus + [-1] * remaining_assignments
|
||||
#
|
||||
# assigned_gpus = assigned_gpus + [0]
|
||||
#
|
||||
# for i in range(len(layers)):
|
||||
# layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]), i==0)
|
||||
#
|
||||
# model.gpus = gpus
|
||||
#
|
||||
#
|
||||
# def benchmark(model, input_ids, check=False):
|
||||
# input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
|
||||
# torch.cuda.synchronize()
|
||||
#
|
||||
# cache = {'past': None}
|
||||
#
|
||||
# def clear_past(i):
|
||||
#
|
||||
# def tmp(layer, inp, out):
|
||||
# if cache['past']:
|
||||
# cache['past'][i] = None
|
||||
#
|
||||
# return tmp
|
||||
#
|
||||
# for i, layer in enumerate(model.model.layers):
|
||||
# layer.register_forward_hook(clear_past(i))
|
||||
#
|
||||
# print('Benchmarking ...')
|
||||
#
|
||||
# if check:
|
||||
# loss = nn.CrossEntropyLoss()
|
||||
# tot = 0.
|
||||
#
|
||||
# def sync():
|
||||
# if hasattr(model, 'gpus'):
|
||||
# for gpu in model.gpus:
|
||||
# torch.cuda.synchronize(gpu)
|
||||
# else:
|
||||
# torch.cuda.synchronize()
|
||||
#
|
||||
# max_memory = 0
|
||||
# with torch.no_grad():
|
||||
# attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
|
||||
# times = []
|
||||
# for i in range(input_ids.numel()):
|
||||
# tick = time.time()
|
||||
# out = model(input_ids[:, i:i + 1], past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
|
||||
# sync()
|
||||
# times.append(time.time() - tick)
|
||||
# print(i, times[-1])
|
||||
# if hasattr(model, 'gpus'):
|
||||
# mem_allocated = sum(torch.cuda.memory_allocated(gpu) for gpu in model.gpus) / 1024 / 1024
|
||||
# else:
|
||||
# mem_allocated = torch.cuda.memory_allocated() / 1024 / 1024
|
||||
# max_memory = max(max_memory, mem_allocated)
|
||||
# if check and i != input_ids.numel() - 1:
|
||||
# tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
|
||||
# cache['past'] = list(out.past_key_values)
|
||||
# del out
|
||||
# sync()
|
||||
# print('Median:', np.median(times))
|
||||
# if check:
|
||||
# print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
|
||||
# print('max memory(MiB):', max_memory)
|
||||
|
||||
|
||||
def quantize(
|
||||
model_id: str, bits: int, groupsize: int, output_dir: str, trust_remote_code: bool
|
||||
model_id: str,
|
||||
bits: int,
|
||||
groupsize: int,
|
||||
output_dir: str,
|
||||
trust_remote_code: bool,
|
||||
upload_to_model_id: Optional[str],
|
||||
):
|
||||
print("loading model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@ -1085,48 +798,6 @@ def quantize(
|
||||
quantizers = sequential(model, dataloader, DEV, nsamples, bits, groupsize)
|
||||
print(time.time() - tick)
|
||||
|
||||
# if args.benchmark:
|
||||
# gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
|
||||
# if len(gpus) > 1:
|
||||
# llama_multigpu(model, gpus, gpu_dist)
|
||||
# else:
|
||||
# model = model.to(DEV)
|
||||
# if args.benchmark:
|
||||
# input_ids = next(iter(dataloader))[0][:, :args.benchmark]
|
||||
# benchmark(model, input_ids, check=args.check)
|
||||
|
||||
# if args.eval:
|
||||
# datasets = ['wikitext2', 'ptb', 'c4']
|
||||
# if args.new_eval:
|
||||
# datasets = ['wikitext2', 'ptb-new', 'c4-new']
|
||||
# for dataset in datasets:
|
||||
# dataloader, testloader = get_loaders(dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
|
||||
# print(dataset)
|
||||
# llama_eval(model, testloader, DEV)
|
||||
#
|
||||
# if args.test_generation:
|
||||
# gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
|
||||
# if len(gpus) > 1:
|
||||
# llama_multigpu(model, gpus, gpu_dist)
|
||||
# else:
|
||||
# model = model.to(DEV)
|
||||
|
||||
# from transformers import LlamaTokenizer, TextStreamer
|
||||
# tokenizer = LlamaTokenizer.from_pretrained(args.model, use_fast=False)
|
||||
# input_ids = tokenizer(["The capital of New Mexico is"], return_tensors="pt").input_ids.to(gpus[0])
|
||||
# streamer = TextStreamer(tokenizer)
|
||||
# with torch.no_grad():
|
||||
# generated_ids = model.generate(input_ids, streamer=streamer)
|
||||
#
|
||||
|
||||
# if args.quant_directory is not None:
|
||||
# export_quant_table(quantizers, args.quant_directory)
|
||||
|
||||
# if not args.observe and args.save:
|
||||
# llama_pack(model, quantizers, args.bits, args.groupsize)
|
||||
# torch.save(model.state_dict(), args.save)
|
||||
|
||||
# if not args.observe and args.save_safetensors:
|
||||
pack(model, quantizers, bits, groupsize)
|
||||
from safetensors.torch import save_file
|
||||
from transformers.modeling_utils import shard_checkpoint
|
||||
@ -1174,3 +845,11 @@ def quantize(
|
||||
)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
logger.info("Saved tokenizer")
|
||||
|
||||
if upload_to_model_id:
|
||||
|
||||
api = HfApi()
|
||||
|
||||
api.upload_folder(
|
||||
folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model"
|
||||
)
|
||||
|
@ -134,7 +134,9 @@ def get_linear(weight, bias, quantize):
|
||||
try:
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize = weight
|
||||
except Exception:
|
||||
raise NotImplementedError(f"The passed weight is not `gptq` compatible, loader needs to be updated.")
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
linear = QuantLinear(
|
||||
qweight,
|
||||
@ -221,7 +223,9 @@ class TensorParallelColumnLinear(SuperLayer):
|
||||
|
||||
@classmethod
|
||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||
weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize, dim=dim)
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes, quantize=config.quantize, dim=dim
|
||||
)
|
||||
|
||||
if bias:
|
||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||
|
Loading…
Reference in New Issue
Block a user