mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
working tunable
This commit is contained in:
parent
2677bf856a
commit
ff5e16b0e2
@ -17,7 +17,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -769,10 +769,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
|
||||
logger.info("calling self.generate_token(batch)")
|
||||
_, batch, _ = self.generate_token(batch)
|
||||
logger.info("end it")
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
|
||||
@ -814,6 +811,21 @@ class FlashCausalLM(Model):
|
||||
self.device,
|
||||
)
|
||||
|
||||
if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||
if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
|
||||
torch.cuda.tunable.tuning_enable(True)
|
||||
|
||||
tuning_sequences = range(1, 8)
|
||||
tunableop_filename = f"tunableop_tp{self.world_size}_rank{self.rank}.csv"
|
||||
|
||||
logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join(tuning_sequences)}.")
|
||||
torch.cuda.tunable.read_file(tunableop_filename)
|
||||
|
||||
for seqlen in range(1, 8):
|
||||
self.tunableop_warmup(seqlen)
|
||||
torch.cuda.tunable.write_file(tunableop_filename)
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
|
||||
if CUDA_GRAPHS:
|
||||
try:
|
||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||
@ -826,48 +838,24 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
|
||||
|
||||
# if IS_ROCM_SYSTEM and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||
# if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
|
||||
# torch.cuda.tunable.tuning_enable(True)
|
||||
# logger.info("enable tuning here")
|
||||
|
||||
logger.info("PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes.")
|
||||
for seqlen in range(1, 3):
|
||||
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
|
||||
self.tunableop_warmup(seqlen, max_s, max_bt)
|
||||
logger.info("call write file")
|
||||
torch.cuda.tunable.write_file()
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
|
||||
logger.info("finished tunable op")
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def tunableop_warmup(self, seqlen: int, max_s: int, max_bt: int):
|
||||
def tunableop_warmup(self, seqlen: int):
|
||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
|
||||
input_lengths = (
|
||||
torch.ones(seqlen, dtype=torch.int32, device=self.device) * max_s
|
||||
)
|
||||
bs = 1
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
.repeat(bs)
|
||||
.reshape((bs, max_bt))
|
||||
)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
logger.info("call self.model.forward")
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=torch.tensor([0, seqlen], device=self.device, dtype=torch.int32),
|
||||
kv_cache=get_cache_manager().kv_cache,
|
||||
block_tables=None,
|
||||
input_lengths=None,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
max_s=seqlen,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
|
||||
|
@ -10,7 +10,6 @@ try:
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}")
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
|
Loading…
Reference in New Issue
Block a user