mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Finishing nits + integration test
This commit is contained in:
parent
c35f39cf83
commit
8ee9307618
61
integration-tests/models/test_flash_awq.py
Normal file
61
integration-tests/models/test_flash_awq.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_gptq_handle(launcher):
|
||||||
|
with launcher("abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", num_shard=2, quantize="awq") as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_gptq(flash_llama_gptq_handle):
|
||||||
|
await flash_llama_gptq_handle.health(300)
|
||||||
|
return flash_llama_gptq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
|
||||||
|
response = await flash_llama_gptq.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
|
||||||
|
response = await flash_llama_gptq.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_gptq_load(
|
||||||
|
flash_llama_gptq, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_llama_gptq, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
@ -6,14 +6,14 @@ import torch.nn as nn
|
|||||||
import awq_inference_engine # with CUDA kernels
|
import awq_inference_engine # with CUDA kernels
|
||||||
|
|
||||||
|
|
||||||
class ScaledActivation(nn.Module):
|
# class ScaledActivation(nn.Module):
|
||||||
def __init__(self, module, scales):
|
# def __init__(self, module, scales):
|
||||||
super().__init__()
|
# super().__init__()
|
||||||
self.act = module
|
# self.act = module
|
||||||
self.scales = nn.Parameter(scales.data)
|
# self.scales = nn.Parameter(scales.data)
|
||||||
|
#
|
||||||
def forward(self, x):
|
# def forward(self, x):
|
||||||
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
||||||
|
|
||||||
|
|
||||||
class WQLinear(nn.Module):
|
class WQLinear(nn.Module):
|
||||||
@ -32,11 +32,11 @@ class WQLinear(nn.Module):
|
|||||||
assert self.in_features % self.group_size == 0
|
assert self.in_features % self.group_size == 0
|
||||||
assert self.out_features % (32 // self.w_bit) == 0
|
assert self.out_features % (32 // self.w_bit) == 0
|
||||||
|
|
||||||
self.register_buffer('qweight', qweight)
|
self.qweight = qweight
|
||||||
self.register_buffer('qzeros', qzeros)
|
self.qzeros = qzeros
|
||||||
self.register_buffer('scales', scales)
|
self.scales = scales
|
||||||
if bias:
|
if bias:
|
||||||
self.register_buffer('bias', bias)
|
self.bias = bias
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
@ -46,8 +46,3 @@ class WQLinear(nn.Module):
|
|||||||
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
|
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
|
||||||
out = out + self.bias if self.bias is not None else out
|
out = out + self.bias if self.bias is not None else out
|
||||||
return out.reshape(out_shape)
|
return out.reshape(out_shape)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
|
|
||||||
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
|
|
||||||
)
|
|
||||||
|
@ -139,21 +139,16 @@ class Weights:
|
|||||||
try:
|
try:
|
||||||
qweight = self._get_qweight(f"{prefix}.qweight")
|
qweight = self._get_qweight(f"{prefix}.qweight")
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
if quantize == "gptq":
|
raise RuntimeError(
|
||||||
raise RuntimeError(
|
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||||
scales = self._get_qweight(f"{prefix}.scales")
|
scales = self._get_qweight(f"{prefix}.scales")
|
||||||
scales = scales.to(dtype=self.dtype)
|
scales = scales.to(dtype=self.dtype)
|
||||||
try:
|
if quantize == "gptq":
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
except RuntimeError:
|
else:
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize = self._get_gptq_params()
|
||||||
@ -185,14 +180,9 @@ class Weights:
|
|||||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||||
)
|
)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
if quantize == "gptq":
|
raise RuntimeError(
|
||||||
raise RuntimeError(
|
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
qzeros = torch.cat(
|
qzeros = torch.cat(
|
||||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
@ -201,12 +191,12 @@ class Weights:
|
|||||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
if quantize == "gptq":
|
||||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
for w2 in w[1:]:
|
for w2 in w[1:]:
|
||||||
torch.testing.assert_close(w2, w[0])
|
torch.testing.assert_close(w2, w[0])
|
||||||
g_idx = w[0]
|
g_idx = w[0]
|
||||||
except RuntimeError:
|
else:
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize = self._get_gptq_params()
|
||||||
@ -233,7 +223,7 @@ class Weights:
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
if quantize in "gptq":
|
if quantize == "gptq":
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize = self._get_gptq_params()
|
||||||
|
|
||||||
@ -311,8 +301,10 @@ class Weights:
|
|||||||
|
|
||||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
|
g_idx = None
|
||||||
|
use_exllama = False
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, None, bits, groupsize, None)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
else:
|
else:
|
||||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
return weight
|
return weight
|
||||||
|
Loading…
Reference in New Issue
Block a user