mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Finishing nits + integration test
This commit is contained in:
parent
c35f39cf83
commit
8ee9307618
61
integration-tests/models/test_flash_awq.py
Normal file
61
integration-tests/models/test_flash_awq.py
Normal file
@ -0,0 +1,61 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_gptq_handle(launcher):
|
||||
with launcher("abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", num_shard=2, quantize="awq") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_gptq(flash_llama_gptq_handle):
|
||||
await flash_llama_gptq_handle.health(300)
|
||||
return flash_llama_gptq_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
|
||||
response = await flash_llama_gptq.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
|
||||
response = await flash_llama_gptq.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq_load(
|
||||
flash_llama_gptq, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_llama_gptq, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
@ -6,14 +6,14 @@ import torch.nn as nn
|
||||
import awq_inference_engine # with CUDA kernels
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
def __init__(self, module, scales):
|
||||
super().__init__()
|
||||
self.act = module
|
||||
self.scales = nn.Parameter(scales.data)
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
||||
# class ScaledActivation(nn.Module):
|
||||
# def __init__(self, module, scales):
|
||||
# super().__init__()
|
||||
# self.act = module
|
||||
# self.scales = nn.Parameter(scales.data)
|
||||
#
|
||||
# def forward(self, x):
|
||||
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
||||
|
||||
|
||||
class WQLinear(nn.Module):
|
||||
@ -32,11 +32,11 @@ class WQLinear(nn.Module):
|
||||
assert self.in_features % self.group_size == 0
|
||||
assert self.out_features % (32 // self.w_bit) == 0
|
||||
|
||||
self.register_buffer('qweight', qweight)
|
||||
self.register_buffer('qzeros', qzeros)
|
||||
self.register_buffer('scales', scales)
|
||||
self.qweight = qweight
|
||||
self.qzeros = qzeros
|
||||
self.scales = scales
|
||||
if bias:
|
||||
self.register_buffer('bias', bias)
|
||||
self.bias = bias
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@ -46,8 +46,3 @@ class WQLinear(nn.Module):
|
||||
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
|
||||
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
|
||||
)
|
||||
|
@ -139,21 +139,16 @@ class Weights:
|
||||
try:
|
||||
qweight = self._get_qweight(f"{prefix}.qweight")
|
||||
except RuntimeError:
|
||||
if quantize == "gptq":
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||
scales = self._get_qweight(f"{prefix}.scales")
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
try:
|
||||
if quantize == "gptq":
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
except RuntimeError:
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
@ -185,14 +180,9 @@ class Weights:
|
||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
if quantize == "gptq":
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
@ -201,12 +191,12 @@ class Weights:
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
try:
|
||||
if quantize == "gptq":
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
except RuntimeError:
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
@ -233,7 +223,7 @@ class Weights:
|
||||
return tensor
|
||||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize in "gptq":
|
||||
if quantize == "gptq":
|
||||
use_exllama = True
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
|
||||
@ -311,8 +301,10 @@ class Weights:
|
||||
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
g_idx = None
|
||||
use_exllama = False
|
||||
|
||||
weight = (qweight, qzeros, scales, None, bits, groupsize, None)
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
Loading…
Reference in New Issue
Block a user