fix: adjust batch for bgmv

This commit is contained in:
drbh 2024-06-06 17:45:08 +00:00
parent 8984ce6c69
commit ad088d51fa
2 changed files with 32 additions and 22 deletions

View File

@ -1205,8 +1205,13 @@ class FlashCausalLM(Model):
if prefill_logprobs if prefill_logprobs
else speculative_logits else speculative_logits
) )
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
len(batch)
)
else: else:
next_token_logits = out next_token_logits = out
next_adapter_indices = batch.adapter_meta.adapter_indices
speculate = get_speculate() speculate = get_speculate()
( (
@ -1228,6 +1233,14 @@ class FlashCausalLM(Model):
) )
if prefill: if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
if len(batch) > 1 and prefill_logprobs: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly # When batch == 1, we will just use the batch.input_ids values directly
@ -1297,6 +1310,7 @@ class FlashCausalLM(Model):
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices
if prefill and prefill_logprobs: if prefill and prefill_logprobs:
# Get prefill logprobs # Get prefill logprobs

View File

@ -239,28 +239,24 @@ def serve(
max_input_tokens, max_input_tokens,
) )
# TODO: avoid hacky hardcoded adapter id if len(lora_adapter_ids) > 0:
adapter_parameters = AdapterParameters( for index, adapter_id in enumerate(lora_adapter_ids):
adapter_ids=lora_adapter_ids, # TODO: avoid hacky hardcoded adapter id
weights=[ adapter_parameters = AdapterParameters(
# TODO: fill with actual weights adapter_ids=[adapter_id],
torch.tensor([1.0], dtype=torch.float32) weights=[],
], merge_strategy=0,
merge_strategy=0, density=1.0,
density=0.0, majority_sign_method=0,
majority_sign_method=0, )
) adapter_index = index
adapter_source = None model.load_adapter(
adapter_index = 0 adapter_parameters,
api_token = None None, # adapter_source
adapter_index,
model.load_adapter( None, # api_token
adapter_parameters, False, # dynamic
adapter_source, )
adapter_index,
api_token,
False,
)
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")