Merge branch 'main' into prefer-chat-object-enum

This commit is contained in:
Nicolas Patry 2024-07-01 14:10:45 +02:00 committed by GitHub
commit 153c8ae60f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 186 additions and 734 deletions

View File

@ -1,84 +0,0 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6230469,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.046875,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1425781,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.9238281,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.076660156,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10821533,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
}

View File

@ -1,84 +0,0 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -2.2539062,
"special": false,
"text": "."
},
{
"id": 578,
"logprob": -0.15563965,
"special": false,
"text": " The"
},
{
"id": 3622,
"logprob": -0.8203125,
"special": false,
"text": " server"
},
{
"id": 706,
"logprob": 0.0,
"special": false,
"text": " has"
},
{
"id": 539,
"logprob": 0.0,
"special": false,
"text": " not"
},
{
"id": 3686,
"logprob": 0.0,
"special": false,
"text": " yet"
},
{
"id": 3288,
"logprob": 0.0,
"special": false,
"text": " sent"
},
{
"id": 904,
"logprob": 0.0,
"special": false,
"text": " any"
},
{
"id": 828,
"logprob": 0.0,
"special": false,
"text": " data"
},
{
"id": 382,
"logprob": -1.5517578,
"special": false,
"text": ".\n\n"
}
],
"top_tokens": null
},
"generated_text": "Test request. The server has not yet sent any data.\n\n"
}

View File

@ -1,338 +0,0 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2323,
"logprob": null,
"text": "Test"
},
{
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": " "
},
{
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": ".method"
},
{
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": " =="
},
{
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " '"
},
{
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n if request.method == 'POST"
}
]

View File

