mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Fixing changed names for santacoder.
This commit is contained in:
parent
34eadb54e9
commit
f282d1bdbc
@ -52,7 +52,8 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(
|
weights = Weights(
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
filenames, device=device, dtype=dtype, process_group=self.process_group,
|
||||||
|
aliases = {"transformer.wte.weight": ["lm_head.weight"]}
|
||||||
)
|
)
|
||||||
|
|
||||||
model = FlashSantacoderForCausalLM(config, weights)
|
model = FlashSantacoderForCausalLM(config, weights)
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List, Dict, Optional
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(self, filenames: List[Path], device, dtype, process_group):
|
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
|
||||||
routing = {}
|
routing = {}
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
@ -14,6 +14,9 @@ class Weights:
|
|||||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
)
|
)
|
||||||
routing[k] = filename
|
routing[k] = filename
|
||||||
|
if aliases is None:
|
||||||
|
aliases = {}
|
||||||
|
self.aliases = aliases
|
||||||
self.routing = routing
|
self.routing = routing
|
||||||
self.device = device
|
self.device = device
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
@ -27,14 +30,19 @@ class Weights:
|
|||||||
|
|
||||||
return self._handles[filename]
|
return self._handles[filename]
|
||||||
|
|
||||||
def get_filename(self, tensor_name: str) -> str:
|
def get_filename(self, tensor_name: str) -> (str, str):
|
||||||
filename = self.routing.get(tensor_name, None)
|
filename = self.routing.get(tensor_name, None)
|
||||||
if filename is None:
|
if filename is None:
|
||||||
|
aliases = self.aliases.get(tensor_name, [])
|
||||||
|
for alias in aliases:
|
||||||
|
filename = self.routing.get(alias, None)
|
||||||
|
if filename is not None:
|
||||||
|
return str(filename), alias
|
||||||
raise RuntimeError(f"weight {tensor_name} does not exist")
|
raise RuntimeError(f"weight {tensor_name} does not exist")
|
||||||
return str(filename)
|
return str(filename), tensor_name
|
||||||
|
|
||||||
def _get_slice(self, tensor_name: str):
|
def _get_slice(self, tensor_name: str):
|
||||||
filename = self.get_filename(tensor_name)
|
filename, tensor_name= self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
return slice_
|
return slice_
|
||||||
@ -43,7 +51,7 @@ class Weights:
|
|||||||
return self._get_slice(tensor_name).get_shape()
|
return self._get_slice(tensor_name).get_shape()
|
||||||
|
|
||||||
def get_tensor(self, tensor_name: str):
|
def get_tensor(self, tensor_name: str):
|
||||||
filename = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
tensor = f.get_tensor(tensor_name)
|
tensor = f.get_tensor(tensor_name)
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
@ -51,7 +59,7 @@ class Weights:
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_sharded(self, tensor_name: str, dim: int):
|
def get_sharded(self, tensor_name: str, dim: int):
|
||||||
filename = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user