Merge branch 'main' into main

This commit is contained in:
Marcus Dunn 2023-08-28 14:12:39 -07:00 committed by GitHub
commit d4e2ce7e7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 615 additions and 34 deletions

View File

@ -86,7 +86,7 @@ The easiest way of getting started is using the official Docker container:
model=tiiuae/falcon-7b-instruct model=tiiuae/falcon-7b-instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.1 --model-id $model docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.2 --model-id $model
``` ```
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
@ -153,7 +153,7 @@ model=meta-llama/Llama-2-7b-chat-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=<your cli READ token> token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.1 --model-id $model docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.2 --model-id $model
``` ```
### A note on Shared Memory (shm) ### A note on Shared Memory (shm)

View File

@ -37,6 +37,7 @@ pub(crate) async fn generation_task(
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
parameters: NextTokenChooserParameters, parameters: NextTokenChooserParameters,
@ -48,7 +49,7 @@ pub(crate) async fn generation_task(
// End task if a message is received on shutdown_receiver // End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished // _shutdown_guard_sender will be dropped once the task is finished
tokio::select! { tokio::select! {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => { res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
if let Err(err) = res { if let Err(err) = res {
run_sender.send(Err(err)).await.unwrap_or(()); run_sender.send(Err(err)).await.unwrap_or(());
} }
@ -64,6 +65,7 @@ async fn generate_runs(
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
parameters: NextTokenChooserParameters, parameters: NextTokenChooserParameters,
@ -82,6 +84,7 @@ async fn generate_runs(
b, b,
decode_length, decode_length,
parameters.clone(), parameters.clone(),
top_n_tokens,
&mut client, &mut client,
) )
.await?; .await?;
@ -97,6 +100,7 @@ async fn generate_runs(
b, b,
decode_length, decode_length,
parameters.clone(), parameters.clone(),
top_n_tokens,
&mut client, &mut client,
) )
.await?; .await?;
@ -130,6 +134,7 @@ async fn prefill(
batch_size: u32, batch_size: u32,
decode_length: u32, decode_length: u32,
parameters: NextTokenChooserParameters, parameters: NextTokenChooserParameters,
top_n_tokens: Option<u32>,
client: &mut ShardedClient, client: &mut ShardedClient,
) -> Result<(Prefill, CachedBatch), ClientError> { ) -> Result<(Prefill, CachedBatch), ClientError> {
// Create requests // Create requests
@ -145,6 +150,7 @@ async fn prefill(
stop_sequences: vec![], stop_sequences: vec![],
ignore_eos_token: true, // Will not stop even if a eos token is generated ignore_eos_token: true, // Will not stop even if a eos token is generated
}), }),
top_n_tokens: top_n_tokens.unwrap_or(0),
}) })
.collect(); .collect();

View File

@ -22,6 +22,7 @@ pub async fn run(
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
temperature: Option<f32>, temperature: Option<f32>,
@ -75,6 +76,7 @@ pub async fn run(
batch_size.clone(), batch_size.clone(),
sequence_length, sequence_length,
decode_length, decode_length,
top_n_tokens,
n_runs, n_runs,
warmups, warmups,
parameters, parameters,
@ -135,6 +137,7 @@ pub async fn run(
tokenizer_name, tokenizer_name,
sequence_length, sequence_length,
decode_length, decode_length,
top_n_tokens,
n_runs, n_runs,
warmups, warmups,
temperature, temperature,

View File

@ -94,6 +94,12 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
do_sample: bool, do_sample: bool,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
top_n_tokens: Option<u32>,
/// Generation parameter in case you want to specifically test/debug particular /// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server` /// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env, value_parser=parse_key_val::<String, f32>)] #[clap(long, env, value_parser=parse_key_val::<String, f32>)]
@ -123,6 +129,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
do_sample, do_sample,
master_shard_uds_path, master_shard_uds_path,
logit_bias, logit_bias,
top_n_tokens,
} = args; } = args;
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]); let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
@ -179,6 +186,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
batch_size, batch_size,
sequence_length, sequence_length,
decode_length, decode_length,
top_n_tokens,
runs, runs,
warmups, warmups,
temperature, temperature,

View File

