mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Merge branch 'main' into main
This commit is contained in:
commit
d4e2ce7e7b
@ -86,7 +86,7 @@ The easiest way of getting started is using the official Docker container:
|
||||
model=tiiuae/falcon-7b-instruct
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.1 --model-id $model
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.2 --model-id $model
|
||||
```
|
||||
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||
|
||||
@ -153,7 +153,7 @@ model=meta-llama/Llama-2-7b-chat-hf
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
token=<your cli READ token>
|
||||
|
||||
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.1 --model-id $model
|
||||
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.2 --model-id $model
|
||||
```
|
||||
|
||||
### A note on Shared Memory (shm)
|
||||
|
@ -37,6 +37,7 @@ pub(crate) async fn generation_task(
|
||||
batch_size: Vec<u32>,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
top_n_tokens: Option<u32>,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
parameters: NextTokenChooserParameters,
|
||||
@ -48,7 +49,7 @@ pub(crate) async fn generation_task(
|
||||
// End task if a message is received on shutdown_receiver
|
||||
// _shutdown_guard_sender will be dropped once the task is finished
|
||||
tokio::select! {
|
||||
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => {
|
||||
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
|
||||
if let Err(err) = res {
|
||||
run_sender.send(Err(err)).await.unwrap_or(());
|
||||
}
|
||||
@ -64,6 +65,7 @@ async fn generate_runs(
|
||||
batch_size: Vec<u32>,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
top_n_tokens: Option<u32>,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
parameters: NextTokenChooserParameters,
|
||||
@ -82,6 +84,7 @@ async fn generate_runs(
|
||||
b,
|
||||
decode_length,
|
||||
parameters.clone(),
|
||||
top_n_tokens,
|
||||
&mut client,
|
||||
)
|
||||
.await?;
|
||||
@ -97,6 +100,7 @@ async fn generate_runs(
|
||||
b,
|
||||
decode_length,
|
||||
parameters.clone(),
|
||||
top_n_tokens,
|
||||
&mut client,
|
||||
)
|
||||
.await?;
|
||||
@ -130,6 +134,7 @@ async fn prefill(
|
||||
batch_size: u32,
|
||||
decode_length: u32,
|
||||
parameters: NextTokenChooserParameters,
|
||||
top_n_tokens: Option<u32>,
|
||||
client: &mut ShardedClient,
|
||||
) -> Result<(Prefill, CachedBatch), ClientError> {
|
||||
// Create requests
|
||||
@ -145,6 +150,7 @@ async fn prefill(
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
||||
}),
|
||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
@ -22,6 +22,7 @@ pub async fn run(
|
||||
batch_size: Vec<u32>,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
top_n_tokens: Option<u32>,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
temperature: Option<f32>,
|
||||
@ -75,6 +76,7 @@ pub async fn run(
|
||||
batch_size.clone(),
|
||||
sequence_length,
|
||||
decode_length,
|
||||
top_n_tokens,
|
||||
n_runs,
|
||||
warmups,
|
||||
parameters,
|
||||
@ -135,6 +137,7 @@ pub async fn run(
|
||||
tokenizer_name,
|
||||
sequence_length,
|
||||
decode_length,
|
||||
top_n_tokens,
|
||||
n_runs,
|
||||
warmups,
|
||||
temperature,
|
||||
|
@ -94,6 +94,12 @@ struct Args {
|
||||
#[clap(long, env)]
|
||||
do_sample: bool,
|
||||
|
||||
|
||||
/// Generation parameter in case you want to specifically test/debug particular
|
||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||
#[clap(long, env)]
|
||||
top_n_tokens: Option<u32>,
|
||||
|
||||
/// Generation parameter in case you want to specifically test/debug particular
|
||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||
#[clap(long, env, value_parser=parse_key_val::<String, f32>)]
|
||||
@ -123,6 +129,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
do_sample,
|
||||
master_shard_uds_path,
|
||||
logit_bias,
|
||||
top_n_tokens,
|
||||
} = args;
|
||||
|
||||
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
||||
@ -179,6 +186,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
batch_size,
|
||||
sequence_length,
|
||||
decode_length,
|
||||
top_n_tokens,
|
||||
runs,
|
||||
warmups,
|
||||
temperature,
|
||||
|
@ -7,6 +7,7 @@ pub(crate) fn parameters_table(
|
||||
tokenizer_name: String,
|
||||
sequence_length: u32,
|
||||
decode_length: u32,
|
||||
top_n_tokens: Option<u32>,
|
||||
n_runs: usize,
|
||||
warmups: usize,
|
||||
temperature: Option<f32>,
|
||||
@ -25,6 +26,7 @@ pub(crate) fn parameters_table(
|
||||
builder.push_record(["Model", &tokenizer_name]);
|
||||
builder.push_record(["Sequence Length", &sequence_length.to_string()]);
|
||||
builder.push_record(["Decode Length", &decode_length.to_string()]);
|
||||
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
|
||||
builder.push_record(["N Runs", &n_runs.to_string()]);
|
||||
builder.push_record(["Warmups", &warmups.to_string()]);
|
||||
builder.push_record(["Temperature", &format!("{temperature:?}")]);
|
||||
|
@ -75,6 +75,7 @@ class Client:
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
decoder_input_details: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
logit_bias: Dict[str, float] = {},
|
||||
) -> Response:
|
||||
"""
|
||||
@ -114,6 +115,8 @@ class Client:
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
decoder_input_details (`bool`):
|
||||
Return the decoder input token logprobs and ids
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
logit_bias (`Dict[str, float]`):
|
||||
Bias generation towards certain tokens.
|
||||
|
||||
@ -137,6 +140,7 @@ class Client:
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
decoder_input_details=decoder_input_details,
|
||||
top_n_tokens=top_n_tokens,
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
@ -168,6 +172,7 @@ class Client:
|
||||
truncate: Optional[int] = None,
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
top_n_tokens: Optional[int] = None,
|
||||
logit_bias: Dict[str, float] = {},
|
||||
) -> Iterator[StreamResponse]:
|
||||
"""
|
||||
@ -203,6 +208,8 @@ class Client:
|
||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||
watermark (`bool`):
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
logit_bias (`Dict[str, float]`):
|
||||
Bias generation towards certain tokens.
|
||||
|
||||
@ -227,6 +234,7 @@ class Client:
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
logit_bias=logit_bias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
@ -326,6 +334,7 @@ class AsyncClient:
|
||||
watermark: bool = False,
|
||||
decoder_input_details: bool = False,
|
||||
logit_bias: Dict[str, float] = {},
|
||||
top_n_tokens: Optional[int] = None,
|
||||
) -> Response:
|
||||
"""
|
||||
Given a prompt, generate the following text asynchronously
|
||||
@ -366,6 +375,8 @@ class AsyncClient:
|
||||
Return the decoder input token logprobs and ids
|
||||
logit_bias (`Dict[str, float]`):
|
||||
Bias generation towards certain tokens.
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
|
||||
Returns:
|
||||
Response: generated response
|
||||
@ -388,6 +399,7 @@ class AsyncClient:
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
logit_bias=logit_bias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
@ -417,6 +429,7 @@ class AsyncClient:
|
||||
typical_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
logit_bias: Dict[str, float] = {},
|
||||
top_n_tokens: Optional[int] = None,
|
||||
) -> AsyncIterator[StreamResponse]:
|
||||
"""
|
||||
Given a prompt, generate the following stream of tokens asynchronously
|
||||
@ -453,6 +466,8 @@ class AsyncClient:
|
||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
logit_bias (`Dict[str, float]`):
|
||||
Bias generation towards certain tokens.
|
||||
top_n_tokens (`int`):
|
||||
Return the `n` most likely tokens at each step
|
||||
|
||||
Returns:
|
||||
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||
@ -475,6 +490,7 @@ class AsyncClient:
|
||||
typical_p=typical_p,
|
||||
watermark=watermark,
|
||||
logit_bias=logit_bias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
|
@ -41,6 +41,8 @@ class Parameters(BaseModel):
|
||||
decoder_input_details: bool = False
|
||||
# Bias generation towards certain tokens
|
||||
logit_bias: Dict[str, float] = {}
|
||||
# Return the N most likely tokens at each step
|
||||
top_n_tokens: Optional[int]
|
||||
|
||||
@validator("best_of")
|
||||
def valid_best_of(cls, field_value, values):
|
||||
@ -103,6 +105,12 @@ class Parameters(BaseModel):
|
||||
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
|
||||
return v
|
||||
|
||||
@validator("top_n_tokens")
|
||||
def valid_top_n_tokens(cls, v):
|
||||
if v is not None and v <= 0:
|
||||
raise ValidationError("`top_n_tokens` must be strictly positive")
|
||||
return v
|
||||
|
||||
|
||||
class Request(BaseModel):
|
||||
# Prompt
|
||||
@ -127,9 +135,7 @@ class Request(BaseModel):
|
||||
and parameters.best_of > 1
|
||||
and field_value
|
||||
):
|
||||
raise ValidationError(
|
||||
"`best_of` != 1 is not supported when `stream` == True"
|
||||
)
|
||||
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
|
||||
return field_value
|
||||
|
||||
|
||||
@ -181,6 +187,8 @@ class BestOfSequence(BaseModel):
|
||||
prefill: List[InputToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
# Most likely tokens
|
||||
top_tokens: Optional[List[List[Token]]]
|
||||
|
||||
|
||||
# `generate` details
|
||||
@ -195,6 +203,8 @@ class Details(BaseModel):
|
||||
prefill: List[InputToken]
|
||||
# Generated tokens
|
||||
tokens: List[Token]
|
||||
# Most likely tokens
|
||||
top_tokens: Optional[List[List[Token]]]
|
||||
# Additional sequences when using the `best_of` parameter
|
||||
best_of_sequences: Optional[List[BestOfSequence]]
|
||||
|
||||
@ -221,6 +231,8 @@ class StreamDetails(BaseModel):
|
||||
class StreamResponse(BaseModel):
|
||||
# Generated token
|
||||
token: Token
|
||||
# Most likely tokens
|
||||
top_tokens: Optional[List[Token]]
|
||||
# Complete generated text
|
||||
# Only available when the generation is finished
|
||||
generated_text: Optional[str]
|
||||
|
@ -10,7 +10,7 @@
|
||||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "1.0.1"
|
||||
"version": "1.0.2"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
|
@ -75,6 +75,81 @@ To serve both ChatUI and TGI in same environment, simply add your own endpoints
|
||||
|
||||

|
||||
|
||||
## Gradio
|
||||
|
||||
Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first.
|
||||
|
||||
```bash
|
||||
pip install huggingface-hub gradio
|
||||
```
|
||||
|
||||
Assume you are serving your model on port 8080, we will query through [InferenceClient](consuming_tgi#inference-client).
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
client = InferenceClient(model="http://127.0.0.1:8080")
|
||||
|
||||
def inference(message, history):
|
||||
partial_message = ""
|
||||
for token in client.text_generation(message, max_new_tokens=20, stream=True):
|
||||
partial_message += token
|
||||
yield partial_message
|
||||
|
||||
gr.ChatInterface(
|
||||
inference,
|
||||
chatbot=gr.Chatbot(height=300),
|
||||
textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7),
|
||||
description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.",
|
||||
title="Gradio 🤝 TGI",
|
||||
examples=["Are tomatoes vegetables?"],
|
||||
retry_btn="Retry",
|
||||
undo_btn="Undo",
|
||||
clear_btn="Clear",
|
||||
).queue().launch()
|
||||
```
|
||||
|
||||
The UI looks like this 👇
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img
|
||||
class="block dark:hidden"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi.png"
|
||||
/>
|
||||
<img
|
||||
class="hidden dark:block"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi-dark.png"
|
||||
/>
|
||||
</div>
|
||||
|
||||
You can try the demo directly here 👇
|
||||
|
||||
<div class="block dark:hidden">
|
||||
<iframe
|
||||
src="https://merve-gradio-tgi-2.hf.space?__theme=light"
|
||||
width="850"
|
||||
height="750"
|
||||
></iframe>
|
||||
</div>
|
||||
<div class="hidden dark:block">
|
||||
<iframe
|
||||
src="https://merve-gradio-tgi-2.hf.space?__theme=dark"
|
||||
width="850"
|
||||
height="750"
|
||||
></iframe>
|
||||
</div>
|
||||
|
||||
|
||||
You can disable streaming mode using `return` instead of `yield` in your inference function, like below.
|
||||
|
||||
```python
|
||||
def inference(message, history):
|
||||
return client.text_generation(message, max_new_tokens=20)
|
||||
```
|
||||
|
||||
You can read more about how to customize a `ChatInterface` [here](https://www.gradio.app/guides/creating-a-chatbot-fast).
|
||||
|
||||
## API documentation
|
||||
|
||||
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. The Swagger UI is also available [here](https://huggingface.github.io/text-generation-inference).
|
||||
|
@ -8,7 +8,7 @@ Let's say you want to deploy [Falcon-7B Instruct](https://huggingface.co/tiiuae/
|
||||
model=tiiuae/falcon-7b-instruct
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.1 --model-id $model
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.2 --model-id $model
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
@ -85,7 +85,7 @@ curl 127.0.0.1:8080/generate \
|
||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||
|
||||
```bash
|
||||
docker run ghcr.io/huggingface/text-generation-inference:1.0.1 --help
|
||||
docker run ghcr.io/huggingface/text-generation-inference:1.0.2 --help
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
@ -159,6 +159,14 @@ struct Args {
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
|
||||
/// This is the maximum allowed value for clients to set `top_n_tokens`.
|
||||
/// `top_n_tokens is used to return information about the the `n` most likely
|
||||
/// tokens at each generation step, instead of just the sampled token. This
|
||||
/// information can be used for downstream tasks like for classification or
|
||||
/// ranking.
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
|
||||
/// This is the maximum allowed input length (expressed in number of tokens)
|
||||
/// for users. The larger this value, the longer prompt users can send which
|
||||
/// can impact the overall memory required to handle the load.
|
||||
@ -929,6 +937,8 @@ fn spawn_webserver(
|
||||
args.max_best_of.to_string(),
|
||||
"--max-stop-sequences".to_string(),
|
||||
args.max_stop_sequences.to_string(),
|
||||
"--max-top-n-tokens".to_string(),
|
||||
args.max_top_n_tokens.to_string(),
|
||||
"--max-input-length".to_string(),
|
||||
args.max_input_length.to_string(),
|
||||
"--max-total-tokens".to_string(),
|
||||
|
@ -101,6 +101,8 @@ message Request {
|
||||
StoppingCriteriaParameters stopping_parameters = 5;
|
||||
/// Return prefill logprobs
|
||||
bool prefill_logprobs = 6;
|
||||
/// Return most likely n tokens
|
||||
uint32 top_n_tokens = 7;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
@ -151,6 +153,17 @@ message PrefillTokens {
|
||||
repeated string texts = 3;
|
||||
}
|
||||
|
||||
message TopTokens {
|
||||
/// Top Token IDs
|
||||
repeated uint32 ids = 1;
|
||||
/// Top Logprobs
|
||||
repeated float logprobs = 2;
|
||||
/// Top Token Texts
|
||||
repeated string texts = 3;
|
||||
/// If the tokens are special
|
||||
repeated bool is_special = 6;
|
||||
}
|
||||
|
||||
message Generation {
|
||||
/// Request ID
|
||||
uint64 request_id = 1;
|
||||
@ -166,6 +179,8 @@ message Generation {
|
||||
bool token_is_special = 6;
|
||||
/// Complete generated text
|
||||
optional GeneratedText generated_text = 7;
|
||||
/// Top tokens
|
||||
TopTokens top_tokens = 8;
|
||||
}
|
||||
|
||||
message FilterBatchRequest {
|
||||
|
@ -132,6 +132,7 @@ impl Client {
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
}
|
||||
|
@ -51,6 +51,7 @@ impl Health {
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: BATCH_ID,
|
||||
|
@ -138,12 +138,15 @@ impl Infer {
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<InferResponse, InferError> {
|
||||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||
|
||||
// Create stream and keep semaphore permit as long as generate lives
|
||||
let (_permit, mut stream) = self.generate_stream(request).await?;
|
||||
|
||||
// Return values
|
||||
let mut result_prefill = Vec::new();
|
||||
let mut result_tokens = Vec::new();
|
||||
let mut result_top_tokens = Vec::new();
|
||||
let mut result_generated_text = None;
|
||||
let mut result_start = None;
|
||||
let mut result_queued = None;
|
||||
@ -164,7 +167,10 @@ impl Infer {
|
||||
.collect();
|
||||
}
|
||||
// Push last token
|
||||
InferStreamResponse::Token(token) => result_tokens.push(token),
|
||||
InferStreamResponse::Intermediate { token, top_tokens } => {
|
||||
result_tokens.push(token);
|
||||
result_top_tokens.push(top_tokens);
|
||||
}
|
||||
// Final message
|
||||
// Set return values
|
||||
InferStreamResponse::End {
|
||||
@ -172,8 +178,10 @@ impl Infer {
|
||||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
top_tokens,
|
||||
} => {
|
||||
result_tokens.push(token);
|
||||
result_top_tokens.push(top_tokens);
|
||||
result_generated_text = Some(generated_text);
|
||||
result_start = Some(start);
|
||||
result_queued = Some(queued)
|
||||
@ -191,6 +199,11 @@ impl Infer {
|
||||
generated_text,
|
||||
queued,
|
||||
start,
|
||||
top_tokens: if use_top_tokens {
|
||||
result_top_tokens
|
||||
} else {
|
||||
Vec::new()
|
||||
},
|
||||
})
|
||||
} else {
|
||||
let err = InferError::IncompleteGeneration;
|
||||
@ -520,6 +533,26 @@ fn send_responses(
|
||||
special: generation.token_is_special,
|
||||
};
|
||||
|
||||
// generation.top_tokens
|
||||
|
||||
let mut top_tokens = Vec::new();
|
||||
if let Some(top_tokens_) = generation.top_tokens {
|
||||
top_tokens.extend(
|
||||
top_tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(top_tokens_.logprobs.into_iter())
|
||||
.zip(top_tokens_.texts.into_iter())
|
||||
.zip(top_tokens_.is_special.into_iter())
|
||||
.map(|(((id, logprob), text), special)| Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
if let Some(generated_text) = generation.generated_text {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
@ -527,6 +560,7 @@ fn send_responses(
|
||||
entry.response_tx.send_timeout(
|
||||
Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text,
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
@ -536,7 +570,7 @@ fn send_responses(
|
||||
} else {
|
||||
// Send message
|
||||
entry.response_tx.send_timeout(
|
||||
Ok(InferStreamResponse::Token(token)),
|
||||
Ok(InferStreamResponse::Intermediate { token, top_tokens }),
|
||||
Duration::from_millis(10),
|
||||
)?;
|
||||
}
|
||||
@ -566,10 +600,14 @@ pub(crate) enum InferStreamResponse {
|
||||
// Optional first message
|
||||
Prefill(PrefillTokens),
|
||||
// Intermediate messages
|
||||
Token(Token),
|
||||
Intermediate {
|
||||
token: Token,
|
||||
top_tokens: Vec<Token>,
|
||||
},
|
||||
// Last message
|
||||
End {
|
||||
token: Token,
|
||||
top_tokens: Vec<Token>,
|
||||
generated_text: GeneratedText,
|
||||
start: Instant,
|
||||
queued: Instant,
|
||||
@ -583,6 +621,7 @@ pub(crate) struct InferResponse {
|
||||
pub(crate) generated_text: GeneratedText,
|
||||
pub(crate) queued: Instant,
|
||||
pub(crate) start: Instant,
|
||||
pub(crate) top_tokens: Vec<Vec<Token>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
@ -139,6 +139,8 @@ pub(crate) struct GenerateParameters {
|
||||
#[serde(default)]
|
||||
#[schema(default=json!({}), example=json!({"hello": 0.5}))]
|
||||
pub logit_bias: BTreeMap<String, f32>
|
||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||
pub top_n_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> u32 {
|
||||
@ -163,6 +165,7 @@ fn default_parameters() -> GenerateParameters {
|
||||
decoder_input_details: false,
|
||||
seed: None,
|
||||
logit_bias: BTreeMap::new(),
|
||||
top_n_tokens: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -240,6 +243,8 @@ pub(crate) struct BestOfSequence {
|
||||
pub seed: Option<u64>,
|
||||
pub prefill: Vec<PrefillToken>,
|
||||
pub tokens: Vec<Token>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub top_tokens: Vec<Vec<Token>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
@ -254,6 +259,8 @@ pub(crate) struct Details {
|
||||
pub tokens: Vec<Token>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub top_tokens: Vec<Vec<Token>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
@ -277,6 +284,8 @@ pub(crate) struct StreamDetails {
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct StreamResponse {
|
||||
pub token: Token,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub top_tokens: Vec<Token>,
|
||||
#[schema(nullable = true, default = "null", example = "test")]
|
||||
pub generated_text: Option<String>,
|
||||
#[schema(nullable = true, default = "null")]
|
||||
|
@ -29,6 +29,8 @@ struct Args {
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_length: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
@ -75,6 +77,7 @@ fn main() -> Result<(), RouterError> {
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
@ -259,6 +262,7 @@ fn main() -> Result<(), RouterError> {
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
|
@ -235,6 +235,7 @@ impl State {
|
||||
truncate: entry.request.truncate,
|
||||
parameters: Some(entry.request.parameters.clone()),
|
||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
@ -329,6 +330,7 @@ mod tests {
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
|
@ -158,7 +158,7 @@ async fn generate(
|
||||
add_prompt = Some(req.inputs.clone());
|
||||
}
|
||||
|
||||
let details = req.parameters.details || req.parameters.decoder_input_details;
|
||||
let details: bool = req.parameters.details || req.parameters.decoder_input_details;
|
||||
|
||||
// Inference
|
||||
let (response, best_of_responses) = match req.parameters.best_of {
|
||||
@ -191,6 +191,7 @@ async fn generate(
|
||||
generated_tokens: response.generated_text.generated_tokens,
|
||||
prefill: response.prefill,
|
||||
tokens: response.tokens,
|
||||
top_tokens: response.top_tokens,
|
||||
seed: response.generated_text.seed,
|
||||
}
|
||||
})
|
||||
@ -204,6 +205,7 @@ async fn generate(
|
||||
tokens: response.tokens,
|
||||
seed: response.generated_text.seed,
|
||||
best_of_sequences,
|
||||
top_tokens: response.top_tokens,
|
||||
})
|
||||
}
|
||||
false => None,
|
||||
@ -385,12 +387,16 @@ async fn generate_stream(
|
||||
// Prefill is ignored
|
||||
InferStreamResponse::Prefill(_) => {}
|
||||
// Yield event for every new token
|
||||
InferStreamResponse::Token(token) => {
|
||||
InferStreamResponse::Intermediate{
|
||||
token,
|
||||
top_tokens,
|
||||
} => {
|
||||
tracing::debug!(parent: &span, "Token: {:?}", token);
|
||||
|
||||
// StreamResponse
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
top_tokens: top_tokens,
|
||||
generated_text: None,
|
||||
details: None,
|
||||
};
|
||||
@ -403,6 +409,7 @@ async fn generate_stream(
|
||||
generated_text,
|
||||
start,
|
||||
queued,
|
||||
top_tokens,
|
||||
} => {
|
||||
// Token details
|
||||
let details = match details {
|
||||
@ -451,6 +458,7 @@ async fn generate_stream(
|
||||
|
||||
let stream_token = StreamResponse {
|
||||
token,
|
||||
top_tokens: top_tokens,
|
||||
generated_text: Some(output_text),
|
||||
details
|
||||
};
|
||||
@ -509,6 +517,7 @@ pub async fn run(
|
||||
max_concurrent_requests: usize,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
waiting_served_ratio: f32,
|
||||
@ -571,6 +580,7 @@ pub async fn run(
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
@ -15,6 +15,7 @@ pub struct Validation {
|
||||
/// Validation parameters
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
/// Channel to communicate with the background tokenization task
|
||||
@ -27,6 +28,7 @@ impl Validation {
|
||||
tokenizer: Option<Tokenizer>,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
) -> Self {
|
||||
@ -54,6 +56,7 @@ impl Validation {
|
||||
max_best_of,
|
||||
sender,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
}
|
||||
@ -143,6 +146,7 @@ impl Validation {
|
||||
watermark,
|
||||
decoder_input_details,
|
||||
logit_bias,
|
||||
top_n_tokens,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
@ -219,6 +223,15 @@ impl Validation {
|
||||
}
|
||||
};
|
||||
|
||||
let top_n_tokens = top_n_tokens
|
||||
.map(|value| {
|
||||
if value > self.max_top_n_tokens {
|
||||
return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
|
||||
}
|
||||
Ok(value)
|
||||
})
|
||||
.unwrap_or(Ok(0))?;
|
||||
|
||||
// Check if inputs is empty
|
||||
if request.inputs.is_empty() {
|
||||
return Err(EmptyInput);
|
||||
@ -268,6 +281,7 @@ impl Validation {
|
||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||
parameters,
|
||||
stopping_parameters,
|
||||
top_n_tokens: top_n_tokens,
|
||||
})
|
||||
}
|
||||
|
||||
@ -341,6 +355,7 @@ pub(crate) struct ValidGenerateRequest {
|
||||
pub decoder_input_details: bool,
|
||||
pub parameters: NextTokenChooserParameters,
|
||||
pub stopping_parameters: StoppingCriteriaParameters,
|
||||
pub top_n_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
@ -355,6 +370,10 @@ pub enum ValidationError {
|
||||
BestOfSeed,
|
||||
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
||||
BestOfStream,
|
||||
#[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
|
||||
TopNTokens(u32, u32),
|
||||
#[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
|
||||
TopNTokensDisabled,
|
||||
#[error("`decoder_input_details` == true is not supported when streaming tokens")]
|
||||
PrefillDetailsStream,
|
||||
#[error("`temperature` must be strictly positive")]
|
||||
@ -396,14 +415,16 @@ mod tests {
|
||||
let tokenizer = None;
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
@ -423,14 +444,16 @@ mod tests {
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
@ -440,7 +463,7 @@ mod tests {
|
||||
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
|
||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||
_ => panic!("Unexpected not max new tokens"),
|
||||
}
|
||||
}
|
||||
@ -450,14 +473,16 @@ mod tests {
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
@ -482,14 +507,16 @@ mod tests {
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
@ -536,4 +563,75 @@ mod tests {
|
||||
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
||||
assert_eq!(valid_request.parameters.top_p, 1.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_top_n_tokens() {
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequences = 3;
|
||||
let max_top_n_tokens = 4;
|
||||
let max_input_length = 5;
|
||||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::TopNTokens(4, 5)) => (),
|
||||
_ => panic!("Unexpected top_n_tokens"),
|
||||
}
|
||||
|
||||
validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(4),
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(0),
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let valid_request = validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: None,
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(valid_request.top_n_tokens, 0);
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
import torch
|
||||
from text_generation_server.utils.tokens import (
|
||||
StopSequenceCriteria,
|
||||
StoppingCriteria,
|
||||
FinishReason,
|
||||
batch_top_tokens,
|
||||
)
|
||||
|
||||
|
||||
@ -42,3 +44,22 @@ def test_stopping_criteria_max():
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (False, None)
|
||||
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
||||
|
||||
def test_batch_top_tokens():
|
||||
top_n_tokens = [0, 2, 3, 4, 5]
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens)
|
||||
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
|
||||
|
||||
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs)
|
||||
|
||||
assert topn_tok_ids[0] == []
|
||||
assert topn_tok_ids[1] == [0, 3]
|
||||
assert topn_tok_ids[2] == [0, 3, 1, 4]
|
||||
assert topn_tok_ids[3] == [0, 3, 1, 4]
|
||||
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
|
||||
|
||||
assert topn_tok_logprobs[0] == []
|
||||
assert topn_tok_logprobs[1] == [-1, -2]
|
||||
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
|
||||
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
|
||||
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
|
||||
|
@ -125,6 +125,9 @@ def download_weights(
|
||||
try:
|
||||
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
|
||||
utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
|
||||
is_local_model = True
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
return
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
import inspect
|
||||
|
||||
@ -12,6 +13,7 @@ from text_generation_server.models.types import (
|
||||
PrefillTokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
@ -42,6 +44,8 @@ class CausalLMBatch(Batch):
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
top_n_tokens: List[int]
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Metadata used for padding
|
||||
max_input_length: int
|
||||
@ -72,6 +76,7 @@ class CausalLMBatch(Batch):
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
requests_idx_mapping = {}
|
||||
@ -88,6 +93,7 @@ class CausalLMBatch(Batch):
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
@ -121,6 +127,9 @@ class CausalLMBatch(Batch):
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||
|
||||
@ -138,6 +147,8 @@ class CausalLMBatch(Batch):
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
@ -163,6 +174,7 @@ class CausalLMBatch(Batch):
|
||||
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
total_remaining_decode_tokens = 0
|
||||
new_padding_right_offset = 0
|
||||
@ -184,6 +196,7 @@ class CausalLMBatch(Batch):
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
remaining_decode_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
@ -223,6 +236,7 @@ class CausalLMBatch(Batch):
|
||||
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
||||
del past_values
|
||||
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||
|
||||
self.requests = requests
|
||||
@ -235,6 +249,8 @@ class CausalLMBatch(Batch):
|
||||
self.read_offsets = read_offsets
|
||||
self.next_token_choosers = next_token_choosers
|
||||
self.stopping_criterias = stopping_criterias
|
||||
self.top_n_tokens = top_n_tokens
|
||||
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||
self.max_input_length = max_input_length
|
||||
self.padding_right_offset = new_padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
@ -262,6 +278,7 @@ class CausalLMBatch(Batch):
|
||||
all_input_ids = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
max_tokens = 0
|
||||
|
||||
# Batch tensors
|
||||
@ -269,6 +286,7 @@ class CausalLMBatch(Batch):
|
||||
attention_mask = None
|
||||
position_ids = None
|
||||
past_key_values = []
|
||||
top_n_tokens_tensor = None
|
||||
|
||||
# Used for slicing correctly inside the tensors
|
||||
# Equivalent to a cumsum on batch sizes
|
||||
@ -281,6 +299,7 @@ class CausalLMBatch(Batch):
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
top_n_tokens.extend(batch.top_n_tokens)
|
||||
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
@ -310,6 +329,12 @@ class CausalLMBatch(Batch):
|
||||
(total_batch_size, max_input_length + padding_right_offset),
|
||||
)
|
||||
|
||||
if top_n_tokens_tensor is None:
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
|
||||
# We need to slice the attention mask to remove padding from previous steps
|
||||
# and to remove unused allocated space
|
||||
left_offset = max_input_length - batch.max_input_length
|
||||
@ -438,6 +463,8 @@ class CausalLMBatch(Batch):
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
@ -549,6 +576,12 @@ class CausalLM(Model):
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
torch.softmax(logits[:, -1], -1),
|
||||
)
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
@ -559,6 +592,9 @@ class CausalLM(Model):
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
batch.top_n_tokens,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
@ -571,6 +607,9 @@ class CausalLM(Model):
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
top_n_tokens,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
@ -637,6 +676,24 @@ class CausalLM(Model):
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = TopTokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
@ -645,6 +702,7 @@ class CausalLM(Model):
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
@ -64,6 +64,7 @@ class LlamaConfig(PretrainedConfig):
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_scaling=None,
|
||||
rope_theta=10000.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -84,6 +85,7 @@ class LlamaConfig(PretrainedConfig):
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_scaling = rope_scaling
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
@ -189,7 +191,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
# )
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
||||
config=config, dim=self.head_size, base=config.rope_theta, device=weights.device
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
@ -1,5 +1,6 @@
|
||||
import math
|
||||
import itertools
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@ -16,6 +17,7 @@ from text_generation_server.models.types import (
|
||||
PrefillTokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
@ -165,6 +167,8 @@ class FlashCausalLMBatch(Batch):
|
||||
# Generation helpers
|
||||
next_token_chooser: HeterogeneousNextTokenChooser
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
top_n_tokens: List[int]
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Number of blocks in this batch
|
||||
blocks: int
|
||||
@ -217,6 +221,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
@ -259,6 +264,7 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
max_new_tokens = stopping_criteria.max_new_tokens
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
|
||||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
@ -352,6 +358,9 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_next_token_indices = torch.tensor(
|
||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
@ -378,6 +387,8 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
@ -417,6 +428,7 @@ class FlashCausalLMBatch(Batch):
|
||||
read_offsets = []
|
||||
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
blocks = 0
|
||||
max_blocks = 0
|
||||
@ -443,6 +455,8 @@ class FlashCausalLMBatch(Batch):
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
|
||||
remaining_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
@ -487,6 +501,7 @@ class FlashCausalLMBatch(Batch):
|
||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
@ -518,6 +533,8 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
@ -566,6 +583,9 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||
(total_batch_size, max_length)
|
||||
)
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
|
||||
start_slots = []
|
||||
block_tables = []
|
||||
@ -577,6 +597,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_batch_size = 0
|
||||
@ -602,6 +623,7 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids[start_index:end_index] = batch.position_ids
|
||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
|
||||
all_input_ids_tensor[
|
||||
@ -624,6 +646,8 @@ class FlashCausalLMBatch(Batch):
|
||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
top_n_tokens.extend(batch.top_n_tokens)
|
||||
|
||||
# Update
|
||||
cumulative_batch_size += len(batch)
|
||||
cumulative_slots += len(batch.slots)
|
||||
@ -667,6 +691,8 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
)
|
||||
@ -832,10 +858,14 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
next_token_logits = out
|
||||
|
||||
next_input_ids, next_token_logprobs = batch.next_token_chooser(
|
||||
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||
)
|
||||
|
||||
if prefill:
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
@ -932,8 +962,11 @@ class FlashCausalLM(Model):
|
||||
batch.all_input_ids,
|
||||
batch.next_token_chooser.do_sample,
|
||||
batch.next_token_chooser.seeds,
|
||||
batch.top_n_tokens,
|
||||
next_token_ids,
|
||||
next_token_logprobs,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
@ -946,8 +979,11 @@ class FlashCausalLM(Model):
|
||||
all_input_ids,
|
||||
do_sample,
|
||||
seed,
|
||||
top_n_tokens,
|
||||
next_token_id,
|
||||
next_token_logprob,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Append next token to all tokens
|
||||
all_input_ids.append(next_token_id)
|
||||
@ -1006,6 +1042,24 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = TopTokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
@ -1014,6 +1068,7 @@ class FlashCausalLM(Model):
|
||||
next_token_text,
|
||||
next_token_id in self.all_special_ids,
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
@ -763,6 +763,8 @@ class IdeficsCausalLM(Model):
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
top_tokens=None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
@ -771,6 +773,7 @@ class IdeficsCausalLM(Model):
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
generated_text,
|
||||
top_tokens
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
@ -11,6 +12,7 @@ from text_generation_server.models.types import (
|
||||
Batch,
|
||||
Generation,
|
||||
PrefillTokens,
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
@ -48,6 +50,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
# Generation helpers
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
top_n_tokens: List[int]
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Metadata used for padding
|
||||
max_input_length: int
|
||||
@ -78,7 +82,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
top_n_tokens = []
|
||||
decoder_input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
@ -97,6 +101,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||
padding_right_offset = max(
|
||||
@ -126,6 +131,9 @@ class Seq2SeqLMBatch(Batch):
|
||||
prefix_offsets.append(0)
|
||||
read_offsets.append(1)
|
||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||
|
||||
@ -146,6 +154,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length.item(),
|
||||
max_decoder_input_length=1,
|
||||
padding_right_offset=padding_right_offset,
|
||||
@ -173,6 +183,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
max_input_length = 0
|
||||
max_decoder_input_length = 0
|
||||
@ -204,6 +215,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
next_token_choosers.append(self.next_token_choosers[idx])
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
remaining_decode_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
@ -239,6 +251,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
layer[2] = layer[2][keep_indices, :, -max_input_length:]
|
||||
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
||||
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||
max_tokens = (
|
||||
len(request_ids) * (max_input_length + max_decoder_input_length)
|
||||
+ remaining_decode_tokens
|
||||
@ -254,6 +267,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
self.read_offsets = read_offsets
|
||||
self.next_token_choosers = next_token_choosers
|
||||
self.stopping_criterias = stopping_criterias
|
||||
self.top_n_tokens = top_n_tokens
|
||||
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||
self.max_input_length = max_input_length
|
||||
self.max_decoder_input_length = max_decoder_input_length
|
||||
self.padding_right_offset = padding_right_offset
|
||||
@ -289,6 +304,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
read_offsets = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
max_tokens = 0
|
||||
|
||||
# Batch tensors
|
||||
@ -296,6 +312,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
decoder_input_ids = None
|
||||
decoder_attention_mask = None
|
||||
encoder_last_hidden_state = None
|
||||
top_n_tokens_tensor = None
|
||||
past_key_values = []
|
||||
|
||||
# Used for slicing correctly inside the tensors
|
||||
@ -312,6 +329,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
read_offsets.extend(batch.read_offsets)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
top_n_tokens.extend(batch.top_n_tokens)
|
||||
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
@ -384,6 +402,12 @@ class Seq2SeqLMBatch(Batch):
|
||||
),
|
||||
)
|
||||
|
||||
if top_n_tokens_tensor is None:
|
||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
|
||||
# Copy to correct indices
|
||||
encoder_last_hidden_state[
|
||||
start_index:end_index, -batch.max_input_length :, :
|
||||
@ -488,6 +512,8 @@ class Seq2SeqLMBatch(Batch):
|
||||
read_offsets=read_offsets,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
max_input_length=max_input_length,
|
||||
max_decoder_input_length=max_decoder_input_length,
|
||||
padding_right_offset=padding_right_offset,
|
||||
@ -613,6 +639,12 @@ class Seq2SeqLM(Model):
|
||||
batch.past_key_values,
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
torch.softmax(logits[:, -1], -1),
|
||||
)
|
||||
|
||||
# Finished requests
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
@ -628,6 +660,9 @@ class Seq2SeqLM(Model):
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_decoder_input_ids,
|
||||
batch.top_n_tokens,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
@ -641,6 +676,9 @@ class Seq2SeqLM(Model):
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_decoder_input_ids,
|
||||
top_n_tokens,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
@ -698,6 +736,24 @@ class Seq2SeqLM(Model):
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
if top_n_tokens > 0:
|
||||
toptoken_texts = self.tokenizer.batch_decode(
|
||||
top_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
special_toptokens = [
|
||||
token_id in self.all_special_ids for token_id in top_token_ids
|
||||
]
|
||||
top_tokens = TopTokens(
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
toptoken_texts,
|
||||
special_toptokens,
|
||||
)
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
@ -706,6 +762,7 @@ class Seq2SeqLM(Model):
|
||||
next_token_text,
|
||||
next_token_id_squeezed.item() in self.all_special_ids,
|
||||
generated_text,
|
||||
top_tokens,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from functools import total_ordering
|
||||
import torch
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
@ -71,6 +72,25 @@ class PrefillTokens:
|
||||
return len(self.token_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TopTokens:
|
||||
token_ids: List[int]
|
||||
logprobs: List[float]
|
||||
texts: List[str]
|
||||
is_special: List[bool]
|
||||
|
||||
def to_pb(self) -> generate_pb2.TopTokens:
|
||||
return generate_pb2.TopTokens(
|
||||
ids=self.token_ids,
|
||||
logprobs=self.logprobs,
|
||||
texts=self.texts,
|
||||
is_special=self.is_special,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.token_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Generation:
|
||||
request_id: int
|
||||
@ -80,6 +100,8 @@ class Generation:
|
||||
token_text: str
|
||||
token_is_special: bool
|
||||
generated_text: Optional[GeneratedText]
|
||||
# Optional for now, since it's not yet supported for every model.
|
||||
top_tokens: Optional[TopTokens]
|
||||
|
||||
def to_pb(self) -> generate_pb2.Generation:
|
||||
return generate_pb2.Generation(
|
||||
@ -94,4 +116,5 @@ class Generation:
|
||||
generated_text=self.generated_text.to_pb()
|
||||
if self.generated_text is not None
|
||||
else None,
|
||||
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
|
||||
)
|
||||
|
@ -6,20 +6,22 @@ from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
SequenceBiasLogitsProcessor,
|
||||
)
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
from typing import Callable, List, Tuple, Optional, Dict
|
||||
|
||||
import torch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from text_generation_server.utils.logits_process import (
|
||||
static_warper,
|
||||
HeterogeneousProcessorWrapper,
|
||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||
HeterogeneousTemperatureLogitsWarper,
|
||||
HeterogeneousTopKLogitsWarper,
|
||||
HeterogeneousTopPLogitsWarper,
|
||||
HeterogeneousTypicalLogitsWarper,
|
||||
HeterogeneousProcessorWrapper,
|
||||
static_warper,
|
||||
)
|
||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
@ -255,11 +257,10 @@ class HeterogeneousNextTokenChooser:
|
||||
scores = warper(input_ids, scores)
|
||||
|
||||
next_ids = self.choice(scores)
|
||||
next_logprobs = torch.gather(
|
||||
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
|
||||
).view(-1)
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
|
||||
return next_ids, next_logprobs
|
||||
return next_ids, next_logprobs, logprobs
|
||||
|
||||
def filter(self, indices):
|
||||
if self.watermark_processor is not None:
|
||||
@ -370,3 +371,50 @@ class HeterogeneousSampling:
|
||||
self.greedy_indices = new_greedy_indices
|
||||
self.sampling_mapping = new_sampling_mapping
|
||||
return self
|
||||
|
||||
|
||||
def batch_top_tokens(
|
||||
top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
"""Find the top n most likely tokens for a batch of generations.
|
||||
|
||||
When multiple tokens have equal probabilities and they don't all fit, the
|
||||
remaining tokens are also returned.
|
||||
"""
|
||||
max_top_n = max(top_n_tokens)
|
||||
# Early exit when top_n_tokens is not used
|
||||
if max_top_n == 0:
|
||||
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
||||
|
||||
# Ensure top_n doesn't exceed vocab size
|
||||
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
|
||||
|
||||
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
|
||||
nth_highest = torch.gather(
|
||||
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
|
||||
)
|
||||
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
|
||||
|
||||
# Find the new "fuzzy" top n values
|
||||
top_n_indices = (logprobs >= nth_highest).nonzero()
|
||||
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
|
||||
|
||||
# Take a new topk for these new max n values
|
||||
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
|
||||
|
||||
top_n_ishes = top_n_ishes.tolist()
|
||||
top_indices = top_k.indices.tolist()
|
||||
top_values = top_k.values.tolist()
|
||||
|
||||
return (
|
||||
[
|
||||
idxs[:n] if req_n > 0 else []
|
||||
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
|
||||
],
|
||||
[
|
||||
vals[:n] if req_n > 0 else []
|
||||
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user