Finishing nits + integration test

This commit is contained in:
Nicolas Patry 2023-09-25 10:07:45 +02:00
parent c35f39cf83
commit 8ee9307618
3 changed files with 87 additions and 39 deletions

View File

@ -0,0 +1,61 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_gptq_handle(launcher):
with launcher("abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", num_shard=2, quantize="awq") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_gptq(flash_llama_gptq_handle):
await flash_llama_gptq_handle.health(300)
return flash_llama_gptq_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
response = await flash_llama_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
response = await flash_llama_gptq.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_load(
flash_llama_gptq, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_gptq, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -6,14 +6,14 @@ import torch.nn as nn
import awq_inference_engine # with CUDA kernels
class ScaledActivation(nn.Module):
def __init__(self, module, scales):
super().__init__()
self.act = module
self.scales = nn.Parameter(scales.data)
def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
# class ScaledActivation(nn.Module):
# def __init__(self, module, scales):
# super().__init__()
# self.act = module
# self.scales = nn.Parameter(scales.data)
#
# def forward(self, x):
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
class WQLinear(nn.Module):
@ -32,11 +32,11 @@ class WQLinear(nn.Module):
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0
self.register_buffer('qweight', qweight)
self.register_buffer('qzeros', qzeros)
self.register_buffer('scales', scales)
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
if bias:
self.register_buffer('bias', bias)
self.bias = bias
else:
self.bias = None
@ -46,8 +46,3 @@ class WQLinear(nn.Module):
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
)

View File

@ -139,21 +139,16 @@ class Weights:
try:
qweight = self._get_qweight(f"{prefix}.qweight")
except RuntimeError:
if quantize == "gptq":
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
else:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)
qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype)
try:
if quantize == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx")
except RuntimeError:
else:
g_idx = None
bits, groupsize = self._get_gptq_params()
@ -185,13 +180,8 @@ class Weights:
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
if quantize == "gptq":
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
else:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
qzeros = torch.cat(
@ -201,12 +191,12 @@ class Weights:
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
try:
if quantize == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
except RuntimeError:
else:
g_idx = None
bits, groupsize = self._get_gptq_params()
@ -233,7 +223,7 @@ class Weights:
return tensor
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize in "gptq":
if quantize == "gptq":
use_exllama = True
bits, groupsize = self._get_gptq_params()
@ -311,8 +301,10 @@ class Weights:
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = None
use_exllama = False
weight = (qweight, qzeros, scales, None, bits, groupsize, None)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight