mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
working & cached tunableop
This commit is contained in:
parent
193dbb683e
commit
17f5c3078b
@ -672,9 +672,7 @@ fn shard_manager(
|
||||
|
||||
// We received a shutdown signal
|
||||
if shutdown.load(Ordering::SeqCst) {
|
||||
p.kill().unwrap();
|
||||
let _ = p.wait();
|
||||
tracing::info!("Shard terminated");
|
||||
terminate("shard", p, Duration::from_secs(90)).unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -766,6 +766,9 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
max_bt = batch.max_blocks
|
||||
max_s = max_bt * get_cache_manager().block_size
|
||||
|
||||
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||
logger.info("PyTorch TunableOp (https://github.com/pytorch/pytorch/tree/v2.3.0/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.")
|
||||
_, batch, _ = self.generate_token(batch)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
raise RuntimeError(
|
||||
@ -820,10 +823,10 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||
|
||||
# TODO: fix
|
||||
if IS_ROCM_SYSTEM and False:
|
||||
total_seqlens = list(range(16))
|
||||
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||
total_seqlens = list(range(2))
|
||||
for seqlen in total_seqlens:
|
||||
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
||||
self.tunableop_warmup(seqlen, max_s, max_bt)
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
@ -166,6 +166,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
total_ns=time.time_ns() - start,
|
||||
)
|
||||
|
||||
import signal
|
||||
|
||||
class SignalHandler:
|
||||
KEEP_PROCESSING = True
|
||||
|
||||
def __init__(self):
|
||||
signal.signal(signal.SIGINT, self.exit_gracefully)
|
||||
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
||||
|
||||
def exit_gracefully(self, signum, frame):
|
||||
print(f"Exiting gracefully: Signal {signum}")
|
||||
self.KEEP_PROCESSING = False
|
||||
|
||||
|
||||
signal_handler = SignalHandler()
|
||||
|
||||
def serve(
|
||||
model_id: str,
|
||||
@ -231,11 +246,8 @@ def serve(
|
||||
|
||||
logger.info("Server started at {}".format(local_url))
|
||||
|
||||
try:
|
||||
await server.wait_for_termination()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Signal received. Shutting down")
|
||||
await server.stop(0)
|
||||
while signal_handler.KEEP_PROCESSING:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
asyncio.run(
|
||||
serve_inner(
|
||||
|
@ -4,6 +4,11 @@ from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SY
|
||||
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
try:
|
||||
from vllm._C import cache_ops
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}")
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
@ -12,18 +17,9 @@ def reshape_and_cache(
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
if IS_CUDA_SYSTEM:
|
||||
from vllm._C import cache_ops
|
||||
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import cache_ops
|
||||
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
else:
|
||||
raise ValueError("vllm is not supported on your system")
|
||||
|
||||
|
||||
def attention(
|
||||
|
Loading…
Reference in New Issue
Block a user