mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Update tensor_parallel.py
Resolve the issue of abnormal conversation performance in the Baichuan large model.
This commit is contained in:
parent
2c74c55637
commit
731f890887
@ -66,6 +66,11 @@ class TensorParallelHead(SuperLayer):
|
|||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
should_gather = False
|
should_gather = False
|
||||||
|
|
||||||
|
if config.model_type == "baichuan":
|
||||||
|
# Resolve the issue of abnormal conversation performance in the Baichuan large model.
|
||||||
|
# https://github.com/huggingface/text-generation-inference/issues/2780
|
||||||
|
weight = F.normalize(weight)
|
||||||
|
|
||||||
return TensorParallelHead(
|
return TensorParallelHead(
|
||||||
get_linear(weight, bias=None),
|
get_linear(weight, bias=None),
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
|
Loading…
Reference in New Issue
Block a user