better handling of inference mode

This commit is contained in:
OlivierDehaene 2023-02-07 15:23:20 +01:00
parent d8b84cc025
commit 5fb826ca14
4 changed files with 20 additions and 22 deletions

View File

@ -28,6 +28,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
# Disable gradients
torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: bool

View File

@ -289,17 +289,12 @@ class CausalLM(Model):
def generate_token(
self, batch: CausalLMBatch
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
logits, past = self.forward(
batch.input_ids,
batch.attention_mask,
batch.position_ids,
batch.past_key_values,
)
with context_manager():
logits, past = self.forward(
batch.input_ids,
batch.attention_mask,
batch.position_ids,
batch.past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []

View File

@ -364,19 +364,14 @@ class Seq2SeqLM(Model):
def generate_token(
self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager = (
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids,
batch.attention_mask,
batch.decoder_input_ids,
batch.decoder_attention_mask,
batch.encoder_last_hidden_state,
batch.past_key_values,
)
with context_manager():
logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids,
batch.attention_mask,
batch.decoder_input_ids,
batch.decoder_attention_mask,
batch.encoder_last_hidden_state,
batch.past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []

View File

@ -1,5 +1,6 @@
import asyncio
import os
import torch
from grpc import aio
from loguru import logger
@ -19,6 +20,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.cache = cache
self.model = model
self.server_urls = server_urls
# For some reason, inference_mode does not work well with GLOO which we use on CPU
if model.device.type == "cuda":
# Force inference mode for the lifetime of TextGenerationService
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def ServiceDiscovery(self, request, context):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)