Fixing sharding.

This commit is contained in:
Nicolas Patry 2024-07-02 15:37:27 +00:00
parent 298500a08e
commit b2fb845923
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674

View File

@ -881,7 +881,9 @@ class FlashCausalLM(Model):
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
self.num_layers = config.num_hidden_layers
self.num_kv_heads = config.num_key_value_heads
# Validation is done in the model itself
self.num_kv_heads = config.num_key_value_heads // self.process_group.size()
self.head_size = config.hidden_size // config.num_attention_heads
self.cuda_graphs = {}