@ -1,68 +0,0 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_gptq_marlin_handle(launcher):
with launcher(
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
await flash_llama_gptq_marlin_handle.health(300)
return flash_llama_gptq_marlin_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
response = await flash_llama_gptq_marlin.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_all_params(
flash_llama_gptq_marlin, response_snapshot
):
response = await flash_llama_gptq_marlin.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_load(
flash_llama_gptq_marlin, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -898,13 +898,20 @@ enum LauncherError {
WebserverCannotStart, WebserverCannotStart,
} }
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> { fn download_convert_model(
model_id: &str,
revision: Option<&str>,
trust_remote_code: bool,
huggingface_hub_cache: Option<&str>,
weights_cache_override: Option<&str>,
running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
// Enter download tracing span // Enter download tracing span
let _span = tracing::span!(tracing::Level::INFO, "download").entered(); let _span = tracing::span!(tracing::Level::INFO, "download").entered();
let mut download_args = vec![ let mut download_args = vec![
"download-weights".to_string(), "download-weights".to_string(),
args.model_id.to_string(), model_id.to_string(),
"--extension".to_string(), "--extension".to_string(),
".safetensors".to_string(), ".safetensors".to_string(),
"--logger-level".to_string(), "--logger-level".to_string(),
@ -913,13 +920,13 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
]; ];
// Model optional revision // Model optional revision
if let Some(revision) = &args.revision { if let Some(revision) = &revision {
download_args.push("--revision".to_string()); download_args.push("--revision".to_string());
download_args.push(revision.to_string()) download_args.push(revision.to_string())
} }
// Trust remote code for automatic peft fusion // Trust remote code for automatic peft fusion
if args.trust_remote_code { if trust_remote_code {
download_args.push("--trust-remote-code".to_string()); download_args.push("--trust-remote-code".to_string());
} }
@ -934,7 +941,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// If huggingface_hub_cache is set, pass it to the download process // If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(ref huggingface_hub_cache) = args.huggingface_hub_cache { if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
}; };
@ -952,7 +959,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// If args.weights_cache_override is some, pass it to the download process // If args.weights_cache_override is some, pass it to the download process
// Useful when running inside a HuggingFace Inference Endpoint // Useful when running inside a HuggingFace Inference Endpoint
if let Some(weights_cache_override) = &args.weights_cache_override { if let Some(weights_cache_override) = &weights_cache_override {
envs.push(( envs.push((
"WEIGHTS_CACHE_OVERRIDE".into(), "WEIGHTS_CACHE_OVERRIDE".into(),
weights_cache_override.into(), weights_cache_override.into(),
@ -960,7 +967,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
}; };
// Start process // Start process
tracing::info!("Starting download process."); tracing::info!("Starting check and download process for {model_id}");
let mut download_process = match Command::new("text-generation-server") let mut download_process = match Command::new("text-generation-server")
.args(download_args) .args(download_args)
.env_clear() .env_clear()
@ -1002,7 +1009,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
loop { loop {
if let Some(status) = download_process.try_wait().unwrap() { if let Some(status) = download_process.try_wait().unwrap() {
if status.success() { if status.success() {
tracing::info!("Successfully downloaded weights."); tracing::info!("Successfully downloaded weights for {model_id}");
break; break;
} }
@ -1557,7 +1564,28 @@ fn main() -> Result<(), LauncherError> {
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
// Download and convert model weights // Download and convert model weights
download_convert_model(&args, running.clone())?; download_convert_model(
&args.model_id,
args.revision.as_deref(),
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
// Download and convert lora adapters if any
if let Some(lora_adapters) = &args.lora_adapters {
for adapter in lora_adapters.split(',') {
download_convert_model(
adapter,
None,
args.trust_remote_code,
args.huggingface_hub_cache.as_deref(),
args.weights_cache_override.as_deref(),
running.clone(),
)?;
}
}
if !running.load(Ordering::SeqCst) { if !running.load(Ordering::SeqCst) {
// Launcher was asked to stop // Launcher was asked to stop

View File

@ -309,7 +309,7 @@ async fn main() -> Result<(), RouterError> {
let mut tokenizer = Tokenizer::from_file(filename).ok(); let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer { if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class { if let Some(class) = &tokenizer_config.tokenizer_class {
if (class == "LlamaTokenizer" || class == "LlamaTokenizerFast") && tokenizer.get_post_processor().is_none() { if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor); tokenizer.with_post_processor(post_processor);
@ -577,7 +577,7 @@ pub fn create_post_processor(
if add_bos_token { if add_bos_token {
if let Some(bos) = bos_token { if let Some(bos) = bos_token {
single.push(format!("{}:1", bos.as_str())); pair.push(format!("{}:1", bos.as_str()));
} }
} }

View File

@ -7,6 +7,16 @@ from text_generation_server.utils.import_utils import (
) )
@dataclass
class GPTQParams:
bits: int
checkpoint_format: Optional[str]
groupsize: int
desc_act: bool
quant_method: str
sym: bool
@dataclass @dataclass
class GPTQWeight: class GPTQWeight:
qweight: torch.Tensor qweight: torch.Tensor

View File

@ -166,12 +166,17 @@ def get_linear(weight, bias, quantize):
elif quantize == "gptq": elif quantize == "gptq":
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
if not isinstance(weight, GPTQWeight): GPTQMarlinLinear,
raise NotImplementedError( GPTQMarlinWeight,
f"The passed weight is not `gptq` compatible, loader needs to be updated."
) )
if isinstance(weight, GPTQMarlinWeight):
linear = GPTQMarlinLinear(
weight=weight,
bias=bias,
)
elif isinstance(weight, GPTQWeight):
if weight.use_exllama: if weight.use_exllama:
try: try:
from text_generation_server.layers.gptq import ( from text_generation_server.layers.gptq import (
@ -195,6 +200,11 @@ def get_linear(weight, bias, quantize):
weight.bits, weight.bits,
weight.groupsize, weight.groupsize,
) )
else:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
elif quantize == "awq": elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
@ -226,18 +236,11 @@ def get_linear(weight, bias, quantize):
from text_generation_server.layers.marlin import ( from text_generation_server.layers.marlin import (
GPTQMarlin24Linear, GPTQMarlin24Linear,
GPTQMarlin24Weight, GPTQMarlin24Weight,
GPTQMarlinLinear,
GPTQMarlinWeight,
MarlinLinear, MarlinLinear,
MarlinWeight, MarlinWeight,
) )
if isinstance(weight, GPTQMarlinWeight): if isinstance(weight, GPTQMarlin24Weight):
linear = GPTQMarlinLinear(
weight=weight,
bias=bias,
)
elif isinstance(weight, GPTQMarlin24Weight):
linear = GPTQMarlin24Linear( linear = GPTQMarlin24Linear(
weight=weight, weight=weight,
bias=bias, bias=bias,

View File

@ -3,6 +3,8 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.layers.gptq import GPTQParams
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
try: try:
@ -22,6 +24,19 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16 MARLIN_TILE_SIZE = 16
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool:
return (
SYSTEM == "cuda"
and marlin_kernels is not None
and has_sm_8_0
and quantize == "gptq"
and gptq_params.quant_method == "gptq"
and gptq_params.bits in GPTQ_MARLIN_BITS
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES
and gptq_params.sym
)
def _check_marlin_kernels(): def _check_marlin_kernels():
if not (SYSTEM == "cuda" and has_sm_8_0): if not (SYSTEM == "cuda" and has_sm_8_0):
raise NotImplementedError( raise NotImplementedError(

View File

@ -309,7 +309,9 @@ class LlamaMLP(nn.Module):
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device="cuda", device="cuda",
) )
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) _custom_C.LLMM_Silu(
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
)
return self.down_proj(out, adapter_data) return self.down_proj(out, adapter_data)
else: else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)

View File

@ -1,25 +1,15 @@
import os import os
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Union
from safetensors import safe_open, SafetensorError from safetensors import safe_open, SafetensorError
import torch import torch
from loguru import logger from loguru import logger
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
from text_generation_server.layers.gptq import GPTQParams
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
@dataclass
class _GPTQParams:
bits: int
checkpoint_format: Optional[str]
groupsize: int
desc_act: bool
quant_method: str
sym: bool
class Weights: class Weights:
def __init__( def __init__(
self, self,
@ -212,6 +202,10 @@ class Weights:
""" """
if quantize in ["gptq", "awq"]: if quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try: try:
qweight = self.get_packed_sharded( qweight = self.get_packed_sharded(
@ -221,17 +215,28 @@ class Weights:
raise RuntimeError( raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized." f"Cannot load `{quantize}` weight, make sure the model is already quantized."
) )
gptq_params = self._get_gptq_params()
qzeros = self.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
scales = self.get_packed_sharded( scales = self.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes f"{prefix}.scales", dim=1, block_sizes=block_sizes
) )
scales = scales.to(dtype=self.dtype) scales = scales.to(dtype=self.dtype)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
g_idx = self.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
qzeros = self.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
if quantize == "gptq" and gptq_params.quant_method == "gptq": if quantize == "gptq" and gptq_params.quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx") g_idx = self.get_tensor(f"{prefix}.g_idx")
elif quantize == "gptq" and gptq_params.quant_method == "awq": elif quantize == "gptq" and gptq_params.quant_method == "awq":
@ -269,7 +274,6 @@ class Weights:
repack_gptq_for_marlin, repack_gptq_for_marlin,
) )
quant_method = getattr(self, "quant_method", "marlin")
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24: if is_marlin_24:
B = self.get_packed_sharded( B = self.get_packed_sharded(
@ -286,31 +290,6 @@ class Weights:
weight = GPTQMarlin24Weight( weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
) )
elif quant_method == "gptq":
gptq_params = self._get_gptq_params()
try:
qweight = self.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
scales = self.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
g_idx = self.get_tensor(f"{prefix}.g_idx")
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
else: else:
B = self.get_packed_sharded( B = self.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes f"{prefix}.B", dim=1, block_sizes=block_sizes
@ -356,6 +335,10 @@ class Weights:
raise ValueError("get_multi_weights_col is not supported for exl2") raise ValueError("get_multi_weights_col is not supported for exl2")
elif quantize in ["gptq", "awq"]: elif quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try: try:
qweight = torch.cat( qweight = torch.cat(
@ -366,14 +349,31 @@ class Weights:
f"Cannot load `{quantize}` weight, make sure the model is already quantized" f"Cannot load `{quantize}` weight, make sure the model is already quantized"
) )
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
scales = torch.cat( scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
) )
gptq_params = self._get_gptq_params() gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
from text_generation_server.layers.gptq import HAS_EXLLAMA from text_generation_server.layers.gptq import HAS_EXLLAMA
@ -425,10 +425,8 @@ class Weights:
from text_generation_server.layers.marlin import ( from text_generation_server.layers.marlin import (
GPTQMarlin24Weight, GPTQMarlin24Weight,
MarlinWeight, MarlinWeight,
repack_gptq_for_marlin,
) )
quant_method = getattr(self, "quant_method", "marlin")
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24: if is_marlin_24:
try: try:
@ -452,36 +450,6 @@ class Weights:
weight = GPTQMarlin24Weight( weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
) )
elif quant_method == "gptq":
gptq_params = self._get_gptq_params()
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes],
dim=1,
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
else: else:
try: try:
B = torch.cat( B = torch.cat(
@ -544,9 +512,41 @@ class Weights:
) )
elif quantize == "gptq": elif quantize == "gptq":
use_exllama = True from text_generation_server.layers.marlin import (
gptq_params = self._get_gptq_params() can_use_gptq_marlin,
repack_gptq_for_marlin,
)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if gptq_params.desc_act or gptq_params.groupsize == -1:
scales = self.get_tensor(f"{prefix}.scales")
else:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = self.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=sharded_in_features,
)
use_exllama = True
if gptq_params.bits != 4: if gptq_params.bits != 4:
use_exllama = False use_exllama = False
@ -672,10 +672,8 @@ class Weights:
from text_generation_server.layers.marlin import ( from text_generation_server.layers.marlin import (
GPTQMarlin24Weight, GPTQMarlin24Weight,
MarlinWeight, MarlinWeight,
repack_gptq_for_marlin,
) )
quant_method = getattr(self, "quant_method", "marlin")
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24: if is_marlin_24:
try: try:
@ -698,35 +696,6 @@ class Weights:
weight = GPTQMarlin24Weight( weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
) )
elif quant_method == "gptq":
log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
gptq_params = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if gptq_params.desc_act or gptq_params.groupsize == -1:
scales = self.get_tensor(f"{prefix}.scales")
else:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = self.process_group.size() > 1
weight = repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=sharded_in_features,
)
else: else:
try: try:
B = self.get_sharded(f"{prefix}.B", dim=0) B = self.get_sharded(f"{prefix}.B", dim=0)
@ -743,18 +712,17 @@ class Weights:
else: else:
s = self.get_sharded(f"{prefix}.s", dim=0) s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s) weight = MarlinWeight(B=B, s=s)
else: else:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight return weight
def _get_gptq_params(self) -> _GPTQParams: def _get_gptq_params(self) -> GPTQParams:
try: try:
bits = self.get_tensor("gptq_bits").item() bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item() groupsize = self.get_tensor("gptq_groupsize").item()
checkpoint_format = getattr(self, "gptq_checkpoint_format", None) checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = False desc_act = False
sym = True sym = False
quant_method = "gptq" quant_method = "gptq"
except (SafetensorError, RuntimeError) as e: except (SafetensorError, RuntimeError) as e:
try: try:
@ -767,7 +735,7 @@ class Weights:
except Exception: except Exception:
raise e raise e
return _GPTQParams( return GPTQParams(
bits=bits, bits=bits,
checkpoint_format=checkpoint_format, checkpoint_format=checkpoint_format,
desc_act=desc_act, desc_act=desc_act,