fix(server): do not warp prefill logits

This commit is contained in:
OlivierDehaene 2023-03-09 11:33:28 +01:00
parent 941cd42e0c
commit f49786ccba

View File

@ -70,6 +70,10 @@ class NextTokenChooser:
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
# Warp logits # Warp logits
if scores.shape[0] > 1:
# only warp the last token logits
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
else:
scores = self.warpers(input_ids, scores) scores = self.warpers(input_ids, scores)
# Compute logprobs # Compute logprobs