Handling bloom prefix.

This commit is contained in:
Nicolas Patry 2023-10-03 09:08:41 +00:00
parent bd998d8797
commit f092404830
2 changed files with 15 additions and 6 deletions

View File

@ -74,7 +74,7 @@ class BLOOMSharded(CausalLM):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer",
)
if config.quantize == "gptq":
weights._set_gptq_params(model_id)

View File

@ -16,6 +16,7 @@ class Weights:
dtype,
process_group,
aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None
):
routing = {}
for filename in filenames:
@ -33,6 +34,7 @@ class Weights:
self.device = device
self.dtype = dtype
self.process_group = process_group
self.prefix = prefix
self._handles = {}
def _get_handle(self, filename):
@ -43,15 +45,22 @@ class Weights:
return self._handles[filename]
def get_filename(self, tensor_name: str) -> (str, str):
filename = self.routing.get(tensor_name, None)
if filename is None:
aliases = self.aliases.get(tensor_name, [])
names = [tensor_name]
if self.prefix is not None:
prefixed = f"{self.prefix}.{tensor_name}"
names.append(prefixed)
for name in names:
filename = self.routing.get(name, None)
if filename is not None:
return str(filename), name
aliases = self.aliases.get(name, [])
for alias in aliases:
filename = self.routing.get(alias, None)
if filename is not None:
return str(filename), alias
raise RuntimeError(f"weight {tensor_name} does not exist")
return str(filename), tensor_name
raise RuntimeError(f"weight {tensor_name} does not exist")
def _get_slice(self, tensor_name: str):
filename, tensor_name = self.get_filename(tensor_name)