turn tp embed back on

This commit is contained in:
OlivierDehaene 2023-03-24 18:20:47 +01:00
parent fc778e46fb
commit 15a6b79c7e

View File

@ -499,12 +499,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
if config.vocab_size % self.tp_world_size == 0: if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True self.tp_embeddings = True
# if self.tp_embeddings: if self.tp_embeddings:
# self.embed_in = TensorParallelEmbedding( self.embed_in = TensorParallelEmbedding(
# config.vocab_size, config.hidden_size, process_group=process_group config.vocab_size, config.hidden_size, process_group=process_group
# ) )
# else: else:
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [