turn tp embed back on

This commit is contained in:
OlivierDehaene 2023-03-24 18:20:47 +01:00
parent fc778e46fb
commit 15a6b79c7e

View File

@ -499,11 +499,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
# if self.tp_embeddings:
# self.embed_in = TensorParallelEmbedding(
# config.vocab_size, config.hidden_size, process_group=process_group
# )
# else:
if self.tp_embeddings:
self.embed_in = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
else:
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(