fix lora failure in platform which does not contain punica_kernels

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-11-28 00:14:22 -08:00
parent caff779dd4
commit 98d0093660

View File

@ -24,6 +24,7 @@ from text_generation_server.utils.sgmv import (
orient_for_rank,
pad_rank,
use_cutlass_shrink,
has_sgmv,
)
@ -325,6 +326,22 @@ class BatchLoraWeights(BatchAdapterWeights):
default=0,
)
adapter_index_configs = {
idx: adapter_weights[idx].adapter_config
for idx in segment_indices
if idx in adapter_weights
}
use_sgmv = False
rank_data = {}
if not has_sgmv():
return BatchLoraWeights(
lora_a=lora_a,
lora_b=lora_b,
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
use_sgmv=use_sgmv,
)
if prefill or max_rank > BGMV_MAX_RANK:
use_sgmv = True
lora_a_ptr = torch.tensor(
@ -378,12 +395,6 @@ class BatchLoraWeights(BatchAdapterWeights):
device=device,
)
adapter_index_configs = {
idx: adapter_weights[idx].adapter_config
for idx in segment_indices
if idx in adapter_weights
}
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
rank_indices = defaultdict(list)