mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Update tensor_parallel.py
Resolve the issue of abnormal conversation performance in the Baichuan large model.
This commit is contained in:
parent
2c74c55637
commit
731f890887
@ -66,6 +66,11 @@ class TensorParallelHead(SuperLayer):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
should_gather = False
|
||||
|
||||
if config.model_type == "baichuan":
|
||||
# Resolve the issue of abnormal conversation performance in the Baichuan large model.
|
||||
# https://github.com/huggingface/text-generation-inference/issues/2780
|
||||
weight = F.normalize(weight)
|
||||
|
||||
return TensorParallelHead(
|
||||
get_linear(weight, bias=None),
|
||||
process_group=weights.process_group,
|
||||
|
Loading…
Reference in New Issue
Block a user