@ -7,6 +7,7 @@ pub(crate) fn parameters_table(
tokenizer_name: String, tokenizer_name: String,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
temperature: Option<f32>, temperature: Option<f32>,
@ -25,6 +26,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Model", &tokenizer_name]); builder.push_record(["Model", &tokenizer_name]);
builder.push_record(["Sequence Length", &sequence_length.to_string()]); builder.push_record(["Sequence Length", &sequence_length.to_string()]);
builder.push_record(["Decode Length", &decode_length.to_string()]); builder.push_record(["Decode Length", &decode_length.to_string()]);
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
builder.push_record(["N Runs", &n_runs.to_string()]); builder.push_record(["N Runs", &n_runs.to_string()]);
builder.push_record(["Warmups", &warmups.to_string()]); builder.push_record(["Warmups", &warmups.to_string()]);
builder.push_record(["Temperature", &format!("{temperature:?}")]); builder.push_record(["Temperature", &format!("{temperature:?}")]);

View File

@ -75,6 +75,7 @@ class Client:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
logit_bias: Dict[str, float] = {}, logit_bias: Dict[str, float] = {},
) -> Response: ) -> Response:
""" """
@ -114,6 +115,8 @@ class Client:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`): decoder_input_details (`bool`):
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
logit_bias (`Dict[str, float]`): logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens. Bias generation towards certain tokens.
@ -137,6 +140,7 @@ class Client:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
decoder_input_details=decoder_input_details, decoder_input_details=decoder_input_details,
top_n_tokens=top_n_tokens,
logit_bias=logit_bias, logit_bias=logit_bias,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -168,6 +172,7 @@ class Client:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
top_n_tokens: Optional[int] = None,
logit_bias: Dict[str, float] = {}, logit_bias: Dict[str, float] = {},
) -> Iterator[StreamResponse]: ) -> Iterator[StreamResponse]:
""" """
@ -203,6 +208,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
logit_bias (`Dict[str, float]`): logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens. Bias generation towards certain tokens.
@ -227,6 +234,7 @@ class Client:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
logit_bias=logit_bias, logit_bias=logit_bias,
top_n_tokens=top_n_tokens,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
@ -326,6 +334,7 @@ class AsyncClient:
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
logit_bias: Dict[str, float] = {}, logit_bias: Dict[str, float] = {},
top_n_tokens: Optional[int] = None,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text asynchronously Given a prompt, generate the following text asynchronously
@ -366,6 +375,8 @@ class AsyncClient:
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
logit_bias (`Dict[str, float]`): logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens. Bias generation towards certain tokens.
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns: Returns:
Response: generated response Response: generated response
@ -388,6 +399,7 @@ class AsyncClient:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
logit_bias=logit_bias, logit_bias=logit_bias,
top_n_tokens=top_n_tokens,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -417,6 +429,7 @@ class AsyncClient:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
logit_bias: Dict[str, float] = {}, logit_bias: Dict[str, float] = {},
top_n_tokens: Optional[int] = None,
) -> AsyncIterator[StreamResponse]: ) -> AsyncIterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens asynchronously Given a prompt, generate the following stream of tokens asynchronously
@ -453,6 +466,8 @@ class AsyncClient:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
logit_bias (`Dict[str, float]`): logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens. Bias generation towards certain tokens.
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns: Returns:
AsyncIterator[StreamResponse]: stream of generated tokens AsyncIterator[StreamResponse]: stream of generated tokens
@ -475,6 +490,7 @@ class AsyncClient:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
logit_bias=logit_bias, logit_bias=logit_bias,
top_n_tokens=top_n_tokens,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)

View File

@ -41,6 +41,8 @@ class Parameters(BaseModel):
decoder_input_details: bool = False decoder_input_details: bool = False
# Bias generation towards certain tokens # Bias generation towards certain tokens
logit_bias: Dict[str, float] = {} logit_bias: Dict[str, float] = {}
# Return the N most likely tokens at each step
top_n_tokens: Optional[int]
@validator("best_of") @validator("best_of")
def valid_best_of(cls, field_value, values): def valid_best_of(cls, field_value, values):
@ -103,6 +105,12 @@ class Parameters(BaseModel):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0") raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v return v
@validator("top_n_tokens")
def valid_top_n_tokens(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_n_tokens` must be strictly positive")
return v
class Request(BaseModel): class Request(BaseModel):
# Prompt # Prompt
@ -127,9 +135,7 @@ class Request(BaseModel):
and parameters.best_of > 1 and parameters.best_of > 1
and field_value and field_value
): ):
raise ValidationError( raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
"`best_of` != 1 is not supported when `stream` == True"
)
return field_value return field_value
@ -181,6 +187,8 @@ class BestOfSequence(BaseModel):
prefill: List[InputToken] prefill: List[InputToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# `generate` details # `generate` details
@ -195,6 +203,8 @@ class Details(BaseModel):
prefill: List[InputToken] prefill: List[InputToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# Additional sequences when using the `best_of` parameter # Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]] best_of_sequences: Optional[List[BestOfSequence]]
@ -221,6 +231,8 @@ class StreamDetails(BaseModel):
class StreamResponse(BaseModel): class StreamResponse(BaseModel):
# Generated token # Generated token
token: Token token: Token
# Most likely tokens
top_tokens: Optional[List[Token]]
# Complete generated text # Complete generated text
# Only available when the generation is finished # Only available when the generation is finished
generated_text: Optional[str] generated_text: Optional[str]

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "1.0.1" "version": "1.0.2"
}, },
"paths": { "paths": {
"/": { "/": {

View File

@ -75,6 +75,81 @@ To serve both ChatUI and TGI in same environment, simply add your own endpoints
![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png) ![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png)
## Gradio
Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first.
```bash
pip install huggingface-hub gradio
```
Assume you are serving your model on port 8080, we will query through [InferenceClient](consuming_tgi#inference-client).
```python
import gradio as gr
from huggingface_hub import InferenceClient
client = InferenceClient(model="http://127.0.0.1:8080")
def inference(message, history):
partial_message = ""
for token in client.text_generation(message, max_new_tokens=20, stream=True):
partial_message += token
yield partial_message
gr.ChatInterface(
inference,
chatbot=gr.Chatbot(height=300),
textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7),
description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.",
title="Gradio 🤝 TGI",
examples=["Are tomatoes vegetables?"],
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
).queue().launch()
```
The UI looks like this 👇
<div class="flex justify-center">
<img
class="block dark:hidden"
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi.png"
/>
<img
class="hidden dark:block"
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi-dark.png"
/>
</div>
You can try the demo directly here 👇
<div class="block dark:hidden">
<iframe
src="https://merve-gradio-tgi-2.hf.space?__theme=light"
width="850"
height="750"
></iframe>
</div>
<div class="hidden dark:block">
<iframe
src="https://merve-gradio-tgi-2.hf.space?__theme=dark"
width="850"
height="750"
></iframe>
</div>
You can disable streaming mode using `return` instead of `yield` in your inference function, like below.
```python
def inference(message, history):
return client.text_generation(message, max_new_tokens=20)
```
You can read more about how to customize a `ChatInterface` [here](https://www.gradio.app/guides/creating-a-chatbot-fast).
## API documentation ## API documentation
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. The Swagger UI is also available [here](https://huggingface.github.io/text-generation-inference). You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. The Swagger UI is also available [here](https://huggingface.github.io/text-generation-inference).

View File

@ -8,7 +8,7 @@ Let's say you want to deploy [Falcon-7B Instruct](https://huggingface.co/tiiuae/
model=tiiuae/falcon-7b-instruct model=tiiuae/falcon-7b-instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.1 --model-id $model docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.2 --model-id $model
``` ```
<Tip warning={true}> <Tip warning={true}>
@ -85,7 +85,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more. To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash ```bash
docker run ghcr.io/huggingface/text-generation-inference:1.0.1 --help docker run ghcr.io/huggingface/text-generation-inference:1.0.2 --help
``` ```
</Tip> </Tip>

View File

@ -159,6 +159,14 @@ struct Args {
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_stop_sequences: usize, max_stop_sequences: usize,
/// This is the maximum allowed value for clients to set `top_n_tokens`.
/// `top_n_tokens is used to return information about the the `n` most likely
/// tokens at each generation step, instead of just the sampled token. This
/// information can be used for downstream tasks like for classification or
/// ranking.
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
/// This is the maximum allowed input length (expressed in number of tokens) /// This is the maximum allowed input length (expressed in number of tokens)
/// for users. The larger this value, the longer prompt users can send which /// for users. The larger this value, the longer prompt users can send which
/// can impact the overall memory required to handle the load. /// can impact the overall memory required to handle the load.
@ -929,6 +937,8 @@ fn spawn_webserver(
args.max_best_of.to_string(), args.max_best_of.to_string(),
"--max-stop-sequences".to_string(), "--max-stop-sequences".to_string(),
args.max_stop_sequences.to_string(), args.max_stop_sequences.to_string(),
"--max-top-n-tokens".to_string(),
args.max_top_n_tokens.to_string(),
"--max-input-length".to_string(), "--max-input-length".to_string(),
args.max_input_length.to_string(), args.max_input_length.to_string(),
"--max-total-tokens".to_string(), "--max-total-tokens".to_string(),

View File

@ -101,6 +101,8 @@ message Request {
StoppingCriteriaParameters stopping_parameters = 5; StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs /// Return prefill logprobs
bool prefill_logprobs = 6; bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
} }
message Batch { message Batch {
@ -151,6 +153,17 @@ message PrefillTokens {
repeated string texts = 3; repeated string texts = 3;
} }
message TopTokens {
/// Top Token IDs
repeated uint32 ids = 1;
/// Top Logprobs
repeated float logprobs = 2;
/// Top Token Texts
repeated string texts = 3;
/// If the tokens are special
repeated bool is_special = 6;
}
message Generation { message Generation {
/// Request ID /// Request ID
uint64 request_id = 1; uint64 request_id = 1;
@ -166,6 +179,8 @@ message Generation {
bool token_is_special = 6; bool token_is_special = 6;
/// Complete generated text /// Complete generated text
optional GeneratedText generated_text = 7; optional GeneratedText generated_text = 7;
/// Top tokens
TopTokens top_tokens = 8;
} }
message FilterBatchRequest { message FilterBatchRequest {

View File

@ -132,6 +132,7 @@ impl Client {
ignore_eos_token: false, ignore_eos_token: false,
}), }),
prefill_logprobs: true, prefill_logprobs: true,
top_n_tokens: 20,
}); });
n_tokens += max_input_length; n_tokens += max_input_length;
} }

View File

@ -51,6 +51,7 @@ impl Health {
stop_sequences: vec![], stop_sequences: vec![],
ignore_eos_token: false, ignore_eos_token: false,
}), }),
top_n_tokens: 0,
}; };
let batch = Batch { let batch = Batch {
id: BATCH_ID, id: BATCH_ID,

View File

@ -138,12 +138,15 @@ impl Infer {
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<InferResponse, InferError> { ) -> Result<InferResponse, InferError> {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives // Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?; let (_permit, mut stream) = self.generate_stream(request).await?;
// Return values // Return values
let mut result_prefill = Vec::new(); let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new(); let mut result_tokens = Vec::new();
let mut result_top_tokens = Vec::new();
let mut result_generated_text = None; let mut result_generated_text = None;
let mut result_start = None; let mut result_start = None;
let mut result_queued = None; let mut result_queued = None;
@ -164,7 +167,10 @@ impl Infer {
.collect(); .collect();
} }
// Push last token // Push last token
InferStreamResponse::Token(token) => result_tokens.push(token), InferStreamResponse::Intermediate { token, top_tokens } => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
}
// Final message // Final message
// Set return values // Set return values
InferStreamResponse::End { InferStreamResponse::End {
@ -172,8 +178,10 @@ impl Infer {
generated_text, generated_text,
start, start,
queued, queued,
top_tokens,
} => { } => {
result_tokens.push(token); result_tokens.push(token);
result_top_tokens.push(top_tokens);
result_generated_text = Some(generated_text); result_generated_text = Some(generated_text);
result_start = Some(start); result_start = Some(start);
result_queued = Some(queued) result_queued = Some(queued)
@ -191,6 +199,11 @@ impl Infer {
generated_text, generated_text,
queued, queued,
start, start,
top_tokens: if use_top_tokens {
result_top_tokens
} else {
Vec::new()
},
}) })
} else { } else {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGeneration;
@ -520,6 +533,26 @@ fn send_responses(
special: generation.token_is_special, special: generation.token_is_special,
}; };
// generation.top_tokens
let mut top_tokens = Vec::new();
if let Some(top_tokens_) = generation.top_tokens {
top_tokens.extend(
top_tokens_
.ids
.into_iter()
.zip(top_tokens_.logprobs.into_iter())
.zip(top_tokens_.texts.into_iter())
.zip(top_tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
id,
text,
logprob,
special,
}),
)
}
if let Some(generated_text) = generation.generated_text { if let Some(generated_text) = generation.generated_text {
// Generation has ended // Generation has ended
stopped = true; stopped = true;
@ -527,6 +560,7 @@ fn send_responses(
entry.response_tx.send_timeout( entry.response_tx.send_timeout(
Ok(InferStreamResponse::End { Ok(InferStreamResponse::End {
token, token,
top_tokens,
generated_text, generated_text,
queued: entry.queue_time, queued: entry.queue_time,
start: entry.batch_time.unwrap(), start: entry.batch_time.unwrap(),
@ -536,7 +570,7 @@ fn send_responses(
} else { } else {
// Send message // Send message
entry.response_tx.send_timeout( entry.response_tx.send_timeout(
Ok(InferStreamResponse::Token(token)), Ok(InferStreamResponse::Intermediate { token, top_tokens }),
Duration::from_millis(10), Duration::from_millis(10),
)?; )?;
} }
@ -566,10 +600,14 @@ pub(crate) enum InferStreamResponse {
// Optional first message // Optional first message
Prefill(PrefillTokens), Prefill(PrefillTokens),
// Intermediate messages // Intermediate messages
Token(Token), Intermediate {
token: Token,
top_tokens: Vec<Token>,
},
// Last message // Last message
End { End {
token: Token, token: Token,
top_tokens: Vec<Token>,
generated_text: GeneratedText, generated_text: GeneratedText,
start: Instant, start: Instant,
queued: Instant, queued: Instant,
@ -583,6 +621,7 @@ pub(crate) struct InferResponse {
pub(crate) generated_text: GeneratedText, pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant, pub(crate) queued: Instant,
pub(crate) start: Instant, pub(crate) start: Instant,
pub(crate) top_tokens: Vec<Vec<Token>>,
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]

View File

@ -139,6 +139,8 @@ pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(default=json!({}), example=json!({"hello": 0.5}))] #[schema(default=json!({}), example=json!({"hello": 0.5}))]
pub logit_bias: BTreeMap<String, f32> pub logit_bias: BTreeMap<String, f32>
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>,
} }
fn default_max_new_tokens() -> u32 { fn default_max_new_tokens() -> u32 {
@ -163,6 +165,7 @@ fn default_parameters() -> GenerateParameters {
decoder_input_details: false, decoder_input_details: false,
seed: None, seed: None,
logit_bias: BTreeMap::new(), logit_bias: BTreeMap::new(),
top_n_tokens: None,
} }
} }
@ -240,6 +243,8 @@ pub(crate) struct BestOfSequence {
pub seed: Option<u64>, pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>, pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>, pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
@ -254,6 +259,8 @@ pub(crate) struct Details {
pub tokens: Vec<Token>, pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub best_of_sequences: Option<Vec<BestOfSequence>>, pub best_of_sequences: Option<Vec<BestOfSequence>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
@ -277,6 +284,8 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse { pub(crate) struct StreamResponse {
pub token: Token, pub token: Token,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Token>,
#[schema(nullable = true, default = "null", example = "test")] #[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>, pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")] #[schema(nullable = true, default = "null")]

View File

@ -29,6 +29,8 @@ struct Args {
max_best_of: usize, max_best_of: usize,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_stop_sequences: usize, max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)] #[clap(default_value = "1024", long, env)]
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "2048", long, env)] #[clap(default_value = "2048", long, env)]
@ -75,6 +77,7 @@ fn main() -> Result<(), RouterError> {
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
@ -259,6 +262,7 @@ fn main() -> Result<(), RouterError> {
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,

View File

@ -235,6 +235,7 @@ impl State {
truncate: entry.request.truncate, truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()), parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()),
top_n_tokens: entry.request.top_n_tokens,
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
@ -329,6 +330,7 @@ mod tests {
max_new_tokens: 1, max_new_tokens: 1,
stop_sequences: vec![], stop_sequences: vec![],
}, },
top_n_tokens: 0,
}, },
response_tx, response_tx,
span: info_span!("entry"), span: info_span!("entry"),

View File

@ -158,7 +158,7 @@ async fn generate(
add_prompt = Some(req.inputs.clone()); add_prompt = Some(req.inputs.clone());
} }
let details = req.parameters.details || req.parameters.decoder_input_details; let details: bool = req.parameters.details || req.parameters.decoder_input_details;
// Inference // Inference
let (response, best_of_responses) = match req.parameters.best_of { let (response, best_of_responses) = match req.parameters.best_of {
@ -191,6 +191,7 @@ async fn generate(
generated_tokens: response.generated_text.generated_tokens, generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill, prefill: response.prefill,
tokens: response.tokens, tokens: response.tokens,
top_tokens: response.top_tokens,
seed: response.generated_text.seed, seed: response.generated_text.seed,
} }
}) })
@ -204,6 +205,7 @@ async fn generate(
tokens: response.tokens, tokens: response.tokens,
seed: response.generated_text.seed, seed: response.generated_text.seed,
best_of_sequences, best_of_sequences,
top_tokens: response.top_tokens,
}) })
} }
false => None, false => None,
@ -385,12 +387,16 @@ async fn generate_stream(
// Prefill is ignored // Prefill is ignored
InferStreamResponse::Prefill(_) => {} InferStreamResponse::Prefill(_) => {}
// Yield event for every new token // Yield event for every new token
InferStreamResponse::Token(token) => { InferStreamResponse::Intermediate{
token,
top_tokens,
} => {
tracing::debug!(parent: &span, "Token: {:?}", token); tracing::debug!(parent: &span, "Token: {:?}", token);
// StreamResponse // StreamResponse
let stream_token = StreamResponse { let stream_token = StreamResponse {
token, token,
top_tokens: top_tokens,
generated_text: None, generated_text: None,
details: None, details: None,
}; };
@ -403,6 +409,7 @@ async fn generate_stream(
generated_text, generated_text,
start, start,
queued, queued,
top_tokens,
} => { } => {
// Token details // Token details
let details = match details { let details = match details {
@ -451,6 +458,7 @@ async fn generate_stream(
let stream_token = StreamResponse { let stream_token = StreamResponse {
token, token,
top_tokens: top_tokens,
generated_text: Some(output_text), generated_text: Some(output_text),
details details
}; };
@ -509,6 +517,7 @@ pub async fn run(
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
waiting_served_ratio: f32, waiting_served_ratio: f32,
@ -571,6 +580,7 @@ pub async fn run(
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );

View File

@ -15,6 +15,7 @@ pub struct Validation {
/// Validation parameters /// Validation parameters
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
/// Channel to communicate with the background tokenization task /// Channel to communicate with the background tokenization task
@ -27,6 +28,7 @@ impl Validation {
tokenizer: Option<Tokenizer>, tokenizer: Option<Tokenizer>,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
) -> Self { ) -> Self {
@ -54,6 +56,7 @@ impl Validation {
max_best_of, max_best_of,
sender, sender,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
} }
@ -143,6 +146,7 @@ impl Validation {
watermark, watermark,
decoder_input_details, decoder_input_details,
logit_bias, logit_bias,
top_n_tokens,
.. ..
} = request.parameters; } = request.parameters;
@ -219,6 +223,15 @@ impl Validation {
} }
}; };
let top_n_tokens = top_n_tokens
.map(|value| {
if value > self.max_top_n_tokens {
return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
}
Ok(value)
})
.unwrap_or(Ok(0))?;
// Check if inputs is empty // Check if inputs is empty
if request.inputs.is_empty() { if request.inputs.is_empty() {
return Err(EmptyInput); return Err(EmptyInput);
@ -268,6 +281,7 @@ impl Validation {
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters, parameters,
stopping_parameters, stopping_parameters,
top_n_tokens: top_n_tokens,
}) })
} }
@ -341,6 +355,7 @@ pub(crate) struct ValidGenerateRequest {
pub decoder_input_details: bool, pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters, pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters, pub stopping_parameters: StoppingCriteriaParameters,
pub top_n_tokens: u32,
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]
@ -355,6 +370,10 @@ pub enum ValidationError {
BestOfSeed, BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")] #[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream, BestOfStream,
#[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
TopNTokens(u32, u32),
#[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
TopNTokensDisabled,
#[error("`decoder_input_details` == true is not supported when streaming tokens")] #[error("`decoder_input_details` == true is not supported when streaming tokens")]
PrefillDetailsStream, PrefillDetailsStream,
#[error("`temperature` must be strictly positive")] #[error("`temperature` must be strictly positive")]
@ -396,14 +415,16 @@ mod tests {
let tokenizer = None; let tokenizer = None;
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_top_n_tokens = 4;
let max_total_tokens = 5; let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
@ -423,14 +444,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_top_n_tokens = 4;
let max_total_tokens = 5; let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
@ -440,7 +463,7 @@ mod tests {
.validate_input("Hello".to_string(), None, max_new_tokens) .validate_input("Hello".to_string(), None, max_new_tokens)
.await .await
{ {
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (), Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
_ => panic!("Unexpected not max new tokens"), _ => panic!("Unexpected not max new tokens"),
} }
} }
@ -450,14 +473,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_top_n_tokens = 4;
let max_total_tokens = 5; let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
@ -482,14 +507,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_top_n_tokens = 4;
let max_total_tokens = 5; let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
@ -536,4 +563,75 @@ mod tests {
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value. // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
assert_eq!(valid_request.parameters.top_p, 1.0); assert_eq!(valid_request.parameters.top_p, 1.0);
} }
#[tokio::test]
async fn test_validation_top_n_tokens() {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequences = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(5),
..default_parameters()
},
})
.await
{
Err(ValidationError::TopNTokens(4, 5)) => (),
_ => panic!("Unexpected top_n_tokens"),
}
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(4),
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(0),
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
let valid_request = validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: None,
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
assert_eq!(valid_request.top_n_tokens, 0);
}
} }

