mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Handling bloom prefix.
This commit is contained in:
parent
bd998d8797
commit
f092404830
@ -74,7 +74,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(
|
weights = Weights(
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer",
|
||||||
)
|
)
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id)
|
||||||
|
@ -16,6 +16,7 @@ class Weights:
|
|||||||
dtype,
|
dtype,
|
||||||
process_group,
|
process_group,
|
||||||
aliases: Optional[Dict[str, List[str]]] = None,
|
aliases: Optional[Dict[str, List[str]]] = None,
|
||||||
|
prefix: Optional[str] = None
|
||||||
):
|
):
|
||||||
routing = {}
|
routing = {}
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
@ -33,6 +34,7 @@ class Weights:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
self.prefix = prefix
|
||||||
self._handles = {}
|
self._handles = {}
|
||||||
|
|
||||||
def _get_handle(self, filename):
|
def _get_handle(self, filename):
|
||||||
@ -43,15 +45,22 @@ class Weights:
|
|||||||
return self._handles[filename]
|
return self._handles[filename]
|
||||||
|
|
||||||
def get_filename(self, tensor_name: str) -> (str, str):
|
def get_filename(self, tensor_name: str) -> (str, str):
|
||||||
filename = self.routing.get(tensor_name, None)
|
|
||||||
if filename is None:
|
names = [tensor_name]
|
||||||
aliases = self.aliases.get(tensor_name, [])
|
if self.prefix is not None:
|
||||||
|
prefixed = f"{self.prefix}.{tensor_name}"
|
||||||
|
names.append(prefixed)
|
||||||
|
for name in names:
|
||||||
|
filename = self.routing.get(name, None)
|
||||||
|
if filename is not None:
|
||||||
|
return str(filename), name
|
||||||
|
|
||||||
|
aliases = self.aliases.get(name, [])
|
||||||
for alias in aliases:
|
for alias in aliases:
|
||||||
filename = self.routing.get(alias, None)
|
filename = self.routing.get(alias, None)
|
||||||
if filename is not None:
|
if filename is not None:
|
||||||
return str(filename), alias
|
return str(filename), alias
|
||||||
raise RuntimeError(f"weight {tensor_name} does not exist")
|
raise RuntimeError(f"weight {tensor_name} does not exist")
|
||||||
return str(filename), tensor_name
|
|
||||||
|
|
||||||
def _get_slice(self, tensor_name: str):
|
def _get_slice(self, tensor_name: str):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
|
Loading…
Reference in New Issue
Block a user