Merge branch 'main' into added_cli_docs

This commit is contained in:
Merve Noyan 2023-08-10 13:42:02 +03:00 committed by GitHub
commit 29129dc660
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 2086 additions and 1666 deletions

View File

@ -39,8 +39,9 @@ RUN cargo build --release
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM debian:bullseye-slim as pytorch-install
ARG PYTORCH_VERSION=2.0.0
ARG PYTORCH_VERSION=2.0.1
ARG PYTHON_VERSION=3.9
# Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=11.8
ARG MAMBA_VERSION=23.1.0-1
ARG CUDA_CHANNEL=nvidia

View File

@ -252,6 +252,8 @@ You can also quantize the weights with bitsandbytes to reduce the VRAM requireme
make run-falcon-7b-instruct-quantize
```
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
## Develop
```shell

View File

@ -1,18 +1,18 @@
- sections:
- local: index
title: Text Generation Inference
- local: installation_launch
title: Installation and Launching
- local: quicktour
title: Quick Tour
- local: installation
title: Installation
- local: supported_models
title: Supported Models and Hardware
title: Getting started
- sections:
- local: basic_tutorials/local_launch
title: Installing from the Source and Launching TGI
- local: basic_tutorials/consuming_tgi
title: Consuming TGI
- local: basic_tutorials/preparing_model
title: Preparing Model for Serving
- local: basic_tutorials/using_cli
title: Using TGI through CLI
- local: basic_tutorials/gated_model_access
title: Serving Private & Gated Models
title: Tutorials

View File

@ -0,0 +1,5 @@
# Serving Private & Gated Models
If the model you wish to serve is behind gated access or the model repository on Hugging Face Hub is private, and you have access to the model, you can provide your Hugging Face Hub access token. You can generate and copy a read token from [Hugging Face Hub tokens page](https://huggingface.co/settings/tokens)
If you're using the CLI, set the `HUGGING_FACE_HUB_TOKEN` environment variable.

View File

@ -5,8 +5,6 @@ Text Generation Inference improves the model in several aspects.
## Quantization
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes` or `gptq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq).
To run quantization with TGI, refer to [`Using TGI through CLI`](TODO: ADD INTERNAL REF) section.
## RoPE Scaling

View File

