fix: adjust batch for bgmv

This commit is contained in:
drbh 2024-06-06 17:45:08 +00:00
parent 8984ce6c69
commit ad088d51fa
2 changed files with 32 additions and 22 deletions

View File

@ -1205,8 +1205,13 @@ class FlashCausalLM(Model):
if prefill_logprobs if prefill_logprobs
else speculative_logits else speculative_logits
) )
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
len(batch)
)
else: else:
next_token_logits = out next_token_logits = out
next_adapter_indices = batch.adapter_meta.adapter_indices
speculate = get_speculate() speculate = get_speculate()
( (
@ -1228,6 +1233,14 @@ class FlashCausalLM(Model):
) )
if prefill: if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
if len(batch) > 1 and prefill_logprobs: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly # When batch == 1, we will just use the batch.input_ids values directly
@ -1297,6 +1310,7 @@ class FlashCausalLM(Model):
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices
if prefill and prefill_logprobs: if prefill and prefill_logprobs:
# Get prefill logprobs # Get prefill logprobs

View File

@ -239,27 +239,23 @@ def serve(
max_input_tokens, max_input_tokens,
) )
if len(lora_adapter_ids) > 0:
for index, adapter_id in enumerate(lora_adapter_ids):
# TODO: avoid hacky hardcoded adapter id # TODO: avoid hacky hardcoded adapter id
adapter_parameters = AdapterParameters( adapter_parameters = AdapterParameters(
adapter_ids=lora_adapter_ids, adapter_ids=[adapter_id],
weights=[ weights=[],
# TODO: fill with actual weights
torch.tensor([1.0], dtype=torch.float32)
],
merge_strategy=0, merge_strategy=0,
density=0.0, density=1.0,
majority_sign_method=0, majority_sign_method=0,
) )
adapter_source = None adapter_index = index
adapter_index = 0
api_token = None
model.load_adapter( model.load_adapter(
adapter_parameters, adapter_parameters,
adapter_source, None, # adapter_source
adapter_index, adapter_index,
api_token, None, # api_token
False, False, # dynamic
) )
except Exception: except Exception: