Fix top_n_tokens > 0

Fix issue #1340
This commit is contained in:
gduhamel 2024-01-19 20:50:17 +01:00
parent 3ccb3bb0b5
commit fd8b42678d
5 changed files with 7 additions and 5 deletions

View File

@ -719,7 +719,7 @@ class CausalLM(Model):
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
[top_tokens],
)
generations.append(generation)

View File

@ -1089,7 +1089,7 @@ class FlashCausalLM(Model):
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text,
top_tokens,
[top_tokens],
)
generations.append(generation)

View File

@ -809,7 +809,7 @@ class IdeficsCausalLM(Model):
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
[top_tokens],
)
generations.append(generation)

View File

@ -773,7 +773,7 @@ class Seq2SeqLM(Model):
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
[top_tokens],
)
generations.append(generation)

View File

@ -95,5 +95,7 @@ class Generation:
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
top_tokens=([top_token.to_pb() for top_token in self.top_tokens]
if self.top_tokens is not None
else None),
)