@ -0,0 +1,71 @@
# Installation
This section explains how to install the CLI tool as well as installing TGI from source. **The strongly recommended approach is to use Docker, as it does not require much setup. Check [the Quick Tour](./quicktour) to learn how to run TGI with Docker.**
## Install CLI
TODO
## Local Installation from Source
Before you start, you will need to setup your environment, and install Text Generation Inference. Text Generation Inference is tested on **Python 3.9+**.
Text Generation Inference is available on pypi, conda and GitHub.
To install and launch locally, first [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
Python 3.9, e.g. using conda:
```shell
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
conda create -n text-generation-inference python=3.9
conda activate text-generation-inference
```
You may also need to install Protoc.
On Linux:
```shell
PROTOC_ZIP=protoc-21.12-linux-x86_64.zip
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP
sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc
sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'
rm -f $PROTOC_ZIP
```
On MacOS, using Homebrew:
```shell
brew install protobuf
```
Then run to install Text Generation Inference:
```shell
BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
```
<Tip warning={true}>
On some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:
```shell
sudo apt-get install libssl-dev gcc -y
```
</Tip>
Once installation is done, simply run:
```shell
make run-falcon-7b-instruct
```
This will serve Falcon 7B Instruct model from the port 8080, which we can query.
To see all options to serve your models, check in the [codebase](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or the CLI:
```
text-generation-launcher --help
```

34
docs/source/quicktour.md Normal file
View File

@ -0,0 +1,34 @@
# Quick Tour
The easiest way of getting started is using the official Docker container. Install Docker following [their installation instructions](https://docs.docker.com/get-docker/).
Let's say you want to deploy [Falcon-7B Instruct](https://huggingface.co/tiiuae/falcon-7b-instruct) model with TGI. Here is an example on how to do that:
```shell
model=tiiuae/falcon-7b-instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.0 --model-id $model
```
<Tip warning={true}>
To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) . We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.
</Tip>
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section.
```shell
curl 127.0.0.1:8080/generate -X POST -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' -H 'Content-Type: application/json'
```
<Tip>
To see all possible flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```shell
docker run ghcr.io/huggingface/text-generation-inference:1.0.0 --help
```
</Tip>

View File

@ -22,6 +22,8 @@ mod env_runtime;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
Bitsandbytes,
BitsandbytesNF4,
BitsandbytesFP4,
Gptq,
}
@ -32,6 +34,12 @@ impl std::fmt::Display for Quantization {
Quantization::Bitsandbytes => {
write!(f, "bitsandbytes")
}
Quantization::BitsandbytesNF4 => {
write!(f, "bitsandbytes-nf4")
}
Quantization::BitsandbytesFP4 => {
write!(f, "bitsandbytes-fp4")
}
Quantization::Gptq => {
write!(f, "gptq")
}
@ -60,6 +68,26 @@ impl std::fmt::Display for Dtype {
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
Linear,
Dynamic,
}
impl std::fmt::Display for RopeScaling {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
RopeScaling::Linear => {
write!(f, "linear")
}
RopeScaling::Dynamic => {
write!(f, "dynamic")
}
}
}
}
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
@ -96,7 +124,8 @@ struct Args {
num_shard: Option<usize>,
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
/// quantization on the fly, or `gptq`.
/// quantization on the fly, or `gptq`. 4bit quantization is available through
/// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options.
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,
@ -250,6 +279,26 @@ struct Args {
#[clap(default_value = "1.0", long, env)]
cuda_memory_fraction: f32,
/// Rope scaling will only be used for RoPE models
/// and allow rescaling the position rotary to accomodate for
/// larger prompts.
///
/// Goes together with `rope_factor`.
///
/// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
/// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
/// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
/// basically)
///
/// `--rope-scaling linear --rope-factor` fully describes the scaling you want
#[clap(long, env)]
rope_scaling: Option<RopeScaling>,
/// Rope scaling will only be used for RoPE models
/// See `rope_scaling`
#[clap(long, env)]
rope_factor: Option<f32>,
/// Outputs the logs in JSON format (useful for telemetry)
#[clap(long, env)]
json_output: bool,
@ -305,6 +354,8 @@ fn shard_manager(
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>,
@ -358,6 +409,12 @@ fn shard_manager(
shard_args.push(revision)
}
let rope = match (rope_scaling, rope_factor) {
(None, None) => None,
(Some(scaling), None) => Some((scaling, 1.0)),
(Some(scaling), Some(factor)) => Some((scaling, factor)),
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
};
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-endpoint".to_string());
@ -395,6 +452,15 @@ fn shard_manager(
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};
// Detect rope scaling
// Sending as env instead of CLI args to not bloat everything
// those only can be used by RoPE models, so passing information around
// for all models will complexify code unnecessarily
if let Some((scaling, factor)) = rope {
envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
}
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
@ -659,6 +725,11 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
download_args.push(revision.to_string())
}
// Trust remote code for automatic peft fusion
if args.trust_remote_code {
download_args.push("--trust-remote-code".to_string());
}
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
@ -784,6 +855,8 @@ fn spawn_shards(
let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta;
let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
thread::spawn(move || {
shard_manager(
model_id,
@ -802,6 +875,8 @@ fn spawn_shards(
watermark_gamma,
watermark_delta,
cuda_memory_fraction,
rope_scaling,
rope_factor,
otlp_endpoint,
status_sender,
shutdown,

View File

@ -54,15 +54,13 @@ example = json ! ({"error": "Incomplete generation"})),
)]
#[instrument(skip(infer, req))]
async fn compat_generate(
default_return_full_text: Extension<bool>,
Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>,
req: Json<CompatGenerateRequest>,
Json(mut req): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let mut req = req.0;
// default return_full_text given the pipeline_tag
if req.parameters.return_full_text.is_none() {
req.parameters.return_full_text = Some(default_return_full_text.0)
req.parameters.return_full_text = Some(default_return_full_text)
}
// switch on stream
@ -71,9 +69,9 @@ async fn compat_generate(
.await
.into_response())
} else {
let (headers, generation) = generate(infer, Json(req.into())).await?;
let (headers, Json(generation)) = generate(infer, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation.0])).into_response())
Ok((headers, Json(vec![generation])).into_response())
}
}
@ -135,7 +133,7 @@ example = json ! ({"error": "Incomplete generation"})),
#[instrument(
skip_all,
fields(
parameters = ? req.0.parameters,
parameters = ? req.parameters,
total_time,
validation_time,
queue_time,
@ -146,29 +144,29 @@ seed,
)]
async fn generate(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count");
tracing::debug!("Input: {}", req.0.inputs);
tracing::debug!("Input: {}", req.inputs);
let compute_characters = req.0.inputs.chars().count();
let compute_characters = req.inputs.chars().count();
let mut add_prompt = None;
if req.0.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.0.inputs.clone());
if req.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.inputs.clone());
}
let details = req.0.parameters.details || req.0.parameters.decoder_input_details;
let details = req.parameters.details || req.parameters.decoder_input_details;
// Inference
let (response, best_of_responses) = match req.0.parameters.best_of {
let (response, best_of_responses) = match req.parameters.best_of {
Some(best_of) if best_of > 1 => {
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?;
(response, Some(best_of_responses))
}
_ => (infer.generate(req.0).await?, None),
_ => (infer.generate(req).await?, None),
};
// Token details
@ -321,7 +319,7 @@ content_type = "text/event-stream"),
#[instrument(
skip_all,
fields(
parameters = ? req.0.parameters,
parameters = ? req.parameters,
total_time,
validation_time,
queue_time,
@ -331,8 +329,8 @@ seed,
)
)]
async fn generate_stream(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
Extension(infer): Extension<Infer>,
Json(req): Json<GenerateRequest>,
) -> (
HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>,
@ -341,9 +339,9 @@ async fn generate_stream(
let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count");
tracing::debug!("Input: {}", req.0.inputs);
tracing::debug!("Input: {}", req.inputs);
let compute_characters = req.0.inputs.chars().count();
let compute_characters = req.inputs.chars().count();
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
@ -359,24 +357,24 @@ async fn generate_stream(
let mut error = false;
let mut add_prompt = None;
if req.0.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.0.inputs.clone());
if req.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.inputs.clone());
}
let details = req.0.parameters.details;
let details = req.parameters.details;
let best_of = req.0.parameters.best_of.unwrap_or(1);
let best_of = req.parameters.best_of.unwrap_or(1);
if best_of != 1 {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else if req.0.parameters.decoder_input_details {
} else if req.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}");
yield Ok(Event::from(err));
} else {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives
Ok((_permit, mut response_stream)) => {
// Server-Sent Event stream

3037
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -8,7 +8,7 @@ authors = ["Olivier Dehaene <olivier@huggingface.co>"]
text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies]
python = "^3.9"
python = ">=3.9,<3.13"
protobuf = "^4.21.7"
grpcio = "^1.51.1"
grpcio-status = "^1.51.1"
@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0"
typer = "^0.6.1"
accelerate = { version = "^0.19.0", optional = true }
bitsandbytes = { version = "^0.38.1", optional = true }
bitsandbytes = { version = "^0.40.0", optional = true }
safetensors = "0.3.1"
loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"
@ -30,6 +30,9 @@ transformers = "4.29.2"
einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }
peft = "^0.4.0"
torch = {version = "^2.0.1+cu118", source = "pytorch-gpu-src"}
scipy = "^1.11.1"
[tool.poetry.extras]
accelerate = ["accelerate"]
@ -40,6 +43,12 @@ quantize = ["texttable", "datasets", "accelerate"]
grpcio-tools = "^1.51.1"
pytest = "^7.3.0"
[[tool.poetry.source]]
name = "pytorch-gpu-src"
url = "https://download.pytorch.org/whl/cu118"
priority = "explicit"
[tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]

