From 106d8ee81899e6461648e1efd56f5a873ce7294d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 9 Apr 2024 10:27:57 +0200 Subject: [PATCH] Automatic quantization config. (#1719) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/text_generation_server/models/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 06219e7c..ec167303 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -186,6 +186,14 @@ def get_model( raise RuntimeError( f"Could not determine model type for {model_id} revision {revision}" ) + quantization_config = config_dict.get("quantization_config", None) + if quantization_config is not None and quantize is None: + method = quantization_config.get("quant_method", None) + if method in {"gptq", "awq"}: + logger.info(f"Auto selecting quantization method {method}") + quantize = method + else: + logger.info(f"Unknown quantization method {method}") if model_type == "ssm": return Mamba(