mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
Fixing changed names for santacoder.
This commit is contained in:
parent
34eadb54e9
commit
f282d1bdbc
@ -52,7 +52,8 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group,
|
||||
aliases = {"transformer.wte.weight": ["lm_head.weight"]}
|
||||
)
|
||||
|
||||
model = FlashSantacoderForCausalLM(config, weights)
|
||||
|
@ -1,10 +1,10 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Dict, Optional
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(self, filenames: List[Path], device, dtype, process_group):
|
||||
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
|
||||
routing = {}
|
||||
for filename in filenames:
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
@ -14,6 +14,9 @@ class Weights:
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
if aliases is None:
|
||||
aliases = {}
|
||||
self.aliases = aliases
|
||||
self.routing = routing
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
@ -27,14 +30,19 @@ class Weights:
|
||||
|
||||
return self._handles[filename]
|
||||
|
||||
def get_filename(self, tensor_name: str) -> str:
|
||||
def get_filename(self, tensor_name: str) -> (str, str):
|
||||
filename = self.routing.get(tensor_name, None)
|
||||
if filename is None:
|
||||
aliases = self.aliases.get(tensor_name, [])
|
||||
for alias in aliases:
|
||||
filename = self.routing.get(alias, None)
|
||||
if filename is not None:
|
||||
return str(filename), alias
|
||||
raise RuntimeError(f"weight {tensor_name} does not exist")
|
||||
return str(filename)
|
||||
return str(filename), tensor_name
|
||||
|
||||
def _get_slice(self, tensor_name: str):
|
||||
filename = self.get_filename(tensor_name)
|
||||
filename, tensor_name= self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
return slice_
|
||||
@ -43,7 +51,7 @@ class Weights:
|
||||
return self._get_slice(tensor_name).get_shape()
|
||||
|
||||
def get_tensor(self, tensor_name: str):
|
||||
filename = self.get_filename(tensor_name)
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
@ -51,7 +59,7 @@ class Weights:
|
||||
return tensor
|
||||
|
||||
def get_sharded(self, tensor_name: str, dim: int):
|
||||
filename = self.get_filename(tensor_name)
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user