2023-06-30 17:09:59 +00:00
import math
2024-02-12 09:09:29 +00:00
import os
2023-12-14 14:59:38 +00:00
import time
2023-06-30 17:09:59 +00:00
import itertools
2023-04-03 17:06:42 +00:00
import torch
import torch . distributed
2023-05-09 16:26:19 +00:00
import numpy as np
2024-02-12 09:09:29 +00:00
from loguru import logger
2023-04-03 17:06:42 +00:00
from dataclasses import dataclass
from opentelemetry import trace
2023-06-30 17:09:59 +00:00
from transformers import PreTrainedTokenizerBase
2024-05-31 11:51:42 +00:00
from typing import Iterable , Optional , Tuple , List , Type , Dict
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 13:30:47 +00:00
from huggingface_hub . constants import HUGGINGFACE_HUB_CACHE
2024-05-31 11:51:42 +00:00
from text_generation_server . utils . chunks import concat_text_chunks
2024-06-25 10:21:29 +00:00
from text_generation_server . utils . import_utils import SYSTEM , IPEX_AVAIL
2023-12-11 13:49:52 +00:00
from text_generation_server . models import Model
2023-12-14 14:59:38 +00:00
from text_generation_server . utils . tokens import batch_top_tokens
2024-05-23 13:40:40 +00:00
from text_generation_server . utils . dist import RANK
2023-12-11 11:46:30 +00:00
from text_generation_server . utils . speculate import get_speculate
2023-04-03 17:06:42 +00:00
from text_generation_server . models . types import (
Batch ,
2023-12-11 11:46:30 +00:00
Tokens ,
2023-04-03 17:06:42 +00:00
Generation ,
GeneratedText ,
)
from text_generation_server . pb import generate_pb2
2024-04-04 21:01:56 +00:00
from text_generation_server . models . globals import MEM_POOL , CUDA_GRAPHS
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 13:30:47 +00:00
import text_generation_server . models . globals as tgi_globals
2023-05-26 10:30:27 +00:00
from text_generation_server . utils import StoppingCriteria , HeterogeneousNextTokenChooser
2023-07-24 09:43:58 +00:00
from text_generation_server . utils . dist import MEMORY_FRACTION
2023-04-03 17:06:42 +00:00
2024-04-26 17:19:55 +00:00
from text_generation_server . utils . import_utils import (
2024-05-13 10:44:30 +00:00
empty_cache ,
synchronize ,
get_free_memory ,
2024-04-26 17:19:55 +00:00
)
2024-05-13 10:44:30 +00:00
tracer = trace . get_tracer ( __name__ )
2024-06-05 10:18:38 +00:00
BLOCK_SIZE : int = 16
# Will be set in init
SLIDING_WINDOW : Optional [ int ] = None
def set_sliding_window ( sliding_window : int ) :
global SLIDING_WINDOW
SLIDING_WINDOW = sliding_window
def get_sliding_windows ( ) - > int :
global SLIDING_WINDOW
return SLIDING_WINDOW
2024-02-12 09:09:29 +00:00
2023-04-03 17:06:42 +00:00
@dataclass
class FlashCausalLMBatch ( Batch ) :
batch_id : int
requests : List [ generate_pb2 . Request ]
2023-04-20 09:07:40 +00:00
# request id -> idx in list mapping
requests_idx_mapping : Dict [ int , int ]
2023-04-03 17:06:42 +00:00
# Decoder values
2023-05-09 16:26:19 +00:00
input_ids : torch . Tensor
position_ids : torch . Tensor
2024-06-05 10:18:38 +00:00
speculative_ids : Optional [ torch . Tensor ]
2023-05-09 16:26:19 +00:00
2023-07-04 18:23:55 +00:00
# Flash Attention values
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill : Optional [ torch . Tensor ]
2024-06-05 10:18:38 +00:00
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices : Optional [ torch . Tensor ]
2023-06-30 17:09:59 +00:00
# Paged Attention values
# Set when creating the batch
# CPU tensor of length b indicating the start of each sequence in slots
start_slots : torch . Tensor
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices : torch . Tensor
# list of length b of list of length s_i // block_size
2024-06-05 10:18:38 +00:00
block_tables : List [ List [ int ] ]
2024-02-12 09:09:29 +00:00
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
2024-06-05 10:18:38 +00:00
block_tables_tensor : torch . Tensor
2023-06-30 17:09:59 +00:00
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
2024-06-05 10:18:38 +00:00
slots : torch . Tensor
2023-06-30 17:09:59 +00:00
2023-04-03 17:06:42 +00:00
max_seqlen : int
2023-06-02 15:12:30 +00:00
# Prefill metadata tensors to efficiently compute logprobs
prefill_head_indices : Optional [ torch . Tensor ]
prefill_next_token_indices : Optional [ torch . tensor ]
prefill_cu_outlens : Optional [ List [ int ] ]
2023-04-03 17:06:42 +00:00
# All tokens
all_input_ids : List [ List [ int ] ]
2023-05-26 10:30:27 +00:00
all_input_ids_tensor : torch . Tensor
2023-04-03 17:06:42 +00:00
# Lengths of all generations present in the batch
input_lengths : List [ int ]
2023-06-30 17:09:59 +00:00
input_lengths_tensor : torch . Tensor
2023-05-16 21:23:27 +00:00
prefix_offsets : List [ Optional [ int ] ]
read_offsets : List [ Optional [ int ] ]
2023-04-03 17:06:42 +00:00
# Generation helpers
2023-05-26 10:30:27 +00:00
next_token_chooser : HeterogeneousNextTokenChooser
2023-04-03 17:06:42 +00:00
stopping_criterias : List [ StoppingCriteria ]
2023-08-28 09:43:47 +00:00
top_n_tokens : List [ int ]
top_n_tokens_tensor : torch . Tensor
2023-04-03 17:06:42 +00:00
2023-06-30 17:09:59 +00:00
# Number of blocks in this batch
2024-06-05 10:18:38 +00:00
num_blocks : int
2023-06-30 17:09:59 +00:00
# Maximum number of blocks
max_blocks : int
2023-04-24 15:59:00 +00:00
2023-05-24 17:19:57 +00:00
def to_pb ( self ) - > generate_pb2 . CachedBatch :
return generate_pb2 . CachedBatch (
2023-04-24 15:59:00 +00:00
id = self . batch_id ,
2023-05-24 17:19:57 +00:00
request_ids = [ r . id for r in self . requests ] ,
2023-04-24 15:59:00 +00:00
size = len ( self ) ,
2024-06-05 10:18:38 +00:00
max_tokens = self . num_blocks * BLOCK_SIZE ,
2023-04-03 17:06:42 +00:00
)
@classmethod
2024-05-31 11:51:42 +00:00
def batch_tokenized_inputs (
cls , requests : Iterable [ generate_pb2 . Request ] , tokenizer
) :
2023-06-05 14:09:41 +00:00
batch_inputs = [ ]
max_truncation = 0
Adding Llava-Next (Llava 1.6) with full support. (#1709)
# What does this PR do?
- Changed all models to extract `embed_tokens` in order to enable llava
to separately call the embeddings and the core model layers.
- Added VlmCausalLM to inherit from FlashMistral in order to be
maximally supported. The only added logics sits on top and parses images
into pixel values, preallocates input_ids space for the image
embeddings, and passes them for the model.
- Added Clip for the vision tower.
- Didn't add flash for the vision tower since there's no padding anyway.
- Added heuristic (potentially incomplete) to calculate number of
features *before* calculating the clip patches (allows for easier logic
reuse of the LLM under the hood).
Still needs to be done:
- [x] Implement the image parsing in the controller side, to avoid
downloading n times per TP shard and also refusing requests too large
early and avoid issues where the truncation actually truncates the
image.
- [ ] Make sure it works with quantization properly.
- [x] Make sure it works with TP>1
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
2024-04-09 19:32:00 +00:00
for r in requests :
2024-05-31 11:51:42 +00:00
batch_inputs . append ( concat_text_chunks ( r . input_chunks . chunks ) )
2023-06-05 14:09:41 +00:00
max_truncation = max ( max_truncation , r . truncate )
batch_tokenized_inputs = tokenizer (
batch_inputs , truncation = True , max_length = max_truncation
) [ " input_ids " ]
Adding Llava-Next (Llava 1.6) with full support. (#1709)
# What does this PR do?
- Changed all models to extract `embed_tokens` in order to enable llava
to separately call the embeddings and the core model layers.
- Added VlmCausalLM to inherit from FlashMistral in order to be
maximally supported. The only added logics sits on top and parses images
into pixel values, preallocates input_ids space for the image
embeddings, and passes them for the model.
- Added Clip for the vision tower.
- Didn't add flash for the vision tower since there's no padding anyway.
- Added heuristic (potentially incomplete) to calculate number of
features *before* calculating the clip patches (allows for easier logic
reuse of the LLM under the hood).
Still needs to be done:
- [x] Implement the image parsing in the controller side, to avoid
downloading n times per TP shard and also refusing requests too large
early and avoid issues where the truncation actually truncates the
image.
- [ ] Make sure it works with quantization properly.
- [x] Make sure it works with TP>1
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
2024-04-09 19:32:00 +00:00
return batch_tokenized_inputs
2023-06-05 14:09:41 +00:00
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"{prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-16 04:58:47 +00:00
@classmethod
def from_tokenized (
cls ,
pb : generate_pb2 . Batch ,
tokenizer : PreTrainedTokenizerBase ,
batch_tokenized_inputs ,
dtype : torch . dtype ,
device : torch . device ,
) - > " FlashCausalLMBatch " :
2024-06-05 10:18:38 +00:00
sliding_window = get_sliding_windows ( )
2023-04-03 17:06:42 +00:00
position_ids = [ ]
2023-07-04 18:23:55 +00:00
cu_seqlen_prefill = [ 0 ]
2023-06-30 17:09:59 +00:00
start_slots = [ ]
slot_indices = [ ]
2024-06-05 10:18:38 +00:00
prefill_cache_indices = [ ]
2023-04-03 17:06:42 +00:00
input_lengths = [ ]
2023-05-16 21:23:27 +00:00
prefix_offsets = [ ]
read_offsets = [ ]
2023-04-03 17:06:42 +00:00
all_input_ids = [ ]
2023-04-20 09:07:40 +00:00
requests_idx_mapping = { }
2023-04-03 17:06:42 +00:00
2023-06-02 15:12:30 +00:00
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = [ ]
prefill_next_token_indices = [ ]
prefill_cu_outlens = [ 0 ]
2023-05-26 10:30:27 +00:00
next_token_chooser_parameters = [ ]
2023-04-03 17:06:42 +00:00
stopping_criterias = [ ]
2023-08-28 09:43:47 +00:00
top_n_tokens = [ ]
2023-04-03 17:06:42 +00:00
# Cumulative length
cumulative_length = 0
2023-06-12 16:30:29 +00:00
cumulative_max_length = 0
2023-06-02 15:12:30 +00:00
prefill_out_cumulative_length = 0
2023-04-03 17:06:42 +00:00
2024-06-05 10:18:38 +00:00
num_blocks = 0
2023-06-30 17:09:59 +00:00
max_seqlen = 0
2023-05-26 10:30:27 +00:00
max_length = 0
2023-06-30 17:09:59 +00:00
max_blocks = 0
2023-04-24 15:59:00 +00:00
2024-06-05 10:18:38 +00:00
block_tables = [ ]
slots = [ ]
2023-04-03 17:06:42 +00:00
# Parse batch
2023-06-05 14:09:41 +00:00
for i , ( r , tokenized_input ) in enumerate (
zip ( pb . requests , batch_tokenized_inputs )
) :
2023-04-20 09:07:40 +00:00
# request id -> idx in list mapping
requests_idx_mapping [ r . id ] = i
2023-06-05 14:09:41 +00:00
tokenized_input = tokenized_input [ - r . truncate : ]
2024-04-10 15:20:25 +00:00
if (
tokenized_input [ 0 ] == tokenizer . bos_token_id
and tokenized_input [ 1 ] == tokenizer . bos_token_id
) :
tokenized_input = tokenized_input [ 1 : ]
2023-04-20 09:07:40 +00:00
2023-04-03 17:06:42 +00:00
input_length = len ( tokenized_input )
input_lengths . append ( input_length )
2023-04-20 09:07:40 +00:00
2023-06-02 15:12:30 +00:00
prefix_offsets . append ( input_length - 5 )
2023-05-16 21:23:27 +00:00
read_offsets . append ( input_length )
2023-04-03 17:06:42 +00:00
2023-05-09 16:26:19 +00:00
all_input_ids . append ( tokenized_input )
2023-04-03 17:06:42 +00:00
# Position ids
2023-06-02 15:12:30 +00:00
request_position_ids = torch . arange ( 0 , input_length , dtype = torch . int32 )
position_ids . append ( request_position_ids )
2023-04-03 17:06:42 +00:00
# Add cumulative lengths of all previous inputs
2023-07-04 18:23:55 +00:00
cu_seqlen_prefill . append ( cumulative_length + input_length )
2023-04-03 17:06:42 +00:00
2023-05-26 10:30:27 +00:00
next_token_chooser_parameters . append ( r . parameters )
2023-04-24 15:59:00 +00:00
2023-04-03 17:06:42 +00:00
stopping_criteria = StoppingCriteria . from_pb (
r . stopping_parameters , tokenizer
)
2023-04-24 15:59:00 +00:00
max_new_tokens = stopping_criteria . max_new_tokens
2023-04-03 17:06:42 +00:00
stopping_criterias . append ( stopping_criteria )
2023-08-28 09:43:47 +00:00
top_n_tokens . append ( r . top_n_tokens )
2023-04-24 15:59:00 +00:00
2023-06-30 17:09:59 +00:00
# Paged attention
# Remove one as the first token des not have a past
2023-12-11 11:46:30 +00:00
speculative_length = get_speculate ( )
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"{prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-16 04:58:47 +00:00
speculative_length = 0 if speculative_length is None else speculative_length
2023-12-11 11:46:30 +00:00
total_tokens = input_length + max_new_tokens - 1 + speculative_length
2024-06-05 10:18:38 +00:00
# blocks and slots can be empty (for example in warmup)
if not r . blocks :
needed_blocks = math . ceil ( total_tokens / BLOCK_SIZE )
request_blocks = [
b for b in range ( num_blocks , num_blocks + needed_blocks )
]
request_slots = [
s
for b in request_blocks
for s in range ( b * BLOCK_SIZE , ( b + 1 ) * BLOCK_SIZE )
]
else :
request_blocks = r . blocks
request_slots = r . slots
block_tables . append ( request_blocks )
slots . extend ( request_slots [ : total_tokens ] )
num_blocks + = len ( request_blocks )
2023-06-30 17:09:59 +00:00
start_slots . append ( cumulative_max_length )
request_slot_indices = torch . arange (
cumulative_max_length ,
cumulative_max_length + input_length ,
dtype = torch . int64 ,
)
slot_indices . append ( request_slot_indices )
2024-06-05 10:18:38 +00:00
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None :
request_prefill_cache_indices = torch . arange (
cumulative_length + max ( 0 , input_length - sliding_window ) ,
cumulative_length + input_length ,
dtype = torch . int64 ,
)
prefill_cache_indices . append ( request_prefill_cache_indices )
2023-06-02 15:12:30 +00:00
all_prefill_logprobs = all_prefill_logprobs and r . prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r . prefill_logprobs
if r . prefill_logprobs :
prefill_head_indices . append ( request_position_ids + cumulative_length )
prefill_next_token_indices . append (
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens . append ( prefill_out_cumulative_length + input_length )
prefill_out_cumulative_length + = input_length
else :
prefill_head_indices . append (
torch . tensor (
[ cumulative_length + input_length - 1 ] , dtype = torch . int32
)
)
prefill_next_token_indices . append ( prefill_out_cumulative_length )
prefill_cu_outlens . append ( prefill_out_cumulative_length + 1 )
prefill_out_cumulative_length + = 1
2023-04-03 17:06:42 +00:00
# Update
cumulative_length + = input_length
2023-06-30 17:09:59 +00:00
cumulative_max_length + = total_tokens
max_seqlen = max ( max_seqlen , input_length )
2024-06-05 10:18:38 +00:00
max_blocks = max ( max_blocks , len ( request_blocks ) )
2023-12-11 13:49:52 +00:00
max_length = max (
max_length , input_length + max_new_tokens + speculative_length
)
2023-05-26 10:30:27 +00:00
next_token_chooser = HeterogeneousNextTokenChooser . from_pb (
2024-02-15 09:28:10 +00:00
next_token_chooser_parameters , dtype , device , tokenizer
2023-05-26 10:30:27 +00:00
)
2023-06-30 17:09:59 +00:00
start_slots = torch . tensor ( start_slots , dtype = torch . int64 )
2023-05-26 10:30:27 +00:00
# Padded all_input_ids_tensor
all_input_ids_tensor = np . zeros (
( len ( all_input_ids ) , max_length ) , dtype = np . int64
)
for i , input_ids in enumerate ( all_input_ids ) :
all_input_ids_tensor [ i , : len ( input_ids ) ] = input_ids
2023-04-03 17:06:42 +00:00
2023-06-12 16:30:29 +00:00
# Create tensors on device
all_input_ids_tensor = torch . tensor (
all_input_ids_tensor , dtype = torch . int64 , device = device
)
2023-06-02 15:12:30 +00:00
if len ( pb . requests ) > 1 :
input_ids = np . concatenate ( all_input_ids , dtype = np . int64 )
position_ids = torch . cat ( position_ids )
2023-06-30 17:09:59 +00:00
slot_indices = torch . cat ( slot_indices )
2024-06-05 10:18:38 +00:00
if sliding_window is not None :
prefill_cache_indices = torch . cat ( prefill_cache_indices )
2023-06-02 15:12:30 +00:00
else :
input_ids = all_input_ids [ 0 ]
position_ids = position_ids [ 0 ]
2023-06-30 17:09:59 +00:00
slot_indices = slot_indices [ 0 ]
2024-06-05 10:18:38 +00:00
if sliding_window is not None :
prefill_cache_indices = prefill_cache_indices [ 0 ]
2023-06-02 15:12:30 +00:00
2023-07-04 18:23:55 +00:00
cu_seqlen_prefill = torch . tensor (
cu_seqlen_prefill , device = device , dtype = torch . int32
2023-06-30 17:09:59 +00:00
)
position_ids = position_ids . to ( device )
slot_indices = slot_indices . to ( device )
2024-06-05 10:18:38 +00:00
prefill_cache_indices = (
prefill_cache_indices . to ( device ) if sliding_window is not None else None
)
2023-06-02 15:12:30 +00:00
input_ids = torch . tensor ( input_ids , dtype = torch . int64 , device = device )
2023-06-30 17:09:59 +00:00
input_lengths_tensor = torch . tensor (
input_lengths , dtype = torch . int32 , device = device
2023-06-12 16:30:29 +00:00
)
2023-05-09 16:26:19 +00:00
2023-06-02 15:12:30 +00:00
if all_prefill_logprobs :
prefill_head_indices = None
2023-07-04 18:23:55 +00:00
prefill_next_token_indices = cu_seqlen_prefill [ 1 : ] - 1
2023-06-02 15:12:30 +00:00
elif no_prefill_logprobs :
2023-07-04 18:23:55 +00:00
prefill_head_indices = cu_seqlen_prefill [ 1 : ] - 1
2023-06-02 15:12:30 +00:00
prefill_next_token_indices = None
else :
prefill_head_indices = torch . tensor (
torch . cat ( prefill_head_indices ) , dtype = torch . int64 , device = device
)
prefill_next_token_indices = torch . tensor (
prefill_next_token_indices , dtype = torch . int64 , device = device
)
2023-08-28 09:43:47 +00:00
top_n_tokens_tensor = torch . tensor (
top_n_tokens , device = device , dtype = torch . int64
)
2023-06-02 15:12:30 +00:00
2024-06-05 10:18:38 +00:00
slots = torch . tensor ( slots , dtype = torch . int64 , device = device )
block_tables_tensor = torch . zeros (
( len ( block_tables ) , max_blocks ) , dtype = torch . int32 , device = " cpu "
)
for i , request_blocks in enumerate ( block_tables ) :
block_tables_tensor [ i , : len ( request_blocks ) ] = torch . tensor ( request_blocks )
block_tables_tensor = block_tables_tensor . to ( device )
2023-04-03 17:06:42 +00:00
return cls (
batch_id = pb . id ,
requests = pb . requests ,
2023-04-20 09:07:40 +00:00
requests_idx_mapping = requests_idx_mapping ,
2023-04-03 17:06:42 +00:00
input_ids = input_ids ,
position_ids = position_ids ,
2023-07-04 18:23:55 +00:00
cu_seqlen_prefill = cu_seqlen_prefill ,
2024-06-05 10:18:38 +00:00
prefill_cache_indices = prefill_cache_indices ,
2023-06-30 17:09:59 +00:00
start_slots = start_slots ,
slot_indices = slot_indices ,
2024-06-05 10:18:38 +00:00
block_tables = block_tables ,
block_tables_tensor = block_tables_tensor ,
slots = slots ,
2023-04-03 17:06:42 +00:00
max_seqlen = max_seqlen ,
2023-06-02 15:12:30 +00:00
prefill_head_indices = prefill_head_indices ,
prefill_next_token_indices = prefill_next_token_indices ,
prefill_cu_outlens = prefill_cu_outlens ,
2023-04-03 17:06:42 +00:00
input_lengths = input_lengths ,
2023-06-30 17:09:59 +00:00
input_lengths_tensor = input_lengths_tensor ,
2023-05-16 21:23:27 +00:00
prefix_offsets = prefix_offsets ,
read_offsets = read_offsets ,
2023-04-03 17:06:42 +00:00
all_input_ids = all_input_ids ,
2023-05-26 10:30:27 +00:00
all_input_ids_tensor = all_input_ids_tensor ,
next_token_chooser = next_token_chooser ,
2023-04-03 17:06:42 +00:00
stopping_criterias = stopping_criterias ,
2023-08-28 09:43:47 +00:00
top_n_tokens = top_n_tokens ,
top_n_tokens_tensor = top_n_tokens_tensor ,
2024-06-05 10:18:38 +00:00
num_blocks = num_blocks ,
2023-06-30 17:09:59 +00:00
max_blocks = max_blocks ,
2023-12-11 11:46:30 +00:00
speculative_ids = None ,
2023-04-03 17:06:42 +00:00
)
2024-06-05 10:18:38 +00:00
@classmethod
def from_pb (
cls ,
pb : generate_pb2 . Batch ,
tokenizer : PreTrainedTokenizerBase ,
dtype : torch . dtype ,
device : torch . device ,
) - > " FlashCausalLMBatch " :
batch_tokenized_inputs = cls . batch_tokenized_inputs ( pb . requests , tokenizer )
return cls . from_tokenized ( pb , tokenizer , batch_tokenized_inputs , dtype , device )
2023-04-20 09:07:40 +00:00
@tracer.start_as_current_span ( " filter " )
2023-05-24 17:19:57 +00:00
def filter ( self , request_ids : List [ int ] ) - > " FlashCausalLMBatch " :
if len ( request_ids ) == 0 :
2023-04-20 09:07:40 +00:00
raise ValueError ( " Batch must have at least one request " )
# We assume that if len(requests) == len(self) then the requests are the same
2023-05-24 17:19:57 +00:00
if len ( request_ids ) == len ( self ) :
2023-04-20 09:07:40 +00:00
return self
2023-06-12 16:30:29 +00:00
device = self . input_ids . device
2023-04-21 12:57:18 +00:00
2023-04-20 09:07:40 +00:00
# New values after filtering
requests_idx_mapping = { }
2023-05-26 10:30:27 +00:00
# Used to index into tensors
indices = [ ]
2023-06-30 17:09:59 +00:00
# slots to keep after filtering
slot_filtering_indices = torch . zeros (
self . slots . shape [ 0 ] , dtype = torch . bool , device = device
2023-06-12 16:30:29 +00:00
)
2023-05-09 16:26:19 +00:00
# Create on CPU to only move to GPU once instead of at every copy
2023-06-30 17:09:59 +00:00
slot_indices = torch . empty ( len ( request_ids ) , dtype = torch . int64 )
2023-04-20 09:07:40 +00:00
max_seqlen = 0
2023-05-24 17:19:57 +00:00
requests = [ ]
2023-06-30 17:09:59 +00:00
start_slots = [ ]
block_tables = [ ]
2023-04-20 09:07:40 +00:00
all_input_ids = [ ]
2023-04-03 17:06:42 +00:00
input_lengths = [ ]
2023-05-16 21:23:27 +00:00
prefix_offsets = [ ]
read_offsets = [ ]
2023-04-20 09:07:40 +00:00
2023-04-03 17:06:42 +00:00
stopping_criterias = [ ]
2023-08-28 09:43:47 +00:00
top_n_tokens = [ ]
2023-04-03 17:06:42 +00:00
2024-06-05 10:18:38 +00:00
num_blocks = 0
2023-06-30 17:09:59 +00:00
max_blocks = 0
# Cumulative length
cumulative_max_length = 0
2023-05-24 17:19:57 +00:00
for i , request_id in enumerate ( request_ids ) :
idx = self . requests_idx_mapping [ request_id ]
2023-05-26 10:30:27 +00:00
indices . append ( idx )
2023-05-24 17:19:57 +00:00
requests_idx_mapping [ request_id ] = i
requests . append ( self . requests [ idx ] )
2023-04-20 09:07:40 +00:00
# Get length
request_input_length = self . input_lengths [ idx ]
max_seqlen = max ( max_seqlen , request_input_length )
2023-04-21 18:26:01 +00:00
2023-04-20 09:07:40 +00:00
all_input_ids . append ( self . all_input_ids [ idx ] )
input_lengths . append ( request_input_length )
2023-05-16 21:23:27 +00:00
prefix_offsets . append ( self . prefix_offsets [ idx ] )
read_offsets . append ( self . read_offsets [ idx ] )
2023-04-20 09:07:40 +00:00
2023-04-24 15:59:00 +00:00
stopping_criteria = self . stopping_criterias [ idx ]
stopping_criterias . append ( stopping_criteria )
2023-04-20 09:07:40 +00:00
2023-08-28 09:43:47 +00:00
top_n_tokens . append ( self . top_n_tokens [ idx ] )
2023-06-12 16:30:29 +00:00
remaining_tokens = (
2023-04-24 15:59:00 +00:00
stopping_criteria . max_new_tokens - stopping_criteria . current_tokens
)
2023-04-20 09:07:40 +00:00
2023-06-30 17:09:59 +00:00
request_block_table = self . block_tables [ idx ]
2024-06-05 10:18:38 +00:00
num_blocks + = len ( request_block_table )
2023-06-30 17:09:59 +00:00
block_tables . append ( request_block_table )
start_slots . append ( cumulative_max_length )
2023-06-12 16:30:29 +00:00
# Copy to tensor (CPU)
2023-06-30 17:09:59 +00:00
slot_indices [ i ] = cumulative_max_length + request_input_length - 1
2023-06-12 16:30:29 +00:00
# Set slice
2023-06-30 17:09:59 +00:00
slot_filtering_indices [
self . start_slots [ idx ] : self . start_slots [ idx ]
+ request_input_length
+ remaining_tokens
- 1
2023-06-12 16:30:29 +00:00
] = True
cumulative_max_length + = request_input_length + remaining_tokens - 1
2023-05-09 16:26:19 +00:00
2023-06-30 17:09:59 +00:00
max_blocks = max ( max_blocks , len ( request_block_table ) )
2023-05-26 10:30:27 +00:00
# Index into tensors
input_ids = self . input_ids [ indices ]
position_ids = self . position_ids [ indices ]
all_input_ids_tensor = self . all_input_ids_tensor [ indices ]
2023-06-30 17:09:59 +00:00
block_tables_tensor = self . block_tables_tensor [ indices ]
input_lengths_tensor = self . input_lengths_tensor [ indices ]
slots = self . slots [ slot_filtering_indices ]
2023-05-26 10:30:27 +00:00
next_token_chooser = self . next_token_chooser . filter ( indices )
2023-08-28 09:43:47 +00:00
top_n_tokens_tensor = self . top_n_tokens_tensor [ indices ]
2023-12-11 13:49:52 +00:00
speculative_ids = (
self . speculative_ids [ indices ] if self . speculative_ids is not None else None
)
2023-06-30 17:09:59 +00:00
start_slots = torch . tensor ( start_slots , dtype = torch . int64 )
2023-05-26 10:30:27 +00:00
2023-05-09 16:26:19 +00:00
# Move to GPU now that we have the whole tensor
2023-06-30 17:09:59 +00:00
slot_indices = slot_indices . to ( device )
2023-04-21 12:57:18 +00:00
2023-09-28 07:55:47 +00:00
return type ( self ) (
2023-04-20 09:07:40 +00:00
batch_id = self . batch_id ,
requests = requests ,
requests_idx_mapping = requests_idx_mapping ,
input_ids = input_ids ,
position_ids = position_ids ,
2023-07-04 18:23:55 +00:00
cu_seqlen_prefill = None ,
2024-06-05 10:18:38 +00:00
prefill_cache_indices = None ,
2023-06-30 17:09:59 +00:00
start_slots = start_slots ,
slot_indices = slot_indices ,
block_tables = block_tables ,
block_tables_tensor = block_tables_tensor ,
slots = slots ,
2023-04-20 09:07:40 +00:00
max_seqlen = max_seqlen ,
2023-06-02 15:12:30 +00:00
prefill_head_indices = None ,
prefill_next_token_indices = None ,
prefill_cu_outlens = None ,
2023-04-20 09:07:40 +00:00
input_lengths = input_lengths ,
2023-06-30 17:09:59 +00:00
input_lengths_tensor = input_lengths_tensor ,
2023-05-16 21:23:27 +00:00
prefix_offsets = prefix_offsets ,
read_offsets = read_offsets ,
2023-04-20 09:07:40 +00:00
all_input_ids = all_input_ids ,
all_input_ids_tensor = all_input_ids_tensor ,
2023-05-26 10:30:27 +00:00
next_token_chooser = next_token_chooser ,
2023-04-20 09:07:40 +00:00
stopping_criterias = stopping_criterias ,
2023-08-28 09:43:47 +00:00
top_n_tokens = top_n_tokens ,
top_n_tokens_tensor = top_n_tokens_tensor ,
2024-06-05 10:18:38 +00:00
num_blocks = num_blocks ,
2023-06-30 17:09:59 +00:00
max_blocks = max_blocks ,
2023-12-11 11:46:30 +00:00
speculative_ids = speculative_ids ,
2023-04-20 09:07:40 +00:00
)
@classmethod
@tracer.start_as_current_span ( " concatenate " )
def concatenate ( cls , batches : List [ " FlashCausalLMBatch " ] ) - > " FlashCausalLMBatch " :
# Batch attributes
requests = [ ]
requests_idx_mapping = { }
2024-06-05 10:18:38 +00:00
num_blocks = 0
2023-06-30 17:09:59 +00:00
total_batch_size = 0
total_slots = 0
max_blocks = 0
max_length = 0
max_seqlen = 0
for b in batches :
total_batch_size + = len ( b )
total_slots + = len ( b . slots )
2024-06-05 10:18:38 +00:00
num_blocks + = b . num_blocks
2023-12-11 13:49:52 +00:00
speculative_length = (
b . speculative_ids . shape [ 1 ] if b . speculative_ids is not None else 0
)
2023-06-30 17:09:59 +00:00
max_blocks = max ( max_blocks , b . max_blocks )
max_seqlen = max ( max_seqlen , b . max_seqlen )
max_length = max (
max_length ,
max (
input_length
+ stopping_criteria . max_new_tokens
2023-12-11 11:46:30 +00:00
+ speculative_length
2023-06-30 17:09:59 +00:00
- stopping_criteria . current_tokens
for input_length , stopping_criteria in zip (
b . input_lengths , b . stopping_criterias
)
) ,
)
2023-05-09 16:26:19 +00:00
input_ids = batches [ 0 ] . input_ids . new_empty ( total_batch_size )
position_ids = batches [ 0 ] . position_ids . new_empty ( total_batch_size )
2023-06-30 17:09:59 +00:00
slots = batches [ 0 ] . slots . new_empty ( total_slots )
slot_indices = batches [ 0 ] . slot_indices . new_empty ( total_batch_size )
input_lengths_tensor = batches [ 0 ] . input_lengths_tensor . new_empty (
total_batch_size
)
block_tables_tensor = batches [ 0 ] . block_tables_tensor . new_zeros (
( total_batch_size , max_blocks )
)
all_input_ids_tensor = batches [ 0 ] . all_input_ids_tensor . new_zeros (
( total_batch_size , max_length )
2023-05-09 16:26:19 +00:00
)
2023-08-28 09:43:47 +00:00
top_n_tokens_tensor = batches [ 0 ] . top_n_tokens_tensor . new_zeros (
total_batch_size ,
)
2023-04-03 17:06:42 +00:00
2023-06-30 17:09:59 +00:00
start_slots = [ ]
block_tables = [ ]
2023-04-20 09:07:40 +00:00
all_input_ids = [ ]
input_lengths = [ ]
2023-05-16 21:23:27 +00:00
prefix_offsets = [ ]
read_offsets = [ ]
2023-04-20 09:07:40 +00:00
2023-05-26 10:30:27 +00:00
next_token_chooser_parameters = [ ]
2024-02-29 10:17:42 +00:00
fsm_grammar_states = [ ]
2023-04-20 09:07:40 +00:00
stopping_criterias = [ ]
2023-08-28 09:43:47 +00:00
top_n_tokens = [ ]
2023-04-20 09:07:40 +00:00
2023-04-03 17:06:42 +00:00
# Cumulative length
2023-04-20 09:07:40 +00:00
cumulative_batch_size = 0
2023-06-30 17:09:59 +00:00
cumulative_slots = 0
2023-04-03 17:06:42 +00:00
for i , batch in enumerate ( batches ) :
requests . extend ( batch . requests )
2023-04-20 09:07:40 +00:00
if i == 0 :
requests_idx_mapping = batch . requests_idx_mapping
else :
# We need to offset the mapping for each batch by the cumulative batch size
for k , v in batch . requests_idx_mapping . items ( ) :
requests_idx_mapping [ k ] = v + cumulative_batch_size
2023-05-09 16:26:19 +00:00
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len ( batch )
2023-06-30 17:09:59 +00:00
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len ( batch . slots )
2023-05-09 16:26:19 +00:00
# Copy tensors (GPU)
input_ids [ start_index : end_index ] = batch . input_ids
position_ids [ start_index : end_index ] = batch . position_ids
2023-06-30 17:09:59 +00:00
slot_indices [ start_index : end_index ] = batch . slot_indices + cumulative_slots
input_lengths_tensor [ start_index : end_index ] = batch . input_lengths_tensor
2023-08-28 09:43:47 +00:00
top_n_tokens_tensor [ start_index : end_index ] = batch . top_n_tokens_tensor
2023-06-30 17:09:59 +00:00
slots [ slots_start_index : slots_end_index ] = batch . slots
2023-05-09 16:26:19 +00:00
2023-06-30 17:09:59 +00:00
all_input_ids_tensor [
start_index : end_index , : batch . all_input_ids_tensor . shape [ 1 ]
] = batch . all_input_ids_tensor [ : , : max_length ]
2023-04-21 13:59:18 +00:00
2023-06-30 17:09:59 +00:00
block_tables_tensor [
start_index : end_index , : batch . block_tables_tensor . shape [ 1 ]
] = batch . block_tables_tensor [ : , : max_blocks ]
2023-04-20 09:07:40 +00:00
2023-06-30 17:09:59 +00:00
start_slots . append ( batch . start_slots + cumulative_slots )
block_tables . extend ( batch . block_tables )
2023-04-20 09:07:40 +00:00
all_input_ids . extend ( batch . all_input_ids )
2023-04-03 17:06:42 +00:00
input_lengths . extend ( batch . input_lengths )
2023-05-16 21:23:27 +00:00
prefix_offsets . extend ( batch . prefix_offsets )
read_offsets . extend ( batch . read_offsets )
2023-04-20 09:07:40 +00:00
2023-05-26 10:30:27 +00:00
next_token_chooser_parameters . extend ( [ r . parameters for r in batch . requests ] )
2024-02-29 10:17:42 +00:00
fsm_grammar_states . extend ( batch . next_token_chooser . fsm_grammar_states )
2023-04-03 17:06:42 +00:00
stopping_criterias . extend ( batch . stopping_criterias )
2023-08-28 09:43:47 +00:00
top_n_tokens . extend ( batch . top_n_tokens )
2023-04-03 17:06:42 +00:00
# Update
2023-04-20 09:07:40 +00:00
cumulative_batch_size + = len ( batch )
2023-06-30 17:09:59 +00:00
cumulative_slots + = len ( batch . slots )
2023-05-26 10:30:27 +00:00
2023-06-30 17:09:59 +00:00
start_slots = torch . concat ( start_slots )
2023-04-03 17:06:42 +00:00
2023-05-26 10:30:27 +00:00
next_token_chooser = HeterogeneousNextTokenChooser . from_pb (
2023-06-30 17:09:59 +00:00
next_token_chooser_parameters ,
dtype = batches [ 0 ] . next_token_chooser . dtype ,
device = batches [ 0 ] . next_token_chooser . device ,
2024-02-15 09:28:10 +00:00
tokenizer = batches [ 0 ] . next_token_chooser . tokenizer ,
2024-02-29 10:17:42 +00:00
fsm_grammar_states = fsm_grammar_states ,
2023-05-26 10:30:27 +00:00
)
2023-12-11 13:49:52 +00:00
speculative_ids = (
torch . cat ( [ b . speculative_ids for b in batches ] , dim = 0 )
if batches [ 0 ] . speculative_ids is not None
else None
)
2023-12-11 11:46:30 +00:00
2023-09-28 07:55:47 +00:00
return cls (
2023-04-03 17:06:42 +00:00
batch_id = batches [ 0 ] . batch_id ,
requests = requests ,
2023-04-20 09:07:40 +00:00
requests_idx_mapping = requests_idx_mapping ,
2023-04-03 17:06:42 +00:00
input_ids = input_ids ,
position_ids = position_ids ,
2023-07-04 18:23:55 +00:00
cu_seqlen_prefill = None ,
2024-06-05 10:18:38 +00:00
prefill_cache_indices = None ,
2023-06-30 17:09:59 +00:00
start_slots = start_slots ,
slot_indices = slot_indices ,
block_tables = block_tables ,
block_tables_tensor = block_tables_tensor ,
slots = slots ,
2023-04-03 17:06:42 +00:00
max_seqlen = max_seqlen ,
2023-06-02 15:12:30 +00:00
prefill_head_indices = None ,
prefill_next_token_indices = None ,
prefill_cu_outlens = None ,
2023-04-03 17:06:42 +00:00
input_lengths = input_lengths ,
2023-06-30 17:09:59 +00:00
input_lengths_tensor = input_lengths_tensor ,
2023-05-16 21:23:27 +00:00
prefix_offsets = prefix_offsets ,
read_offsets = read_offsets ,
2023-04-03 17:06:42 +00:00
all_input_ids = all_input_ids ,
all_input_ids_tensor = all_input_ids_tensor ,
2023-05-26 10:30:27 +00:00
next_token_chooser = next_token_chooser ,
2023-04-03 17:06:42 +00:00
stopping_criterias = stopping_criterias ,
2023-08-28 09:43:47 +00:00
top_n_tokens = top_n_tokens ,
top_n_tokens_tensor = top_n_tokens_tensor ,
2024-06-05 10:18:38 +00:00
num_blocks = num_blocks ,
2023-06-30 17:09:59 +00:00
max_blocks = max_blocks ,
2023-12-11 13:49:52 +00:00
speculative_ids = speculative_ids ,
2023-04-03 17:06:42 +00:00
)
def __len__ ( self ) :
return len ( self . requests )
class FlashCausalLM ( Model ) :
def __init__ (
self ,
2023-06-30 17:09:59 +00:00
model : torch . nn . Module ,
tokenizer : PreTrainedTokenizerBase ,
num_layers : int ,
num_kv_heads : int ,
head_size : int ,
dtype : torch . dtype ,
device : torch . device ,
rank : int = 0 ,
world_size : int = 1 ,
2023-09-28 07:55:47 +00:00
sliding_window : Optional [ int ] = None ,
2023-04-03 17:06:42 +00:00
) :
2023-06-30 17:09:59 +00:00
self . num_layers = num_layers
self . num_kv_heads = num_kv_heads
self . head_size = head_size
2023-04-03 17:06:42 +00:00
2024-02-12 09:09:29 +00:00
self . cuda_graphs = { }
2024-06-05 10:18:38 +00:00
self . kv_cache = [ ]
2024-02-12 09:09:29 +00:00
2023-04-03 17:06:42 +00:00
super ( FlashCausalLM , self ) . __init__ (
2023-05-16 21:23:27 +00:00
model = model ,
2023-04-21 13:36:29 +00:00
tokenizer = tokenizer ,
requires_padding = False ,
dtype = dtype ,
device = device ,
2023-06-30 17:09:59 +00:00
rank = rank ,
world_size = world_size ,
2023-09-28 07:55:47 +00:00
sliding_window = sliding_window ,
2023-04-03 17:06:42 +00:00
)
@property
def batch_type ( self ) - > Type [ FlashCausalLMBatch ] :
return FlashCausalLMBatch
2024-06-05 10:18:38 +00:00
def max_past ( self ) - > int :
return getattr ( self . model , " max_past " , None )
def init_kv_cache (
self ,
num_blocks : int ,
num_layers : int ,
num_heads : int ,
head_size : int ,
dtype : torch . dtype ,
device : torch . device ,
) :
self . kv_cache = [ ]
empty_cache ( )
element_size = torch . tensor ( [ ] , dtype = dtype ) . element_size ( )
if SYSTEM == " xpu " :
x = 1
else :
x = BLOCK_SIZE / / element_size
2024-06-25 10:21:29 +00:00
if IPEX_AVAIL and SYSTEM == " cpu " :
self . kv_cache = [
(
torch . empty (
( num_blocks , num_heads , BLOCK_SIZE , head_size ) ,
dtype = dtype ,
device = device ,
) ,
torch . empty (
( num_blocks , num_heads , BLOCK_SIZE , head_size ) ,
dtype = dtype ,
device = device ,
) ,
)
for _ in range ( num_layers )
]
else :
self . kv_cache = [
(
torch . empty (
( num_blocks , num_heads , head_size / / x , BLOCK_SIZE , x ) ,
dtype = dtype ,
device = device ,
) ,
torch . empty (
( num_blocks , num_heads , head_size , BLOCK_SIZE ) ,
dtype = dtype ,
device = device ,
) ,
)
for _ in range ( num_layers )
]
2024-06-05 10:18:38 +00:00
2024-02-12 09:09:29 +00:00
def cuda_graph_warmup ( self , bs : int , max_s : int , max_bt : int ) :
input_ids = torch . zeros ( bs , dtype = torch . int64 , device = self . device )
position_ids = torch . zeros ( bs , dtype = torch . int32 , device = self . device )
2024-04-10 15:20:25 +00:00
slots = torch . arange ( bs , dtype = torch . int64 , device = self . device )
2024-02-12 09:09:29 +00:00
input_lengths = torch . ones ( bs , dtype = torch . int32 , device = self . device ) * max_s
block_tables = (
torch . arange ( max_bt , dtype = torch . int32 , device = self . device )
. repeat ( bs )
. reshape ( ( bs , max_bt ) )
)
self . cuda_graphs [ bs ] = {
" input_ids " : input_ids ,
" position_ids " : position_ids ,
2024-06-05 10:18:38 +00:00
" kv_cache " : self . kv_cache ,
2024-02-12 09:09:29 +00:00
" block_tables " : block_tables ,
" slots " : slots ,
" input_lengths " : input_lengths ,
}
graph = torch . cuda . CUDAGraph ( )
self . cuda_graphs [ bs ] [ " graph " ] = graph
torch . cuda . synchronize ( )
# Run once outside to warmup
self . model . forward (
input_ids = input_ids ,
position_ids = position_ids ,
cu_seqlen_prefill = None ,
2024-06-05 10:18:38 +00:00
kv_cache = self . kv_cache ,
2024-02-12 09:09:29 +00:00
block_tables = block_tables ,
slots = slots ,
input_lengths = input_lengths ,
max_s = max_s ,
2024-06-05 10:18:38 +00:00
prefill_cache_indices = None ,
2024-02-12 09:09:29 +00:00
lm_head_indices = None ,
)
torch . cuda . synchronize ( )
with torch . cuda . graph ( graph , pool = MEM_POOL ) :
2024-02-26 18:49:28 +00:00
logits , speculative_logits = self . model . forward (
2024-02-12 09:09:29 +00:00
input_ids = input_ids ,
position_ids = position_ids ,
cu_seqlen_prefill = None ,
2024-06-05 10:18:38 +00:00
kv_cache = self . kv_cache ,
2024-02-12 09:09:29 +00:00
block_tables = block_tables ,
slots = slots ,
input_lengths = input_lengths ,
max_s = max_s ,
2024-06-05 10:18:38 +00:00
prefill_cache_indices = None ,
2024-02-12 09:09:29 +00:00
lm_head_indices = None ,
)
2024-02-26 18:49:28 +00:00
self . cuda_graphs [ bs ] [ " logits " ] = logits
self . cuda_graphs [ bs ] [ " speculative_logits " ] = speculative_logits
2024-02-12 09:09:29 +00:00
torch . cuda . synchronize ( )
2023-07-19 07:31:25 +00:00
def warmup ( self , batch : FlashCausalLMBatch ) :
2024-02-12 09:09:29 +00:00
# The warmup batch is the biggest batch we could ever receive
2024-05-13 10:44:30 +00:00
empty_cache ( )
2023-06-30 17:09:59 +00:00
try :
2024-06-05 10:18:38 +00:00
self . init_kv_cache (
batch . num_blocks ,
2023-06-30 17:09:59 +00:00
self . num_layers ,
self . num_kv_heads ,
self . head_size ,
self . dtype ,
self . device ,
)
2024-02-12 09:09:29 +00:00
max_bt = batch . max_blocks
2024-06-05 10:18:38 +00:00
max_s = max_bt * BLOCK_SIZE
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 13:30:47 +00:00
if SYSTEM == " rocm " and os . environ . get ( " PYTORCH_TUNABLEOP_ENABLED " , False ) :
torch . cuda . tunable . tuning_enable ( False )
2023-12-14 14:59:38 +00:00
_ , batch , _ = self . generate_token ( batch )
2023-10-25 08:18:58 +00:00
except torch . cuda . OutOfMemoryError as e :
2023-07-10 12:47:15 +00:00
raise RuntimeError (
2023-07-19 07:31:25 +00:00
f " Not enough memory to handle { len ( batch . input_ids ) } prefill tokens. "
f " You need to decrease `--max-batch-prefill-tokens` "
2023-07-10 12:47:15 +00:00
) from e
2023-07-19 07:31:25 +00:00
2024-05-13 10:44:30 +00:00
synchronize ( self . device )
2023-07-19 07:31:25 +00:00
2023-07-20 15:23:49 +00:00
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
2023-07-19 07:31:25 +00:00
dtype_size = torch . tensor ( [ ] , dtype = self . dtype ) . element_size ( )
cache_block_size = BLOCK_SIZE * self . num_kv_heads * self . head_size
total_cache_size = self . num_layers * cache_block_size * 2 * dtype_size
2024-05-13 10:44:30 +00:00
free_memory = get_free_memory ( self . device , MEMORY_FRACTION )
2023-07-19 07:31:25 +00:00
num_blocks = (
2024-02-12 09:09:29 +00:00
# Leave 5% for some wiggle room
int ( ( free_memory * 0.95 ) / / total_cache_size )
2024-06-05 10:18:38 +00:00
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch . num_blocks
2023-07-19 07:31:25 +00:00
)
2023-06-30 17:09:59 +00:00
del batch
2023-07-19 07:31:25 +00:00
2024-06-05 10:18:38 +00:00
self . init_kv_cache (
2023-07-19 07:31:25 +00:00
num_blocks ,
self . num_layers ,
self . num_kv_heads ,
self . head_size ,
self . dtype ,
self . device ,
)
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 13:30:47 +00:00
if SYSTEM == " rocm " :
if (
os . environ . get ( " PYTORCH_TUNABLEOP_ENABLED " ) is None
or os . environ . get ( " PYTORCH_TUNABLEOP_ENABLED " ) == " 1 "
) :
2024-06-10 07:09:50 +00:00
torch . cuda . tunable . enable ( )
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 13:30:47 +00:00
if os . environ . get ( " PYTORCH_TUNABLEOP_TUNING " ) != " 0 " :
torch . cuda . tunable . tuning_enable ( True )
if os . environ . get ( " PYTORCH_TUNABLEOP_SEQLENS " ) is not None :
tuning_sequences = [
int ( val )
for val in os . environ [ " PYTORCH_TUNABLEOP_SEQLENS " ] . split ( " , " )
]
2024-06-10 07:09:50 +00:00
elif CUDA_GRAPHS is not None :
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 13:30:47 +00:00
tuning_sequences = CUDA_GRAPHS
2024-06-10 07:09:50 +00:00
else :
# For seqlen = 1, we dispatch to LLMM1 kernel.
tuning_sequences = [ 2 , 3 , 4 , 5 , 6 , 7 ]
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 13:30:47 +00:00
tunableop_filepath = os . path . join (
HUGGINGFACE_HUB_CACHE ,
f " tunableop_ { tgi_globals . MODEL_ID . replace ( ' / ' , ' - ' ) } _tp { self . world_size } _rank { self . rank } .csv " ,
)
logger . info (
f " PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths { ' , ' . join ( [ str ( seqlen ) for seqlen in tuning_sequences ] ) } , with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file { tunableop_filepath } . To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`. "
)
if os . path . isfile ( tunableop_filepath ) :
logger . info (
f " The file { tunableop_filepath } already exists and will be reused. "
)
torch . cuda . tunable . read_file ( tunableop_filepath )
os . makedirs ( HUGGINGFACE_HUB_CACHE , exist_ok = True )
for seqlen in tuning_sequences :
logger . info ( f " Warming up TunableOp for seqlen= { seqlen } " )
self . tunableop_warmup ( seqlen )
torch . cuda . tunable . write_file ( tunableop_filepath )
torch . cuda . tunable . tuning_enable ( False )
else :
logger . info (
" PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8 % la tency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp. "
)
2024-04-04 21:01:56 +00:00
if CUDA_GRAPHS :
2024-02-12 09:09:29 +00:00
try :
2024-04-04 21:01:56 +00:00
logger . info ( f " Cuda Graphs are enabled for sizes { CUDA_GRAPHS } " )
2024-02-12 09:09:29 +00:00
# Warmup cuda graphs
2024-04-04 21:01:56 +00:00
for bs in CUDA_GRAPHS :
2024-02-12 09:09:29 +00:00
if self . speculate is None or self . speculate + 1 < = bs :
self . cuda_graph_warmup ( bs , max_s , max_bt )
2024-04-12 14:24:45 +00:00
except torch . cuda . OutOfMemoryError :
2024-02-12 09:09:29 +00:00
logger . exception ( f " Decode cuda graph warmup failed " )
2024-04-22 14:09:19 +00:00
else :
logger . info ( f " Cuda Graphs are disabled (CUDA_GRAPHS= { CUDA_GRAPHS } ). " )
2024-02-12 09:09:29 +00:00
2023-07-19 07:31:25 +00:00
return int ( num_blocks * BLOCK_SIZE )
2023-06-30 17:09:59 +00:00
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 13:30:47 +00:00
def tunableop_warmup ( self , seqlen : int ) :
input_ids = torch . zeros ( seqlen , dtype = torch . int64 , device = self . device )
position_ids = torch . zeros ( seqlen , dtype = torch . int32 , device = self . device )
slots = torch . arange ( seqlen , dtype = torch . int64 , device = self . device )
2024-05-17 17:50:52 +00:00
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch . ones ( seqlen , dtype = torch . int32 , device = self . device )
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 13:30:47 +00:00
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self . model . forward (
input_ids = input_ids ,
position_ids = position_ids ,
cu_seqlen_prefill = torch . tensor (
[ 0 , seqlen ] , device = self . device , dtype = torch . int32
) ,
2024-06-05 10:18:38 +00:00
kv_cache = self . kv_cache ,
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 13:30:47 +00:00
block_tables = None ,
2024-05-17 17:50:52 +00:00
input_lengths = input_lengths ,
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 13:30:47 +00:00
slots = slots ,
max_s = seqlen ,
lm_head_indices = None ,
2024-06-05 10:18:38 +00:00
prefill_cache_indices = None ,
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 13:30:47 +00:00
)
2024-02-26 18:49:28 +00:00
def forward (
self , batch : FlashCausalLMBatch
) - > Tuple [ torch . Tensor , Optional [ torch . Tensor ] ] :
2023-04-03 17:06:42 +00:00
# Model Forward
2023-12-11 11:46:30 +00:00
if batch . speculative_ids is not None :
2023-12-11 13:49:52 +00:00
input_ids = batch . input_ids
position_ids = batch . position_ids
cu_seqlen_prefill = batch . cu_seqlen_prefill
2024-06-05 10:18:38 +00:00
kv_cache = self . kv_cache
2023-12-11 13:49:52 +00:00
block_tables = batch . block_tables_tensor
slots = batch . slots [ batch . slot_indices ]
input_lengths = batch . input_lengths_tensor
max_s = batch . max_seqlen
lm_head_indices = batch . prefill_head_indices
2023-12-11 11:46:30 +00:00
speculative_ids = batch . speculative_ids
2023-12-11 13:49:52 +00:00
B , speculative_length = speculative_ids . shape
2023-12-11 11:46:30 +00:00
new_length = speculative_length + 1
2023-12-11 13:49:52 +00:00
new_input_ids = torch . cat (
[ input_ids . unsqueeze ( - 1 ) , speculative_ids ] , dim = 1
) . reshape ( - 1 )
2023-12-11 11:46:30 +00:00
arange = torch . arange ( new_length , device = position_ids . device ) . unsqueeze ( 0 )
arange_int = arange . to ( dtype = torch . int32 )
2023-12-11 13:49:52 +00:00
new_position_ids = (
position_ids . unsqueeze ( - 1 ) . expand ( B , new_length ) + arange
) . view ( - 1 )
2023-12-11 11:46:30 +00:00
slots = ( slots . unsqueeze ( - 1 ) . expand ( B , new_length ) + arange_int ) . view ( - 1 )
2023-12-11 13:49:52 +00:00
input_lengths = (
input_lengths . unsqueeze ( - 1 ) . expand ( B , new_length ) + arange_int
) . view ( - 1 )
2023-12-11 11:46:30 +00:00
# Add Copy the block tables for all members
2023-12-11 13:49:52 +00:00
block_tables = (
block_tables . unsqueeze ( 1 )
. expand ( B , new_length , - 1 )
. reshape ( B * new_length , - 1 )
. contiguous ( )
)
2023-12-11 11:46:30 +00:00
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else :
2023-12-11 13:49:52 +00:00
input_ids = batch . input_ids
position_ids = batch . position_ids
cu_seqlen_prefill = batch . cu_seqlen_prefill
2024-06-05 10:18:38 +00:00
kv_cache = self . kv_cache
2023-12-11 13:49:52 +00:00
block_tables = batch . block_tables_tensor
slots = batch . slots [ batch . slot_indices ]
input_lengths = batch . input_lengths_tensor
max_s = batch . max_seqlen
lm_head_indices = batch . prefill_head_indices
2023-12-11 11:46:30 +00:00
2024-06-05 10:18:38 +00:00
if cu_seqlen_prefill is None and self . max_past ( ) is not None :
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min ( self . max_past ( ) , max_s )
2024-02-12 09:09:29 +00:00
bs = input_ids . shape [ 0 ]
2024-04-12 14:24:45 +00:00
sorted_padded_bs = sorted ( [ k for k in self . cuda_graphs . keys ( ) if k > = bs ] )
if sorted_padded_bs :
# Get associated cuda graph
cuda_graph = self . cuda_graphs [ sorted_padded_bs [ 0 ] ]
else :
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None :
2024-06-05 10:18:38 +00:00
logits , speculative_logits = self . model . forward (
2024-02-12 09:09:29 +00:00
input_ids = input_ids ,
position_ids = position_ids ,
cu_seqlen_prefill = cu_seqlen_prefill ,
kv_cache = kv_cache ,
block_tables = block_tables ,
slots = slots ,
input_lengths = input_lengths ,
max_s = max_s ,
2024-06-05 10:18:38 +00:00
prefill_cache_indices = batch . prefill_cache_indices ,
2024-02-12 09:09:29 +00:00
lm_head_indices = lm_head_indices ,
)
2024-06-05 10:18:38 +00:00
if batch . prefill_cache_indices is not None :
batch . prefill_cache_indices = None
return logits , speculative_logits
2024-02-12 09:09:29 +00:00
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph [ " input_ids " ] [ : input_ids . shape [ 0 ] ] = input_ids
cuda_graph [ " position_ids " ] [ : position_ids . shape [ 0 ] ] = position_ids
cuda_graph [ " block_tables " ] [
: block_tables . shape [ 0 ] , : block_tables . shape [ 1 ]
] = block_tables
cuda_graph [ " slots " ] . fill_ ( - 1 )
cuda_graph [ " slots " ] [ : slots . shape [ 0 ] ] = slots
cuda_graph [ " input_lengths " ] . zero_ ( )
cuda_graph [ " input_lengths " ] [ : input_lengths . shape [ 0 ] ] = input_lengths
# Replay the graph
cuda_graph [ " graph " ] . replay ( )
# Slice output to the correct shape
2024-02-26 18:49:28 +00:00
speculative_logits = (
cuda_graph [ " speculative_logits " ] [ : bs ]
if cuda_graph [ " speculative_logits " ] is not None
else None
)
logits = cuda_graph [ " logits " ] [ : bs ]
return logits , speculative_logits
2023-04-03 17:06:42 +00:00
@tracer.start_as_current_span ( " generate_token " )
def generate_token (
self , batch : FlashCausalLMBatch
2023-12-14 14:59:38 +00:00
) - > Tuple [ List [ Generation ] , Optional [ FlashCausalLMBatch ] , Tuple [ int , int ] ] :
start = time . time_ns ( )
2023-07-04 18:23:55 +00:00
prefill = batch . cu_seqlen_prefill is not None
2023-06-02 15:12:30 +00:00
prefill_logprobs = batch . prefill_next_token_indices is not None
2023-04-20 09:07:40 +00:00
2024-06-05 10:18:38 +00:00
out , speculative_logits = self . forward ( batch )
2023-04-03 17:06:42 +00:00
2023-05-26 10:30:27 +00:00
if prefill :
next_token_logits = (
2023-06-02 15:12:30 +00:00
out [ batch . prefill_next_token_indices ] if prefill_logprobs else out
2023-05-26 10:30:27 +00:00
)
2023-12-11 11:46:30 +00:00
if speculative_logits is not None :
speculative_logits = (
2023-12-11 13:49:52 +00:00
speculative_logits [ batch . prefill_next_token_indices ]
if prefill_logprobs
else speculative_logits
2023-12-11 11:46:30 +00:00
)
2023-05-26 10:30:27 +00:00
else :
next_token_logits = out
2024-01-26 19:13:47 +00:00
speculate = get_speculate ( )
2023-12-11 13:49:52 +00:00
(
next_input_ids ,
next_token_logprobs ,
logprobs ,
accepted_ids ,
speculative_ids ,
) = batch . next_token_chooser (
batch . all_input_ids_tensor [ : , : batch . max_seqlen ] ,
next_token_logits ,
2024-01-26 19:13:47 +00:00
speculate ,
2023-12-11 13:49:52 +00:00
batch . speculative_ids ,
speculative_logits ,
2023-05-26 10:30:27 +00:00
)
2023-08-28 09:43:47 +00:00
batch_top_token_ids , batch_top_token_logprobs = batch_top_tokens (
2024-01-26 19:13:47 +00:00
batch . top_n_tokens , batch . top_n_tokens_tensor , logprobs , accepted_ids
2023-08-28 09:43:47 +00:00
)
2023-05-09 16:26:19 +00:00
if prefill :
2023-06-02 15:12:30 +00:00
if len ( batch ) > 1 and prefill_logprobs :
2023-05-09 16:26:19 +00:00
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
2023-06-02 15:12:30 +00:00
prefill_tokens_indices = batch . input_ids . new_zeros ( len ( out ) )
2023-05-09 16:26:19 +00:00
next_position_ids = batch . position_ids . new_empty ( len ( batch ) )
2023-07-04 18:23:55 +00:00
batch . slot_indices = batch . slot_indices [ batch . cu_seqlen_prefill [ 1 : ] - 1 ]
# We do not need cu_seqlen_prefill anymore
batch . cu_seqlen_prefill = None
2023-05-09 16:26:19 +00:00
else :
prefill_logprobs = None
next_position_ids = batch . position_ids
2023-04-03 17:06:42 +00:00
# Cumulative length
cumulative_length = 0
# Results
generations : List [ Generation ] = [ ]
2023-04-20 09:07:40 +00:00
stopped = True
2023-04-03 17:06:42 +00:00
# Zipped iterator
2023-12-11 13:49:52 +00:00
iterator = zip ( batch . input_lengths , batch . all_input_ids , accepted_ids )
2023-04-03 17:06:42 +00:00
2023-05-09 16:26:19 +00:00
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a GPU <-> CPU sync
# It is faster if we delay this sync for the maximum amount of time
2023-04-03 17:06:42 +00:00
# For each member of the batch
2023-12-11 11:46:30 +00:00
index = 0
2023-12-11 13:49:52 +00:00
for i , ( input_length , all_input_ids , n_accepted_ids ) in enumerate ( iterator ) :
2023-06-12 16:30:29 +00:00
# Indexing metadata
2023-04-03 17:06:42 +00:00
start_index = cumulative_length
end_index = cumulative_length + input_length
2023-04-20 09:07:40 +00:00
if prefill :
2023-06-02 15:12:30 +00:00
# Indexing metadata
out_start_index = batch . prefill_cu_outlens [ i ]
out_end_index = batch . prefill_cu_outlens [ i + 1 ]
out_length = out_end_index - out_start_index
2023-05-09 16:26:19 +00:00
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids [ i ] = batch . position_ids [ end_index - 1 ]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
2023-06-02 15:12:30 +00:00
if prefill_logprobs :
if len ( batch ) > 1 :
2024-02-15 09:28:10 +00:00
prefill_tokens_indices [ out_start_index : out_end_index - 1 ] = (
batch . input_ids [ start_index + 1 : start_index + out_length ]
)
2023-06-02 15:12:30 +00:00
else :
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch . input_ids [
start_index + 1 : start_index + out_length
]
2023-05-09 16:26:19 +00:00
2023-12-11 11:46:30 +00:00
for j in range ( n_accepted_ids ) :
batch . all_input_ids_tensor [ i , input_length + j ] = next_input_ids [ index ]
index + = 1
2023-05-09 16:26:19 +00:00
cumulative_length + = input_length
2024-02-15 09:28:10 +00:00
# Update values
2023-12-11 11:46:30 +00:00
batch . input_ids = next_input_ids [ accepted_ids . cumsum ( dim = - 1 ) - 1 ]
batch . speculative_ids = speculative_ids
batch . position_ids = next_position_ids + accepted_ids
batch . input_lengths_tensor + = accepted_ids
batch . slot_indices + = accepted_ids
2023-05-09 16:26:19 +00:00
2023-06-02 15:12:30 +00:00
if prefill and prefill_logprobs :
2023-05-09 16:26:19 +00:00
# Get prefill logprobs
prefill_logprobs_tensor = torch . log_softmax ( out , - 1 )
prefill_logprobs = torch . gather (
prefill_logprobs_tensor , 1 , prefill_tokens_indices . view ( - 1 , 1 )
)
# GPU <-> CPU sync
prefill_logprobs = prefill_logprobs . view ( - 1 ) . tolist ( )
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs . tolist ( )
2023-12-11 11:46:30 +00:00
next_token_ids = next_input_ids . tolist ( )
2023-12-14 14:59:38 +00:00
accepted_ids = accepted_ids . tolist ( )
start_decode = time . time_ns ( )
2023-05-09 16:26:19 +00:00
# Zipped iterator
iterator = zip (
batch . requests ,
batch . input_lengths ,
2023-05-16 21:23:27 +00:00
batch . prefix_offsets ,
batch . read_offsets ,
2023-05-09 16:26:19 +00:00
batch . stopping_criterias ,
batch . all_input_ids ,
2023-05-26 10:30:27 +00:00
batch . next_token_chooser . do_sample ,
batch . next_token_chooser . seeds ,
2023-08-28 09:43:47 +00:00
batch . top_n_tokens ,
2023-12-11 11:46:30 +00:00
accepted_ids ,
2023-08-28 09:43:47 +00:00
batch_top_token_ids ,
batch_top_token_logprobs ,
2023-05-09 16:26:19 +00:00
)
# For each member of the batch
2023-12-11 11:46:30 +00:00
index = 0
2023-05-09 16:26:19 +00:00
for i , (
request ,
input_length ,
2023-05-16 21:23:27 +00:00
prefix_offset ,
read_offset ,
2023-05-09 16:26:19 +00:00
stopping_criteria ,
all_input_ids ,
2023-05-26 10:30:27 +00:00
do_sample ,
seed ,
2023-08-28 09:43:47 +00:00
top_n_tokens ,
2023-12-11 11:46:30 +00:00
n_accepted_ids ,
2023-08-28 09:43:47 +00:00
top_token_ids ,
top_token_logprobs ,
2023-05-09 16:26:19 +00:00
) in enumerate ( iterator ) :
2023-04-03 17:06:42 +00:00
# Append next token to all tokens
2023-12-11 11:46:30 +00:00
next_token_texts = [ ]
left = 0
2024-05-23 13:40:40 +00:00
if n_accepted_ids > 1 :
if RANK == 0 :
logger . debug ( f " Speculated ids { n_accepted_ids - 1 } " )
2023-12-11 11:46:30 +00:00
current_stopped = False
for j in range ( index , index + n_accepted_ids ) :
# Generated token
next_token_id = next_token_ids [ j ]
all_input_ids . append ( next_token_id )
next_token_text , prefix_offset , read_offset = self . decode_token (
all_input_ids ,
prefix_offset ,
read_offset ,
)
next_token_texts . append ( next_token_text )
2023-04-03 17:06:42 +00:00
2023-12-11 11:46:30 +00:00
stop , reason = stopping_criteria (
next_token_id ,
next_token_text ,
)
2023-04-03 17:06:42 +00:00
2023-12-11 11:46:30 +00:00
if stop :
left = index + n_accepted_ids - j - 1
current_stopped = True
break
else :
current_stopped = False
stopped = stopped and current_stopped
2023-04-03 17:06:42 +00:00
2023-12-11 13:49:52 +00:00
_next_token_ids = next_token_ids [ index : index + n_accepted_ids - left ]
_next_token_logprobs = next_token_logprobs [
index : index + n_accepted_ids - left
]
2023-12-11 11:46:30 +00:00
index + = n_accepted_ids
2023-04-03 17:06:42 +00:00
2023-05-10 13:48:21 +00:00
# Shard generations
# All generations will be appended in the rust sharded client
if i % self . world_size == self . rank :
if stop :
# Decode generated tokens
2023-09-27 10:13:45 +00:00
output_text , _ , _ = self . decode_token (
all_input_ids ,
2023-09-27 10:22:09 +00:00
prefix_offset = len ( all_input_ids )
- stopping_criteria . current_tokens
- 1 ,
read_offset = len ( all_input_ids )
- stopping_criteria . current_tokens ,
skip_special_tokens = True ,
2023-05-10 13:48:21 +00:00
)
generated_text = GeneratedText (
2023-05-26 10:30:27 +00:00
output_text ,
stopping_criteria . current_tokens ,
reason ,
seed if do_sample else None ,
2023-05-10 13:48:21 +00:00
)
else :
generated_text = None
# Prefill
2023-06-02 15:12:30 +00:00
if prefill and request . prefill_logprobs :
out_start_index = batch . prefill_cu_outlens [ i ]
out_end_index = batch . prefill_cu_outlens [ i + 1 ]
2023-05-10 13:48:21 +00:00
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [ float ( " nan " ) ] + prefill_logprobs [
2023-06-02 15:12:30 +00:00
out_start_index : out_end_index - 1
2023-05-10 13:48:21 +00:00
]
prefill_token_ids = all_input_ids [ : - 1 ]
prefill_texts = self . tokenizer . batch_decode (
prefill_token_ids ,
clean_up_tokenization_spaces = False ,
skip_special_tokens = False ,
)
2023-12-11 11:46:30 +00:00
prefill_tokens = Tokens (
2023-12-11 13:49:52 +00:00
prefill_token_ids ,
request_prefill_logprobs ,
prefill_texts ,
is_special = [ ] ,
2023-05-10 13:48:21 +00:00
)
else :
prefill_tokens = None
2023-08-28 09:43:47 +00:00
if top_n_tokens > 0 :
2024-01-26 19:13:47 +00:00
all_top_tokens = [ ]
2024-02-15 09:28:10 +00:00
for top_token_ids , top_token_logprobs in zip (
2024-02-08 17:41:25 +00:00
top_token_ids , top_token_logprobs
) :
2024-01-26 19:13:47 +00:00
toptoken_texts = self . tokenizer . batch_decode (
top_token_ids ,
clean_up_tokenization_spaces = False ,
skip_special_tokens = False ,
)
special_toptokens = [
2024-02-08 17:41:25 +00:00
token_id in self . all_special_ids
for token_id in top_token_ids
2024-01-26 19:13:47 +00:00
]
top_tokens = Tokens (
top_token_ids ,
top_token_logprobs ,
toptoken_texts ,
special_toptokens ,
)
all_top_tokens . append ( top_tokens )
top_tokens = all_top_tokens
2023-08-28 09:43:47 +00:00
else :
top_tokens = None
2023-05-10 13:48:21 +00:00
generation = Generation (
request . id ,
prefill_tokens ,
2023-12-11 11:46:30 +00:00
Tokens (
_next_token_ids ,
_next_token_logprobs ,
next_token_texts ,
[ nid in self . all_special_ids for nid in _next_token_ids ] ,
) ,
2023-05-10 13:48:21 +00:00
generated_text ,
2023-08-28 09:43:47 +00:00
top_tokens ,
2023-04-03 17:06:42 +00:00
)
2023-05-10 13:48:21 +00:00
generations . append ( generation )
2023-04-03 17:06:42 +00:00
2024-02-15 09:28:10 +00:00
# accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids :
2024-02-16 10:58:58 +00:00
batch . next_token_chooser = (
batch . next_token_chooser . advance_grammar_single ( i , next_token_id )
)
2024-02-15 09:28:10 +00:00
2023-04-20 09:07:40 +00:00
# Update values
2023-12-14 14:59:38 +00:00
batch . input_lengths [ i ] = input_length + n_accepted_ids
2023-12-11 11:46:30 +00:00
if batch . input_lengths [ i ] > batch . max_seqlen :
batch . max_seqlen = batch . input_lengths [ i ]
2023-05-16 21:23:27 +00:00
batch . prefix_offsets [ i ] = prefix_offset
batch . read_offsets [ i ] = read_offset
2023-04-20 09:07:40 +00:00
batch . all_input_ids [ i ] = all_input_ids
2023-06-30 17:09:59 +00:00
if stopped :
# No need to return a batch if we know that all requests stopped
2023-12-14 14:59:38 +00:00
forward_ns = start_decode - start
decode_ns = time . time_ns ( ) - start_decode
return generations , None , ( forward_ns , decode_ns )
2023-06-30 17:09:59 +00:00
2023-06-02 15:12:30 +00:00
batch . prefill_cu_outlens = None
batch . prefill_head_indices = None
batch . prefill_next_token_indices = None
2023-05-26 10:30:27 +00:00
2023-12-14 14:59:38 +00:00
forward_ns = start_decode - start
decode_ns = time . time_ns ( ) - start_decode
return generations , batch , ( forward_ns , decode_ns )