cohere fix

This commit is contained in:
Cyril Vallez 2025-01-23 12:49:30 +00:00 committed by Nicolas Patry
parent f4dc44b88c
commit bcd9d3a5cb
No known key found for this signature in database
GPG Key ID: D2920555C90F704C

View File

@ -263,5 +263,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
# For Granite while next transformers version is released and we can use `lm_head_indices` natively
if hasattr(self.model.config, "logits_scaling"):
logits = logits / self.model.config.logits_scaling
# For Cohere for similar reasons
elif hasattr(self.model, "logit_scale"):
logits = logits * self.model.logit_scale
return logits, None