From 55cf4d257c9af7468405e230cb763bfb2a5f3728 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 14 Jun 2023 09:29:44 +0200 Subject: [PATCH] Tiny fixes for falcon. --- .../models/custom_modeling/flash_rw_modeling.py | 3 ++- server/text_generation_server/utils/gptq/quantize.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 4a9063eb..efc9548c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -21,7 +21,8 @@ from text_generation_server.utils.layers import ( def load_row(config, prefix: str, weights, bias: bool): - weight = weights.get_sharded(f"{prefix}.weight", dim=1) + weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=1) + if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index 54d996dd..4956f7c2 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -205,7 +205,7 @@ class GPTQ: def print_loss(self, name, q_weight, weight_error, timecost): table = Texttable() - length = 30 + length = 28 name = ( (name + " " * (length - len(name))) if len(name) <= length @@ -1165,10 +1165,12 @@ def quantize( f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " f"index located at {save_index_file}." ) - config = AutoConfig.from_pretrained(model_id) + config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) config.save_pretrained(output_dir) logger.info("Saved config") logger.info("Saving tokenizer") - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained( + model_id, trust_remote_code=trust_remote_code + ) tokenizer.save_pretrained(output_dir) logger.info("Saved tokenizer")