mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fixing the transformers backend.
inference_mode forces the use of `aten.matmul` instead of `aten.mm` the former doesn't have sharding support crashing the transformers TP support. `lm_head.forward` also crashes because it skips the hook that cast/decast the DTensor. Torch 2.5.1 is required for sharding support.
This commit is contained in:
parent
859d2f0464
commit
6fe37d61d0
@ -79,7 +79,7 @@ __all__ = [
|
|||||||
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
|
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
|
@ -262,6 +262,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
# To update with full Transformers support asap
|
# To update with full Transformers support asap
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.model.lm_head.forward(hidden_states)
|
logits = self.model.lm_head(hidden_states)
|
||||||
|
|
||||||
return logits, None
|
return logits, None
|
||||||
|
@ -68,9 +68,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
self.quantize = model.quantize
|
self.quantize = model.quantize
|
||||||
self.server_urls = server_urls
|
self.server_urls = server_urls
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
if model.device.type == "cuda":
|
# if model.device.type == "cuda":
|
||||||
# Force inference mode for the lifetime of TextGenerationService
|
# # Force inference mode for the lifetime of TextGenerationService
|
||||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||||
|
|
||||||
async def Info(self, request, context):
|
async def Info(self, request, context):
|
||||||
return self.model.info
|
return self.model.info
|
||||||
|
Loading…
Reference in New Issue
Block a user