View File

@ -1,72 +1,59 @@
accelerate==0.19.0 ; python_version >= "3.9" and python_version < "4.0"
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "4.0"
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
async-timeout==4.0.2 ; python_version >= "3.9" and python_version < "4.0"
attrs==23.1.0 ; python_version >= "3.9" and python_version < "4.0"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0"
certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0"
charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows"
datasets==2.14.0 ; python_version >= "3.9" and python_version < "4.0"
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
dill==0.3.7 ; python_version >= "3.9" and python_version < "4.0"
einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "4.0"
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
fsspec[http]==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio-reflection==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio-status==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio==1.56.0 ; python_version >= "3.9" and python_version < "4.0"
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "4.0"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
multidict==6.0.4 ; python_version >= "3.9" and python_version < "4.0"
multiprocess==0.70.15 ; python_version >= "3.9" and python_version < "4.0"
networkx==3.1 ; python_version >= "3.9" and python_version < "4.0"
numpy==1.25.0 ; python_version < "4.0" and python_version >= "3.9"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
pandas==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
protobuf==4.23.3 ; python_version >= "3.9" and python_version < "4.0"
psutil==5.9.5 ; python_version >= "3.9" and python_version < "4.0"
pyarrow==12.0.1 ; python_version >= "3.9" and python_version < "4.0"
python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
pytz==2023.3 ; python_version >= "3.9" and python_version < "4.0"
pyyaml==6.0 ; python_version >= "3.9" and python_version < "4.0"
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0"
six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
sympy==1.12 ; python_version >= "3.9" and python_version < "4.0"
texttable==1.6.7 ; python_version >= "3.9" and python_version < "4.0"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
torch==2.0.1 ; python_version >= "3.9" and python_version < "4.0"
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "4.0"
tzdata==2023.3 ; python_version >= "3.9" and python_version < "4.0"
urllib3==2.0.3 ; python_version >= "3.9" and python_version < "4.0"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
xxhash==3.2.0 ; python_version >= "3.9" and python_version < "4.0"
yarl==1.9.2 ; python_version >= "3.9" and python_version < "4.0"
--extra-index-url https://download.pytorch.org/whl/cu118
accelerate==0.19.0 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.6 ; python_version >= "3.9" and python_version < "3.13"
cmake==3.27.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.12.2 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
lit==16.0.6 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.25.2 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
peft==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.23.4 ; python_version >= "3.9" and python_version < "3.13"
psutil==5.9.5 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.6.3 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
torch==2.0.1+cu118 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.29.2 ; python_version >= "3.9" and python_version < "3.13"
triton==2.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -6,6 +6,7 @@ from pathlib import Path
from loguru import logger
from typing import Optional
from enum import Enum
from huggingface_hub import hf_hub_download
app = typer.Typer()
@ -13,6 +14,8 @@ app = typer.Typer()
class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
bitsandbytes_nf4 = "bitsandbytes-nf4"
bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq"
@ -88,6 +91,7 @@ def download_weights(
auto_convert: bool = True,
logger_level: str = "INFO",
json_output: bool = False,
trust_remote_code: bool = False,
):
# Remove default handler
logger.remove()
@ -118,6 +122,12 @@ def download_weights(
) is not None
if not is_local_model:
try:
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# Try to download weights from the hub
try:
filenames = utils.weight_hub_files(model_id, revision, extension)

