text-generation-inference/server/text_generation_server/models/globals.py
Nicolas Patry 1b86d0f31d Using flash decoding
Conditional flashdecoding.

Fix max_q.

Working kvcache

Working version with flash decoding.

Make it work for mistral.
2024-05-24 13:58:08 +00:00

31 lines
828 B
Python

import torch
import os
from loguru import logger
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
cuda_graphs = os.getenv("CUDA_GRAPHS")
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
if FLASH_DECODING:
logger.info("Using FLASH_DECODING")
if cuda_graphs is not None:
try:
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
except Exception as e:
raise RuntimeError(
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
)
else:
cuda_graphs = None
CUDA_GRAPHS = cuda_graphs
# This is overridden at model loading.
global MODEL_ID
MODEL_ID = None
def set_model_id(model_id: str):
global MODEL_ID
MODEL_ID = model_id