Merge branch 'main' into main

This commit is contained in:
Marcus Dunn 2023-08-28 14:12:39 -07:00 committed by GitHub
commit d4e2ce7e7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 615 additions and 34 deletions

View File

@ -86,7 +86,7 @@ The easiest way of getting started is using the official Docker container:
model=tiiuae/falcon-7b-instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.1 --model-id $model
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.2 --model-id $model
```
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
@ -153,7 +153,7 @@ model=meta-llama/Llama-2-7b-chat-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.1 --model-id $model
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.2 --model-id $model
```
### A note on Shared Memory (shm)

View File

@ -37,6 +37,7 @@ pub(crate) async fn generation_task(
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
parameters: NextTokenChooserParameters,
@ -48,7 +49,7 @@ pub(crate) async fn generation_task(
// End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished
tokio::select! {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
if let Err(err) = res {
run_sender.send(Err(err)).await.unwrap_or(());
}
@ -64,6 +65,7 @@ async fn generate_runs(
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
parameters: NextTokenChooserParameters,
@ -82,6 +84,7 @@ async fn generate_runs(
b,
decode_length,
parameters.clone(),
top_n_tokens,
&mut client,
)
.await?;
@ -97,6 +100,7 @@ async fn generate_runs(
b,
decode_length,
parameters.clone(),
top_n_tokens,
&mut client,
)
.await?;
@ -130,6 +134,7 @@ async fn prefill(
batch_size: u32,
decode_length: u32,
parameters: NextTokenChooserParameters,
top_n_tokens: Option<u32>,
client: &mut ShardedClient,
) -> Result<(Prefill, CachedBatch), ClientError> {
// Create requests
@ -145,6 +150,7 @@ async fn prefill(
stop_sequences: vec![],
ignore_eos_token: true, // Will not stop even if a eos token is generated
}),
top_n_tokens: top_n_tokens.unwrap_or(0),
})
.collect();

View File

@ -22,6 +22,7 @@ pub async fn run(
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
temperature: Option<f32>,
@ -75,6 +76,7 @@ pub async fn run(
batch_size.clone(),
sequence_length,
decode_length,
top_n_tokens,
n_runs,
warmups,
parameters,
@ -135,6 +137,7 @@ pub async fn run(
tokenizer_name,
sequence_length,
decode_length,
top_n_tokens,
n_runs,
warmups,
temperature,

View File

@ -94,6 +94,12 @@ struct Args {
#[clap(long, env)]
do_sample: bool,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
top_n_tokens: Option<u32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env, value_parser=parse_key_val::<String, f32>)]
@ -123,6 +129,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
do_sample,
master_shard_uds_path,
logit_bias,
top_n_tokens,
} = args;
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
@ -179,6 +186,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
batch_size,
sequence_length,
decode_length,
top_n_tokens,
runs,
warmups,
temperature,

View File

@ -7,6 +7,7 @@ pub(crate) fn parameters_table(
tokenizer_name: String,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
temperature: Option<f32>,
@ -25,6 +26,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Model", &tokenizer_name]);
builder.push_record(["Sequence Length", &sequence_length.to_string()]);
builder.push_record(["Decode Length", &decode_length.to_string()]);
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
builder.push_record(["N Runs", &n_runs.to_string()]);
builder.push_record(["Warmups", &warmups.to_string()]);
builder.push_record(["Temperature", &format!("{temperature:?}")]);

View File

@ -75,6 +75,7 @@ class Client:
typical_p: Optional[float] = None,
watermark: bool = False,
decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
logit_bias: Dict[str, float] = {},
) -> Response:
"""
@ -114,6 +115,8 @@ class Client:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens.
@ -137,6 +140,7 @@ class Client:
typical_p=typical_p,
watermark=watermark,
decoder_input_details=decoder_input_details,
top_n_tokens=top_n_tokens,
logit_bias=logit_bias,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -168,6 +172,7 @@ class Client:
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
top_n_tokens: Optional[int] = None,
logit_bias: Dict[str, float] = {},
) -> Iterator[StreamResponse]:
"""
@ -203,6 +208,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens.
@ -227,6 +234,7 @@ class Client:
typical_p=typical_p,
watermark=watermark,
logit_bias=logit_bias,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
@ -326,6 +334,7 @@ class AsyncClient:
watermark: bool = False,
decoder_input_details: bool = False,
logit_bias: Dict[str, float] = {},
top_n_tokens: Optional[int] = None,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
@ -366,6 +375,8 @@ class AsyncClient:
Return the decoder input token logprobs and ids
logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens.
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
Response: generated response
@ -388,6 +399,7 @@ class AsyncClient:
typical_p=typical_p,
watermark=watermark,
logit_bias=logit_bias,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -417,6 +429,7 @@ class AsyncClient:
typical_p: Optional[float] = None,
watermark: bool = False,
logit_bias: Dict[str, float] = {},
top_n_tokens: Optional[int] = None,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
@ -453,6 +466,8 @@ class AsyncClient:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
logit_bias (`Dict[str, float]`):
Bias generation towards certain tokens.
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns:
AsyncIterator[StreamResponse]: stream of generated tokens
@ -475,6 +490,7 @@ class AsyncClient:
typical_p=typical_p,
watermark=watermark,
logit_bias=logit_bias,
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)

View File

@ -41,6 +41,8 @@ class Parameters(BaseModel):
decoder_input_details: bool = False
# Bias generation towards certain tokens
logit_bias: Dict[str, float] = {}
# Return the N most likely tokens at each step
top_n_tokens: Optional[int]
@validator("best_of")
def valid_best_of(cls, field_value, values):
@ -103,6 +105,12 @@ class Parameters(BaseModel):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v
@validator("top_n_tokens")
def valid_top_n_tokens(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_n_tokens` must be strictly positive")
return v
class Request(BaseModel):
# Prompt
@ -127,9 +135,7 @@ class Request(BaseModel):
and parameters.best_of > 1
and field_value
):
raise ValidationError(
"`best_of` != 1 is not supported when `stream` == True"
)
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
return field_value
@ -181,6 +187,8 @@ class BestOfSequence(BaseModel):
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# `generate` details
@ -195,6 +203,8 @@ class Details(BaseModel):
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
@ -221,6 +231,8 @@ class StreamDetails(BaseModel):
class StreamResponse(BaseModel):
# Generated token
token: Token
# Most likely tokens
top_tokens: Optional[List[Token]]
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "1.0.1"
"version": "1.0.2"
},
"paths": {
"/": {

View File

@ -75,6 +75,81 @@ To serve both ChatUI and TGI in same environment, simply add your own endpoints
![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png)
## Gradio
Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first.
```bash
pip install huggingface-hub gradio
```
Assume you are serving your model on port 8080, we will query through [InferenceClient](consuming_tgi#inference-client).
```python
import gradio as gr
from huggingface_hub import InferenceClient
client = InferenceClient(model="http://127.0.0.1:8080")
def inference(message, history):
partial_message = ""
for token in client.text_generation(message, max_new_tokens=20, stream=True):
partial_message += token
yield partial_message
gr.ChatInterface(
inference,
chatbot=gr.Chatbot(height=300),
textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7),
description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.",
title="Gradio 🤝 TGI",
examples=["Are tomatoes vegetables?"],
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
).queue().launch()
```
The UI looks like this 👇
<div class="flex justify-center">
<img
class="block dark:hidden"
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi.png"
/>
<img
class="hidden dark:block"
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi-dark.png"
/>
</div>
You can try the demo directly here 👇
<div class="block dark:hidden">
<iframe
src="https://merve-gradio-tgi-2.hf.space?__theme=light"
width="850"
height="750"
></iframe>
</div>
<div class="hidden dark:block">
<iframe
src="https://merve-gradio-tgi-2.hf.space?__theme=dark"
width="850"
height="750"
></iframe>
</div>
You can disable streaming mode using `return` instead of `yield` in your inference function, like below.
```python
def inference(message, history):
return client.text_generation(message, max_new_tokens=20)
```
You can read more about how to customize a `ChatInterface` [here](https://www.gradio.app/guides/creating-a-chatbot-fast).
## API documentation
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. The Swagger UI is also available [here](https://huggingface.github.io/text-generation-inference).

View File

@ -8,7 +8,7 @@ Let's say you want to deploy [Falcon-7B Instruct](https://huggingface.co/tiiuae/
model=tiiuae/falcon-7b-instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.1 --model-id $model
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.2 --model-id $model
```
<Tip warning={true}>
@ -85,7 +85,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash
docker run ghcr.io/huggingface/text-generation-inference:1.0.1 --help
docker run ghcr.io/huggingface/text-generation-inference:1.0.2 --help
```
</Tip>

View File

@ -159,6 +159,14 @@ struct Args {
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
/// This is the maximum allowed value for clients to set `top_n_tokens`.
/// `top_n_tokens is used to return information about the the `n` most likely
/// tokens at each generation step, instead of just the sampled token. This
/// information can be used for downstream tasks like for classification or
/// ranking.
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
/// This is the maximum allowed input length (expressed in number of tokens)
/// for users. The larger this value, the longer prompt users can send which
/// can impact the overall memory required to handle the load.
@ -929,6 +937,8 @@ fn spawn_webserver(
args.max_best_of.to_string(),
"--max-stop-sequences".to_string(),
args.max_stop_sequences.to_string(),
"--max-top-n-tokens".to_string(),
args.max_top_n_tokens.to_string(),
"--max-input-length".to_string(),
args.max_input_length.to_string(),
"--max-total-tokens".to_string(),

View File

@ -101,6 +101,8 @@ message Request {
StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
}
message Batch {
@ -151,6 +153,17 @@ message PrefillTokens {
repeated string texts = 3;
}
message TopTokens {
/// Top Token IDs
repeated uint32 ids = 1;
/// Top Logprobs
repeated float logprobs = 2;
/// Top Token Texts
repeated string texts = 3;
/// If the tokens are special
repeated bool is_special = 6;
}
message Generation {
/// Request ID
uint64 request_id = 1;
@ -166,6 +179,8 @@ message Generation {
bool token_is_special = 6;
/// Complete generated text
optional GeneratedText generated_text = 7;
/// Top tokens
TopTokens top_tokens = 8;
}
message FilterBatchRequest {

View File

@ -132,6 +132,7 @@ impl Client {
ignore_eos_token: false,
}),
prefill_logprobs: true,
top_n_tokens: 20,
});
n_tokens += max_input_length;
}

View File

@ -51,6 +51,7 @@ impl Health {
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
};
let batch = Batch {
id: BATCH_ID,

View File

@ -138,12 +138,15 @@ impl Infer {
&self,
request: GenerateRequest,
) -> Result<InferResponse, InferError> {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new();
let mut result_top_tokens = Vec::new();
let mut result_generated_text = None;
let mut result_start = None;
let mut result_queued = None;
@ -164,7 +167,10 @@ impl Infer {
.collect();
}
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token),
InferStreamResponse::Intermediate { token, top_tokens } => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
}
// Final message
// Set return values
InferStreamResponse::End {
@ -172,8 +178,10 @@ impl Infer {
generated_text,
start,
queued,
top_tokens,
} => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
result_generated_text = Some(generated_text);
result_start = Some(start);
result_queued = Some(queued)
@ -191,6 +199,11 @@ impl Infer {
generated_text,
queued,
start,
top_tokens: if use_top_tokens {
result_top_tokens
} else {
Vec::new()
},
})
} else {
let err = InferError::IncompleteGeneration;
@ -520,6 +533,26 @@ fn send_responses(
special: generation.token_is_special,
};
// generation.top_tokens
let mut top_tokens = Vec::new();
if let Some(top_tokens_) = generation.top_tokens {
top_tokens.extend(
top_tokens_
.ids
.into_iter()
.zip(top_tokens_.logprobs.into_iter())
.zip(top_tokens_.texts.into_iter())
.zip(top_tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
id,
text,
logprob,
special,
}),
)
}
if let Some(generated_text) = generation.generated_text {
// Generation has ended
stopped = true;
@ -527,6 +560,7 @@ fn send_responses(
entry.response_tx.send_timeout(
Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
@ -536,7 +570,7 @@ fn send_responses(
} else {
// Send message
entry.response_tx.send_timeout(
Ok(InferStreamResponse::Token(token)),
Ok(InferStreamResponse::Intermediate { token, top_tokens }),
Duration::from_millis(10),
)?;
}
@ -566,10 +600,14 @@ pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(PrefillTokens),
// Intermediate messages
Token(Token),
Intermediate {
token: Token,
top_tokens: Vec<Token>,
},
// Last message
End {
token: Token,
top_tokens: Vec<Token>,
generated_text: GeneratedText,
start: Instant,
queued: Instant,
@ -583,6 +621,7 @@ pub(crate) struct InferResponse {
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) top_tokens: Vec<Vec<Token>>,
}
#[derive(Debug, Error)]

View File

@ -139,6 +139,8 @@ pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(default=json!({}), example=json!({"hello": 0.5}))]
pub logit_bias: BTreeMap<String, f32>
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>,
}
fn default_max_new_tokens() -> u32 {
@ -163,6 +165,7 @@ fn default_parameters() -> GenerateParameters {
decoder_input_details: false,
seed: None,
logit_bias: BTreeMap::new(),
top_n_tokens: None,
}
}
@ -240,6 +243,8 @@ pub(crate) struct BestOfSequence {
pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
}
#[derive(Serialize, ToSchema)]
@ -254,6 +259,8 @@ pub(crate) struct Details {
pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of_sequences: Option<Vec<BestOfSequence>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
}
#[derive(Serialize, ToSchema)]
@ -277,6 +284,8 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse {
pub token: Token,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Token>,
#[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")]

View File

@ -29,6 +29,8 @@ struct Args {
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_length: usize,
#[clap(default_value = "2048", long, env)]
@ -75,6 +77,7 @@ fn main() -> Result<(), RouterError> {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
waiting_served_ratio,
@ -259,6 +262,7 @@ fn main() -> Result<(), RouterError> {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
waiting_served_ratio,

View File

@ -235,6 +235,7 @@ impl State {
truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
top_n_tokens: entry.request.top_n_tokens,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
@ -329,6 +330,7 @@ mod tests {
max_new_tokens: 1,
stop_sequences: vec![],
},
top_n_tokens: 0,
},
response_tx,
span: info_span!("entry"),

View File

@ -158,7 +158,7 @@ async fn generate(
add_prompt = Some(req.inputs.clone());
}
let details = req.parameters.details || req.parameters.decoder_input_details;
let details: bool = req.parameters.details || req.parameters.decoder_input_details;
// Inference
let (response, best_of_responses) = match req.parameters.best_of {
@ -191,6 +191,7 @@ async fn generate(
generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill,
tokens: response.tokens,
top_tokens: response.top_tokens,
seed: response.generated_text.seed,
}
})
@ -204,6 +205,7 @@ async fn generate(
tokens: response.tokens,
seed: response.generated_text.seed,
best_of_sequences,
top_tokens: response.top_tokens,
})
}
false => None,
@ -385,12 +387,16 @@ async fn generate_stream(
// Prefill is ignored
InferStreamResponse::Prefill(_) => {}
// Yield event for every new token
InferStreamResponse::Token(token) => {
InferStreamResponse::Intermediate{
token,
top_tokens,
} => {
tracing::debug!(parent: &span, "Token: {:?}", token);
// StreamResponse
let stream_token = StreamResponse {
token,
top_tokens: top_tokens,
generated_text: None,
details: None,
};
@ -403,6 +409,7 @@ async fn generate_stream(
generated_text,
start,
queued,
top_tokens,
} => {
// Token details
let details = match details {
@ -451,6 +458,7 @@ async fn generate_stream(
let stream_token = StreamResponse {
token,
top_tokens: top_tokens,
generated_text: Some(output_text),
details
};
@ -509,6 +517,7 @@ pub async fn run(
max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize,
max_total_tokens: usize,
waiting_served_ratio: f32,
@ -571,6 +580,7 @@ pub async fn run(
tokenizer,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);

View File

@ -15,6 +15,7 @@ pub struct Validation {
/// Validation parameters
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize,
max_total_tokens: usize,
/// Channel to communicate with the background tokenization task
@ -27,6 +28,7 @@ impl Validation {
tokenizer: Option<Tokenizer>,
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize,
max_total_tokens: usize,
) -> Self {
@ -54,6 +56,7 @@ impl Validation {
max_best_of,
sender,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
}
@ -143,6 +146,7 @@ impl Validation {
watermark,
decoder_input_details,
logit_bias,
top_n_tokens,
..
} = request.parameters;
@ -219,6 +223,15 @@ impl Validation {
}
};
let top_n_tokens = top_n_tokens
.map(|value| {
if value > self.max_top_n_tokens {
return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
}
Ok(value)
})
.unwrap_or(Ok(0))?;
// Check if inputs is empty
if request.inputs.is_empty() {
return Err(EmptyInput);
@ -268,6 +281,7 @@ impl Validation {
truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters,
stopping_parameters,
top_n_tokens: top_n_tokens,
})
}
@ -341,6 +355,7 @@ pub(crate) struct ValidGenerateRequest {
pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
pub top_n_tokens: u32,
}
#[derive(Error, Debug)]
@ -355,6 +370,10 @@ pub enum ValidationError {
BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream,
#[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
TopNTokens(u32, u32),
#[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
TopNTokensDisabled,
#[error("`decoder_input_details` == true is not supported when streaming tokens")]
PrefillDetailsStream,
#[error("`temperature` must be strictly positive")]
@ -396,14 +415,16 @@ mod tests {
let tokenizer = None;
let max_best_of = 2;
let max_stop_sequence = 3;
let max_input_length = 4;
let max_total_tokens = 5;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
@ -423,14 +444,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequence = 3;
let max_input_length = 4;
let max_total_tokens = 5;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
@ -440,7 +463,7 @@ mod tests {
.validate_input("Hello".to_string(), None, max_new_tokens)
.await
{
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
_ => panic!("Unexpected not max new tokens"),
}
}
@ -450,14 +473,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequence = 3;
let max_input_length = 4;
let max_total_tokens = 5;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
@ -482,14 +507,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequence = 3;
let max_input_length = 4;
let max_total_tokens = 5;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
@ -536,4 +563,75 @@ mod tests {
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
assert_eq!(valid_request.parameters.top_p, 1.0);
}
#[tokio::test]
async fn test_validation_top_n_tokens() {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequences = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(5),
..default_parameters()
},
})
.await
{
Err(ValidationError::TopNTokens(4, 5)) => (),
_ => panic!("Unexpected top_n_tokens"),
}
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(4),
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(0),
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
let valid_request = validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: None,
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
assert_eq!(valid_request.top_n_tokens, 0);
}
}

View File

@ -1,7 +1,9 @@
import torch
from text_generation_server.utils.tokens import (
StopSequenceCriteria,
StoppingCriteria,
FinishReason,
batch_top_tokens,
)
@ -42,3 +44,22 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5]
top_n_tokens_tensor = torch.tensor(top_n_tokens)
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs)
assert topn_tok_ids[0] == []
assert topn_tok_ids[1] == [0, 3]
assert topn_tok_ids[2] == [0, 3, 1, 4]
assert topn_tok_ids[3] == [0, 3, 1, 4]
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
assert topn_tok_logprobs[0] == []
assert topn_tok_logprobs[1] == [-1, -2]
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]

View File

@ -125,6 +125,9 @@ def download_weights(
try:
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
is_local_model = True
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass

View File

@ -1,3 +1,4 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch
import inspect
@ -12,6 +13,7 @@ from text_generation_server.models.types import (
PrefillTokens,
Generation,
GeneratedText,
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -42,6 +44,8 @@ class CausalLMBatch(Batch):
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
@ -72,6 +76,7 @@ class CausalLMBatch(Batch):
inputs = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
@ -88,6 +93,7 @@ class CausalLMBatch(Batch):
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
@ -121,6 +127,9 @@ class CausalLMBatch(Batch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -138,6 +147,8 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
@ -163,6 +174,7 @@ class CausalLMBatch(Batch):
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
total_remaining_decode_tokens = 0
new_padding_right_offset = 0
@ -184,6 +196,7 @@ class CausalLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
@ -223,6 +236,7 @@ class CausalLMBatch(Batch):
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests
@ -235,6 +249,8 @@ class CausalLMBatch(Batch):
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
@ -262,6 +278,7 @@ class CausalLMBatch(Batch):
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
# Batch tensors
@ -269,6 +286,7 @@ class CausalLMBatch(Batch):
attention_mask = None
position_ids = None
past_key_values = []
top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
@ -281,6 +299,7 @@ class CausalLMBatch(Batch):
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
@ -310,6 +329,12 @@ class CausalLMBatch(Batch):
(total_batch_size, max_input_length + padding_right_offset),
)
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
@ -438,6 +463,8 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last,
@ -549,6 +576,12 @@ class CausalLM(Model):
generations: List[Generation] = []
stopped = True
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
)
# Zipped iterator
iterator = zip(
batch.requests,
@ -559,6 +592,9 @@ class CausalLM(Model):
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
@ -571,6 +607,9 @@ class CausalLM(Model):
next_token_chooser,
stopping_criteria,
all_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
@ -637,6 +676,24 @@ class CausalLM(Model):
else:
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
@ -645,6 +702,7 @@ class CausalLM(Model):
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
top_tokens,
)
generations.append(generation)

View File

@ -64,6 +64,7 @@ class LlamaConfig(PretrainedConfig):
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000.0,
**kwargs,
):
self.vocab_size = vocab_size
@ -84,6 +85,7 @@ class LlamaConfig(PretrainedConfig):
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
super().__init__(
pad_token_id=pad_token_id,
@ -189,7 +191,7 @@ class FlashLlamaAttention(torch.nn.Module):
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
# )
self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=self.head_size, base=10000.0, device=weights.device
config=config, dim=self.head_size, base=config.rope_theta, device=weights.device
)
self.softmax_scale = self.head_size**-0.5

View File

@ -1,5 +1,6 @@
import math
import itertools
from text_generation_server.utils.tokens import batch_top_tokens
import torch
import torch.distributed
@ -16,6 +17,7 @@ from text_generation_server.models.types import (
PrefillTokens,
Generation,
GeneratedText,
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
@ -165,6 +167,8 @@ class FlashCausalLMBatch(Batch):
# Generation helpers
next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Number of blocks in this batch
blocks: int
@ -217,6 +221,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters = []
stopping_criterias = []
top_n_tokens = []
# Cumulative length
cumulative_length = 0
@ -259,6 +264,7 @@ class FlashCausalLMBatch(Batch):
)
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
# Paged attention
# Remove one as the first token des not have a past
@ -352,6 +358,9 @@ class FlashCausalLMBatch(Batch):
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
return cls(
batch_id=pb.id,
@ -378,6 +387,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
)
@ -417,6 +428,7 @@ class FlashCausalLMBatch(Batch):
read_offsets = []
stopping_criterias = []
top_n_tokens = []
blocks = 0
max_blocks = 0
@ -443,6 +455,8 @@ class FlashCausalLMBatch(Batch):
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
@ -487,6 +501,7 @@ class FlashCausalLMBatch(Batch):
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -518,6 +533,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
)
@ -566,6 +583,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length)
)
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
start_slots = []
block_tables = []
@ -577,6 +597,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters = []
stopping_criterias = []
top_n_tokens = []
# Cumulative length
cumulative_batch_size = 0
@ -602,6 +623,7 @@ class FlashCausalLMBatch(Batch):
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots
all_input_ids_tensor[
@ -624,6 +646,8 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
# Update
cumulative_batch_size += len(batch)
cumulative_slots += len(batch.slots)
@ -667,6 +691,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
)
@ -832,10 +858,14 @@ class FlashCausalLM(Model):
else:
next_token_logits = out
next_input_ids, next_token_logprobs = batch.next_token_chooser(
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
)
if prefill:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -932,8 +962,11 @@ class FlashCausalLM(Model):
batch.all_input_ids,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
next_token_ids,
next_token_logprobs,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
@ -946,8 +979,11 @@ class FlashCausalLM(Model):
all_input_ids,
do_sample,
seed,
top_n_tokens,
next_token_id,
next_token_logprob,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Append next token to all tokens
all_input_ids.append(next_token_id)
@ -1006,6 +1042,24 @@ class FlashCausalLM(Model):
else:
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
@ -1014,6 +1068,7 @@ class FlashCausalLM(Model):
next_token_text,
next_token_id in self.all_special_ids,
generated_text,
top_tokens,
)
generations.append(generation)

View File

@ -763,6 +763,8 @@ class IdeficsCausalLM(Model):
else:
prefill_tokens = None
top_tokens=None
generation = Generation(
request.id,
prefill_tokens,
@ -771,6 +773,7 @@ class IdeficsCausalLM(Model):
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
top_tokens
)
generations.append(generation)

View File

@ -1,3 +1,4 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch
from dataclasses import dataclass
@ -11,6 +12,7 @@ from text_generation_server.models.types import (
Batch,
Generation,
PrefillTokens,
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -48,6 +50,8 @@ class Seq2SeqLMBatch(Batch):
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
@ -78,7 +82,7 @@ class Seq2SeqLMBatch(Batch):
inputs = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
decoder_input_lengths = []
prefix_offsets = []
read_offsets = []
@ -97,6 +101,7 @@ class Seq2SeqLMBatch(Batch):
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
@ -126,6 +131,9 @@ class Seq2SeqLMBatch(Batch):
prefix_offsets.append(0)
read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -146,6 +154,8 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
@ -173,6 +183,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_input_length = 0
max_decoder_input_length = 0
@ -204,6 +215,7 @@ class Seq2SeqLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
@ -239,6 +251,7 @@ class Seq2SeqLMBatch(Batch):
layer[2] = layer[2][keep_indices, :, -max_input_length:]
layer[3] = layer[3][keep_indices, :, -max_input_length:]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = (
len(request_ids) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens
@ -254,6 +267,8 @@ class Seq2SeqLMBatch(Batch):
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset
@ -289,6 +304,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
# Batch tensors
@ -296,6 +312,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids = None
decoder_attention_mask = None
encoder_last_hidden_state = None
top_n_tokens_tensor = None
past_key_values = []
# Used for slicing correctly inside the tensors
@ -312,6 +329,7 @@ class Seq2SeqLMBatch(Batch):
read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
@ -384,6 +402,12 @@ class Seq2SeqLMBatch(Batch):
),
)
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Copy to correct indices
encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length :, :
@ -488,6 +512,8 @@ class Seq2SeqLMBatch(Batch):
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
@ -613,6 +639,12 @@ class Seq2SeqLM(Model):
batch.past_key_values,
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
)
# Finished requests
generations: List[Generation] = []
stopped = True
@ -628,6 +660,9 @@ class Seq2SeqLM(Model):
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_decoder_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
@ -641,6 +676,9 @@ class Seq2SeqLM(Model):
next_token_chooser,
stopping_criteria,
all_decoder_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
@ -698,6 +736,24 @@ class Seq2SeqLM(Model):
else:
prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
@ -706,6 +762,7 @@ class Seq2SeqLM(Model):
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
top_tokens,
)
generations.append(generation)

