diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index d424de35..36de89b4 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -263,5 +263,8 @@ class TransformersFlashCausalLM(FlashCausalLM): # For Granite while next transformers version is released and we can use `lm_head_indices` natively if hasattr(self.model.config, "logits_scaling"): logits = logits / self.model.config.logits_scaling + # For Cohere for similar reasons + elif hasattr(self.model, "logit_scale"): + logits = logits * self.model.logit_scale return logits, None