mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
working & cached tunableop
This commit is contained in:
parent
193dbb683e
commit
17f5c3078b
@ -672,9 +672,7 @@ fn shard_manager(
|
|||||||
|
|
||||||
// We received a shutdown signal
|
// We received a shutdown signal
|
||||||
if shutdown.load(Ordering::SeqCst) {
|
if shutdown.load(Ordering::SeqCst) {
|
||||||
p.kill().unwrap();
|
terminate("shard", p, Duration::from_secs(90)).unwrap();
|
||||||
let _ = p.wait();
|
|
||||||
tracing::info!("Shard terminated");
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -766,6 +766,9 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
max_bt = batch.max_blocks
|
max_bt = batch.max_blocks
|
||||||
max_s = max_bt * get_cache_manager().block_size
|
max_s = max_bt * get_cache_manager().block_size
|
||||||
|
|
||||||
|
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||||
|
logger.info("PyTorch TunableOp (https://github.com/pytorch/pytorch/tree/v2.3.0/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.")
|
||||||
_, batch, _ = self.generate_token(batch)
|
_, batch, _ = self.generate_token(batch)
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -820,10 +823,10 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||||
|
|
||||||
# TODO: fix
|
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||||
if IS_ROCM_SYSTEM and False:
|
total_seqlens = list(range(2))
|
||||||
total_seqlens = list(range(16))
|
|
||||||
for seqlen in total_seqlens:
|
for seqlen in total_seqlens:
|
||||||
|
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
||||||
self.tunableop_warmup(seqlen, max_s, max_bt)
|
self.tunableop_warmup(seqlen, max_s, max_bt)
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
@ -166,6 +166,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
total_ns=time.time_ns() - start,
|
total_ns=time.time_ns() - start,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import signal
|
||||||
|
|
||||||
|
class SignalHandler:
|
||||||
|
KEEP_PROCESSING = True
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
signal.signal(signal.SIGINT, self.exit_gracefully)
|
||||||
|
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
||||||
|
|
||||||
|
def exit_gracefully(self, signum, frame):
|
||||||
|
print(f"Exiting gracefully: Signal {signum}")
|
||||||
|
self.KEEP_PROCESSING = False
|
||||||
|
|
||||||
|
|
||||||
|
signal_handler = SignalHandler()
|
||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -231,11 +246,8 @@ def serve(
|
|||||||
|
|
||||||
logger.info("Server started at {}".format(local_url))
|
logger.info("Server started at {}".format(local_url))
|
||||||
|
|
||||||
try:
|
while signal_handler.KEEP_PROCESSING:
|
||||||
await server.wait_for_termination()
|
await asyncio.sleep(0.5)
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Signal received. Shutting down")
|
|
||||||
await server.stop(0)
|
|
||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(
|
serve_inner(
|
||||||
|
@ -4,6 +4,11 @@ from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SY
|
|||||||
|
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm._C import cache_ops
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}")
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
def reshape_and_cache(
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
@ -12,18 +17,9 @@ def reshape_and_cache(
|
|||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
):
|
):
|
||||||
if IS_CUDA_SYSTEM:
|
|
||||||
from vllm._C import cache_ops
|
|
||||||
|
|
||||||
cache_ops.reshape_and_cache(
|
cache_ops.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||||
)
|
)
|
||||||
elif IS_ROCM_SYSTEM:
|
|
||||||
from vllm import cache_ops
|
|
||||||
|
|
||||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
|
||||||
else:
|
|
||||||
raise ValueError("vllm is not supported on your system")
|
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
|
Loading…
Reference in New Issue
Block a user