From 9ad7b6a1a12f8cd6b715be9f0ca85603e0a2b002 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 1 Feb 2024 13:29:04 +0100 Subject: [PATCH 01/21] Hotfix the / health - route. (#1515) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- router/src/server.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/router/src/server.rs b/router/src/server.rs index 52ed03df..b4d26158 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -936,6 +936,7 @@ pub async fn run( // Define base and health routes let base_routes = Router::new() .route("/", post(compat_generate)) + .route("/", get(health)) .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) From 1e03b61b5c56e2ed5c723457df21cc18d48c1854 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 1 Feb 2024 14:36:10 +0000 Subject: [PATCH 02/21] Revert "Modify default for max_new_tokens in python client (#1336)" This reverts commit 2d56f106a60c7b698705494e7539f8a7e4c85dd9. It causes a breaking in our integrations-tests. --- clients/python/tests/test_client.py | 16 ---------------- clients/python/text_generation/client.py | 8 ++++---- clients/python/text_generation/types.py | 2 +- 3 files changed, 5 insertions(+), 21 deletions(-) diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 775e7a6c..1e25e1b1 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -21,22 +21,6 @@ def test_generate(flan_t5_xxl_url, hf_headers): assert not response.details.tokens[0].special -def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers): - client = Client(flan_t5_xxl_url, hf_headers) - response = client.generate("test", decoder_input_details=True) - - assert response.generated_text != "" - assert response.details.finish_reason == FinishReason.EndOfSequenceToken - assert response.details.generated_tokens > 1 - assert response.details.seed is None - assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) - assert len(response.details.tokens) > 1 - assert response.details.tokens[0].id == 3 - assert response.details.tokens[0].text == " " - assert not response.details.tokens[0].special - - def test_generate_best_of(flan_t5_xxl_url, hf_headers): client = Client(flan_t5_xxl_url, hf_headers) response = client.generate( diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 63b5258d..0bf80f8c 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -62,7 +62,7 @@ class Client: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -157,7 +157,7 @@ class Client: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -312,7 +312,7 @@ class AsyncClient: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -405,7 +405,7 @@ class AsyncClient: self, prompt: str, do_sample: bool = False, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 7fa8033e..aa02d8d8 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -9,7 +9,7 @@ class Parameters(BaseModel): # Activate logits sampling do_sample: bool = False # Maximum number of generated tokens - max_new_tokens: Optional[int] = None + max_new_tokens: int = 20 # The parameter for repetition penalty. 1.0 means no penalty. # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. repetition_penalty: Optional[float] = None From ee1cf51ce796e4b034eedaf3e909b4c902eae70c Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 1 Feb 2024 09:39:32 -0500 Subject: [PATCH 03/21] fix: tokenizer config should use local model path when possible (#1518) This PR fixes the issue with loading a local tokenizer config. Previously the default functionality would look in the current working directory. Now if a local model path is specified we will check that directory for the tokenizer_config. ## Examples of valid commands uses tokenizer_config from hub ``` text-generation-launcher --model-id HuggingFaceH4/zephyr-7b-beta ``` use tokenizer_config from local model path ``` text-generation-launcher \ --model-id ~/.cache/huggingface/hub/models--HuggingFaceH4--zephyr-7b-beta/snapshots/dc24cabd13eacd3ae3a5fe574bd645483a335a4a/ ``` use specific tokenizer_config file ``` text-generation-launcher \ --model-id ~/.cache/huggingface/hub/models--HuggingFaceH4--zephyr-7b-beta/snapshots/dc24cabd13eacd3ae3a5fe574bd645483a335a4a/ \ --tokenizer-config-path ~/.cache/huggingface/hub/models--HuggingFaceH4--zephyr-7b-beta/snapshots/dc24cabd13eacd3ae3a5fe574bd645483a335a4a/tokenizer_config.json ``` --------- Co-authored-by: Nicolas Patry --- router/src/lib.rs | 2 +- router/src/main.rs | 49 +++++++++++++++++++++++++--------------------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index fc5670a0..07360e78 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -37,7 +37,7 @@ pub struct HubTokenizerConfig { } impl HubTokenizerConfig { - pub fn from_file(filename: &str) -> Self { + pub fn from_file(filename: &std::path::Path) -> Self { let content = std::fs::read_to_string(filename).unwrap(); serde_json::from_str(&content).unwrap_or_default() } diff --git a/router/src/main.rs b/router/src/main.rs index 495fd5bc..2a080468 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -154,12 +154,6 @@ async fn main() -> Result<(), RouterError> { let local_path = Path::new(&tokenizer_name); let local_model = local_path.exists() && local_path.is_dir(); - // Load tokenizer config - // This will be used to format the chat template - let local_tokenizer_config_path = - tokenizer_config_path.unwrap_or("tokenizer_config.json".to_string()); - let local_tokenizer_config = Path::new(&local_tokenizer_config_path).exists(); - // Shared API builder initialization let api_builder = || { let mut builder = ApiBuilder::new() @@ -230,24 +224,35 @@ async fn main() -> Result<(), RouterError> { }; // Load tokenizer config if found locally, or check if we can get it from the API if needed - let tokenizer_config = if local_tokenizer_config { + let tokenizer_config = if let Some(path) = tokenizer_config_path { + tracing::info!("Using local tokenizer config from user specified path"); + HubTokenizerConfig::from_file(&std::path::PathBuf::from(path)) + } else if local_model { tracing::info!("Using local tokenizer config"); - HubTokenizerConfig::from_file(&local_tokenizer_config_path) - } else if let Some(api) = api { - tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); - get_tokenizer_config(&api.repo(Repo::with_revision( - tokenizer_name.to_string(), - RepoType::Model, - revision.unwrap_or_else(|| "main".to_string()), - ))) - .await - .unwrap_or_else(|| { - tracing::warn!("Could not retrieve tokenizer config from the Hugging Face hub."); - HubTokenizerConfig::default() - }) + HubTokenizerConfig::from_file(&local_path.join("tokenizer_config.json")) } else { - tracing::warn!("Could not find tokenizer config locally and no revision specified"); - HubTokenizerConfig::default() + match api { + Some(api) => { + tracing::info!("Using the Hugging Face API to retrieve tokenizer config"); + let repo = Repo::with_revision( + tokenizer_name.to_string(), + RepoType::Model, + revision.unwrap_or("main".to_string()), + ); + get_tokenizer_config(&api.repo(repo)) + .await + .unwrap_or_else(|| { + tracing::warn!( + "Could not retrieve tokenizer config from the Hugging Face hub." + ); + HubTokenizerConfig::default() + }) + } + None => { + tracing::warn!("Could not find tokenizer config locally and no API specified"); + HubTokenizerConfig::default() + } + } }; if tokenizer.is_none() { From 0e97af456af3102ed4f927f7b7e870ec976079ae Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 1 Feb 2024 16:26:48 +0100 Subject: [PATCH 04/21] Updating tokenizers. (#1517) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- router/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/Cargo.toml b/router/Cargo.toml index f6f16dae..1a7ceb70 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -32,7 +32,7 @@ reqwest = { version = "0.11.20", features = [] } serde = "1.0.188" serde_json = "1.0.107" thiserror = "1.0.48" -tokenizers = { version = "0.14.0", features = ["http"] } +tokenizers = { version = "0.15.1", features = ["http"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.14" tower-http = { version = "0.4.4", features = ["cors"] } From 3ab578b4160b200ad601bbd30bd8ecf39b979326 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 2 Feb 2024 14:05:30 +0100 Subject: [PATCH 05/21] [docs] Fix link to Install CLI (#1526) # What does this PR do? Attempts to fix a link from Using TGI CLI to Installation. ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? --- docs/source/basic_tutorials/using_cli.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/basic_tutorials/using_cli.md b/docs/source/basic_tutorials/using_cli.md index 82c10e6b..a3a65f60 100644 --- a/docs/source/basic_tutorials/using_cli.md +++ b/docs/source/basic_tutorials/using_cli.md @@ -1,6 +1,6 @@ # Using TGI CLI -You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](./installation#install-cli). +You can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](../installation#install-cli). `text-generation-server` lets you download the model with `download-weights` command like below 👇 From 0da00be52c9e591f8890ab07eea05cc15b9b127b Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 2 Feb 2024 10:31:11 -0500 Subject: [PATCH 06/21] feat: add ie update to message docs (#1523) update messages api docs and add Hugging Face Inference Endpoints integrations section/instructions --------- Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com> --- docs/source/messages_api.md | 45 +++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/docs/source/messages_api.md b/docs/source/messages_api.md index 1e342686..939850aa 100644 --- a/docs/source/messages_api.md +++ b/docs/source/messages_api.md @@ -4,6 +4,15 @@ Text Generation Inference (TGI) now supports the Messages API, which is fully co > **Note:** The Messages API is supported from TGI version 1.4.0 and above. Ensure you are using a compatible version to access this feature. +#### Table of Contents + +- [Making a Request](#making-a-request) +- [Streaming](#streaming) +- [Synchronous](#synchronous) +- [Hugging Face Inference Endpoints](#hugging-face-inference-endpoints) +- [Cloud Providers](#cloud-providers) + - [Amazon SageMaker](#amazon-sagemaker) + ## Making a Request You can make a request to TGI's Messages API using `curl`. Here's an example: @@ -81,6 +90,38 @@ chat_completion = client.chat.completions.create( print(chat_completion) ``` +## Hugging Face Inference Endpoints + +The Messages API is integrated with [Inference Endpoints](https://huggingface.co/inference-endpoints/dedicated). +Every endpoint that uses "Text Generation Inference" with an LLM, which has a chat template can now be used. Below is an example of how to use IE with TGI using OpenAI's Python client library: + +> **Note:** Make sure to replace `base_url` with your endpoint URL and to include `v1/` at the end of the URL. The `api_key` should be replaced with your Hugging Face API key. + +```python +from openai import OpenAI + +# init the client but point it to TGI +client = OpenAI( + # replace with your endpoint url, make sure to include "v1/" at the end + base_url="https://vlzz10eq3fol3429.us-east-1.aws.endpoints.huggingface.cloud/v1/", + # replace with your API key + api_key="hf_XXX" +) + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "What is deep learning?"} + ], + stream=True +) + +# iterate and print stream +for message in chat_completion: + print(message.choices[0].delta.content, end="") +``` + ## Cloud Providers TGI can be deployed on various cloud providers for scalable and robust text generation. One such provider is Amazon SageMaker, which has recently added support for TGI. Here's how you can deploy TGI on Amazon SageMaker: @@ -114,7 +155,7 @@ hub = { huggingface_model = HuggingFaceModel( image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"), env=hub, - role=role, + role=role, ) # deploy model to SageMaker Inference @@ -123,7 +164,7 @@ predictor = huggingface_model.deploy( instance_type="ml.g5.2xlarge", container_startup_health_check_timeout=300, ) - + # send request predictor.predict({ "messages": [ From 17345402114e3d4a1f1c0a7fd650c734f7a992f9 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 7 Feb 2024 03:35:53 -0500 Subject: [PATCH 07/21] =?UTF-8?q?feat:=20use=20existing=20add=5Fgeneration?= =?UTF-8?q?=5Fprompt=20variable=20from=20config=20in=20temp=E2=80=A6=20(#1?= =?UTF-8?q?533)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds support to read the `add_generation_prompt` from the config and use it in the chat template. If `add_generation_prompt` does not exist we default to false --- router/src/infer.rs | 64 ++++++++++++++++++++++++++++++++++++++------- router/src/lib.rs | 1 + 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 5f078ba0..4da0da0a 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -198,6 +198,7 @@ impl Infer { messages, eos_token: eos_token.as_deref(), bos_token: bos_token.as_deref(), + add_generation_prompt: true, }) .map_err(|e| { metrics::increment_counter!("tgi_request_failure", "err" => "template"); @@ -806,21 +807,14 @@ mod tests { ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), + add_generation_prompt: true, }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); assert_eq!( result, - r#"### User: -Hi! - -### Assistant: -Hello how can I help?### User: -What is Deep Learning? - -### Assistant: -magic!"# + "### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n" ); } @@ -878,6 +872,7 @@ magic!"# ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), + add_generation_prompt: true, }; let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); @@ -943,9 +938,60 @@ magic!"# ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), + add_generation_prompt: true, }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]"); } + + #[test] + fn test_chat_template_valid_with_add_generation_prompt() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + + let source = r#" + {% for message in messages %} + {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}} + {% endfor %} + {% if add_generation_prompt %} + {{ '<|im_start|>assistant\n' }} + {% endif %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + Message { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + Message { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n"); + } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 07360e78..e85519cc 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -398,6 +398,7 @@ pub(crate) struct ChatTemplateInputs<'a> { messages: Vec, bos_token: Option<&'a str>, eos_token: Option<&'a str>, + add_generation_prompt: bool, } #[derive(Clone, Deserialize, ToSchema, Serialize)] From bd405e035b9e05d4b1e74e029ff1d5de86854ea0 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 8 Feb 2024 04:19:45 -0500 Subject: [PATCH 08/21] Impl simple mamba model (#1480) This draft PR is a work in progress implementation of the mamba model. This PR currently loads weights, and produces correct logits after a single pass. This PR still needs to correctly integrate this model so it produces tokens as expected, and apply optimization to avoid all copies during runtime/unnecessary operations. #### Helpful resources [Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752) https://github.com/johnma2006/mamba-minimal https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs https://github.com/huggingface/transformers/pull/28094 Notes: this dev work is currently targeting `state-spaces/mamba-130m`, so if you want to test please use that model. Additionally when starting the router the prefill needs to be limited: `cargo run -- --max-batch-prefill-tokens 768 --max-input-length 768` ## Update / Current State Integration tests have been added and basic functionality such as model loading is supported. ```bash cd integration-tests pytest -vv models/test_fused_kernel_mamba.py ``` - [x] add tests - [x] load model - [x] make simple request - [ ] resolve warmup issue - [ ] resolve output issues fetching models tested during dev ```bash text-generation-server download-weights state-spaces/mamba-130m text-generation-server download-weights state-spaces/mamba-1.4b text-generation-server download-weights state-spaces/mamba-2.8b ``` The server can be run ```bash cd server MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b ``` router ```bash cargo run ``` make a request ```bash curl -s localhost:3000/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' | jq ``` response ```json { "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data." } ``` --------- Co-authored-by: Nicolas Patry --- Dockerfile | 10 + .../__snapshots__/test_mamba/test_mamba.json | 73 ++ .../test_mamba/test_mamba_all_params.json | 99 +++ .../test_mamba/test_mamba_load.json | 398 +++++++++++ integration-tests/models/test_mamba.py | 59 ++ server/.gitignore | 1 + server/Makefile | 1 + server/Makefile-selective-scan | 28 + .../text_generation_server/models/__init__.py | 29 +- .../models/custom_modeling/mamba_modeling.py | 194 ++++++ server/text_generation_server/models/mamba.py | 656 ++++++++++++++++++ 11 files changed, 1547 insertions(+), 1 deletion(-) create mode 100644 integration-tests/models/__snapshots__/test_mamba/test_mamba.json create mode 100644 integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json create mode 100644 integration-tests/models/test_mamba.py create mode 100644 server/Makefile-selective-scan create mode 100644 server/text_generation_server/models/custom_modeling/mamba_modeling.py create mode 100644 server/text_generation_server/models/mamba.py diff --git a/Dockerfile b/Dockerfile index b6c5b2ed..6818005f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -154,6 +154,12 @@ COPY server/Makefile-vllm Makefile # Build specific version of vllm RUN make build-vllm-cuda +# Build mamba kernels +FROM kernel-builder as mamba-builder +WORKDIR /usr/src +COPY server/Makefile-selective-scan Makefile +RUN make build-all + # Build megablocks FROM kernel-builder as megablocks-builder @@ -205,6 +211,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31 # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from mamba builder +COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages +COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages + # Install flash-attention dependencies RUN pip install einops --no-cache-dir diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json new file mode 100644 index 00000000..4435f215 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.3552246, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38378906, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.140625, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.5551758, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.59033203, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.70654297, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0410156, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3789062, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0026435852, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2841797, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" +} diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json new file mode 100644 index 00000000..052c1c69 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2502, + "logprob": null, + "text": " red" + }, + { + "id": 13, + "logprob": -2.5234375, + "text": "," + }, + { + "id": 8862, + "logprob": -3.4433594, + "text": " yellow" + }, + { + "id": 13, + "logprob": -0.43017578, + "text": "," + }, + { + "id": 209, + "logprob": -8.21875, + "text": " " + } + ], + "seed": 0, + "tokens": [ + { + "id": 187, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 395, + "logprob": -0.46411133, + "special": false, + "text": "and" + }, + { + "id": 13735, + "logprob": -2.1132812, + "special": false, + "text": " orange" + }, + { + "id": 313, + "logprob": -1.2128906, + "special": false, + "text": " (" + }, + { + "id": 249, + "logprob": -2.3671875, + "special": false, + "text": "in" + }, + { + "id": 253, + "logprob": 0.0, + "special": false, + "text": " the" + }, + { + "id": 1340, + "logprob": -1.640625, + "special": false, + "text": " order" + }, + { + "id": 597, + "logprob": -0.5488281, + "special": false, + "text": " they" + }, + { + "id": 3176, + "logprob": -0.48608398, + "special": false, + "text": " appear" + }, + { + "id": 275, + "logprob": 0.0, + "special": false, + "text": " in" + } + ], + "top_tokens": null + }, + "generated_text": "blue, red, yellow, \nand orange (in the order they appear in" +} diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json new file mode 100644 index 00000000..014210b2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json @@ -0,0 +1,398 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.8125, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.828125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -3.0, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1484375, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.3552246, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38378906, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.1279297, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.5595703, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.60253906, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.7050781, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0488281, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3808594, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0026416779, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2851562, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.78027344, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.8203125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.9902344, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1523438, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.35351562, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38256836, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.1269531, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.54541016, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.59765625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.7001953, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0585938, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3789062, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0027446747, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2851562, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.78027344, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.8203125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.9902344, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1523438, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.35351562, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38256836, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.1269531, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.54541016, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.59765625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.7001953, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0585938, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3789062, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0027446747, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2851562, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.78027344, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.8203125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.9902344, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1523438, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.35351562, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38256836, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.1269531, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.54541016, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.59765625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.7001953, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0585938, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3789062, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0027446747, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2851562, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + } +] diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py new file mode 100644 index 00000000..d86faeff --- /dev/null +++ b/integration-tests/models/test_mamba.py @@ -0,0 +1,59 @@ +import pytest + + +@pytest.fixture(scope="module") +def fused_kernel_mamba_handle(launcher): + with launcher("state-spaces/mamba-130m", num_shard=1) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def fused_kernel_mamba(fused_kernel_mamba_handle): + await fused_kernel_mamba_handle.health(300) + return fused_kernel_mamba_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_mamba(fused_kernel_mamba, response_snapshot): + response = await fused_kernel_mamba.generate( + "What is Deep Learning?", max_new_tokens=10 + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == "\n\nDeep learning is a new type of machine" + assert response == response_snapshot + +@pytest.mark.asyncio +@pytest.mark.private +async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): + response = await fused_kernel_mamba.generate( + "blue, red, yellow, ", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in" + assert response == response_snapshot + +@pytest.mark.asyncio +@pytest.mark.private +async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): + responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert responses[0].generated_text == "\n\nDeep learning is a new type of machine" + + assert responses == response_snapshot diff --git a/server/.gitignore b/server/.gitignore index dcb8fe67..576746ee 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -161,3 +161,4 @@ flash-attention-v2/ vllm/ llm-awq/ eetq/ +mamba/ diff --git a/server/Makefile b/server/Makefile index b1926828..31d55c41 100644 --- a/server/Makefile +++ b/server/Makefile @@ -3,6 +3,7 @@ include Makefile-flash-att-v2 include Makefile-vllm include Makefile-awq include Makefile-eetq +include Makefile-selective-scan unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-selective-scan b/server/Makefile-selective-scan new file mode 100644 index 00000000..f4dec868 --- /dev/null +++ b/server/Makefile-selective-scan @@ -0,0 +1,28 @@ +selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137 + +causal-conv1d: + rm -rf causal-conv1d + git clone https://github.com/Dao-AILab/causal-conv1d.git + +build-causal-conv1d: causal-conv1d + cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag + cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build + +install-causal-conv1d: build-causal-conv1d + pip uninstall causal-conv1d -y || true + cd causal-conv1d/ && pip install . + +# selective-scan dependends on causal-conv1d +selective-scan: + rm -rf mamba + git clone https://github.com/state-spaces/mamba.git mamba + +build-selective-scan: selective-scan + cd mamba/ && git fetch && git checkout $(selective_scan_commit) + cd mamba && python setup.py build + +install-selective-scan: install-causal-conv1d build-selective-scan + pip uninstall selective-scan-cuda -y || true + cd mamba && pip install . + +build-all: build-causal-conv1d build-selective-scan \ No newline at end of file diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 68096709..a952f060 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -76,6 +76,15 @@ if FLASH_ATTENTION: __all__.append(FlashMixtral) __all__.append(FlashPhi) +MAMBA_AVAILABLE = True +try: + from text_generation_server.models.mamba import Mamba +except ImportError as e: + logger.warning(f"Could not import Mamba: {e}") + MAMBA_AVAILABLE = False + +if MAMBA_AVAILABLE: + __all__.append(Mamba) def get_model( model_id: str, @@ -164,7 +173,25 @@ def get_model( if speculate > 0: logger.info(f"Using speculation {method} with {speculate} input ids.") - model_type = config_dict["model_type"] + model_type = config_dict.get("model_type", None) + if model_type is None: + # TODO: fix how we determine model type for Mamba + if "ssm_cfg" in config_dict: + # *only happens in Mamba case + model_type = "ssm" + else: + raise RuntimeError( + f"Could not determine model type for {model_id} revision {revision}" + ) + + if model_type == "ssm": + return Mamba( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type == "gpt_bigcode": if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py new file mode 100644 index 00000000..1773f04d --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -0,0 +1,194 @@ +import torch +import torch.distributed + +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.utils.generation import InferenceParams +from torch import nn +from typing import Optional, Tuple, Any +from transformers.configuration_utils import PretrainedConfig +import torch.nn.functional as F + +from text_generation_server.utils.layers import ( + TensorParallelEmbedding, + FastRMSNorm, + FastLinear, +) + +from einops import rearrange +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +import math + +class MambaConfig(PretrainedConfig): + def __init__( + self, + vocab_size=50280, + d_model=768, + d_state=16, + n_layer=32, + layer_norm_epsilon=1e-5, + tie_word_embeddings=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + expand=2, + dt_rank="auto", + **kwargs, + ): + self.vocab_size = vocab_size + self.n_layer = n_layer + self.layer_norm_epsilon = layer_norm_epsilon + self.d_model = d_model + self.d_inner = d_model * 2 + self.d_conv = 4 + self.d_state = d_state + self.expand = expand + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + +class MambaBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.layer_idx = int(prefix.split(".")[2]) + self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False) + self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False) + self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True) + self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False) + self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False) + self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True) + self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float()) + self.D = weights.get_tensor(f"{prefix}.D") + self.activation = "silu" + self.dt_rank = config.dt_rank + self.d_state = config.d_state + self.d_conv = config.d_conv + self.act = nn.SiLU() + + # inference_params + def forward(self, hidden_states: torch.Tensor, inference_params=None): + _, seqlen, _ = hidden_states.shape + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + + if inference_params.seqlen_offset > 0: + out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state) + return out, conv_state, ssm_state + + projected_states = self.in_proj(hidden_states).transpose(1,2) + x, z = projected_states.chunk(2, dim=1) + conv_state = F.pad(x, (self.d_conv - seqlen, 0)) + x = causal_conv1d_fn( + x=x, + weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + bias=self.conv1d.bias, + activation=self.activation, + ) + + # We're careful here about the layout, to avoid extra transposes. + # We want dt to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + y, last_state = selective_scan_fn( + x, + dt, + self.negA, + B, + C, + self.D.float(), + z=z, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=True, + ) + y = rearrange(y, "b d l -> b l d") + attn_outputs = self.out_proj(y) + return attn_outputs, conv_state, last_state + + def step(self, hidden_states, conv_state, ssm_state): + _xz = self.in_proj(hidden_states) + _x, _z = _xz.chunk(2, dim=-1) # (B D) + conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1) + conv_out = causal_conv1d_fn( + x=conv_state_new, + weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + bias=self.conv1d.bias, + activation=self.activation + ) + conv_state = conv_state_new[:, :, 1:] + bsz, seqlen, dim = hidden_states.shape + output_tensor = torch.zeros( + (bsz, seqlen, dim), + device=hidden_states.device, + dtype=hidden_states.dtype + ) + for i in range(0, bsz): + x = conv_out[i:i+1,:,-1] + z = _z[i:i+1, -1, :] + x_db = self.x_proj(x) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = F.linear(dt, self.dt_proj.weight) + y = selective_state_update( + ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ) + out = self.out_proj(y) + output_tensor[i] = out + + return output_tensor, conv_state, ssm_state + + + +class ResidualBlock(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights) + self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] = None, + inference_params: Optional[Any] = None, + ): + residual = (hidden_states + residual) if residual is not None else hidden_states + shape = residual.shape + hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1])) + hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params) + return hidden_states, residual, conv_state, last_ssm_state + +class MambaModel(nn.Module): + def __init__(self, config, weights): + super().__init__() + prefix = "backbone" + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) + self.blocks = nn.ModuleList( + [ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)] + ) + self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon) + self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False) + self.config = config + + def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: + hidden_states = self.embed_tokens(input_ids) + for block in self.blocks: + hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params) + inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state) + + hidden_states = hidden_states + residual if residual is not None else hidden_states + hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) + hidden_states = hidden_states.view(residual.shape) + logits = self.lm_head(hidden_states) + + # update the offset for the next inference using these params + inference_params.seqlen_offset += input_ids.size(1) + return logits, input_ids, inference_params \ No newline at end of file diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py new file mode 100644 index 00000000..c10910aa --- /dev/null +++ b/server/text_generation_server/models/mamba.py @@ -0,0 +1,656 @@ +import torch +import torch.distributed +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from typing import Optional +from text_generation_server.models.custom_modeling.mamba_modeling import ( + MambaConfig, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) +import time +from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel +from text_generation_server.models import Model +from typing import Any, List, Optional, Tuple, Type, Dict +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.utils.tokens import batch_top_tokens, Sampling +from dataclasses import dataclass +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from mamba_ssm.utils.generation import InferenceParams + +@dataclass +class MambaBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] + + # Decoder values + input_ids: torch.Tensor + + # All tokens + all_input_ids: List[torch.Tensor] + + # Lengths of all generations present in the batch + input_lengths: List[int] + prefix_offsets: List[int] + read_offsets: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor + + # Metadata used for padding + max_input_length: int + padding_right_offset: int + + # Maximum number of tokens this batch will grow to + max_tokens: int + + # Past metadata + keys_head_dim_last: bool = True + + # Inference params + inference_params: Optional[Dict[str, Any]] = None + + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.id for r in self.requests], + size=len(self), + max_tokens=self.max_tokens, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "MambaBatch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + prefix_offsets = [] + read_offsets = [] + requests_idx_mapping = {} + + # Parse batch + max_truncation = 0 + padding_right_offset = 0 + max_decode_tokens = 0 + for i, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = i + inputs.append(r.inputs) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) + max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) + + tokenized_inputs = tokenizer( + inputs, + return_tensors="pt", + padding=True, + return_token_type_ids=False, + truncation=True, + max_length=max_truncation, + ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append(input_len - 5) + read_offsets.append(input_len) + + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + input_ids = tokenized_inputs["input_ids"] + all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + # past_input_ids=None, + all_input_ids=list(all_input_ids), + input_lengths=input_lengths.tolist(), + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length.item(), + padding_right_offset=padding_right_offset, + max_tokens=max_tokens, + ) + + def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + requests = [] + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + max_input_length = 0 + + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + + total_remaining_decode_tokens = 0 + new_padding_right_offset = 0 + + indices = [] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + keep_indices.append(idx) + + requests.append(self.requests[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) + all_input_ids.append(self.all_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max(max_input_length, request_input_length) + indices.append(idx) + + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(self.top_n_tokens[idx]) + remaining_decode_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) + total_remaining_decode_tokens += remaining_decode_tokens + new_padding_right_offset = max( + new_padding_right_offset, remaining_decode_tokens + ) + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + input_ids = self.input_ids[keep_indices] + + top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] + max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = input_ids + self.all_input_ids = all_input_ids + self.input_lengths = input_lengths + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.top_n_tokens = top_n_tokens + self.top_n_tokens_tensor = top_n_tokens_tensor + self.max_input_length = max_input_length + self.padding_right_offset = new_padding_right_offset + self.max_tokens = max_tokens + + # TODO + # Kept it simple by just updating the state, maybe updating the other CPU values is necessary. + key_value_memory_dict = {} + for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items(): + key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices]) + self.inference_params.key_value_memory_dict = key_value_memory_dict + + return self + + @classmethod + def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch": + # Used for padding + total_batch_size = 0 + max_input_length = 0 + padding_right_offset = 0 + for batch in batches: + total_batch_size += len(batch) + max_input_length = max(max_input_length, batch.max_input_length) + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) + + # Batch attributes + requests = [] + requests_idx_mapping = {} + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + max_tokens = 0 + max_seqlen = 0 + batch_size = 0 + seqlen_offset = 0 + + # Batch tensors + input_ids = None + top_n_tokens_tensor = None + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + top_n_tokens.extend(batch.top_n_tokens) + + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + # Slicing end index for this batch + end_index = start_index + len(batch) + + # Create empty tensor + # input_ids is always of shape [batch_size, 1] + # We do not need to pad it + if input_ids is None: + input_ids = batch.input_ids.new_empty((total_batch_size, 1)) + # Copy to correct indices + input_ids[start_index:end_index] = batch.input_ids + + if top_n_tokens_tensor is None: + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length - batch.max_input_length + ) * len(batch) + + max_seqlen = max(max_seqlen, batch.inference_params.max_seqlen) + seqlen_offset = max(seqlen_offset, batch.inference_params.seqlen_offset) + batch_size += batch.inference_params.max_batch_size + + start_index = end_index + + + (_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape + (_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape + n_blocks = len(batches[0].inference_params.key_value_memory_dict) + dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype + device = batches[0].inference_params.key_value_memory_dict[0][0].device + + key_value_memory_dict = {} + for i in range(n_blocks): + conv_state = torch.zeros( + batch_size, + d_model, + d_conv, + device=device, + dtype=dtype, + ) + ssm_state = torch.zeros( + batch_size, + d_model, + d_state, + device=device, + dtype=dtype, + ) + key_value_memory_dict[i] = (conv_state, ssm_state) + lengths_per_sample = torch.zeros(batch_size, dtype=torch.int32, device=device) + + inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_offset, + key_value_memory_dict=key_value_memory_dict, + lengths_per_sample=lengths_per_sample, + ) + + current_batch = 0 + for batch in batches: + for i in range(n_blocks): + conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i] + batch_size = batch.inference_params.max_batch_size + inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state + inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state + inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample + current_batch += batch_size + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + all_input_ids=all_input_ids, + input_lengths=input_lengths, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length, + padding_right_offset=padding_right_offset, + keys_head_dim_last=batches[0].keys_head_dim_last, + max_tokens=max_tokens, + inference_params=inference_params + ) + + def __len__(self): + return len(self.requests) + +class Mamba(Model): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, _rank, _world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/gpt-neox-20b", + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + config = MambaConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + + tokenizer.bos_token_id = config.bos_token_id + tokenizer.eos_token_id = config.eos_token_id + tokenizer.pad_token = tokenizer.eos_token + + config.quantize = quantize + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + model = MambaModel(config, weights) + torch.distributed.barrier(group=self.process_group) + super(Mamba, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + + @property + def batch_type(self) -> Type[MambaBatch]: + return MambaBatch + + def warmup(self, batch) -> Optional[int]: + # TODO: implement warmup for Mamba if needed + return None + + def forward( + self, + input_ids: torch.Tensor, + past: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return self.model( + input_ids, + past=past, + ) + + def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: + start = time.time_ns() + input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids + + batch_size = input_ids.shape[0] + max_seqlen = input_ids.shape[1] + dtype = input_ids.dtype + + # Inference params + seqlen_og = 0 + inf_cache = {} + lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen + + if batch.inference_params is None: + inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_og, + key_value_memory_dict=inf_cache, + lengths_per_sample=lengths_per_sample, + ) + + # Allocate inference cache + for res_block in self.model.blocks: + block = res_block.mamba_block + conv_state = torch.zeros( + batch_size, + self.model.config.d_model * self.model.config.expand, + self.model.config.d_conv, + device=block.conv1d.weight.device, + dtype=block.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.model.config.d_model * self.model.config.expand, + self.model.config.d_state, + device=block.dt_proj.weight.device, + dtype=block.dt_proj.weight.dtype, + ) + inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state) + batch.inference_params = inference_params + + # Forward pass + logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params) + + batch.inference_params = new_inference_params + # Results + generations: List[Generation] = [] + stopped = True + + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + torch.log_softmax(logits[:, -1], -1), + accepted_ids, + ) + + start_decode = time.time_ns() + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + batch.top_n_tokens, + batch_top_token_ids, + batch_top_token_logprobs, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + top_n_tokens, + top_token_ids, + top_token_logprobs, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + + # Append next token to all tokens + all_input_ids = torch.cat([all_input_ids, next_token_id]) + new_input_length = input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_squeezed, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None + + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = Tokens( + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + + if top_n_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + else: + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # Update values + batch.input_ids[i, 0] = next_token_id + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) + + # We finished all generations in the batch; there is no next batch + if stopped: + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) + + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns) From 39af000cb9bfa12628c21c6e65d3ce8930250171 Mon Sep 17 00:00:00 2001 From: Jason Stillerman Date: Thu, 8 Feb 2024 06:44:04 -0500 Subject: [PATCH 09/21] Update to peft 0.8.2 (#1537) # What does this PR do? ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [x] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [x] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [x] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @OlivierDehaene OR @Narsil --- server/poetry.lock | 17 ++++++++++------- server/pyproject.toml | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/server/poetry.lock b/server/poetry.lock index 64b1b74f..32031f89 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "accelerate" @@ -1589,30 +1589,32 @@ xml = ["lxml (>=4.9.2)"] [[package]] name = "peft" -version = "0.4.0" +version = "0.8.2" description = "Parameter-Efficient Fine-Tuning (PEFT)" optional = true python-versions = ">=3.8.0" files = [ - {file = "peft-0.4.0-py3-none-any.whl", hash = "sha256:2cf992772a6d703814477e0bdcdadd68cb8ea388111ce2d793dd2ff0e438f357"}, - {file = "peft-0.4.0.tar.gz", hash = "sha256:e768fa22d6e9f32aa7e891f0d06f355960278ca4dc0cdd96bff71f6f06269207"}, + {file = "peft-0.8.2-py3-none-any.whl", hash = "sha256:4a9c81c38e689fd4043b2757cd0e2b526a9b8b8fd04f8442df2c4824b32c2505"}, + {file = "peft-0.8.2.tar.gz", hash = "sha256:bbdf61db2d8ca503e894edc64016038e6f34b7b522374bad09a22af41882e7ac"}, ] [package.dependencies] -accelerate = "*" +accelerate = ">=0.21.0" +huggingface-hub = ">=0.17.0" numpy = ">=1.17" packaging = ">=20.0" psutil = "*" pyyaml = "*" safetensors = "*" torch = ">=1.13.0" +tqdm = "*" transformers = "*" [package.extras] dev = ["black (>=22.0,<23.0)", "hf-doc-builder", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"] docs-specific = ["hf-doc-builder"] quality = ["black (>=22.0,<23.0)", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"] -test = ["black (>=22.0,<23.0)", "datasets", "diffusers", "hf-doc-builder", "parameterized", "pytest", "pytest-cov", "pytest-xdist", "ruff (>=0.0.241)", "urllib3 (<=2.0.0)"] +test = ["black (>=22.0,<23.0)", "datasets", "diffusers (<0.21.0)", "hf-doc-builder", "parameterized", "pytest", "pytest-cov", "pytest-xdist", "ruff (>=0.0.241)", "scipy", "urllib3 (<=2.0.0)"] [[package]] name = "pillow" @@ -1893,6 +1895,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -2962,4 +2965,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "33d533d21d14c258678a8c4bb28e2a15e8ebe5ca35d8589cbfe4a7b7d2e79a90" +content-hash = "f7529125bdd7ce142082ce4969edbda5d9b67b6209f199194c54198829f5dc64" diff --git a/server/pyproject.toml b/server/pyproject.toml index 72a7afb0..b8ebf2e3 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -30,7 +30,7 @@ transformers = "^4.37.1" einops = "^0.6.1" texttable = { version = "^1.6.7", optional = true } datasets = { version = "^2.14.0", optional = true } -peft = { version = "^0.4.0", optional = true } +peft = { version = "^0.8.2", optional = true } torch = { version = "^2.1.1", optional = true } scipy = "^1.11.1" pillow = "^10.0.0" From 461dd6f1c7d6ef1cdbe8363ce71d198c06c6f390 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 1 Feb 2024 12:05:36 +0100 Subject: [PATCH 10/21] fix exllama overflows --- .../exllama_kernels/cuda_func/q4_matmul.cu | 8 ++++---- .../exllama_kernels/cuda_func/q4_matrix.cu | 2 +- .../cuda/q_gemm_kernel_gptq.cuh | 16 ++++++++-------- .../exllamav2_kernels/cuda/q_matrix.cu | 16 ++++++++-------- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu index 61380f42..09126efe 100644 --- a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu @@ -85,7 +85,7 @@ __global__ void q4_matmul_kernel if constexpr (use_half2) { half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); @@ -93,7 +93,7 @@ __global__ void q4_matmul_kernel else { half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); @@ -110,7 +110,7 @@ __global__ void q4_matmul_kernel { int group = k / groupsize; half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); @@ -119,7 +119,7 @@ __global__ void q4_matmul_kernel { int group = k / groupsize; half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu index f3d1564f..2867a8d0 100644 --- a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu @@ -189,7 +189,7 @@ __global__ void reconstruct_kernel int group = row / groupsize; half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; + uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0F; uint32_t w_read = w_.item_uint32_t(row, column); half* out_ptr = out_.item_ptr(row, column); diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh index 74b0db2b..f816fd9d 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh @@ -152,10 +152,10 @@ __global__ void gemm_half_q_half_gptq_kernel half2 y1y16[4][2]; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); // __syncthreads(); @@ -174,10 +174,10 @@ __global__ void gemm_half_q_half_gptq_kernel nextgroup += groupsize; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); } #pragma unroll diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu index ae08cc1f..7a0038b4 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu @@ -237,10 +237,10 @@ __global__ void reconstruct_gptq_kernel half2 y1y16[4][2]; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); __syncthreads(); @@ -255,10 +255,10 @@ __global__ void reconstruct_gptq_kernel nextgroup += groupsize; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); } for (int p = 0; p < 4; p++) From 75086526d3ef320cfaa1905e88c0dff044070326 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 1 Feb 2024 12:06:02 +0100 Subject: [PATCH 11/21] awq fallback to exllama --- .../utils/awq/pack_utils.py | 146 ++++++++++++++++++ server/text_generation_server/utils/layers.py | 23 ++- 2 files changed, 161 insertions(+), 8 deletions(-) create mode 100644 server/text_generation_server/utils/awq/pack_utils.py diff --git a/server/text_generation_server/utils/awq/pack_utils.py b/server/text_generation_server/utils/awq/pack_utils.py new file mode 100644 index 00000000..9b15e1db --- /dev/null +++ b/server/text_generation_server/utils/awq/pack_utils.py @@ -0,0 +1,146 @@ +import torch +from typing import List + + +AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] +REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + + +def pack(imatrix: torch.Tensor, direction: str = "column"): + """ + Packs a 4-bit integer matrix into a packed 32-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of packing, either "column" or "row" + Returns: + qmatrix (torch.Tensor): packed matrix of integers + """ + shifts = torch.arange(0, 32, 4, device=imatrix.device) + + imatrix = imatrix.to(torch.int8) + imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow + + if direction == "column": + imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1) + + elif direction == "row": + imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1) + + qmatrix = qmatrix.to(torch.int32) + + return qmatrix + + +def unpack(qmatrix: torch.Tensor, direction: str = "column"): + """ + Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix. + Args: + qmatrix (torch.Tensor): matrix of packed integers + direction (str): direction of unpacking, either "column" or "row" + Returns: + imatrix (torch.Tensor): matrix of integers + """ + shifts = torch.arange(0, 32, 4, device=qmatrix.device) + + if direction == "column": + imatrix = torch.bitwise_right_shift( + qmatrix[:, :, None], shifts[None, None, :] + ).view(qmatrix.shape[0], -1) + + elif direction == "row": + imatrix = torch.bitwise_right_shift( + qmatrix[:, None, :], shifts[None, :, None] + ).view(-1, qmatrix.shape[-1]) + + imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow + + return imatrix + + +def quantize(fmatrix, scales, zeros, group_size): + """ + Quantizes a matrix of 16-bit floats into a matrix of 4-bit integers. + Args: + fmatrix (torch.Tensor): matrix of 16-bit floats + scales (torch.Tensor): matrix of 16-bit floats + zeros (torch.Tensor): matrix of 4-bit integers + group_size (int): group size + Returns: + imatrix (torch.Tensor): matrix of 4-bit integers + """ + zeros = zeros.to(torch.int8) & 0x0F + + imatrix = torch.round( + ( + fmatrix / scales.repeat_interleave(group_size, dim=0) + + zeros.repeat_interleave(group_size, dim=0) + ) + ) + + imatrix = imatrix.to(torch.int8) & 0x0F + + return imatrix + + +def dequantize(imatrix, scales, zeros, group_size): + """ + Dequantizes a 4-bit integer matrix into a float matrix. + Args: + imatrix (torch.Tensor): matrix of 4-bit integers + scales (torch.Tensor): matrix of 16-bit floats + zeros (torch.Tensor): matrix of 4-bit integers + group_size (int): group size + Returns: + fmatrix (torch.Tensor): matrix of 16-bit floats + """ + zeros = zeros.to(torch.int8) & 0x0F + imatrix = imatrix.to(torch.int8) & 0x0F + + fmatrix = ( + imatrix - zeros.repeat_interleave(group_size, dim=0) + ) * scales.repeat_interleave(group_size, dim=0) + + fmatrix = fmatrix.to(torch.float16) + + return fmatrix + + +def apply_order( + imatrix: torch.Tensor, + direction: str = "column", + order: List[int] = AWQ_PACK_ORDER, +): + """ + Applies the order to a 4-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of applying order, either "column" or "row" + order (List[int]): order to apply, default is AWQ_PACK_ORDER + Returns: + imatrix (torch.Tensor): matrix of integers + """ + if direction == "column": + imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape) + elif direction == "row": + imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape) + + return imatrix + + +def fast_awq_to_exllama(qweight, qzeros): + # awq uses column packing for both weights and zeros + izeros = unpack(qzeros, direction="column") + iweights = unpack(qweight, direction="column") + + # Reverse the order of the iweight and izeros tensors + izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER) + iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER) + # Subtract 1 from the izeros tensor (exllama adds 1 during inference) + izeros = izeros - 1 + # exllama uses row packing for weights and column packing for zeros + qzeros = pack(izeros, direction="column") + qweight = pack(iweights, direction="row") + + return qweight, qzeros diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 010d6143..782744ed 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -25,6 +25,7 @@ HAS_AWQ = True try: from text_generation_server.utils.awq.quantize.qmodule import WQLinear except ImportError: + from text_generation_server.utils.awq.pack_utils import fast_awq_to_exllama HAS_AWQ = False try: @@ -349,14 +350,20 @@ def get_linear(weight, bias, quantize): raise NotImplementedError( f"The passed weight is not `awq` compatible, loader needs to be updated." ) - linear = WQLinear( - w_bit=bits, - group_size=groupsize, - qweight=qweight, - qzeros=qzeros, - scales=scales, - bias=bias is not None, - ) + if HAS_AWQ: + linear = WQLinear( + w_bit=bits, + group_size=groupsize, + qweight=qweight, + qzeros=qzeros, + scales=scales, + bias=bias is not None, + ) + elif HAS_EXLLAMA: + qweight, qzeros = fast_awq_to_exllama(qweight, qzeros) + linear = ExllamaQuantLinear( + qweight, qzeros, scales, None, bias, bits, groupsize + ) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear From aa2014fc79cb3a5e7764bcf5d383dda8a47179c0 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 1 Feb 2024 12:48:17 +0100 Subject: [PATCH 12/21] post process exllama model --- server/text_generation_server/server.py | 33 +++++++++++++++---------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index d5adbd32..08d672f3 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -63,20 +63,27 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - if self.quantize == "gptq": - try: - # When using GPTQ, Exllama kernels need some global kernels - # For which we have the finale shapes only after the model has loaded - # This will allocate those buffers. - from text_generation_server.utils.layers import ( - create_exllama_buffers, - set_device, - ) + if self.quantize in ["gptq", "awq"]: + has_exllama_layers = False + for _, module in self.model.model.named_modules(): + if hasattr(module, "QUANT_TYPE"): + has_exllama_layers = True + break - set_device(self.model.device) - create_exllama_buffers(request.max_prefill_tokens) - except ImportError: - pass + if has_exllama_layers: + try: + # When using GPTQ or AWQ, Exllama kernels need some global kernels + # For which we have the finale shapes only after the model has loaded + # This will allocate those buffers. + from text_generation_server.utils.layers import ( + create_exllama_buffers, + set_device, + ) + + set_device(self.model.device) + create_exllama_buffers(request.max_prefill_tokens) + except ImportError: + pass if ( self.model.batch_type == IdeficsCausalLMBatch From 3963074cebf55e8113256b0e5c656607905a7b8c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 1 Feb 2024 13:30:13 +0000 Subject: [PATCH 13/21] add triton fallback to awq --- .../utils/awq/pack_utils.py | 56 ++----------------- server/text_generation_server/utils/layers.py | 8 ++- 2 files changed, 10 insertions(+), 54 deletions(-) diff --git a/server/text_generation_server/utils/awq/pack_utils.py b/server/text_generation_server/utils/awq/pack_utils.py index 9b15e1db..d144b3cd 100644 --- a/server/text_generation_server/utils/awq/pack_utils.py +++ b/server/text_generation_server/utils/awq/pack_utils.py @@ -15,10 +15,10 @@ def pack(imatrix: torch.Tensor, direction: str = "column"): Returns: qmatrix (torch.Tensor): packed matrix of integers """ - shifts = torch.arange(0, 32, 4, device=imatrix.device) - imatrix = imatrix.to(torch.int8) imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow + + shifts = torch.arange(0, 32, 4, device=imatrix.device) if direction == "column": imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) @@ -59,54 +59,6 @@ def unpack(qmatrix: torch.Tensor, direction: str = "column"): return imatrix -def quantize(fmatrix, scales, zeros, group_size): - """ - Quantizes a matrix of 16-bit floats into a matrix of 4-bit integers. - Args: - fmatrix (torch.Tensor): matrix of 16-bit floats - scales (torch.Tensor): matrix of 16-bit floats - zeros (torch.Tensor): matrix of 4-bit integers - group_size (int): group size - Returns: - imatrix (torch.Tensor): matrix of 4-bit integers - """ - zeros = zeros.to(torch.int8) & 0x0F - - imatrix = torch.round( - ( - fmatrix / scales.repeat_interleave(group_size, dim=0) - + zeros.repeat_interleave(group_size, dim=0) - ) - ) - - imatrix = imatrix.to(torch.int8) & 0x0F - - return imatrix - - -def dequantize(imatrix, scales, zeros, group_size): - """ - Dequantizes a 4-bit integer matrix into a float matrix. - Args: - imatrix (torch.Tensor): matrix of 4-bit integers - scales (torch.Tensor): matrix of 16-bit floats - zeros (torch.Tensor): matrix of 4-bit integers - group_size (int): group size - Returns: - fmatrix (torch.Tensor): matrix of 16-bit floats - """ - zeros = zeros.to(torch.int8) & 0x0F - imatrix = imatrix.to(torch.int8) & 0x0F - - fmatrix = ( - imatrix - zeros.repeat_interleave(group_size, dim=0) - ) * scales.repeat_interleave(group_size, dim=0) - - fmatrix = fmatrix.to(torch.float16) - - return fmatrix - - def apply_order( imatrix: torch.Tensor, direction: str = "column", @@ -129,7 +81,7 @@ def apply_order( return imatrix -def fast_awq_to_exllama(qweight, qzeros): +def fast_awq_to_gptq(qweight, qzeros): # awq uses column packing for both weights and zeros izeros = unpack(qzeros, direction="column") iweights = unpack(qweight, direction="column") @@ -137,7 +89,7 @@ def fast_awq_to_exllama(qweight, qzeros): # Reverse the order of the iweight and izeros tensors izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER) iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER) - # Subtract 1 from the izeros tensor (exllama adds 1 during inference) + # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros) izeros = izeros - 1 # exllama uses row packing for weights and column packing for zeros qzeros = pack(izeros, direction="column") diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 782744ed..7f20081d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -25,7 +25,8 @@ HAS_AWQ = True try: from text_generation_server.utils.awq.quantize.qmodule import WQLinear except ImportError: - from text_generation_server.utils.awq.pack_utils import fast_awq_to_exllama + from text_generation_server.utils.awq.pack_utils import fast_awq_to_gptq + HAS_AWQ = False try: @@ -360,10 +361,13 @@ def get_linear(weight, bias, quantize): bias=bias is not None, ) elif HAS_EXLLAMA: - qweight, qzeros = fast_awq_to_exllama(qweight, qzeros) + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) linear = ExllamaQuantLinear( qweight, qzeros, scales, None, bias, bits, groupsize ) + else: + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + linear = QuantLinear(qweight, qzeros, scales, None, bias, bits, groupsize) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear From 3ceeb858420deae3e8779caed9f6ee660ba55bb1 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 1 Feb 2024 13:30:43 +0000 Subject: [PATCH 14/21] fix missing g_idx and eventual overflow in triton kernel --- .../utils/gptq/quant_linear.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index bfc91c00..34895c01 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -182,7 +182,7 @@ try: ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = zeros + 1 + zeros = (zeros + 1) & maxq # add 1 and avoid overflow a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated @@ -251,7 +251,17 @@ class QuantLinear(nn.Module): self.register_buffer("qweight", qweight) self.register_buffer("qzeros", qzeros) self.register_buffer("scales", scales) - self.register_buffer("g_idx", g_idx) + if g_idx is not None: + self.register_buffer("g_idx", g_idx) + else: + self.register_buffer( + "g_idx", + torch.tensor( + [i // groupsize for i in range(qweight.shape[0] * 32 // bits)], + device=qweight.device, + dtype=torch.int32, + ), + ) if bias is not None: self.register_buffer("bias", bias) else: From 212fdfffad95fcb08a7e0f0863acdc3d59f808ab Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 1 Feb 2024 18:35:04 +0000 Subject: [PATCH 15/21] revert changes --- server/text_generation_server/server.py | 33 +++---- .../utils/awq/pack_utils.py | 98 ------------------- server/text_generation_server/utils/layers.py | 30 +++--- 3 files changed, 25 insertions(+), 136 deletions(-) delete mode 100644 server/text_generation_server/utils/awq/pack_utils.py diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 08d672f3..d5adbd32 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -63,27 +63,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - if self.quantize in ["gptq", "awq"]: - has_exllama_layers = False - for _, module in self.model.model.named_modules(): - if hasattr(module, "QUANT_TYPE"): - has_exllama_layers = True - break + if self.quantize == "gptq": + try: + # When using GPTQ, Exllama kernels need some global kernels + # For which we have the finale shapes only after the model has loaded + # This will allocate those buffers. + from text_generation_server.utils.layers import ( + create_exllama_buffers, + set_device, + ) - if has_exllama_layers: - try: - # When using GPTQ or AWQ, Exllama kernels need some global kernels - # For which we have the finale shapes only after the model has loaded - # This will allocate those buffers. - from text_generation_server.utils.layers import ( - create_exllama_buffers, - set_device, - ) - - set_device(self.model.device) - create_exllama_buffers(request.max_prefill_tokens) - except ImportError: - pass + set_device(self.model.device) + create_exllama_buffers(request.max_prefill_tokens) + except ImportError: + pass if ( self.model.batch_type == IdeficsCausalLMBatch diff --git a/server/text_generation_server/utils/awq/pack_utils.py b/server/text_generation_server/utils/awq/pack_utils.py deleted file mode 100644 index d144b3cd..00000000 --- a/server/text_generation_server/utils/awq/pack_utils.py +++ /dev/null @@ -1,98 +0,0 @@ -import torch -from typing import List - - -AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] -REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] - - -def pack(imatrix: torch.Tensor, direction: str = "column"): - """ - Packs a 4-bit integer matrix into a packed 32-bit integer matrix. - Args: - imatrix (torch.Tensor): matrix of integers - direction (str): direction of packing, either "column" or "row" - Returns: - qmatrix (torch.Tensor): packed matrix of integers - """ - imatrix = imatrix.to(torch.int8) - imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow - - shifts = torch.arange(0, 32, 4, device=imatrix.device) - - if direction == "column": - imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) - qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1) - - elif direction == "row": - imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1) - qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1) - - qmatrix = qmatrix.to(torch.int32) - - return qmatrix - - -def unpack(qmatrix: torch.Tensor, direction: str = "column"): - """ - Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix. - Args: - qmatrix (torch.Tensor): matrix of packed integers - direction (str): direction of unpacking, either "column" or "row" - Returns: - imatrix (torch.Tensor): matrix of integers - """ - shifts = torch.arange(0, 32, 4, device=qmatrix.device) - - if direction == "column": - imatrix = torch.bitwise_right_shift( - qmatrix[:, :, None], shifts[None, None, :] - ).view(qmatrix.shape[0], -1) - - elif direction == "row": - imatrix = torch.bitwise_right_shift( - qmatrix[:, None, :], shifts[None, :, None] - ).view(-1, qmatrix.shape[-1]) - - imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow - - return imatrix - - -def apply_order( - imatrix: torch.Tensor, - direction: str = "column", - order: List[int] = AWQ_PACK_ORDER, -): - """ - Applies the order to a 4-bit integer matrix. - Args: - imatrix (torch.Tensor): matrix of integers - direction (str): direction of applying order, either "column" or "row" - order (List[int]): order to apply, default is AWQ_PACK_ORDER - Returns: - imatrix (torch.Tensor): matrix of integers - """ - if direction == "column": - imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape) - elif direction == "row": - imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape) - - return imatrix - - -def fast_awq_to_gptq(qweight, qzeros): - # awq uses column packing for both weights and zeros - izeros = unpack(qzeros, direction="column") - iweights = unpack(qweight, direction="column") - - # Reverse the order of the iweight and izeros tensors - izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER) - iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER) - # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros) - izeros = izeros - 1 - # exllama uses row packing for weights and column packing for zeros - qzeros = pack(izeros, direction="column") - qweight = pack(iweights, direction="row") - - return qweight, qzeros diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 7f20081d..b9b1dfac 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -25,8 +25,6 @@ HAS_AWQ = True try: from text_generation_server.utils.awq.quantize.qmodule import WQLinear except ImportError: - from text_generation_server.utils.awq.pack_utils import fast_awq_to_gptq - HAS_AWQ = False try: @@ -351,23 +349,19 @@ def get_linear(weight, bias, quantize): raise NotImplementedError( f"The passed weight is not `awq` compatible, loader needs to be updated." ) - if HAS_AWQ: - linear = WQLinear( - w_bit=bits, - group_size=groupsize, - qweight=qweight, - qzeros=qzeros, - scales=scales, - bias=bias is not None, + if IS_ROCM_SYSTEM: + raise NotImplementedError( + "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " + "to use Exllama/GPTQ kernels for AWQ inference." ) - elif HAS_EXLLAMA: - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - linear = ExllamaQuantLinear( - qweight, qzeros, scales, None, bias, bits, groupsize - ) - else: - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - linear = QuantLinear(qweight, qzeros, scales, None, bias, bits, groupsize) + linear = WQLinear( + w_bit=bits, + group_size=groupsize, + qweight=qweight, + qzeros=qzeros, + scales=scales, + bias=bias is not None, + ) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear From 8074c40473ad6fb67ece92815cad0d4525cf4c3c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 1 Feb 2024 18:35:41 +0000 Subject: [PATCH 16/21] adapt awq weights to exllama/gptq kernels --- .../utils/pack_utils.py | 98 +++++++++++++++++++ .../text_generation_server/utils/weights.py | 79 ++++++++++----- 2 files changed, 155 insertions(+), 22 deletions(-) create mode 100644 server/text_generation_server/utils/pack_utils.py diff --git a/server/text_generation_server/utils/pack_utils.py b/server/text_generation_server/utils/pack_utils.py new file mode 100644 index 00000000..d144b3cd --- /dev/null +++ b/server/text_generation_server/utils/pack_utils.py @@ -0,0 +1,98 @@ +import torch +from typing import List + + +AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] +REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + + +def pack(imatrix: torch.Tensor, direction: str = "column"): + """ + Packs a 4-bit integer matrix into a packed 32-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of packing, either "column" or "row" + Returns: + qmatrix (torch.Tensor): packed matrix of integers + """ + imatrix = imatrix.to(torch.int8) + imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow + + shifts = torch.arange(0, 32, 4, device=imatrix.device) + + if direction == "column": + imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1) + + elif direction == "row": + imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1) + + qmatrix = qmatrix.to(torch.int32) + + return qmatrix + + +def unpack(qmatrix: torch.Tensor, direction: str = "column"): + """ + Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix. + Args: + qmatrix (torch.Tensor): matrix of packed integers + direction (str): direction of unpacking, either "column" or "row" + Returns: + imatrix (torch.Tensor): matrix of integers + """ + shifts = torch.arange(0, 32, 4, device=qmatrix.device) + + if direction == "column": + imatrix = torch.bitwise_right_shift( + qmatrix[:, :, None], shifts[None, None, :] + ).view(qmatrix.shape[0], -1) + + elif direction == "row": + imatrix = torch.bitwise_right_shift( + qmatrix[:, None, :], shifts[None, :, None] + ).view(-1, qmatrix.shape[-1]) + + imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow + + return imatrix + + +def apply_order( + imatrix: torch.Tensor, + direction: str = "column", + order: List[int] = AWQ_PACK_ORDER, +): + """ + Applies the order to a 4-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of applying order, either "column" or "row" + order (List[int]): order to apply, default is AWQ_PACK_ORDER + Returns: + imatrix (torch.Tensor): matrix of integers + """ + if direction == "column": + imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape) + elif direction == "row": + imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape) + + return imatrix + + +def fast_awq_to_gptq(qweight, qzeros): + # awq uses column packing for both weights and zeros + izeros = unpack(qzeros, direction="column") + iweights = unpack(qweight, direction="column") + + # Reverse the order of the iweight and izeros tensors + izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER) + iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER) + # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros) + izeros = izeros - 1 + # exllama uses row packing for weights and column packing for zeros + qzeros = pack(izeros, direction="column") + qweight = pack(iweights, direction="row") + + return qweight, qzeros diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 186733f3..aabd52f4 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -7,6 +7,7 @@ from loguru import logger from huggingface_hub import hf_hub_download import json from text_generation_server.utils.log import log_once +from text_generation_server.utils.pack_utils import fast_awq_to_gptq class Weights: @@ -46,7 +47,6 @@ class Weights: return self._handles[filename] def get_filename(self, tensor_name: str) -> (str, str): - names = [tensor_name] if self.prefix is not None: prefixed = f"{self.prefix}.{tensor_name}" @@ -157,12 +157,20 @@ class Weights: qzeros = self._get_qweight(f"{prefix}.qzeros") scales = self._get_qweight(f"{prefix}.scales") scales = scales.to(dtype=self.dtype) - if quantize == "gptq": + if quantize == "gptq" and self.quant_method == "gptq": g_idx = self.get_tensor(f"{prefix}.g_idx") else: g_idx = None - bits, groupsize, _ = self._get_gptq_params() + bits, groupsize, _, _ = self._get_gptq_params() + + if quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, + "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", + ) + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: slice_ = self._get_slice(f"{prefix}.weight") @@ -204,7 +212,7 @@ class Weights: [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 ) - if quantize == "gptq": + if quantize == "gptq" and self.quant_method == "gptq": w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) @@ -212,12 +220,20 @@ class Weights: else: g_idx = None - bits, groupsize, desc_act = self._get_gptq_params() + bits, groupsize, desc_act, quant_method = self._get_gptq_params() from text_generation_server.utils.layers import HAS_EXLLAMA use_exllama = ( bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act ) + + if quantize == "gptq" and quant_method == "awq": + log_once( + logger.info, + "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", + ) + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -243,7 +259,7 @@ class Weights: def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": use_exllama = True - bits, groupsize, desc_act = self._get_gptq_params() + bits, groupsize, desc_act, quant_method = self._get_gptq_params() if bits != 4: use_exllama = False @@ -252,8 +268,19 @@ class Weights: log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if quant_method == "gptq": + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + else: + g_idx = None + if self.process_group.size() > 1: - g_idx = self.get_tensor(f"{prefix}.g_idx") if g_idx is not None: if ( not torch.equal( @@ -269,13 +296,6 @@ class Weights: # it would require to reorder input activations that are split unto several GPUs use_exllama = False - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) - from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA if use_exllama: @@ -289,8 +309,6 @@ class Weights: else: log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - if use_exllama and groupsize != -1: qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0) @@ -298,12 +316,19 @@ class Weights: qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") - if use_exllama: + if use_exllama and g_idx is not None: g_idx = g_idx - g_idx[0] + if quant_method == "awq": + log_once( + logger.info, + "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", + ) + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) elif quantize == "awq": - bits, groupsize, _ = self._get_gptq_params() + bits, groupsize, _, _ = self._get_gptq_params() try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) @@ -331,11 +356,12 @@ class Weights: try: bits = self.gptq_bits groupsize = self.gptq_groupsize + quant_method = self.quant_method desc_act = getattr(self, "gptq_desc_act", False) except Exception: raise e - return bits, groupsize, desc_act + return bits, groupsize, desc_act, quant_method def _set_gptq_params(self, model_id, revision): filename = "config.json" @@ -350,7 +376,8 @@ class Weights: data = json.load(f) self.gptq_bits = data["quantization_config"]["bits"] self.gptq_groupsize = data["quantization_config"]["group_size"] - self.gptq_desc_act = data["quantization_config"]["desc_act"] + self.gptq_desc_act = data["quantization_config"].get("desc_act", False) + self.quant_method = data["quantization_config"]["quant_method"] except Exception: filename = "quantize_config.json" try: @@ -364,7 +391,11 @@ class Weights: data = json.load(f) self.gptq_bits = data["bits"] self.gptq_groupsize = data["group_size"] - self.gptq_desc_act = data["desc_act"] + self.gptq_desc_act = data.get("desc_act", False) + if "version" in data and data["version"] == "GEMM": + self.quant_method = "awq" + else: + self.quant_method = "gptq" except Exception: filename = "quant_config.json" try: @@ -378,6 +409,10 @@ class Weights: data = json.load(f) self.gptq_bits = data["w_bit"] self.gptq_groupsize = data["q_group_size"] - self.gptq_desc_act = data["desc_act"] + self.gptq_desc_act = data.get("desc_act", False) + if "version" in data and data["version"] == "GEMM": + self.quant_method = "awq" + else: + self.quant_method = "gptq" except Exception: pass From 646ab282855e39ab674619d13bf609791d123cb5 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 1 Feb 2024 19:37:02 +0000 Subject: [PATCH 17/21] typing --- server/text_generation_server/utils/weights.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index aabd52f4..875ac464 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -154,17 +154,18 @@ class Weights: f"Cannot load `{quantize}` weight, make sure the model is already quantized." ) + bits, groupsize, _, quant_method = self._get_gptq_params() + qzeros = self._get_qweight(f"{prefix}.qzeros") scales = self._get_qweight(f"{prefix}.scales") scales = scales.to(dtype=self.dtype) - if quantize == "gptq" and self.quant_method == "gptq": + + if quantize == "gptq" and quant_method == "gptq": g_idx = self.get_tensor(f"{prefix}.g_idx") else: g_idx = None - bits, groupsize, _, _ = self._get_gptq_params() - - if quantize == "gptq" and self.quant_method == "awq": + if quantize == "gptq" and quant_method == "awq": log_once( logger.info, "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", @@ -212,7 +213,9 @@ class Weights: [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 ) - if quantize == "gptq" and self.quant_method == "gptq": + bits, groupsize, desc_act, quant_method = self._get_gptq_params() + + if quantize == "gptq" and quant_method == "gptq": w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) @@ -220,7 +223,6 @@ class Weights: else: g_idx = None - bits, groupsize, desc_act, quant_method = self._get_gptq_params() from text_generation_server.utils.layers import HAS_EXLLAMA use_exllama = ( @@ -347,7 +349,7 @@ class Weights: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight - def _get_gptq_params(self) -> Tuple[int, int, int]: + def _get_gptq_params(self) -> Tuple[int, int, int, str]: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() From bbe5bedea5f07aeee0ca8954d03ec8e7fbf0a950 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 2 Feb 2024 14:34:15 +0100 Subject: [PATCH 18/21] pass g_idx instead of changing triton kernel --- .../conversion_utils.py} | 7 +- .../utils/gptq/quant_linear.py | 14 +--- .../text_generation_server/utils/weights.py | 66 +++++++++++++------ 3 files changed, 50 insertions(+), 37 deletions(-) rename server/text_generation_server/utils/{pack_utils.py => awq/conversion_utils.py} (94%) diff --git a/server/text_generation_server/utils/pack_utils.py b/server/text_generation_server/utils/awq/conversion_utils.py similarity index 94% rename from server/text_generation_server/utils/pack_utils.py rename to server/text_generation_server/utils/awq/conversion_utils.py index d144b3cd..b19eafbb 100644 --- a/server/text_generation_server/utils/pack_utils.py +++ b/server/text_generation_server/utils/awq/conversion_utils.py @@ -15,10 +15,9 @@ def pack(imatrix: torch.Tensor, direction: str = "column"): Returns: qmatrix (torch.Tensor): packed matrix of integers """ - imatrix = imatrix.to(torch.int8) - imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow - - shifts = torch.arange(0, 32, 4, device=imatrix.device) + shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device) + + imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow if direction == "column": imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index 34895c01..8ad0dd80 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -182,7 +182,7 @@ try: ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) & maxq # add 1 and avoid overflow + zeros = (zeros + 1) & maxq # eventually avoid overflow a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated @@ -251,17 +251,7 @@ class QuantLinear(nn.Module): self.register_buffer("qweight", qweight) self.register_buffer("qzeros", qzeros) self.register_buffer("scales", scales) - if g_idx is not None: - self.register_buffer("g_idx", g_idx) - else: - self.register_buffer( - "g_idx", - torch.tensor( - [i // groupsize for i in range(qweight.shape[0] * 32 // bits)], - device=qweight.device, - dtype=torch.int32, - ), - ) + self.register_buffer("g_idx", g_idx) if bias is not None: self.register_buffer("bias", bias) else: diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 875ac464..759ea602 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -7,7 +7,6 @@ from loguru import logger from huggingface_hub import hf_hub_download import json from text_generation_server.utils.log import log_once -from text_generation_server.utils.pack_utils import fast_awq_to_gptq class Weights: @@ -162,15 +161,22 @@ class Weights: if quantize == "gptq" and quant_method == "gptq": g_idx = self.get_tensor(f"{prefix}.g_idx") - else: - g_idx = None - - if quantize == "gptq" and quant_method == "awq": + elif quantize == "gptq" and quant_method == "awq": log_once( logger.info, - "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", + "Converting AWQ weights to Exllama/GPTQ packing format, " + "in order used with Exllama/GPTQ kernels.", ) + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + g_idx = torch.zeros( + (qweight.shape[0] * 32 // bits), + dtype=torch.int32, + device=qweight.device, + ) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: @@ -220,8 +226,22 @@ class Weights: for w2 in w[1:]: torch.testing.assert_close(w2, w[0]) g_idx = w[0] - else: - g_idx = None + elif quantize == "gptq" and quant_method == "awq": + log_once( + logger.info, + "Converting AWQ weights to Exllama/GPTQ packing format, " + "in order used with Exllama/GPTQ kernels.", + ) + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + g_idx = torch.zeros( + (qweight.shape[0] * 32 // bits), + dtype=torch.int32, + device=qweight.device, + ) from text_generation_server.utils.layers import HAS_EXLLAMA @@ -229,13 +249,6 @@ class Weights: bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act ) - if quantize == "gptq" and quant_method == "awq": - log_once( - logger.info, - "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", - ) - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -279,7 +292,7 @@ class Weights: if quant_method == "gptq": g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - else: + elif quant_method == "awq": g_idx = None if self.process_group.size() > 1: @@ -324,9 +337,19 @@ class Weights: if quant_method == "awq": log_once( logger.info, - "Converting AWQ weights to Exllama/GPTQ format to be used with Exllama/GPTQ kernels", + "Converting AWQ weights to Exllama/GPTQ packing format, " + "in order used with Exllama/GPTQ kernels.", ) + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + g_idx = torch.zeros( + (qweight.shape[0] * 32 // bits), + dtype=torch.int32, + device=qweight.device, + ) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) elif quantize == "awq": @@ -353,13 +376,14 @@ class Weights: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() + quant_method = "gptq" desc_act = False except (SafetensorError, RuntimeError) as e: try: bits = self.gptq_bits groupsize = self.gptq_groupsize - quant_method = self.quant_method desc_act = getattr(self, "gptq_desc_act", False) + quant_method = getattr(self, "quant_method", "gptq") except Exception: raise e @@ -378,8 +402,8 @@ class Weights: data = json.load(f) self.gptq_bits = data["quantization_config"]["bits"] self.gptq_groupsize = data["quantization_config"]["group_size"] - self.gptq_desc_act = data["quantization_config"].get("desc_act", False) self.quant_method = data["quantization_config"]["quant_method"] + self.gptq_desc_act = data["quantization_config"]["desc_act"] except Exception: filename = "quantize_config.json" try: @@ -393,11 +417,11 @@ class Weights: data = json.load(f) self.gptq_bits = data["bits"] self.gptq_groupsize = data["group_size"] - self.gptq_desc_act = data.get("desc_act", False) if "version" in data and data["version"] == "GEMM": self.quant_method = "awq" else: self.quant_method = "gptq" + self.gptq_desc_act = data["desc_act"] except Exception: filename = "quant_config.json" try: @@ -411,10 +435,10 @@ class Weights: data = json.load(f) self.gptq_bits = data["w_bit"] self.gptq_groupsize = data["q_group_size"] - self.gptq_desc_act = data.get("desc_act", False) if "version" in data and data["version"] == "GEMM": self.quant_method = "awq" else: self.quant_method = "gptq" + self.gptq_desc_act = data["desc_act"] except Exception: pass From 76834c9989875749915eb3e40ef4177f565f8f57 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 2 Feb 2024 14:42:42 +0100 Subject: [PATCH 19/21] none g_idx --- server/text_generation_server/utils/weights.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 759ea602..f600a296 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -177,6 +177,8 @@ class Weights: dtype=torch.int32, device=qweight.device, ) + else: + g_idx = None weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: @@ -242,6 +244,8 @@ class Weights: dtype=torch.int32, device=qweight.device, ) + else: + g_idx = None from text_generation_server.utils.layers import HAS_EXLLAMA From 2629193efa949b59f446bad99ba77968a04f09f0 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 5 Feb 2024 09:26:47 +0100 Subject: [PATCH 20/21] log message --- .../text_generation_server/utils/weights.py | 39 +++++++------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index f600a296..767a23b2 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -163,20 +163,17 @@ class Weights: g_idx = self.get_tensor(f"{prefix}.g_idx") elif quantize == "gptq" and quant_method == "awq": log_once( - logger.info, - "Converting AWQ weights to Exllama/GPTQ packing format, " - "in order used with Exllama/GPTQ kernels.", + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) from text_generation_server.utils.awq.conversion_utils import ( fast_awq_to_gptq, ) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - g_idx = torch.zeros( - (qweight.shape[0] * 32 // bits), - dtype=torch.int32, - device=qweight.device, - ) + g_idx = ( + torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) + // groupsize + ).to(dtype=torch.int32) else: g_idx = None @@ -230,20 +227,17 @@ class Weights: g_idx = w[0] elif quantize == "gptq" and quant_method == "awq": log_once( - logger.info, - "Converting AWQ weights to Exllama/GPTQ packing format, " - "in order used with Exllama/GPTQ kernels.", + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) from text_generation_server.utils.awq.conversion_utils import ( fast_awq_to_gptq, ) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - g_idx = torch.zeros( - (qweight.shape[0] * 32 // bits), - dtype=torch.int32, - device=qweight.device, - ) + g_idx = ( + torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) + // groupsize + ).to(dtype=torch.int32) else: g_idx = None @@ -340,20 +334,17 @@ class Weights: if quant_method == "awq": log_once( - logger.info, - "Converting AWQ weights to Exllama/GPTQ packing format, " - "in order used with Exllama/GPTQ kernels.", + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." ) from text_generation_server.utils.awq.conversion_utils import ( fast_awq_to_gptq, ) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - g_idx = torch.zeros( - (qweight.shape[0] * 32 // bits), - dtype=torch.int32, - device=qweight.device, - ) + g_idx = ( + torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) + // groupsize + ).to(dtype=torch.int32) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) elif quantize == "awq": From 04d38a83be979c19641c87992c7fd3848ad84274 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 8 Feb 2024 14:59:35 +0000 Subject: [PATCH 21/21] Updating the tests. --- .../test_flash_starcoder_gptq.json | 247 ++++++------- ...t_flash_starcoder_gptq_default_params.json | 63 ++-- .../test_flash_starcoder_gptq_load.json | 332 +++++++++--------- .../flash_santacoder_modeling.py | 14 +- 4 files changed, 335 insertions(+), 321 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json index 53055e42..5e537bb7 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json @@ -1,193 +1,194 @@ { - "generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L", "details": { "best_of_sequences": null, "finish_reason": "length", "generated_tokens": 20, - "seed": null, "prefill": [ { "id": 589, - "text": "def", - "logprob": null + "logprob": null, + "text": "def" }, { "id": 3226, - "text": " ge", - "logprob": -9.0234375 + "logprob": -8.5859375, + "text": " ge" }, { "id": 21017, - "text": "ometric", - "logprob": -9.0859375 + "logprob": -7.5859375, + "text": "ometric" }, { "id": 81, - "text": "_", - "logprob": -0.25878906 + "logprob": -0.2668457, + "text": "_" }, { "id": 6009, - "text": "mean", - "logprob": -2.2109375 + "logprob": -1.6416016, + "text": "mean" }, { "id": 26, - "text": "(", - "logprob": -0.30371094 + "logprob": -0.22705078, + "text": "(" }, { "id": 62, - "text": "L", - "logprob": -5.6054688 + "logprob": -5.2304688, + "text": "L" }, { "id": 44, - "text": ":", - "logprob": -3.0722656 + "logprob": -3.0976562, + "text": ":" }, { "id": 1682, - "text": " List", - "logprob": -0.6879883 + "logprob": -1.1044922, + "text": " List" }, { "id": 77, - "text": "[", - "logprob": -0.38500977 + "logprob": -0.14294434, + "text": "[" }, { "id": 1808, - "text": "float", - "logprob": -0.984375 + "logprob": -0.32299805, + "text": "float" }, { "id": 10794, - "text": "]):", - "logprob": -2.5351562 + "logprob": -2.8164062, + "text": "]):" } ], + "seed": null, "tokens": [ { "id": 284, - "text": "\n ", - "logprob": -1.1738281, - "special": false + "logprob": -0.1282959, + "special": false, + "text": "\n " }, { - "id": 442, - "text": " return", - "logprob": -0.95947266, - "special": false + "id": 1524, + "logprob": -0.97998047, + "special": false, + "text": " \"\"\"" }, { - "id": 3632, - "text": " sum", - "logprob": -1.4199219, - "special": false + "id": 284, + "logprob": -0.7006836, + "special": false, + "text": "\n " }, { - "id": 26, - "text": "(", - "logprob": -0.085876465, - "special": false + "id": 14883, + "logprob": -2.1933594, + "special": false, + "text": " Calculate" }, { - "id": 62, - "text": "L", - "logprob": -0.09875488, - "special": false - }, - { - "id": 27, - "text": ")", - "logprob": -0.30517578, - "special": false - }, - { - "id": 517, - "text": " /", - "logprob": -0.42089844, - "special": false - }, - { - "id": 2069, - "text": " len", - "logprob": -0.042053223, - "special": false - }, - { - "id": 26, - "text": "(", - "logprob": -0.0011806488, - "special": false - }, - { - "id": 62, - "text": "L", - "logprob": -0.0005259514, - "special": false - }, - { - "id": 27, - "text": ")", - "logprob": -0.0017633438, - "special": false - }, - { - "id": 478, - "text": "\n\n", - "logprob": -0.69189453, - "special": false - }, - { - "id": 203, - "text": "\n", - "logprob": -0.041870117, - "special": false - }, - { - "id": 589, - "text": "def", - "logprob": -0.27856445, - "special": false + "id": 322, + "logprob": -0.2697754, + "special": false, + "text": " the" }, { "id": 3226, - "text": " ge", - "logprob": -1.7255859, - "special": false + "logprob": -0.0836792, + "special": false, + "text": " ge" }, { "id": 21017, - "text": "ometric", - "logprob": -0.011291504, - "special": false + "logprob": -0.018737793, + "special": false, + "text": "ometric" }, { - "id": 81, - "text": "_", - "logprob": -0.008430481, - "special": false + "id": 5651, + "logprob": -0.028640747, + "special": false, + "text": " mean" }, { - "id": 6009, - "text": "mean", - "logprob": -0.025787354, - "special": false + "id": 432, + "logprob": -0.29467773, + "special": false, + "text": " of" }, { - "id": 26, - "text": "(", - "logprob": -0.073913574, - "special": false + "id": 312, + "logprob": -0.31518555, + "special": false, + "text": " a" }, { - "id": 62, - "text": "L", - "logprob": -0.09967041, - "special": false + "id": 1149, + "logprob": -0.20605469, + "special": false, + "text": " list" + }, + { + "id": 432, + "logprob": -0.23254395, + "special": false, + "text": " of" + }, + { + "id": 7515, + "logprob": -0.4489746, + "special": false, + "text": " numbers" + }, + { + "id": 32, + "logprob": -0.6044922, + "special": false, + "text": "." + }, + { + "id": 446, + "logprob": -0.63964844, + "special": false, + "text": "\n\n " + }, + { + "id": 499, + "logprob": -1.1953125, + "special": false, + "text": " :" + }, + { + "id": 753, + "logprob": -0.03515625, + "special": false, + "text": "param" + }, + { + "id": 498, + "logprob": -0.06311035, + "special": false, + "text": " L" + }, + { + "id": 44, + "logprob": -0.003414154, + "special": false, + "text": ":" + }, + { + "id": 1682, + "logprob": -1.3310547, + "special": false, + "text": " List" } - ] - } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json index 1ace3814..bf0f5146 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json @@ -11,57 +11,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5898438, "text": "ometric" }, { "id": 81, - "logprob": -0.25830078, + "logprob": -0.26586914, "text": "_" }, { "id": 6009, - "logprob": -2.1875, + "logprob": -1.6347656, "text": "mean" }, { "id": 26, - "logprob": -0.30004883, + "logprob": -0.22705078, "text": "(" }, { "id": 62, - "logprob": -5.6171875, + "logprob": -5.2382812, "text": "L" }, { "id": 44, - "logprob": -3.078125, + "logprob": -3.0996094, "text": ":" }, { "id": 1682, - "logprob": -0.68066406, + "logprob": -1.1025391, "text": " List" }, { "id": 77, - "logprob": -0.38745117, + "logprob": -0.14294434, "text": "[" }, { "id": 1808, - "logprob": -0.9453125, + "logprob": -0.32226562, "text": "float" }, { "id": 10794, - "logprob": -2.5371094, + "logprob": -2.8164062, "text": "]):" } ], @@ -69,19 +69,19 @@ "tokens": [ { "id": 284, - "logprob": -0.051635742, + "logprob": 0.0, "special": false, "text": "\n " }, { "id": 442, - "logprob": 0.0, + "logprob": -1.3134766, "special": false, "text": " return" }, { "id": 11665, - "logprob": -1.2236328, + "logprob": -0.10021973, "special": false, "text": " reduce" }, @@ -129,7 +129,7 @@ }, { "id": 319, - "logprob": 0.0, + "logprob": -0.42871094, "special": false, "text": " *" }, @@ -158,36 +158,37 @@ "text": ")" }, { - "id": 203, - "logprob": -0.12695312, - "special": false, - "text": "\n" - }, - { - "id": 203, + "id": 1115, "logprob": 0.0, "special": false, - "text": "\n" + "text": " **" }, { - "id": 589, + "id": 308, "logprob": 0.0, "special": false, - "text": "def" + "text": " (" }, { - "id": 3226, + "id": 35, "logprob": 0.0, "special": false, - "text": " ge" + "text": "1" }, { - "id": 21017, + "id": 32, + "logprob": -0.31323242, + "special": false, + "text": "." + }, + { + "id": 34, "logprob": 0.0, "special": false, - "text": "ometric" + "text": "0" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric" + "generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json index 5381ce5a..46a21ed8 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json @@ -12,57 +12,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5820312, "text": "ometric" }, { "id": 81, - "logprob": -0.25927734, + "logprob": -0.26708984, "text": "_" }, { "id": 6009, - "logprob": -2.25, + "logprob": -1.6386719, "text": "mean" }, { "id": 26, - "logprob": -0.30126953, + "logprob": -0.22717285, "text": "(" }, { "id": 62, - "logprob": -5.7539062, + "logprob": -5.234375, "text": "L" }, { "id": 44, - "logprob": -3.0878906, + "logprob": -3.1015625, "text": ":" }, { "id": 1682, - "logprob": -0.6845703, + "logprob": -1.1083984, "text": " List" }, { "id": 77, - "logprob": -0.3918457, + "logprob": -0.14294434, "text": "[" }, { "id": 1808, - "logprob": -0.8798828, + "logprob": -0.32592773, "text": "float" }, { "id": 10794, - "logprob": -2.4980469, + "logprob": -2.8164062, "text": "]):" } ], @@ -70,67 +70,68 @@ "tokens": [ { "id": 284, - "logprob": -1.1533203, + "logprob": -0.12817383, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.91796875, + "id": 1524, + "logprob": -0.9863281, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.3291016, + "id": 284, + "logprob": -0.7011719, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.08062744, + "id": 14883, + "logprob": -2.2050781, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.097717285, + "id": 322, + "logprob": -0.2668457, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.29003906, + "id": 3226, + "logprob": -0.08465576, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.34958984, + "id": 21017, + "logprob": -0.019012451, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.03829956, + "id": 5651, + "logprob": -0.028625488, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011987686, + "id": 432, + "logprob": -0.29418945, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.00050878525, + "id": 312, + "logprob": -0.3161621, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" }, { "details": { @@ -145,57 +146,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.59375, "text": "ometric" }, { "id": 81, - "logprob": -0.25878906, + "logprob": -0.26953125, "text": "_" }, { "id": 6009, - "logprob": -2.2109375, + "logprob": -1.640625, "text": "mean" }, { "id": 26, - "logprob": -0.30371094, + "logprob": -0.22705078, "text": "(" }, { "id": 62, - "logprob": -5.6054688, + "logprob": -5.234375, "text": "L" }, { "id": 44, - "logprob": -3.0722656, + "logprob": -3.1132812, "text": ":" }, { "id": 1682, - "logprob": -0.6879883, + "logprob": -1.1123047, "text": " List" }, { "id": 77, - "logprob": -0.38500977, + "logprob": -0.14294434, "text": "[" }, { "id": 1808, - "logprob": -0.984375, + "logprob": -0.32299805, "text": "float" }, { "id": 10794, - "logprob": -2.5351562, + "logprob": -2.8164062, "text": "]):" } ], @@ -203,67 +204,68 @@ "tokens": [ { "id": 284, - "logprob": -1.1738281, + "logprob": -0.12854004, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.9584961, + "id": 1524, + "logprob": -0.9897461, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.4169922, + "id": 284, + "logprob": -0.69970703, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.085876465, + "id": 14883, + "logprob": -2.2050781, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.0982666, + "id": 322, + "logprob": -0.2668457, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.3022461, + "id": 3226, + "logprob": -0.08496094, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.40504883, + "id": 21017, + "logprob": -0.019012451, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.041656494, + "id": 5651, + "logprob": -0.029037476, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011844635, + "id": 432, + "logprob": -0.2939453, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.0005264282, + "id": 312, + "logprob": -0.31591797, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" }, { "details": { @@ -278,57 +280,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5859375, "text": "ometric" }, { "id": 81, - "logprob": -0.25927734, + "logprob": -0.26586914, "text": "_" }, { "id": 6009, - "logprob": -2.25, + "logprob": -1.6347656, "text": "mean" }, { "id": 26, - "logprob": -0.30126953, + "logprob": -0.22766113, "text": "(" }, { "id": 62, - "logprob": -5.7539062, + "logprob": -5.2265625, "text": "L" }, { "id": 44, - "logprob": -3.0878906, + "logprob": -3.0976562, "text": ":" }, { "id": 1682, - "logprob": -0.6845703, + "logprob": -1.1025391, "text": " List" }, { "id": 77, - "logprob": -0.3918457, + "logprob": -0.1427002, "text": "[" }, { "id": 1808, - "logprob": -0.8798828, + "logprob": -0.32592773, "text": "float" }, { "id": 10794, - "logprob": -2.4980469, + "logprob": -2.8164062, "text": "]):" } ], @@ -336,67 +338,68 @@ "tokens": [ { "id": 284, - "logprob": -1.1533203, + "logprob": -0.13012695, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.9165039, + "id": 1524, + "logprob": -0.98046875, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.328125, + "id": 284, + "logprob": -0.69921875, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.07946777, + "id": 14883, + "logprob": -2.1992188, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.09820557, + "id": 322, + "logprob": -0.2668457, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.28930664, + "id": 3226, + "logprob": -0.083496094, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.34592773, + "id": 21017, + "logprob": -0.01902771, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.038330078, + "id": 5651, + "logprob": -0.029006958, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011940002, + "id": 432, + "logprob": -0.29248047, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.00050878525, + "id": 312, + "logprob": -0.3161621, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" }, { "details": { @@ -411,57 +414,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5859375, "text": "ometric" }, { "id": 81, - "logprob": -0.25927734, + "logprob": -0.26904297, "text": "_" }, { "id": 6009, - "logprob": -2.25, + "logprob": -1.6386719, "text": "mean" }, { "id": 26, - "logprob": -0.30126953, + "logprob": -0.22705078, "text": "(" }, { "id": 62, - "logprob": -5.7539062, + "logprob": -5.234375, "text": "L" }, { "id": 44, - "logprob": -3.0878906, + "logprob": -3.1132812, "text": ":" }, { "id": 1682, - "logprob": -0.6845703, + "logprob": -1.1074219, "text": " List" }, { "id": 77, - "logprob": -0.3918457, + "logprob": -0.14477539, "text": "[" }, { "id": 1808, - "logprob": -0.8798828, + "logprob": -0.3256836, "text": "float" }, { "id": 10794, - "logprob": -2.4980469, + "logprob": -2.8027344, "text": "]):" } ], @@ -469,66 +472,67 @@ "tokens": [ { "id": 284, - "logprob": -1.1533203, + "logprob": -0.12915039, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.91259766, + "id": 1524, + "logprob": -0.98535156, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.3251953, + "id": 284, + "logprob": -0.69921875, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.08062744, + "id": 14883, + "logprob": -2.2011719, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.09906006, + "id": 322, + "logprob": -0.26708984, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.28979492, + "id": 3226, + "logprob": -0.08502197, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.35958984, + "id": 21017, + "logprob": -0.019012451, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.038604736, + "id": 5651, + "logprob": -0.028625488, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011901855, + "id": 432, + "logprob": -0.29589844, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.0005078316, + "id": 312, + "logprob": -0.31591797, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" } ] diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 22d03adf..81041046 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -69,9 +69,17 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = qzeros.to(device=weights.device) - g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") - g_idx = g_idx.to(device=weights.device) - bits, groupsize, _ = weights._get_gptq_params() + bits, groupsize, _, quant_method, = weights._get_gptq_params() + if quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") + g_idx = g_idx.to(device=weights.device) + elif quant_method == "awq": + g_idx = None + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) from text_generation_server.utils.layers import HAS_EXLLAMA