feat(server): Support BLOOMChat-176B

Tweaks to accommodate config differences in current version of https://huggingface.co/sambanovasystems/BLOOMChat-176B-v1
This commit is contained in:
Nick Hill 2023-05-21 08:11:49 -07:00
parent 5a58226130
commit 5558dca0ec
2 changed files with 6 additions and 2 deletions

View File

@ -131,7 +131,10 @@ class BLOOMSharded(BLOOM):
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
full_name = f"transformer.{name}"
if name.startswith("transformer.") or name.startswith("lm_head."):
full_name = name
else:
full_name = f"transformer.{name}"
module_name, param_name = full_name.rsplit(".", 1)
module = model.get_submodule(module_name)
@ -157,7 +160,7 @@ class BLOOMSharded(BLOOM):
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
elif isinstance(module, TensorParallelEmbedding) or name == "lm_head.weight":
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size

View File

@ -504,6 +504,7 @@ class CausalLM(Model):
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
)
return outputs.logits, outputs.past_key_values