Adding Quantization GPTQ as an option (step 1)

- Just modifies the boolean into a proper enum
- Use `str` within python (prevents importing enum everywhere).
This commit is contained in:
Nicolas Patry 2023-05-04 12:16:40 +02:00
parent b67908e0cf
commit 1185f66205
6 changed files with 50 additions and 16 deletions

View File

@ -1,4 +1,4 @@
use clap::Parser; use clap::{Parser, ValueEnum};
use serde::Deserialize; use serde::Deserialize;
use std::env; use std::env;
use std::ffi::OsString; use std::ffi::OsString;
@ -16,6 +16,26 @@ use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
mod env_runtime; mod env_runtime;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
Bitsandbytes,
Gptq,
}
impl std::fmt::Display for Quantization {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
Quantization::Bitsandbytes => {
write!(f, "bitsandbytes")
}
Quantization::Gptq => {
write!(f, "gptq")
}
}
}
}
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)] #[clap(author, version, about, long_about = None)]
@ -46,10 +66,10 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
num_shard: Option<usize>, num_shard: Option<usize>,
/// Wether you want the model to be quantized or not. This will use bitsandbytes for /// Wether you want the model to be quantized or not. This will use `bitsandbytes` for
/// quantization on the fly. /// quantization on the fly, or `gptq`.
#[clap(long, env)] #[clap(long, env, value_enum)]
quantize: bool, quantize: Option<Quantization>,
/// The maximum amount of concurrent requests for this particular deployment. /// The maximum amount of concurrent requests for this particular deployment.
/// Having a low limit will refuse clients requests instead of having them /// Having a low limit will refuse clients requests instead of having them
@ -218,7 +238,7 @@ enum ShardStatus {
fn shard_manager( fn shard_manager(
model_id: String, model_id: String,
revision: Option<String>, revision: Option<String>,
quantize: bool, quantize: Option<Quantization>,
uds_path: String, uds_path: String,
rank: usize, rank: usize,
world_size: usize, world_size: usize,
@ -257,8 +277,9 @@ fn shard_manager(
shard_argv.push("--sharded".to_string()); shard_argv.push("--sharded".to_string());
} }
if quantize { if let Some(quantize) = quantize {
shard_argv.push("--quantize".to_string()) shard_argv.push("--quantize".to_string());
shard_argv.push(quantize.to_string())
} }
// Model optional revision // Model optional revision

View File

@ -5,17 +5,23 @@ import typer
from pathlib import Path from pathlib import Path
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from enum import Enum
app = typer.Typer() app = typer.Typer()
class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
gptq = "gptq"
@app.command() @app.command()
def serve( def serve(
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
sharded: bool = False, sharded: bool = False,
quantize: bool = False, quantize: Optional[Quantization] = None,
uds_path: Path = "/tmp/text-generation-server", uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
@ -55,7 +61,7 @@ def serve(
if otlp_endpoint is not None: if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
server.serve(model_id, revision, sharded, quantize, uds_path) server.serve(model_id, revision, sharded, quantize.value, uds_path)
@app.command() @app.command()

View File

@ -91,7 +91,7 @@ torch.set_grad_enabled(False)
def get_model( def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
) -> Model: ) -> Model:
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:
if sharded: if sharded:

View File

@ -105,7 +105,7 @@ class FastLinear(nn.Linear):
self.bnb_linear = None self.bnb_linear = None
def prepare_weights(self, quantize: bool = False): def prepare_weights(self, quantize: bool = False):
if quantize: if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES: if not HAS_BITS_AND_BYTES:
raise ImportError( raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed " "bitsandbytes is not available on your machine either because it is not installed "
@ -129,8 +129,12 @@ class FastLinear(nn.Linear):
# Delete reference to data # Delete reference to data
self.weight = None self.weight = None
self.bias = None self.bias = None
else: elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T) self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized: if self.quantized:

View File

@ -154,7 +154,10 @@ class FlashLlama(FlashCausalLM):
class FlashLlamaSharded(FlashLlama): class FlashLlamaSharded(FlashLlama):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
): ):
self.past_pad = None self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()

View File

@ -100,14 +100,14 @@ def serve(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: bool, quantize: Optional[str],
uds_path: Path, uds_path: Path,
): ):
async def serve_inner( async def serve_inner(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],
sharded: bool = False, sharded: bool = False,
quantize: bool = False, quantize: Optional[str] = None,
): ):
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
if sharded: if sharded: