import torch MEM_POOL = torch.cuda.graph_pool_handle()