mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fix all_input_ids shape
This commit is contained in:
parent
34fc1e5cc6
commit
c25fd1e2e8
@ -7,6 +7,7 @@ from text_generation.pb import generate_pb2
|
||||
def default_pb_parameters():
|
||||
return generate_pb2.NextTokenChooserParameters(
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.0,
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
do_sample=False,
|
||||
|
@ -336,7 +336,7 @@ class CausalLM(Model):
|
||||
all_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
tokens, logprobs = next_token_chooser(all_input_ids, logits)
|
||||
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
|
||||
next_token_id = tokens[-1].view(1, 1)
|
||||
|
||||
# Append next token to all tokens
|
||||
|
@ -418,7 +418,7 @@ class Seq2SeqLM(Model):
|
||||
decoder_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
|
||||
next_token_id, logprobs = next_token_chooser(decoder_input_ids.view(1, -1), logits)
|
||||
|
||||
# Append next token to decoder tokens
|
||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
|
||||
|
@ -78,12 +78,8 @@ class NextTokenChooser:
|
||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
scores = scores.unsqueeze(-1)
|
||||
|
||||
# Warp logits
|
||||
scores = self.warpers(input_ids, scores)
|
||||
scores = scores.squeeze(-1)
|
||||
|
||||
# Compute logprobs
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
|
Loading…
Reference in New Issue
Block a user