working & cached tunableop

This commit is contained in:
fxmarty 2024-04-29 14:55:59 +00:00
parent 193dbb683e
commit 17f5c3078b
4 changed files with 32 additions and 23 deletions

View File

@ -672,9 +672,7 @@ fn shard_manager(
// We received a shutdown signal // We received a shutdown signal
if shutdown.load(Ordering::SeqCst) { if shutdown.load(Ordering::SeqCst) {
p.kill().unwrap(); terminate("shard", p, Duration::from_secs(90)).unwrap();
let _ = p.wait();
tracing::info!("Shard terminated");
return; return;
} }

View File

@ -766,6 +766,9 @@ class FlashCausalLM(Model):
) )
max_bt = batch.max_blocks max_bt = batch.max_blocks
max_s = max_bt * get_cache_manager().block_size max_s = max_bt * get_cache_manager().block_size
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
logger.info("PyTorch TunableOp (https://github.com/pytorch/pytorch/tree/v2.3.0/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.")
_, batch, _ = self.generate_token(batch) _, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e: except torch.cuda.OutOfMemoryError as e:
raise RuntimeError( raise RuntimeError(
@ -820,10 +823,10 @@ class FlashCausalLM(Model):
else: else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
# TODO: fix if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
if IS_ROCM_SYSTEM and False: total_seqlens = list(range(2))
total_seqlens = list(range(16))
for seqlen in total_seqlens: for seqlen in total_seqlens:
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen, max_s, max_bt) self.tunableop_warmup(seqlen, max_s, max_bt)
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)

View File

@ -166,6 +166,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
total_ns=time.time_ns() - start, total_ns=time.time_ns() - start,
) )
import signal
class SignalHandler:
KEEP_PROCESSING = True
def __init__(self):
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
def exit_gracefully(self, signum, frame):
print(f"Exiting gracefully: Signal {signum}")
self.KEEP_PROCESSING = False
signal_handler = SignalHandler()
def serve( def serve(
model_id: str, model_id: str,
@ -231,11 +246,8 @@ def serve(
logger.info("Server started at {}".format(local_url)) logger.info("Server started at {}".format(local_url))
try: while signal_handler.KEEP_PROCESSING:
await server.wait_for_termination() await asyncio.sleep(0.5)
except KeyboardInterrupt:
logger.info("Signal received. Shutting down")
await server.stop(0)
asyncio.run( asyncio.run(
serve_inner( serve_inner(

View File

@ -4,6 +4,11 @@ from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SY
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
try:
from vllm._C import cache_ops
except Exception as e:
raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}")
def reshape_and_cache( def reshape_and_cache(
key: torch.Tensor, key: torch.Tensor,
@ -12,18 +17,9 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
): ):
if IS_CUDA_SYSTEM:
from vllm._C import cache_ops
cache_ops.reshape_and_cache( cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0 key, value, key_cache, value_cache, slots, "auto", 1.0
) )
elif IS_ROCM_SYSTEM:
from vllm import cache_ops
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
else:
raise ValueError("vllm is not supported on your system")
def attention( def attention(