feat(server): modify nccl init_method

This commit is contained in:
OlivierDehaene 2023-02-07 14:38:43 +01:00
parent e114d87486
commit d8b84cc025
3 changed files with 15 additions and 3 deletions

View File

@ -38,9 +38,9 @@ struct Args {
port: u16, port: u16,
#[clap(default_value = "/tmp/text-generation-server", long, env)] #[clap(default_value = "/tmp/text-generation-server", long, env)]
shard_uds_path: String, shard_uds_path: String,
#[clap(default_value = "localhost", long, env)] #[clap(default_value = "0.0.0.0", long, env)]
master_addr: String, master_addr: String,
#[clap(default_value = "29500", long, env)] #[clap(default_value = "6000", long, env)]
master_port: usize, master_port: usize,
#[clap(long, env)] #[clap(long, env)]
json_output: bool, json_output: bool,

View File

@ -89,7 +89,11 @@ def serve(
local_url = unix_socket_template.format(uds_path, 0) local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url] server_urls = [local_url]
try:
model = get_model(model_id, revision, sharded, quantize) model = get_model(model_id, revision, sharded, quantize)
except Exception:
logger.exception("Error when initializing model")
raise
server = aio.server(interceptors=[ExceptionInterceptor()]) server = aio.server(interceptors=[ExceptionInterceptor()])
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
@ -101,8 +105,11 @@ def serve(
) )
reflection.enable_server_reflection(SERVICE_NAMES, server) reflection.enable_server_reflection(SERVICE_NAMES, server)
server.add_insecure_port(local_url) server.add_insecure_port(local_url)
await server.start() await server.start()
logger.info("Server started at {}".format(local_url)) logger.info("Server started at {}".format(local_url))
try: try:
await server.wait_for_termination() await server.wait_for_termination()
except KeyboardInterrupt: except KeyboardInterrupt:

View File

@ -171,9 +171,14 @@ def initialize_torch_distributed():
else: else:
backend = "gloo" backend = "gloo"
master_ip = os.getenv("MASTER_ADDR", "0.0.0.0")
master_port = os.getenv("MASTER_PORT", "6000")
init_method = f"tcp://{master_ip}:{master_port}"
# Call the init process. # Call the init process.
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend=backend,
init_method=init_method,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
timeout=timedelta(seconds=60), timeout=timedelta(seconds=60),