mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
better handling of inference mode
This commit is contained in:
parent
d8b84cc025
commit
5fb826ca14
@ -28,6 +28,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
|
||||
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Disable gradients
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def get_model(
|
||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||
|
@ -289,11 +289,6 @@ class CausalLM(Model):
|
||||
def generate_token(
|
||||
self, batch: CausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||
context_manager = (
|
||||
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
||||
)
|
||||
with context_manager():
|
||||
logits, past = self.forward(
|
||||
batch.input_ids,
|
||||
batch.attention_mask,
|
||||
|
@ -364,11 +364,6 @@ class Seq2SeqLM(Model):
|
||||
def generate_token(
|
||||
self, batch: Seq2SeqLMBatch
|
||||
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||
context_manager = (
|
||||
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
||||
)
|
||||
with context_manager():
|
||||
logits, encoder_last_hidden_state, past = self.forward(
|
||||
batch.input_ids,
|
||||
batch.attention_mask,
|
||||
|
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import torch
|
||||
|
||||
from grpc import aio
|
||||
from loguru import logger
|
||||
@ -19,6 +20,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
self.cache = cache
|
||||
self.model = model
|
||||
self.server_urls = server_urls
|
||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||
if model.device.type == "cuda":
|
||||
# Force inference mode for the lifetime of TextGenerationService
|
||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||
|
||||
async def ServiceDiscovery(self, request, context):
|
||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||
|
Loading…
Reference in New Issue
Block a user