View File

@ -1,7 +1,9 @@
import torch
from text_generation_server.utils.tokens import ( from text_generation_server.utils.tokens import (
StopSequenceCriteria, StopSequenceCriteria,
StoppingCriteria, StoppingCriteria,
FinishReason, FinishReason,
batch_top_tokens,
) )
@ -42,3 +44,22 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5]
top_n_tokens_tensor = torch.tensor(top_n_tokens)
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs)
assert topn_tok_ids[0] == []
assert topn_tok_ids[1] == [0, 3]
assert topn_tok_ids[2] == [0, 3, 1, 4]
assert topn_tok_ids[3] == [0, 3, 1, 4]
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
assert topn_tok_logprobs[0] == []
assert topn_tok_logprobs[1] == [-1, -2]
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]

View File

@ -125,6 +125,9 @@ def download_weights(
try: try:
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json") adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code) utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
is_local_model = True
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass

View File

@ -1,3 +1,4 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
import inspect import inspect
@ -12,6 +13,7 @@ from text_generation_server.models.types import (
PrefillTokens, PrefillTokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -42,6 +44,8 @@ class CausalLMBatch(Batch):
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding # Metadata used for padding
max_input_length: int max_input_length: int
@ -72,6 +76,7 @@ class CausalLMBatch(Batch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
requests_idx_mapping = {} requests_idx_mapping = {}
@ -88,6 +93,7 @@ class CausalLMBatch(Batch):
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max( padding_right_offset = max(
@ -121,6 +127,9 @@ class CausalLMBatch(Batch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens) max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -138,6 +147,8 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens, max_tokens=max_tokens,
@ -163,6 +174,7 @@ class CausalLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
total_remaining_decode_tokens = 0 total_remaining_decode_tokens = 0
new_padding_right_offset = 0 new_padding_right_offset = 0
@ -184,6 +196,7 @@ class CausalLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = ( remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
@ -223,6 +236,7 @@ class CausalLMBatch(Batch):
layer[1] = past_values[keep_indices, :, -past_kv_length:, :] layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values del past_values
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests self.requests = requests
@ -235,6 +249,8 @@ class CausalLMBatch(Batch):
self.read_offsets = read_offsets self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens self.max_tokens = max_tokens
@ -262,6 +278,7 @@ class CausalLMBatch(Batch):
all_input_ids = [] all_input_ids = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
max_tokens = 0 max_tokens = 0
# Batch tensors # Batch tensors
@ -269,6 +286,7 @@ class CausalLMBatch(Batch):
attention_mask = None attention_mask = None
position_ids = None position_ids = None
past_key_values = [] past_key_values = []
top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors # Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes # Equivalent to a cumsum on batch sizes
@ -281,6 +299,7 @@ class CausalLMBatch(Batch):
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0: if i == 0:
requests_idx_mapping = batch.requests_idx_mapping requests_idx_mapping = batch.requests_idx_mapping
@ -310,6 +329,12 @@ class CausalLMBatch(Batch):
(total_batch_size, max_input_length + padding_right_offset), (total_batch_size, max_input_length + padding_right_offset),
) )
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space # and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length left_offset = max_input_length - batch.max_input_length
@ -438,6 +463,8 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length, max_input_length=max_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
@ -549,6 +576,12 @@ class CausalLM(Model):
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True stopped = True
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
)
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
@ -559,6 +592,9 @@ class CausalLM(Model):
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
@ -571,6 +607,9 @@ class CausalLM(Model):
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
@ -637,6 +676,24 @@ class CausalLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
@ -645,6 +702,7 @@ class CausalLM(Model):
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids, next_token_id_squeezed.item() in self.all_special_ids,
generated_text, generated_text,
top_tokens,
) )
generations.append(generation) generations.append(generation)

View File

@ -64,6 +64,7 @@ class LlamaConfig(PretrainedConfig):
pretraining_tp=1, pretraining_tp=1,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_scaling=None, rope_scaling=None,
rope_theta=10000.0,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -84,6 +85,7 @@ class LlamaConfig(PretrainedConfig):
self.pretraining_tp = pretraining_tp self.pretraining_tp = pretraining_tp
self.use_cache = use_cache self.use_cache = use_cache
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
@ -189,7 +191,7 @@ class FlashLlamaAttention(torch.nn.Module):
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights # config=config, prefix=f"{prefix}.rotary_emb", weights=weights
# ) # )
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=self.head_size, base=10000.0, device=weights.device config=config, dim=self.head_size, base=config.rope_theta, device=weights.device
) )
self.softmax_scale = self.head_size**-0.5 self.softmax_scale = self.head_size**-0.5

