diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 54634e4a..a71c0061 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -52,7 +52,8 @@ class FlashSantacoderSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, device=device, dtype=dtype, process_group=self.process_group, + aliases = {"transformer.wte.weight": ["lm_head.weight"]} ) model = FlashSantacoderForCausalLM(config, weights) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 76a4f65a..88347a6a 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,10 +1,10 @@ from pathlib import Path -from typing import List +from typing import List, Dict, Optional from safetensors import safe_open class Weights: - def __init__(self, filenames: List[Path], device, dtype, process_group): + def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None): routing = {} for filename in filenames: with safe_open(filename, framework="pytorch") as f: @@ -14,6 +14,9 @@ class Weights: f"Key {k} was found in multiple files: {filename} and {routing[k]}" ) routing[k] = filename + if aliases is None: + aliases = {} + self.aliases = aliases self.routing = routing self.device = device self.dtype = dtype @@ -27,14 +30,19 @@ class Weights: return self._handles[filename] - def get_filename(self, tensor_name: str) -> str: + def get_filename(self, tensor_name: str) -> (str, str): filename = self.routing.get(tensor_name, None) if filename is None: + aliases = self.aliases.get(tensor_name, []) + for alias in aliases: + filename = self.routing.get(alias, None) + if filename is not None: + return str(filename), alias raise RuntimeError(f"weight {tensor_name} does not exist") - return str(filename) + return str(filename), tensor_name def _get_slice(self, tensor_name: str): - filename = self.get_filename(tensor_name) + filename, tensor_name= self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) return slice_ @@ -43,7 +51,7 @@ class Weights: return self._get_slice(tensor_name).get_shape() def get_tensor(self, tensor_name: str): - filename = self.get_filename(tensor_name) + filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) tensor = tensor.to(dtype=self.dtype) @@ -51,7 +59,7 @@ class Weights: return tensor def get_sharded(self, tensor_name: str, dim: int): - filename = self.get_filename(tensor_name) + filename, tensor_name = self.get_filename(tensor_name) world_size = self.process_group.size() rank = self.process_group.rank()