diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index a00338e7..d4ef8858 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -24,6 +24,7 @@ from text_generation_server.utils.sgmv import ( orient_for_rank, pad_rank, use_cutlass_shrink, + has_sgmv, ) @@ -325,6 +326,22 @@ class BatchLoraWeights(BatchAdapterWeights): default=0, ) + adapter_index_configs = { + idx: adapter_weights[idx].adapter_config + for idx in segment_indices + if idx in adapter_weights + } + use_sgmv = False + rank_data = {} + + if not has_sgmv(): + return BatchLoraWeights( + lora_a=lora_a, + lora_b=lora_b, + adapter_index_configs=adapter_index_configs, + rank_data=rank_data, + use_sgmv=use_sgmv, + ) if prefill or max_rank > BGMV_MAX_RANK: use_sgmv = True lora_a_ptr = torch.tensor( @@ -378,12 +395,6 @@ class BatchLoraWeights(BatchAdapterWeights): device=device, ) - adapter_index_configs = { - idx: adapter_weights[idx].adapter_config - for idx in segment_indices - if idx in adapter_weights - } - adapter_to_segment = {v: k for k, v in enumerate(segment_indices)} rank_indices = defaultdict(list)