diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index a00338e7..f1edd9a0 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -24,6 +24,7 @@ from text_generation_server.utils.sgmv import ( orient_for_rank, pad_rank, use_cutlass_shrink, + has_sgmv, ) @@ -325,8 +326,10 @@ class BatchLoraWeights(BatchAdapterWeights): default=0, ) + use_sgmv = False if prefill or max_rank > BGMV_MAX_RANK: - use_sgmv = True + if has_sgmv(): + use_sgmv = True lora_a_ptr = torch.tensor( [ ( @@ -352,7 +355,6 @@ class BatchLoraWeights(BatchAdapterWeights): device=device, ) else: - use_sgmv = False lora_a_ptr = torch.tensor( [ (