Fixing the transformers backend.

inference_mode forces the use of `aten.matmul` instead of `aten.mm` the
former doesn't have sharding support crashing the transformers TP
support.

`lm_head.forward` also crashes because it skips the hook that
cast/decast the DTensor.

Torch 2.5.1 is required for sharding support.
This commit is contained in:
Nicolas Patry 2025-01-22 17:47:20 +01:00
parent 859d2f0464
commit 6fe37d61d0
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
3 changed files with 5 additions and 5 deletions

View File

@ -79,7 +79,7 @@ __all__ = [
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = False
FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM

View File

@ -262,6 +262,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
# To update with full Transformers support asap
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.model.lm_head.forward(hidden_states)
logits = self.model.lm_head(hidden_states)
return logits, None

View File

@ -68,9 +68,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.quantize = model.quantize
self.server_urls = server_urls
# For some reason, inference_mode does not work well with GLOO which we use on CPU
if model.device.type == "cuda":
# Force inference mode for the lifetime of TextGenerationService
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
# if model.device.type == "cuda":
# # Force inference mode for the lifetime of TextGenerationService
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def Info(self, request, context):
return self.model.info