From e649bf9a55753a3de4a64c1d6be2cc4ed32c260b Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 22 May 2023 13:36:00 +0200 Subject: [PATCH] feat(server): Support BLOOMChat-176B (#348) (#351) @njhill, temporary workaround to be able to run our CI as secrets are not available to runners run by external contributors. I will ask around to see if there is a better way. Co-authored-by: Nick Hill --- server/text_generation_server/models/bloom.py | 7 +++++-- server/text_generation_server/models/causal_lm.py | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 1f324f771..390f0a0a6 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -131,7 +131,10 @@ class BLOOMSharded(BLOOM): file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): - full_name = f"transformer.{name}" + if name.startswith("transformer.") or name.startswith("lm_head."): + full_name = name + else: + full_name = f"transformer.{name}" module_name, param_name = full_name.rsplit(".", 1) module = model.get_submodule(module_name) @@ -157,7 +160,7 @@ class BLOOMSharded(BLOOM): # XXX: Hack for Rowlinear to add the bias only once. if rank != 0: tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): + elif isinstance(module, TensorParallelEmbedding) or name == "lm_head.weight": size = slice_.get_shape()[0] block_size = size // world_size start = rank * block_size diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 9d8ae2542..90b1e5ee2 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -504,6 +504,7 @@ class CausalLM(Model): position_ids=position_ids, past_key_values=past_key_values, use_cache=True, + return_dict=True, ) return outputs.logits, outputs.past_key_values