View File

@ -1,5 +1,6 @@
import math import math
import itertools import itertools
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
import torch.distributed import torch.distributed
@ -16,6 +17,7 @@ from text_generation_server.models.types import (
PrefillTokens, PrefillTokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
@ -165,6 +167,8 @@ class FlashCausalLMBatch(Batch):
# Generation helpers # Generation helpers
next_token_chooser: HeterogeneousNextTokenChooser next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Number of blocks in this batch # Number of blocks in this batch
blocks: int blocks: int
@ -217,6 +221,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters = [] next_token_chooser_parameters = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
@ -259,6 +264,7 @@ class FlashCausalLMBatch(Batch):
) )
max_new_tokens = stopping_criteria.max_new_tokens max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
@ -352,6 +358,9 @@ class FlashCausalLMBatch(Batch):
prefill_next_token_indices = torch.tensor( prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device prefill_next_token_indices, dtype=torch.int64, device=device
) )
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
@ -378,6 +387,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@ -417,6 +428,7 @@ class FlashCausalLMBatch(Batch):
read_offsets = [] read_offsets = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
blocks = 0 blocks = 0
max_blocks = 0 max_blocks = 0
@ -443,6 +455,8 @@ class FlashCausalLMBatch(Batch):
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_tokens = ( remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
@ -487,6 +501,7 @@ class FlashCausalLMBatch(Batch):
input_lengths_tensor = self.input_lengths_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -518,6 +533,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@ -566,6 +583,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length) (total_batch_size, max_length)
) )
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
start_slots = [] start_slots = []
block_tables = [] block_tables = []
@ -577,6 +597,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters = [] next_token_chooser_parameters = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
# Cumulative length # Cumulative length
cumulative_batch_size = 0 cumulative_batch_size = 0
@ -602,6 +623,7 @@ class FlashCausalLMBatch(Batch):
position_ids[start_index:end_index] = batch.position_ids position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots slots[slots_start_index:slots_end_index] = batch.slots
all_input_ids_tensor[ all_input_ids_tensor[
@ -624,6 +646,8 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
# Update # Update
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
cumulative_slots += len(batch.slots) cumulative_slots += len(batch.slots)
@ -667,6 +691,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
@ -832,10 +858,14 @@ class FlashCausalLM(Model):
else: else:
next_token_logits = out next_token_logits = out
next_input_ids, next_token_logprobs = batch.next_token_chooser( next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
)
if prefill: if prefill:
if len(batch) > 1 and prefill_logprobs: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -932,8 +962,11 @@ class FlashCausalLM(Model):
batch.all_input_ids, batch.all_input_ids,
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
batch.top_n_tokens,
next_token_ids, next_token_ids,
next_token_logprobs, next_token_logprobs,
batch_top_token_ids,
batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
@ -946,8 +979,11 @@ class FlashCausalLM(Model):
all_input_ids, all_input_ids,
do_sample, do_sample,
seed, seed,
top_n_tokens,
next_token_id, next_token_id,
next_token_logprob, next_token_logprob,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Append next token to all tokens # Append next token to all tokens
all_input_ids.append(next_token_id) all_input_ids.append(next_token_id)
@ -1006,6 +1042,24 @@ class FlashCausalLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
@ -1014,6 +1068,7 @@ class FlashCausalLM(Model):
next_token_text, next_token_text,
next_token_id in self.all_special_ids, next_token_id in self.all_special_ids,
generated_text, generated_text,
top_tokens,
) )
generations.append(generation) generations.append(generation)

