mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
better handling of inference mode
This commit is contained in:
parent
d8b84cc025
commit
5fb826ca14
@ -28,6 +28,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
|
|||||||
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
# Disable gradients
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||||
|
@ -289,17 +289,12 @@ class CausalLM(Model):
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: CausalLMBatch
|
self, batch: CausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
logits, past = self.forward(
|
||||||
context_manager = (
|
batch.input_ids,
|
||||||
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
batch.attention_mask,
|
||||||
|
batch.position_ids,
|
||||||
|
batch.past_key_values,
|
||||||
)
|
)
|
||||||
with context_manager():
|
|
||||||
logits, past = self.forward(
|
|
||||||
batch.input_ids,
|
|
||||||
batch.attention_mask,
|
|
||||||
batch.position_ids,
|
|
||||||
batch.past_key_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
# List of indices to cache
|
# List of indices to cache
|
||||||
next_batch_keep_indices = []
|
next_batch_keep_indices = []
|
||||||
|
@ -364,19 +364,14 @@ class Seq2SeqLM(Model):
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: Seq2SeqLMBatch
|
self, batch: Seq2SeqLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
logits, encoder_last_hidden_state, past = self.forward(
|
||||||
context_manager = (
|
batch.input_ids,
|
||||||
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
batch.attention_mask,
|
||||||
|
batch.decoder_input_ids,
|
||||||
|
batch.decoder_attention_mask,
|
||||||
|
batch.encoder_last_hidden_state,
|
||||||
|
batch.past_key_values,
|
||||||
)
|
)
|
||||||
with context_manager():
|
|
||||||
logits, encoder_last_hidden_state, past = self.forward(
|
|
||||||
batch.input_ids,
|
|
||||||
batch.attention_mask,
|
|
||||||
batch.decoder_input_ids,
|
|
||||||
batch.decoder_attention_mask,
|
|
||||||
batch.encoder_last_hidden_state,
|
|
||||||
batch.past_key_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
# List of indices to cache
|
# List of indices to cache
|
||||||
next_batch_keep_indices = []
|
next_batch_keep_indices = []
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
from grpc import aio
|
from grpc import aio
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -19,6 +20,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
self.cache = cache
|
self.cache = cache
|
||||||
self.model = model
|
self.model = model
|
||||||
self.server_urls = server_urls
|
self.server_urls = server_urls
|
||||||
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
# Force inference mode for the lifetime of TextGenerationService
|
||||||
|
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||||
|
|
||||||
async def ServiceDiscovery(self, request, context):
|
async def ServiceDiscovery(self, request, context):
|
||||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||||
|
Loading…
Reference in New Issue
Block a user