fix: set sharded true if WORLD_SIZE is set

This commit is contained in:
drbh 2024-06-12 17:12:18 +00:00
parent cdbf802860
commit 9854f20225

View File

@ -44,6 +44,9 @@ def serve(
otlp_endpoint: Optional[str] = None,
max_input_tokens: Optional[int] = None,
):
# derive sharded from environment variables if not provided
sharded = sharded or os.getenv("WORLD_SIZE", None) is not None
if sharded:
assert (
os.getenv("RANK", None) is not None