mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
feat(server): Support BLOOMChat-176B
Tweaks to accommodate config differences in current version of https://huggingface.co/sambanovasystems/BLOOMChat-176B-v1
This commit is contained in:
parent
5a58226130
commit
5558dca0ec
@ -131,7 +131,10 @@ class BLOOMSharded(BLOOM):
|
||||
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
full_name = f"transformer.{name}"
|
||||
if name.startswith("transformer.") or name.startswith("lm_head."):
|
||||
full_name = name
|
||||
else:
|
||||
full_name = f"transformer.{name}"
|
||||
|
||||
module_name, param_name = full_name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
@ -157,7 +160,7 @@ class BLOOMSharded(BLOOM):
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
elif isinstance(module, TensorParallelEmbedding) or name == "lm_head.weight":
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
|
@ -504,6 +504,7 @@ class CausalLM(Model):
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user