View File

@ -1,3 +1,4 @@
from functools import total_ordering
import torch
from abc import ABC, abstractmethod
@ -71,6 +72,25 @@ class PrefillTokens:
return len(self.token_ids)
@dataclass
class TopTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
is_special: List[bool]
def to_pb(self) -> generate_pb2.TopTokens:
return generate_pb2.TopTokens(
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
)
def __len__(self):
return len(self.token_ids)
@dataclass
class Generation:
request_id: int
@ -80,6 +100,8 @@ class Generation:
token_text: str
token_is_special: bool
generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[TopTokens]
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
@ -94,4 +116,5 @@ class Generation:
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
)

View File

@ -6,20 +6,22 @@ from transformers import (
PreTrainedTokenizerBase,
SequenceBiasLogitsProcessor,
)
from typing import List, Tuple, Optional, Dict
from typing import Callable, List, Tuple, Optional, Dict
import torch
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from text_generation_server.utils.logits_process import (
static_warper,
HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper,
HeterogeneousProcessorWrapper,
static_warper,
)
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
class NextTokenChooser:
@ -255,11 +257,10 @@ class HeterogeneousNextTokenChooser:
scores = warper(input_ids, scores)
next_ids = self.choice(scores)
next_logprobs = torch.gather(
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
).view(-1)
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
return next_ids, next_logprobs
return next_ids, next_logprobs, logprobs
def filter(self, indices):
if self.watermark_processor is not None:
@ -370,3 +371,50 @@ class HeterogeneousSampling:
self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping
return self
def batch_top_tokens(
top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
) -> Tuple[List[List[int]], List[List[float]]]:
"""Find the top n most likely tokens for a batch of generations.
When multiple tokens have equal probabilities and they don't all fit, the
remaining tokens are also returned.
"""
max_top_n = max(top_n_tokens)
# Early exit when top_n_tokens is not used
if max_top_n == 0:
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
# Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
)
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
# Find the new "fuzzy" top n values
top_n_indices = (logprobs >= nth_highest).nonzero()
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
# Take a new topk for these new max n values
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
top_n_ishes = top_n_ishes.tolist()
top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist()
return (
[
idxs[:n] if req_n > 0 else []
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
],
[
vals[:n] if req_n > 0 else []
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
],
)