mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
fix(server): better handling of inference mode (#57)
This commit is contained in:
parent
e114d87486
commit
4acc42a605
@ -38,9 +38,9 @@ struct Args {
|
|||||||
port: u16,
|
port: u16,
|
||||||
#[clap(default_value = "/tmp/text-generation-server", long, env)]
|
#[clap(default_value = "/tmp/text-generation-server", long, env)]
|
||||||
shard_uds_path: String,
|
shard_uds_path: String,
|
||||||
#[clap(default_value = "localhost", long, env)]
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
master_addr: String,
|
master_addr: String,
|
||||||
#[clap(default_value = "29500", long, env)]
|
#[clap(default_value = "6000", long, env)]
|
||||||
master_port: usize,
|
master_port: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
json_output: bool,
|
json_output: bool,
|
||||||
|
@ -28,6 +28,9 @@ torch.backends.cuda.matmul.allow_tf32 = True
|
|||||||
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
# Disable gradients
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||||
|
@ -289,11 +289,6 @@ class CausalLM(Model):
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: CausalLMBatch
|
self, batch: CausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
|
||||||
context_manager = (
|
|
||||||
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
|
||||||
)
|
|
||||||
with context_manager():
|
|
||||||
logits, past = self.forward(
|
logits, past = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
batch.attention_mask,
|
batch.attention_mask,
|
||||||
|
@ -364,11 +364,6 @@ class Seq2SeqLM(Model):
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: Seq2SeqLMBatch
|
self, batch: Seq2SeqLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
|
||||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
|
||||||
context_manager = (
|
|
||||||
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
|
||||||
)
|
|
||||||
with context_manager():
|
|
||||||
logits, encoder_last_hidden_state, past = self.forward(
|
logits, encoder_last_hidden_state, past = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
batch.attention_mask,
|
batch.attention_mask,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
from grpc import aio
|
from grpc import aio
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -19,6 +20,10 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
self.cache = cache
|
self.cache = cache
|
||||||
self.model = model
|
self.model = model
|
||||||
self.server_urls = server_urls
|
self.server_urls = server_urls
|
||||||
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
# Force inference mode for the lifetime of TextGenerationService
|
||||||
|
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||||
|
|
||||||
async def ServiceDiscovery(self, request, context):
|
async def ServiceDiscovery(self, request, context):
|
||||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||||
@ -89,7 +94,11 @@ def serve(
|
|||||||
local_url = unix_socket_template.format(uds_path, 0)
|
local_url = unix_socket_template.format(uds_path, 0)
|
||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
|
try:
|
||||||
model = get_model(model_id, revision, sharded, quantize)
|
model = get_model(model_id, revision, sharded, quantize)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error when initializing model")
|
||||||
|
raise
|
||||||
|
|
||||||
server = aio.server(interceptors=[ExceptionInterceptor()])
|
server = aio.server(interceptors=[ExceptionInterceptor()])
|
||||||
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
||||||
@ -101,8 +110,11 @@ def serve(
|
|||||||
)
|
)
|
||||||
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
||||||
server.add_insecure_port(local_url)
|
server.add_insecure_port(local_url)
|
||||||
|
|
||||||
await server.start()
|
await server.start()
|
||||||
|
|
||||||
logger.info("Server started at {}".format(local_url))
|
logger.info("Server started at {}".format(local_url))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await server.wait_for_termination()
|
await server.wait_for_termination()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
@ -171,9 +171,14 @@ def initialize_torch_distributed():
|
|||||||
else:
|
else:
|
||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
|
|
||||||
|
master_ip = os.getenv("MASTER_ADDR", "0.0.0.0")
|
||||||
|
master_port = os.getenv("MASTER_PORT", "6000")
|
||||||
|
init_method = f"tcp://{master_ip}:{master_port}"
|
||||||
|
|
||||||
# Call the init process.
|
# Call the init process.
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
|
init_method=init_method,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
timeout=timedelta(seconds=60),
|
timeout=timedelta(seconds=60),
|
||||||
|
Loading…
Reference in New Issue
Block a user