View File

@ -763,6 +763,8 @@ class IdeficsCausalLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
top_tokens=None
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
@ -771,6 +773,7 @@ class IdeficsCausalLM(Model):
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids, next_token_id_squeezed.item() in self.all_special_ids,
generated_text, generated_text,
top_tokens
) )
generations.append(generation) generations.append(generation)

View File

@ -1,3 +1,4 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
@ -11,6 +12,7 @@ from text_generation_server.models.types import (
Batch, Batch,
Generation, Generation,
PrefillTokens, PrefillTokens,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -48,6 +50,8 @@ class Seq2SeqLMBatch(Batch):
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding # Metadata used for padding
max_input_length: int max_input_length: int
@ -78,7 +82,7 @@ class Seq2SeqLMBatch(Batch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
decoder_input_lengths = [] decoder_input_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
@ -97,6 +101,7 @@ class Seq2SeqLMBatch(Batch):
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max( padding_right_offset = max(
@ -126,6 +131,9 @@ class Seq2SeqLMBatch(Batch):
prefix_offsets.append(0) prefix_offsets.append(0)
read_offsets.append(1) read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens) max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -146,6 +154,8 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
max_decoder_input_length=1, max_decoder_input_length=1,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
@ -173,6 +183,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
max_input_length = 0 max_input_length = 0
max_decoder_input_length = 0 max_decoder_input_length = 0
@ -204,6 +215,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = ( remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
@ -239,6 +251,7 @@ class Seq2SeqLMBatch(Batch):
layer[2] = layer[2][keep_indices, :, -max_input_length:] layer[2] = layer[2][keep_indices, :, -max_input_length:]
layer[3] = layer[3][keep_indices, :, -max_input_length:] layer[3] = layer[3][keep_indices, :, -max_input_length:]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = ( max_tokens = (
len(request_ids) * (max_input_length + max_decoder_input_length) len(request_ids) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens + remaining_decode_tokens
@ -254,6 +267,8 @@ class Seq2SeqLMBatch(Batch):
self.read_offsets = read_offsets self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset self.padding_right_offset = padding_right_offset
@ -289,6 +304,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets = [] read_offsets = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
max_tokens = 0 max_tokens = 0
# Batch tensors # Batch tensors
@ -296,6 +312,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids = None decoder_input_ids = None
decoder_attention_mask = None decoder_attention_mask = None
encoder_last_hidden_state = None encoder_last_hidden_state = None
top_n_tokens_tensor = None
past_key_values = [] past_key_values = []
# Used for slicing correctly inside the tensors # Used for slicing correctly inside the tensors
@ -312,6 +329,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets.extend(batch.read_offsets) read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0: if i == 0:
requests_idx_mapping = batch.requests_idx_mapping requests_idx_mapping = batch.requests_idx_mapping
@ -384,6 +402,12 @@ class Seq2SeqLMBatch(Batch):
), ),
) )
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Copy to correct indices # Copy to correct indices
encoder_last_hidden_state[ encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length :, : start_index:end_index, -batch.max_input_length :, :
@ -488,6 +512,8 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length, max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length, max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
@ -613,6 +639,12 @@ class Seq2SeqLM(Model):
batch.past_key_values, batch.past_key_values,
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
)
# Finished requests # Finished requests
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True stopped = True
@ -628,6 +660,9 @@ class Seq2SeqLM(Model):
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_decoder_input_ids, batch.all_decoder_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
@ -641,6 +676,9 @@ class Seq2SeqLM(Model):
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_decoder_input_ids, all_decoder_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
@ -698,6 +736,24 @@ class Seq2SeqLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
@ -706,6 +762,7 @@ class Seq2SeqLM(Model):
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids, next_token_id_squeezed.item() in self.all_special_ids,
generated_text, generated_text,
top_tokens,
) )
generations.append(generation) generations.append(generation)

