Adding Quantization GPTQ as an option (step 1)

- Just modifies the boolean into a proper enum
- Use `str` within python (prevents importing enum everywhere).
This commit is contained in:
Nicolas Patry 2023-05-04 12:16:40 +02:00
parent b67908e0cf
commit 1185f66205
6 changed files with 50 additions and 16 deletions

View File

@ -1,4 +1,4 @@
use clap::Parser;
use clap::{Parser, ValueEnum};
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
@ -16,6 +16,26 @@ use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
mod env_runtime;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
Bitsandbytes,
Gptq,
}
impl std::fmt::Display for Quantization {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
Quantization::Bitsandbytes => {
write!(f, "bitsandbytes")
}
Quantization::Gptq => {
write!(f, "gptq")
}
}
}
}
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
@ -46,10 +66,10 @@ struct Args {
#[clap(long, env)]
num_shard: Option<usize>,
/// Wether you want the model to be quantized or not. This will use bitsandbytes for
/// quantization on the fly.
#[clap(long, env)]
quantize: bool,
/// Wether you want the model to be quantized or not. This will use `bitsandbytes` for
/// quantization on the fly, or `gptq`.
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,
/// The maximum amount of concurrent requests for this particular deployment.
/// Having a low limit will refuse clients requests instead of having them
@ -218,7 +238,7 @@ enum ShardStatus {
fn shard_manager(
model_id: String,
revision: Option<String>,
quantize: bool,
quantize: Option<Quantization>,
uds_path: String,
rank: usize,
world_size: usize,
@ -257,8 +277,9 @@ fn shard_manager(
shard_argv.push("--sharded".to_string());
}
if quantize {
shard_argv.push("--quantize".to_string())
if let Some(quantize) = quantize {
shard_argv.push("--quantize".to_string());
shard_argv.push(quantize.to_string())
}
// Model optional revision

View File

@ -5,17 +5,23 @@ import typer
from pathlib import Path
from loguru import logger
from typing import Optional
from enum import Enum
app = typer.Typer()
class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
gptq = "gptq"
@app.command()
def serve(
model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
quantize: bool = False,
quantize: Optional[Quantization] = None,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,
@ -55,7 +61,7 @@ def serve(
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
server.serve(model_id, revision, sharded, quantize, uds_path)
server.serve(model_id, revision, sharded, quantize.value, uds_path)
@app.command()

View File

@ -91,7 +91,7 @@ torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
) -> Model:
if "facebook/galactica" in model_id:
if sharded:

View File

@ -105,7 +105,7 @@ class FastLinear(nn.Linear):
self.bnb_linear = None
def prepare_weights(self, quantize: bool = False):
if quantize:
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
@ -129,8 +129,12 @@ class FastLinear(nn.Linear):
# Delete reference to data
self.weight = None
self.bias = None
else:
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:

View File

@ -154,7 +154,10 @@ class FlashLlama(FlashCausalLM):
class FlashLlamaSharded(FlashLlama):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed()

View File

@ -100,14 +100,14 @@ def serve(
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: bool,
quantize: Optional[str],
uds_path: Path,
):
async def serve_inner(
model_id: str,
revision: Optional[str],
sharded: bool = False,
quantize: bool = False,
quantize: Optional[str] = None,
):
unix_socket_template = "unix://{}-{}"
if sharded: