mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Merge branch 'feat/qwen' into qwen2
This commit is contained in:
commit
73403aa4db
@ -52,6 +52,8 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
|
||||
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
|
||||
- Stop sequences
|
||||
- Log probabilities
|
||||
- [Speculation](https://huggingface.co/docs/text-generation-inference/conceptual/speculation) ~2x latency
|
||||
- [Guidance/JSON](https://huggingface.co/docs/text-generation-inference/conceptual/guidance). Specify output format to speed up inference and make sure the output is valid according to some specs..
|
||||
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output
|
||||
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance
|
||||
|
||||
|
@ -3,7 +3,7 @@ import requests
|
||||
|
||||
from aiohttp import ClientSession, ClientTimeout
|
||||
from pydantic import ValidationError
|
||||
from typing import Dict, Optional, List, AsyncIterator, Iterator
|
||||
from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
|
||||
|
||||
from text_generation.types import (
|
||||
StreamResponse,
|
||||
@ -11,6 +11,11 @@ from text_generation.types import (
|
||||
Request,
|
||||
Parameters,
|
||||
Grammar,
|
||||
ChatRequest,
|
||||
ChatCompletionChunk,
|
||||
ChatComplete,
|
||||
Message,
|
||||
Tool,
|
||||
)
|
||||
from text_generation.errors import parse_error
|
||||
|
||||
@ -59,6 +64,114 @@ class Client:
|
||||
self.cookies = cookies
|
||||
self.timeout = timeout
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[List[float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
stream: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Given a list of messages, generate a response asynchronously
|
||||
|
||||
Args:
|
||||
messages (`List[Message]`):
|
||||
List of messages
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
logit_bias (`List[float]`):
|
||||
Adjust the likelihood of specified tokens
|
||||
logprobs (`bool`):
|
||||
Include log probabilities in the response
|
||||
top_logprobs (`int`):
|
||||
Include the `n` most likely tokens at each step
|
||||
max_tokens (`int`):
|
||||
Maximum number of generated tokens
|
||||
n (`int`):
|
||||
Generate `n` completions
|
||||
presence_penalty (`float`):
|
||||
The parameter for presence penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
stream (`bool`):
|
||||
Stream the response
|
||||
seed (`int`):
|
||||
Random sampling seed
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation
|
||||
tools (`List[Tool]`):
|
||||
List of tools to use
|
||||
tool_choice (`str`):
|
||||
The tool to use
|
||||
|
||||
"""
|
||||
request = ChatRequest(
|
||||
model="tgi",
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
stream=stream,
|
||||
seed=seed,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
if not stream:
|
||||
resp = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
payload = resp.json()
|
||||
if resp.status_code != 200:
|
||||
raise parse_error(resp.status_code, payload)
|
||||
return ChatComplete(**payload)
|
||||
else:
|
||||
return self._chat_stream_response(request)
|
||||
|
||||
def _chat_stream_response(self, request):
|
||||
resp = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
timeout=self.timeout,
|
||||
stream=True,
|
||||
)
|
||||
# iterate and print stream
|
||||
for byte_payload in resp.iter_lines():
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
payload = byte_payload.decode("utf-8")
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||
try:
|
||||
response = ChatCompletionChunk(**json_payload)
|
||||
yield response
|
||||
except ValidationError:
|
||||
raise parse_error(resp.status, json_payload)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -313,6 +426,113 @@ class AsyncClient:
|
||||
self.cookies = cookies
|
||||
self.timeout = ClientTimeout(timeout * 60)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[List[float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
stream: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
||||
"""
|
||||
Given a list of messages, generate a response asynchronously
|
||||
|
||||
Args:
|
||||
messages (`List[Message]`):
|
||||
List of messages
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
logit_bias (`List[float]`):
|
||||
Adjust the likelihood of specified tokens
|
||||
logprobs (`bool`):
|
||||
Include log probabilities in the response
|
||||
top_logprobs (`int`):
|
||||
Include the `n` most likely tokens at each step
|
||||
max_tokens (`int`):
|
||||
Maximum number of generated tokens
|
||||
n (`int`):
|
||||
Generate `n` completions
|
||||
presence_penalty (`float`):
|
||||
The parameter for presence penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
stream (`bool`):
|
||||
Stream the response
|
||||
seed (`int`):
|
||||
Random sampling seed
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation
|
||||
tools (`List[Tool]`):
|
||||
List of tools to use
|
||||
tool_choice (`str`):
|
||||
The tool to use
|
||||
|
||||
"""
|
||||
request = ChatRequest(
|
||||
model="tgi",
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
stream=stream,
|
||||
seed=seed,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
if not stream:
|
||||
return await self._chat_single_response(request)
|
||||
else:
|
||||
return self._chat_stream_response(request)
|
||||
|
||||
async def _chat_single_response(self, request):
|
||||
async with ClientSession(
|
||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/v1/chat/completions", json=request.dict()
|
||||
) as resp:
|
||||
payload = await resp.json()
|
||||
if resp.status != 200:
|
||||
raise parse_error(resp.status, payload)
|
||||
return ChatComplete(**payload)
|
||||
|
||||
async def _chat_stream_response(self, request):
|
||||
async with ClientSession(
|
||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/v1/chat/completions", json=request.dict()
|
||||
) as resp:
|
||||
async for byte_payload in resp.content:
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
payload = byte_payload.decode("utf-8")
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||
try:
|
||||
response = ChatCompletionChunk(**json_payload)
|
||||
yield response
|
||||
except ValidationError:
|
||||
raise parse_error(resp.status, json_payload)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
|
@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, validator
|
||||
from typing import Optional, List, Union
|
||||
from typing import Optional, List, Union, Any
|
||||
|
||||
from text_generation.errors import ValidationError
|
||||
|
||||
@ -19,6 +19,124 @@ class Grammar(BaseModel):
|
||||
value: Union[str, dict]
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
# Id of the tool call
|
||||
id: int
|
||||
# Type of the tool call
|
||||
type: str
|
||||
# Function details of the tool call
|
||||
function: dict
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
# Role of the message sender
|
||||
role: str
|
||||
# Content of the message
|
||||
content: Optional[str]
|
||||
# Optional name of the message sender
|
||||
name: Optional[str] = None
|
||||
# Tool calls associated with the chat completion
|
||||
tool_calls: Optional[Any] = None
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
# Type of the tool
|
||||
type: str
|
||||
# Function details of the tool
|
||||
function: dict
|
||||
|
||||
|
||||
class ChatCompletionComplete(BaseModel):
|
||||
# Index of the chat completion
|
||||
index: int
|
||||
# Message associated with the chat completion
|
||||
message: Message
|
||||
# Log probabilities for the chat completion
|
||||
logprobs: Optional[Any]
|
||||
# Reason for completion
|
||||
finish_reason: str
|
||||
# Usage details of the chat completion
|
||||
usage: Any
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: Optional[str]
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChoiceDeltaToolCall(BaseModel):
|
||||
index: int
|
||||
id: str
|
||||
type: str
|
||||
function: Function
|
||||
|
||||
|
||||
class ChoiceDelta(BaseModel):
|
||||
role: str
|
||||
content: Optional[str]
|
||||
tool_calls: Optional[ChoiceDeltaToolCall]
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
index: int
|
||||
delta: ChoiceDelta
|
||||
logprobs: Optional[dict] = None
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[Choice]
|
||||
|
||||
|
||||
class ChatComplete(BaseModel):
|
||||
# Chat completion details
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[ChatCompletionComplete]
|
||||
usage: Any
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
# Model identifier
|
||||
model: str
|
||||
# List of messages in the conversation
|
||||
messages: List[Message]
|
||||
# Penalty for frequency of new tokens
|
||||
frequency_penalty: Optional[float] = None
|
||||
# Bias values for token selection
|
||||
logit_bias: Optional[List[float]] = None
|
||||
# Whether to return log probabilities
|
||||
logprobs: Optional[bool] = None
|
||||
# Number of most likely tokens to return at each position
|
||||
top_logprobs: Optional[int] = None
|
||||
# Maximum number of tokens to generate
|
||||
max_tokens: Optional[int] = None
|
||||
# Number of chat completion choices to generate
|
||||
n: Optional[int] = None
|
||||
# Penalty for presence of new tokens
|
||||
presence_penalty: Optional[float] = None
|
||||
# Flag to indicate streaming response
|
||||
stream: bool = False
|
||||
# Random sampling seed
|
||||
seed: Optional[int] = None
|
||||
# Sampling temperature
|
||||
temperature: Optional[float] = None
|
||||
# Top-p value for nucleus sampling
|
||||
top_p: Optional[float] = None
|
||||
# List of tools to be used
|
||||
tools: Optional[List[Tool]] = None
|
||||
# Choice of tool to be used
|
||||
tool_choice: Optional[str] = None
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
# Activate logits sampling
|
||||
do_sample: bool = False
|
||||
|
@ -9,6 +9,8 @@
|
||||
title: Supported Models and Hardware
|
||||
- local: messages_api
|
||||
title: Messages API
|
||||
- local: guidance
|
||||
title: Guidance
|
||||
title: Getting started
|
||||
- sections:
|
||||
- local: basic_tutorials/consuming_tgi
|
||||
@ -37,4 +39,8 @@
|
||||
title: Safetensors
|
||||
- local: conceptual/flash_attention
|
||||
title: Flash Attention
|
||||
- local: conceptual/speculation
|
||||
title: Speculation (Medusa, ngram)
|
||||
- local: conceptual/guidance
|
||||
title: Guidance, JSON, tools (using outlines)
|
||||
title: Conceptual Guides
|
||||
|
419
docs/source/conceptual/guidance.md
Normal file
419
docs/source/conceptual/guidance.md
Normal file
@ -0,0 +1,419 @@
|
||||
# Guidance
|
||||
|
||||
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs.
|
||||
|
||||
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
|
||||
|
||||
## Quick Start
|
||||
|
||||
Before we jump into the deep end, ensure your system is using TGI version `1.4.3` or later to access all the features we're about to explore in this guide.
|
||||
|
||||
If you're not up to date, grab the latest version and let's get started!
|
||||
|
||||
## Table of Contents 📚
|
||||
|
||||
### Grammar and Constraints
|
||||
|
||||
- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
|
||||
- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
|
||||
- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema.
|
||||
- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
|
||||
|
||||
### Tools and Functions
|
||||
|
||||
- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions.
|
||||
- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions.
|
||||
- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
||||
|
||||
## Grammar and Constraints 🛣️
|
||||
|
||||
### The Grammar Parameter
|
||||
|
||||
In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the AI. This is a game-changer for those who need precise control over the AI's output.
|
||||
|
||||
Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability.
|
||||
|
||||
```json
|
||||
curl localhost:3000/generate \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park",
|
||||
"parameters": {
|
||||
"repetition_penalty": 1.3,
|
||||
"grammar": {
|
||||
"type": "json",
|
||||
"value": {
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"activity": {
|
||||
"type": "string"
|
||||
},
|
||||
"animals_seen": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 5
|
||||
},
|
||||
"animals": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["location", "activity", "animals_seen", "animals"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"}
|
||||
|
||||
```
|
||||
|
||||
A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
|
||||
|
||||
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compliation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
|
||||
|
||||
### Constrain with Pydantic
|
||||
|
||||
Pydantic is a powerful library for data validation and settings management. It's the perfect tool for crafting the a specific response format.
|
||||
|
||||
Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way.
|
||||
|
||||
```python
|
||||
import requests
|
||||
from pydantic import BaseModel, conint
|
||||
from typing import List
|
||||
|
||||
class Animals(BaseModel):
|
||||
location: str
|
||||
activity: str
|
||||
animals_seen: conint(ge=1, le=5) # Constrained integer type
|
||||
animals: List[str]
|
||||
|
||||
prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park"
|
||||
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"repetition_penalty": 1.3,
|
||||
"grammar": {
|
||||
"type": "json",
|
||||
"value": Animals.schema()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'http://127.0.0.1:3000/generate',
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
print(response.json())
|
||||
# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'}
|
||||
|
||||
```
|
||||
|
||||
### JSON Schema Integration
|
||||
|
||||
If Pydantic's not your style, go raw with direct JSON Schema integration. It's like having a conversation with the AI in its own language. This is simliar to the first example but with programmatic control.
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
json_schema = {
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"activity": {
|
||||
"type": "string"
|
||||
},
|
||||
"animals_seen": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 5
|
||||
},
|
||||
"animals": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["location", "activity", "animals_seen", "animals"]
|
||||
}
|
||||
|
||||
data = {
|
||||
"inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]",
|
||||
"parameters": {
|
||||
"max_new_tokens": 200,
|
||||
"repetition_penalty": 1.3,
|
||||
"grammar": {
|
||||
"type": "json",
|
||||
"value": json_schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'http://127.0.0.1:3000/generate',
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
print(response.json())
|
||||
# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'}
|
||||
|
||||
```
|
||||
|
||||
### Using the client
|
||||
|
||||
TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter.
|
||||
|
||||
```python
|
||||
from text_generation import AsyncClient
|
||||
from text_generation.types import GrammarType
|
||||
|
||||
# NOTE: tools defined above and removed for brevity
|
||||
|
||||
# Define an async function to encapsulate the async operation
|
||||
async def main():
|
||||
client = AsyncClient(base_url="http://localhost:3000")
|
||||
|
||||
# Use 'await' to wait for the async method 'chat' to complete
|
||||
response = await client.generate(
|
||||
"Whats Googles DNS",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
seed=1,
|
||||
grammar={
|
||||
"type": GrammarType.Regex,
|
||||
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
||||
},
|
||||
)
|
||||
|
||||
# Once the response is received, you can process it
|
||||
print(response.generated_text)
|
||||
|
||||
# Ensure the main async function is run in the event loop
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
# 118.8.0.84
|
||||
|
||||
```
|
||||
|
||||
## Tools and Functions 🛠️
|
||||
|
||||
### The Tools Parameter
|
||||
|
||||
In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API.
|
||||
|
||||
Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the AI's capabilities. You can use these tools to perform a variety of tasks, such as data manipulation, formatting, and more.
|
||||
|
||||
Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
|
||||
|
||||
```json
|
||||
curl localhost:3000/v1/chat/completions \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in New York?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location."
|
||||
}
|
||||
},
|
||||
"required": ["location", "format"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"tool_choice": "get_current_weather"
|
||||
}'
|
||||
// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}}
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Tools used in example below</summary>
|
||||
|
||||
```python
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_n_day_weather_forecast",
|
||||
"description": "Get an N-day weather forecast",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
"num_days": {
|
||||
"type": "integer",
|
||||
"description": "The number of days to forecast",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format", "num_days"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Text Generation Inference Client
|
||||
|
||||
TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions.
|
||||
|
||||
```python
|
||||
from text_generation import AsyncClient
|
||||
|
||||
# NOTE: tools defined above and removed for brevity
|
||||
|
||||
# Define an async function to encapsulate the async operation
|
||||
async def main():
|
||||
client = AsyncClient(base_url="http://localhost:3000")
|
||||
|
||||
# Use 'await' to wait for the async method 'chat' to complete
|
||||
response = await client.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You're a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Once the response is received, you can process it
|
||||
print(response.choices[0].message.tool_calls)
|
||||
|
||||
# Ensure the main async function is run in the event loop
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}}
|
||||
|
||||
```
|
||||
|
||||
### OpenAI integration
|
||||
|
||||
TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
||||
|
||||
However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Initialize the client, pointing it to one of the available models
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:3000/v1",
|
||||
api_key="_",
|
||||
)
|
||||
|
||||
# NOTE: tools defined above and removed for brevity
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model="tgi",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||
},
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice="auto", # tool selected by model
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
|
||||
called = chat_completion.choices[0].message.tool_calls
|
||||
print(called)
|
||||
# {
|
||||
# "id": 0,
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "description": None,
|
||||
# "name": "tools",
|
||||
# "parameters": {
|
||||
# "format": "celsius",
|
||||
# "location": "San Francisco, CA",
|
||||
# "num_days": 3,
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
```
|
48
docs/source/conceptual/speculation.md
Normal file
48
docs/source/conceptual/speculation.md
Normal file
@ -0,0 +1,48 @@
|
||||
## Speculation
|
||||
|
||||
Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea.
|
||||
The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid.
|
||||
|
||||
So you are making *more* computations on your LLM, but if you are correct you produce 1, 2, 3 etc.. tokens on a single LLM pass. Since LLMs are usually memory bound (and not compute bound), provided your guesses are correct enough, this is a 2-3x faster inference (It can be much more for code oriented tasks for instance).
|
||||
|
||||
You can check a more [detailed explanation](https://huggingface.co/blog/assisted-generation).
|
||||
|
||||
Text-generation inference supports 2 main speculative methods:
|
||||
|
||||
- Medusa
|
||||
- N-gram
|
||||
|
||||
|
||||
### Medusa
|
||||
|
||||
|
||||
Medusa is a [simple method](https://arxiv.org/abs/2401.10774) to create many tokens in a single pass using fine-tuned LM heads in addition to your existing models.
|
||||
|
||||
|
||||
You can check a few existing fine-tunes for popular models:
|
||||
|
||||
- [text-generation-inference/gemma-7b-it-medusa](https://huggingface.co/text-generation-inference/gemma-7b-it-medusa)
|
||||
- [text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa](https://huggingface.co/text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa)
|
||||
- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa)
|
||||
|
||||
|
||||
In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [https://github.com/FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa)
|
||||
|
||||
|
||||
In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically.
|
||||
|
||||
|
||||
### N-gram
|
||||
|
||||
|
||||
If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`.
|
||||
Ngram works by trying to find in the previous sequence existing tokens that match, and use those as speculation.
|
||||
|
||||
This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much.
|
||||
|
||||
|
||||
In order to enable n-gram speculation simply use
|
||||
|
||||
`--speculate 2` in your flags.
|
||||
|
||||
[Details about the flag](https://huggingface.co/docs/text-generation-inference/basic_tutorials/launcher#speculate)
|
@ -23,6 +23,8 @@ from text_generation.types import (
|
||||
Token,
|
||||
BestOfSequence,
|
||||
Grammar,
|
||||
ChatComplete,
|
||||
ChatCompletionChunk,
|
||||
)
|
||||
|
||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||
@ -59,6 +61,15 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
) -> bool:
|
||||
def convert_data(data):
|
||||
data = json.loads(data)
|
||||
if isinstance(data, Dict) and "choices" in data:
|
||||
choices = data["choices"]
|
||||
if (
|
||||
isinstance(choices, List)
|
||||
and len(choices) >= 1
|
||||
and "delta" in choices[0]
|
||||
):
|
||||
return ChatCompletionChunk(**data)
|
||||
return ChatComplete(**data)
|
||||
|
||||
if isinstance(data, Dict):
|
||||
return Response(**data)
|
||||
@ -144,6 +155,16 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
)
|
||||
)
|
||||
|
||||
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
|
||||
return (
|
||||
response.choices[0].message.content == other.choices[0].message.content
|
||||
)
|
||||
|
||||
def eq_chat_complete_chunk(
|
||||
response: ChatCompletionChunk, other: ChatCompletionChunk
|
||||
) -> bool:
|
||||
return response.choices[0].delta.content == other.choices[0].delta.content
|
||||
|
||||
def eq_response(response: Response, other: Response) -> bool:
|
||||
return response.generated_text == other.generated_text and eq_details(
|
||||
response.details, other.details
|
||||
@ -157,6 +178,19 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
if not isinstance(snapshot_data, List):
|
||||
snapshot_data = [snapshot_data]
|
||||
|
||||
if isinstance(serialized_data[0], ChatComplete):
|
||||
return len(snapshot_data) == len(serialized_data) and all(
|
||||
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||
)
|
||||
|
||||
if isinstance(serialized_data[0], ChatCompletionChunk):
|
||||
return len(snapshot_data) == len(serialized_data) and all(
|
||||
[
|
||||
eq_chat_complete_chunk(r, o)
|
||||
for r, o in zip(serialized_data, snapshot_data)
|
||||
]
|
||||
)
|
||||
|
||||
return len(snapshot_data) == len(serialized_data) and all(
|
||||
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||
)
|
||||
@ -236,6 +270,7 @@ def launcher(event_loop):
|
||||
use_flash_attention: bool = True,
|
||||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
master_port = random.randint(10_000, 20_000)
|
||||
@ -268,6 +303,9 @@ def launcher(event_loop):
|
||||
if dtype is not None:
|
||||
args.append("--dtype")
|
||||
args.append(dtype)
|
||||
if revision is not None:
|
||||
args.append("--revision")
|
||||
args.append(revision)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
|
||||
@ -302,6 +340,7 @@ def launcher(event_loop):
|
||||
use_flash_attention: bool = True,
|
||||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
|
||||
@ -317,6 +356,9 @@ def launcher(event_loop):
|
||||
if dtype is not None:
|
||||
args.append("--dtype")
|
||||
args.append(dtype)
|
||||
if revision is not None:
|
||||
args.append("--revision")
|
||||
args.append(revision)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
|
||||
|
@ -0,0 +1,94 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": null,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -5.2617188,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": -0.38476562,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 7670,
|
||||
"logprob": -7.640625,
|
||||
"text": "hello"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 2284,
|
||||
"logprob": -0.92626953,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
{
|
||||
"id": 303,
|
||||
"logprob": -0.40844727,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -0.27905273,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 459,
|
||||
"logprob": -0.6118164,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8302,
|
||||
"logprob": -0.68652344,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 10914,
|
||||
"logprob": -1.4619141,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": -0.7993164,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.63134766,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.23278809,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": -1.2294922,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||
}
|
@ -0,0 +1,394 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 60,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": null,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -5.2617188,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": -0.38476562,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 7670,
|
||||
"logprob": -7.640625,
|
||||
"text": "hello"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 2284,
|
||||
"logprob": -0.296875,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
{
|
||||
"id": 303,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8302,
|
||||
"logprob": -0.28125,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 10914,
|
||||
"logprob": -0.79248047,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": -0.61816406,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.0619812,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": -0.4091797,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 7670,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "hello"
|
||||
},
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 444,
|
||||
"logprob": -0.21655273,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 45,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 444,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 731,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "):"
|
||||
},
|
||||
{
|
||||
"id": 303,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8302,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 332,
|
||||
"logprob": -0.034698486,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 655,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " name"
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": -0.20141602,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 332,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 7670,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "hello"
|
||||
},
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 444,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 400,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "age"
|
||||
},
|
||||
{
|
||||
"id": 45,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 444,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 49,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 11505,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " age"
|
||||
},
|
||||
{
|
||||
"id": 731,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "):"
|
||||
},
|
||||
{
|
||||
"id": 303,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 459,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8302,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 332,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 655,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " name"
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 3021,
|
||||
"logprob": -0.5761719,
|
||||
"special": false,
|
||||
"text": " \","
|
||||
},
|
||||
{
|
||||
"id": 863,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " you"
|
||||
},
|
||||
{
|
||||
"id": 904,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " are"
|
||||
},
|
||||
{
|
||||
"id": 332,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 494,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " +"
|
||||
},
|
||||
{
|
||||
"id": 615,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " str"
|
||||
},
|
||||
{
|
||||
"id": 45,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 400,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "age"
|
||||
},
|
||||
{
|
||||
"id": 46,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)"
|
||||
}
|
@ -0,0 +1,378 @@
|
||||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": null,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -5.2617188,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": -0.38476562,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 7670,
|
||||
"logprob": -7.640625,
|
||||
"text": "hello"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 2284,
|
||||
"logprob": -0.92626953,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
{
|
||||
"id": 303,
|
||||
"logprob": -0.40722656,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -0.27954102,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 459,
|
||||
"logprob": -0.6142578,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8302,
|
||||
"logprob": -0.68310547,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 10914,
|
||||
"logprob": -1.4570312,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": -0.80126953,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.6303711,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.23327637,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": -1.2304688,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": null,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -5.2617188,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": -0.38476562,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 7670,
|
||||
"logprob": -7.640625,
|
||||
"text": "hello"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 2284,
|
||||
"logprob": -0.92626953,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
{
|
||||
"id": 303,
|
||||
"logprob": -0.40722656,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -0.27954102,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 459,
|
||||
"logprob": -0.6142578,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8302,
|
||||
"logprob": -0.68310547,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 10914,
|
||||
"logprob": -1.4570312,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": -0.80126953,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.6303711,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.23327637,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": -1.2304688,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": null,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -5.2617188,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": -0.38476562,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 7670,
|
||||
"logprob": -7.640625,
|
||||
"text": "hello"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 2284,
|
||||
"logprob": -0.92626953,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
{
|
||||
"id": 303,
|
||||
"logprob": -0.40722656,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -0.27954102,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 459,
|
||||
"logprob": -0.6142578,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8302,
|
||||
"logprob": -0.68310547,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 10914,
|
||||
"logprob": -1.4570312,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": -0.80126953,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.6303711,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.23327637,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": -1.2304688,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": null,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -5.2617188,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": -0.38476562,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 7670,
|
||||
"logprob": -7.640625,
|
||||
"text": "hello"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 2284,
|
||||
"logprob": -0.92626953,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
{
|
||||
"id": 303,
|
||||
"logprob": -0.40722656,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1489,
|
||||
"logprob": -0.27954102,
|
||||
"special": false,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 459,
|
||||
"logprob": -0.6142578,
|
||||
"special": false,
|
||||
"text": "(\""
|
||||
},
|
||||
{
|
||||
"id": 8302,
|
||||
"logprob": -0.68310547,
|
||||
"special": false,
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"id": 10914,
|
||||
"logprob": -1.4570312,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": -0.80126953,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.6303711,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.23327637,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": -1.2304688,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||
}
|
||||
]
|
@ -0,0 +1,26 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1708957015,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 100,
|
||||
"prompt_tokens": 60,
|
||||
"total_tokens": 160
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": null,
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1709079417,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 29,
|
||||
"prompt_tokens": 316,
|
||||
"total_tokens": 345
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": null,
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1709079492,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 29,
|
||||
"prompt_tokens": 316,
|
||||
"total_tokens": 345
|
||||
}
|
||||
}
|
@ -0,0 +1,37 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": null,
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY"
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1709079493,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 21,
|
||||
"prompt_tokens": 187,
|
||||
"total_tokens": 208
|
||||
}
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"arguments": "</s>",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 20,
|
||||
"type": "function"
|
||||
}
|
||||
},
|
||||
"finish_reason": "eos_token",
|
||||
"index": 20,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1709087088,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native"
|
||||
}
|
@ -3,7 +3,9 @@ import pytest
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_medusa_handle(launcher):
|
||||
with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle:
|
||||
with launcher(
|
||||
"FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2, revision="refs/pr/1"
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
55
integration-tests/models/test_flash_starcoder2.py
Normal file
55
integration-tests/models/test_flash_starcoder2.py
Normal file
@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_starcoder2_handle(launcher):
|
||||
with launcher("bigcode/starcoder2-3b", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_starcoder2(flash_starcoder2_handle):
|
||||
await flash_starcoder2_handle.health(300)
|
||||
return flash_starcoder2_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
||||
response = await flash_starcoder2.generate(
|
||||
"def print_hello", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
|
||||
response = await flash_starcoder2.generate(
|
||||
"def print_hello",
|
||||
max_new_tokens=60,
|
||||
temperature=0.2,
|
||||
top_p=0.95,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 60
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder2_load(
|
||||
flash_starcoder2, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_starcoder2, "def print_hello", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
240
integration-tests/models/test_tools_llama.py
Normal file
240
integration-tests/models/test_tools_llama.py
Normal file
@ -0,0 +1,240 @@
|
||||
import pytest
|
||||
import json
|
||||
|
||||
from text_generation.types import GrammarType
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_grammar_tools_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
|
||||
await flash_llama_grammar_tools_handle.health(300)
|
||||
return flash_llama_grammar_tools_handle.client
|
||||
|
||||
|
||||
# tools to be used in the following tests
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_n_day_weather_forecast",
|
||||
"description": "Get an N-day weather forecast",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
"num_days": {
|
||||
"type": "integer",
|
||||
"description": "The number of days to forecast",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format", "num_days"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_no_tools(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||
response = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.tool_calls == {
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
},
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
}
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_auto(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.tool_calls == {
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
},
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
}
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_choice(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
tool_choice="get_current_weather",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.tool_calls == {
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {"format": "celsius", "location": "New York, NY"},
|
||||
},
|
||||
}
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_stream(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
tool_choice="get_current_weather",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Paris, France?",
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
count = 0
|
||||
async for response in responses:
|
||||
count += 1
|
||||
|
||||
assert count == 20
|
||||
assert response == response_snapshot
|
@ -230,7 +230,6 @@ message WarmupRequest {
|
||||
uint32 max_total_tokens = 4;
|
||||
}
|
||||
|
||||
/// Empty response
|
||||
message WarmupResponse {
|
||||
/// Maximum number of tokens supported by the model
|
||||
optional uint32 max_supported_total_tokens = 1;
|
||||
|
@ -812,23 +812,27 @@ mod tests {
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
@ -877,28 +881,33 @@ mod tests {
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi again!".to_string(),
|
||||
content: Some("Hi again!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
@ -952,23 +961,27 @@ mod tests {
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
@ -1006,23 +1019,27 @@ mod tests {
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
|
@ -358,10 +358,11 @@ impl ChatCompletion {
|
||||
pub(crate) fn new(
|
||||
model: String,
|
||||
system_fingerprint: String,
|
||||
output: String,
|
||||
output: Option<String>,
|
||||
created: u64,
|
||||
details: Details,
|
||||
return_logprobs: bool,
|
||||
tool_calls: Option<ToolCall>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: String::new(),
|
||||
@ -375,6 +376,7 @@ impl ChatCompletion {
|
||||
role: "assistant".into(),
|
||||
content: output,
|
||||
name: None,
|
||||
tool_calls,
|
||||
},
|
||||
logprobs: return_logprobs
|
||||
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
||||
@ -413,15 +415,35 @@ pub(crate) struct ChatCompletionChoice {
|
||||
pub(crate) struct ChatCompletionDelta {
|
||||
#[schema(example = "user")]
|
||||
pub role: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub content: String,
|
||||
pub content: Option<String>,
|
||||
// default to None
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<DeltaToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
pub(crate) struct DeltaToolCall {
|
||||
pub index: u32,
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
pub(crate) struct Function {
|
||||
pub name: Option<String>,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
impl ChatCompletionChunk {
|
||||
pub(crate) fn new(
|
||||
model: String,
|
||||
system_fingerprint: String,
|
||||
delta: String,
|
||||
delta: Option<String>,
|
||||
tool_calls: Option<Vec<String>>,
|
||||
created: u64,
|
||||
index: u32,
|
||||
logprobs: Option<ChatCompletionLogprobs>,
|
||||
@ -438,6 +460,15 @@ impl ChatCompletionChunk {
|
||||
delta: ChatCompletionDelta {
|
||||
role: "assistant".to_string(),
|
||||
content: delta,
|
||||
tool_calls: tool_calls.map(|tc| DeltaToolCall {
|
||||
index,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: None,
|
||||
arguments: tc[0].to_string(),
|
||||
},
|
||||
}),
|
||||
},
|
||||
logprobs,
|
||||
finish_reason,
|
||||
@ -520,6 +551,125 @@ pub(crate) struct ChatRequest {
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = 0.95)]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
|
||||
/// functions the model may generate JSON inputs for.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
|
||||
/// A prompt to be appended before the tools
|
||||
#[serde(default = "default_tool_prompt")]
|
||||
#[schema(
|
||||
nullable = true,
|
||||
example = "\"Based on the conversation, please choose the most appropriate tool to use: \""
|
||||
)]
|
||||
pub tool_prompt: Option<String>,
|
||||
|
||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
||||
pub tool_choice: Option<ToolType>,
|
||||
}
|
||||
|
||||
fn default_tool_prompt() -> Option<String> {
|
||||
Some(
|
||||
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
|
||||
)
|
||||
}
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
enum ToolType {
|
||||
FunctionName(String),
|
||||
OneOf,
|
||||
}
|
||||
|
||||
/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None)
|
||||
mod deserialize_tool_choice {
|
||||
use super::*;
|
||||
use serde::de;
|
||||
use serde::Deserializer;
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ToolType>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
|
||||
match value {
|
||||
Value::String(s) => match s.as_str() {
|
||||
"none" => Ok(None),
|
||||
"auto" => Ok(Some(ToolType::OneOf)),
|
||||
_ => Ok(Some(ToolType::FunctionName(s))),
|
||||
},
|
||||
Value::Object(map) => {
|
||||
if let Some(content) = map
|
||||
.get("function")
|
||||
.and_then(|v| v.get("name"))
|
||||
.and_then(|v| v.as_str())
|
||||
{
|
||||
Ok(Some(ToolType::FunctionName(content.to_string())))
|
||||
} else {
|
||||
Err(de::Error::custom("function key not found in tool choice"))
|
||||
}
|
||||
}
|
||||
Value::Null => Ok(Some(ToolType::OneOf)),
|
||||
_ => Err(de::Error::custom("invalid token format")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub struct Tools {
|
||||
#[serde(flatten)]
|
||||
functions_map: FunctionsMap,
|
||||
properties: Properties,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FunctionsMap {
|
||||
#[serde(rename = "$functions")]
|
||||
functions: std::collections::HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FunctionRef {
|
||||
#[serde(rename = "$ref")]
|
||||
ref_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Properties {
|
||||
#[serde(serialize_with = "serialize_function")]
|
||||
function: Vec<FunctionRef>,
|
||||
}
|
||||
|
||||
fn serialize_function<S>(functions: &Vec<FunctionRef>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
use serde::ser::SerializeStruct;
|
||||
let mut state = serializer.serialize_struct("Function", 1)?;
|
||||
state.serialize_field("anyOf", functions)?;
|
||||
state.end()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
||||
pub(crate) struct FunctionDefinition {
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub name: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct Tool {
|
||||
// The type of the tool. Currently, only 'function' is supported.
|
||||
#[schema(example = "function")]
|
||||
pub r#type: String,
|
||||
// Grab the tool as generic JSON for debugging purposes.
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
@ -530,15 +680,25 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
||||
add_generation_prompt: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
pub(crate) struct ToolCall {
|
||||
pub id: u32,
|
||||
pub r#type: String,
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
pub(crate) struct Message {
|
||||
#[schema(example = "user")]
|
||||
pub role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "My name is David and I")]
|
||||
pub content: String,
|
||||
pub content: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "\"David\"")]
|
||||
pub name: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<ToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
|
@ -10,6 +10,7 @@ use crate::{
|
||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
||||
};
|
||||
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
@ -22,6 +23,8 @@ use futures::stream::StreamExt;
|
||||
use futures::Stream;
|
||||
use futures::TryStreamExt;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
@ -239,7 +242,7 @@ async fn generate(
|
||||
headers.insert("x-compute-type", compute_type.parse().unwrap());
|
||||
headers.insert(
|
||||
"x-compute-time",
|
||||
total_time.as_millis().to_string().parse().unwrap(),
|
||||
total_time.as_secs_f64().to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-compute-characters",
|
||||
@ -581,7 +584,7 @@ async fn chat_completions(
|
||||
let seed = req.seed;
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let inputs = match infer.apply_chat_template(req.messages) {
|
||||
let mut inputs = match infer.apply_chat_template(req.messages) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
@ -596,6 +599,62 @@ async fn chat_completions(
|
||||
}
|
||||
};
|
||||
|
||||
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
|
||||
let tool_prompt = req.tool_prompt.unwrap_or_default();
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *name)
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Tool choice not found in tool names".to_string(),
|
||||
error_type: "Tool not found".to_string(),
|
||||
}),
|
||||
)
|
||||
})?
|
||||
.clone()]
|
||||
}
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
};
|
||||
|
||||
let functions: HashMap<String, Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
(func.name, func.parameters)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
||||
Some(GrammarType::Json(serde_json::json!(tools)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
@ -617,7 +676,7 @@ async fn chat_completions(
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
grammar: tool_grammar.clone(),
|
||||
},
|
||||
};
|
||||
|
||||
@ -640,11 +699,19 @@ async fn chat_completions(
|
||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
|
||||
});
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if tool_grammar.is_some() {
|
||||
(None, Some(vec![stream_token.token.text]))
|
||||
} else {
|
||||
(Some(stream_token.token.text), None)
|
||||
};
|
||||
|
||||
event
|
||||
.json_data(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
stream_token.token.text,
|
||||
content,
|
||||
tool_calls,
|
||||
current_time,
|
||||
stream_token.index,
|
||||
logprobs,
|
||||
@ -681,14 +748,54 @@ async fn chat_completions(
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let (tool_calls, output) = if tool_grammar.is_some() {
|
||||
// gen_text should be valid json
|
||||
let gen_text_value: Value =
|
||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_call = Some(ToolCall {
|
||||
id: 0,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name: "tools".to_string(),
|
||||
parameters: gen_text_value.get("function").map_or_else(
|
||||
|| {
|
||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
},
|
||||
|f| Ok(f.clone()),
|
||||
)?,
|
||||
},
|
||||
});
|
||||
(tool_call, None)
|
||||
} else {
|
||||
(None, Some(generation.generated_text))
|
||||
};
|
||||
// build the complete response object with the full text
|
||||
let response = ChatCompletion::new(
|
||||
model_id,
|
||||
system_fingerprint,
|
||||
generation.generated_text,
|
||||
output,
|
||||
current_time,
|
||||
generation.details.unwrap(),
|
||||
logprobs,
|
||||
tool_calls,
|
||||
);
|
||||
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
|
@ -154,12 +154,8 @@ def download_weights(
|
||||
import json
|
||||
|
||||
medusa_head = hf_hub_download(
|
||||
model_id, revision=revision, filename="medusa_lm_head.pt"
|
||||
model_id, revision=revision, filename="medusa_lm_head.safetensors"
|
||||
)
|
||||
if auto_convert:
|
||||
medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors")
|
||||
if not medusa_sf.exists():
|
||||
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
||||
medusa_config = hf_hub_download(
|
||||
model_id, revision=revision, filename="config.json"
|
||||
)
|
||||
@ -198,16 +194,12 @@ def download_weights(
|
||||
if not extension == ".safetensors" or not auto_convert:
|
||||
raise e
|
||||
|
||||
elif (Path(model_id) / "medusa_lm_head.pt").exists():
|
||||
elif (Path(model_id) / "medusa_lm_head.safetensors").exists():
|
||||
# Try to load as a local Medusa model
|
||||
try:
|
||||
import json
|
||||
|
||||
medusa_head = Path(model_id) / "medusa_lm_head.pt"
|
||||
if auto_convert:
|
||||
medusa_sf = Path(model_id) / "medusa_lm_head.safetensors"
|
||||
if not medusa_sf.exists():
|
||||
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
||||
medusa_head = Path(model_id) / "medusa_lm_head.safetensors"
|
||||
medusa_config = Path(model_id) / "config.json"
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
|
@ -3,7 +3,9 @@ import torch
|
||||
from loguru import logger
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import modeling_auto
|
||||
from huggingface_hub import hf_hub_download
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
@ -65,6 +67,7 @@ try:
|
||||
from text_generation_server.models.flash_mistral import FlashMistral
|
||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||
from text_generation_server.models.flash_phi import FlashPhi
|
||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
||||
|
||||
except ImportError as e:
|
||||
@ -82,6 +85,7 @@ if FLASH_ATTENTION:
|
||||
__all__.append(FlashMixtral)
|
||||
__all__.append(FlashPhi)
|
||||
__all__.append(FlashQwen2)
|
||||
__all__.append(FlashStarcoder2)
|
||||
|
||||
MAMBA_AVAILABLE = True
|
||||
try:
|
||||
@ -119,44 +123,14 @@ def get_model(
|
||||
else:
|
||||
set_speculate(0)
|
||||
|
||||
if "facebook/galactica" in model_id:
|
||||
return GalacticaSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_id.startswith("bigcode/"):
|
||||
if FLASH_ATTENTION:
|
||||
return FlashSantacoderSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
||||
)
|
||||
else:
|
||||
return SantaCoder(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
use_medusa = None
|
||||
if "medusa_num_heads" in config_dict:
|
||||
use_medusa = model_id
|
||||
medusa_model_id = model_id
|
||||
medusa_revision = revision
|
||||
model_id = config_dict["base_model_name_or_path"]
|
||||
revision = "main"
|
||||
speculate_medusa = config_dict["medusa_num_heads"]
|
||||
@ -173,6 +147,20 @@ def get_model(
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
is_local = Path(medusa_model_id).exists()
|
||||
if not is_local:
|
||||
medusa_config = hf_hub_download(
|
||||
medusa_model_id, revision=medusa_revision, filename="config.json"
|
||||
)
|
||||
hf_hub_download(
|
||||
medusa_model_id,
|
||||
revision=medusa_revision,
|
||||
filename="medusa_lm_head.safetensors",
|
||||
)
|
||||
use_medusa = Path(medusa_config).parent
|
||||
else:
|
||||
use_medusa = Path(medusa_model_id)
|
||||
|
||||
method = "medusa"
|
||||
else:
|
||||
method = "n-gram"
|
||||
@ -197,16 +185,32 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "gpt_bigcode":
|
||||
if model_id.startswith("facebook/galactica"):
|
||||
return GalacticaSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if (
|
||||
model_type == "gpt_bigcode"
|
||||
or model_type == "gpt2"
|
||||
and model_id.startswith("bigcode/")
|
||||
):
|
||||
if FLASH_ATTENTION:
|
||||
return FlashSantacoderSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -219,6 +223,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -228,6 +233,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -236,6 +242,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -246,6 +253,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -254,6 +262,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -262,6 +271,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -272,15 +282,16 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_medusa=use_medusa,
|
||||
)
|
||||
else:
|
||||
return CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -295,6 +306,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -305,9 +317,9 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_medusa=use_medusa,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||
@ -316,6 +328,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -346,9 +359,9 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_medusa=use_medusa,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
@ -359,6 +372,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -372,6 +386,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -382,6 +397,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -390,6 +406,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -403,6 +420,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -413,6 +431,19 @@ def get_model(
|
||||
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
||||
) or HAS_FLASH_ATTN_V2_CUDA:
|
||||
return FlashMixtral(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if model_type == "starcoder2":
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if (
|
||||
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
||||
) or HAS_FLASH_ATTN_V2_CUDA:
|
||||
return FlashStarcoder2(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
@ -425,6 +456,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -434,6 +466,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -443,6 +476,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -466,6 +500,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -474,6 +509,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -485,6 +521,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -493,6 +530,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -42,6 +42,7 @@ class BLOOMSharded(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -70,6 +71,7 @@ class BLOOMSharded(CausalLM):
|
||||
)
|
||||
config.pad_token_id = 3
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
@ -103,7 +105,7 @@ class BLOOMSharded(CausalLM):
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@ -112,4 +114,4 @@ class BLOOMSharded(CausalLM):
|
||||
)
|
||||
|
||||
logits = outputs.logits
|
||||
return logits, outputs.past_key_values
|
||||
return logits, speculative_logits, outputs.past_key_values
|
||||
|
@ -482,6 +482,7 @@ class CausalLM(Model):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -550,7 +551,9 @@ class CausalLM(Model):
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
) -> Tuple[
|
||||
torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
|
||||
]:
|
||||
# Model Forward
|
||||
kwargs = {
|
||||
"input_ids": input_ids,
|
||||
@ -563,7 +566,11 @@ class CausalLM(Model):
|
||||
kwargs["position_ids"] = position_ids
|
||||
|
||||
outputs = self.model.forward(**kwargs)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
if isinstance(outputs, tuple):
|
||||
outputs, speculative_logits = outputs
|
||||
else:
|
||||
speculative_logits = None
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
@ -573,7 +580,7 @@ class CausalLM(Model):
|
||||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
|
||||
logits, past = self.forward(
|
||||
logits, speculative_logits, past = self.forward(
|
||||
batch.input_ids,
|
||||
attention_mask,
|
||||
batch.position_ids,
|
||||
|
@ -36,7 +36,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
)
|
||||
|
||||
CUSTOM_KERNELS_ENABLED = False
|
||||
@ -820,7 +820,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.transformer = BloomModel(config, weights)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="word_embeddings",
|
||||
weights=weights,
|
||||
@ -904,17 +904,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
return (
|
||||
CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
logits=logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
),
|
||||
speculative_logits,
|
||||
)
|
||||
|
@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastRMSNorm,
|
||||
)
|
||||
@ -575,7 +575,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashGemmaModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
|
||||
weights=weights,
|
||||
@ -592,7 +592,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
@ -605,5 +605,5 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastRMSNorm,
|
||||
)
|
||||
@ -410,7 +410,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashLlamaModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
@ -427,7 +427,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
@ -440,5 +440,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastRMSNorm,
|
||||
)
|
||||
@ -419,7 +419,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.model = MistralModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
|
@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
@ -810,7 +810,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.model = MixtralModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
|
@ -33,7 +33,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
FastLayerNorm,
|
||||
PositionRotaryEmbedding,
|
||||
get_linear,
|
||||
@ -369,7 +369,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.gpt_neox = FlashGPTNeoXModel(config, weights)
|
||||
|
||||
self.embed_out = TensorParallelHead.load(
|
||||
self.embed_out = SpeculativeHead.load(
|
||||
config, prefix="embed_out", weights=weights
|
||||
)
|
||||
|
||||
|
@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastLayerNorm,
|
||||
)
|
||||
@ -376,7 +376,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashPhiModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
|
@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
FastLayerNorm,
|
||||
PositionRotaryEmbedding,
|
||||
get_linear,
|
||||
@ -613,9 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
|
||||
self.transformer = FlashRWModel(config, weights)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config, prefix="lm_head", weights=weights
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -9,7 +9,7 @@ from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
TensorParallelEmbedding,
|
||||
FastLayerNorm,
|
||||
get_linear,
|
||||
@ -453,7 +453,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.transformer = FlashSantacoderModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="transformer.wte", weights=weights
|
||||
)
|
||||
|
||||
|
@ -0,0 +1,545 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastRMSNorm,
|
||||
FastLayerNorm,
|
||||
)
|
||||
|
||||
|
||||
class Starcoder2Config(PretrainedConfig):
|
||||
model_type = "starcoder2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=49152,
|
||||
hidden_size=3072,
|
||||
intermediate_size=12288,
|
||||
num_hidden_layers=30,
|
||||
num_attention_heads=24,
|
||||
num_key_value_heads=2,
|
||||
mlp_type="default",
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
max_position_embeddings=4096,
|
||||
initializer_range=0.018042,
|
||||
norm_type="layer_norm",
|
||||
norm_epsilon=1e-5,
|
||||
use_cache=True,
|
||||
bos_token_id=50256,
|
||||
eos_token_id=50256,
|
||||
rope_theta=10000.0,
|
||||
sliding_window=None,
|
||||
attention_dropout=0.0,
|
||||
residual_dropout=0.0,
|
||||
embedding_dropout=0.0,
|
||||
use_bias: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.use_bias = use_bias
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.mlp_type = mlp_type
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.norm_type = norm_type
|
||||
self.norm_epsilon = norm_epsilon
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.residual_dropout = residual_dropout
|
||||
self.embedding_dropout = embedding_dropout
|
||||
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
if config.use_bias:
|
||||
w = [
|
||||
weights.get_sharded(f"{p}.bias", dim=0)
|
||||
for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
|
||||
]
|
||||
bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=bias, quantize=config.quantize)
|
||||
)
|
||||
|
||||
|
||||
class Starcoder2Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class Starcoder2MLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.c_fc = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.c_fc",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
self.c_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
return self.c_proj(hidden_states)
|
||||
|
||||
|
||||
class Starcoder2GatedMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
|
||||
|
||||
STARCODER2_NORMALIZATION_CLASSES = {
|
||||
"layer_norm": FastLayerNorm,
|
||||
"rms_norm": FastRMSNorm,
|
||||
}
|
||||
|
||||
STARCODER2_MLP_CLASSES = {
|
||||
"default": Starcoder2MLP,
|
||||
"gated": Starcoder2GatedMLP,
|
||||
}
|
||||
|
||||
|
||||
class Starcoder2Layer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = Starcoder2Attention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.norm_epsilon
|
||||
)
|
||||
self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[
|
||||
config.norm_type
|
||||
].load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.norm_epsilon,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
normed_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
||||
class Starcoder2Model(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Starcoder2Layer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
||||
prefix="model.norm", weights=weights, eps=config.norm_epsilon
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, true_max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = Starcoder2Model(config, weights)
|
||||
try:
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
except RuntimeError:
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="model.embed_tokens",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past_tensor = (
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
if self.max_past is not None
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
elif self.max_past is not None:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
@ -51,7 +51,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
PositionRotaryEmbedding,
|
||||
FastLinear,
|
||||
)
|
||||
@ -272,9 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
||||
weights,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.fc = TensorParallelHead.load(
|
||||
config=config, prefix="lm_head", weights=weights
|
||||
)
|
||||
self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights)
|
||||
self.additional_fc = FastLinear.load(
|
||||
config=config,
|
||||
prefix="lm_head.additional_fc",
|
||||
@ -283,11 +281,11 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
output = self.fc(input)
|
||||
output, speculative_logits = self.fc(input)
|
||||
additional_features = self.additional_fc(input)
|
||||
output = torch.cat((output, additional_features), -1)
|
||||
|
||||
return output
|
||||
return output, speculative_logits
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
|
||||
@ -1503,17 +1501,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
|
||||
return CausalLMOutputWithPastImage(
|
||||
return (
|
||||
CausalLMOutputWithPastImage(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
),
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
|
@ -9,6 +9,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
import torch.nn.functional as F
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
SpeculativeHead,
|
||||
TensorParallelEmbedding,
|
||||
FastRMSNorm,
|
||||
FastLinear,
|
||||
@ -205,14 +206,12 @@ class MambaModel(nn.Module):
|
||||
self.norm_f = FastRMSNorm.load(
|
||||
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.lm_head = FastLinear.load(
|
||||
config, f"{prefix}.embedding", weights, bias=False
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor, inference_params=None, residual=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, residual, conv_state, ssm_state = block(
|
||||
@ -226,8 +225,8 @@ class MambaModel(nn.Module):
|
||||
)
|
||||
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
||||
hidden_states = hidden_states.view(residual.shape)
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
|
||||
# update the offset for the next inference using these params
|
||||
inference_params.seqlen_offset += input_ids.size(1)
|
||||
return logits
|
||||
return logits, speculative_logits
|
||||
|
@ -21,7 +21,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
@ -1090,7 +1090,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
||||
if not config.tie_word_embeddings:
|
||||
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
||||
self.transformer = MPTModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="transformer.wte", weights=weights
|
||||
)
|
||||
self.logit_scale = None
|
||||
@ -1133,7 +1133,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
logits = self.lm_head(outputs.last_hidden_state)
|
||||
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
|
||||
if self.logit_scale is not None:
|
||||
if self.logit_scale == 0:
|
||||
warnings.warn(
|
||||
@ -1147,12 +1147,15 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
||||
loss = F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
|
||||
)
|
||||
return CausalLMOutputWithPast(
|
||||
return (
|
||||
CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
),
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
|
@ -44,7 +44,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
)
|
||||
|
||||
|
||||
@ -646,7 +646,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__(config)
|
||||
self.gpt_neox = GPTNeoXModel(config, weights)
|
||||
self.embed_out = TensorParallelHead.load(
|
||||
self.embed_out = SpeculativeHead.load(
|
||||
config, prefix="embed_out", weights=weights
|
||||
)
|
||||
|
||||
|
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
)
|
||||
|
||||
EPS = 1e-5
|
||||
@ -748,7 +748,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
||||
|
||||
self.model = OPTModel(config, weights)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="model.decoder.embed_tokens", weights=weights
|
||||
)
|
||||
|
||||
|
@ -13,7 +13,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
FastLinear,
|
||||
)
|
||||
|
||||
@ -120,7 +120,7 @@ class PhiCausalLMHead(nn.Module):
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
self.linear = TensorParallelHead.load(
|
||||
self.linear = SpeculativeHead.load(
|
||||
config=config, prefix="lm_head.linear", weights=weights
|
||||
)
|
||||
|
||||
|
@ -42,7 +42,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
)
|
||||
|
||||
|
||||
@ -1033,14 +1033,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
)
|
||||
|
||||
try:
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="lm_head", weights=weights
|
||||
)
|
||||
except RuntimeError:
|
||||
# Some models like t5-small were saved with shared weights unlike flan
|
||||
# Since they are declared as the same arch we have no choice but hope
|
||||
# that this is OK instead of using a proper flag.
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="shared", weights=weights
|
||||
)
|
||||
|
||||
@ -1126,7 +1126,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
||||
sequence_output = sequence_output * (self.model_dim**-0.5)
|
||||
|
||||
lm_logits = self.lm_head(sequence_output)
|
||||
logits, speculative_logits = self.lm_head(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
@ -1140,9 +1140,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
return (
|
||||
Seq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
logits=logits,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
@ -1150,6 +1151,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
),
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
|
@ -723,7 +723,7 @@ class FlashCausalLM(Model):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
@ -734,6 +734,8 @@ class FlashCausalLM(Model):
|
||||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
@ -805,7 +807,9 @@ class FlashCausalLM(Model):
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor:
|
||||
def forward(
|
||||
self, batch: FlashCausalLMBatch
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
@ -900,9 +904,14 @@ class FlashCausalLM(Model):
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
return cuda_graph["logits"][:bs]
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
@ -926,16 +935,11 @@ class FlashCausalLM(Model):
|
||||
batch.slots = slots
|
||||
|
||||
try:
|
||||
out = self.forward(batch)
|
||||
out, speculative_logits = self.forward(batch)
|
||||
except Exception as e:
|
||||
del batch
|
||||
raise e
|
||||
|
||||
if isinstance(out, tuple):
|
||||
out, speculative_logits = out
|
||||
else:
|
||||
speculative_logits = None
|
||||
|
||||
if prefill:
|
||||
next_token_logits = (
|
||||
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||
|
@ -25,9 +25,9 @@ class FlashGemma(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_medusa: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -50,6 +50,7 @@ class FlashGemma(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
@ -59,36 +60,6 @@ class FlashGemma(FlashCausalLM):
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashGemmaForCausalLM(config, weights)
|
||||
if use_medusa:
|
||||
from text_generation_server.utils.medusa import MedusaModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (
|
||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="config.json"
|
||||
)
|
||||
medusa_head = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
else:
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||
weights = Weights(
|
||||
[medusa_sf], device, dtype, process_group=self.process_group
|
||||
)
|
||||
lm_head = model.lm_head
|
||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGemma, self).__init__(
|
||||
|
@ -26,9 +26,9 @@ class FlashLlama(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_medusa: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -58,6 +58,7 @@ class FlashLlama(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
@ -67,37 +68,6 @@ class FlashLlama(FlashCausalLM):
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
if use_medusa:
|
||||
from text_generation_server.utils.medusa import MedusaModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (
|
||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="config.json"
|
||||
)
|
||||
medusa_head = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
else:
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||
weights = Weights(
|
||||
[medusa_sf], device, dtype, process_group=self.process_group
|
||||
)
|
||||
lm_head = model.lm_head
|
||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model=model,
|
||||
|
@ -8,7 +8,7 @@ from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
from typing import Optional, Tuple, Type, List
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
@ -38,6 +38,19 @@ SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||
|
||||
|
||||
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
|
||||
global SLIDING_WINDOW
|
||||
global SLIDING_WINDOW_BLOCKS
|
||||
SLIDING_WINDOW = sliding_window
|
||||
SLIDING_WINDOW_BLOCKS = sliding_window_blocks
|
||||
|
||||
|
||||
def get_sliding_windows() -> Tuple[int, int]:
|
||||
global SLIDING_WINDOW
|
||||
global SLIDING_WINDOW_BLOCKS
|
||||
return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS
|
||||
|
||||
|
||||
# Adds windowing logic to FlashCausalLMBatch
|
||||
@dataclass
|
||||
class FlashMistralBatch(FlashCausalLMBatch):
|
||||
@ -53,8 +66,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
global SLIDING_WINDOW
|
||||
global SLIDING_WINDOW_BLOCKS
|
||||
sliding_window, sliding_window_blocks = get_sliding_windows()
|
||||
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
@ -139,8 +151,8 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
|
||||
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||
if SLIDING_WINDOW_BLOCKS is not None:
|
||||
needed_blocks = min(needed_blocks, SLIDING_WINDOW_BLOCKS)
|
||||
if sliding_window_blocks is not None:
|
||||
needed_blocks = min(needed_blocks, sliding_window_blocks)
|
||||
blocks += needed_blocks
|
||||
|
||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||
@ -154,9 +166,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
slot_indices.append(request_slot_indices)
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
if SLIDING_WINDOW is not None:
|
||||
if sliding_window is not None:
|
||||
request_prefill_cache_indices = torch.arange(
|
||||
cumulative_length + max(0, input_length - SLIDING_WINDOW),
|
||||
cumulative_length + max(0, input_length - sliding_window),
|
||||
cumulative_length + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
@ -212,13 +224,13 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||
position_ids = torch.cat(position_ids)
|
||||
slot_indices = torch.cat(slot_indices)
|
||||
if SLIDING_WINDOW is not None:
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||
else:
|
||||
input_ids = all_input_ids[0]
|
||||
position_ids = position_ids[0]
|
||||
slot_indices = slot_indices[0]
|
||||
if SLIDING_WINDOW is not None:
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = prefill_cache_indices[0]
|
||||
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
@ -228,7 +240,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||
position_ids = position_ids.to(device)
|
||||
slot_indices = slot_indices.to(device)
|
||||
prefill_cache_indices = (
|
||||
prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None
|
||||
prefill_cache_indices.to(device) if sliding_window is not None else None
|
||||
)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
input_lengths_tensor = torch.tensor(
|
||||
@ -294,12 +306,10 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
global SLIDING_WINDOW
|
||||
global SLIDING_WINDOW_BLOCKS
|
||||
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
@ -319,11 +329,13 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
# Set context windows
|
||||
if config.sliding_window is not None:
|
||||
SLIDING_WINDOW = config.sliding_window
|
||||
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||
set_sliding_window(
|
||||
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
@ -394,7 +406,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
@ -406,9 +418,13 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(
|
||||
self, batch: FlashMistralBatch
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
@ -479,7 +495,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits = self.model.forward(
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
@ -493,7 +509,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
@ -511,7 +527,13 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
return cuda_graph["logits"][:bs]
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class FlashMistral(BaseFlashMistral):
|
||||
@ -520,6 +542,7 @@ class FlashMistral(BaseFlashMistral):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -529,6 +552,7 @@ class FlashMistral(BaseFlashMistral):
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -15,6 +15,7 @@ class FlashMixtral(BaseFlashMistral):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -24,6 +25,7 @@ class FlashMixtral(BaseFlashMistral):
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -24,6 +24,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -46,6 +47,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
@ -25,9 +25,9 @@ class FlashPhi(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_medusa: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -48,6 +48,7 @@ class FlashPhi(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -25,6 +25,7 @@ class FlashRWSharded(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -61,6 +62,7 @@ class FlashRWSharded(FlashCausalLM):
|
||||
)
|
||||
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
|
@ -27,6 +27,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -51,6 +52,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||
trust_remote_code=True,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.transpose = config.architectures[0].startswith("GPT2")
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
86
server/text_generation_server/models/flash_starcoder2.py
Normal file
86
server/text_generation_server/models/flash_starcoder2.py
Normal file
@ -0,0 +1,86 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||
|
||||
from text_generation_server.models.cache_manager import BLOCK_SIZE
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
set_sliding_window,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||
Starcoder2Config,
|
||||
FlashStarcoder2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
# Starcoder2 has the same base as Mistral
|
||||
class FlashStarcoder2(BaseFlashMistral):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = Starcoder2Config.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
# Set context windows
|
||||
if config.sliding_window is not None:
|
||||
set_sliding_window(
|
||||
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashStarcoder2ForCausalLM(config, weights)
|
||||
|
||||
self.cuda_graphs = {}
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(BaseFlashMistral, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
sliding_window=config.sliding_window,
|
||||
)
|
@ -31,6 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -51,6 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.vision_config.quantize = quantize
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
|
@ -662,8 +662,13 @@ class IdeficsCausalLM(Model):
|
||||
if self.has_position_ids:
|
||||
kwargs["position_ids"] = position_ids
|
||||
|
||||
outputs = self.model.forward(**kwargs)
|
||||
return outputs.logits, outputs.past_key_values, outputs.image_hidden_states
|
||||
outputs, speculative_logits = self.model.forward(**kwargs)
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
outputs.past_key_values,
|
||||
outputs.image_hidden_states,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
@ -686,7 +691,7 @@ class IdeficsCausalLM(Model):
|
||||
:, : -batch.padding_right_offset
|
||||
]
|
||||
|
||||
logits, past, image_hidden_states = self.forward(
|
||||
logits, speculative_logits, past, image_hidden_states = self.forward(
|
||||
input_ids=batch.input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=batch.position_ids,
|
||||
|
@ -408,6 +408,7 @@ class Mamba(Model):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -444,6 +445,7 @@ class Mamba(Model):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
@ -505,7 +507,7 @@ class Mamba(Model):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
logits = self.model.forward(
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids, inference_params=inference_params
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
@ -514,6 +516,7 @@ class Mamba(Model):
|
||||
"inference_params": inference_params,
|
||||
"graph": graph,
|
||||
"logits": logits,
|
||||
"speculative_logits": speculative_logits,
|
||||
}
|
||||
self.cuda_graphs[batch_size] = graph_dict
|
||||
|
||||
@ -556,9 +559,14 @@ class Mamba(Model):
|
||||
inference_params.ssm_states.copy_(
|
||||
cuda_graph["inference_params"].ssm_states[:, :bs]
|
||||
)
|
||||
|
||||
# Slice output to the correct shape
|
||||
return cuda_graph["logits"][:bs]
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
@ -589,7 +597,9 @@ class Mamba(Model):
|
||||
batch.inference_params = inference_params
|
||||
|
||||
# Forward pass
|
||||
logits = self.forward(input_ids, inference_params=batch.inference_params)
|
||||
logits, speculative_logits = self.forward(
|
||||
input_ids, inference_params=batch.inference_params
|
||||
)
|
||||
|
||||
# batch.inference_params = new_inference_params
|
||||
# Results
|
||||
|
@ -43,6 +43,7 @@ class MPTSharded(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -75,6 +76,7 @@ class MPTSharded(CausalLM):
|
||||
config = json.load(f)
|
||||
config = PretrainedConfig(**config)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -22,6 +22,7 @@ class OPTSharded(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -47,6 +48,7 @@ class OPTSharded(CausalLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -22,6 +22,7 @@ class Phi(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -52,6 +53,7 @@ class Phi(CausalLM):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
|
@ -19,6 +19,7 @@ class SantaCoder(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -532,6 +532,7 @@ class Seq2SeqLM(Model):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -596,6 +597,7 @@ class Seq2SeqLM(Model):
|
||||
past_key_values: Optional = None,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
torch.Tensor,
|
||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
@ -609,8 +611,15 @@ class Seq2SeqLM(Model):
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
if isinstance(outputs, tuple):
|
||||
# Our custom models
|
||||
outputs, speculative_logits = outputs
|
||||
else:
|
||||
# Generic transformers models
|
||||
speculative_logits = None
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
outputs.encoder_last_hidden_state,
|
||||
outputs.past_key_values,
|
||||
)
|
||||
@ -635,7 +644,7 @@ class Seq2SeqLM(Model):
|
||||
else:
|
||||
encoder_last_hidden_state = None
|
||||
|
||||
logits, encoder_last_hidden_state, past = self.forward(
|
||||
logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
|
||||
batch.input_ids,
|
||||
batch.attention_mask,
|
||||
batch.decoder_input_ids,
|
||||
|
@ -25,6 +25,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -42,6 +43,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
@ -94,7 +96,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
@ -106,6 +108,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
outputs.encoder_last_hidden_state,
|
||||
outputs.past_key_values,
|
||||
)
|
||||
|
@ -40,6 +40,7 @@ def _weight_hub_files_from_model_info(
|
||||
and "arguments" not in s.rfilename
|
||||
and "args" not in s.rfilename
|
||||
and "training" not in s.rfilename
|
||||
and "medusa_lm_head" not in s.rfilename
|
||||
]
|
||||
|
||||
|
||||
@ -56,6 +57,7 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
||||
and "args" not in f
|
||||
and "adapter" not in f
|
||||
and "training" not in f
|
||||
and "medusa_lm_head" not in f
|
||||
]
|
||||
return filenames
|
||||
|
||||
|
@ -4,7 +4,7 @@ import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from typing import List
|
||||
from typing import List, Tuple, Optional
|
||||
from loguru import logger
|
||||
from functools import lru_cache
|
||||
|
||||
@ -380,6 +380,96 @@ class SuperLayer(nn.Module):
|
||||
return self.linear.forward(x)
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.linear = FastLinear.load(
|
||||
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
||||
)
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.act(self.linear(x))
|
||||
|
||||
|
||||
class MedusaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.heads = torch.nn.ModuleList(
|
||||
[
|
||||
MedusaHead(config, prefix=f"{i}", weights=weights)
|
||||
for i in range(config["medusa_num_heads"])
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||
return speculative_logits
|
||||
|
||||
|
||||
class MedusaHead(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||
for i in range(config["medusa_num_layers"])
|
||||
]
|
||||
)
|
||||
n = len(self.blocks)
|
||||
self.out = FastLinear.load(
|
||||
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
class SpeculativeHead(nn.Module):
|
||||
def __init__(self, lm_head, medusa):
|
||||
super().__init__()
|
||||
self.lm_head = lm_head
|
||||
self.medusa = medusa
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
use_medusa = config.use_medusa
|
||||
if use_medusa:
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
weights.routing[k] = filename
|
||||
|
||||
medusa = MedusaModel(config, weights)
|
||||
else:
|
||||
medusa = None
|
||||
return SpeculativeHead(lm_head, medusa)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
logits = self.lm_head(input)
|
||||
speculative_logits = self.medusa(input) if self.medusa is not None else None
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
def __init__(self, linear, process_group, should_gather: bool):
|
||||
super().__init__(linear)
|
||||
|
@ -1,59 +0,0 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
|
||||
|
||||
|
||||
@dataclass
|
||||
class Output:
|
||||
logits: torch.FloatTensor = None
|
||||
speculative_logits: torch.FloatTensor = None
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.linear = FastLinear.load(
|
||||
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
||||
)
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.act(self.linear(x))
|
||||
|
||||
|
||||
class MedusaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights, lm_head):
|
||||
super().__init__()
|
||||
self.heads = torch.nn.ModuleList(
|
||||
[
|
||||
MedusaHead(config, prefix=f"{i}", weights=weights)
|
||||
for i in range(config["medusa_num_heads"])
|
||||
]
|
||||
)
|
||||
self.lm_head = lm_head
|
||||
|
||||
def forward(self, x):
|
||||
logits = self.lm_head(x)
|
||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class MedusaHead(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||
for i in range(config["medusa_num_layers"])
|
||||
]
|
||||
)
|
||||
n = len(self.blocks)
|
||||
self.out = FastLinear.load(
|
||||
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.out(x)
|
||||
return x
|
Loading…
Reference in New Issue
Block a user