mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Merge branch 'main' into added_cli_docs
This commit is contained in:
commit
29129dc660
@ -39,8 +39,9 @@ RUN cargo build --release
|
||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||
FROM debian:bullseye-slim as pytorch-install
|
||||
|
||||
ARG PYTORCH_VERSION=2.0.0
|
||||
ARG PYTORCH_VERSION=2.0.1
|
||||
ARG PYTHON_VERSION=3.9
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG CUDA_VERSION=11.8
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG CUDA_CHANNEL=nvidia
|
||||
|
@ -252,6 +252,8 @@ You can also quantize the weights with bitsandbytes to reduce the VRAM requireme
|
||||
make run-falcon-7b-instruct-quantize
|
||||
```
|
||||
|
||||
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
|
||||
|
||||
## Develop
|
||||
|
||||
```shell
|
||||
|
@ -1,18 +1,18 @@
|
||||
- sections:
|
||||
- local: index
|
||||
title: Text Generation Inference
|
||||
- local: installation_launch
|
||||
title: Installation and Launching
|
||||
- local: quicktour
|
||||
title: Quick Tour
|
||||
- local: installation
|
||||
title: Installation
|
||||
- local: supported_models
|
||||
title: Supported Models and Hardware
|
||||
title: Getting started
|
||||
- sections:
|
||||
- local: basic_tutorials/local_launch
|
||||
title: Installing from the Source and Launching TGI
|
||||
- local: basic_tutorials/consuming_tgi
|
||||
title: Consuming TGI
|
||||
- local: basic_tutorials/preparing_model
|
||||
title: Preparing Model for Serving
|
||||
- local: basic_tutorials/using_cli
|
||||
title: Using TGI through CLI
|
||||
- local: basic_tutorials/gated_model_access
|
||||
title: Serving Private & Gated Models
|
||||
title: Tutorials
|
||||
|
5
docs/source/basic_tutorials/gated_model_access.md
Normal file
5
docs/source/basic_tutorials/gated_model_access.md
Normal file
@ -0,0 +1,5 @@
|
||||
# Serving Private & Gated Models
|
||||
|
||||
If the model you wish to serve is behind gated access or the model repository on Hugging Face Hub is private, and you have access to the model, you can provide your Hugging Face Hub access token. You can generate and copy a read token from [Hugging Face Hub tokens page](https://huggingface.co/settings/tokens)
|
||||
|
||||
If you're using the CLI, set the `HUGGING_FACE_HUB_TOKEN` environment variable.
|
@ -4,9 +4,7 @@ Text Generation Inference improves the model in several aspects.
|
||||
|
||||
## Quantization
|
||||
|
||||
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes` or `gptq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq).
|
||||
To run quantization with TGI, refer to [`Using TGI through CLI`](TODO: ADD INTERNAL REF) section.
|
||||
|
||||
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes` or `gptq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq).
|
||||
|
||||
## RoPE Scaling
|
||||
|
||||
|
71
docs/source/installation.md
Normal file
71
docs/source/installation.md
Normal file
@ -0,0 +1,71 @@
|
||||
# Installation
|
||||
|
||||
This section explains how to install the CLI tool as well as installing TGI from source. **The strongly recommended approach is to use Docker, as it does not require much setup. Check [the Quick Tour](./quicktour) to learn how to run TGI with Docker.**
|
||||
|
||||
## Install CLI
|
||||
|
||||
TODO
|
||||
|
||||
|
||||
## Local Installation from Source
|
||||
|
||||
Before you start, you will need to setup your environment, and install Text Generation Inference. Text Generation Inference is tested on **Python 3.9+**.
|
||||
|
||||
Text Generation Inference is available on pypi, conda and GitHub.
|
||||
|
||||
To install and launch locally, first [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
|
||||
Python 3.9, e.g. using conda:
|
||||
|
||||
```shell
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
|
||||
conda create -n text-generation-inference python=3.9
|
||||
conda activate text-generation-inference
|
||||
```
|
||||
|
||||
You may also need to install Protoc.
|
||||
|
||||
On Linux:
|
||||
|
||||
```shell
|
||||
PROTOC_ZIP=protoc-21.12-linux-x86_64.zip
|
||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP
|
||||
sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc
|
||||
sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'
|
||||
rm -f $PROTOC_ZIP
|
||||
```
|
||||
|
||||
On MacOS, using Homebrew:
|
||||
|
||||
```shell
|
||||
brew install protobuf
|
||||
```
|
||||
|
||||
Then run to install Text Generation Inference:
|
||||
|
||||
```shell
|
||||
BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
On some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:
|
||||
|
||||
```shell
|
||||
sudo apt-get install libssl-dev gcc -y
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
Once installation is done, simply run:
|
||||
|
||||
```shell
|
||||
make run-falcon-7b-instruct
|
||||
```
|
||||
|
||||
This will serve Falcon 7B Instruct model from the port 8080, which we can query.
|
||||
|
||||
To see all options to serve your models, check in the [codebase](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or the CLI:
|
||||
```
|
||||
text-generation-launcher --help
|
||||
```
|
34
docs/source/quicktour.md
Normal file
34
docs/source/quicktour.md
Normal file
@ -0,0 +1,34 @@
|
||||
# Quick Tour
|
||||
|
||||
The easiest way of getting started is using the official Docker container. Install Docker following [their installation instructions](https://docs.docker.com/get-docker/).
|
||||
|
||||
Let's say you want to deploy [Falcon-7B Instruct](https://huggingface.co/tiiuae/falcon-7b-instruct) model with TGI. Here is an example on how to do that:
|
||||
|
||||
```shell
|
||||
model=tiiuae/falcon-7b-instruct
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.0 --model-id $model
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) . We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
|
||||
|
||||
</Tip>
|
||||
|
||||
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section.
|
||||
|
||||
```shell
|
||||
curl 127.0.0.1:8080/generate -X POST -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' -H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
To see all possible flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||
|
||||
```shell
|
||||
docker run ghcr.io/huggingface/text-generation-inference:1.0.0 --help
|
||||
```
|
||||
|
||||
</Tip>
|
@ -22,6 +22,8 @@ mod env_runtime;
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Quantization {
|
||||
Bitsandbytes,
|
||||
BitsandbytesNF4,
|
||||
BitsandbytesFP4,
|
||||
Gptq,
|
||||
}
|
||||
|
||||
@ -32,6 +34,12 @@ impl std::fmt::Display for Quantization {
|
||||
Quantization::Bitsandbytes => {
|
||||
write!(f, "bitsandbytes")
|
||||
}
|
||||
Quantization::BitsandbytesNF4 => {
|
||||
write!(f, "bitsandbytes-nf4")
|
||||
}
|
||||
Quantization::BitsandbytesFP4 => {
|
||||
write!(f, "bitsandbytes-fp4")
|
||||
}
|
||||
Quantization::Gptq => {
|
||||
write!(f, "gptq")
|
||||
}
|
||||
@ -60,6 +68,26 @@ impl std::fmt::Display for Dtype {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum RopeScaling {
|
||||
Linear,
|
||||
Dynamic,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RopeScaling {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// To keep in track with `server`.
|
||||
match self {
|
||||
RopeScaling::Linear => {
|
||||
write!(f, "linear")
|
||||
}
|
||||
RopeScaling::Dynamic => {
|
||||
write!(f, "dynamic")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
@ -96,7 +124,8 @@ struct Args {
|
||||
num_shard: Option<usize>,
|
||||
|
||||
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
|
||||
/// quantization on the fly, or `gptq`.
|
||||
/// quantization on the fly, or `gptq`. 4bit quantization is available through
|
||||
/// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options.
|
||||
#[clap(long, env, value_enum)]
|
||||
quantize: Option<Quantization>,
|
||||
|
||||
@ -250,6 +279,26 @@ struct Args {
|
||||
#[clap(default_value = "1.0", long, env)]
|
||||
cuda_memory_fraction: f32,
|
||||
|
||||
/// Rope scaling will only be used for RoPE models
|
||||
/// and allow rescaling the position rotary to accomodate for
|
||||
/// larger prompts.
|
||||
///
|
||||
/// Goes together with `rope_factor`.
|
||||
///
|
||||
/// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
|
||||
/// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
|
||||
/// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
|
||||
/// basically)
|
||||
///
|
||||
/// `--rope-scaling linear --rope-factor` fully describes the scaling you want
|
||||
#[clap(long, env)]
|
||||
rope_scaling: Option<RopeScaling>,
|
||||
|
||||
/// Rope scaling will only be used for RoPE models
|
||||
/// See `rope_scaling`
|
||||
#[clap(long, env)]
|
||||
rope_factor: Option<f32>,
|
||||
|
||||
/// Outputs the logs in JSON format (useful for telemetry)
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
@ -305,6 +354,8 @@ fn shard_manager(
|
||||
watermark_gamma: Option<f32>,
|
||||
watermark_delta: Option<f32>,
|
||||
cuda_memory_fraction: f32,
|
||||
rope_scaling: Option<RopeScaling>,
|
||||
rope_factor: Option<f32>,
|
||||
otlp_endpoint: Option<String>,
|
||||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
@ -358,6 +409,12 @@ fn shard_manager(
|
||||
shard_args.push(revision)
|
||||
}
|
||||
|
||||
let rope = match (rope_scaling, rope_factor) {
|
||||
(None, None) => None,
|
||||
(Some(scaling), None) => Some((scaling, 1.0)),
|
||||
(Some(scaling), Some(factor)) => Some((scaling, factor)),
|
||||
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
|
||||
};
|
||||
// OpenTelemetry
|
||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||
shard_args.push("--otlp-endpoint".to_string());
|
||||
@ -395,6 +452,15 @@ fn shard_manager(
|
||||
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
|
||||
};
|
||||
|
||||
// Detect rope scaling
|
||||
// Sending as env instead of CLI args to not bloat everything
|
||||
// those only can be used by RoPE models, so passing information around
|
||||
// for all models will complexify code unnecessarily
|
||||
if let Some((scaling, factor)) = rope {
|
||||
envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
|
||||
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
|
||||
}
|
||||
|
||||
// If huggingface_hub_cache is some, pass it to the shard
|
||||
// Useful when running inside a docker container
|
||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||
@ -659,6 +725,11 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
||||
download_args.push(revision.to_string())
|
||||
}
|
||||
|
||||
// Trust remote code for automatic peft fusion
|
||||
if args.trust_remote_code {
|
||||
download_args.push("--trust-remote-code".to_string());
|
||||
}
|
||||
|
||||
// Copy current process env
|
||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
@ -784,6 +855,8 @@ fn spawn_shards(
|
||||
let watermark_gamma = args.watermark_gamma;
|
||||
let watermark_delta = args.watermark_delta;
|
||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||
let rope_scaling = args.rope_scaling;
|
||||
let rope_factor = args.rope_factor;
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
@ -802,6 +875,8 @@ fn spawn_shards(
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
cuda_memory_fraction,
|
||||
rope_scaling,
|
||||
rope_factor,
|
||||
otlp_endpoint,
|
||||
status_sender,
|
||||
shutdown,
|
||||
|
@ -54,15 +54,13 @@ example = json ! ({"error": "Incomplete generation"})),
|
||||
)]
|
||||
#[instrument(skip(infer, req))]
|
||||
async fn compat_generate(
|
||||
default_return_full_text: Extension<bool>,
|
||||
Extension(default_return_full_text): Extension<bool>,
|
||||
infer: Extension<Infer>,
|
||||
req: Json<CompatGenerateRequest>,
|
||||
Json(mut req): Json<CompatGenerateRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let mut req = req.0;
|
||||
|
||||
// default return_full_text given the pipeline_tag
|
||||
if req.parameters.return_full_text.is_none() {
|
||||
req.parameters.return_full_text = Some(default_return_full_text.0)
|
||||
req.parameters.return_full_text = Some(default_return_full_text)
|
||||
}
|
||||
|
||||
// switch on stream
|
||||
@ -71,9 +69,9 @@ async fn compat_generate(
|
||||
.await
|
||||
.into_response())
|
||||
} else {
|
||||
let (headers, generation) = generate(infer, Json(req.into())).await?;
|
||||
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?;
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
Ok((headers, Json(vec![generation.0])).into_response())
|
||||
Ok((headers, Json(vec![generation])).into_response())
|
||||
}
|
||||
}
|
||||
|
||||
@ -135,7 +133,7 @@ example = json ! ({"error": "Incomplete generation"})),
|
||||
#[instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
parameters = ? req.0.parameters,
|
||||
parameters = ? req.parameters,
|
||||
total_time,
|
||||
validation_time,
|
||||
queue_time,
|
||||
@ -146,29 +144,29 @@ seed,
|
||||
)]
|
||||
async fn generate(
|
||||
infer: Extension<Infer>,
|
||||
req: Json<GenerateRequest>,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
let start_time = Instant::now();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
tracing::debug!("Input: {}", req.0.inputs);
|
||||
tracing::debug!("Input: {}", req.inputs);
|
||||
|
||||
let compute_characters = req.0.inputs.chars().count();
|
||||
let compute_characters = req.inputs.chars().count();
|
||||
let mut add_prompt = None;
|
||||
if req.0.parameters.return_full_text.unwrap_or(false) {
|
||||
add_prompt = Some(req.0.inputs.clone());
|
||||
if req.parameters.return_full_text.unwrap_or(false) {
|
||||
add_prompt = Some(req.inputs.clone());
|
||||
}
|
||||
|
||||
let details = req.0.parameters.details || req.0.parameters.decoder_input_details;
|
||||
let details = req.parameters.details || req.parameters.decoder_input_details;
|
||||
|
||||
// Inference
|
||||
let (response, best_of_responses) = match req.0.parameters.best_of {
|
||||
let (response, best_of_responses) = match req.parameters.best_of {
|
||||
Some(best_of) if best_of > 1 => {
|
||||
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
|
||||
let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?;
|
||||
(response, Some(best_of_responses))
|
||||
}
|
||||
_ => (infer.generate(req.0).await?, None),
|
||||
_ => (infer.generate(req).await?, None),
|
||||
};
|
||||
|
||||
// Token details
|
||||
@ -321,7 +319,7 @@ content_type = "text/event-stream"),
|
||||
#[instrument(
|
||||
skip_all,
|
||||
fields(
|
||||
parameters = ? req.0.parameters,
|
||||
parameters = ? req.parameters,
|
||||
total_time,
|
||||
validation_time,
|
||||
queue_time,
|
||||
@ -331,8 +329,8 @@ seed,
|
||||
)
|
||||
)]
|
||||
async fn generate_stream(
|
||||
infer: Extension<Infer>,
|
||||
req: Json<GenerateRequest>,
|
||||
Extension(infer): Extension<Infer>,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
) -> (
|
||||
HeaderMap,
|
||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||
@ -341,9 +339,9 @@ async fn generate_stream(
|
||||
let start_time = Instant::now();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
tracing::debug!("Input: {}", req.0.inputs);
|
||||
tracing::debug!("Input: {}", req.inputs);
|
||||
|
||||
let compute_characters = req.0.inputs.chars().count();
|
||||
let compute_characters = req.inputs.chars().count();
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
||||
@ -359,24 +357,24 @@ async fn generate_stream(
|
||||
let mut error = false;
|
||||
|
||||
let mut add_prompt = None;
|
||||
if req.0.parameters.return_full_text.unwrap_or(false) {
|
||||
add_prompt = Some(req.0.inputs.clone());
|
||||
if req.parameters.return_full_text.unwrap_or(false) {
|
||||
add_prompt = Some(req.inputs.clone());
|
||||
}
|
||||
let details = req.0.parameters.details;
|
||||
let details = req.parameters.details;
|
||||
|
||||
let best_of = req.0.parameters.best_of.unwrap_or(1);
|
||||
let best_of = req.parameters.best_of.unwrap_or(1);
|
||||
if best_of != 1 {
|
||||
let err = InferError::from(ValidationError::BestOfStream);
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
} else if req.0.parameters.decoder_input_details {
|
||||
} else if req.parameters.decoder_input_details {
|
||||
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
} else {
|
||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, mut response_stream)) => {
|
||||
// Server-Sent Event stream
|
||||
|
3037
server/poetry.lock
generated
3037
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -8,7 +8,7 @@ authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
text-generation-server = 'text_generation_server.cli:app'
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9"
|
||||
python = ">=3.9,<3.13"
|
||||
protobuf = "^4.21.7"
|
||||
grpcio = "^1.51.1"
|
||||
grpcio-status = "^1.51.1"
|
||||
@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1"
|
||||
grpc-interceptor = "^0.15.0"
|
||||
typer = "^0.6.1"
|
||||
accelerate = { version = "^0.19.0", optional = true }
|
||||
bitsandbytes = { version = "^0.38.1", optional = true }
|
||||
bitsandbytes = { version = "^0.40.0", optional = true }
|
||||
safetensors = "0.3.1"
|
||||
loguru = "^0.6.0"
|
||||
opentelemetry-api = "^1.15.0"
|
||||
@ -30,6 +30,9 @@ transformers = "4.29.2"
|
||||
einops = "^0.6.1"
|
||||
texttable = { version = "^1.6.7", optional = true }
|
||||
datasets = { version = "^2.14.0", optional = true }
|
||||
peft = "^0.4.0"
|
||||
torch = {version = "^2.0.1+cu118", source = "pytorch-gpu-src"}
|
||||
scipy = "^1.11.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
accelerate = ["accelerate"]
|
||||
@ -40,6 +43,12 @@ quantize = ["texttable", "datasets", "accelerate"]
|
||||
grpcio-tools = "^1.51.1"
|
||||
pytest = "^7.3.0"
|
||||
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = "pytorch-gpu-src"
|
||||
url = "https://download.pytorch.org/whl/cu118"
|
||||
priority = "explicit"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
|
||||
|
||||
|
@ -1,72 +1,59 @@
|
||||
accelerate==0.19.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "4.0"
|
||||
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
async-timeout==4.0.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
attrs==23.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0"
|
||||
charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows"
|
||||
datasets==2.14.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
|
||||
dill==0.3.7 ; python_version >= "3.9" and python_version < "4.0"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
fsspec[http]==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio-reflection==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio-status==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
|
||||
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
multidict==6.0.4 ; python_version >= "3.9" and python_version < "4.0"
|
||||
multiprocess==0.70.15 ; python_version >= "3.9" and python_version < "4.0"
|
||||
networkx==3.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
numpy==1.25.0 ; python_version < "4.0" and python_version >= "3.9"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
pandas==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
protobuf==4.23.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
psutil==5.9.5 ; python_version >= "3.9" and python_version < "4.0"
|
||||
pyarrow==12.0.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
pytz==2023.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
|
||||
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
sympy==1.12 ; python_version >= "3.9" and python_version < "4.0"
|
||||
texttable==1.6.7 ; python_version >= "3.9" and python_version < "4.0"
|
||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
torch==2.0.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
tzdata==2023.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
|
||||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
xxhash==3.2.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
yarl==1.9.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
accelerate==0.19.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
cmake==3.27.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
lit==16.0.6 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.25.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
peft==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.23.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
psutil==5.9.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2023.6.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.11.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
torch==2.0.1+cu118 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.29.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
triton==2.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -6,6 +6,7 @@ from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
@ -13,6 +14,8 @@ app = typer.Typer()
|
||||
|
||||
class Quantization(str, Enum):
|
||||
bitsandbytes = "bitsandbytes"
|
||||
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
||||
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
||||
gptq = "gptq"
|
||||
|
||||
|
||||
@ -88,6 +91,7 @@ def download_weights(
|
||||
auto_convert: bool = True,
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
@ -118,6 +122,12 @@ def download_weights(
|
||||
) is not None
|
||||
|
||||
if not is_local_model:
|
||||
try:
|
||||
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
|
||||
utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
|
||||
# Try to download weights from the hub
|
||||
try:
|
||||
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||||
|
@ -89,7 +89,7 @@ def get_model(
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
dtypetrust_remote_code=trust_remote_code,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_id.startswith("bigcode/"):
|
||||
@ -255,7 +255,10 @@ def get_model(
|
||||
raise ValueError(
|
||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
||||
raise ValueError(
|
||||
"4bit quantization is not supported for AutoModel"
|
||||
)
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
model_id,
|
||||
|
@ -185,8 +185,11 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
# self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
# )
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
@ -102,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
@ -133,7 +133,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
dim=self.head_size, base=10000.0, device=weights.device
|
||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
@ -247,7 +247,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
self.head_size, base=10000.0, device=weights.device
|
||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
|
@ -297,7 +297,7 @@ def triton_flash_attn_fn(
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""Multi-head self attention.
|
||||
|
||||
Using torch or triton attention implemetation enables user to also use
|
||||
Using torch or triton attention implementation enables user to also use
|
||||
additive bias.
|
||||
"""
|
||||
|
||||
@ -386,7 +386,7 @@ class MultiheadAttention(nn.Module):
|
||||
class MultiQueryAttention(nn.Module):
|
||||
"""Multi-Query self attention.
|
||||
|
||||
Using torch or triton attention implemetation enables user to also use
|
||||
Using torch or triton attention implementation enables user to also use
|
||||
additive bias.
|
||||
"""
|
||||
|
||||
|
@ -28,6 +28,7 @@ from transformers.modeling_outputs import (
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers import OPTConfig
|
||||
from text_generation_server.utils.layers import (
|
||||
FastLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
|
@ -54,7 +54,7 @@ class FlashRWSharded(FlashCausalLM):
|
||||
device,
|
||||
dtype,
|
||||
process_group=self.process_group,
|
||||
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
|
||||
aliases={"lm_head.weight": ["transformer.word_embeddings.weight"]},
|
||||
)
|
||||
|
||||
config.quantize = quantize
|
||||
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
|
||||
from text_generation_server.models.types import Batch, GeneratedText
|
||||
from text_generation_server.models.types import Batch, Generation
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
@ -52,7 +52,7 @@ class Model(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||||
def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def warmup(self, batch: B) -> Optional[int]:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from text_generation_server.utils.convert import convert_file, convert_files
|
||||
from text_generation_server.utils.dist import initialize_torch_distributed
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.utils.peft import download_and_unload_peft
|
||||
from text_generation_server.utils.hub import (
|
||||
weight_files,
|
||||
weight_hub_files,
|
||||
@ -26,6 +27,7 @@ __all__ = [
|
||||
"weight_files",
|
||||
"weight_hub_files",
|
||||
"download_weights",
|
||||
"download_and_unload_peft",
|
||||
"EntryNotFoundError",
|
||||
"HeterogeneousNextTokenChooser",
|
||||
"LocalEntryNotFoundError",
|
||||
|
@ -263,7 +263,7 @@ class QuantLinear(nn.Module):
|
||||
self.groupsize = groupsize
|
||||
|
||||
self.outfeatures = qweight.shape[1]
|
||||
self.infeatures = qweight.shape[0] * 32 // 4
|
||||
self.infeatures = qweight.shape[0] * 32 // bits
|
||||
|
||||
@classmethod
|
||||
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||
|
@ -360,15 +360,21 @@ class GPTQ:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def get_wikitext2(nsamples, seed, seqlen, model_id):
|
||||
def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):
|
||||
from datasets import load_dataset
|
||||
|
||||
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
||||
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
|
||||
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
|
||||
|
||||
@ -386,18 +392,21 @@ def get_wikitext2(nsamples, seed, seqlen, model_id):
|
||||
return trainloader, testenc
|
||||
|
||||
|
||||
def get_ptb(nsamples, seed, seqlen, model_id):
|
||||
def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):
|
||||
from datasets import load_dataset
|
||||
|
||||
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
|
||||
valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
|
||||
testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
|
||||
|
||||
@ -415,7 +424,7 @@ def get_ptb(nsamples, seed, seqlen, model_id):
|
||||
return trainloader, testenc
|
||||
|
||||
|
||||
def get_c4(nsamples, seed, seqlen, model_id):
|
||||
def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):
|
||||
from datasets import load_dataset
|
||||
|
||||
traindata = load_dataset(
|
||||
@ -433,12 +442,14 @@ def get_c4(nsamples, seed, seqlen, model_id):
|
||||
use_auth_token=False,
|
||||
)
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
import random
|
||||
|
||||
@ -481,18 +492,21 @@ def get_c4(nsamples, seed, seqlen, model_id):
|
||||
return trainloader, valenc
|
||||
|
||||
|
||||
def get_ptb_new(nsamples, seed, seqlen, model_id):
|
||||
def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):
|
||||
from datasets import load_dataset
|
||||
|
||||
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
|
||||
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
|
||||
testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")
|
||||
|
||||
@ -510,7 +524,7 @@ def get_ptb_new(nsamples, seed, seqlen, model_id):
|
||||
return trainloader, testenc
|
||||
|
||||
|
||||
def get_c4_new(nsamples, seed, seqlen, model_id):
|
||||
def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
|
||||
from datasets import load_dataset
|
||||
|
||||
traindata = load_dataset(
|
||||
@ -526,12 +540,14 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
|
||||
split="validation",
|
||||
)
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=False, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, use_fast=True, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
import random
|
||||
|
||||
@ -562,17 +578,17 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
|
||||
return trainloader, valenc
|
||||
|
||||
|
||||
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=""):
|
||||
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False):
|
||||
if "wikitext2" in name:
|
||||
return get_wikitext2(nsamples, seed, seqlen, model_id)
|
||||
return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)
|
||||
if "ptb" in name:
|
||||
if "new" in name:
|
||||
return get_ptb_new(nsamples, seed, seqlen, model_id)
|
||||
return get_ptb(nsamples, seed, seqlen, model_id)
|
||||
return get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code)
|
||||
return get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code)
|
||||
if "c4" in name:
|
||||
if "new" in name:
|
||||
return get_c4_new(nsamples, seed, seqlen, model_id)
|
||||
return get_c4(nsamples, seed, seqlen, model_id)
|
||||
return get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code)
|
||||
return get_c4(nsamples, seed, seqlen, model_id, trust_remote_code)
|
||||
|
||||
|
||||
def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""):
|
||||
@ -906,7 +922,12 @@ def quantize(
|
||||
seed = None
|
||||
|
||||
dataloader, testloader = get_loaders(
|
||||
dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen
|
||||
dataset,
|
||||
nsamples=nsamples,
|
||||
seed=seed,
|
||||
model_id=model_id,
|
||||
seqlen=model.seqlen,
|
||||
trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
tick = time.time()
|
||||
|
@ -9,7 +9,7 @@ from typing import List
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params
|
||||
from bitsandbytes.nn import Int8Params, Params4bit
|
||||
|
||||
except ImportError:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
@ -140,6 +140,39 @@ class Linear8bitLt(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class Linear4bit(nn.Module):
|
||||
def __init__(self, weight, bias, quant_type):
|
||||
super().__init__()
|
||||
self.weight = Params4bit(
|
||||
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
|
||||
)
|
||||
self.compute_dtype = None
|
||||
self.weight.cuda(weight.device)
|
||||
self.bias = bias
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if getattr(self.weight, "quant_state", None) is None:
|
||||
print(
|
||||
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
|
||||
)
|
||||
inp_dtype = x.dtype
|
||||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
||||
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||
out = bnb.matmul_4bit(
|
||||
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
|
||||
)
|
||||
|
||||
out = out.to(inp_dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_linear(weight, bias, quantize):
|
||||
if quantize is None:
|
||||
linear = FastLinear(weight, bias)
|
||||
@ -152,6 +185,18 @@ def get_linear(weight, bias, quantize):
|
||||
)
|
||||
if bias is not None:
|
||||
linear.bias = nn.Parameter(bias)
|
||||
elif quantize == "bitsandbytes-fp4":
|
||||
linear = Linear4bit(
|
||||
weight,
|
||||
bias,
|
||||
quant_type="fp4",
|
||||
)
|
||||
elif quantize == "bitsandbytes-nf4":
|
||||
linear = Linear4bit(
|
||||
weight,
|
||||
bias,
|
||||
quant_type="nf4",
|
||||
)
|
||||
elif quantize == "gptq":
|
||||
try:
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
||||
@ -381,33 +426,65 @@ try:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
|
||||
class PositionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, inv_freq):
|
||||
super().__init__()
|
||||
def _create_inv_freq(dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _get_rope_config(config):
|
||||
if os.getenv("ROPE_SCALING", None) is not None:
|
||||
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
|
||||
return rope_scaling
|
||||
return getattr(config, "rope_scaling", None)
|
||||
|
||||
class PositionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, inv_freq, scaling_factor):
|
||||
super().__init__()
|
||||
self.inv_freq = inv_freq
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dynamic_args = None
|
||||
|
||||
@classmethod
|
||||
def static(cls, dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
return cls(inv_freq)
|
||||
def static(cls, config, dim, base, device):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
@classmethod
|
||||
def load(cls, prefix, weights):
|
||||
def load(cls, config, prefix, weights):
|
||||
# XXX: Always load this in float32 !
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||
weights.dtype = dtype
|
||||
return cls(inv_freq)
|
||||
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
|
||||
return cls(inv_freq, scaling_factor)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
@ -419,8 +496,11 @@ try:
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
if self.scaling_factor is not None:
|
||||
t /= self.scaling_factor
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
@ -446,5 +526,36 @@ try:
|
||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||
return x
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
if seqlen > self.max_position_embeddings:
|
||||
newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
|
||||
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
if self.scaling_factor is not None:
|
||||
t /= self.scaling_factor
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
46
server/text_generation_server/utils/peft.py
Normal file
46
server/text_generation_server/utils/peft.py
Normal file
@ -0,0 +1,46 @@
|
||||
import os
|
||||
import json
|
||||
from loguru import logger
|
||||
import torch
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
|
||||
|
||||
def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||
torch_dtype = torch.float16
|
||||
|
||||
logger.info("Peft model detected.")
|
||||
logger.info("Loading the model it might take a while without feedback")
|
||||
try:
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
except Exception:
|
||||
model = AutoPeftModelForSeq2SeqLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
logger.info(f"Loaded.")
|
||||
logger.info(f"Merging the lora weights.")
|
||||
|
||||
base_model_id = model.peft_config["default"].base_model_name_or_path
|
||||
|
||||
model = model.merge_and_unload()
|
||||
|
||||
os.makedirs(model_id, exist_ok=True)
|
||||
cache_dir = model_id
|
||||
logger.info(f"Saving the newly created merged model to {cache_dir}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
|
||||
model.save_pretrained(cache_dir, safe_serialization=True)
|
||||
model.config.save_pretrained(cache_dir)
|
||||
tokenizer.save_pretrained(cache_dir)
|
||||
|
||||
|
||||
|
@ -224,7 +224,7 @@ class Weights:
|
||||
def _set_gptq_params(self, model_id):
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
if not os.path.exists(os.path.join(model_id, filename)):
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename)
|
||||
|
Loading…
Reference in New Issue
Block a user