mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Adding Quantization GPTQ as an option (step 1)
- Just modifies the boolean into a proper enum - Use `str` within python (prevents importing enum everywhere).
This commit is contained in:
parent
b67908e0cf
commit
1185f66205
@ -1,4 +1,4 @@
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
@ -16,6 +16,26 @@ use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
|
||||
|
||||
mod env_runtime;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Quantization {
|
||||
Bitsandbytes,
|
||||
Gptq,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Quantization {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// To keep in track with `server`.
|
||||
match self {
|
||||
Quantization::Bitsandbytes => {
|
||||
write!(f, "bitsandbytes")
|
||||
}
|
||||
Quantization::Gptq => {
|
||||
write!(f, "gptq")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
@ -46,10 +66,10 @@ struct Args {
|
||||
#[clap(long, env)]
|
||||
num_shard: Option<usize>,
|
||||
|
||||
/// Wether you want the model to be quantized or not. This will use bitsandbytes for
|
||||
/// quantization on the fly.
|
||||
#[clap(long, env)]
|
||||
quantize: bool,
|
||||
/// Wether you want the model to be quantized or not. This will use `bitsandbytes` for
|
||||
/// quantization on the fly, or `gptq`.
|
||||
#[clap(long, env, value_enum)]
|
||||
quantize: Option<Quantization>,
|
||||
|
||||
/// The maximum amount of concurrent requests for this particular deployment.
|
||||
/// Having a low limit will refuse clients requests instead of having them
|
||||
@ -218,7 +238,7 @@ enum ShardStatus {
|
||||
fn shard_manager(
|
||||
model_id: String,
|
||||
revision: Option<String>,
|
||||
quantize: bool,
|
||||
quantize: Option<Quantization>,
|
||||
uds_path: String,
|
||||
rank: usize,
|
||||
world_size: usize,
|
||||
@ -257,8 +277,9 @@ fn shard_manager(
|
||||
shard_argv.push("--sharded".to_string());
|
||||
}
|
||||
|
||||
if quantize {
|
||||
shard_argv.push("--quantize".to_string())
|
||||
if let Some(quantize) = quantize {
|
||||
shard_argv.push("--quantize".to_string());
|
||||
shard_argv.push(quantize.to_string())
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
|
@ -5,17 +5,23 @@ import typer
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
class Quantization(str, Enum):
|
||||
bitsandbytes = "bitsandbytes"
|
||||
gptq = "gptq"
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
sharded: bool = False,
|
||||
quantize: bool = False,
|
||||
quantize: Optional[Quantization] = None,
|
||||
uds_path: Path = "/tmp/text-generation-server",
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
@ -55,7 +61,7 @@ def serve(
|
||||
if otlp_endpoint is not None:
|
||||
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
||||
|
||||
server.serve(model_id, revision, sharded, quantize, uds_path)
|
||||
server.serve(model_id, revision, sharded, quantize.value, uds_path)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
@ -91,7 +91,7 @@ torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def get_model(
|
||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
|
||||
) -> Model:
|
||||
if "facebook/galactica" in model_id:
|
||||
if sharded:
|
||||
|
@ -105,7 +105,7 @@ class FastLinear(nn.Linear):
|
||||
self.bnb_linear = None
|
||||
|
||||
def prepare_weights(self, quantize: bool = False):
|
||||
if quantize:
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
@ -129,8 +129,12 @@ class FastLinear(nn.Linear):
|
||||
# Delete reference to data
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
else:
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError("`gptq` is not implemented for now")
|
||||
elif quantize is None:
|
||||
self.weight = nn.Parameter(self.weight.T)
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.quantized:
|
||||
|
@ -154,7 +154,10 @@ class FlashLlama(FlashCausalLM):
|
||||
|
||||
class FlashLlamaSharded(FlashLlama):
|
||||
def __init__(
|
||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
self.past_pad = None
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
|
@ -100,14 +100,14 @@ def serve(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: bool,
|
||||
quantize: Optional[str],
|
||||
uds_path: Path,
|
||||
):
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
sharded: bool = False,
|
||||
quantize: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
):
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
if sharded:
|
||||
|
Loading…
Reference in New Issue
Block a user