diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 1f324f771..390f0a0a6 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -131,7 +131,10 @@ class BLOOMSharded(BLOOM): file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): - full_name = f"transformer.{name}" + if name.startswith("transformer.") or name.startswith("lm_head."): + full_name = name + else: + full_name = f"transformer.{name}" module_name, param_name = full_name.rsplit(".", 1) module = model.get_submodule(module_name) @@ -157,7 +160,7 @@ class BLOOMSharded(BLOOM): # XXX: Hack for Rowlinear to add the bias only once. if rank != 0: tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): + elif isinstance(module, TensorParallelEmbedding) or name == "lm_head.weight": size = slice_.get_shape()[0] block_size = size // world_size start = rank * block_size diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 9d8ae2542..90b1e5ee2 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -504,6 +504,7 @@ class CausalLM(Model): position_ids=position_ids, past_key_values=past_key_values, use_cache=True, + return_dict=True, ) return outputs.logits, outputs.past_key_values