mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Adding Quantization GPTQ as an option (step 1)
- Just modifies the boolean into a proper enum - Use `str` within python (prevents importing enum everywhere).
This commit is contained in:
parent
b67908e0cf
commit
1185f66205
@ -1,4 +1,4 @@
|
|||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::ffi::OsString;
|
use std::ffi::OsString;
|
||||||
@ -16,6 +16,26 @@ use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection};
|
|||||||
|
|
||||||
mod env_runtime;
|
mod env_runtime;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Quantization {
|
||||||
|
Bitsandbytes,
|
||||||
|
Gptq,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Quantization {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
// To keep in track with `server`.
|
||||||
|
match self {
|
||||||
|
Quantization::Bitsandbytes => {
|
||||||
|
write!(f, "bitsandbytes")
|
||||||
|
}
|
||||||
|
Quantization::Gptq => {
|
||||||
|
write!(f, "gptq")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
@ -46,10 +66,10 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
num_shard: Option<usize>,
|
num_shard: Option<usize>,
|
||||||
|
|
||||||
/// Wether you want the model to be quantized or not. This will use bitsandbytes for
|
/// Wether you want the model to be quantized or not. This will use `bitsandbytes` for
|
||||||
/// quantization on the fly.
|
/// quantization on the fly, or `gptq`.
|
||||||
#[clap(long, env)]
|
#[clap(long, env, value_enum)]
|
||||||
quantize: bool,
|
quantize: Option<Quantization>,
|
||||||
|
|
||||||
/// The maximum amount of concurrent requests for this particular deployment.
|
/// The maximum amount of concurrent requests for this particular deployment.
|
||||||
/// Having a low limit will refuse clients requests instead of having them
|
/// Having a low limit will refuse clients requests instead of having them
|
||||||
@ -218,7 +238,7 @@ enum ShardStatus {
|
|||||||
fn shard_manager(
|
fn shard_manager(
|
||||||
model_id: String,
|
model_id: String,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
quantize: bool,
|
quantize: Option<Quantization>,
|
||||||
uds_path: String,
|
uds_path: String,
|
||||||
rank: usize,
|
rank: usize,
|
||||||
world_size: usize,
|
world_size: usize,
|
||||||
@ -257,8 +277,9 @@ fn shard_manager(
|
|||||||
shard_argv.push("--sharded".to_string());
|
shard_argv.push("--sharded".to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
if quantize {
|
if let Some(quantize) = quantize {
|
||||||
shard_argv.push("--quantize".to_string())
|
shard_argv.push("--quantize".to_string());
|
||||||
|
shard_argv.push(quantize.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Model optional revision
|
// Model optional revision
|
||||||
|
@ -5,17 +5,23 @@ import typer
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
class Quantization(str, Enum):
|
||||||
|
bitsandbytes = "bitsandbytes"
|
||||||
|
gptq = "gptq"
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: bool = False,
|
quantize: Optional[Quantization] = None,
|
||||||
uds_path: Path = "/tmp/text-generation-server",
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
json_output: bool = False,
|
json_output: bool = False,
|
||||||
@ -55,7 +61,7 @@ def serve(
|
|||||||
if otlp_endpoint is not None:
|
if otlp_endpoint is not None:
|
||||||
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
|
||||||
|
|
||||||
server.serve(model_id, revision, sharded, quantize, uds_path)
|
server.serve(model_id, revision, sharded, quantize.value, uds_path)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
@ -91,7 +91,7 @@ torch.set_grad_enabled(False)
|
|||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if "facebook/galactica" in model_id:
|
if "facebook/galactica" in model_id:
|
||||||
if sharded:
|
if sharded:
|
||||||
|
@ -105,7 +105,7 @@ class FastLinear(nn.Linear):
|
|||||||
self.bnb_linear = None
|
self.bnb_linear = None
|
||||||
|
|
||||||
def prepare_weights(self, quantize: bool = False):
|
def prepare_weights(self, quantize: bool = False):
|
||||||
if quantize:
|
if quantize == "bitsandbytes":
|
||||||
if not HAS_BITS_AND_BYTES:
|
if not HAS_BITS_AND_BYTES:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"bitsandbytes is not available on your machine either because it is not installed "
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||||||
@ -129,8 +129,12 @@ class FastLinear(nn.Linear):
|
|||||||
# Delete reference to data
|
# Delete reference to data
|
||||||
self.weight = None
|
self.weight = None
|
||||||
self.bias = None
|
self.bias = None
|
||||||
else:
|
elif quantize == "gptq":
|
||||||
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
|
elif quantize is None:
|
||||||
self.weight = nn.Parameter(self.weight.T)
|
self.weight = nn.Parameter(self.weight.T)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if self.quantized:
|
if self.quantized:
|
||||||
|
@ -154,7 +154,10 @@ class FlashLlama(FlashCausalLM):
|
|||||||
|
|
||||||
class FlashLlamaSharded(FlashLlama):
|
class FlashLlamaSharded(FlashLlama):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.past_pad = None
|
self.past_pad = None
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
|
@ -100,14 +100,14 @@ def serve(
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: bool,
|
quantize: Optional[str],
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
):
|
):
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: bool = False,
|
quantize: Optional[str] = None,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
if sharded:
|
if sharded:
|
||||||
|
Loading…
Reference in New Issue
Block a user