View File

@ -89,7 +89,7 @@ def get_model(
revision,
quantize=quantize,
dtype=dtype,
dtypetrust_remote_code=trust_remote_code,
trust_remote_code=trust_remote_code,
)
if model_id.startswith("bigcode/"):
@ -255,7 +255,10 @@ def get_model(
raise ValueError(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError(
"4bit quantization is not supported for AutoModel"
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,

View File

@ -185,8 +185,11 @@ class FlashLlamaAttention(torch.nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
# self.rotary_emb = PositionRotaryEmbedding.load(
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
# )
self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size**-0.5

View File

@ -102,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
)
self.softmax_scale = self.head_size ** (-0.5)

View File

@ -133,7 +133,7 @@ class FlashRWAttention(torch.nn.Module):
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
config=config, dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)
@ -247,7 +247,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
config=config, dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)

View File

@ -297,7 +297,7 @@ def triton_flash_attn_fn(
class MultiheadAttention(nn.Module):
"""Multi-head self attention.
Using torch or triton attention implemetation enables user to also use
Using torch or triton attention implementation enables user to also use
additive bias.
"""
@ -386,7 +386,7 @@ class MultiheadAttention(nn.Module):
class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
Using torch or triton attention implemetation enables user to also use
Using torch or triton attention implementation enables user to also use
additive bias.
"""

View File

@ -28,6 +28,7 @@ from transformers.modeling_outputs import (
from transformers.modeling_utils import PreTrainedModel
from transformers import OPTConfig
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,

View File

@ -54,7 +54,7 @@ class FlashRWSharded(FlashCausalLM):
device,
dtype,
process_group=self.process_group,
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
aliases={"lm_head.weight": ["transformer.word_embeddings.weight"]},
)
config.quantize = quantize

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.models.types import Batch, Generation
from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch)
@ -52,7 +52,7 @@ class Model(ABC):
raise NotImplementedError
@abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
raise NotImplementedError
def warmup(self, batch: B) -> Optional[int]:

View File

@ -1,6 +1,7 @@
from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.weights import Weights
from text_generation_server.utils.peft import download_and_unload_peft
from text_generation_server.utils.hub import (
weight_files,
weight_hub_files,
@ -26,6 +27,7 @@ __all__ = [
"weight_files",
"weight_hub_files",
"download_weights",
"download_and_unload_peft",
"EntryNotFoundError",
"HeterogeneousNextTokenChooser",
"LocalEntryNotFoundError",

View File

@ -263,7 +263,7 @@ class QuantLinear(nn.Module):
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // 4
self.infeatures = qweight.shape[0] * 32 // bits
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):

View File

@ -360,15 +360,21 @@ class GPTQ:
torch.cuda.empty_cache()
def get_wikitext2(nsamples, seed, seqlen, model_id):
def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):
from datasets import load_dataset
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
@ -386,18 +392,21 @@ def get_wikitext2(nsamples, seed, seqlen, model_id):
return trainloader, testenc
def get_ptb(nsamples, seed, seqlen, model_id):
def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):
from datasets import load_dataset
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
@ -415,7 +424,7 @@ def get_ptb(nsamples, seed, seqlen, model_id):
return trainloader, testenc
def get_c4(nsamples, seed, seqlen, model_id):
def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):
from datasets import load_dataset
traindata = load_dataset(
@ -433,12 +442,14 @@ def get_c4(nsamples, seed, seqlen, model_id):
use_auth_token=False,
)
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
import random
@ -481,18 +492,21 @@ def get_c4(nsamples, seed, seqlen, model_id):
return trainloader, valenc
def get_ptb_new(nsamples, seed, seqlen, model_id):
def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):
from datasets import load_dataset
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")
@ -510,7 +524,7 @@ def get_ptb_new(nsamples, seed, seqlen, model_id):
return trainloader, testenc
def get_c4_new(nsamples, seed, seqlen, model_id):
def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
from datasets import load_dataset
traindata = load_dataset(
@ -526,12 +540,14 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
split="validation",
)
from transformers import AutoTokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
import random
@ -562,17 +578,17 @@ def get_c4_new(nsamples, seed, seqlen, model_id):
return trainloader, valenc
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id=""):
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False):
if "wikitext2" in name:
return get_wikitext2(nsamples, seed, seqlen, model_id)
return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)
if "ptb" in name:
if "new" in name:
return get_ptb_new(nsamples, seed, seqlen, model_id)
return get_ptb(nsamples, seed, seqlen, model_id)
return get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code)
return get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code)
if "c4" in name:
if "new" in name:
return get_c4_new(nsamples, seed, seqlen, model_id)
return get_c4(nsamples, seed, seqlen, model_id)
return get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code)
return get_c4(nsamples, seed, seqlen, model_id, trust_remote_code)
def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""):
@ -906,7 +922,12 @@ def quantize(
seed = None
dataloader, testloader = get_loaders(
dataset, nsamples=nsamples, seed=seed, model_id=model_id, seqlen=model.seqlen
dataset,
nsamples=nsamples,
seed=seed,
model_id=model_id,
seqlen=model.seqlen,
trust_remote_code=trust_remote_code
)
tick = time.time()

View File

@ -9,7 +9,7 @@ from typing import List
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
from bitsandbytes.nn import Int8Params, Params4bit
except ImportError:
HAS_BITS_AND_BYTES = False
@ -140,6 +140,39 @@ class Linear8bitLt(nn.Module):
return out
class Linear4bit(nn.Module):
def __init__(self, weight, bias, quant_type):
super().__init__()
self.weight = Params4bit(
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
)
self.compute_dtype = None
self.weight.cuda(weight.device)
self.bias = bias
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, "quant_state", None) is None:
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
)
inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
)
out = out.to(inp_dtype)
return out
def get_linear(weight, bias, quantize):
if quantize is None:
linear = FastLinear(weight, bias)
@ -152,6 +185,18 @@ def get_linear(weight, bias, quantize):
)
if bias is not None:
linear.bias = nn.Parameter(bias)
elif quantize == "bitsandbytes-fp4":
linear = Linear4bit(
weight,
bias,
quant_type="fp4",
)
elif quantize == "bitsandbytes-nf4":
linear = Linear4bit(
weight,
bias,
quant_type="nf4",
)
elif quantize == "gptq":
try:
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
@ -381,33 +426,65 @@ try:
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq):
super().__init__()
def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return inv_freq
def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
return rope_scaling
return getattr(config, "rope_scaling", None)
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor):
super().__init__()
self.inv_freq = inv_freq
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None
@classmethod
def static(cls, dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return cls(inv_freq)
def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
return cls(inv_freq, scaling_factor)
@classmethod
def load(cls, prefix, weights):
def load(cls, config, prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
return cls(inv_freq)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
return cls(inv_freq, scaling_factor)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
@ -419,8 +496,11 @@ try:
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
@ -446,5 +526,36 @@ try:
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
return x
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings:
newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
except ImportError:
pass

View File

@ -0,0 +1,46 @@
import os
import json
from loguru import logger
import torch
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
def download_and_unload_peft(model_id, revision, trust_remote_code):
torch_dtype = torch.float16
logger.info("Peft model detected.")
logger.info("Loading the model it might take a while without feedback")
try:
model = AutoPeftModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
except Exception:
model = AutoPeftModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
logger.info(f"Loaded.")
logger.info(f"Merging the lora weights.")
base_model_id = model.peft_config["default"].base_model_name_or_path
model = model.merge_and_unload()
os.makedirs(model_id, exist_ok=True)
cache_dir = model_id
logger.info(f"Saving the newly created merged model to {cache_dir}")
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
model.save_pretrained(cache_dir, safe_serialization=True)
model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir)

View File

@ -224,7 +224,7 @@ class Weights:
def _set_gptq_params(self, model_id):
filename = "quantize_config.json"
try:
if not os.path.exists(os.path.join(model_id, filename)):
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename)