mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
Merge branch 'support-openai-models-endpoint' of github.com:huggingface/text-generation-inference into support-openai-models-endpoint
This commit is contained in:
commit
b348ab4c55
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -2174,6 +2174,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d"
|
checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -316,10 +316,15 @@ impl State {
|
|||||||
+ self.speculate
|
+ self.speculate
|
||||||
- 1;
|
- 1;
|
||||||
|
|
||||||
match block_allocator
|
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||||
.allocate(tokens, entry.request.input_ids.clone())
|
// So no input_ids for the radix tree.
|
||||||
.await
|
let input_ids = if entry.request.decoder_input_details {
|
||||||
{
|
None
|
||||||
|
} else {
|
||||||
|
entry.request.input_ids.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
match block_allocator.allocate(tokens, input_ids).await {
|
||||||
None => {
|
None => {
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
|
@ -205,6 +205,7 @@ pub struct RadixTrie {
|
|||||||
/// call that a real time lookup would require.
|
/// call that a real time lookup would require.
|
||||||
time: u64,
|
time: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RadixTrie {
|
impl Default for RadixTrie {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self::new()
|
Self::new()
|
||||||
|
@ -757,7 +757,12 @@ class AsyncClient:
|
|||||||
continue
|
continue
|
||||||
payload = byte_payload.decode("utf-8")
|
payload = byte_payload.decode("utf-8")
|
||||||
if payload.startswith("data:"):
|
if payload.startswith("data:"):
|
||||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
payload_data = (
|
||||||
|
payload.lstrip("data:").rstrip("\n").removeprefix(" ")
|
||||||
|
)
|
||||||
|
if payload_data == "[DONE]":
|
||||||
|
break
|
||||||
|
json_payload = json.loads(payload_data)
|
||||||
try:
|
try:
|
||||||
response = ChatCompletionChunk(**json_payload)
|
response = ChatCompletionChunk(**json_payload)
|
||||||
yield response
|
yield response
|
||||||
|
@ -924,7 +924,7 @@
|
|||||||
"tool_prompt": {
|
"tool_prompt": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "A prompt to be appended before the tools",
|
"description": "A prompt to be appended before the tools",
|
||||||
"example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"",
|
"example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
"tools": {
|
"tools": {
|
||||||
|
127
flake.lock
127
flake.lock
@ -492,24 +492,6 @@
|
|||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"flake-utils_7": {
|
|
||||||
"inputs": {
|
|
||||||
"systems": "systems_7"
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1710146030,
|
|
||||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "flake-utils",
|
|
||||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "flake-utils",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gitignore": {
|
"gitignore": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"nixpkgs": [
|
"nixpkgs": [
|
||||||
@ -594,27 +576,6 @@
|
|||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nix-github-actions": {
|
|
||||||
"inputs": {
|
|
||||||
"nixpkgs": [
|
|
||||||
"poetry2nix",
|
|
||||||
"nixpkgs"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1703863825,
|
|
||||||
"narHash": "sha256-rXwqjtwiGKJheXB43ybM8NwWB8rO2dSRrEqes0S7F5Y=",
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "nix-github-actions",
|
|
||||||
"rev": "5163432afc817cf8bd1f031418d1869e4c9d5547",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "nix-github-actions",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nix-test-runner": {
|
"nix-test-runner": {
|
||||||
"flake": false,
|
"flake": false,
|
||||||
"locked": {
|
"locked": {
|
||||||
@ -753,31 +714,6 @@
|
|||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"poetry2nix": {
|
|
||||||
"inputs": {
|
|
||||||
"flake-utils": "flake-utils_7",
|
|
||||||
"nix-github-actions": "nix-github-actions",
|
|
||||||
"nixpkgs": [
|
|
||||||
"tgi-nix",
|
|
||||||
"nixpkgs"
|
|
||||||
],
|
|
||||||
"systems": "systems_8",
|
|
||||||
"treefmt-nix": "treefmt-nix"
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1723854676,
|
|
||||||
"narHash": "sha256-+BrHfNuXrqeE7PoV6xDaoh0joYiJkvTTCIV0fFR3THw=",
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "poetry2nix",
|
|
||||||
"rev": "d650118bce34c0238b9b54f23f7f173f9e4db867",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "poetry2nix",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"pre-commit-hooks": {
|
"pre-commit-hooks": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"flake-compat": [
|
"flake-compat": [
|
||||||
@ -887,7 +823,6 @@
|
|||||||
"tgi-nix",
|
"tgi-nix",
|
||||||
"nixpkgs"
|
"nixpkgs"
|
||||||
],
|
],
|
||||||
"poetry2nix": "poetry2nix",
|
|
||||||
"rust-overlay": "rust-overlay",
|
"rust-overlay": "rust-overlay",
|
||||||
"tgi-nix": "tgi-nix"
|
"tgi-nix": "tgi-nix"
|
||||||
}
|
}
|
||||||
@ -900,11 +835,11 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1723515680,
|
"lastModified": 1724206841,
|
||||||
"narHash": "sha256-nHdKymsHCVIh0Wdm4MvSgxcTTg34FJIYHRQkQYaSuvk=",
|
"narHash": "sha256-L8dKaX4T3k+TR2fEHCfGbH4UXdspovz/pj87iai9qmc=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "4ee3d9e9569f70d7bb40f28804d6fe950c81eab3",
|
"rev": "45e98fbd62c32e5927e952d2833fa1ba4fb35a61",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -1003,46 +938,17 @@
|
|||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"systems_7": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1681028828,
|
|
||||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"systems_8": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1681028828,
|
|
||||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"id": "systems",
|
|
||||||
"type": "indirect"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"tgi-nix": {
|
"tgi-nix": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"flake-compat": "flake-compat_4",
|
"flake-compat": "flake-compat_4",
|
||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1723973328,
|
"lastModified": 1724270760,
|
||||||
"narHash": "sha256-q5FmW4YFQcRb6fXHnrxL0uno6xcw9dcg+pFBbVM1xeQ=",
|
"narHash": "sha256-KX566x0+3HZcB20HPdvdwyMm7ZJg21M+iqVrs/HCimA=",
|
||||||
"owner": "danieldk",
|
"owner": "danieldk",
|
||||||
"repo": "tgi-nix",
|
"repo": "tgi-nix",
|
||||||
"rev": "d2038f36589a8a179834e5771ffd081620ba94c3",
|
"rev": "12cbaa76ff258351741d3b5afb7161f617fe7b4c",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -1050,27 +956,6 @@
|
|||||||
"repo": "tgi-nix",
|
"repo": "tgi-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"treefmt-nix": {
|
|
||||||
"inputs": {
|
|
||||||
"nixpkgs": [
|
|
||||||
"poetry2nix",
|
|
||||||
"nixpkgs"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1719749022,
|
|
||||||
"narHash": "sha256-ddPKHcqaKCIFSFc/cvxS14goUhCOAwsM1PbMr0ZtHMg=",
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "treefmt-nix",
|
|
||||||
"rev": "8df5ff62195d4e67e2264df0b7f5e8c9995fd0bd",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "treefmt-nix",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"root": "root",
|
"root": "root",
|
||||||
|
80
flake.nix
80
flake.nix
@ -8,10 +8,6 @@
|
|||||||
tgi-nix.url = "github:danieldk/tgi-nix";
|
tgi-nix.url = "github:danieldk/tgi-nix";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
poetry2nix = {
|
|
||||||
url = "github:nix-community/poetry2nix";
|
|
||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
|
||||||
};
|
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
url = "github:oxalica/rust-overlay";
|
url = "github:oxalica/rust-overlay";
|
||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
@ -26,7 +22,6 @@
|
|||||||
flake-utils,
|
flake-utils,
|
||||||
rust-overlay,
|
rust-overlay,
|
||||||
tgi-nix,
|
tgi-nix,
|
||||||
poetry2nix,
|
|
||||||
}:
|
}:
|
||||||
flake-utils.lib.eachDefaultSystem (
|
flake-utils.lib.eachDefaultSystem (
|
||||||
system:
|
system:
|
||||||
@ -47,14 +42,32 @@
|
|||||||
tgi-nix.overlay
|
tgi-nix.overlay
|
||||||
];
|
];
|
||||||
};
|
};
|
||||||
inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryEditablePackage;
|
|
||||||
text-generation-server = mkPoetryEditablePackage { editablePackageSources = ./server; };
|
|
||||||
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
|
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
|
||||||
|
benchmark = cargoNix.workspaceMembers.text-generation-benchmark.build.override {
|
||||||
|
inherit crateOverrides;
|
||||||
|
};
|
||||||
|
launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override {
|
||||||
|
inherit crateOverrides;
|
||||||
|
};
|
||||||
|
router = cargoNix.workspaceMembers.text-generation-router-v3.build.override {
|
||||||
|
inherit crateOverrides;
|
||||||
|
};
|
||||||
|
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
devShells.default =
|
devShells = with pkgs; rec {
|
||||||
with pkgs;
|
default = pure;
|
||||||
mkShell {
|
|
||||||
|
pure = mkShell {
|
||||||
|
buildInputs = [
|
||||||
|
benchmark
|
||||||
|
launcher
|
||||||
|
router
|
||||||
|
server
|
||||||
|
];
|
||||||
|
};
|
||||||
|
|
||||||
|
impure = mkShell {
|
||||||
buildInputs =
|
buildInputs =
|
||||||
[
|
[
|
||||||
openssl.dev
|
openssl.dev
|
||||||
@ -65,42 +78,16 @@
|
|||||||
"rust-src"
|
"rust-src"
|
||||||
];
|
];
|
||||||
})
|
})
|
||||||
|
protobuf
|
||||||
]
|
]
|
||||||
++ (with python3.pkgs; [
|
++ (with python3.pkgs; [
|
||||||
venvShellHook
|
venvShellHook
|
||||||
pip
|
pip
|
||||||
|
ipdb
|
||||||
causal-conv1d
|
|
||||||
click
|
|
||||||
einops
|
|
||||||
exllamav2
|
|
||||||
fbgemm-gpu
|
|
||||||
flashinfer
|
|
||||||
flash-attn
|
|
||||||
flash-attn-layer-norm
|
|
||||||
flash-attn-rotary
|
|
||||||
grpc-interceptor
|
|
||||||
grpcio-reflection
|
|
||||||
grpcio-status
|
|
||||||
grpcio-tools
|
|
||||||
hf-transfer
|
|
||||||
loguru
|
|
||||||
mamba-ssm
|
|
||||||
marlin-kernels
|
|
||||||
opentelemetry-api
|
|
||||||
opentelemetry-exporter-otlp
|
|
||||||
opentelemetry-instrumentation-grpc
|
|
||||||
opentelemetry-semantic-conventions
|
|
||||||
peft
|
|
||||||
tokenizers
|
|
||||||
torch
|
|
||||||
transformers
|
|
||||||
vllm
|
|
||||||
|
|
||||||
(cargoNix.workspaceMembers.text-generation-launcher.build.override { inherit crateOverrides; })
|
|
||||||
(cargoNix.workspaceMembers.text-generation-router-v3.build.override { inherit crateOverrides; })
|
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
inputsFrom = [ server ];
|
||||||
|
|
||||||
venvDir = "./.venv";
|
venvDir = "./.venv";
|
||||||
|
|
||||||
postVenv = ''
|
postVenv = ''
|
||||||
@ -108,6 +95,19 @@
|
|||||||
'';
|
'';
|
||||||
postShellHook = ''
|
postShellHook = ''
|
||||||
unset SOURCE_DATE_EPOCH
|
unset SOURCE_DATE_EPOCH
|
||||||
|
export PATH=$PATH:~/.cargo/bin
|
||||||
|
'';
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
packages.default = pkgs.writeShellApplication {
|
||||||
|
name = "text-generation-inference";
|
||||||
|
runtimeInputs = [
|
||||||
|
server
|
||||||
|
router
|
||||||
|
];
|
||||||
|
text = ''
|
||||||
|
${launcher}/bin/text-generation-launcher "$@"
|
||||||
'';
|
'';
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -257,7 +257,7 @@ class IgnoreLogProbResponseComparator(ResponseComparator):
|
|||||||
|
|
||||||
class LauncherHandle:
|
class LauncherHandle:
|
||||||
def __init__(self, port: int):
|
def __init__(self, port: int):
|
||||||
self.client = AsyncClient(f"http://localhost:{port}")
|
self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
|
||||||
|
|
||||||
def _inner_health(self):
|
def _inner_health(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -36,6 +36,7 @@ tools = [
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["location", "format"],
|
"required": ["location", "format"],
|
||||||
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -62,13 +63,13 @@ tools = [
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["location", "format", "num_days"],
|
"required": ["location", "format", "num_days"],
|
||||||
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||||
@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
|||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
presence_penalty=-1.1,
|
temperature=0.0,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
|||||||
assert response.choices[0].message.content is None
|
assert response.choices[0].message.content is None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_auto(
|
async def test_flash_llama_grammar_tools_auto(
|
||||||
@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto(
|
|||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
presence_penalty=-1.1,
|
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto(
|
|||||||
assert response.choices[0].message.content is None
|
assert response.choices[0].message.content is None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto(
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_choice(
|
async def test_flash_llama_grammar_tools_choice(
|
||||||
@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice(
|
|||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
tool_choice="get_current_weather",
|
tool_choice="get_current_weather",
|
||||||
presence_penalty=-1.1,
|
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice(
|
|||||||
assert response.choices[0].message.content is None
|
assert response.choices[0].message.content is None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice(
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_stream(
|
async def test_flash_llama_grammar_tools_stream(
|
||||||
@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream(
|
|||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
tool_choice="get_current_weather",
|
tool_choice="get_current_weather",
|
||||||
presence_penalty=-1.1,
|
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream(
|
|||||||
async for response in responses:
|
async for response in responses:
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
assert count == 38
|
assert count == 48
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_insufficient_information(
|
async def test_flash_llama_grammar_tools_insufficient_information(
|
||||||
@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
|||||||
):
|
):
|
||||||
responses = await flash_llama_grammar_tools.chat(
|
responses = await flash_llama_grammar_tools.chat(
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=8,
|
seed=24,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert responses.choices[0].message.content is None
|
assert responses.choices[0].message.content is None
|
||||||
assert responses.choices[0].message.tool_calls == [
|
assert (
|
||||||
{
|
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
|
||||||
"function": {
|
)
|
||||||
"arguments": {
|
|
||||||
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
|
|
||||||
},
|
|
||||||
"description": None,
|
|
||||||
"name": "notify_error",
|
|
||||||
},
|
|
||||||
"id": 0,
|
|
||||||
"type": "function",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
@ -20,8 +20,7 @@ defaultCrateOverrides
|
|||||||
rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; };
|
rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; };
|
||||||
|
|
||||||
grpc-metadata = attrs: {
|
grpc-metadata = attrs: {
|
||||||
src =
|
src = filter {
|
||||||
filter {
|
|
||||||
root = ../backends/grpc-metadata;
|
root = ../backends/grpc-metadata;
|
||||||
include = with filter; [
|
include = with filter; [
|
||||||
isDirectory
|
isDirectory
|
||||||
@ -29,9 +28,29 @@ defaultCrateOverrides
|
|||||||
];
|
];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
text-generation-launcer = attrs: {
|
text-generation-benchmark = attrs: {
|
||||||
src =
|
src = filter {
|
||||||
filter {
|
root = ../benchmark;
|
||||||
|
include = with filter; [
|
||||||
|
isDirectory
|
||||||
|
(matchExt "rs")
|
||||||
|
];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
text-generation-client = attrs: {
|
||||||
|
src = filter {
|
||||||
|
root = ../.;
|
||||||
|
include = with filter; [
|
||||||
|
isDirectory
|
||||||
|
(and (inDirectory "backends/client") (matchExt "rs"))
|
||||||
|
(and (inDirectory "proto") (matchExt "proto"))
|
||||||
|
];
|
||||||
|
};
|
||||||
|
postPatch = "cd backends/client";
|
||||||
|
buildInputs = [ protobuf ];
|
||||||
|
};
|
||||||
|
text-generation-launcher = attrs: {
|
||||||
|
src = filter {
|
||||||
root = ../launcher;
|
root = ../launcher;
|
||||||
include = with filter; [
|
include = with filter; [
|
||||||
isDirectory
|
isDirectory
|
||||||
@ -40,8 +59,7 @@ defaultCrateOverrides
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
text-generation-router = attrs: {
|
text-generation-router = attrs: {
|
||||||
src =
|
src = filter {
|
||||||
filter {
|
|
||||||
root = ../router;
|
root = ../router;
|
||||||
include = with filter; [
|
include = with filter; [
|
||||||
isDirectory
|
isDirectory
|
||||||
|
109
nix/server.nix
Normal file
109
nix/server.nix
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
{
|
||||||
|
nix-filter,
|
||||||
|
buildPythonPackage,
|
||||||
|
poetry-core,
|
||||||
|
mypy-protobuf,
|
||||||
|
awq-inference-engine,
|
||||||
|
causal-conv1d,
|
||||||
|
eetq,
|
||||||
|
einops,
|
||||||
|
exllamav2,
|
||||||
|
fbgemm-gpu,
|
||||||
|
flashinfer,
|
||||||
|
flash-attn,
|
||||||
|
flash-attn-layer-norm,
|
||||||
|
flash-attn-rotary,
|
||||||
|
grpc-interceptor,
|
||||||
|
grpcio-reflection,
|
||||||
|
grpcio-status,
|
||||||
|
grpcio-tools,
|
||||||
|
hf-transfer,
|
||||||
|
loguru,
|
||||||
|
mamba-ssm,
|
||||||
|
marlin-kernels,
|
||||||
|
opentelemetry-api,
|
||||||
|
opentelemetry-exporter-otlp,
|
||||||
|
opentelemetry-instrumentation-grpc,
|
||||||
|
opentelemetry-semantic-conventions,
|
||||||
|
peft,
|
||||||
|
safetensors,
|
||||||
|
tokenizers,
|
||||||
|
sentencepiece,
|
||||||
|
transformers,
|
||||||
|
typer,
|
||||||
|
vllm,
|
||||||
|
}:
|
||||||
|
|
||||||
|
let
|
||||||
|
filter = nix-filter.lib;
|
||||||
|
in
|
||||||
|
buildPythonPackage {
|
||||||
|
name = "text-generation-server";
|
||||||
|
|
||||||
|
src = filter {
|
||||||
|
root = ../.;
|
||||||
|
include = with filter; [
|
||||||
|
isDirectory
|
||||||
|
(and (inDirectory "server") (or_ (matchExt "py") (matchExt "pyi")))
|
||||||
|
"server/pyproject.toml"
|
||||||
|
(and (inDirectory "proto/v3") (matchExt "proto"))
|
||||||
|
];
|
||||||
|
};
|
||||||
|
|
||||||
|
pyproject = true;
|
||||||
|
|
||||||
|
build-system = [ poetry-core ];
|
||||||
|
|
||||||
|
nativeBuildInputs = [ mypy-protobuf ];
|
||||||
|
|
||||||
|
pythonRelaxDeps = [
|
||||||
|
"einops"
|
||||||
|
"huggingface-hub"
|
||||||
|
"loguru"
|
||||||
|
"opentelemetry-instrumentation-grpc"
|
||||||
|
"sentencepiece"
|
||||||
|
"typer"
|
||||||
|
];
|
||||||
|
|
||||||
|
pythonRemoveDeps = [ "scipy" ];
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
awq-inference-engine
|
||||||
|
eetq
|
||||||
|
causal-conv1d
|
||||||
|
einops
|
||||||
|
exllamav2
|
||||||
|
fbgemm-gpu
|
||||||
|
flashinfer
|
||||||
|
flash-attn
|
||||||
|
flash-attn-layer-norm
|
||||||
|
flash-attn-rotary
|
||||||
|
grpc-interceptor
|
||||||
|
grpcio-reflection
|
||||||
|
grpcio-status
|
||||||
|
grpcio-tools
|
||||||
|
hf-transfer
|
||||||
|
loguru
|
||||||
|
mamba-ssm
|
||||||
|
marlin-kernels
|
||||||
|
opentelemetry-api
|
||||||
|
opentelemetry-exporter-otlp
|
||||||
|
opentelemetry-instrumentation-grpc
|
||||||
|
opentelemetry-semantic-conventions
|
||||||
|
peft
|
||||||
|
safetensors
|
||||||
|
sentencepiece
|
||||||
|
tokenizers
|
||||||
|
transformers
|
||||||
|
typer
|
||||||
|
vllm
|
||||||
|
];
|
||||||
|
|
||||||
|
prePatch = ''
|
||||||
|
python -m grpc_tools.protoc -Iproto/v3 --python_out=server/text_generation_server/pb \
|
||||||
|
--grpc_python_out=server/text_generation_server/pb --mypy_out=server/text_generation_server/pb proto/v3/generate.proto
|
||||||
|
find server/text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
|
touch server/text_generation_server/pb/__init__.py
|
||||||
|
cd server
|
||||||
|
'';
|
||||||
|
}
|
@ -46,7 +46,7 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
|||||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
] }
|
] }
|
||||||
minijinja = { version = "2.0.2" }
|
minijinja = { version = "2.0.2", features = ["json"] }
|
||||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
regex = "1.10.3"
|
regex = "1.10.3"
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{
|
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
|
||||||
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
|
|
||||||
};
|
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
use minijinja_contrib::pycompat;
|
use minijinja_contrib::pycompat;
|
||||||
|
|
||||||
@ -32,6 +30,7 @@ impl ChatTemplate {
|
|||||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||||
let template_str = template.into_boxed_str();
|
let template_str = template.into_boxed_str();
|
||||||
env.add_function("raise_exception", raise_exception);
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
tracing::debug!("Loading template: {:#?}", template_str);
|
||||||
|
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
let template = Box::leak(env)
|
let template = Box::leak(env)
|
||||||
@ -42,6 +41,7 @@ impl ChatTemplate {
|
|||||||
let variables = template.undeclared_variables(true);
|
let variables = template.undeclared_variables(true);
|
||||||
// check if the `tools` variable is used in the template
|
// check if the `tools` variable is used in the template
|
||||||
let use_default_tool_template = !variables.contains("tools");
|
let use_default_tool_template = !variables.contains("tools");
|
||||||
|
tracing::debug!("Use default tool template: {}", use_default_tool_template);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
template,
|
template,
|
||||||
@ -56,25 +56,36 @@ impl ChatTemplate {
|
|||||||
&self,
|
&self,
|
||||||
guideline: Option<&str>,
|
guideline: Option<&str>,
|
||||||
mut messages: Vec<Message>,
|
mut messages: Vec<Message>,
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
if self.use_default_tool_template {
|
|
||||||
if let Some(last_message) = messages.last_mut() {
|
|
||||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
|
||||||
last_message.content.push(MessageChunk::Text {
|
|
||||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
|
||||||
|
|
||||||
// check if guideline is expected but not provided
|
// check if guideline is expected but not provided
|
||||||
if self.variables.contains("guideline") && guideline.is_none() {
|
if self.variables.contains("guideline") && guideline.is_none() {
|
||||||
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
|
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let tools = match tools_and_prompt {
|
||||||
|
Some((tools, tool_prompt)) => {
|
||||||
|
// check if the `tools` variable is used in the template
|
||||||
|
// if not, we need to append the tools to the last message
|
||||||
|
let text = if self.use_default_tool_template {
|
||||||
|
match serde_json::to_string(&tools) {
|
||||||
|
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
|
||||||
|
Err(e) => return Err(InferError::ToolError(e.to_string())),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// if the `tools` variable is used in the template, we just append the tool_prompt
|
||||||
|
format!("\n---\n{}", tool_prompt)
|
||||||
|
};
|
||||||
|
if let Some(last_message) = messages.last_mut() {
|
||||||
|
last_message.content.push(MessageChunk::Text { text });
|
||||||
|
}
|
||||||
|
Some(tools)
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||||
|
|
||||||
self.template
|
self.template
|
||||||
.render(ChatTemplateInputs {
|
.render(ChatTemplateInputs {
|
||||||
guideline,
|
guideline,
|
||||||
@ -82,8 +93,7 @@ impl ChatTemplate {
|
|||||||
bos_token: self.bos_token.as_deref(),
|
bos_token: self.bos_token.as_deref(),
|
||||||
eos_token: self.eos_token.as_deref(),
|
eos_token: self.eos_token.as_deref(),
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
tools: None,
|
tools,
|
||||||
tools_prompt: None,
|
|
||||||
})
|
})
|
||||||
.map_err(InferError::TemplateError)
|
.map_err(InferError::TemplateError)
|
||||||
}
|
}
|
||||||
@ -95,7 +105,7 @@ mod tests {
|
|||||||
use crate::infer::chat_template::raise_exception;
|
use crate::infer::chat_template::raise_exception;
|
||||||
use crate::infer::ChatTemplate;
|
use crate::infer::ChatTemplate;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken,
|
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
|
||||||
};
|
};
|
||||||
use minijinja::Environment;
|
use minijinja::Environment;
|
||||||
|
|
||||||
@ -854,11 +864,12 @@ mod tests {
|
|||||||
content: MessageContent::SingleText("Just testing".to_string()),
|
content: MessageContent::SingleText("Just testing".to_string()),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
let tools = serde_json::json!("[]");
|
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||||
|
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
|
||||||
let tool_prompt = "This default prompt will be used".to_string();
|
let tool_prompt = "This default prompt will be used".to_string();
|
||||||
let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt);
|
let tools_and_prompt = Some((tools, tool_prompt));
|
||||||
let result = ct.apply(None, msgs, Some(grammer_with_prompt));
|
let result = ct.apply(None, msgs, tools_and_prompt);
|
||||||
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string();
|
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
|
||||||
assert_eq!(result.unwrap(), expected);
|
assert_eq!(result.unwrap(), expected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ mod chat_template;
|
|||||||
pub mod tool_grammar;
|
pub mod tool_grammar;
|
||||||
|
|
||||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||||
use crate::GrammarType;
|
use crate::Tool;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||||
Message, PrefillToken, Token,
|
Message, PrefillToken, Token,
|
||||||
@ -140,12 +140,12 @@ impl Infer {
|
|||||||
&self,
|
&self,
|
||||||
guideline: Option<String>,
|
guideline: Option<String>,
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
self.chat_template
|
self.chat_template
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
.apply(guideline.as_deref(), messages, grammar_with_prompt)
|
.apply(guideline.as_deref(), messages, tools_and_prompt)
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
||||||
tracing::error!("{e}");
|
tracing::error!("{e}");
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
|
use crate::{
|
||||||
|
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
|
||||||
|
ToolType,
|
||||||
|
};
|
||||||
use serde_json::{json, Map, Value};
|
use serde_json::{json, Map, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -16,17 +19,38 @@ impl ToolGrammar {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn apply(
|
pub fn apply(
|
||||||
tools: Option<Vec<Tool>>,
|
tools: Vec<Tool>,
|
||||||
tool_choice: ToolChoice,
|
tool_choice: ToolChoice,
|
||||||
) -> Result<Option<Tools>, InferError> {
|
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
|
||||||
// if no tools are provided, we return None
|
// if no tools are provided, we return None
|
||||||
let tools = match tools {
|
if tools.is_empty() {
|
||||||
Some(tools) if !tools.is_empty() => tools,
|
return Ok((tools, None));
|
||||||
_ => return Ok(None),
|
}
|
||||||
};
|
|
||||||
|
|
||||||
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||||
|
|
||||||
|
let mut tools = tools.clone();
|
||||||
|
|
||||||
|
// add the notify_error function to the tools
|
||||||
|
let notify_error = Tool {
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: FunctionDefinition {
|
||||||
|
name: "notify_error".to_string(),
|
||||||
|
description: Some("Notify an error or issue".to_string()),
|
||||||
|
arguments: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"error": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The error or issue to notify"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["error"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
tools.push(notify_error);
|
||||||
|
|
||||||
// if tools are provided and no tool_choice we default to the OneOf
|
// if tools are provided and no tool_choice we default to the OneOf
|
||||||
let tools_to_use = match tool_choice {
|
let tools_to_use = match tool_choice {
|
||||||
ToolType::FunctionName(name) => {
|
ToolType::FunctionName(name) => {
|
||||||
@ -35,87 +59,57 @@ impl ToolGrammar {
|
|||||||
ToolType::Function { function } => {
|
ToolType::Function { function } => {
|
||||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||||
}
|
}
|
||||||
ToolType::OneOf => tools,
|
ToolType::OneOf => tools.clone(),
|
||||||
ToolType::NoTool => return Ok(None),
|
ToolType::NoTool => return Ok((tools, None)),
|
||||||
};
|
};
|
||||||
|
|
||||||
// adds the error notification function for LLM feedback if required
|
|
||||||
let mut text_response_properties = Map::new();
|
|
||||||
text_response_properties.insert(
|
|
||||||
"error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"description": "The error or issue to notify"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
text_response_properties.insert(
|
|
||||||
"_name".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"const": "notify_error"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||||
.iter()
|
.iter()
|
||||||
.map(|tool| {
|
.map(|tool| {
|
||||||
let func = tool.function.clone();
|
let func = tool.function.clone();
|
||||||
|
|
||||||
// Clone the existing parameters, which are expected to be a JSON object
|
let mut params = Map::new();
|
||||||
let mut params = if let Value::Object(params) = &func.arguments {
|
|
||||||
params.clone()
|
|
||||||
} else {
|
|
||||||
Map::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Insert the function's description at the top level, outside of properties
|
|
||||||
params.insert(
|
params.insert(
|
||||||
"description".to_string(),
|
"description".to_string(),
|
||||||
Value::String(func.description.clone().unwrap_or_default()),
|
Value::String(func.description.unwrap_or_default()),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Ensure 'properties' exists and is an object
|
let mut properties = Map::new();
|
||||||
let properties = params
|
let mut required = vec![Value::String("_name".to_string())];
|
||||||
.entry("properties".to_string())
|
|
||||||
.or_insert_with(|| json!({}))
|
|
||||||
.as_object_mut()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Insert the constant for the function name inside 'properties'
|
|
||||||
properties.insert(
|
properties.insert(
|
||||||
"_name".to_string(),
|
"_name".to_string(),
|
||||||
json!({
|
json!({
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"const": func.name.clone(),
|
"const": func.name.clone(),
|
||||||
// "description": "The name of the function"
|
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
if let Value::Object(args) = func.arguments {
|
||||||
let required = params
|
if let Some(Value::Object(props)) = args.get("properties") {
|
||||||
.entry("required".to_string())
|
properties.extend(props.clone());
|
||||||
.or_insert_with(|| json!([]))
|
|
||||||
.as_array_mut()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Add 'name' to the 'required' array if it is not already present
|
|
||||||
if !required.iter().any(|r| r == "_name") {
|
|
||||||
required.push(json!("_name"));
|
|
||||||
}
|
}
|
||||||
|
if let Some(Value::Array(reqs)) = args.get("required") {
|
||||||
|
required.extend(reqs.clone());
|
||||||
|
}
|
||||||
|
params.insert(
|
||||||
|
"additionalProperties".to_string(),
|
||||||
|
Value::Bool(
|
||||||
|
args.get("additionalProperties").and_then(|v| v.as_str())
|
||||||
|
== Some("true"),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
params.insert("properties".to_string(), Value::Object(properties));
|
||||||
|
params.insert("required".to_string(), Value::Array(required));
|
||||||
|
|
||||||
(func.name, Value::Object(params))
|
(func.name, Value::Object(params))
|
||||||
})
|
})
|
||||||
.chain([(
|
|
||||||
"notify_error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"properties": text_response_properties,
|
|
||||||
"required": ["error", "_name"],
|
|
||||||
"type": "object"
|
|
||||||
}),
|
|
||||||
)])
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let tools = Tools {
|
let tool_schema = JsonSchemaTool {
|
||||||
functions_map: FunctionsMap { functions },
|
functions_map: FunctionsMap { functions },
|
||||||
properties: Properties {
|
properties: Properties {
|
||||||
function: tools_to_use
|
function: tools_to_use
|
||||||
@ -123,13 +117,10 @@ impl ToolGrammar {
|
|||||||
.map(|tool| FunctionRef {
|
.map(|tool| FunctionRef {
|
||||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||||
})
|
})
|
||||||
.chain(std::iter::once(FunctionRef {
|
|
||||||
ref_path: "#/$functions/notify_error".to_string(),
|
|
||||||
}))
|
|
||||||
.collect(),
|
.collect(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Some(tools))
|
Ok((tools, Some(tool_schema)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -840,10 +840,10 @@ pub(crate) struct ChatRequest {
|
|||||||
pub tools: Option<Vec<Tool>>,
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
|
||||||
/// A prompt to be appended before the tools
|
/// A prompt to be appended before the tools
|
||||||
#[serde(default = "default_tool_prompt")]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
nullable = true,
|
nullable = true,
|
||||||
example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
|
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
||||||
)]
|
)]
|
||||||
pub tool_prompt: Option<String>,
|
pub tool_prompt: Option<String>,
|
||||||
|
|
||||||
@ -865,10 +865,8 @@ pub(crate) struct ChatRequest {
|
|||||||
pub guideline: Option<String>,
|
pub guideline: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_tool_prompt() -> Option<String> {
|
pub fn default_tool_prompt() -> String {
|
||||||
Some(
|
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
|
||||||
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||||
@ -910,7 +908,7 @@ impl From<ToolTypeDeserializer> for ToolChoice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
|
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
|
||||||
pub struct Tools {
|
pub struct JsonSchemaTool {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
functions_map: FunctionsMap,
|
functions_map: FunctionsMap,
|
||||||
properties: Properties,
|
properties: Properties,
|
||||||
@ -968,8 +966,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
bos_token: Option<&'a str>,
|
bos_token: Option<&'a str>,
|
||||||
eos_token: Option<&'a str>,
|
eos_token: Option<&'a str>,
|
||||||
add_generation_prompt: bool,
|
add_generation_prompt: bool,
|
||||||
tools: Option<&'a str>,
|
tools: Option<Vec<Tool>>,
|
||||||
tools_prompt: Option<&'a str>,
|
|
||||||
guideline: Option<&'a str>,
|
guideline: Option<&'a str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ use crate::kserve::{
|
|||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::ChatTokenizeResponse;
|
use crate::{default_tool_prompt, ChatTokenizeResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
@ -166,7 +166,7 @@ async fn get_chat_tokenize(
|
|||||||
} = req;
|
} = req;
|
||||||
|
|
||||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||||
let (inputs, _grammar, _tool_grammar) = prepare_chat_input(
|
let (inputs, _grammar, _using_tools) = prepare_chat_input(
|
||||||
&infer,
|
&infer,
|
||||||
response_format,
|
response_format,
|
||||||
tools,
|
tools,
|
||||||
@ -1178,14 +1178,16 @@ async fn chat_completions(
|
|||||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||||
let max_new_tokens = max_tokens.or(Some(100));
|
let max_new_tokens = max_tokens.or(Some(100));
|
||||||
let logprobs = logprobs.unwrap_or(false);
|
let logprobs = logprobs.unwrap_or(false);
|
||||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
let tool_prompt = tool_prompt
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.unwrap_or_else(default_tool_prompt);
|
||||||
let stop = stop.unwrap_or_default();
|
let stop = stop.unwrap_or_default();
|
||||||
// enable greedy only when temperature is 0
|
// enable greedy only when temperature is 0
|
||||||
let (do_sample, temperature) = match temperature {
|
let (do_sample, temperature) = match temperature {
|
||||||
Some(temperature) if temperature == 0.0 => (false, None),
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
other => (true, other),
|
other => (true, other),
|
||||||
};
|
};
|
||||||
let (inputs, grammar, tool_grammar) = prepare_chat_input(
|
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||||
&infer,
|
&infer,
|
||||||
response_format,
|
response_format,
|
||||||
tools,
|
tools,
|
||||||
@ -1241,7 +1243,7 @@ async fn chat_completions(
|
|||||||
});
|
});
|
||||||
|
|
||||||
// replace the content with the tool calls if grammar is present
|
// replace the content with the tool calls if grammar is present
|
||||||
let (content, tool_calls) = if tool_grammar.is_some() {
|
let (content, tool_calls) = if using_tools {
|
||||||
(None, Some(vec![stream_token.token.text]))
|
(None, Some(vec![stream_token.token.text]))
|
||||||
} else {
|
} else {
|
||||||
let content = if !stream_token.token.special {
|
let content = if !stream_token.token.special {
|
||||||
@ -1295,7 +1297,7 @@ async fn chat_completions(
|
|||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
let (tool_calls, output) = if tool_grammar.is_some() {
|
let (tool_calls, output) = if using_tools {
|
||||||
let gen_text_value: Value =
|
let gen_text_value: Value =
|
||||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||||
InferError::ToolError(format!(
|
InferError::ToolError(format!(
|
||||||
@ -2560,7 +2562,7 @@ fn create_post_processor(
|
|||||||
Ok(post_processor)
|
Ok(post_processor)
|
||||||
}
|
}
|
||||||
|
|
||||||
type PreparedInput = (String, Option<GrammarType>, Option<Tools>);
|
type PreparedInput = (String, Option<GrammarType>, bool);
|
||||||
|
|
||||||
fn prepare_chat_input(
|
fn prepare_chat_input(
|
||||||
infer: &Infer,
|
infer: &Infer,
|
||||||
@ -2577,19 +2579,139 @@ fn prepare_chat_input(
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// when response_format is set, tools are not included when applying the chat template to generate inputs
|
||||||
if let Some(format) = response_format {
|
if let Some(format) = response_format {
|
||||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||||
return Ok((inputs, Some(format), None));
|
return Ok((inputs, Some(format), false));
|
||||||
}
|
}
|
||||||
|
|
||||||
// if tools are set, apply the tool grammar and then the chat template
|
// when no response_format is set and tools are included, apply the chat template with the tools
|
||||||
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools, tool_choice)?;
|
// to generate inputs
|
||||||
let grammar = tool_grammar
|
if let Some(tools) = tools {
|
||||||
|
let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?;
|
||||||
|
|
||||||
|
let grammar = tool_schema
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
||||||
let tools_grammar_prompt = tool_grammar
|
|
||||||
.as_ref()
|
let inputs: String = infer.apply_chat_template(
|
||||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into()));
|
guideline,
|
||||||
let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?;
|
messages,
|
||||||
Ok((inputs, grammar, tool_grammar))
|
Some((updated_tools, tool_prompt.into())),
|
||||||
|
)?;
|
||||||
|
return Ok((inputs, grammar, tool_schema.is_some()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// if no response_format or tools are set simply apply the chat template to generate inputs
|
||||||
|
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||||
|
Ok((inputs, None, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::ChatTemplateVersions;
|
||||||
|
use crate::HubTokenizerConfig;
|
||||||
|
use crate::TokenizerConfigToken;
|
||||||
|
use crate::Tool;
|
||||||
|
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_prepare_chat_input() {
|
||||||
|
// Mock Backend to avoid network requests
|
||||||
|
struct MockBackend;
|
||||||
|
|
||||||
|
impl Backend for MockBackend {
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
_request: crate::validation::ValidGenerateRequest,
|
||||||
|
) -> Result<
|
||||||
|
tokio_stream::wrappers::UnboundedReceiverStream<
|
||||||
|
Result<InferStreamResponse, InferError>,
|
||||||
|
>,
|
||||||
|
InferError,
|
||||||
|
> {
|
||||||
|
unimplemented!("Never called in this test");
|
||||||
|
}
|
||||||
|
fn health<'a, 'async_trait>(
|
||||||
|
&'a self,
|
||||||
|
_current_health: bool,
|
||||||
|
) -> core::pin::Pin<
|
||||||
|
Box<dyn core::future::Future<Output = bool> + core::marker::Send + 'async_trait>,
|
||||||
|
>
|
||||||
|
where
|
||||||
|
'a: 'async_trait,
|
||||||
|
Self: 'async_trait,
|
||||||
|
{
|
||||||
|
unimplemented!("Never called in this test");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let backend = MockBackend {};
|
||||||
|
|
||||||
|
let mut tokenizer_config = HubTokenizerConfig::default();
|
||||||
|
|
||||||
|
// mock tokenizer config values
|
||||||
|
tokenizer_config.bos_token = Some(TokenizerConfigToken::String("<s>".to_string()));
|
||||||
|
tokenizer_config.eos_token = Some(TokenizerConfigToken::String("</s>".to_string()));
|
||||||
|
tokenizer_config.chat_template = Some(
|
||||||
|
ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
let infer = Infer::new(
|
||||||
|
backend,
|
||||||
|
Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false),
|
||||||
|
1,
|
||||||
|
tokenizer_config,
|
||||||
|
HubProcessorConfig::default(),
|
||||||
|
);
|
||||||
|
let response_format = None;
|
||||||
|
let tools = Some(vec![Tool {
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: FunctionDefinition {
|
||||||
|
name: "get_current_weather".to_string(),
|
||||||
|
description: Some("Get the current weather".to_string()),
|
||||||
|
arguments: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "format"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}]);
|
||||||
|
let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.";
|
||||||
|
let guideline = None;
|
||||||
|
let messages = vec![Message {
|
||||||
|
name: None,
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: MessageContent::SingleText(
|
||||||
|
"What is the weather like in New York?".to_string(),
|
||||||
|
),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let result = prepare_chat_input(
|
||||||
|
&infer,
|
||||||
|
response_format,
|
||||||
|
tools,
|
||||||
|
ToolChoice(None),
|
||||||
|
tool_prompt,
|
||||||
|
guideline,
|
||||||
|
messages,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let (inputs, _grammar, using_tools) = result.unwrap();
|
||||||
|
assert_eq!(using_tools, true);
|
||||||
|
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,12 @@ from .common import Seqlen
|
|||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .cuda import (
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
|
)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
|
@ -76,7 +76,7 @@ def paged_attention(
|
|||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flash_infer import decode_state
|
from text_generation_server.layers.attention.flashinfer import decode_state
|
||||||
|
|
||||||
return decode_state.get().forward(
|
return decode_state.get().forward(
|
||||||
query.contiguous(),
|
query.contiguous(),
|
||||||
@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2
|
|||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q: torch.Tensor,
|
||||||
k,
|
k: torch.Tensor,
|
||||||
v,
|
v: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
@ -231,14 +233,15 @@ if ATTENTION == "flashinfer":
|
|||||||
causal=True,
|
causal=True,
|
||||||
softcap=0.0,
|
softcap=0.0,
|
||||||
):
|
):
|
||||||
from text_generation_server.layers.attention.flash_infer import prefill_state
|
assert window_size_left == -1, "Windowing is not supported with flash infer"
|
||||||
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
|
prefill_with_paged_kv_state,
|
||||||
|
)
|
||||||
|
|
||||||
return prefill_state.get().forward(
|
return prefill_with_paged_kv_state.get().forward(
|
||||||
q,
|
q.contiguous(),
|
||||||
k,
|
|
||||||
v,
|
|
||||||
causal=causal,
|
causal=causal,
|
||||||
window_left=window_size_left,
|
paged_kv_cache=(key_cache, value_cache),
|
||||||
logits_soft_cap=softcap,
|
logits_soft_cap=softcap,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
)
|
)
|
||||||
@ -249,6 +252,8 @@ elif V2:
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
@ -289,6 +294,8 @@ else:
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
|
@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con
|
|||||||
"prefill_state"
|
"prefill_state"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prefill_with_paged_kv_state: ContextVar[
|
||||||
|
flashinfer.BatchPrefillWithPagedKVCacheWrapper
|
||||||
|
] = ContextVar("prefill_with_paged_kv_state")
|
||||||
|
|
||||||
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
||||||
"decode_state"
|
"decode_state"
|
||||||
)
|
)
|
||||||
@ -24,6 +28,78 @@ def get_workspace(device):
|
|||||||
return workspace
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
def create_prefill_with_paged_kv_state(
|
||||||
|
*,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Create a prefill state that uses the KV cache."""
|
||||||
|
workspace_buffer = get_workspace(device)
|
||||||
|
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def use_prefill_with_paged_kv_state(
|
||||||
|
*,
|
||||||
|
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
page_size: int,
|
||||||
|
query_dtype: str = "float16",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Context manager to set the active flashinfer prefill state to the given
|
||||||
|
`state` and parameters. This state will be used by all calls to the
|
||||||
|
`attention` function while the context manager is active.
|
||||||
|
"""
|
||||||
|
|
||||||
|
indptr = torch.zeros(
|
||||||
|
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
# Round up to page size and then calculate the cumulative sum to get
|
||||||
|
# the indices into the block table.
|
||||||
|
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
||||||
|
indptr[1:].div_(page_size, rounding_mode="floor")
|
||||||
|
indptr[1:].cumsum_(-1)
|
||||||
|
|
||||||
|
# Get the lengths of the last page in a block.
|
||||||
|
if page_size == 1:
|
||||||
|
last_page_len = torch.ones(
|
||||||
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
last_page_len = torch.empty(
|
||||||
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||||
|
)
|
||||||
|
torch.sub(input_lengths, 1, out=last_page_len)
|
||||||
|
last_page_len.remainder_(page_size)
|
||||||
|
last_page_len += 1
|
||||||
|
|
||||||
|
token = prefill_with_paged_kv_state.set(state)
|
||||||
|
try:
|
||||||
|
state.begin_forward(
|
||||||
|
qo_indptr=cu_seqlens,
|
||||||
|
paged_kv_indptr=indptr,
|
||||||
|
paged_kv_indices=block_tables,
|
||||||
|
paged_kv_last_page_len=last_page_len,
|
||||||
|
num_qo_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_size,
|
||||||
|
q_data_type=query_dtype,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
state.end_forward()
|
||||||
|
if token is not None:
|
||||||
|
prefill_with_paged_kv_state.reset(token)
|
||||||
|
|
||||||
|
|
||||||
def create_prefill_state(
|
def create_prefill_state(
|
||||||
*,
|
*,
|
||||||
device: torch.device,
|
device: torch.device,
|
@ -32,6 +32,8 @@ class MedusaModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if not self.heads:
|
||||||
|
return None
|
||||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||||
return speculative_logits
|
return speculative_logits
|
||||||
|
|
||||||
|
@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -220,6 +220,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -34,6 +34,11 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
config=config.vision_config,
|
config=config.vision_config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
self.post_vision_tower_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix="vision_tower.vision_model.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.vision_config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
self.multi_modal_projector = TensorParallelColumnLinear.load(
|
self.multi_modal_projector = TensorParallelColumnLinear.load(
|
||||||
config,
|
config,
|
||||||
@ -84,7 +89,10 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
||||||
image_outputs = self.vision_tower(pixel_values)
|
image_outputs = self.vision_tower(pixel_values)
|
||||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
last_hidden_state = self.post_vision_tower_layernorm(
|
||||||
|
image_outputs.last_hidden_state
|
||||||
|
)
|
||||||
|
image_features = self.multi_modal_projector(last_hidden_state)
|
||||||
|
|
||||||
# mask where image or padding tokens
|
# mask where image or padding tokens
|
||||||
mask = input_ids == self.config.image_token_index
|
mask = input_ids == self.config.image_token_index
|
||||||
|
@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
torch.select(kv, dim=2, index=0),
|
||||||
torch.select(kv, dim=2, index=1),
|
torch.select(kv, dim=2, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module):
|
|||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
for idx, encoder_layer in enumerate(self.layers):
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
hidden_states, _ = encoder_layer(
|
hidden_states, _ = encoder_layer(
|
||||||
@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module):
|
|||||||
self.encoder = SiglipEncoder(
|
self.encoder = SiglipEncoder(
|
||||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
)
|
)
|
||||||
self.post_layernorm = nn.LayerNorm.load(
|
|
||||||
prefix=f"{prefix}.post_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
):
|
):
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
raise ValueError("You have to specify pixel_values")
|
raise ValueError("You have to specify pixel_values")
|
||||||
|
|
||||||
@ -412,10 +402,9 @@ class SiglipVisionTransformer(nn.Module):
|
|||||||
inputs_embeds=hidden_states,
|
inputs_embeds=hidden_states,
|
||||||
)
|
)
|
||||||
last_hidden_state = encoder_outputs
|
last_hidden_state = encoder_outputs
|
||||||
post_last_hidden_state = self.post_layernorm(last_hidden_state)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPooling(
|
return BaseModelOutputWithPooling(
|
||||||
last_hidden_state=post_last_hidden_state,
|
last_hidden_state=last_hidden_state,
|
||||||
# pooler_output=pooled_output,
|
# pooler_output=pooled_output,
|
||||||
# hidden_states=encoder_outputs,
|
# hidden_states=encoder_outputs,
|
||||||
)
|
)
|
||||||
|
@ -43,6 +43,7 @@ from text_generation_server.models.globals import (
|
|||||||
ATTENTION,
|
ATTENTION,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
CUDA_GRAPHS,
|
CUDA_GRAPHS,
|
||||||
|
PREFIX_CACHING,
|
||||||
get_adapter_to_index,
|
get_adapter_to_index,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
@ -138,6 +139,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_tensor: torch.Tensor
|
block_tables_tensor: torch.Tensor
|
||||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
slots: torch.Tensor
|
slots: torch.Tensor
|
||||||
|
# size [b], containing the number of blocks that can be retrieved from the cache
|
||||||
|
prefix_lens: List[int]
|
||||||
|
prefix_lens_tensor: torch.Tensor
|
||||||
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
@ -146,6 +150,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_next_token_indices: Optional[torch.tensor]
|
prefill_next_token_indices: Optional[torch.tensor]
|
||||||
prefill_cu_outlens: Optional[List[int]]
|
prefill_cu_outlens: Optional[List[int]]
|
||||||
|
|
||||||
|
# Prefixes
|
||||||
|
prefix_ids: List[List[int]]
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
all_input_ids: List[List[int]]
|
all_input_ids: List[List[int]]
|
||||||
all_input_ids_tensor: torch.Tensor
|
all_input_ids_tensor: torch.Tensor
|
||||||
@ -213,6 +220,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
prefix_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
all_prefill_logprobs = True
|
all_prefill_logprobs = True
|
||||||
@ -230,7 +238,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
cumulative_max_length = 0
|
cumulative_slot_tokens = 0
|
||||||
prefill_out_cumulative_length = 0
|
prefill_out_cumulative_length = 0
|
||||||
|
|
||||||
num_blocks = 0
|
num_blocks = 0
|
||||||
@ -240,6 +248,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
block_tables = []
|
block_tables = []
|
||||||
slots = []
|
slots = []
|
||||||
|
prefix_lens = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
@ -255,6 +264,19 @@ class FlashCausalLMBatch(Batch):
|
|||||||
):
|
):
|
||||||
tokenized_input = tokenized_input[1:]
|
tokenized_input = tokenized_input[1:]
|
||||||
|
|
||||||
|
orig_input_length = len(tokenized_input)
|
||||||
|
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
prefix_len = r.prefix_len
|
||||||
|
if prefix_len == orig_input_length:
|
||||||
|
assert prefix_len > 0
|
||||||
|
prefix_len -= 1
|
||||||
|
else:
|
||||||
|
prefix_len = 0
|
||||||
|
|
||||||
|
prefix_ids.append(tokenized_input[:prefix_len])
|
||||||
|
tokenized_input = tokenized_input[prefix_len:]
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
@ -264,7 +286,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
request_position_ids = torch.arange(
|
||||||
|
prefix_len, orig_input_length, dtype=torch.int32
|
||||||
|
)
|
||||||
position_ids.append(request_position_ids)
|
position_ids.append(request_position_ids)
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
@ -288,11 +312,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
speculative_length = get_speculate()
|
speculative_length = get_speculate()
|
||||||
speculative_length = 0 if speculative_length is None else speculative_length
|
speculative_length = 0 if speculative_length is None else speculative_length
|
||||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
|
||||||
|
# Tokens that need to be mapped to blocks.
|
||||||
|
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
|
||||||
|
|
||||||
|
# Tokens that need to be mapped to slots. We don't need slots for the
|
||||||
|
# cached prefix (if present).
|
||||||
|
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||||
|
|
||||||
# blocks and slots can be empty (for example in warmup)
|
# blocks and slots can be empty (for example in warmup)
|
||||||
if not r.blocks:
|
if not r.blocks:
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
|
||||||
request_blocks = [
|
request_blocks = [
|
||||||
b for b in range(num_blocks, num_blocks + needed_blocks)
|
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||||
]
|
]
|
||||||
@ -303,16 +333,20 @@ class FlashCausalLMBatch(Batch):
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
request_blocks = r.blocks
|
request_blocks = r.blocks
|
||||||
request_slots = r.slots
|
request_slots = r.slots[
|
||||||
|
prefix_len: #: orig_input_length + max_new_tokens + speculative_length
|
||||||
|
]
|
||||||
|
|
||||||
block_tables.append(request_blocks)
|
block_tables.append(request_blocks)
|
||||||
slots.extend(request_slots[:total_tokens])
|
|
||||||
|
slots.extend(request_slots)
|
||||||
|
prefix_lens.append(prefix_len)
|
||||||
num_blocks += len(request_blocks)
|
num_blocks += len(request_blocks)
|
||||||
start_slots.append(cumulative_max_length)
|
start_slots.append(cumulative_slot_tokens)
|
||||||
|
|
||||||
request_slot_indices = torch.arange(
|
request_slot_indices = torch.arange(
|
||||||
cumulative_max_length,
|
cumulative_slot_tokens,
|
||||||
cumulative_max_length + input_length,
|
cumulative_slot_tokens + input_length,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
@ -348,7 +382,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
cumulative_max_length += total_tokens
|
cumulative_slot_tokens += slot_tokens
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
max_blocks = max(max_blocks, len(request_blocks))
|
max_blocks = max(max_blocks, len(request_blocks))
|
||||||
max_length = max(
|
max_length = max(
|
||||||
@ -425,12 +459,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
block_tables_tensor = torch.zeros(
|
block_tables_tensor = torch.zeros(
|
||||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||||
)
|
)
|
||||||
for i, request_blocks in enumerate(block_tables):
|
for i, request_blocks in enumerate(block_tables):
|
||||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||||
block_tables_tensor = block_tables_tensor.to(device)
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
|
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
@ -445,6 +481,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
@ -455,6 +493,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
|
prefix_ids=prefix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
@ -510,8 +549,10 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_slots = []
|
start_slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
prefix_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
prefix_lens = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
@ -533,11 +574,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Get length
|
# Get length
|
||||||
request_input_length = self.input_lengths[idx]
|
request_input_length = self.input_lengths[idx]
|
||||||
|
prefix_len = self.prefix_lens[idx]
|
||||||
max_seqlen = max(max_seqlen, request_input_length)
|
max_seqlen = max(max_seqlen, request_input_length)
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
|
prefix_ids.append(self.prefix_ids[idx])
|
||||||
|
|
||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
|
prefix_lens.append(prefix_len)
|
||||||
prefix_offsets.append(self.prefix_offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
read_offsets.append(self.read_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
|
||||||
@ -582,6 +626,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
|
prefix_lens_tensor = self.prefix_lens_tensor[indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
speculative_ids = (
|
speculative_ids = (
|
||||||
@ -617,10 +662,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
|
prefix_ids=prefix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
@ -681,6 +729,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||||
(total_batch_size, max_blocks)
|
(total_batch_size, max_blocks)
|
||||||
)
|
)
|
||||||
|
prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size)
|
||||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||||
(total_batch_size, max_length)
|
(total_batch_size, max_length)
|
||||||
)
|
)
|
||||||
@ -698,7 +747,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
start_slots = []
|
start_slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
|
prefix_lens = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
prefix_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
@ -760,10 +811,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||||
] = batch.block_tables_tensor[:, :max_blocks]
|
] = batch.block_tables_tensor[:, :max_blocks]
|
||||||
|
|
||||||
|
prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor
|
||||||
|
|
||||||
start_slots.append(batch.start_slots + cumulative_slots)
|
start_slots.append(batch.start_slots + cumulative_slots)
|
||||||
|
|
||||||
block_tables.extend(batch.block_tables)
|
block_tables.extend(batch.block_tables)
|
||||||
|
prefix_lens.extend(batch.prefix_lens)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
|
prefix_ids.extend(batch.prefix_ids)
|
||||||
|
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
prefix_offsets.extend(batch.prefix_offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
@ -809,6 +864,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
@ -820,6 +877,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
|
prefix_ids=prefix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
@ -970,14 +1028,17 @@ class FlashCausalLM(Model):
|
|||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
create_prefill_state,
|
create_prefill_state,
|
||||||
create_decode_state,
|
create_decode_state,
|
||||||
|
create_prefill_with_paged_kv_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prefill_state = create_prefill_state(device=device)
|
self.prefill_state = create_prefill_state(device=device)
|
||||||
|
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
if not CUDA_GRAPHS:
|
|
||||||
self.decode_state = create_decode_state(
|
self.decode_state = create_decode_state(
|
||||||
device=device,
|
device=device,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
@ -1074,11 +1135,22 @@ class FlashCausalLM(Model):
|
|||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
input_lengths = [max_s] * bs
|
||||||
block_tables = (
|
prefix_lengths = [0] * bs
|
||||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
input_lengths_tensor = (
|
||||||
.repeat(bs)
|
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
.reshape((bs, max_bt))
|
)
|
||||||
|
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
|
block_tables = torch.arange(
|
||||||
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
|
).repeat(bs)
|
||||||
|
block_tables = block_tables.reshape((bs, max_bt))
|
||||||
|
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_lens=prefix_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cuda_graphs[bs] = {
|
self.cuda_graphs[bs] = {
|
||||||
@ -1087,14 +1159,14 @@ class FlashCausalLM(Model):
|
|||||||
"kv_cache": self.kv_cache,
|
"kv_cache": self.kv_cache,
|
||||||
"block_tables": block_tables,
|
"block_tables": block_tables,
|
||||||
"slots": slots,
|
"slots": slots,
|
||||||
"input_lengths": input_lengths,
|
"input_lengths": input_lengths_tensor,
|
||||||
}
|
}
|
||||||
input_lengths_ = Seqlen(input_lengths=input_lengths)
|
input_lengths_ = Seqlen(input_lengths=input_lengths_tensor)
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs[bs]["graph"] = graph
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
create_decode_state_cuda_graphs,
|
create_decode_state_cuda_graphs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1104,7 +1176,7 @@ class FlashCausalLM(Model):
|
|||||||
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||||
state = create_decode_state_cuda_graphs(
|
state = create_decode_state_cuda_graphs(
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
block_tables=block_tables.view(-1),
|
block_tables=block_tables,
|
||||||
block_tables_ptr=block_tables_ptr,
|
block_tables_ptr=block_tables_ptr,
|
||||||
last_page_len=last_page_len,
|
last_page_len=last_page_len,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
@ -1120,7 +1192,10 @@ class FlashCausalLM(Model):
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
state=state,
|
state=state,
|
||||||
|
prefix_lens=prefix_lengths,
|
||||||
|
prefix_lens_tensor=prefix_lengths_tensor,
|
||||||
):
|
):
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -1138,7 +1213,7 @@ class FlashCausalLM(Model):
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -1146,7 +1221,7 @@ class FlashCausalLM(Model):
|
|||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths_tensor,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
@ -1334,6 +1409,9 @@ class FlashCausalLM(Model):
|
|||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
|
prefix_lens_tensor = (
|
||||||
|
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
block_tables = (
|
block_tables = (
|
||||||
@ -1354,6 +1432,7 @@ class FlashCausalLM(Model):
|
|||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
@ -1372,10 +1451,20 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
|
input_lengths = input_lengths + prefix_lens_tensor
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths=input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
):
|
):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
@ -1399,20 +1488,32 @@ class FlashCausalLM(Model):
|
|||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
|
else:
|
||||||
cuda_graph["block_tables"][
|
cuda_graph["block_tables"][
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
] = block_tables
|
] = block_tables
|
||||||
cuda_graph["slots"].fill_(-1)
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||||
|
input_lengths + prefix_lens_tensor
|
||||||
|
)
|
||||||
|
|
||||||
state = cuda_graph.get("state")
|
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=cuda_graph["block_tables"],
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
state=state,
|
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
|
state=cuda_graph.get("state"),
|
||||||
):
|
):
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
@ -1610,6 +1711,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.read_offsets,
|
batch.read_offsets,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
|
batch.prefix_ids,
|
||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
@ -1627,6 +1729,7 @@ class FlashCausalLM(Model):
|
|||||||
read_offset,
|
read_offset,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
|
prefix_ids,
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
@ -1701,18 +1804,18 @@ class FlashCausalLM(Model):
|
|||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
request_prefill_logprobs = (
|
||||||
out_start_index : out_end_index - 1
|
[float("nan")] * (len(prefix_ids) + 1)
|
||||||
]
|
) + prefill_logprobs[out_start_index : out_end_index - 1]
|
||||||
prefill_token_ids = all_input_ids[:-1]
|
prefill_token_ids = all_input_ids[:-1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
prefill_token_ids,
|
prefix_ids + prefill_token_ids,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill_tokens = Tokens(
|
prefill_tokens = Tokens(
|
||||||
prefill_token_ids,
|
prefix_ids + prefill_token_ids,
|
||||||
request_prefill_logprobs,
|
request_prefill_logprobs,
|
||||||
prefill_texts,
|
prefill_texts,
|
||||||
is_special=[],
|
is_special=[],
|
||||||
@ -1794,33 +1897,68 @@ class FlashCausalLM(Model):
|
|||||||
*,
|
*,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: List[int],
|
||||||
|
input_lengths_tensor: torch.Tensor,
|
||||||
|
prefix_lens: List[int],
|
||||||
|
prefix_lens_tensor: torch.Tensor,
|
||||||
state: Optional[Any] = None,
|
state: Optional[Any] = None,
|
||||||
) -> ContextManager:
|
) -> ContextManager:
|
||||||
if ATTENTION != "flashinfer":
|
if ATTENTION != "flashinfer":
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
|
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
use_decode_state,
|
use_decode_state,
|
||||||
use_prefill_state,
|
use_prefill_with_paged_kv_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
return use_prefill_state(
|
return use_prefill_with_paged_kv_state(
|
||||||
state=state if state is not None else self.prefill_state,
|
state=(
|
||||||
|
state if state is not None else self.prefill_with_paged_kv_state
|
||||||
|
),
|
||||||
|
# block_tables=block_tables_to_ragged(
|
||||||
|
# block_tables=block_tables,
|
||||||
|
# input_lengths=input_lengths,
|
||||||
|
# prefix_lens=prefix_lens,
|
||||||
|
# ),
|
||||||
|
block_tables=block_tables,
|
||||||
cu_seqlens=cu_seqlen_prefill,
|
cu_seqlens=cu_seqlen_prefill,
|
||||||
num_heads=self.num_heads,
|
input_lengths=input_lengths_tensor,
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
head_size=self.head_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert input_lengths is not None
|
|
||||||
return use_decode_state(
|
|
||||||
state=state if state is not None else self.decode_state,
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
block_tables=block_tables.view(-1),
|
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
page_size=BLOCK_SIZE,
|
page_size=BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
assert input_lengths_tensor is not None
|
||||||
|
return use_decode_state(
|
||||||
|
state=state if state is not None else self.decode_state,
|
||||||
|
input_lengths=input_lengths_tensor,
|
||||||
|
block_tables=block_tables,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
page_size=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def block_tables_to_ragged(
|
||||||
|
*, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Convert block table to ragged format compatible with FlashInfer."""
|
||||||
|
assert len(input_lengths) == len(prefix_lens)
|
||||||
|
|
||||||
|
total_len = sum(input_lengths) + sum(prefix_lens)
|
||||||
|
block_tables_ragged = torch.empty(
|
||||||
|
total_len, dtype=torch.int32, device=block_tables.device
|
||||||
|
)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)):
|
||||||
|
seq_len = prefix_len + input_length
|
||||||
|
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
|
||||||
|
offset += seq_len
|
||||||
|
|
||||||
|
return block_tables_ragged
|
||||||
|
@ -5,9 +5,8 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False)
|
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"}
|
||||||
log_master(logger.info, f"Using Attention = {PREFIX_CACHING}")
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
|
|
||||||
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
|
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
|
||||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||||
assert (
|
assert (
|
||||||
@ -29,7 +28,6 @@ elif ATTENTION == "flashinfer":
|
|||||||
else:
|
else:
|
||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
|
|
||||||
|
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
try:
|
try:
|
||||||
|
@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2
|
|||||||
from text_generation_server.models.flash_causal_lm import (
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
FlashCausalLMBatch,
|
FlashCausalLMBatch,
|
||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
|
block_tables_to_ragged,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
raise NotImplementedError("Vlm do not work with prefix caching yet")
|
||||||
if processor_kwargs is None:
|
if processor_kwargs is None:
|
||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
self.processor = processor_class.from_pretrained(
|
self.processor = processor_class.from_pretrained(
|
||||||
@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
|
prefix_lens_tensor = (
|
||||||
|
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
block_tables = (
|
block_tables = (
|
||||||
@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
@ -349,6 +357,21 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
else:
|
else:
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
|
input_lengths = input_lengths + prefix_lens_tensor
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=block_tables,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
|
):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -379,13 +402,23 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
|
else:
|
||||||
cuda_graph["block_tables"][
|
cuda_graph["block_tables"][
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
] = block_tables
|
] = block_tables
|
||||||
cuda_graph["slots"].fill_(-1)
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||||
|
input_lengths + prefix_lens_tensor
|
||||||
|
)
|
||||||
|
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
Loading…
Reference in New Issue
Block a user