diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 36c837cc..dd116c9e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -79,7 +79,7 @@ __all__ = [ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." -FLASH_ATTENTION = False +FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 45d51d82..6f242ca4 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -262,6 +262,6 @@ class TransformersFlashCausalLM(FlashCausalLM): # To update with full Transformers support asap if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits = self.model.lm_head.forward(hidden_states) + logits = self.model.lm_head(hidden_states) return logits, None diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 45b48df8..935e0985 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -68,9 +68,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): self.quantize = model.quantize self.server_urls = server_urls # For some reason, inference_mode does not work well with GLOO which we use on CPU - if model.device.type == "cuda": - # Force inference mode for the lifetime of TextGenerationService - self._inference_mode_raii_guard = torch._C._InferenceMode(True) + # if model.device.type == "cuda": + # # Force inference mode for the lifetime of TextGenerationService + # self._inference_mode_raii_guard = torch._C._InferenceMode(True) async def Info(self, request, context): return self.model.info