Update server/text_generation_server/utils/layers.py

This commit is contained in:
OlivierDehaene 2023-07-12 11:05:07 +02:00 committed by GitHub
parent 63f03b4b7d
commit 6193512c4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -286,7 +286,6 @@ class TensorParallelRowLinear(SuperLayer):
class TensorParallelEmbedding(nn.Module):
def __init__(self, prefix: str, weights, reduce=True):
super().__init__()
# weight = weights.get_sharded(f"{prefix}.weight", dim=0)
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]