Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2025-05-11 18:23:23 +00:00
parent 4e95db304f
commit d98116db6e

View File

@ -469,6 +469,7 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module): class FlashLlamaModel(torch.nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()