mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fixing the transformers backend.
inference_mode forces the use of `aten.matmul` instead of `aten.mm` the former doesn't have sharding support crashing the transformers TP support. `lm_head.forward` also crashes because it skips the hook that cast/decast the DTensor. Torch 2.5.1 is required for sharding support.
This commit is contained in:
parent
859d2f0464
commit
6fe37d61d0
@ -79,7 +79,7 @@ __all__ = [
|
||||
|
||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||
|
||||
FLASH_ATTENTION = False
|
||||
FLASH_ATTENTION = True
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
|
@ -262,6 +262,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
# To update with full Transformers support asap
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.model.lm_head.forward(hidden_states)
|
||||
logits = self.model.lm_head(hidden_states)
|
||||
|
||||
return logits, None
|
||||
|
@ -68,9 +68,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
self.quantize = model.quantize
|
||||
self.server_urls = server_urls
|
||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||
if model.device.type == "cuda":
|
||||
# Force inference mode for the lifetime of TextGenerationService
|
||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||
# if model.device.type == "cuda":
|
||||
# # Force inference mode for the lifetime of TextGenerationService
|
||||
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||
|
||||
async def Info(self, request, context):
|
||||
return self.model.info
|
||||
|
Loading…
Reference in New Issue
Block a user