working tunable

This commit is contained in:
fxmarty 2024-05-02 13:29:20 +00:00
parent 2677bf856a
commit ff5e16b0e2
3 changed files with 22 additions and 36 deletions

View File

@ -17,7 +17,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed

View File

@ -769,10 +769,7 @@ class FlashCausalLM(Model):
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
torch.cuda.tunable.tuning_enable(False)
logger.info("calling self.generate_token(batch)")
_, batch, _ = self.generate_token(batch)
logger.info("end it")
except torch.cuda.OutOfMemoryError as e:
raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
@ -814,6 +811,21 @@ class FlashCausalLM(Model):
self.device,
)
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
torch.cuda.tunable.tuning_enable(True)
tuning_sequences = range(1, 8)
tunableop_filename = f"tunableop_tp{self.world_size}_rank{self.rank}.csv"
logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join(tuning_sequences)}.")
torch.cuda.tunable.read_file(tunableop_filename)
for seqlen in range(1, 8):
self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filename)
torch.cuda.tunable.tuning_enable(False)
if CUDA_GRAPHS:
try:
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
@ -826,48 +838,24 @@ class FlashCausalLM(Model):
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
# if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
# if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
# torch.cuda.tunable.tuning_enable(True)
# logger.info("enable tuning here")
logger.info("PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.")
for seqlen in range(1, 3):
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen, max_s, max_bt)
logger.info("call write file")
torch.cuda.tunable.write_file()
torch.cuda.tunable.tuning_enable(False)
logger.info("finished tunable op")
return int(num_blocks * BLOCK_SIZE)
def tunableop_warmup(self, seqlen: int, max_s: int, max_bt: int):
def tunableop_warmup(self, seqlen: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
input_lengths = (
torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s
)
bs = 1
block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device)
.repeat(bs)
.reshape((bs, max_bt))
)
kv_cache = get_cache_manager().kv_cache
logger.info("call self.model.forward")
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
cu_seqlen_prefill=torch.tensor([0, seqlen], device=self.device, dtype=torch.int32),
kv_cache=get_cache_manager().kv_cache,
block_tables=None,
input_lengths=None,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
max_s=seqlen,
lm_head_indices=None,
)

View File

@ -10,7 +10,6 @@ try:
except Exception as e:
raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}")
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,