diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index ee1a86de..caa7d62d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -286,7 +286,6 @@ class TensorParallelRowLinear(SuperLayer): class TensorParallelEmbedding(nn.Module): def __init__(self, prefix: str, weights, reduce=True): super().__init__() - # weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0) num_embeddings = weights.get_shape(f"{prefix}.weight")[0]