From 0fc4f99379f61cae4bbcc09633913daae84ddc8c Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 19 Apr 2023 12:32:25 +0200 Subject: [PATCH] fix santacoder sharded --- server/text_generation_server/models/__init__.py | 4 ++-- server/text_generation_server/models/flash_llama.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index babc4d29..13c74c91 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -90,10 +90,10 @@ def get_model( raise NotImplementedError( FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") ) - return FlashSantacoderSharded(model_id, revision=revision) + return FlashSantacoderSharded(model_id, revision, quantize=quantize) else: santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder - return santacoder_cls(model_id, revision, quantize) + return santacoder_cls(model_id, revision, quantize=quantize) config = AutoConfig.from_pretrained(model_id, revision=revision) model_type = config.model_type diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 0526ea9b..9cbf1b57 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -75,9 +75,9 @@ class FlashLlama(FlashCausalLM): dtype: torch.dtype, ): for filename in filenames: - state_dict = torch.load(filename, "cpu") + state_dict = torch.load(filename, map_location="cpu") for key, value in state_dict.items(): - value = value.to(dtype).to(device if not quantize else "cpu") + value = value.to(device if not quantize else "cpu").to(dtype) layer_name = ".".join(key.split(".")[:4])