Merge branch 'feat/qwen' into qwen2

This commit is contained in:
OlivierDehaene 2024-02-28 14:34:08 +01:00 committed by GitHub
commit 73403aa4db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
65 changed files with 3518 additions and 327 deletions

View File

@ -52,6 +52,8 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) - Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
- Stop sequences - Stop sequences
- Log probabilities - Log probabilities
- [Speculation](https://huggingface.co/docs/text-generation-inference/conceptual/speculation) ~2x latency
- [Guidance/JSON](https://huggingface.co/docs/text-generation-inference/conceptual/guidance). Specify output format to speed up inference and make sure the output is valid according to some specs..
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output - Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance - Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance

View File

@ -3,7 +3,7 @@ import requests
from aiohttp import ClientSession, ClientTimeout from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
from text_generation.types import ( from text_generation.types import (
StreamResponse, StreamResponse,
@ -11,6 +11,11 @@ from text_generation.types import (
Request, Request,
Parameters, Parameters,
Grammar, Grammar,
ChatRequest,
ChatCompletionChunk,
ChatComplete,
Message,
Tool,
) )
from text_generation.errors import parse_error from text_generation.errors import parse_error
@ -59,6 +64,114 @@ class Client:
self.cookies = cookies self.cookies = cookies
self.timeout = timeout self.timeout = timeout
def chat(
self,
messages: List[Message],
frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
stream: bool = False,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
tools: Optional[List[Tool]] = None,
tool_choice: Optional[str] = None,
):
"""
Given a list of messages, generate a response asynchronously
Args:
messages (`List[Message]`):
List of messages
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
logit_bias (`List[float]`):
Adjust the likelihood of specified tokens
logprobs (`bool`):
Include log probabilities in the response
top_logprobs (`int`):
Include the `n` most likely tokens at each step
max_tokens (`int`):
Maximum number of generated tokens
n (`int`):
Generate `n` completions
presence_penalty (`float`):
The parameter for presence penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
stream (`bool`):
Stream the response
seed (`int`):
Random sampling seed
temperature (`float`):
The value used to module the logits distribution.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation
tools (`List[Tool]`):
List of tools to use
tool_choice (`str`):
The tool to use
"""
request = ChatRequest(
model="tgi",
messages=messages,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
top_logprobs=top_logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
stream=stream,
seed=seed,
temperature=temperature,
top_p=top_p,
tools=tools,
tool_choice=tool_choice,
)
if not stream:
resp = requests.post(
f"{self.base_url}/v1/chat/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return ChatComplete(**payload)
else:
return self._chat_stream_response(request)
def _chat_stream_response(self, request):
resp = requests.post(
f"{self.base_url}/v1/chat/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
stream=True,
)
# iterate and print stream
for byte_payload in resp.iter_lines():
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = ChatCompletionChunk(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
def generate( def generate(
self, self,
prompt: str, prompt: str,
@ -313,6 +426,113 @@ class AsyncClient:
self.cookies = cookies self.cookies = cookies
self.timeout = ClientTimeout(timeout * 60) self.timeout = ClientTimeout(timeout * 60)
async def chat(
self,
messages: List[Message],
frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
stream: bool = False,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
tools: Optional[List[Tool]] = None,
tool_choice: Optional[str] = None,
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
"""
Given a list of messages, generate a response asynchronously
Args:
messages (`List[Message]`):
List of messages
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
logit_bias (`List[float]`):
Adjust the likelihood of specified tokens
logprobs (`bool`):
Include log probabilities in the response
top_logprobs (`int`):
Include the `n` most likely tokens at each step
max_tokens (`int`):
Maximum number of generated tokens
n (`int`):
Generate `n` completions
presence_penalty (`float`):
The parameter for presence penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
stream (`bool`):
Stream the response
seed (`int`):
Random sampling seed
temperature (`float`):
The value used to module the logits distribution.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation
tools (`List[Tool]`):
List of tools to use
tool_choice (`str`):
The tool to use
"""
request = ChatRequest(
model="tgi",
messages=messages,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
top_logprobs=top_logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
stream=stream,
seed=seed,
temperature=temperature,
top_p=top_p,
tools=tools,
tool_choice=tool_choice,
)
if not stream:
return await self._chat_single_response(request)
else:
return self._chat_stream_response(request)
async def _chat_single_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/chat/completions", json=request.dict()
) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return ChatComplete(**payload)
async def _chat_stream_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/chat/completions", json=request.dict()
) as resp:
async for byte_payload in resp.content:
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = ChatCompletionChunk(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
async def generate( async def generate(
self, self,
prompt: str, prompt: str,

View File

@ -1,6 +1,6 @@
from enum import Enum from enum import Enum
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from typing import Optional, List, Union from typing import Optional, List, Union, Any
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
@ -19,6 +19,124 @@ class Grammar(BaseModel):
value: Union[str, dict] value: Union[str, dict]
class ToolCall(BaseModel):
# Id of the tool call
id: int
# Type of the tool call
type: str
# Function details of the tool call
function: dict
class Message(BaseModel):
# Role of the message sender
role: str
# Content of the message
content: Optional[str]
# Optional name of the message sender
name: Optional[str] = None
# Tool calls associated with the chat completion
tool_calls: Optional[Any] = None
class Tool(BaseModel):
# Type of the tool
type: str
# Function details of the tool
function: dict
class ChatCompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
message: Message
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Any
class Function(BaseModel):
name: Optional[str]
arguments: str
class ChoiceDeltaToolCall(BaseModel):
index: int
id: str
type: str
function: Function
class ChoiceDelta(BaseModel):
role: str
content: Optional[str]
tool_calls: Optional[ChoiceDeltaToolCall]
class Choice(BaseModel):
index: int
delta: ChoiceDelta
logprobs: Optional[dict] = None
finish_reason: Optional[str] = None
class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[Choice]
class ChatComplete(BaseModel):
# Chat completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[ChatCompletionComplete]
usage: Any
class ChatRequest(BaseModel):
# Model identifier
model: str
# List of messages in the conversation
messages: List[Message]
# Penalty for frequency of new tokens
frequency_penalty: Optional[float] = None
# Bias values for token selection
logit_bias: Optional[List[float]] = None
# Whether to return log probabilities
logprobs: Optional[bool] = None
# Number of most likely tokens to return at each position
top_logprobs: Optional[int] = None
# Maximum number of tokens to generate
max_tokens: Optional[int] = None
# Number of chat completion choices to generate
n: Optional[int] = None
# Penalty for presence of new tokens
presence_penalty: Optional[float] = None
# Flag to indicate streaming response
stream: bool = False
# Random sampling seed
seed: Optional[int] = None
# Sampling temperature
temperature: Optional[float] = None
# Top-p value for nucleus sampling
top_p: Optional[float] = None
# List of tools to be used
tools: Optional[List[Tool]] = None
# Choice of tool to be used
tool_choice: Optional[str] = None
class Parameters(BaseModel): class Parameters(BaseModel):
# Activate logits sampling # Activate logits sampling
do_sample: bool = False do_sample: bool = False

View File

@ -9,6 +9,8 @@
title: Supported Models and Hardware title: Supported Models and Hardware
- local: messages_api - local: messages_api
title: Messages API title: Messages API
- local: guidance
title: Guidance
title: Getting started title: Getting started
- sections: - sections:
- local: basic_tutorials/consuming_tgi - local: basic_tutorials/consuming_tgi
@ -37,4 +39,8 @@
title: Safetensors title: Safetensors
- local: conceptual/flash_attention - local: conceptual/flash_attention
title: Flash Attention title: Flash Attention
- local: conceptual/speculation
title: Speculation (Medusa, ngram)
- local: conceptual/guidance
title: Guidance, JSON, tools (using outlines)
title: Conceptual Guides title: Conceptual Guides

View File

@ -0,0 +1,419 @@
# Guidance
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs.
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
## Quick Start
Before we jump into the deep end, ensure your system is using TGI version `1.4.3` or later to access all the features we're about to explore in this guide.
If you're not up to date, grab the latest version and let's get started!
## Table of Contents 📚
### Grammar and Constraints
- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema.
- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
### Tools and Functions
- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions.
- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions.
- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
## Grammar and Constraints 🛣️
### The Grammar Parameter
In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the AI. This is a game-changer for those who need precise control over the AI's output.
Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability.
```json
curl localhost:3000/generate \
-X POST \
-H 'Content-Type: application/json' \
-d '{
"inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park",
"parameters": {
"repetition_penalty": 1.3,
"grammar": {
"type": "json",
"value": {
"properties": {
"location": {
"type": "string"
},
"activity": {
"type": "string"
},
"animals_seen": {
"type": "integer",
"minimum": 1,
"maximum": 5
},
"animals": {
"type": "array",
"items": {
"type": "string"
}
}
},
"required": ["location", "activity", "animals_seen", "animals"]
}
}
}
}'
// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"}
```
A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compliation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
### Constrain with Pydantic
Pydantic is a powerful library for data validation and settings management. It's the perfect tool for crafting the a specific response format.
Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way.
```python
import requests
from pydantic import BaseModel, conint
from typing import List
class Animals(BaseModel):
location: str
activity: str
animals_seen: conint(ge=1, le=5) # Constrained integer type
animals: List[str]
prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park"
data = {
"inputs": prompt,
"parameters": {
"repetition_penalty": 1.3,
"grammar": {
"type": "json",
"value": Animals.schema()
}
}
}
headers = {
"Content-Type": "application/json",
}
response = requests.post(
'http://127.0.0.1:3000/generate',
headers=headers,
json=data
)
print(response.json())
# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'}
```
### JSON Schema Integration
If Pydantic's not your style, go raw with direct JSON Schema integration. It's like having a conversation with the AI in its own language. This is simliar to the first example but with programmatic control.
```python
import requests
json_schema = {
"properties": {
"location": {
"type": "string"
},
"activity": {
"type": "string"
},
"animals_seen": {
"type": "integer",
"minimum": 1,
"maximum": 5
},
"animals": {
"type": "array",
"items": {
"type": "string"
}
}
},
"required": ["location", "activity", "animals_seen", "animals"]
}
data = {
"inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]",
"parameters": {
"max_new_tokens": 200,
"repetition_penalty": 1.3,
"grammar": {
"type": "json",
"value": json_schema
}
}
}
headers = {
"Content-Type": "application/json",
}
response = requests.post(
'http://127.0.0.1:3000/generate',
headers=headers,
json=data
)
print(response.json())
# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'}
```
### Using the client
TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter.
```python
from text_generation import AsyncClient
from text_generation.types import GrammarType
# NOTE: tools defined above and removed for brevity
# Define an async function to encapsulate the async operation
async def main():
client = AsyncClient(base_url="http://localhost:3000")
# Use 'await' to wait for the async method 'chat' to complete
response = await client.generate(
"Whats Googles DNS",
max_new_tokens=10,
decoder_input_details=True,
seed=1,
grammar={
"type": GrammarType.Regex,
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
},
)
# Once the response is received, you can process it
print(response.generated_text)
# Ensure the main async function is run in the event loop
if __name__ == "__main__":
import asyncio
asyncio.run(main())
# 118.8.0.84
```
## Tools and Functions 🛠️
### The Tools Parameter
In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API.
Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the AI's capabilities. You can use these tools to perform a variety of tasks, such as data manipulation, formatting, and more.
Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
```json
curl localhost:3000/v1/chat/completions \
-X POST \
-H 'Content-Type: application/json' \
-d '{
"model": "tgi",
"messages": [
{
"role": "user",
"content": "What is the weather like in New York?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location."
}
},
"required": ["location", "format"]
}
}
}
],
"tool_choice": "get_current_weather"
}'
// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}}
```
<details>
<summary>Tools used in example below</summary>
```python
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
},
},
{
"type": "function",
"function": {
"name": "get_n_day_weather_forecast",
"description": "Get an N-day weather forecast",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
"num_days": {
"type": "integer",
"description": "The number of days to forecast",
},
},
"required": ["location", "format", "num_days"],
},
},
}
]
```
</details>
### Text Generation Inference Client
TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions.
```python
from text_generation import AsyncClient
# NOTE: tools defined above and removed for brevity
# Define an async function to encapsulate the async operation
async def main():
client = AsyncClient(base_url="http://localhost:3000")
# Use 'await' to wait for the async method 'chat' to complete
response = await client.chat(
max_tokens=100,
seed=1,
tools=tools,
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
# Once the response is received, you can process it
print(response.choices[0].message.tool_calls)
# Ensure the main async function is run in the event loop
if __name__ == "__main__":
import asyncio
asyncio.run(main())
# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}}
```
### OpenAI integration
TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.
```python
from openai import OpenAI
# Initialize the client, pointing it to one of the available models
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="_",
)
# NOTE: tools defined above and removed for brevity
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{
"role": "system",
"content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
tools=tools,
tool_choice="auto", # tool selected by model
max_tokens=500,
)
called = chat_completion.choices[0].message.tool_calls
print(called)
# {
# "id": 0,
# "type": "function",
# "function": {
# "description": None,
# "name": "tools",
# "parameters": {
# "format": "celsius",
# "location": "San Francisco, CA",
# "num_days": 3,
# },
# },
# }
```

View File

@ -0,0 +1,48 @@
## Speculation
Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea.
The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid.
So you are making *more* computations on your LLM, but if you are correct you produce 1, 2, 3 etc.. tokens on a single LLM pass. Since LLMs are usually memory bound (and not compute bound), provided your guesses are correct enough, this is a 2-3x faster inference (It can be much more for code oriented tasks for instance).
You can check a more [detailed explanation](https://huggingface.co/blog/assisted-generation).
Text-generation inference supports 2 main speculative methods:
- Medusa
- N-gram
### Medusa
Medusa is a [simple method](https://arxiv.org/abs/2401.10774) to create many tokens in a single pass using fine-tuned LM heads in addition to your existing models.
You can check a few existing fine-tunes for popular models:
- [text-generation-inference/gemma-7b-it-medusa](https://huggingface.co/text-generation-inference/gemma-7b-it-medusa)
- [text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa](https://huggingface.co/text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa)
- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa)
In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [https://github.com/FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa)
In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically.
### N-gram
If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`.
Ngram works by trying to find in the previous sequence existing tokens that match, and use those as speculation.
This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much.
In order to enable n-gram speculation simply use
`--speculate 2` in your flags.
[Details about the flag](https://huggingface.co/docs/text-generation-inference/basic_tutorials/launcher#speculate)

View File

@ -23,6 +23,8 @@ from text_generation.types import (
Token, Token,
BestOfSequence, BestOfSequence,
Grammar, Grammar,
ChatComplete,
ChatCompletionChunk,
) )
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
@ -59,6 +61,15 @@ class ResponseComparator(JSONSnapshotExtension):
) -> bool: ) -> bool:
def convert_data(data): def convert_data(data):
data = json.loads(data) data = json.loads(data)
if isinstance(data, Dict) and "choices" in data:
choices = data["choices"]
if (
isinstance(choices, List)
and len(choices) >= 1
and "delta" in choices[0]
):
return ChatCompletionChunk(**data)
return ChatComplete(**data)
if isinstance(data, Dict): if isinstance(data, Dict):
return Response(**data) return Response(**data)
@ -144,6 +155,16 @@ class ResponseComparator(JSONSnapshotExtension):
) )
) )
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
return (
response.choices[0].message.content == other.choices[0].message.content
)
def eq_chat_complete_chunk(
response: ChatCompletionChunk, other: ChatCompletionChunk
) -> bool:
return response.choices[0].delta.content == other.choices[0].delta.content
def eq_response(response: Response, other: Response) -> bool: def eq_response(response: Response, other: Response) -> bool:
return response.generated_text == other.generated_text and eq_details( return response.generated_text == other.generated_text and eq_details(
response.details, other.details response.details, other.details
@ -157,6 +178,19 @@ class ResponseComparator(JSONSnapshotExtension):
if not isinstance(snapshot_data, List): if not isinstance(snapshot_data, List):
snapshot_data = [snapshot_data] snapshot_data = [snapshot_data]
if isinstance(serialized_data[0], ChatComplete):
return len(snapshot_data) == len(serialized_data) and all(
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
)
if isinstance(serialized_data[0], ChatCompletionChunk):
return len(snapshot_data) == len(serialized_data) and all(
[
eq_chat_complete_chunk(r, o)
for r, o in zip(serialized_data, snapshot_data)
]
)
return len(snapshot_data) == len(serialized_data) and all( return len(snapshot_data) == len(serialized_data) and all(
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)] [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
) )
@ -236,6 +270,7 @@ def launcher(event_loop):
use_flash_attention: bool = True, use_flash_attention: bool = True,
disable_grammar_support: bool = False, disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
revision: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000) master_port = random.randint(10_000, 20_000)
@ -268,6 +303,9 @@ def launcher(event_loop):
if dtype is not None: if dtype is not None:
args.append("--dtype") args.append("--dtype")
args.append(dtype) args.append(dtype)
if revision is not None:
args.append("--revision")
args.append(revision)
if trust_remote_code: if trust_remote_code:
args.append("--trust-remote-code") args.append("--trust-remote-code")
@ -302,6 +340,7 @@ def launcher(event_loop):
use_flash_attention: bool = True, use_flash_attention: bool = True,
disable_grammar_support: bool = False, disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
revision: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
@ -317,6 +356,9 @@ def launcher(event_loop):
if dtype is not None: if dtype is not None:
args.append("--dtype") args.append("--dtype")
args.append(dtype) args.append(dtype)
if revision is not None:
args.append("--revision")
args.append(revision)
if trust_remote_code: if trust_remote_code:
args.append("--trust-remote-code") args.append("--trust-remote-code")

View File

@ -0,0 +1,94 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 610,
"logprob": null,
"text": "def"
},
{
"id": 1489,
"logprob": -5.2617188,
"text": " print"
},
{
"id": 100,
"logprob": -0.38476562,
"text": "_"
},
{
"id": 7670,
"logprob": -7.640625,
"text": "hello"
}
],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.92626953,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40844727,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27905273,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6118164,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68652344,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4619141,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.7993164,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.63134766,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23278809,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2294922,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
}

View File

@ -0,0 +1,394 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 60,
"prefill": [
{
"id": 610,
"logprob": null,
"text": "def"
},
{
"id": 1489,
"logprob": -5.2617188,
"text": " print"
},
{
"id": 100,
"logprob": -0.38476562,
"text": "_"
},
{
"id": 7670,
"logprob": -7.640625,
"text": "hello"
}
],
"seed": 0,
"tokens": [
{
"id": 2284,
"logprob": -0.296875,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": 0.0,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.28125,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -0.79248047,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.61816406,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.0619812,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -0.4091797,
"special": false,
"text": "def"
},
{
"id": 1489,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 7670,
"logprob": 0.0,
"special": false,
"text": "hello"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 444,
"logprob": -0.21655273,
"special": false,
"text": "name"
},
{
"id": 45,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 444,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 731,
"logprob": 0.0,
"special": false,
"text": "):"
},
{
"id": 303,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": 0.0,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 332,
"logprob": -0.034698486,
"special": false,
"text": " \""
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 655,
"logprob": 0.0,
"special": false,
"text": " name"
},
{
"id": 494,
"logprob": -0.20141602,
"special": false,
"text": " +"
},
{
"id": 332,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 16013,
"logprob": 0.0,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": 0.0,
"special": false,
"text": "def"
},
{
"id": 1489,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 7670,
"logprob": 0.0,
"special": false,
"text": "hello"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 444,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 400,
"logprob": 0.0,
"special": false,
"text": "age"
},
{
"id": 45,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 444,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 49,
"logprob": 0.0,
"special": false,
"text": ","
},
{
"id": 11505,
"logprob": 0.0,
"special": false,
"text": " age"
},
{
"id": 731,
"logprob": 0.0,
"special": false,
"text": "):"
},
{
"id": 303,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": 0.0,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 332,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 655,
"logprob": 0.0,
"special": false,
"text": " name"
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 3021,
"logprob": -0.5761719,
"special": false,
"text": " \","
},
{
"id": 863,
"logprob": 0.0,
"special": false,
"text": " you"
},
{
"id": 904,
"logprob": 0.0,
"special": false,
"text": " are"
},
{
"id": 332,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 615,
"logprob": 0.0,
"special": false,
"text": " str"
},
{
"id": 45,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 400,
"logprob": 0.0,
"special": false,
"text": "age"
},
{
"id": 46,
"logprob": 0.0,
"special": false,
"text": ")"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)"
}

View File

@ -0,0 +1,378 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 610,
"logprob": null,
"text": "def"
},
{
"id": 1489,
"logprob": -5.2617188,
"text": " print"
},
{
"id": 100,
"logprob": -0.38476562,
"text": "_"
},
{
"id": 7670,
"logprob": -7.640625,
"text": "hello"
}
],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.92626953,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40722656,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27954102,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6142578,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68310547,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4570312,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.80126953,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.6303711,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23327637,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2304688,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 610,
"logprob": null,
"text": "def"
},
{
"id": 1489,
"logprob": -5.2617188,
"text": " print"
},
{
"id": 100,
"logprob": -0.38476562,
"text": "_"
},
{
"id": 7670,
"logprob": -7.640625,
"text": "hello"
}
],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.92626953,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40722656,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27954102,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6142578,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68310547,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4570312,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.80126953,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.6303711,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23327637,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2304688,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 610,
"logprob": null,
"text": "def"
},
{
"id": 1489,
"logprob": -5.2617188,
"text": " print"
},
{
"id": 100,
"logprob": -0.38476562,
"text": "_"
},
{
"id": 7670,
"logprob": -7.640625,
"text": "hello"
}
],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.92626953,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40722656,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27954102,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6142578,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68310547,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4570312,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.80126953,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.6303711,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23327637,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2304688,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 610,
"logprob": null,
"text": "def"
},
{
"id": 1489,
"logprob": -5.2617188,
"text": " print"
},
{
"id": 100,
"logprob": -0.38476562,
"text": "_"
},
{
"id": 7670,
"logprob": -7.640625,
"text": "hello"
}
],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.92626953,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40722656,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27954102,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6142578,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68310547,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4570312,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.80126953,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.6303711,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23327637,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2304688,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
}
]

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1708957015,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 100,
"prompt_tokens": 60,
"total_tokens": 160
}
}

View File

@ -0,0 +1,38 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
},
"usage": null
}
],
"created": 1709079417,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 29,
"prompt_tokens": 316,
"total_tokens": 345
}
}

View File

@ -0,0 +1,38 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
},
"usage": null
}
],
"created": 1709079492,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 29,
"prompt_tokens": 316,
"total_tokens": 345
}
}

View File

@ -0,0 +1,37 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY"
}
},
"id": 0,
"type": "function"
}
},
"usage": null
}
],
"created": 1709079493,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 21,
"prompt_tokens": 187,
"total_tokens": 208
}
}

View File

@ -0,0 +1,27 @@
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "</s>",
"name": null
},
"id": "",
"index": 20,
"type": "function"
}
},
"finish_reason": "eos_token",
"index": 20,
"logprobs": null
}
],
"created": 1709087088,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native"
}

View File

@ -3,7 +3,9 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_medusa_handle(launcher): def flash_medusa_handle(launcher):
with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle: with launcher(
"FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2, revision="refs/pr/1"
) as handle:
yield handle yield handle

View File

@ -0,0 +1,55 @@
import pytest
@pytest.fixture(scope="module")
def flash_starcoder2_handle(launcher):
with launcher("bigcode/starcoder2-3b", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_starcoder2(flash_starcoder2_handle):
await flash_starcoder2_handle.health(300)
return flash_starcoder2_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"def print_hello", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"def print_hello",
max_new_tokens=60,
temperature=0.2,
top_p=0.95,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 60
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder2_load(
flash_starcoder2, generate_load, response_snapshot
):
responses = await generate_load(
flash_starcoder2, "def print_hello", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -0,0 +1,240 @@
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
def flash_llama_grammar_tools_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
await flash_llama_grammar_tools_handle.health(300)
return flash_llama_grammar_tools_handle.client
# tools to be used in the following tests
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
},
},
{
"type": "function",
"function": {
"name": "get_n_day_weather_forecast",
"description": "Get an N-day weather forecast",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
"num_days": {
"type": "integer",
"description": "The number of days to forecast",
},
},
"required": ["location", "format", "num_days"],
},
},
},
]
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_no_tools(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert (
response.choices[0].message.content
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
)
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == {
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0,
"type": "function",
}
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_auto(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
tool_choice="auto",
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == {
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0,
"type": "function",
}
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_choice(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == {
"id": 0,
"type": "function",
"function": {
"description": None,
"name": "tools",
"parameters": {"format": "celsius", "location": "New York, NY"},
},
}
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Paris, France?",
},
],
stream=True,
)
count = 0
async for response in responses:
count += 1
assert count == 20
assert response == response_snapshot

View File

@ -230,7 +230,6 @@ message WarmupRequest {
uint32 max_total_tokens = 4; uint32 max_total_tokens = 4;
} }
/// Empty response
message WarmupResponse { message WarmupResponse {
/// Maximum number of tokens supported by the model /// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1; optional uint32 max_supported_total_tokens = 1;

View File

@ -812,23 +812,27 @@ mod tests {
messages: vec![ messages: vec![
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi!".to_string(), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "Hello how can I help?".to_string(), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "What is Deep Learning?".to_string(), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "magic!".to_string(), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -877,28 +881,33 @@ mod tests {
messages: vec![ messages: vec![
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi!".to_string(), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi again!".to_string(), content: Some("Hi again!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "Hello how can I help?".to_string(), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "What is Deep Learning?".to_string(), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "magic!".to_string(), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -952,23 +961,27 @@ mod tests {
messages: vec![ messages: vec![
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi!".to_string(), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "Hello how can I help?".to_string(), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "What is Deep Learning?".to_string(), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "magic!".to_string(), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1006,23 +1019,27 @@ mod tests {
messages: vec![ messages: vec![
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi!".to_string(), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "Hello how can I help?".to_string(), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "What is Deep Learning?".to_string(), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "magic!".to_string(), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),

View File

@ -358,10 +358,11 @@ impl ChatCompletion {
pub(crate) fn new( pub(crate) fn new(
model: String, model: String,
system_fingerprint: String, system_fingerprint: String,
output: String, output: Option<String>,
created: u64, created: u64,
details: Details, details: Details,
return_logprobs: bool, return_logprobs: bool,
tool_calls: Option<ToolCall>,
) -> Self { ) -> Self {
Self { Self {
id: String::new(), id: String::new(),
@ -375,6 +376,7 @@ impl ChatCompletion {
role: "assistant".into(), role: "assistant".into(),
content: output, content: output,
name: None, name: None,
tool_calls,
}, },
logprobs: return_logprobs logprobs: return_logprobs
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
@ -413,15 +415,35 @@ pub(crate) struct ChatCompletionChoice {
pub(crate) struct ChatCompletionDelta { pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")] #[schema(example = "user")]
pub role: String, pub role: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
pub content: String, pub content: Option<String>,
// default to None
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<DeltaToolCall>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub(crate) struct DeltaToolCall {
pub index: u32,
pub id: String,
pub r#type: String,
pub function: Function,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub(crate) struct Function {
pub name: Option<String>,
pub arguments: String,
}
#[allow(clippy::too_many_arguments)]
impl ChatCompletionChunk { impl ChatCompletionChunk {
pub(crate) fn new( pub(crate) fn new(
model: String, model: String,
system_fingerprint: String, system_fingerprint: String,
delta: String, delta: Option<String>,
tool_calls: Option<Vec<String>>,
created: u64, created: u64,
index: u32, index: u32,
logprobs: Option<ChatCompletionLogprobs>, logprobs: Option<ChatCompletionLogprobs>,
@ -438,6 +460,15 @@ impl ChatCompletionChunk {
delta: ChatCompletionDelta { delta: ChatCompletionDelta {
role: "assistant".to_string(), role: "assistant".to_string(),
content: delta, content: delta,
tool_calls: tool_calls.map(|tc| DeltaToolCall {
index,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tc[0].to_string(),
},
}),
}, },
logprobs, logprobs,
finish_reason, finish_reason,
@ -520,6 +551,125 @@ pub(crate) struct ChatRequest {
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = 0.95)] #[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>, pub top_p: Option<f32>,
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
/// functions the model may generate JSON inputs for.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tools: Option<Vec<Tool>>,
/// A prompt to be appended before the tools
#[serde(default = "default_tool_prompt")]
#[schema(
nullable = true,
example = "\"Based on the conversation, please choose the most appropriate tool to use: \""
)]
pub tool_prompt: Option<String>,
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, example = "null")]
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
pub tool_choice: Option<ToolType>,
}
fn default_tool_prompt() -> Option<String> {
Some(
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
)
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
enum ToolType {
FunctionName(String),
OneOf,
}
/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None)
mod deserialize_tool_choice {
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ToolType>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => match s.as_str() {
"none" => Ok(None),
"auto" => Ok(Some(ToolType::OneOf)),
_ => Ok(Some(ToolType::FunctionName(s))),
},
Value::Object(map) => {
if let Some(content) = map
.get("function")
.and_then(|v| v.get("name"))
.and_then(|v| v.as_str())
{
Ok(Some(ToolType::FunctionName(content.to_string())))
} else {
Err(de::Error::custom("function key not found in tool choice"))
}
}
Value::Null => Ok(Some(ToolType::OneOf)),
_ => Err(de::Error::custom("invalid token format")),
}
}
}
#[derive(Debug, Deserialize, Serialize, ToSchema)]
pub struct Tools {
#[serde(flatten)]
functions_map: FunctionsMap,
properties: Properties,
}
#[derive(Debug, Serialize, Deserialize)]
struct FunctionsMap {
#[serde(rename = "$functions")]
functions: std::collections::HashMap<String, serde_json::Value>,
}
#[derive(Debug, Serialize, Deserialize)]
struct FunctionRef {
#[serde(rename = "$ref")]
ref_path: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct Properties {
#[serde(serialize_with = "serialize_function")]
function: Vec<FunctionRef>,
}
fn serialize_function<S>(functions: &Vec<FunctionRef>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("Function", 1)?;
state.serialize_field("anyOf", functions)?;
state.end()
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
pub(crate) struct FunctionDefinition {
#[serde(default)]
pub description: Option<String>,
pub name: String,
pub parameters: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct Tool {
// The type of the tool. Currently, only 'function' is supported.
#[schema(example = "function")]
pub r#type: String,
// Grab the tool as generic JSON for debugging purposes.
pub function: FunctionDefinition,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
@ -530,15 +680,25 @@ pub(crate) struct ChatTemplateInputs<'a> {
add_generation_prompt: bool, add_generation_prompt: bool,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct ToolCall {
pub id: u32,
pub r#type: String,
pub function: FunctionDefinition,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct Message { pub(crate) struct Message {
#[schema(example = "user")] #[schema(example = "user")]
pub role: String, pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(example = "My name is David and I")] #[schema(example = "My name is David and I")]
pub content: String, pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")] #[schema(example = "\"David\"")]
pub name: Option<String>, pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<ToolCall>,
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]

View File

@ -10,6 +10,7 @@ use crate::{
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
}; };
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
@ -22,6 +23,8 @@ use futures::stream::StreamExt;
use futures::Stream; use futures::Stream;
use futures::TryStreamExt; use futures::TryStreamExt;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value;
use std::collections::HashMap;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
@ -239,7 +242,7 @@ async fn generate(
headers.insert("x-compute-type", compute_type.parse().unwrap()); headers.insert("x-compute-type", compute_type.parse().unwrap());
headers.insert( headers.insert(
"x-compute-time", "x-compute-time",
total_time.as_millis().to_string().parse().unwrap(), total_time.as_secs_f64().to_string().parse().unwrap(),
); );
headers.insert( headers.insert(
"x-compute-characters", "x-compute-characters",
@ -581,7 +584,7 @@ async fn chat_completions(
let seed = req.seed; let seed = req.seed;
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(req.messages) { let mut inputs = match infer.apply_chat_template(req.messages) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
@ -596,6 +599,62 @@ async fn chat_completions(
} }
}; };
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
let tool_prompt = req.tool_prompt.unwrap_or_default();
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.ok_or_else(|| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Tool choice not found in tool names".to_string(),
error_type: "Tool not found".to_string(),
}),
)
})?
.clone()]
}
ToolType::OneOf => req_tools.to_owned(),
};
let functions: HashMap<String, Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
(func.name, func.parameters)
})
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.collect(),
},
};
let tools_str = serde_json::to_string(&tools).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;
inputs = format!("{inputs}{tool_prompt}{tools_str}");
Some(GrammarType::Json(serde_json::json!(tools)))
} else {
None
};
// build the request passing some parameters // build the request passing some parameters
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: inputs.to_string(), inputs: inputs.to_string(),
@ -617,7 +676,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: tool_grammar.clone(),
}, },
}; };
@ -640,11 +699,19 @@ async fn chat_completions(
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens)) ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
}); });
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if tool_grammar.is_some() {
(None, Some(vec![stream_token.token.text]))
} else {
(Some(stream_token.token.text), None)
};
event event
.json_data(ChatCompletionChunk::new( .json_data(ChatCompletionChunk::new(
model_id.clone(), model_id.clone(),
system_fingerprint.clone(), system_fingerprint.clone(),
stream_token.token.text, content,
tool_calls,
current_time, current_time,
stream_token.index, stream_token.index,
logprobs, logprobs,
@ -681,14 +748,54 @@ async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
let (tool_calls, output) = if tool_grammar.is_some() {
// gen_text should be valid json
let gen_text_value: Value =
serde_json::from_str(&generation.generated_text).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;
let tool_call = Some(ToolCall {
id: 0,
r#type: "function".to_string(),
function: FunctionDefinition {
description: None,
name: "tools".to_string(),
parameters: gen_text_value.get("function").map_or_else(
|| {
serde_json::from_str(&generation.generated_text).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})
},
|f| Ok(f.clone()),
)?,
},
});
(tool_call, None)
} else {
(None, Some(generation.generated_text))
};
// build the complete response object with the full text // build the complete response object with the full text
let response = ChatCompletion::new( let response = ChatCompletion::new(
model_id, model_id,
system_fingerprint, system_fingerprint,
generation.generated_text, output,
current_time, current_time,
generation.details.unwrap(), generation.details.unwrap(),
logprobs, logprobs,
tool_calls,
); );
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference

View File

@ -154,12 +154,8 @@ def download_weights(
import json import json
medusa_head = hf_hub_download( medusa_head = hf_hub_download(
model_id, revision=revision, filename="medusa_lm_head.pt" model_id, revision=revision, filename="medusa_lm_head.safetensors"
) )
if auto_convert:
medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors")
if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = hf_hub_download( medusa_config = hf_hub_download(
model_id, revision=revision, filename="config.json" model_id, revision=revision, filename="config.json"
) )
@ -198,16 +194,12 @@ def download_weights(
if not extension == ".safetensors" or not auto_convert: if not extension == ".safetensors" or not auto_convert:
raise e raise e
elif (Path(model_id) / "medusa_lm_head.pt").exists(): elif (Path(model_id) / "medusa_lm_head.safetensors").exists():
# Try to load as a local Medusa model # Try to load as a local Medusa model
try: try:
import json import json
medusa_head = Path(model_id) / "medusa_lm_head.pt" medusa_head = Path(model_id) / "medusa_lm_head.safetensors"
if auto_convert:
medusa_sf = Path(model_id) / "medusa_lm_head.safetensors"
if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = Path(model_id) / "config.json" medusa_config = Path(model_id) / "config.json"
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)

View File

@ -3,7 +3,9 @@ import torch
from loguru import logger from loguru import logger
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download
from typing import Optional from typing import Optional
from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
@ -65,6 +67,7 @@ try:
from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
except ImportError as e: except ImportError as e:
@ -82,6 +85,7 @@ if FLASH_ATTENTION:
__all__.append(FlashMixtral) __all__.append(FlashMixtral)
__all__.append(FlashPhi) __all__.append(FlashPhi)
__all__.append(FlashQwen2) __all__.append(FlashQwen2)
__all__.append(FlashStarcoder2)
MAMBA_AVAILABLE = True MAMBA_AVAILABLE = True
try: try:
@ -119,44 +123,14 @@ def get_model(
else: else:
set_speculate(0) set_speculate(0)
if "facebook/galactica" in model_id:
return GalacticaSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_id.startswith("bigcode/"):
if FLASH_ATTENTION:
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
return SantaCoder(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
use_medusa = None use_medusa = None
if "medusa_num_heads" in config_dict: if "medusa_num_heads" in config_dict:
use_medusa = model_id medusa_model_id = model_id
medusa_revision = revision
model_id = config_dict["base_model_name_or_path"] model_id = config_dict["base_model_name_or_path"]
revision = "main" revision = "main"
speculate_medusa = config_dict["medusa_num_heads"] speculate_medusa = config_dict["medusa_num_heads"]
@ -173,6 +147,20 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
is_local = Path(medusa_model_id).exists()
if not is_local:
medusa_config = hf_hub_download(
medusa_model_id, revision=medusa_revision, filename="config.json"
)
hf_hub_download(
medusa_model_id,
revision=medusa_revision,
filename="medusa_lm_head.safetensors",
)
use_medusa = Path(medusa_config).parent
else:
use_medusa = Path(medusa_model_id)
method = "medusa" method = "medusa"
else: else:
method = "n-gram" method = "n-gram"
@ -197,16 +185,32 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "gpt_bigcode": if model_id.startswith("facebook/galactica"):
return GalacticaSharded(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if (
model_type == "gpt_bigcode"
or model_type == "gpt2"
and model_id.startswith("bigcode/")
):
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashSantacoderSharded( return FlashSantacoderSharded(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -219,6 +223,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -228,6 +233,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -236,6 +242,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -246,6 +253,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -254,6 +262,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -262,6 +271,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -272,15 +282,16 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
) )
else: else:
return CausalLM( return CausalLM(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -295,6 +306,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -305,9 +317,9 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
@ -316,6 +328,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -346,9 +359,9 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
) )
elif sharded: elif sharded:
raise NotImplementedError( raise NotImplementedError(
@ -359,6 +372,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -372,6 +386,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -382,6 +397,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -390,6 +406,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -403,6 +420,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -413,6 +431,19 @@ def get_model(
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
) or HAS_FLASH_ATTN_V2_CUDA: ) or HAS_FLASH_ATTN_V2_CUDA:
return FlashMixtral( return FlashMixtral(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "starcoder2":
sliding_window = config_dict.get("sliding_window", -1)
if (
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
) or HAS_FLASH_ATTN_V2_CUDA:
return FlashStarcoder2(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -425,6 +456,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -434,6 +466,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -443,6 +476,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -466,6 +500,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -474,6 +509,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -485,6 +521,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -493,6 +530,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -42,6 +42,7 @@ class BLOOMSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -70,6 +71,7 @@ class BLOOMSharded(CausalLM):
) )
config.pad_token_id = 3 config.pad_token_id = 3
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
@ -103,7 +105,7 @@ class BLOOMSharded(CausalLM):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):
outputs = self.model.forward( outputs, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
@ -112,4 +114,4 @@ class BLOOMSharded(CausalLM):
) )
logits = outputs.logits logits = outputs.logits
return logits, outputs.past_key_values return logits, speculative_logits, outputs.past_key_values

View File

@ -482,6 +482,7 @@ class CausalLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -550,7 +551,9 @@ class CausalLM(Model):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[
torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
]:
# Model Forward # Model Forward
kwargs = { kwargs = {
"input_ids": input_ids, "input_ids": input_ids,
@ -563,7 +566,11 @@ class CausalLM(Model):
kwargs["position_ids"] = position_ids kwargs["position_ids"] = position_ids
outputs = self.model.forward(**kwargs) outputs = self.model.forward(**kwargs)
return outputs.logits, outputs.past_key_values if isinstance(outputs, tuple):
outputs, speculative_logits = outputs
else:
speculative_logits = None
return outputs.logits, speculative_logits, outputs.past_key_values
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
@ -573,7 +580,7 @@ class CausalLM(Model):
# slice the attention mask to the correct shape # slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
logits, past = self.forward( logits, speculative_logits, past = self.forward(
batch.input_ids, batch.input_ids,
attention_mask, attention_mask,
batch.position_ids, batch.position_ids,

View File

@ -36,7 +36,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
) )
CUSTOM_KERNELS_ENABLED = False CUSTOM_KERNELS_ENABLED = False
@ -820,7 +820,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
super().__init__(config) super().__init__(config)
self.transformer = BloomModel(config, weights) self.transformer = BloomModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="word_embeddings", prefix="word_embeddings",
weights=weights, weights=weights,
@ -904,17 +904,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
loss = None loss = None
if not return_dict: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions( return (
CausalLMOutputWithCrossAttentions(
loss=loss, loss=loss,
logits=lm_logits, logits=logits,
past_key_values=transformer_outputs.past_key_values, past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
),
speculative_logits,
) )

View File

@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastRMSNorm, FastRMSNorm,
) )
@ -575,7 +575,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = FlashGemmaModel(config, weights) self.model = FlashGemmaModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
weights=weights, weights=weights,
@ -592,7 +592,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
@ -605,5 +605,5 @@ class FlashGemmaForCausalLM(torch.nn.Module):
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
return logits return logits, speculative_logits

View File

@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastRMSNorm, FastRMSNorm,
) )
@ -410,7 +410,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = FlashLlamaModel(config, weights) self.model = FlashLlamaModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,
@ -427,7 +427,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
@ -440,5 +440,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
return logits return logits, speculative_logits

View File

@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastRMSNorm, FastRMSNorm,
) )
@ -419,7 +419,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = MistralModel(config, weights) self.model = MistralModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,

View File

@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
) )
@ -810,7 +810,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = MixtralModel(config, weights) self.model = MixtralModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,

View File

@ -33,7 +33,7 @@ from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead, SpeculativeHead,
FastLayerNorm, FastLayerNorm,
PositionRotaryEmbedding, PositionRotaryEmbedding,
get_linear, get_linear,
@ -369,7 +369,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
super().__init__(config) super().__init__(config)
self.gpt_neox = FlashGPTNeoXModel(config, weights) self.gpt_neox = FlashGPTNeoXModel(config, weights)
self.embed_out = TensorParallelHead.load( self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights config, prefix="embed_out", weights=weights
) )

View File

@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastLayerNorm, FastLayerNorm,
) )
@ -376,7 +376,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = FlashPhiModel(config, weights) self.model = FlashPhiModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,

View File

@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead, SpeculativeHead,
FastLayerNorm, FastLayerNorm,
PositionRotaryEmbedding, PositionRotaryEmbedding,
get_linear, get_linear,
@ -613,9 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
self.transformer = FlashRWModel(config, weights) self.transformer = FlashRWModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
config, prefix="lm_head", weights=weights
)
def forward( def forward(
self, self,

View File

@ -9,7 +9,7 @@ from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelHead, SpeculativeHead,
TensorParallelEmbedding, TensorParallelEmbedding,
FastLayerNorm, FastLayerNorm,
get_linear, get_linear,
@ -453,7 +453,7 @@ class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
self.transformer = FlashSantacoderModel(config, weights) self.transformer = FlashSantacoderModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix="transformer.wte", weights=weights
) )

View File

@ -0,0 +1,545 @@
# coding=utf-8
# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
SpeculativeHead,
get_linear,
FastRMSNorm,
FastLayerNorm,
)
class Starcoder2Config(PretrainedConfig):
model_type = "starcoder2"
def __init__(
self,
vocab_size=49152,
hidden_size=3072,
intermediate_size=12288,
num_hidden_layers=30,
num_attention_heads=24,
num_key_value_heads=2,
mlp_type="default",
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=4096,
initializer_range=0.018042,
norm_type="layer_norm",
norm_epsilon=1e-5,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
rope_theta=10000.0,
sliding_window=None,
attention_dropout=0.0,
residual_dropout=0.0,
embedding_dropout=0.0,
use_bias: bool = True,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.use_bias = use_bias
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.mlp_type = mlp_type
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.norm_type = norm_type
self.norm_epsilon = norm_epsilon
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.embedding_dropout = embedding_dropout
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=config.use_bias,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
if config.use_bias:
w = [
weights.get_sharded(f"{p}.bias", dim=0)
for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
]
bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
else:
bias = None
return TensorParallelColumnLinear(
get_linear(weight, bias=bias, quantize=config.quantize)
)
class Starcoder2Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=config.use_bias,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
paged_attention.reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class Starcoder2MLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
self.c_fc = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.c_fc",
weights=weights,
bias=config.use_bias,
)
self.c_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=config.use_bias,
)
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
return self.c_proj(hidden_states)
class Starcoder2GatedMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=config.use_bias,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=config.use_bias,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
STARCODER2_NORMALIZATION_CLASSES = {
"layer_norm": FastLayerNorm,
"rms_norm": FastRMSNorm,
}
STARCODER2_MLP_CLASSES = {
"default": Starcoder2MLP,
"gated": Starcoder2GatedMLP,
}
class Starcoder2Layer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = Starcoder2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
prefix=f"{prefix}.mlp", config=config, weights=weights
)
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.norm_epsilon
)
self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[
config.norm_type
].load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.norm_epsilon,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output)
return mlp_output, attn_res
class Starcoder2Model(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
Starcoder2Layer(
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
prefix="model.norm", weights=weights, eps=config.norm_epsilon
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashStarcoder2ForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.model = Starcoder2Model(config, weights)
try:
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",
weights=weights,
)
except RuntimeError:
self.lm_head = SpeculativeHead.load(
config,
prefix="model.embed_tokens",
weights=weights,
)
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits

View File

@ -51,7 +51,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
PositionRotaryEmbedding, PositionRotaryEmbedding,
FastLinear, FastLinear,
) )
@ -272,9 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
weights, weights,
) -> None: ) -> None:
super().__init__() super().__init__()
self.fc = TensorParallelHead.load( self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights)
config=config, prefix="lm_head", weights=weights
)
self.additional_fc = FastLinear.load( self.additional_fc = FastLinear.load(
config=config, config=config,
prefix="lm_head.additional_fc", prefix="lm_head.additional_fc",
@ -283,11 +281,11 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
output = self.fc(input) output, speculative_logits = self.fc(input)
additional_features = self.additional_fc(input) additional_features = self.additional_fc(input)
output = torch.cat((output, additional_features), -1) output = torch.cat((output, additional_features), -1)
return output return output, speculative_logits
def extra_repr(self) -> str: def extra_repr(self) -> str:
"""Overwriting `nn.Linear.extra_repr` to include new parameters.""" """Overwriting `nn.Linear.extra_repr` to include new parameters."""
@ -1503,17 +1501,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
loss = None loss = None
return CausalLMOutputWithPastImage( return (
CausalLMOutputWithPastImage(
loss=loss, loss=loss,
logits=logits, logits=logits,
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states, image_hidden_states=outputs.image_hidden_states,
),
speculative_logits,
) )
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):

View File

@ -9,6 +9,7 @@ from transformers.configuration_utils import PretrainedConfig
import torch.nn.functional as F import torch.nn.functional as F
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
SpeculativeHead,
TensorParallelEmbedding, TensorParallelEmbedding,
FastRMSNorm, FastRMSNorm,
FastLinear, FastLinear,
@ -205,14 +206,12 @@ class MambaModel(nn.Module):
self.norm_f = FastRMSNorm.load( self.norm_f = FastRMSNorm.load(
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
) )
self.lm_head = FastLinear.load( self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
config, f"{prefix}.embedding", weights, bias=False
)
self.config = config self.config = config
def forward( def forward(
self, input_ids: torch.Tensor, inference_params=None, residual=None self, input_ids: torch.Tensor, inference_params=None, residual=None
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
hidden_states, residual, conv_state, ssm_state = block( hidden_states, residual, conv_state, ssm_state = block(
@ -226,8 +225,8 @@ class MambaModel(nn.Module):
) )
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
hidden_states = hidden_states.view(residual.shape) hidden_states = hidden_states.view(residual.shape)
logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
# update the offset for the next inference using these params # update the offset for the next inference using these params
inference_params.seqlen_offset += input_ids.size(1) inference_params.seqlen_offset += input_ids.size(1)
return logits return logits, speculative_logits

View File

@ -21,7 +21,7 @@ from text_generation_server.utils.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
) )
@ -1090,7 +1090,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
if not config.tie_word_embeddings: if not config.tie_word_embeddings:
raise ValueError("MPTForCausalLM only supports tied word embeddings") raise ValueError("MPTForCausalLM only supports tied word embeddings")
self.transformer = MPTModel(config, weights) self.transformer = MPTModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix="transformer.wte", weights=weights
) )
self.logit_scale = None self.logit_scale = None
@ -1133,7 +1133,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
use_cache=use_cache, use_cache=use_cache,
) )
logits = self.lm_head(outputs.last_hidden_state) logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
if self.logit_scale is not None: if self.logit_scale is not None:
if self.logit_scale == 0: if self.logit_scale == 0:
warnings.warn( warnings.warn(
@ -1147,12 +1147,15 @@ class MPTForCausalLM(MPTPreTrainedModel):
loss = F.cross_entropy( loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
) )
return CausalLMOutputWithPast( return (
CausalLMOutputWithPast(
loss=loss, loss=loss,
logits=logits, logits=logits,
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
),
speculative_logits,
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(

View File

@ -44,7 +44,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
) )
@ -646,7 +646,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.gpt_neox = GPTNeoXModel(config, weights) self.gpt_neox = GPTNeoXModel(config, weights)
self.embed_out = TensorParallelHead.load( self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights config, prefix="embed_out", weights=weights
) )

View File

@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
) )
EPS = 1e-5 EPS = 1e-5
@ -748,7 +748,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
self.model = OPTModel(config, weights) self.model = OPTModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="model.decoder.embed_tokens", weights=weights config, prefix="model.decoder.embed_tokens", weights=weights
) )

View File

@ -13,7 +13,7 @@ from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead, SpeculativeHead,
FastLinear, FastLinear,
) )
@ -120,7 +120,7 @@ class PhiCausalLMHead(nn.Module):
weights=weights, weights=weights,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
self.linear = TensorParallelHead.load( self.linear = SpeculativeHead.load(
config=config, prefix="lm_head.linear", weights=weights config=config, prefix="lm_head.linear", weights=weights
) )

View File

@ -42,7 +42,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
) )
@ -1033,14 +1033,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
) )
try: try:
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="lm_head", weights=weights config, prefix="lm_head", weights=weights
) )
except RuntimeError: except RuntimeError:
# Some models like t5-small were saved with shared weights unlike flan # Some models like t5-small were saved with shared weights unlike flan
# Since they are declared as the same arch we have no choice but hope # Since they are declared as the same arch we have no choice but hope
# that this is OK instead of using a proper flag. # that this is OK instead of using a proper flag.
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="shared", weights=weights config, prefix="shared", weights=weights
) )
@ -1126,7 +1126,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5) sequence_output = sequence_output * (self.model_dim**-0.5)
lm_logits = self.lm_head(sequence_output) logits, speculative_logits = self.lm_head(sequence_output)
loss = None loss = None
if labels is not None: if labels is not None:
@ -1140,9 +1140,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput( return (
Seq2SeqLMOutput(
loss=loss, loss=loss,
logits=lm_logits, logits=logits,
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
@ -1150,6 +1151,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
),
speculative_logits,
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(

View File

@ -723,7 +723,7 @@ class FlashCausalLM(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
@ -734,6 +734,8 @@ class FlashCausalLM(Model):
max_s=max_s, max_s=max_s,
lm_head_indices=None, lm_head_indices=None,
) )
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize() torch.cuda.synchronize()
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
@ -805,7 +807,9 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor: def forward(
self, batch: FlashCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids = batch.input_ids input_ids = batch.input_ids
@ -900,9 +904,14 @@ class FlashCausalLM(Model):
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
# Slice output to the correct shape # Slice output to the correct shape
return cuda_graph["logits"][:bs] speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
@ -926,16 +935,11 @@ class FlashCausalLM(Model):
batch.slots = slots batch.slots = slots
try: try:
out = self.forward(batch) out, speculative_logits = self.forward(batch)
except Exception as e: except Exception as e:
del batch del batch
raise e raise e
if isinstance(out, tuple):
out, speculative_logits = out
else:
speculative_logits = None
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out out[batch.prefill_next_token_indices] if prefill_logprobs else out

View File

@ -25,9 +25,9 @@ class FlashGemma(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -50,6 +50,7 @@ class FlashGemma(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -59,36 +60,6 @@ class FlashGemma(FlashCausalLM):
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashGemmaForCausalLM(config, weights) model = FlashGemmaForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__( super(FlashGemma, self).__init__(

View File

@ -26,9 +26,9 @@ class FlashLlama(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -58,6 +58,7 @@ class FlashLlama(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -67,37 +68,6 @@ class FlashLlama(FlashCausalLM):
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashLlamaForCausalLM(config, weights) model = FlashLlamaForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__( super(FlashLlama, self).__init__(
model=model, model=model,

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.models.llama import LlamaTokenizerFast from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type, List from typing import Optional, Tuple, Type
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
@ -38,6 +38,19 @@ SLIDING_WINDOW_BLOCKS: Optional[int] = None
MEM_POOL = torch.cuda.graph_pool_handle() MEM_POOL = torch.cuda.graph_pool_handle()
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
SLIDING_WINDOW = sliding_window
SLIDING_WINDOW_BLOCKS = sliding_window_blocks
def get_sliding_windows() -> Tuple[int, int]:
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS
# Adds windowing logic to FlashCausalLMBatch # Adds windowing logic to FlashCausalLMBatch
@dataclass @dataclass
class FlashMistralBatch(FlashCausalLMBatch): class FlashMistralBatch(FlashCausalLMBatch):
@ -53,8 +66,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
global SLIDING_WINDOW sliding_window, sliding_window_blocks = get_sliding_windows()
global SLIDING_WINDOW_BLOCKS
batch_inputs = [] batch_inputs = []
max_truncation = 0 max_truncation = 0
@ -139,8 +151,8 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS # Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
if SLIDING_WINDOW_BLOCKS is not None: if sliding_window_blocks is not None:
needed_blocks = min(needed_blocks, SLIDING_WINDOW_BLOCKS) needed_blocks = min(needed_blocks, sliding_window_blocks)
blocks += needed_blocks blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens)) needed_blocks_slots.append((needed_blocks, total_tokens))
@ -154,9 +166,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
slot_indices.append(request_slot_indices) slot_indices.append(request_slot_indices)
# Create tensor to slice into the kv tensor in prefill # Create tensor to slice into the kv tensor in prefill
if SLIDING_WINDOW is not None: if sliding_window is not None:
request_prefill_cache_indices = torch.arange( request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - SLIDING_WINDOW), cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length, cumulative_length + input_length,
dtype=torch.int64, dtype=torch.int64,
) )
@ -212,13 +224,13 @@ class FlashMistralBatch(FlashCausalLMBatch):
input_ids = np.concatenate(all_input_ids, dtype=np.int64) input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids) position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices) slot_indices = torch.cat(slot_indices)
if SLIDING_WINDOW is not None: if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices) prefill_cache_indices = torch.cat(prefill_cache_indices)
else: else:
input_ids = all_input_ids[0] input_ids = all_input_ids[0]
position_ids = position_ids[0] position_ids = position_ids[0]
slot_indices = slot_indices[0] slot_indices = slot_indices[0]
if SLIDING_WINDOW is not None: if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0] prefill_cache_indices = prefill_cache_indices[0]
cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill = torch.tensor(
@ -228,7 +240,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
position_ids = position_ids.to(device) position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
prefill_cache_indices = ( prefill_cache_indices = (
prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None prefill_cache_indices.to(device) if sliding_window is not None else None
) )
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor( input_lengths_tensor = torch.tensor(
@ -294,12 +306,10 @@ class BaseFlashMistral(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
@ -319,11 +329,13 @@ class BaseFlashMistral(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
# Set context windows # Set context windows
if config.sliding_window is not None: if config.sliding_window is not None:
SLIDING_WINDOW = config.sliding_window set_sliding_window(
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE) config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -394,7 +406,7 @@ class BaseFlashMistral(FlashCausalLM):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
@ -406,9 +418,13 @@ class BaseFlashMistral(FlashCausalLM):
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
) )
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize() torch.cuda.synchronize()
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(
self, batch: FlashMistralBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids = batch.input_ids input_ids = batch.input_ids
@ -479,7 +495,7 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph = self.cuda_graphs.get(padded_bs, None) cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
@ -493,7 +509,7 @@ class BaseFlashMistral(FlashCausalLM):
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
return logits return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded # Static inputs are potentially padded
@ -511,7 +527,13 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
# Slice output to the correct shape # Slice output to the correct shape
return cuda_graph["logits"][:bs] speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
class FlashMistral(BaseFlashMistral): class FlashMistral(BaseFlashMistral):
@ -520,6 +542,7 @@ class FlashMistral(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -529,6 +552,7 @@ class FlashMistral(BaseFlashMistral):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -15,6 +15,7 @@ class FlashMixtral(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -24,6 +25,7 @@ class FlashMixtral(BaseFlashMistral):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -24,6 +24,7 @@ class FlashNeoXSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -46,6 +47,7 @@ class FlashNeoXSharded(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")

View File

@ -25,9 +25,9 @@ class FlashPhi(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -48,6 +48,7 @@ class FlashPhi(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -25,6 +25,7 @@ class FlashRWSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -61,6 +62,7 @@ class FlashRWSharded(FlashCausalLM):
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)

View File

@ -27,6 +27,7 @@ class FlashSantacoderSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -51,6 +52,7 @@ class FlashSantacoderSharded(FlashCausalLM):
trust_remote_code=True, trust_remote_code=True,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
config.transpose = config.architectures[0].startswith("GPT2") config.transpose = config.architectures[0].startswith("GPT2")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -0,0 +1,86 @@
import math
import torch
from typing import Optional
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models.cache_manager import BLOCK_SIZE
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
Starcoder2Config,
FlashStarcoder2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
# Starcoder2 has the same base as Mistral
class FlashStarcoder2(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")
tokenizer = GPT2TokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = Starcoder2Config.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
# Set context windows
if config.sliding_window is not None:
set_sliding_window(
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision)
model = FlashStarcoder2ForCausalLM(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)

View File

@ -31,6 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -51,6 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
config.vision_config.quantize = quantize config.vision_config.quantize = quantize
tokenizer = LlamaTokenizerFast.from_pretrained( tokenizer = LlamaTokenizerFast.from_pretrained(

View File

@ -662,8 +662,13 @@ class IdeficsCausalLM(Model):
if self.has_position_ids: if self.has_position_ids:
kwargs["position_ids"] = position_ids kwargs["position_ids"] = position_ids
outputs = self.model.forward(**kwargs) outputs, speculative_logits = self.model.forward(**kwargs)
return outputs.logits, outputs.past_key_values, outputs.image_hidden_states return (
outputs.logits,
speculative_logits,
outputs.past_key_values,
outputs.image_hidden_states,
)
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(
@ -686,7 +691,7 @@ class IdeficsCausalLM(Model):
:, : -batch.padding_right_offset :, : -batch.padding_right_offset
] ]
logits, past, image_hidden_states = self.forward( logits, speculative_logits, past, image_hidden_states = self.forward(
input_ids=batch.input_ids, input_ids=batch.input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=batch.position_ids, position_ids=batch.position_ids,

View File

@ -408,6 +408,7 @@ class Mamba(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -444,6 +445,7 @@ class Mamba(Model):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
@ -505,7 +507,7 @@ class Mamba(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, inference_params=inference_params input_ids=input_ids, inference_params=inference_params
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@ -514,6 +516,7 @@ class Mamba(Model):
"inference_params": inference_params, "inference_params": inference_params,
"graph": graph, "graph": graph,
"logits": logits, "logits": logits,
"speculative_logits": speculative_logits,
} }
self.cuda_graphs[batch_size] = graph_dict self.cuda_graphs[batch_size] = graph_dict
@ -556,9 +559,14 @@ class Mamba(Model):
inference_params.ssm_states.copy_( inference_params.ssm_states.copy_(
cuda_graph["inference_params"].ssm_states[:, :bs] cuda_graph["inference_params"].ssm_states[:, :bs]
) )
# Slice output to the correct shape # Slice output to the correct shape
return cuda_graph["logits"][:bs] speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns() start = time.time_ns()
@ -589,7 +597,9 @@ class Mamba(Model):
batch.inference_params = inference_params batch.inference_params = inference_params
# Forward pass # Forward pass
logits = self.forward(input_ids, inference_params=batch.inference_params) logits, speculative_logits = self.forward(
input_ids, inference_params=batch.inference_params
)
# batch.inference_params = new_inference_params # batch.inference_params = new_inference_params
# Results # Results

View File

@ -43,6 +43,7 @@ class MPTSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -75,6 +76,7 @@ class MPTSharded(CausalLM):
config = json.load(f) config = json.load(f)
config = PretrainedConfig(**config) config = PretrainedConfig(**config)
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -22,6 +22,7 @@ class OPTSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -47,6 +48,7 @@ class OPTSharded(CausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -22,6 +22,7 @@ class Phi(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -52,6 +53,7 @@ class Phi(CausalLM):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)

View File

@ -19,6 +19,7 @@ class SantaCoder(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):

View File

@ -532,6 +532,7 @@ class Seq2SeqLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -596,6 +597,7 @@ class Seq2SeqLM(Model):
past_key_values: Optional = None, past_key_values: Optional = None,
) -> Tuple[ ) -> Tuple[
torch.Tensor, torch.Tensor,
Optional[torch.Tensor],
torch.Tensor, torch.Tensor,
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]: ]:
@ -609,8 +611,15 @@ class Seq2SeqLM(Model):
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )
if isinstance(outputs, tuple):
# Our custom models
outputs, speculative_logits = outputs
else:
# Generic transformers models
speculative_logits = None
return ( return (
outputs.logits, outputs.logits,
speculative_logits,
outputs.encoder_last_hidden_state, outputs.encoder_last_hidden_state,
outputs.past_key_values, outputs.past_key_values,
) )
@ -635,7 +644,7 @@ class Seq2SeqLM(Model):
else: else:
encoder_last_hidden_state = None encoder_last_hidden_state = None
logits, encoder_last_hidden_state, past = self.forward( logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids, batch.input_ids,
batch.attention_mask, batch.attention_mask,
batch.decoder_input_ids, batch.decoder_input_ids,

View File

@ -25,6 +25,7 @@ class T5Sharded(Seq2SeqLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -42,6 +43,7 @@ class T5Sharded(Seq2SeqLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
@ -94,7 +96,7 @@ class T5Sharded(Seq2SeqLM):
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]: ]:
# Model Forward # Model Forward
outputs = self.model.forward( outputs, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
@ -106,6 +108,7 @@ class T5Sharded(Seq2SeqLM):
return ( return (
outputs.logits, outputs.logits,
speculative_logits,
outputs.encoder_last_hidden_state, outputs.encoder_last_hidden_state,
outputs.past_key_values, outputs.past_key_values,
) )

View File

@ -40,6 +40,7 @@ def _weight_hub_files_from_model_info(
and "arguments" not in s.rfilename and "arguments" not in s.rfilename
and "args" not in s.rfilename and "args" not in s.rfilename
and "training" not in s.rfilename and "training" not in s.rfilename
and "medusa_lm_head" not in s.rfilename
] ]
@ -56,6 +57,7 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
and "args" not in f and "args" not in f
and "adapter" not in f and "adapter" not in f
and "training" not in f and "training" not in f
and "medusa_lm_head" not in f
] ]
return filenames return filenames

View File

@ -4,7 +4,7 @@ import torch.distributed
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import List from typing import List, Tuple, Optional
from loguru import logger from loguru import logger
from functools import lru_cache from functools import lru_cache
@ -380,6 +380,96 @@ class SuperLayer(nn.Module):
return self.linear.forward(x) return self.linear.forward(x)
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True
)
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
class MedusaModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"])
]
)
def forward(self, x):
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return speculative_logits
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"])
]
)
n = len(self.blocks)
self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x
class SpeculativeHead(nn.Module):
def __init__(self, lm_head, medusa):
super().__init__()
self.lm_head = lm_head
self.medusa = medusa
@staticmethod
def load(config, prefix: str, weights):
lm_head = TensorParallelHead.load(config, prefix, weights)
use_medusa = config.use_medusa
if use_medusa:
from pathlib import Path
from safetensors import safe_open
import json
medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f:
config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
weights.routing[k] = filename
medusa = MedusaModel(config, weights)
else:
medusa = None
return SpeculativeHead(lm_head, medusa)
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
speculative_logits = self.medusa(input) if self.medusa is not None else None
return logits, speculative_logits
class TensorParallelHead(SuperLayer): class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group, should_gather: bool): def __init__(self, linear, process_group, should_gather: bool):
super().__init__(linear) super().__init__(linear)

View File

@ -1,59 +0,0 @@
import torch
from dataclasses import dataclass
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
@dataclass
class Output:
logits: torch.FloatTensor = None
speculative_logits: torch.FloatTensor = None
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True
)
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
class MedusaModel(torch.nn.Module):
def __init__(self, config, weights, lm_head):
super().__init__()
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"])
]
)
self.lm_head = lm_head
def forward(self, x):
logits = self.lm_head(x)
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return logits, speculative_logits
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"])
]
)
n = len(self.blocks)
self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x