Checking our device.

This commit is contained in:
Nicolas Patry 2023-04-26 12:19:32 +02:00
parent e1867079fd
commit e28b5bf460

View File

@ -30,6 +30,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return self.model.info
async def Health(self, request, context):
if self.model.device.type == "cuda":
torch.zeros((2, 2)).to(device=f"cuda:{os.environ['RANK']}")
return generate_pb2.HealthResponse()
async def ServiceDiscovery(self, request, context):