mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fixes
This commit is contained in:
parent
a5bf08f6e2
commit
e29bb90e88
@ -441,7 +441,7 @@ class VectorizedNextTokenChooser:
|
||||
if return_logprobs:
|
||||
# Compute logprobs
|
||||
if scores.size(1)==1:
|
||||
scores=scores.unsqueeze(1)
|
||||
scores=last_token_scores.unsqueeze(1)
|
||||
else:
|
||||
# TODO: Post-process all the tokens?
|
||||
scores[:, -1, :]=last_token_scores
|
||||
@ -560,6 +560,9 @@ class VectorizedCausalLM(Model):
|
||||
self, batch: VectorizedCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[VectorizedCausalLMBatch]]:
|
||||
key_length=batch.max_input_length
|
||||
if key_length>batch.input_ids.size(1):
|
||||
raise RuntimeError("Cannot generate more than `max_tokens`.")
|
||||
|
||||
query_length=key_length if batch.past_key_values is None else 1
|
||||
input_ids=batch.input_ids[:, key_length-query_length: key_length]
|
||||
|
||||
@ -572,16 +575,15 @@ class VectorizedCausalLM(Model):
|
||||
# TODO: Post-processing
|
||||
next_token_ids, logprobs = batch.next_token_chooser(input_ids, outputs.logits, batch.details)
|
||||
|
||||
next_token_ids=next_token_ids.cpu().tolist()
|
||||
if batch.generate_stream:
|
||||
# TODO: self.decode_token, offsets?
|
||||
next_token_texts=self.tokenizer.batch_decode(next_token_ids)
|
||||
next_token_texts=self.tokenizer.batch_decode(next_token_ids.tolist())
|
||||
|
||||
if batch.details:
|
||||
token_logprobs=logprobs[:, -1, :].gather(1, next_token_ids.unsqueeze(1)).squeeze(1).tolist()
|
||||
if query_length>1:
|
||||
prefill_token_ids=batch.input_ids[:, :key_length].tolist()
|
||||
prefill_logprobs=logprobs.gather(1, batch.input_ids[:, 1:key_length, None]).squeeze(2).tolist()
|
||||
prefill_logprobs=logprobs.gather(2, batch.input_ids[:, 1:key_length, None]).squeeze(2).tolist()
|
||||
prefill_tokens=[]
|
||||
for prefill_token_ids_, prefill_logprobs_, input_length in zip(prefill_token_ids, prefill_logprobs, batch.input_lengths):
|
||||
prefill_token_ids_=prefill_token_ids_[-input_length:]
|
||||
|
Loading…
Reference in New Issue
Block a user