This commit is contained in:
Joel Lamy-Poirier 2023-05-05 16:26:50 -04:00
parent a5bf08f6e2
commit e29bb90e88
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF

View File

@ -441,7 +441,7 @@ class VectorizedNextTokenChooser:
if return_logprobs:
# Compute logprobs
if scores.size(1)==1:
scores=scores.unsqueeze(1)
scores=last_token_scores.unsqueeze(1)
else:
# TODO: Post-process all the tokens?
scores[:, -1, :]=last_token_scores
@ -560,6 +560,9 @@ class VectorizedCausalLM(Model):
self, batch: VectorizedCausalLMBatch
) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]:
key_length=batch.max_input_length
if key_length>batch.input_ids.size(1):
raise RuntimeError("Cannot generate more than `max_tokens`.")
query_length=key_length if batch.past_key_values is None else 1
input_ids=batch.input_ids[:, key_length-query_length: key_length]
@ -572,16 +575,15 @@ class VectorizedCausalLM(Model):
# TODO: Post-processing
next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits, batch.details)
next_token_ids=next_token_ids.cpu().tolist()
if batch.generate_stream:
# TODO: self.decode_token, offsets?
next_token_texts=self.tokenizer.batch_decode(next_token_ids)
next_token_texts=self.tokenizer.batch_decode(next_token_ids.tolist())
if batch.details:
token_logprobs=logprobs[:, -1, :].gather(1, next_token_ids.unsqueeze(1)).squeeze(1).tolist()
if query_length>1:
prefill_token_ids=batch.input_ids[:, :key_length].tolist()
prefill_logprobs=logprobs.gather(1, batch.input_ids[:, 1:key_length, None]).squeeze(2).tolist()
prefill_logprobs=logprobs.gather(2, batch.input_ids[:, 1:key_length, None]).squeeze(2).tolist()
prefill_tokens=[]
for prefill_token_ids_, prefill_logprobs_, input_length in zip(prefill_token_ids, prefill_logprobs, batch.input_lengths):
prefill_token_ids_=prefill_token_ids_[-input_length:]