mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
fix gptq tests, LLMM1 matrix bound
This commit is contained in:
parent
d3c7f63416
commit
b452620c04
@ -320,7 +320,6 @@ def launcher(event_loop):
|
||||
max_batch_prefill_tokens: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
print("call local_launcher")
|
||||
port = random.randint(8000, 10_000)
|
||||
master_port = random.randint(10_000, 20_000)
|
||||
|
||||
|
@ -1,12 +1,22 @@
|
||||
import pytest
|
||||
|
||||
from testing_utils import SYSTEM
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_awq_handle(launcher):
|
||||
if SYSTEM == "rocm":
|
||||
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.
|
||||
quantize = "gptq"
|
||||
elif SYSTEM == "xpu":
|
||||
pytest.skiptest("AWQ is not supported on xpu")
|
||||
else:
|
||||
quantize = "awq"
|
||||
|
||||
with launcher(
|
||||
"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq",
|
||||
num_shard=1,
|
||||
quantize="awq",
|
||||
quantize=quantize,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
@ -1,12 +1,22 @@
|
||||
import pytest
|
||||
|
||||
from testing_utils import SYSTEM, is_flaky_async
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_awq_handle_sharded(launcher):
|
||||
if SYSTEM == "rocm":
|
||||
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.
|
||||
quantize = "gptq"
|
||||
elif SYSTEM == "xpu":
|
||||
pytest.skiptest("AWQ is not supported on xpu")
|
||||
else:
|
||||
quantize = "awq"
|
||||
|
||||
with launcher(
|
||||
"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq",
|
||||
num_shard=2,
|
||||
quantize="awq",
|
||||
quantize=quantize,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
@ -17,17 +27,23 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
||||
return flash_llama_awq_handle_sharded.client
|
||||
|
||||
|
||||
@is_flaky_async(max_attempts=5)
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
||||
response = await flash_llama_awq_sharded.generate(
|
||||
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
# ExllamaV2 (which may be used as an AWQ backend) is highly non-deterministic, see for reference https://github.com/turboderp/exllamav2/issues/232.
|
||||
assert response.details.generated_tokens == 10
|
||||
|
||||
assert (
|
||||
response.generated_text
|
||||
== "\nWhat is the difference between Deep Learning and Machine"
|
||||
)
|
||||
|
||||
if SYSTEM != "rocm":
|
||||
# Logits were taken on an Nvidia GPU, and are too far off to be meaningfully compared.
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@ -35,11 +51,15 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
|
||||
async def test_flash_llama_awq_load_sharded(
|
||||
flash_llama_awq_sharded, generate_load, response_snapshot
|
||||
):
|
||||
if SYSTEM == "rocm":
|
||||
pytest.skiptest(
|
||||
"This test relies on ExllamaV2 on ROCm systems, which is highly non-determinstic (flaky)"
|
||||
)
|
||||
|
||||
responses = await generate_load(
|
||||
flash_llama_awq_sharded, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all(
|
||||
[
|
||||
r.generated_text
|
||||
@ -48,4 +68,5 @@ async def test_flash_llama_awq_load_sharded(
|
||||
]
|
||||
)
|
||||
|
||||
# Logits were taken on an Nvidia GPU, and are too far off to be meaningfully compared.
|
||||
assert responses == response_snapshot
|
||||
|
@ -14,9 +14,6 @@ def flash_llama_marlin_handle(launcher):
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_marlin(flash_llama_marlin_handle):
|
||||
if SYSTEM != "cuda":
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await flash_llama_marlin_handle.health(300)
|
||||
assert exc_info.value.args[0] == "only available on Nvidia"
|
||||
pytest.skip(f"Marlin not supported on SYSTEM={SYSTEM}")
|
||||
else:
|
||||
await flash_llama_marlin_handle.health(300)
|
||||
|
@ -1,5 +1,5 @@
|
||||
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa
|
||||
commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0
|
||||
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
||||
build-vllm-cuda:
|
||||
if [ ! -d 'vllm' ]; then \
|
||||
pip install -U ninja packaging --no-cache-dir && \
|
||||
@ -20,4 +20,4 @@ build-vllm-rocm:
|
||||
|
||||
install-vllm-rocm: build-vllm-rocm
|
||||
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||
PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .
|
||||
VLLM_TARGET_DEVICE="rocm" PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install --no-build-isolation -e .
|
||||
|
@ -19,7 +19,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -189,7 +188,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
@ -439,5 +437,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
@ -561,6 +561,8 @@ class Weights:
|
||||
|
||||
def _set_gptq_params(self, model_id, revision):
|
||||
filename = "config.json"
|
||||
|
||||
self.quant_method = None
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
@ -608,7 +610,11 @@ class Weights:
|
||||
if "version" in data and data["version"] == "GEMM":
|
||||
self.quant_method = "awq"
|
||||
except Exception:
|
||||
pass
|
||||
if self.quant_method is None:
|
||||
if "awq" in model_id.lower():
|
||||
self.quant_method = "awq"
|
||||
elif "gptq" in model_id.lower():
|
||||
self.quant_method = "gptq"
|
||||
|
||||
|
||||
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||
|
Loading…
Reference in New Issue
Block a user