fix(server): do not warp prefill logits

This commit is contained in:
OlivierDehaene 2023-03-09 11:33:28 +01:00
parent 941cd42e0c
commit f49786ccba

View File

@ -70,7 +70,11 @@ class NextTokenChooser:
def __call__(self, input_ids, scores):
# Warp logits
scores = self.warpers(input_ids, scores)
if scores.shape[0] > 1:
# only warp the last token logits
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
else:
scores = self.warpers(input_ids, scores)
# Compute logprobs
logprobs = torch.log_softmax(scores, -1)