Update tensor_parallel.py

Resolve the issue of abnormal conversation performance in the Baichuan large model.
This commit is contained in:
Kaixiong Happy 2024-12-03 19:00:28 +08:00 committed by GitHub
parent 2c74c55637
commit 731f890887
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -66,6 +66,11 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
if config.model_type == "baichuan":
# Resolve the issue of abnormal conversation performance in the Baichuan large model.
# https://github.com/huggingface/text-generation-inference/issues/2780
weight = F.normalize(weight)
return TensorParallelHead(
get_linear(weight, bias=None),
process_group=weights.process_group,