mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Merge branch 'feat/qwen' into qwen2
This commit is contained in:
commit
73403aa4db
@ -52,6 +52,8 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
|
|||||||
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
|
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
|
||||||
- Stop sequences
|
- Stop sequences
|
||||||
- Log probabilities
|
- Log probabilities
|
||||||
|
- [Speculation](https://huggingface.co/docs/text-generation-inference/conceptual/speculation) ~2x latency
|
||||||
|
- [Guidance/JSON](https://huggingface.co/docs/text-generation-inference/conceptual/guidance). Specify output format to speed up inference and make sure the output is valid according to some specs..
|
||||||
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output
|
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output
|
||||||
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance
|
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import requests
|
|||||||
|
|
||||||
from aiohttp import ClientSession, ClientTimeout
|
from aiohttp import ClientSession, ClientTimeout
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from typing import Dict, Optional, List, AsyncIterator, Iterator
|
from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
|
||||||
|
|
||||||
from text_generation.types import (
|
from text_generation.types import (
|
||||||
StreamResponse,
|
StreamResponse,
|
||||||
@ -11,6 +11,11 @@ from text_generation.types import (
|
|||||||
Request,
|
Request,
|
||||||
Parameters,
|
Parameters,
|
||||||
Grammar,
|
Grammar,
|
||||||
|
ChatRequest,
|
||||||
|
ChatCompletionChunk,
|
||||||
|
ChatComplete,
|
||||||
|
Message,
|
||||||
|
Tool,
|
||||||
)
|
)
|
||||||
from text_generation.errors import parse_error
|
from text_generation.errors import parse_error
|
||||||
|
|
||||||
@ -59,6 +64,114 @@ class Client:
|
|||||||
self.cookies = cookies
|
self.cookies = cookies
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
logit_bias: Optional[List[float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
tools: Optional[List[Tool]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Given a list of messages, generate a response asynchronously
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (`List[Message]`):
|
||||||
|
List of messages
|
||||||
|
frequency_penalty (`float`):
|
||||||
|
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
logit_bias (`List[float]`):
|
||||||
|
Adjust the likelihood of specified tokens
|
||||||
|
logprobs (`bool`):
|
||||||
|
Include log probabilities in the response
|
||||||
|
top_logprobs (`int`):
|
||||||
|
Include the `n` most likely tokens at each step
|
||||||
|
max_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
n (`int`):
|
||||||
|
Generate `n` completions
|
||||||
|
presence_penalty (`float`):
|
||||||
|
The parameter for presence penalty. 0.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
stream (`bool`):
|
||||||
|
Stream the response
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation
|
||||||
|
tools (`List[Tool]`):
|
||||||
|
List of tools to use
|
||||||
|
tool_choice (`str`):
|
||||||
|
The tool to use
|
||||||
|
|
||||||
|
"""
|
||||||
|
request = ChatRequest(
|
||||||
|
model="tgi",
|
||||||
|
messages=messages,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
stream=stream,
|
||||||
|
seed=seed,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
if not stream:
|
||||||
|
resp = requests.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions",
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
cookies=self.cookies,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
payload = resp.json()
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise parse_error(resp.status_code, payload)
|
||||||
|
return ChatComplete(**payload)
|
||||||
|
else:
|
||||||
|
return self._chat_stream_response(request)
|
||||||
|
|
||||||
|
def _chat_stream_response(self, request):
|
||||||
|
resp = requests.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions",
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
cookies=self.cookies,
|
||||||
|
timeout=self.timeout,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
# iterate and print stream
|
||||||
|
for byte_payload in resp.iter_lines():
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||||
|
try:
|
||||||
|
response = ChatCompletionChunk(**json_payload)
|
||||||
|
yield response
|
||||||
|
except ValidationError:
|
||||||
|
raise parse_error(resp.status, json_payload)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -313,6 +426,113 @@ class AsyncClient:
|
|||||||
self.cookies = cookies
|
self.cookies = cookies
|
||||||
self.timeout = ClientTimeout(timeout * 60)
|
self.timeout = ClientTimeout(timeout * 60)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
logit_bias: Optional[List[float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
tools: Optional[List[Tool]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
||||||
|
"""
|
||||||
|
Given a list of messages, generate a response asynchronously
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (`List[Message]`):
|
||||||
|
List of messages
|
||||||
|
frequency_penalty (`float`):
|
||||||
|
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
logit_bias (`List[float]`):
|
||||||
|
Adjust the likelihood of specified tokens
|
||||||
|
logprobs (`bool`):
|
||||||
|
Include log probabilities in the response
|
||||||
|
top_logprobs (`int`):
|
||||||
|
Include the `n` most likely tokens at each step
|
||||||
|
max_tokens (`int`):
|
||||||
|
Maximum number of generated tokens
|
||||||
|
n (`int`):
|
||||||
|
Generate `n` completions
|
||||||
|
presence_penalty (`float`):
|
||||||
|
The parameter for presence penalty. 0.0 means no penalty. See [this
|
||||||
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||||
|
stream (`bool`):
|
||||||
|
Stream the response
|
||||||
|
seed (`int`):
|
||||||
|
Random sampling seed
|
||||||
|
temperature (`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
top_p (`float`):
|
||||||
|
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||||
|
higher are kept for generation
|
||||||
|
tools (`List[Tool]`):
|
||||||
|
List of tools to use
|
||||||
|
tool_choice (`str`):
|
||||||
|
The tool to use
|
||||||
|
|
||||||
|
"""
|
||||||
|
request = ChatRequest(
|
||||||
|
model="tgi",
|
||||||
|
messages=messages,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
stream=stream,
|
||||||
|
seed=seed,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
if not stream:
|
||||||
|
return await self._chat_single_response(request)
|
||||||
|
else:
|
||||||
|
return self._chat_stream_response(request)
|
||||||
|
|
||||||
|
async def _chat_single_response(self, request):
|
||||||
|
async with ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions", json=request.dict()
|
||||||
|
) as resp:
|
||||||
|
payload = await resp.json()
|
||||||
|
if resp.status != 200:
|
||||||
|
raise parse_error(resp.status, payload)
|
||||||
|
return ChatComplete(**payload)
|
||||||
|
|
||||||
|
async def _chat_stream_response(self, request):
|
||||||
|
async with ClientSession(
|
||||||
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
|
) as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.base_url}/v1/chat/completions", json=request.dict()
|
||||||
|
) as resp:
|
||||||
|
async for byte_payload in resp.content:
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||||
|
try:
|
||||||
|
response = ChatCompletionChunk(**json_payload)
|
||||||
|
yield response
|
||||||
|
except ValidationError:
|
||||||
|
raise parse_error(resp.status, json_payload)
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
from typing import Optional, List, Union
|
from typing import Optional, List, Union, Any
|
||||||
|
|
||||||
from text_generation.errors import ValidationError
|
from text_generation.errors import ValidationError
|
||||||
|
|
||||||
@ -19,6 +19,124 @@ class Grammar(BaseModel):
|
|||||||
value: Union[str, dict]
|
value: Union[str, dict]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(BaseModel):
|
||||||
|
# Id of the tool call
|
||||||
|
id: int
|
||||||
|
# Type of the tool call
|
||||||
|
type: str
|
||||||
|
# Function details of the tool call
|
||||||
|
function: dict
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
# Role of the message sender
|
||||||
|
role: str
|
||||||
|
# Content of the message
|
||||||
|
content: Optional[str]
|
||||||
|
# Optional name of the message sender
|
||||||
|
name: Optional[str] = None
|
||||||
|
# Tool calls associated with the chat completion
|
||||||
|
tool_calls: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Tool(BaseModel):
|
||||||
|
# Type of the tool
|
||||||
|
type: str
|
||||||
|
# Function details of the tool
|
||||||
|
function: dict
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionComplete(BaseModel):
|
||||||
|
# Index of the chat completion
|
||||||
|
index: int
|
||||||
|
# Message associated with the chat completion
|
||||||
|
message: Message
|
||||||
|
# Log probabilities for the chat completion
|
||||||
|
logprobs: Optional[Any]
|
||||||
|
# Reason for completion
|
||||||
|
finish_reason: str
|
||||||
|
# Usage details of the chat completion
|
||||||
|
usage: Any
|
||||||
|
|
||||||
|
|
||||||
|
class Function(BaseModel):
|
||||||
|
name: Optional[str]
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChoiceDeltaToolCall(BaseModel):
|
||||||
|
index: int
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
function: Function
|
||||||
|
|
||||||
|
|
||||||
|
class ChoiceDelta(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: Optional[str]
|
||||||
|
tool_calls: Optional[ChoiceDeltaToolCall]
|
||||||
|
|
||||||
|
|
||||||
|
class Choice(BaseModel):
|
||||||
|
index: int
|
||||||
|
delta: ChoiceDelta
|
||||||
|
logprobs: Optional[dict] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChunk(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
system_fingerprint: str
|
||||||
|
choices: List[Choice]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatComplete(BaseModel):
|
||||||
|
# Chat completion details
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
system_fingerprint: str
|
||||||
|
choices: List[ChatCompletionComplete]
|
||||||
|
usage: Any
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
# Model identifier
|
||||||
|
model: str
|
||||||
|
# List of messages in the conversation
|
||||||
|
messages: List[Message]
|
||||||
|
# Penalty for frequency of new tokens
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
# Bias values for token selection
|
||||||
|
logit_bias: Optional[List[float]] = None
|
||||||
|
# Whether to return log probabilities
|
||||||
|
logprobs: Optional[bool] = None
|
||||||
|
# Number of most likely tokens to return at each position
|
||||||
|
top_logprobs: Optional[int] = None
|
||||||
|
# Maximum number of tokens to generate
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
# Number of chat completion choices to generate
|
||||||
|
n: Optional[int] = None
|
||||||
|
# Penalty for presence of new tokens
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
# Flag to indicate streaming response
|
||||||
|
stream: bool = False
|
||||||
|
# Random sampling seed
|
||||||
|
seed: Optional[int] = None
|
||||||
|
# Sampling temperature
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
# Top-p value for nucleus sampling
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
# List of tools to be used
|
||||||
|
tools: Optional[List[Tool]] = None
|
||||||
|
# Choice of tool to be used
|
||||||
|
tool_choice: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
class Parameters(BaseModel):
|
||||||
# Activate logits sampling
|
# Activate logits sampling
|
||||||
do_sample: bool = False
|
do_sample: bool = False
|
||||||
|
@ -9,6 +9,8 @@
|
|||||||
title: Supported Models and Hardware
|
title: Supported Models and Hardware
|
||||||
- local: messages_api
|
- local: messages_api
|
||||||
title: Messages API
|
title: Messages API
|
||||||
|
- local: guidance
|
||||||
|
title: Guidance
|
||||||
title: Getting started
|
title: Getting started
|
||||||
- sections:
|
- sections:
|
||||||
- local: basic_tutorials/consuming_tgi
|
- local: basic_tutorials/consuming_tgi
|
||||||
@ -37,4 +39,8 @@
|
|||||||
title: Safetensors
|
title: Safetensors
|
||||||
- local: conceptual/flash_attention
|
- local: conceptual/flash_attention
|
||||||
title: Flash Attention
|
title: Flash Attention
|
||||||
|
- local: conceptual/speculation
|
||||||
|
title: Speculation (Medusa, ngram)
|
||||||
|
- local: conceptual/guidance
|
||||||
|
title: Guidance, JSON, tools (using outlines)
|
||||||
title: Conceptual Guides
|
title: Conceptual Guides
|
||||||
|
419
docs/source/conceptual/guidance.md
Normal file
419
docs/source/conceptual/guidance.md
Normal file
@ -0,0 +1,419 @@
|
|||||||
|
# Guidance
|
||||||
|
|
||||||
|
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs.
|
||||||
|
|
||||||
|
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Before we jump into the deep end, ensure your system is using TGI version `1.4.3` or later to access all the features we're about to explore in this guide.
|
||||||
|
|
||||||
|
If you're not up to date, grab the latest version and let's get started!
|
||||||
|
|
||||||
|
## Table of Contents 📚
|
||||||
|
|
||||||
|
### Grammar and Constraints
|
||||||
|
|
||||||
|
- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
|
||||||
|
- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
|
||||||
|
- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema.
|
||||||
|
- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
|
||||||
|
|
||||||
|
### Tools and Functions
|
||||||
|
|
||||||
|
- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions.
|
||||||
|
- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions.
|
||||||
|
- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
||||||
|
|
||||||
|
## Grammar and Constraints 🛣️
|
||||||
|
|
||||||
|
### The Grammar Parameter
|
||||||
|
|
||||||
|
In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the AI. This is a game-changer for those who need precise control over the AI's output.
|
||||||
|
|
||||||
|
Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability.
|
||||||
|
|
||||||
|
```json
|
||||||
|
curl localhost:3000/generate \
|
||||||
|
-X POST \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park",
|
||||||
|
"parameters": {
|
||||||
|
"repetition_penalty": 1.3,
|
||||||
|
"grammar": {
|
||||||
|
"type": "json",
|
||||||
|
"value": {
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"activity": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"animals_seen": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 5
|
||||||
|
},
|
||||||
|
"animals": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "activity", "animals_seen", "animals"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
|
||||||
|
|
||||||
|
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compliation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
|
||||||
|
|
||||||
|
### Constrain with Pydantic
|
||||||
|
|
||||||
|
Pydantic is a powerful library for data validation and settings management. It's the perfect tool for crafting the a specific response format.
|
||||||
|
|
||||||
|
Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel, conint
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class Animals(BaseModel):
|
||||||
|
location: str
|
||||||
|
activity: str
|
||||||
|
animals_seen: conint(ge=1, le=5) # Constrained integer type
|
||||||
|
animals: List[str]
|
||||||
|
|
||||||
|
prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park"
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"inputs": prompt,
|
||||||
|
"parameters": {
|
||||||
|
"repetition_penalty": 1.3,
|
||||||
|
"grammar": {
|
||||||
|
"type": "json",
|
||||||
|
"value": Animals.schema()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
'http://127.0.0.1:3000/generate',
|
||||||
|
headers=headers,
|
||||||
|
json=data
|
||||||
|
)
|
||||||
|
print(response.json())
|
||||||
|
# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### JSON Schema Integration
|
||||||
|
|
||||||
|
If Pydantic's not your style, go raw with direct JSON Schema integration. It's like having a conversation with the AI in its own language. This is simliar to the first example but with programmatic control.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
|
||||||
|
json_schema = {
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"activity": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"animals_seen": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 5
|
||||||
|
},
|
||||||
|
"animals": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "activity", "animals_seen", "animals"]
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]",
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": 200,
|
||||||
|
"repetition_penalty": 1.3,
|
||||||
|
"grammar": {
|
||||||
|
"type": "json",
|
||||||
|
"value": json_schema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
'http://127.0.0.1:3000/generate',
|
||||||
|
headers=headers,
|
||||||
|
json=data
|
||||||
|
)
|
||||||
|
print(response.json())
|
||||||
|
# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using the client
|
||||||
|
|
||||||
|
TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation import AsyncClient
|
||||||
|
from text_generation.types import GrammarType
|
||||||
|
|
||||||
|
# NOTE: tools defined above and removed for brevity
|
||||||
|
|
||||||
|
# Define an async function to encapsulate the async operation
|
||||||
|
async def main():
|
||||||
|
client = AsyncClient(base_url="http://localhost:3000")
|
||||||
|
|
||||||
|
# Use 'await' to wait for the async method 'chat' to complete
|
||||||
|
response = await client.generate(
|
||||||
|
"Whats Googles DNS",
|
||||||
|
max_new_tokens=10,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=1,
|
||||||
|
grammar={
|
||||||
|
"type": GrammarType.Regex,
|
||||||
|
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Once the response is received, you can process it
|
||||||
|
print(response.generated_text)
|
||||||
|
|
||||||
|
# Ensure the main async function is run in the event loop
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
# 118.8.0.84
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tools and Functions 🛠️
|
||||||
|
|
||||||
|
### The Tools Parameter
|
||||||
|
|
||||||
|
In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API.
|
||||||
|
|
||||||
|
Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the AI's capabilities. You can use these tools to perform a variety of tasks, such as data manipulation, formatting, and more.
|
||||||
|
|
||||||
|
Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
|
||||||
|
|
||||||
|
```json
|
||||||
|
curl localhost:3000/v1/chat/completions \
|
||||||
|
-X POST \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in New York?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "format"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tool_choice": "get_current_weather"
|
||||||
|
}'
|
||||||
|
// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}}
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Tools used in example below</summary>
|
||||||
|
|
||||||
|
```python
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_n_day_weather_forecast",
|
||||||
|
"description": "Get an N-day weather forecast",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location.",
|
||||||
|
},
|
||||||
|
"num_days": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The number of days to forecast",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format", "num_days"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### Text Generation Inference Client
|
||||||
|
|
||||||
|
TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from text_generation import AsyncClient
|
||||||
|
|
||||||
|
# NOTE: tools defined above and removed for brevity
|
||||||
|
|
||||||
|
# Define an async function to encapsulate the async operation
|
||||||
|
async def main():
|
||||||
|
client = AsyncClient(base_url="http://localhost:3000")
|
||||||
|
|
||||||
|
# Use 'await' to wait for the async method 'chat' to complete
|
||||||
|
response = await client.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=1,
|
||||||
|
tools=tools,
|
||||||
|
presence_penalty=-1.1,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You're a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in Brooklyn, New York?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Once the response is received, you can process it
|
||||||
|
print(response.choices[0].message.tool_calls)
|
||||||
|
|
||||||
|
# Ensure the main async function is run in the event loop
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenAI integration
|
||||||
|
|
||||||
|
TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
||||||
|
|
||||||
|
However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# Initialize the client, pointing it to one of the available models
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://localhost:3000/v1",
|
||||||
|
api_key="_",
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: tools defined above and removed for brevity
|
||||||
|
|
||||||
|
chat_completion = client.chat.completions.create(
|
||||||
|
model="tgi",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto", # tool selected by model
|
||||||
|
max_tokens=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
called = chat_completion.choices[0].message.tool_calls
|
||||||
|
print(called)
|
||||||
|
# {
|
||||||
|
# "id": 0,
|
||||||
|
# "type": "function",
|
||||||
|
# "function": {
|
||||||
|
# "description": None,
|
||||||
|
# "name": "tools",
|
||||||
|
# "parameters": {
|
||||||
|
# "format": "celsius",
|
||||||
|
# "location": "San Francisco, CA",
|
||||||
|
# "num_days": 3,
|
||||||
|
# },
|
||||||
|
# },
|
||||||
|
# }
|
||||||
|
```
|
48
docs/source/conceptual/speculation.md
Normal file
48
docs/source/conceptual/speculation.md
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
## Speculation
|
||||||
|
|
||||||
|
Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea.
|
||||||
|
The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid.
|
||||||
|
|
||||||
|
So you are making *more* computations on your LLM, but if you are correct you produce 1, 2, 3 etc.. tokens on a single LLM pass. Since LLMs are usually memory bound (and not compute bound), provided your guesses are correct enough, this is a 2-3x faster inference (It can be much more for code oriented tasks for instance).
|
||||||
|
|
||||||
|
You can check a more [detailed explanation](https://huggingface.co/blog/assisted-generation).
|
||||||
|
|
||||||
|
Text-generation inference supports 2 main speculative methods:
|
||||||
|
|
||||||
|
- Medusa
|
||||||
|
- N-gram
|
||||||
|
|
||||||
|
|
||||||
|
### Medusa
|
||||||
|
|
||||||
|
|
||||||
|
Medusa is a [simple method](https://arxiv.org/abs/2401.10774) to create many tokens in a single pass using fine-tuned LM heads in addition to your existing models.
|
||||||
|
|
||||||
|
|
||||||
|
You can check a few existing fine-tunes for popular models:
|
||||||
|
|
||||||
|
- [text-generation-inference/gemma-7b-it-medusa](https://huggingface.co/text-generation-inference/gemma-7b-it-medusa)
|
||||||
|
- [text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa](https://huggingface.co/text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa)
|
||||||
|
- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa)
|
||||||
|
|
||||||
|
|
||||||
|
In order to create your own medusa heads for your own finetune, you should check own the original medusa repo. [https://github.com/FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa)
|
||||||
|
|
||||||
|
|
||||||
|
In order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically.
|
||||||
|
|
||||||
|
|
||||||
|
### N-gram
|
||||||
|
|
||||||
|
|
||||||
|
If you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`.
|
||||||
|
Ngram works by trying to find in the previous sequence existing tokens that match, and use those as speculation.
|
||||||
|
|
||||||
|
This is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much.
|
||||||
|
|
||||||
|
|
||||||
|
In order to enable n-gram speculation simply use
|
||||||
|
|
||||||
|
`--speculate 2` in your flags.
|
||||||
|
|
||||||
|
[Details about the flag](https://huggingface.co/docs/text-generation-inference/basic_tutorials/launcher#speculate)
|
@ -23,6 +23,8 @@ from text_generation.types import (
|
|||||||
Token,
|
Token,
|
||||||
BestOfSequence,
|
BestOfSequence,
|
||||||
Grammar,
|
Grammar,
|
||||||
|
ChatComplete,
|
||||||
|
ChatCompletionChunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||||
@ -59,6 +61,15 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
def convert_data(data):
|
def convert_data(data):
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
|
if isinstance(data, Dict) and "choices" in data:
|
||||||
|
choices = data["choices"]
|
||||||
|
if (
|
||||||
|
isinstance(choices, List)
|
||||||
|
and len(choices) >= 1
|
||||||
|
and "delta" in choices[0]
|
||||||
|
):
|
||||||
|
return ChatCompletionChunk(**data)
|
||||||
|
return ChatComplete(**data)
|
||||||
|
|
||||||
if isinstance(data, Dict):
|
if isinstance(data, Dict):
|
||||||
return Response(**data)
|
return Response(**data)
|
||||||
@ -144,6 +155,16 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
|
||||||
|
return (
|
||||||
|
response.choices[0].message.content == other.choices[0].message.content
|
||||||
|
)
|
||||||
|
|
||||||
|
def eq_chat_complete_chunk(
|
||||||
|
response: ChatCompletionChunk, other: ChatCompletionChunk
|
||||||
|
) -> bool:
|
||||||
|
return response.choices[0].delta.content == other.choices[0].delta.content
|
||||||
|
|
||||||
def eq_response(response: Response, other: Response) -> bool:
|
def eq_response(response: Response, other: Response) -> bool:
|
||||||
return response.generated_text == other.generated_text and eq_details(
|
return response.generated_text == other.generated_text and eq_details(
|
||||||
response.details, other.details
|
response.details, other.details
|
||||||
@ -157,6 +178,19 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
if not isinstance(snapshot_data, List):
|
if not isinstance(snapshot_data, List):
|
||||||
snapshot_data = [snapshot_data]
|
snapshot_data = [snapshot_data]
|
||||||
|
|
||||||
|
if isinstance(serialized_data[0], ChatComplete):
|
||||||
|
return len(snapshot_data) == len(serialized_data) and all(
|
||||||
|
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(serialized_data[0], ChatCompletionChunk):
|
||||||
|
return len(snapshot_data) == len(serialized_data) and all(
|
||||||
|
[
|
||||||
|
eq_chat_complete_chunk(r, o)
|
||||||
|
for r, o in zip(serialized_data, snapshot_data)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return len(snapshot_data) == len(serialized_data) and all(
|
return len(snapshot_data) == len(serialized_data) and all(
|
||||||
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||||
)
|
)
|
||||||
@ -236,6 +270,7 @@ def launcher(event_loop):
|
|||||||
use_flash_attention: bool = True,
|
use_flash_attention: bool = True,
|
||||||
disable_grammar_support: bool = False,
|
disable_grammar_support: bool = False,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
master_port = random.randint(10_000, 20_000)
|
master_port = random.randint(10_000, 20_000)
|
||||||
@ -268,6 +303,9 @@ def launcher(event_loop):
|
|||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
args.append("--dtype")
|
args.append("--dtype")
|
||||||
args.append(dtype)
|
args.append(dtype)
|
||||||
|
if revision is not None:
|
||||||
|
args.append("--revision")
|
||||||
|
args.append(revision)
|
||||||
if trust_remote_code:
|
if trust_remote_code:
|
||||||
args.append("--trust-remote-code")
|
args.append("--trust-remote-code")
|
||||||
|
|
||||||
@ -302,6 +340,7 @@ def launcher(event_loop):
|
|||||||
use_flash_attention: bool = True,
|
use_flash_attention: bool = True,
|
||||||
disable_grammar_support: bool = False,
|
disable_grammar_support: bool = False,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
|
|
||||||
@ -317,6 +356,9 @@ def launcher(event_loop):
|
|||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
args.append("--dtype")
|
args.append("--dtype")
|
||||||
args.append(dtype)
|
args.append(dtype)
|
||||||
|
if revision is not None:
|
||||||
|
args.append("--revision")
|
||||||
|
args.append(revision)
|
||||||
if trust_remote_code:
|
if trust_remote_code:
|
||||||
args.append("--trust-remote-code")
|
args.append("--trust-remote-code")
|
||||||
|
|
||||||
|
@ -0,0 +1,94 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "def"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -5.2617188,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": -0.38476562,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7670,
|
||||||
|
"logprob": -7.640625,
|
||||||
|
"text": "hello"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 2284,
|
||||||
|
"logprob": -0.92626953,
|
||||||
|
"special": false,
|
||||||
|
"text": "():"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 303,
|
||||||
|
"logprob": -0.40844727,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -0.27905273,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 459,
|
||||||
|
"logprob": -0.6118164,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8302,
|
||||||
|
"logprob": -0.68652344,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10914,
|
||||||
|
"logprob": -1.4619141,
|
||||||
|
"special": false,
|
||||||
|
"text": " World"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16013,
|
||||||
|
"logprob": -0.7993164,
|
||||||
|
"special": false,
|
||||||
|
"text": "!\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.63134766,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.23278809,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": -1.2294922,
|
||||||
|
"special": false,
|
||||||
|
"text": "def"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||||
|
}
|
@ -0,0 +1,394 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 60,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "def"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -5.2617188,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": -0.38476562,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7670,
|
||||||
|
"logprob": -7.640625,
|
||||||
|
"text": "hello"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 2284,
|
||||||
|
"logprob": -0.296875,
|
||||||
|
"special": false,
|
||||||
|
"text": "():"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 303,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 459,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8302,
|
||||||
|
"logprob": -0.28125,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10914,
|
||||||
|
"logprob": -0.79248047,
|
||||||
|
"special": false,
|
||||||
|
"text": " World"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16013,
|
||||||
|
"logprob": -0.61816406,
|
||||||
|
"special": false,
|
||||||
|
"text": "!\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.0619812,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": -0.4091797,
|
||||||
|
"special": false,
|
||||||
|
"text": "def"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7670,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 444,
|
||||||
|
"logprob": -0.21655273,
|
||||||
|
"special": false,
|
||||||
|
"text": "name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 45,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "("
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 444,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 731,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "):"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 303,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 459,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8302,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 332,
|
||||||
|
"logprob": -0.034698486,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 655,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": -0.20141602,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 332,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16013,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "!\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "def"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7670,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 444,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 400,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "age"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 45,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "("
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 444,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 49,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11505,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " age"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 731,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "):"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 303,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 459,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8302,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 332,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 655,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3021,
|
||||||
|
"logprob": -0.5761719,
|
||||||
|
"special": false,
|
||||||
|
"text": " \","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 863,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " you"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 904,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " are"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 332,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 494,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 615,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " str"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 45,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "("
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 400,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "age"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 46,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": ")"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)"
|
||||||
|
}
|
@ -0,0 +1,378 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "def"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -5.2617188,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": -0.38476562,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7670,
|
||||||
|
"logprob": -7.640625,
|
||||||
|
"text": "hello"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 2284,
|
||||||
|
"logprob": -0.92626953,
|
||||||
|
"special": false,
|
||||||
|
"text": "():"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 303,
|
||||||
|
"logprob": -0.40722656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -0.27954102,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 459,
|
||||||
|
"logprob": -0.6142578,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8302,
|
||||||
|
"logprob": -0.68310547,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10914,
|
||||||
|
"logprob": -1.4570312,
|
||||||
|
"special": false,
|
||||||
|
"text": " World"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16013,
|
||||||
|
"logprob": -0.80126953,
|
||||||
|
"special": false,
|
||||||
|
"text": "!\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.6303711,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.23327637,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": -1.2304688,
|
||||||
|
"special": false,
|
||||||
|
"text": "def"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "def"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -5.2617188,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": -0.38476562,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7670,
|
||||||
|
"logprob": -7.640625,
|
||||||
|
"text": "hello"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 2284,
|
||||||
|
"logprob": -0.92626953,
|
||||||
|
"special": false,
|
||||||
|
"text": "():"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 303,
|
||||||
|
"logprob": -0.40722656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -0.27954102,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 459,
|
||||||
|
"logprob": -0.6142578,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8302,
|
||||||
|
"logprob": -0.68310547,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10914,
|
||||||
|
"logprob": -1.4570312,
|
||||||
|
"special": false,
|
||||||
|
"text": " World"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16013,
|
||||||
|
"logprob": -0.80126953,
|
||||||
|
"special": false,
|
||||||
|
"text": "!\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.6303711,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.23327637,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": -1.2304688,
|
||||||
|
"special": false,
|
||||||
|
"text": "def"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "def"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -5.2617188,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": -0.38476562,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7670,
|
||||||
|
"logprob": -7.640625,
|
||||||
|
"text": "hello"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 2284,
|
||||||
|
"logprob": -0.92626953,
|
||||||
|
"special": false,
|
||||||
|
"text": "():"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 303,
|
||||||
|
"logprob": -0.40722656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -0.27954102,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 459,
|
||||||
|
"logprob": -0.6142578,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8302,
|
||||||
|
"logprob": -0.68310547,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10914,
|
||||||
|
"logprob": -1.4570312,
|
||||||
|
"special": false,
|
||||||
|
"text": " World"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16013,
|
||||||
|
"logprob": -0.80126953,
|
||||||
|
"special": false,
|
||||||
|
"text": "!\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.6303711,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.23327637,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": -1.2304688,
|
||||||
|
"special": false,
|
||||||
|
"text": "def"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "def"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -5.2617188,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"logprob": -0.38476562,
|
||||||
|
"text": "_"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7670,
|
||||||
|
"logprob": -7.640625,
|
||||||
|
"text": "hello"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 2284,
|
||||||
|
"logprob": -0.92626953,
|
||||||
|
"special": false,
|
||||||
|
"text": "():"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 303,
|
||||||
|
"logprob": -0.40722656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1489,
|
||||||
|
"logprob": -0.27954102,
|
||||||
|
"special": false,
|
||||||
|
"text": " print"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 459,
|
||||||
|
"logprob": -0.6142578,
|
||||||
|
"special": false,
|
||||||
|
"text": "(\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8302,
|
||||||
|
"logprob": -0.68310547,
|
||||||
|
"special": false,
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10914,
|
||||||
|
"logprob": -1.4570312,
|
||||||
|
"special": false,
|
||||||
|
"text": " World"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16013,
|
||||||
|
"logprob": -0.80126953,
|
||||||
|
"special": false,
|
||||||
|
"text": "!\")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.6303711,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 222,
|
||||||
|
"logprob": -0.23327637,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 610,
|
||||||
|
"logprob": -1.2304688,
|
||||||
|
"special": false,
|
||||||
|
"text": "def"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1708957015,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "1.4.2-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 100,
|
||||||
|
"prompt_tokens": 60,
|
||||||
|
"total_tokens": 160
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,38 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": null,
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": {
|
||||||
|
"function": {
|
||||||
|
"description": null,
|
||||||
|
"name": "tools",
|
||||||
|
"parameters": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "New York, NY",
|
||||||
|
"num_days": 14
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
"type": "function"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1709079417,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "1.4.2-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 29,
|
||||||
|
"prompt_tokens": 316,
|
||||||
|
"total_tokens": 345
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,38 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": null,
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": {
|
||||||
|
"function": {
|
||||||
|
"description": null,
|
||||||
|
"name": "tools",
|
||||||
|
"parameters": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "New York, NY",
|
||||||
|
"num_days": 14
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
"type": "function"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1709079492,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "1.4.2-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 29,
|
||||||
|
"prompt_tokens": 316,
|
||||||
|
"total_tokens": 345
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,37 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": null,
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": {
|
||||||
|
"function": {
|
||||||
|
"description": null,
|
||||||
|
"name": "tools",
|
||||||
|
"parameters": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "New York, NY"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
"type": "function"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1709079493,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "1.4.2-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 21,
|
||||||
|
"prompt_tokens": 187,
|
||||||
|
"total_tokens": 208
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": {
|
||||||
|
"function": {
|
||||||
|
"arguments": "</s>",
|
||||||
|
"name": null
|
||||||
|
},
|
||||||
|
"id": "",
|
||||||
|
"index": 20,
|
||||||
|
"type": "function"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 20,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1709087088,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "1.4.2-native"
|
||||||
|
}
|
@ -3,7 +3,9 @@ import pytest
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_medusa_handle(launcher):
|
def flash_medusa_handle(launcher):
|
||||||
with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle:
|
with launcher(
|
||||||
|
"FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2, revision="refs/pr/1"
|
||||||
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
55
integration-tests/models/test_flash_starcoder2.py
Normal file
55
integration-tests/models/test_flash_starcoder2.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_starcoder2_handle(launcher):
|
||||||
|
with launcher("bigcode/starcoder2-3b", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_starcoder2(flash_starcoder2_handle):
|
||||||
|
await flash_starcoder2_handle.health(300)
|
||||||
|
return flash_starcoder2_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
|
||||||
|
response = await flash_starcoder2.generate(
|
||||||
|
"def print_hello", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
|
||||||
|
response = await flash_starcoder2.generate(
|
||||||
|
"def print_hello",
|
||||||
|
max_new_tokens=60,
|
||||||
|
temperature=0.2,
|
||||||
|
top_p=0.95,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 60
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_starcoder2_load(
|
||||||
|
flash_starcoder2, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_starcoder2, "def print_hello", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
240
integration-tests/models/test_tools_llama.py
Normal file
240
integration-tests/models/test_tools_llama.py
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
|
||||||
|
from text_generation.types import GrammarType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_grammar_tools_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
|
||||||
|
await flash_llama_grammar_tools_handle.health(300)
|
||||||
|
return flash_llama_grammar_tools_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
# tools to be used in the following tests
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_n_day_weather_forecast",
|
||||||
|
"description": "Get an N-day weather forecast",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location.",
|
||||||
|
},
|
||||||
|
"num_days": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The number of days to forecast",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location", "format", "num_days"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_no_tools(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=1,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in Brooklyn, New York?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||||
|
response = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=1,
|
||||||
|
tools=tools,
|
||||||
|
presence_penalty=-1.1,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in Brooklyn, New York?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert response.choices[0].message.content == None
|
||||||
|
assert response.choices[0].message.tool_calls == {
|
||||||
|
"function": {
|
||||||
|
"description": None,
|
||||||
|
"name": "tools",
|
||||||
|
"parameters": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "New York, NY",
|
||||||
|
"num_days": 14,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools_auto(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=1,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
presence_penalty=-1.1,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in Brooklyn, New York?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert response.choices[0].message.content == None
|
||||||
|
assert response.choices[0].message.tool_calls == {
|
||||||
|
"function": {
|
||||||
|
"description": None,
|
||||||
|
"name": "tools",
|
||||||
|
"parameters": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "New York, NY",
|
||||||
|
"num_days": 14,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"id": 0,
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools_choice(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=1,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="get_current_weather",
|
||||||
|
presence_penalty=-1.1,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in Brooklyn, New York?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert response.choices[0].message.content == None
|
||||||
|
assert response.choices[0].message.tool_calls == {
|
||||||
|
"id": 0,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": None,
|
||||||
|
"name": "tools",
|
||||||
|
"parameters": {"format": "celsius", "location": "New York, NY"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_tools_stream(
|
||||||
|
flash_llama_grammar_tools, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await flash_llama_grammar_tools.chat(
|
||||||
|
max_tokens=100,
|
||||||
|
seed=1,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="get_current_weather",
|
||||||
|
presence_penalty=-1.1,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in Paris, France?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
async for response in responses:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
assert count == 20
|
||||||
|
assert response == response_snapshot
|
@ -230,7 +230,6 @@ message WarmupRequest {
|
|||||||
uint32 max_total_tokens = 4;
|
uint32 max_total_tokens = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty response
|
|
||||||
message WarmupResponse {
|
message WarmupResponse {
|
||||||
/// Maximum number of tokens supported by the model
|
/// Maximum number of tokens supported by the model
|
||||||
optional uint32 max_supported_total_tokens = 1;
|
optional uint32 max_supported_total_tokens = 1;
|
||||||
|
@ -812,23 +812,27 @@ mod tests {
|
|||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: Some("Hi!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: Some("Hello how can I help?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: Some("What is Deep Learning?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: Some("magic!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -877,28 +881,33 @@ mod tests {
|
|||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: Some("Hi!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi again!".to_string(),
|
content: Some("Hi again!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: Some("Hello how can I help?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: Some("What is Deep Learning?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: Some("magic!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -952,23 +961,27 @@ mod tests {
|
|||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: Some("Hi!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: Some("Hello how can I help?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: Some("What is Deep Learning?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: Some("magic!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
@ -1006,23 +1019,27 @@ mod tests {
|
|||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "Hi!".to_string(),
|
content: Some("Hi!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "Hello how can I help?".to_string(),
|
content: Some("Hello how can I help?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: "What is Deep Learning?".to_string(),
|
content: Some("What is Deep Learning?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
Message {
|
Message {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: "magic!".to_string(),
|
content: Some("magic!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
|
@ -358,10 +358,11 @@ impl ChatCompletion {
|
|||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
model: String,
|
model: String,
|
||||||
system_fingerprint: String,
|
system_fingerprint: String,
|
||||||
output: String,
|
output: Option<String>,
|
||||||
created: u64,
|
created: u64,
|
||||||
details: Details,
|
details: Details,
|
||||||
return_logprobs: bool,
|
return_logprobs: bool,
|
||||||
|
tool_calls: Option<ToolCall>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: String::new(),
|
id: String::new(),
|
||||||
@ -375,6 +376,7 @@ impl ChatCompletion {
|
|||||||
role: "assistant".into(),
|
role: "assistant".into(),
|
||||||
content: output,
|
content: output,
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls,
|
||||||
},
|
},
|
||||||
logprobs: return_logprobs
|
logprobs: return_logprobs
|
||||||
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
||||||
@ -413,15 +415,35 @@ pub(crate) struct ChatCompletionChoice {
|
|||||||
pub(crate) struct ChatCompletionDelta {
|
pub(crate) struct ChatCompletionDelta {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
pub role: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "What is Deep Learning?")]
|
#[schema(example = "What is Deep Learning?")]
|
||||||
pub content: String,
|
pub content: Option<String>,
|
||||||
|
// default to None
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<DeltaToolCall>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||||
|
pub(crate) struct DeltaToolCall {
|
||||||
|
pub index: u32,
|
||||||
|
pub id: String,
|
||||||
|
pub r#type: String,
|
||||||
|
pub function: Function,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||||
|
pub(crate) struct Function {
|
||||||
|
pub name: Option<String>,
|
||||||
|
pub arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
impl ChatCompletionChunk {
|
impl ChatCompletionChunk {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
model: String,
|
model: String,
|
||||||
system_fingerprint: String,
|
system_fingerprint: String,
|
||||||
delta: String,
|
delta: Option<String>,
|
||||||
|
tool_calls: Option<Vec<String>>,
|
||||||
created: u64,
|
created: u64,
|
||||||
index: u32,
|
index: u32,
|
||||||
logprobs: Option<ChatCompletionLogprobs>,
|
logprobs: Option<ChatCompletionLogprobs>,
|
||||||
@ -438,6 +460,15 @@ impl ChatCompletionChunk {
|
|||||||
delta: ChatCompletionDelta {
|
delta: ChatCompletionDelta {
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: delta,
|
content: delta,
|
||||||
|
tool_calls: tool_calls.map(|tc| DeltaToolCall {
|
||||||
|
index,
|
||||||
|
id: String::new(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: Function {
|
||||||
|
name: None,
|
||||||
|
arguments: tc[0].to_string(),
|
||||||
|
},
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
logprobs,
|
logprobs,
|
||||||
finish_reason,
|
finish_reason,
|
||||||
@ -520,6 +551,125 @@ pub(crate) struct ChatRequest {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, example = 0.95)]
|
#[schema(nullable = true, example = 0.95)]
|
||||||
pub top_p: Option<f32>,
|
pub top_p: Option<f32>,
|
||||||
|
|
||||||
|
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
|
||||||
|
/// functions the model may generate JSON inputs for.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
|
||||||
|
/// A prompt to be appended before the tools
|
||||||
|
#[serde(default = "default_tool_prompt")]
|
||||||
|
#[schema(
|
||||||
|
nullable = true,
|
||||||
|
example = "\"Based on the conversation, please choose the most appropriate tool to use: \""
|
||||||
|
)]
|
||||||
|
pub tool_prompt: Option<String>,
|
||||||
|
|
||||||
|
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
||||||
|
pub tool_choice: Option<ToolType>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_tool_prompt() -> Option<String> {
|
||||||
|
Some(
|
||||||
|
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
enum ToolType {
|
||||||
|
FunctionName(String),
|
||||||
|
OneOf,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None)
|
||||||
|
mod deserialize_tool_choice {
|
||||||
|
use super::*;
|
||||||
|
use serde::de;
|
||||||
|
use serde::Deserializer;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ToolType>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Value::deserialize(deserializer)?;
|
||||||
|
|
||||||
|
match value {
|
||||||
|
Value::String(s) => match s.as_str() {
|
||||||
|
"none" => Ok(None),
|
||||||
|
"auto" => Ok(Some(ToolType::OneOf)),
|
||||||
|
_ => Ok(Some(ToolType::FunctionName(s))),
|
||||||
|
},
|
||||||
|
Value::Object(map) => {
|
||||||
|
if let Some(content) = map
|
||||||
|
.get("function")
|
||||||
|
.and_then(|v| v.get("name"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
{
|
||||||
|
Ok(Some(ToolType::FunctionName(content.to_string())))
|
||||||
|
} else {
|
||||||
|
Err(de::Error::custom("function key not found in tool choice"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Value::Null => Ok(Some(ToolType::OneOf)),
|
||||||
|
_ => Err(de::Error::custom("invalid token format")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub struct Tools {
|
||||||
|
#[serde(flatten)]
|
||||||
|
functions_map: FunctionsMap,
|
||||||
|
properties: Properties,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct FunctionsMap {
|
||||||
|
#[serde(rename = "$functions")]
|
||||||
|
functions: std::collections::HashMap<String, serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct FunctionRef {
|
||||||
|
#[serde(rename = "$ref")]
|
||||||
|
ref_path: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct Properties {
|
||||||
|
#[serde(serialize_with = "serialize_function")]
|
||||||
|
function: Vec<FunctionRef>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn serialize_function<S>(functions: &Vec<FunctionRef>, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
use serde::ser::SerializeStruct;
|
||||||
|
let mut state = serializer.serialize_struct("Function", 1)?;
|
||||||
|
state.serialize_field("anyOf", functions)?;
|
||||||
|
state.end()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
||||||
|
pub(crate) struct FunctionDefinition {
|
||||||
|
#[serde(default)]
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub name: String,
|
||||||
|
pub parameters: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
|
pub(crate) struct Tool {
|
||||||
|
// The type of the tool. Currently, only 'function' is supported.
|
||||||
|
#[schema(example = "function")]
|
||||||
|
pub r#type: String,
|
||||||
|
// Grab the tool as generic JSON for debugging purposes.
|
||||||
|
pub function: FunctionDefinition,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize)]
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
@ -530,15 +680,25 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||||||
add_generation_prompt: bool,
|
add_generation_prompt: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||||
|
pub(crate) struct ToolCall {
|
||||||
|
pub id: u32,
|
||||||
|
pub r#type: String,
|
||||||
|
pub function: FunctionDefinition,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
pub(crate) struct Message {
|
pub(crate) struct Message {
|
||||||
#[schema(example = "user")]
|
#[schema(example = "user")]
|
||||||
pub role: String,
|
pub role: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "My name is David and I")]
|
#[schema(example = "My name is David and I")]
|
||||||
pub content: String,
|
pub content: Option<String>,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
#[schema(example = "\"David\"")]
|
#[schema(example = "\"David\"")]
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<ToolCall>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
|
@ -10,6 +10,7 @@ use crate::{
|
|||||||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||||
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
|
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
@ -22,6 +23,8 @@ use futures::stream::StreamExt;
|
|||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::atomic::AtomicBool;
|
use std::sync::atomic::AtomicBool;
|
||||||
@ -239,7 +242,7 @@ async fn generate(
|
|||||||
headers.insert("x-compute-type", compute_type.parse().unwrap());
|
headers.insert("x-compute-type", compute_type.parse().unwrap());
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"x-compute-time",
|
"x-compute-time",
|
||||||
total_time.as_millis().to_string().parse().unwrap(),
|
total_time.as_secs_f64().to_string().parse().unwrap(),
|
||||||
);
|
);
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"x-compute-characters",
|
"x-compute-characters",
|
||||||
@ -581,7 +584,7 @@ async fn chat_completions(
|
|||||||
let seed = req.seed;
|
let seed = req.seed;
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
// apply chat template to flatten the request into a single input
|
||||||
let inputs = match infer.apply_chat_template(req.messages) {
|
let mut inputs = match infer.apply_chat_template(req.messages) {
|
||||||
Ok(inputs) => inputs,
|
Ok(inputs) => inputs,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
@ -596,6 +599,62 @@ async fn chat_completions(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
|
||||||
|
let tool_prompt = req.tool_prompt.unwrap_or_default();
|
||||||
|
let tools_to_use = match tool_choice {
|
||||||
|
ToolType::FunctionName(name) => {
|
||||||
|
vec![req_tools
|
||||||
|
.iter()
|
||||||
|
.find(|tool| tool.function.name == *name)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
(
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Tool choice not found in tool names".to_string(),
|
||||||
|
error_type: "Tool not found".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
.clone()]
|
||||||
|
}
|
||||||
|
ToolType::OneOf => req_tools.to_owned(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let functions: HashMap<String, Value> = tools_to_use
|
||||||
|
.iter()
|
||||||
|
.map(|tool| {
|
||||||
|
let func = tool.function.clone();
|
||||||
|
(func.name, func.parameters)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let tools = Tools {
|
||||||
|
functions_map: FunctionsMap { functions },
|
||||||
|
properties: Properties {
|
||||||
|
function: tools_to_use
|
||||||
|
.iter()
|
||||||
|
.map(|tool| FunctionRef {
|
||||||
|
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
||||||
|
(
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: e.to_string(),
|
||||||
|
error_type: "Input validation error".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
||||||
|
Some(GrammarType::Json(serde_json::json!(tools)))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
// build the request passing some parameters
|
// build the request passing some parameters
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
inputs: inputs.to_string(),
|
inputs: inputs.to_string(),
|
||||||
@ -617,7 +676,7 @@ async fn chat_completions(
|
|||||||
decoder_input_details: !stream,
|
decoder_input_details: !stream,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: None,
|
grammar: tool_grammar.clone(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -640,11 +699,19 @@ async fn chat_completions(
|
|||||||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
|
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// replace the content with the tool calls if grammar is present
|
||||||
|
let (content, tool_calls) = if tool_grammar.is_some() {
|
||||||
|
(None, Some(vec![stream_token.token.text]))
|
||||||
|
} else {
|
||||||
|
(Some(stream_token.token.text), None)
|
||||||
|
};
|
||||||
|
|
||||||
event
|
event
|
||||||
.json_data(ChatCompletionChunk::new(
|
.json_data(ChatCompletionChunk::new(
|
||||||
model_id.clone(),
|
model_id.clone(),
|
||||||
system_fingerprint.clone(),
|
system_fingerprint.clone(),
|
||||||
stream_token.token.text,
|
content,
|
||||||
|
tool_calls,
|
||||||
current_time,
|
current_time,
|
||||||
stream_token.index,
|
stream_token.index,
|
||||||
logprobs,
|
logprobs,
|
||||||
@ -681,14 +748,54 @@ async fn chat_completions(
|
|||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
|
let (tool_calls, output) = if tool_grammar.is_some() {
|
||||||
|
// gen_text should be valid json
|
||||||
|
let gen_text_value: Value =
|
||||||
|
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||||
|
(
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: e.to_string(),
|
||||||
|
error_type: "Input validation error".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let tool_call = Some(ToolCall {
|
||||||
|
id: 0,
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: FunctionDefinition {
|
||||||
|
description: None,
|
||||||
|
name: "tools".to_string(),
|
||||||
|
parameters: gen_text_value.get("function").map_or_else(
|
||||||
|
|| {
|
||||||
|
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||||
|
(
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: e.to_string(),
|
||||||
|
error_type: "Input validation error".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
|f| Ok(f.clone()),
|
||||||
|
)?,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
(tool_call, None)
|
||||||
|
} else {
|
||||||
|
(None, Some(generation.generated_text))
|
||||||
|
};
|
||||||
// build the complete response object with the full text
|
// build the complete response object with the full text
|
||||||
let response = ChatCompletion::new(
|
let response = ChatCompletion::new(
|
||||||
model_id,
|
model_id,
|
||||||
system_fingerprint,
|
system_fingerprint,
|
||||||
generation.generated_text,
|
output,
|
||||||
current_time,
|
current_time,
|
||||||
generation.details.unwrap(),
|
generation.details.unwrap(),
|
||||||
logprobs,
|
logprobs,
|
||||||
|
tool_calls,
|
||||||
);
|
);
|
||||||
|
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
|
@ -154,12 +154,8 @@ def download_weights(
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
medusa_head = hf_hub_download(
|
medusa_head = hf_hub_download(
|
||||||
model_id, revision=revision, filename="medusa_lm_head.pt"
|
model_id, revision=revision, filename="medusa_lm_head.safetensors"
|
||||||
)
|
)
|
||||||
if auto_convert:
|
|
||||||
medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors")
|
|
||||||
if not medusa_sf.exists():
|
|
||||||
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
|
||||||
medusa_config = hf_hub_download(
|
medusa_config = hf_hub_download(
|
||||||
model_id, revision=revision, filename="config.json"
|
model_id, revision=revision, filename="config.json"
|
||||||
)
|
)
|
||||||
@ -198,16 +194,12 @@ def download_weights(
|
|||||||
if not extension == ".safetensors" or not auto_convert:
|
if not extension == ".safetensors" or not auto_convert:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
elif (Path(model_id) / "medusa_lm_head.pt").exists():
|
elif (Path(model_id) / "medusa_lm_head.safetensors").exists():
|
||||||
# Try to load as a local Medusa model
|
# Try to load as a local Medusa model
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
medusa_head = Path(model_id) / "medusa_lm_head.pt"
|
medusa_head = Path(model_id) / "medusa_lm_head.safetensors"
|
||||||
if auto_convert:
|
|
||||||
medusa_sf = Path(model_id) / "medusa_lm_head.safetensors"
|
|
||||||
if not medusa_sf.exists():
|
|
||||||
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
|
||||||
medusa_config = Path(model_id) / "config.json"
|
medusa_config = Path(model_id) / "config.json"
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
@ -3,7 +3,9 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
@ -65,6 +67,7 @@ try:
|
|||||||
from text_generation_server.models.flash_mistral import FlashMistral
|
from text_generation_server.models.flash_mistral import FlashMistral
|
||||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||||
from text_generation_server.models.flash_phi import FlashPhi
|
from text_generation_server.models.flash_phi import FlashPhi
|
||||||
|
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||||
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@ -82,6 +85,7 @@ if FLASH_ATTENTION:
|
|||||||
__all__.append(FlashMixtral)
|
__all__.append(FlashMixtral)
|
||||||
__all__.append(FlashPhi)
|
__all__.append(FlashPhi)
|
||||||
__all__.append(FlashQwen2)
|
__all__.append(FlashQwen2)
|
||||||
|
__all__.append(FlashStarcoder2)
|
||||||
|
|
||||||
MAMBA_AVAILABLE = True
|
MAMBA_AVAILABLE = True
|
||||||
try:
|
try:
|
||||||
@ -119,44 +123,14 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
set_speculate(0)
|
set_speculate(0)
|
||||||
|
|
||||||
if "facebook/galactica" in model_id:
|
|
||||||
return GalacticaSharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_id.startswith("bigcode/"):
|
|
||||||
if FLASH_ATTENTION:
|
|
||||||
return FlashSantacoderSharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
elif sharded:
|
|
||||||
raise NotImplementedError(
|
|
||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return SantaCoder(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
|
||||||
use_medusa = None
|
use_medusa = None
|
||||||
if "medusa_num_heads" in config_dict:
|
if "medusa_num_heads" in config_dict:
|
||||||
use_medusa = model_id
|
medusa_model_id = model_id
|
||||||
|
medusa_revision = revision
|
||||||
model_id = config_dict["base_model_name_or_path"]
|
model_id = config_dict["base_model_name_or_path"]
|
||||||
revision = "main"
|
revision = "main"
|
||||||
speculate_medusa = config_dict["medusa_num_heads"]
|
speculate_medusa = config_dict["medusa_num_heads"]
|
||||||
@ -173,6 +147,20 @@ def get_model(
|
|||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
is_local = Path(medusa_model_id).exists()
|
||||||
|
if not is_local:
|
||||||
|
medusa_config = hf_hub_download(
|
||||||
|
medusa_model_id, revision=medusa_revision, filename="config.json"
|
||||||
|
)
|
||||||
|
hf_hub_download(
|
||||||
|
medusa_model_id,
|
||||||
|
revision=medusa_revision,
|
||||||
|
filename="medusa_lm_head.safetensors",
|
||||||
|
)
|
||||||
|
use_medusa = Path(medusa_config).parent
|
||||||
|
else:
|
||||||
|
use_medusa = Path(medusa_model_id)
|
||||||
|
|
||||||
method = "medusa"
|
method = "medusa"
|
||||||
else:
|
else:
|
||||||
method = "n-gram"
|
method = "n-gram"
|
||||||
@ -197,16 +185,32 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "gpt_bigcode":
|
if model_id.startswith("facebook/galactica"):
|
||||||
|
return GalacticaSharded(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
model_type == "gpt_bigcode"
|
||||||
|
or model_type == "gpt2"
|
||||||
|
and model_id.startswith("bigcode/")
|
||||||
|
):
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashSantacoderSharded(
|
return FlashSantacoderSharded(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -219,6 +223,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -228,6 +233,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -236,6 +242,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -246,6 +253,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -254,6 +262,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -262,6 +271,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -272,15 +282,16 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
use_medusa=use_medusa,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -295,6 +306,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -305,9 +317,9 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
use_medusa=use_medusa,
|
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||||
@ -316,6 +328,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -346,9 +359,9 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
use_medusa=use_medusa,
|
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -359,6 +372,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -372,6 +386,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -382,6 +397,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -390,6 +406,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -403,6 +420,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -413,6 +431,19 @@ def get_model(
|
|||||||
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
||||||
) or HAS_FLASH_ATTN_V2_CUDA:
|
) or HAS_FLASH_ATTN_V2_CUDA:
|
||||||
return FlashMixtral(
|
return FlashMixtral(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
if model_type == "starcoder2":
|
||||||
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
|
if (
|
||||||
|
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
||||||
|
) or HAS_FLASH_ATTN_V2_CUDA:
|
||||||
|
return FlashStarcoder2(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -425,6 +456,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -434,6 +466,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -443,6 +476,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -466,6 +500,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -474,6 +509,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -485,6 +521,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -493,6 +530,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -42,6 +42,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -70,6 +71,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
)
|
)
|
||||||
config.pad_token_id = 3
|
config.pad_token_id = 3
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
@ -103,7 +105,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
):
|
):
|
||||||
outputs = self.model.forward(
|
outputs, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -112,4 +114,4 @@ class BLOOMSharded(CausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
return logits, outputs.past_key_values
|
return logits, speculative_logits, outputs.past_key_values
|
||||||
|
@ -482,6 +482,7 @@ class CausalLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -550,7 +551,9 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[
|
||||||
|
torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
|
||||||
|
]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
@ -563,7 +566,11 @@ class CausalLM(Model):
|
|||||||
kwargs["position_ids"] = position_ids
|
kwargs["position_ids"] = position_ids
|
||||||
|
|
||||||
outputs = self.model.forward(**kwargs)
|
outputs = self.model.forward(**kwargs)
|
||||||
return outputs.logits, outputs.past_key_values
|
if isinstance(outputs, tuple):
|
||||||
|
outputs, speculative_logits = outputs
|
||||||
|
else:
|
||||||
|
speculative_logits = None
|
||||||
|
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
@ -573,7 +580,7 @@ class CausalLM(Model):
|
|||||||
# slice the attention mask to the correct shape
|
# slice the attention mask to the correct shape
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||||
|
|
||||||
logits, past = self.forward(
|
logits, speculative_logits, past = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
|
@ -36,7 +36,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
|
|
||||||
CUSTOM_KERNELS_ENABLED = False
|
CUSTOM_KERNELS_ENABLED = False
|
||||||
@ -820,7 +820,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.transformer = BloomModel(config, weights)
|
self.transformer = BloomModel(config, weights)
|
||||||
|
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="word_embeddings",
|
prefix="word_embeddings",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -904,17 +904,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|||||||
)
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (lm_logits,) + transformer_outputs[1:]
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return CausalLMOutputWithCrossAttentions(
|
return (
|
||||||
|
CausalLMOutputWithCrossAttentions(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=lm_logits,
|
logits=logits,
|
||||||
past_key_values=transformer_outputs.past_key_values,
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
|
),
|
||||||
|
speculative_logits,
|
||||||
)
|
)
|
||||||
|
@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
@ -575,7 +575,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = FlashGemmaModel(config, weights)
|
self.model = FlashGemmaModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
|
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -592,7 +592,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
@ -605,5 +605,5 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
return logits
|
return logits, speculative_logits
|
||||||
|
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
@ -410,7 +410,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = FlashLlamaModel(config, weights)
|
self.model = FlashLlamaModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@ -427,7 +427,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
@ -440,5 +440,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
return logits
|
return logits, speculative_logits
|
||||||
|
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
@ -419,7 +419,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = MistralModel(config, weights)
|
self.model = MistralModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -810,7 +810,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = MixtralModel(config, weights)
|
self.model = MixtralModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -33,7 +33,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
get_linear,
|
get_linear,
|
||||||
@ -369,7 +369,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.gpt_neox = FlashGPTNeoXModel(config, weights)
|
self.gpt_neox = FlashGPTNeoXModel(config, weights)
|
||||||
|
|
||||||
self.embed_out = TensorParallelHead.load(
|
self.embed_out = SpeculativeHead.load(
|
||||||
config, prefix="embed_out", weights=weights
|
config, prefix="embed_out", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
@ -376,7 +376,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = FlashPhiModel(config, weights)
|
self.model = FlashPhiModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
get_linear,
|
get_linear,
|
||||||
@ -613,9 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
|
|
||||||
self.transformer = FlashRWModel(config, weights)
|
self.transformer = FlashRWModel(config, weights)
|
||||||
|
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
|
||||||
config, prefix="lm_head", weights=weights
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -9,7 +9,7 @@ from text_generation_server.utils import paged_attention, flash_attn
|
|||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
get_linear,
|
get_linear,
|
||||||
@ -453,7 +453,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.transformer = FlashSantacoderModel(config, weights)
|
self.transformer = FlashSantacoderModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="transformer.wte", weights=weights
|
config, prefix="transformer.wte", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -0,0 +1,545 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
SpeculativeHead,
|
||||||
|
get_linear,
|
||||||
|
FastRMSNorm,
|
||||||
|
FastLayerNorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Starcoder2Config(PretrainedConfig):
|
||||||
|
model_type = "starcoder2"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=49152,
|
||||||
|
hidden_size=3072,
|
||||||
|
intermediate_size=12288,
|
||||||
|
num_hidden_layers=30,
|
||||||
|
num_attention_heads=24,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
mlp_type="default",
|
||||||
|
hidden_act="gelu_pytorch_tanh",
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
initializer_range=0.018042,
|
||||||
|
norm_type="layer_norm",
|
||||||
|
norm_epsilon=1e-5,
|
||||||
|
use_cache=True,
|
||||||
|
bos_token_id=50256,
|
||||||
|
eos_token_id=50256,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
sliding_window=None,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
residual_dropout=0.0,
|
||||||
|
embedding_dropout=0.0,
|
||||||
|
use_bias: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.use_bias = use_bias
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.mlp_type = mlp_type
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.norm_epsilon = norm_epsilon
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.residual_dropout = residual_dropout
|
||||||
|
self.embedding_dropout = embedding_dropout
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_attention(config, prefix, weights):
|
||||||
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
|
return _load_gqa(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=config.use_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gqa(config, prefix: str, weights):
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
quantize=config.quantize,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.quantize not in ["gptq", "awq"]:
|
||||||
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
|
assert list(weight.shape) == [
|
||||||
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
|
config.hidden_size,
|
||||||
|
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
|
if config.use_bias:
|
||||||
|
w = [
|
||||||
|
weights.get_sharded(f"{p}.bias", dim=0)
|
||||||
|
for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
|
||||||
|
]
|
||||||
|
bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(
|
||||||
|
get_linear(weight, bias=bias, quantize=config.quantize)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Starcoder2Attention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.max_past = (
|
||||||
|
config.sliding_window if config.sliding_window is not None else -1
|
||||||
|
)
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.use_bias,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
|
else:
|
||||||
|
kv_to_cache = kv
|
||||||
|
|
||||||
|
paged_attention.reshape_and_cache(
|
||||||
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
|
)
|
||||||
|
|
||||||
|
# output tensor
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# flash attention
|
||||||
|
flash_attn.attention(
|
||||||
|
query,
|
||||||
|
torch.select(kv, dim=1, index=0),
|
||||||
|
torch.select(kv, dim=1, index=1),
|
||||||
|
attn_output,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
max_s,
|
||||||
|
self.softmax_scale,
|
||||||
|
window_size_left=self.max_past,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
paged_attention.attention(
|
||||||
|
attn_output,
|
||||||
|
query,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
class Starcoder2MLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
act = config.hidden_act
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Fuse gate and up proj
|
||||||
|
self.c_fc = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.c_fc",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.use_bias,
|
||||||
|
)
|
||||||
|
self.c_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.c_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.use_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.c_fc(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
return self.c_proj(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class Starcoder2GatedMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
act = config.hidden_act
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Fuse gate and up proj
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=config.use_bias,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=config.use_bias,
|
||||||
|
)
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||||
|
|
||||||
|
|
||||||
|
STARCODER2_NORMALIZATION_CLASSES = {
|
||||||
|
"layer_norm": FastLayerNorm,
|
||||||
|
"rms_norm": FastRMSNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
STARCODER2_MLP_CLASSES = {
|
||||||
|
"default": Starcoder2MLP,
|
||||||
|
"gated": Starcoder2GatedMLP,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Starcoder2Layer(nn.Module):
|
||||||
|
def __init__(self, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"model.layers.{layer_id}"
|
||||||
|
self.self_attn = Starcoder2Attention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.norm_epsilon
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[
|
||||||
|
config.norm_type
|
||||||
|
].load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.norm_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
):
|
||||||
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||||
|
attn_output, res
|
||||||
|
)
|
||||||
|
|
||||||
|
mlp_output = self.mlp(normed_attn_res_output)
|
||||||
|
|
||||||
|
return mlp_output, attn_res
|
||||||
|
|
||||||
|
|
||||||
|
class Starcoder2Model(torch.nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix="model.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Starcoder2Layer(
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
||||||
|
prefix="model.norm", weights=weights, eps=config.norm_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
true_max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
|
position_ids, true_max_s, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = Starcoder2Model(config, weights)
|
||||||
|
try:
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
self.lm_head = SpeculativeHead.load(
|
||||||
|
config,
|
||||||
|
prefix="model.embed_tokens",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.max_past = config.sliding_window
|
||||||
|
self.max_past_tensor = (
|
||||||
|
torch.tensor(config.sliding_window, device=weights.device)
|
||||||
|
if self.max_past is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
true_max_s = max_s
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||||
|
slots = slots[prefill_cache_indices]
|
||||||
|
elif self.max_past is not None:
|
||||||
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
|
# kernel requires the true values
|
||||||
|
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||||
|
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
true_max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
return logits
|
@ -51,7 +51,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
@ -272,9 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
|||||||
weights,
|
weights,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc = TensorParallelHead.load(
|
self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights)
|
||||||
config=config, prefix="lm_head", weights=weights
|
|
||||||
)
|
|
||||||
self.additional_fc = FastLinear.load(
|
self.additional_fc = FastLinear.load(
|
||||||
config=config,
|
config=config,
|
||||||
prefix="lm_head.additional_fc",
|
prefix="lm_head.additional_fc",
|
||||||
@ -283,11 +281,11 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
output = self.fc(input)
|
output, speculative_logits = self.fc(input)
|
||||||
additional_features = self.additional_fc(input)
|
additional_features = self.additional_fc(input)
|
||||||
output = torch.cat((output, additional_features), -1)
|
output = torch.cat((output, additional_features), -1)
|
||||||
|
|
||||||
return output
|
return output, speculative_logits
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
|
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
|
||||||
@ -1503,17 +1501,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
return CausalLMOutputWithPastImage(
|
return (
|
||||||
|
CausalLMOutputWithPastImage(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
past_key_values=outputs.past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
image_hidden_states=outputs.image_hidden_states,
|
image_hidden_states=outputs.image_hidden_states,
|
||||||
|
),
|
||||||
|
speculative_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||||
|
@ -9,6 +9,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
|
SpeculativeHead,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
@ -205,14 +206,12 @@ class MambaModel(nn.Module):
|
|||||||
self.norm_f = FastRMSNorm.load(
|
self.norm_f = FastRMSNorm.load(
|
||||||
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
||||||
)
|
)
|
||||||
self.lm_head = FastLinear.load(
|
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
|
||||||
config, f"{prefix}.embedding", weights, bias=False
|
|
||||||
)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids: torch.Tensor, inference_params=None, residual=None
|
self, input_ids: torch.Tensor, inference_params=None, residual=None
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
hidden_states, residual, conv_state, ssm_state = block(
|
hidden_states, residual, conv_state, ssm_state = block(
|
||||||
@ -226,8 +225,8 @@ class MambaModel(nn.Module):
|
|||||||
)
|
)
|
||||||
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
||||||
hidden_states = hidden_states.view(residual.shape)
|
hidden_states = hidden_states.view(residual.shape)
|
||||||
logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
# update the offset for the next inference using these params
|
# update the offset for the next inference using these params
|
||||||
inference_params.seqlen_offset += input_ids.size(1)
|
inference_params.seqlen_offset += input_ids.size(1)
|
||||||
return logits
|
return logits, speculative_logits
|
||||||
|
@ -21,7 +21,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1090,7 +1090,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|||||||
if not config.tie_word_embeddings:
|
if not config.tie_word_embeddings:
|
||||||
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
||||||
self.transformer = MPTModel(config, weights)
|
self.transformer = MPTModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="transformer.wte", weights=weights
|
config, prefix="transformer.wte", weights=weights
|
||||||
)
|
)
|
||||||
self.logit_scale = None
|
self.logit_scale = None
|
||||||
@ -1133,7 +1133,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
logits = self.lm_head(outputs.last_hidden_state)
|
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
|
||||||
if self.logit_scale is not None:
|
if self.logit_scale is not None:
|
||||||
if self.logit_scale == 0:
|
if self.logit_scale == 0:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -1147,12 +1147,15 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
|
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
|
||||||
)
|
)
|
||||||
return CausalLMOutputWithPast(
|
return (
|
||||||
|
CausalLMOutputWithPast(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
past_key_values=outputs.past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
|
),
|
||||||
|
speculative_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
|
@ -44,7 +44,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -646,7 +646,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.gpt_neox = GPTNeoXModel(config, weights)
|
self.gpt_neox = GPTNeoXModel(config, weights)
|
||||||
self.embed_out = TensorParallelHead.load(
|
self.embed_out = SpeculativeHead.load(
|
||||||
config, prefix="embed_out", weights=weights
|
config, prefix="embed_out", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
|
|
||||||
EPS = 1e-5
|
EPS = 1e-5
|
||||||
@ -748,7 +748,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
|||||||
|
|
||||||
self.model = OPTModel(config, weights)
|
self.model = OPTModel(config, weights)
|
||||||
|
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="model.decoder.embed_tokens", weights=weights
|
config, prefix="model.decoder.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ class PhiCausalLMHead(nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
)
|
)
|
||||||
self.linear = TensorParallelHead.load(
|
self.linear = SpeculativeHead.load(
|
||||||
config=config, prefix="lm_head.linear", weights=weights
|
config=config, prefix="lm_head.linear", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1033,14 +1033,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="lm_head", weights=weights
|
config, prefix="lm_head", weights=weights
|
||||||
)
|
)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# Some models like t5-small were saved with shared weights unlike flan
|
# Some models like t5-small were saved with shared weights unlike flan
|
||||||
# Since they are declared as the same arch we have no choice but hope
|
# Since they are declared as the same arch we have no choice but hope
|
||||||
# that this is OK instead of using a proper flag.
|
# that this is OK instead of using a proper flag.
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="shared", weights=weights
|
config, prefix="shared", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1126,7 +1126,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
||||||
sequence_output = sequence_output * (self.model_dim**-0.5)
|
sequence_output = sequence_output * (self.model_dim**-0.5)
|
||||||
|
|
||||||
lm_logits = self.lm_head(sequence_output)
|
logits, speculative_logits = self.lm_head(sequence_output)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
@ -1140,9 +1140,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return Seq2SeqLMOutput(
|
return (
|
||||||
|
Seq2SeqLMOutput(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=lm_logits,
|
logits=logits,
|
||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
@ -1150,6 +1151,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||||
encoder_attentions=encoder_outputs.attentions,
|
encoder_attentions=encoder_outputs.attentions,
|
||||||
|
),
|
||||||
|
speculative_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
|
@ -723,7 +723,7 @@ class FlashCausalLM(Model):
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
@ -734,6 +734,8 @@ class FlashCausalLM(Model):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
)
|
)
|
||||||
|
self.cuda_graphs[bs]["logits"] = logits
|
||||||
|
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
@ -805,7 +807,9 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor:
|
def forward(
|
||||||
|
self, batch: FlashCausalLMBatch
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
@ -900,9 +904,14 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
# Slice output to the correct shape
|
# Slice output to the correct shape
|
||||||
return cuda_graph["logits"][:bs]
|
speculative_logits = (
|
||||||
|
cuda_graph["speculative_logits"][:bs]
|
||||||
|
if cuda_graph["speculative_logits"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
logits = cuda_graph["logits"][:bs]
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
@ -926,16 +935,11 @@ class FlashCausalLM(Model):
|
|||||||
batch.slots = slots
|
batch.slots = slots
|
||||||
|
|
||||||
try:
|
try:
|
||||||
out = self.forward(batch)
|
out, speculative_logits = self.forward(batch)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
del batch
|
del batch
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if isinstance(out, tuple):
|
|
||||||
out, speculative_logits = out
|
|
||||||
else:
|
|
||||||
speculative_logits = None
|
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
next_token_logits = (
|
next_token_logits = (
|
||||||
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||||
|
@ -25,9 +25,9 @@ class FlashGemma(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
use_medusa: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -50,6 +50,7 @@ class FlashGemma(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
@ -59,36 +60,6 @@ class FlashGemma(FlashCausalLM):
|
|||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = FlashGemmaForCausalLM(config, weights)
|
model = FlashGemmaForCausalLM(config, weights)
|
||||||
if use_medusa:
|
|
||||||
from text_generation_server.utils.medusa import MedusaModel
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
is_local_model = (
|
|
||||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
|
||||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
|
||||||
|
|
||||||
if not is_local_model:
|
|
||||||
medusa_config = hf_hub_download(
|
|
||||||
use_medusa, revision=revision, filename="config.json"
|
|
||||||
)
|
|
||||||
medusa_head = hf_hub_download(
|
|
||||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
|
||||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
|
||||||
weights = Weights(
|
|
||||||
[medusa_sf], device, dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
lm_head = model.lm_head
|
|
||||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashGemma, self).__init__(
|
super(FlashGemma, self).__init__(
|
||||||
|
@ -26,9 +26,9 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
use_medusa: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -58,6 +58,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
@ -67,37 +68,6 @@ class FlashLlama(FlashCausalLM):
|
|||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = FlashLlamaForCausalLM(config, weights)
|
model = FlashLlamaForCausalLM(config, weights)
|
||||||
if use_medusa:
|
|
||||||
from text_generation_server.utils.medusa import MedusaModel
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
is_local_model = (
|
|
||||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
|
||||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
|
||||||
|
|
||||||
if not is_local_model:
|
|
||||||
medusa_config = hf_hub_download(
|
|
||||||
use_medusa, revision=revision, filename="config.json"
|
|
||||||
)
|
|
||||||
medusa_head = hf_hub_download(
|
|
||||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
|
||||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
|
||||||
weights = Weights(
|
|
||||||
[medusa_sf], device, dtype, process_group=self.process_group
|
|
||||||
)
|
|
||||||
lm_head = model.lm_head
|
|
||||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashLlama, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -8,7 +8,7 @@ from dataclasses import dataclass
|
|||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.models.llama import LlamaTokenizerFast
|
from transformers.models.llama import LlamaTokenizerFast
|
||||||
from typing import Optional, Tuple, Type, List
|
from typing import Optional, Tuple, Type
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
@ -38,6 +38,19 @@ SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
|||||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||||
|
|
||||||
|
|
||||||
|
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
|
||||||
|
global SLIDING_WINDOW
|
||||||
|
global SLIDING_WINDOW_BLOCKS
|
||||||
|
SLIDING_WINDOW = sliding_window
|
||||||
|
SLIDING_WINDOW_BLOCKS = sliding_window_blocks
|
||||||
|
|
||||||
|
|
||||||
|
def get_sliding_windows() -> Tuple[int, int]:
|
||||||
|
global SLIDING_WINDOW
|
||||||
|
global SLIDING_WINDOW_BLOCKS
|
||||||
|
return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS
|
||||||
|
|
||||||
|
|
||||||
# Adds windowing logic to FlashCausalLMBatch
|
# Adds windowing logic to FlashCausalLMBatch
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashMistralBatch(FlashCausalLMBatch):
|
class FlashMistralBatch(FlashCausalLMBatch):
|
||||||
@ -53,8 +66,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
global SLIDING_WINDOW
|
sliding_window, sliding_window_blocks = get_sliding_windows()
|
||||||
global SLIDING_WINDOW_BLOCKS
|
|
||||||
|
|
||||||
batch_inputs = []
|
batch_inputs = []
|
||||||
max_truncation = 0
|
max_truncation = 0
|
||||||
@ -139,8 +151,8 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
|
|
||||||
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||||
if SLIDING_WINDOW_BLOCKS is not None:
|
if sliding_window_blocks is not None:
|
||||||
needed_blocks = min(needed_blocks, SLIDING_WINDOW_BLOCKS)
|
needed_blocks = min(needed_blocks, sliding_window_blocks)
|
||||||
blocks += needed_blocks
|
blocks += needed_blocks
|
||||||
|
|
||||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||||
@ -154,9 +166,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
|
|
||||||
# Create tensor to slice into the kv tensor in prefill
|
# Create tensor to slice into the kv tensor in prefill
|
||||||
if SLIDING_WINDOW is not None:
|
if sliding_window is not None:
|
||||||
request_prefill_cache_indices = torch.arange(
|
request_prefill_cache_indices = torch.arange(
|
||||||
cumulative_length + max(0, input_length - SLIDING_WINDOW),
|
cumulative_length + max(0, input_length - sliding_window),
|
||||||
cumulative_length + input_length,
|
cumulative_length + input_length,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
@ -212,13 +224,13 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||||
position_ids = torch.cat(position_ids)
|
position_ids = torch.cat(position_ids)
|
||||||
slot_indices = torch.cat(slot_indices)
|
slot_indices = torch.cat(slot_indices)
|
||||||
if SLIDING_WINDOW is not None:
|
if sliding_window is not None:
|
||||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||||
else:
|
else:
|
||||||
input_ids = all_input_ids[0]
|
input_ids = all_input_ids[0]
|
||||||
position_ids = position_ids[0]
|
position_ids = position_ids[0]
|
||||||
slot_indices = slot_indices[0]
|
slot_indices = slot_indices[0]
|
||||||
if SLIDING_WINDOW is not None:
|
if sliding_window is not None:
|
||||||
prefill_cache_indices = prefill_cache_indices[0]
|
prefill_cache_indices = prefill_cache_indices[0]
|
||||||
|
|
||||||
cu_seqlen_prefill = torch.tensor(
|
cu_seqlen_prefill = torch.tensor(
|
||||||
@ -228,7 +240,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||||||
position_ids = position_ids.to(device)
|
position_ids = position_ids.to(device)
|
||||||
slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
prefill_cache_indices = (
|
prefill_cache_indices = (
|
||||||
prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None
|
prefill_cache_indices.to(device) if sliding_window is not None else None
|
||||||
)
|
)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
input_lengths_tensor = torch.tensor(
|
input_lengths_tensor = torch.tensor(
|
||||||
@ -294,12 +306,10 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
global SLIDING_WINDOW
|
|
||||||
global SLIDING_WINDOW_BLOCKS
|
|
||||||
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
@ -319,11 +329,13 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
# Set context windows
|
# Set context windows
|
||||||
if config.sliding_window is not None:
|
if config.sliding_window is not None:
|
||||||
SLIDING_WINDOW = config.sliding_window
|
set_sliding_window(
|
||||||
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
|
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||||
|
)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
@ -394,7 +406,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
@ -406,9 +418,13 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
)
|
)
|
||||||
|
self.cuda_graphs[bs]["logits"] = logits
|
||||||
|
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(
|
||||||
|
self, batch: FlashMistralBatch
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
@ -479,7 +495,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
@ -493,7 +509,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
return logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
@ -511,7 +527,13 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
# Slice output to the correct shape
|
# Slice output to the correct shape
|
||||||
return cuda_graph["logits"][:bs]
|
speculative_logits = (
|
||||||
|
cuda_graph["speculative_logits"][:bs]
|
||||||
|
if cuda_graph["speculative_logits"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
logits = cuda_graph["logits"][:bs]
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
class FlashMistral(BaseFlashMistral):
|
class FlashMistral(BaseFlashMistral):
|
||||||
@ -520,6 +542,7 @@ class FlashMistral(BaseFlashMistral):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -529,6 +552,7 @@ class FlashMistral(BaseFlashMistral):
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -15,6 +15,7 @@ class FlashMixtral(BaseFlashMistral):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -24,6 +25,7 @@ class FlashMixtral(BaseFlashMistral):
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -24,6 +24,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -46,6 +47,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
@ -25,9 +25,9 @@ class FlashPhi(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
use_medusa: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -48,6 +48,7 @@ class FlashPhi(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -61,6 +62,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -51,6 +52,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
config.transpose = config.architectures[0].startswith("GPT2")
|
config.transpose = config.architectures[0].startswith("GPT2")
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
86
server/text_generation_server/models/flash_starcoder2.py
Normal file
86
server/text_generation_server/models/flash_starcoder2.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||||
|
|
||||||
|
from text_generation_server.models.cache_manager import BLOCK_SIZE
|
||||||
|
from text_generation_server.models.flash_mistral import (
|
||||||
|
BaseFlashMistral,
|
||||||
|
set_sliding_window,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||||
|
Starcoder2Config,
|
||||||
|
FlashStarcoder2ForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Starcoder2 has the same base as Mistral
|
||||||
|
class FlashStarcoder2(BaseFlashMistral):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = Starcoder2Config.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
|
# Set context windows
|
||||||
|
if config.sliding_window is not None:
|
||||||
|
set_sliding_window(
|
||||||
|
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
if config.quantize in ["gptq", "awq"]:
|
||||||
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
|
model = FlashStarcoder2ForCausalLM(config, weights)
|
||||||
|
|
||||||
|
self.cuda_graphs = {}
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(BaseFlashMistral, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_layers=len(model.model.layers),
|
||||||
|
num_kv_heads=model.model.num_key_value_heads,
|
||||||
|
head_size=model.model.head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
sliding_window=config.sliding_window,
|
||||||
|
)
|
@ -31,6 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -51,6 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
config.vision_config.quantize = quantize
|
config.vision_config.quantize = quantize
|
||||||
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||||
|
@ -662,8 +662,13 @@ class IdeficsCausalLM(Model):
|
|||||||
if self.has_position_ids:
|
if self.has_position_ids:
|
||||||
kwargs["position_ids"] = position_ids
|
kwargs["position_ids"] = position_ids
|
||||||
|
|
||||||
outputs = self.model.forward(**kwargs)
|
outputs, speculative_logits = self.model.forward(**kwargs)
|
||||||
return outputs.logits, outputs.past_key_values, outputs.image_hidden_states
|
return (
|
||||||
|
outputs.logits,
|
||||||
|
speculative_logits,
|
||||||
|
outputs.past_key_values,
|
||||||
|
outputs.image_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
@ -686,7 +691,7 @@ class IdeficsCausalLM(Model):
|
|||||||
:, : -batch.padding_right_offset
|
:, : -batch.padding_right_offset
|
||||||
]
|
]
|
||||||
|
|
||||||
logits, past, image_hidden_states = self.forward(
|
logits, speculative_logits, past, image_hidden_states = self.forward(
|
||||||
input_ids=batch.input_ids,
|
input_ids=batch.input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=batch.position_ids,
|
position_ids=batch.position_ids,
|
||||||
|
@ -408,6 +408,7 @@ class Mamba(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -444,6 +445,7 @@ class Mamba(Model):
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
@ -505,7 +507,7 @@ class Mamba(Model):
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids, inference_params=inference_params
|
input_ids=input_ids, inference_params=inference_params
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -514,6 +516,7 @@ class Mamba(Model):
|
|||||||
"inference_params": inference_params,
|
"inference_params": inference_params,
|
||||||
"graph": graph,
|
"graph": graph,
|
||||||
"logits": logits,
|
"logits": logits,
|
||||||
|
"speculative_logits": speculative_logits,
|
||||||
}
|
}
|
||||||
self.cuda_graphs[batch_size] = graph_dict
|
self.cuda_graphs[batch_size] = graph_dict
|
||||||
|
|
||||||
@ -556,9 +559,14 @@ class Mamba(Model):
|
|||||||
inference_params.ssm_states.copy_(
|
inference_params.ssm_states.copy_(
|
||||||
cuda_graph["inference_params"].ssm_states[:, :bs]
|
cuda_graph["inference_params"].ssm_states[:, :bs]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Slice output to the correct shape
|
# Slice output to the correct shape
|
||||||
return cuda_graph["logits"][:bs]
|
speculative_logits = (
|
||||||
|
cuda_graph["speculative_logits"][:bs]
|
||||||
|
if cuda_graph["speculative_logits"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
logits = cuda_graph["logits"][:bs]
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
@ -589,7 +597,9 @@ class Mamba(Model):
|
|||||||
batch.inference_params = inference_params
|
batch.inference_params = inference_params
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
logits = self.forward(input_ids, inference_params=batch.inference_params)
|
logits, speculative_logits = self.forward(
|
||||||
|
input_ids, inference_params=batch.inference_params
|
||||||
|
)
|
||||||
|
|
||||||
# batch.inference_params = new_inference_params
|
# batch.inference_params = new_inference_params
|
||||||
# Results
|
# Results
|
||||||
|
@ -43,6 +43,7 @@ class MPTSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -75,6 +76,7 @@ class MPTSharded(CausalLM):
|
|||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
config = PretrainedConfig(**config)
|
config = PretrainedConfig(**config)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ class OPTSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -47,6 +48,7 @@ class OPTSharded(CausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -22,6 +22,7 @@ class Phi(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -52,6 +53,7 @@ class Phi(CausalLM):
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
@ -19,6 +19,7 @@ class SantaCoder(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -532,6 +532,7 @@ class Seq2SeqLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -596,6 +597,7 @@ class Seq2SeqLM(Model):
|
|||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
|
Optional[torch.Tensor],
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||||
]:
|
]:
|
||||||
@ -609,8 +611,15 @@ class Seq2SeqLM(Model):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
if isinstance(outputs, tuple):
|
||||||
|
# Our custom models
|
||||||
|
outputs, speculative_logits = outputs
|
||||||
|
else:
|
||||||
|
# Generic transformers models
|
||||||
|
speculative_logits = None
|
||||||
return (
|
return (
|
||||||
outputs.logits,
|
outputs.logits,
|
||||||
|
speculative_logits,
|
||||||
outputs.encoder_last_hidden_state,
|
outputs.encoder_last_hidden_state,
|
||||||
outputs.past_key_values,
|
outputs.past_key_values,
|
||||||
)
|
)
|
||||||
@ -635,7 +644,7 @@ class Seq2SeqLM(Model):
|
|||||||
else:
|
else:
|
||||||
encoder_last_hidden_state = None
|
encoder_last_hidden_state = None
|
||||||
|
|
||||||
logits, encoder_last_hidden_state, past = self.forward(
|
logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
batch.attention_mask,
|
batch.attention_mask,
|
||||||
batch.decoder_input_ids,
|
batch.decoder_input_ids,
|
||||||
|
@ -25,6 +25,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -42,6 +43,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
@ -94,7 +96,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||||
]:
|
]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
outputs = self.model.forward(
|
outputs, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
@ -106,6 +108,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
outputs.logits,
|
outputs.logits,
|
||||||
|
speculative_logits,
|
||||||
outputs.encoder_last_hidden_state,
|
outputs.encoder_last_hidden_state,
|
||||||
outputs.past_key_values,
|
outputs.past_key_values,
|
||||||
)
|
)
|
||||||
|
@ -40,6 +40,7 @@ def _weight_hub_files_from_model_info(
|
|||||||
and "arguments" not in s.rfilename
|
and "arguments" not in s.rfilename
|
||||||
and "args" not in s.rfilename
|
and "args" not in s.rfilename
|
||||||
and "training" not in s.rfilename
|
and "training" not in s.rfilename
|
||||||
|
and "medusa_lm_head" not in s.rfilename
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -56,6 +57,7 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|||||||
and "args" not in f
|
and "args" not in f
|
||||||
and "adapter" not in f
|
and "adapter" not in f
|
||||||
and "training" not in f
|
and "training" not in f
|
||||||
|
and "medusa_lm_head" not in f
|
||||||
]
|
]
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import torch.distributed
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from typing import List
|
from typing import List, Tuple, Optional
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
@ -380,6 +380,96 @@ class SuperLayer(nn.Module):
|
|||||||
return self.linear.forward(x)
|
return self.linear.forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = FastLinear.load(
|
||||||
|
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.act(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaModel(torch.nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
MedusaHead(config, prefix=f"{i}", weights=weights)
|
||||||
|
for i in range(config["medusa_num_heads"])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||||
|
return speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaHead(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||||
|
for i in range(config["medusa_num_layers"])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
n = len(self.blocks)
|
||||||
|
self.out = FastLinear.load(
|
||||||
|
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
x = self.out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpeculativeHead(nn.Module):
|
||||||
|
def __init__(self, lm_head, medusa):
|
||||||
|
super().__init__()
|
||||||
|
self.lm_head = lm_head
|
||||||
|
self.medusa = medusa
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(config, prefix: str, weights):
|
||||||
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
use_medusa = config.use_medusa
|
||||||
|
if use_medusa:
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors import safe_open
|
||||||
|
import json
|
||||||
|
|
||||||
|
medusa_config = str(Path(use_medusa) / "config.json")
|
||||||
|
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||||
|
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
routing = weights.routing
|
||||||
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if k in routing:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
|
)
|
||||||
|
weights.routing[k] = filename
|
||||||
|
|
||||||
|
medusa = MedusaModel(config, weights)
|
||||||
|
else:
|
||||||
|
medusa = None
|
||||||
|
return SpeculativeHead(lm_head, medusa)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
logits = self.lm_head(input)
|
||||||
|
speculative_logits = self.medusa(input) if self.medusa is not None else None
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelHead(SuperLayer):
|
class TensorParallelHead(SuperLayer):
|
||||||
def __init__(self, linear, process_group, should_gather: bool):
|
def __init__(self, linear, process_group, should_gather: bool):
|
||||||
super().__init__(linear)
|
super().__init__(linear)
|
||||||
|
@ -1,59 +0,0 @@
|
|||||||
import torch
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Output:
|
|
||||||
logits: torch.FloatTensor = None
|
|
||||||
speculative_logits: torch.FloatTensor = None
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(torch.nn.Module):
|
|
||||||
def __init__(self, config, prefix, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.linear = FastLinear.load(
|
|
||||||
config, prefix=f"{prefix}.linear", weights=weights, bias=True
|
|
||||||
)
|
|
||||||
self.act = torch.nn.SiLU()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x + self.act(self.linear(x))
|
|
||||||
|
|
||||||
|
|
||||||
class MedusaModel(torch.nn.Module):
|
|
||||||
def __init__(self, config, weights, lm_head):
|
|
||||||
super().__init__()
|
|
||||||
self.heads = torch.nn.ModuleList(
|
|
||||||
[
|
|
||||||
MedusaHead(config, prefix=f"{i}", weights=weights)
|
|
||||||
for i in range(config["medusa_num_heads"])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.lm_head = lm_head
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
logits = self.lm_head(x)
|
|
||||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
|
||||||
return logits, speculative_logits
|
|
||||||
|
|
||||||
|
|
||||||
class MedusaHead(torch.nn.Module):
|
|
||||||
def __init__(self, config, prefix, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.blocks = torch.nn.ModuleList(
|
|
||||||
[
|
|
||||||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
|
||||||
for i in range(config["medusa_num_layers"])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
n = len(self.blocks)
|
|
||||||
self.out = FastLinear.load(
|
|
||||||
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x)
|
|
||||||
x = self.out(x)
|
|
||||||
return x
|
|
Loading…
Reference in New Issue
Block a user