fix: allocate tmp based on sgmv kernel if available (#2345)

* fix: allocate tmp based on sgmv kernel if available

* fix: re add copy build artifacts step for punica kernels
This commit is contained in:
drbh 2024-08-12 11:24:32 -04:00 committed by yuanwu
parent 8e6bfa2fc5
commit 3079865b60

View File

@ -151,13 +151,17 @@ def get_tmp_expand_size(size: int) -> int:
def get_tmp_tensors( def get_tmp_tensors(
nsegments: int, lora_rank: int, device: torch.device nsegments: int, lora_rank: int, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if use_cutlass_shrink(lora_rank) and has_sgmv(): use_cutlass = use_cutlass_shrink(lora_rank) and has_sgmv()
has_sgmv_available = has_sgmv()
if use_cutlass:
tmp = get_tmp_tensor_for_size(nsegments, device) tmp = get_tmp_tensor_for_size(nsegments, device)
return tmp, tmp return tmp, tmp
elif has_sgmv_available:
return get_tmp_tensor(device), get_tmp_tensor_for_size(nsegments, device)
else: else:
tmp_shrink = get_tmp_tensor(device) tmp = get_tmp_tensor_for_size(nsegments, device)
tmp_expand = get_tmp_tensor_for_size_no_kernels(nsegments, device) return tmp, tmp
return tmp_shrink, tmp_expand
def lora_a_sgmv_cutlass( def lora_a_sgmv_cutlass(