View File

@ -1,3 +1,4 @@
from functools import total_ordering
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -71,6 +72,25 @@ class PrefillTokens:
return len(self.token_ids) return len(self.token_ids)
@dataclass
class TopTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
is_special: List[bool]
def to_pb(self) -> generate_pb2.TopTokens:
return generate_pb2.TopTokens(
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
)
def __len__(self):
return len(self.token_ids)
@dataclass @dataclass
class Generation: class Generation:
request_id: int request_id: int
@ -80,6 +100,8 @@ class Generation:
token_text: str token_text: str
token_is_special: bool token_is_special: bool
generated_text: Optional[GeneratedText] generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[TopTokens]
def to_pb(self) -> generate_pb2.Generation: def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation( return generate_pb2.Generation(
@ -94,4 +116,5 @@ class Generation:
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
) )

View File

@ -6,20 +6,22 @@ from transformers import (
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
SequenceBiasLogitsProcessor, SequenceBiasLogitsProcessor,
) )
from typing import List, Tuple, Optional, Dict from typing import Callable, List, Tuple, Optional, Dict
import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from text_generation_server.utils.logits_process import ( from text_generation_server.utils.logits_process import (
static_warper, HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper, HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper, HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper, HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper, HeterogeneousTypicalLogitsWarper,
HeterogeneousProcessorWrapper, static_warper,
) )
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class NextTokenChooser: class NextTokenChooser:
@ -255,11 +257,10 @@ class HeterogeneousNextTokenChooser:
scores = warper(input_ids, scores) scores = warper(input_ids, scores)
next_ids = self.choice(scores) next_ids = self.choice(scores)
next_logprobs = torch.gather( logprobs = torch.log_softmax(scores, -1)
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
).view(-1)
return next_ids, next_logprobs return next_ids, next_logprobs, logprobs
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None:
@ -370,3 +371,50 @@ class HeterogeneousSampling:
self.greedy_indices = new_greedy_indices self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping self.sampling_mapping = new_sampling_mapping
return self return self
def batch_top_tokens(
top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
) -> Tuple[List[List[int]], List[List[float]]]:
"""Find the top n most likely tokens for a batch of generations.
When multiple tokens have equal probabilities and they don't all fit, the
remaining tokens are also returned.
"""
max_top_n = max(top_n_tokens)
# Early exit when top_n_tokens is not used
if max_top_n == 0:
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
# Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
)
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
# Find the new "fuzzy" top n values
top_n_indices = (logprobs >= nth_highest).nonzero()
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
# Take a new topk for these new max n values
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
top_n_ishes = top_n_ishes.tolist()
top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist()
return (
[
idxs[:n] if req_n > 0 else []
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
],
[
vals[:n] if req_n > 0 else []
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
],
)