mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing sharding.
This commit is contained in:
parent
298500a08e
commit
b2fb845923
@ -881,7 +881,9 @@ class FlashCausalLM(Model):
|
||||
model = model_class(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
|
||||
# Validation is done in the model itself
|
||||
self.num_kv_heads = config.num_key_value_heads // self.process_group.size()
|
||||
self.head_size = config.hidden_size // config.num_attention_heads
|
||||
|
||||
self.cuda_graphs = {}
|
||||
|
Loading…
Reference in New Issue
Block a user