fix all_input_ids shape

This commit is contained in:
OlivierDehaene 2023-02-01 15:30:09 +01:00
parent 34fc1e5cc6
commit c25fd1e2e8
4 changed files with 3 additions and 6 deletions

View File

@ -7,6 +7,7 @@ from text_generation.pb import generate_pb2
def default_pb_parameters():
return generate_pb2.NextTokenChooserParameters(
temperature=1.0,
repetition_penalty=1.0,
top_k=0,
top_p=1.0,
do_sample=False,

View File

@ -336,7 +336,7 @@ class CausalLM(Model):
all_input_ids,
) in enumerate(iterator):
# Select next token
tokens, logprobs = next_token_chooser(all_input_ids, logits)
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
next_token_id = tokens[-1].view(1, 1)
# Append next token to all tokens

View File

@ -418,7 +418,7 @@ class Seq2SeqLM(Model):
decoder_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits)
next_token_id, logprobs = next_token_chooser(decoder_input_ids.view(1, -1), logits)
# Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])

View File

@ -78,12 +78,8 @@ class NextTokenChooser:
self.choice = Sampling(seed, device) if sampling else Greedy()
def __call__(self, input_ids, scores):
input_ids = input_ids.unsqueeze(0)
scores = scores.unsqueeze(-1)
# Warp logits
scores = self.warpers(input_ids, scores)
scores = scores.squeeze(-1)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)