Removed redundant and crash causing regions to be a subject to Torch compile (#194)

Co-authored-by: Jacek Czaja <jczaja@habana.ai>
This commit is contained in:
Jacek Czaja 2024-08-08 13:06:20 +02:00 committed by GitHub
parent 4dc67e4ef3
commit 256a97231b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -119,7 +119,6 @@ def roll(tensor, chunk, dim, merge_graphs):
return tensor
@torch_compile_for_eager
def grouped_roll(tensor_groups, chunk, dims, merge_graphs):
tensor_groups = [[roll(t, chunk, dim, merge_graphs) for t in tensors] for tensors, dim in zip(tensor_groups, dims)]
if merge_graphs:
@ -135,7 +134,6 @@ def grouped_shift(tensor_groups, dims, offset, merge_graphs):
return tensor_groups
@torch_compile_for_eager
def move(dst_tensors, dst_indices, src_tensors):
bs_dim = 0
num_indices = dst_indices.size(0)