mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Merge branch 'main' into main
This commit is contained in:
commit
d4e2ce7e7b
@ -86,7 +86,7 @@ The easiest way of getting started is using the official Docker container:
|
|||||||
model=tiiuae/falcon-7b-instruct
|
model=tiiuae/falcon-7b-instruct
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.1 --model-id $model
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.2 --model-id $model
|
||||||
```
|
```
|
||||||
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||||
|
|
||||||
@ -153,7 +153,7 @@ model=meta-llama/Llama-2-7b-chat-hf
|
|||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
token=<your cli READ token>
|
token=<your cli READ token>
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.1 --model-id $model
|
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.2 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
### A note on Shared Memory (shm)
|
### A note on Shared Memory (shm)
|
||||||
|
@ -37,6 +37,7 @@ pub(crate) async fn generation_task(
|
|||||||
batch_size: Vec<u32>,
|
batch_size: Vec<u32>,
|
||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
|
top_n_tokens: Option<u32>,
|
||||||
n_runs: usize,
|
n_runs: usize,
|
||||||
warmups: usize,
|
warmups: usize,
|
||||||
parameters: NextTokenChooserParameters,
|
parameters: NextTokenChooserParameters,
|
||||||
@ -48,7 +49,7 @@ pub(crate) async fn generation_task(
|
|||||||
// End task if a message is received on shutdown_receiver
|
// End task if a message is received on shutdown_receiver
|
||||||
// _shutdown_guard_sender will be dropped once the task is finished
|
// _shutdown_guard_sender will be dropped once the task is finished
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => {
|
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
|
||||||
if let Err(err) = res {
|
if let Err(err) = res {
|
||||||
run_sender.send(Err(err)).await.unwrap_or(());
|
run_sender.send(Err(err)).await.unwrap_or(());
|
||||||
}
|
}
|
||||||
@ -64,6 +65,7 @@ async fn generate_runs(
|
|||||||
batch_size: Vec<u32>,
|
batch_size: Vec<u32>,
|
||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
|
top_n_tokens: Option<u32>,
|
||||||
n_runs: usize,
|
n_runs: usize,
|
||||||
warmups: usize,
|
warmups: usize,
|
||||||
parameters: NextTokenChooserParameters,
|
parameters: NextTokenChooserParameters,
|
||||||
@ -82,6 +84,7 @@ async fn generate_runs(
|
|||||||
b,
|
b,
|
||||||
decode_length,
|
decode_length,
|
||||||
parameters.clone(),
|
parameters.clone(),
|
||||||
|
top_n_tokens,
|
||||||
&mut client,
|
&mut client,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
@ -97,6 +100,7 @@ async fn generate_runs(
|
|||||||
b,
|
b,
|
||||||
decode_length,
|
decode_length,
|
||||||
parameters.clone(),
|
parameters.clone(),
|
||||||
|
top_n_tokens,
|
||||||
&mut client,
|
&mut client,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
@ -130,6 +134,7 @@ async fn prefill(
|
|||||||
batch_size: u32,
|
batch_size: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
parameters: NextTokenChooserParameters,
|
parameters: NextTokenChooserParameters,
|
||||||
|
top_n_tokens: Option<u32>,
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
) -> Result<(Prefill, CachedBatch), ClientError> {
|
) -> Result<(Prefill, CachedBatch), ClientError> {
|
||||||
// Create requests
|
// Create requests
|
||||||
@ -145,6 +150,7 @@ async fn prefill(
|
|||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
||||||
}),
|
}),
|
||||||
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ pub async fn run(
|
|||||||
batch_size: Vec<u32>,
|
batch_size: Vec<u32>,
|
||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
|
top_n_tokens: Option<u32>,
|
||||||
n_runs: usize,
|
n_runs: usize,
|
||||||
warmups: usize,
|
warmups: usize,
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
@ -75,6 +76,7 @@ pub async fn run(
|
|||||||
batch_size.clone(),
|
batch_size.clone(),
|
||||||
sequence_length,
|
sequence_length,
|
||||||
decode_length,
|
decode_length,
|
||||||
|
top_n_tokens,
|
||||||
n_runs,
|
n_runs,
|
||||||
warmups,
|
warmups,
|
||||||
parameters,
|
parameters,
|
||||||
@ -135,6 +137,7 @@ pub async fn run(
|
|||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
sequence_length,
|
sequence_length,
|
||||||
decode_length,
|
decode_length,
|
||||||
|
top_n_tokens,
|
||||||
n_runs,
|
n_runs,
|
||||||
warmups,
|
warmups,
|
||||||
temperature,
|
temperature,
|
||||||
|
@ -94,6 +94,12 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
|
|
||||||
|
|
||||||
|
/// Generation parameter in case you want to specifically test/debug particular
|
||||||
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
|
#[clap(long, env)]
|
||||||
|
top_n_tokens: Option<u32>,
|
||||||
|
|
||||||
/// Generation parameter in case you want to specifically test/debug particular
|
/// Generation parameter in case you want to specifically test/debug particular
|
||||||
/// decoding strategies, for full doc refer to the `text-generation-server`
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
||||||
#[clap(long, env, value_parser=parse_key_val::<String, f32>)]
|
#[clap(long, env, value_parser=parse_key_val::<String, f32>)]
|
||||||
@ -123,6 +129,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
do_sample,
|
do_sample,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
logit_bias,
|
logit_bias,
|
||||||
|
top_n_tokens,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
||||||
@ -179,6 +186,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
batch_size,
|
batch_size,
|
||||||
sequence_length,
|
sequence_length,
|
||||||
decode_length,
|
decode_length,
|
||||||
|
top_n_tokens,
|
||||||
runs,
|
runs,
|
||||||
warmups,
|
warmups,
|
||||||
temperature,
|
temperature,
|
||||||
|
@ -7,6 +7,7 @@ pub(crate) fn parameters_table(
|
|||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
sequence_length: u32,
|
sequence_length: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
|
top_n_tokens: Option<u32>,
|
||||||
n_runs: usize,
|
n_runs: usize,
|
||||||
warmups: usize,
|
warmups: usize,
|
||||||
temperature: Option<f32>,
|
temperature: Option<f32>,
|
||||||
@ -25,6 +26,7 @@ pub(crate) fn parameters_table(
|
|||||||
builder.push_record(["Model", &tokenizer_name]);
|
builder.push_record(["Model", &tokenizer_name]);
|
||||||
builder.push_record(["Sequence Length", &sequence_length.to_string()]);
|
builder.push_record(["Sequence Length", &sequence_length.to_string()]);
|
||||||
builder.push_record(["Decode Length", &decode_length.to_string()]);
|
builder.push_record(["Decode Length", &decode_length.to_string()]);
|
||||||
|
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
|
||||||
builder.push_record(["N Runs", &n_runs.to_string()]);
|
builder.push_record(["N Runs", &n_runs.to_string()]);
|
||||||
builder.push_record(["Warmups", &warmups.to_string()]);
|
builder.push_record(["Warmups", &warmups.to_string()]);
|
||||||
builder.push_record(["Temperature", &format!("{temperature:?}")]);
|
builder.push_record(["Temperature", &format!("{temperature:?}")]);
|
||||||
|
@ -75,6 +75,7 @@ class Client:
|
|||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
decoder_input_details: bool = False,
|
decoder_input_details: bool = False,
|
||||||
|
top_n_tokens: Optional[int] = None,
|
||||||
logit_bias: Dict[str, float] = {},
|
logit_bias: Dict[str, float] = {},
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
@ -114,6 +115,8 @@ class Client:
|
|||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
decoder_input_details (`bool`):
|
decoder_input_details (`bool`):
|
||||||
Return the decoder input token logprobs and ids
|
Return the decoder input token logprobs and ids
|
||||||
|
top_n_tokens (`int`):
|
||||||
|
Return the `n` most likely tokens at each step
|
||||||
logit_bias (`Dict[str, float]`):
|
logit_bias (`Dict[str, float]`):
|
||||||
Bias generation towards certain tokens.
|
Bias generation towards certain tokens.
|
||||||
|
|
||||||
@ -137,6 +140,7 @@ class Client:
|
|||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
decoder_input_details=decoder_input_details,
|
decoder_input_details=decoder_input_details,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
@ -168,6 +172,7 @@ class Client:
|
|||||||
truncate: Optional[int] = None,
|
truncate: Optional[int] = None,
|
||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
|
top_n_tokens: Optional[int] = None,
|
||||||
logit_bias: Dict[str, float] = {},
|
logit_bias: Dict[str, float] = {},
|
||||||
) -> Iterator[StreamResponse]:
|
) -> Iterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
@ -203,6 +208,8 @@ class Client:
|
|||||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
|
||||||
watermark (`bool`):
|
watermark (`bool`):
|
||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
|
top_n_tokens (`int`):
|
||||||
|
Return the `n` most likely tokens at each step
|
||||||
logit_bias (`Dict[str, float]`):
|
logit_bias (`Dict[str, float]`):
|
||||||
Bias generation towards certain tokens.
|
Bias generation towards certain tokens.
|
||||||
|
|
||||||
@ -227,6 +234,7 @@ class Client:
|
|||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
@ -326,6 +334,7 @@ class AsyncClient:
|
|||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
decoder_input_details: bool = False,
|
decoder_input_details: bool = False,
|
||||||
logit_bias: Dict[str, float] = {},
|
logit_bias: Dict[str, float] = {},
|
||||||
|
top_n_tokens: Optional[int] = None,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following text asynchronously
|
Given a prompt, generate the following text asynchronously
|
||||||
@ -366,6 +375,8 @@ class AsyncClient:
|
|||||||
Return the decoder input token logprobs and ids
|
Return the decoder input token logprobs and ids
|
||||||
logit_bias (`Dict[str, float]`):
|
logit_bias (`Dict[str, float]`):
|
||||||
Bias generation towards certain tokens.
|
Bias generation towards certain tokens.
|
||||||
|
top_n_tokens (`int`):
|
||||||
|
Return the `n` most likely tokens at each step
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response: generated response
|
Response: generated response
|
||||||
@ -388,6 +399,7 @@ class AsyncClient:
|
|||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
@ -417,6 +429,7 @@ class AsyncClient:
|
|||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
logit_bias: Dict[str, float] = {},
|
logit_bias: Dict[str, float] = {},
|
||||||
|
top_n_tokens: Optional[int] = None,
|
||||||
) -> AsyncIterator[StreamResponse]:
|
) -> AsyncIterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following stream of tokens asynchronously
|
Given a prompt, generate the following stream of tokens asynchronously
|
||||||
@ -453,6 +466,8 @@ class AsyncClient:
|
|||||||
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||||
logit_bias (`Dict[str, float]`):
|
logit_bias (`Dict[str, float]`):
|
||||||
Bias generation towards certain tokens.
|
Bias generation towards certain tokens.
|
||||||
|
top_n_tokens (`int`):
|
||||||
|
Return the `n` most likely tokens at each step
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncIterator[StreamResponse]: stream of generated tokens
|
AsyncIterator[StreamResponse]: stream of generated tokens
|
||||||
@ -475,6 +490,7 @@ class AsyncClient:
|
|||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
|
@ -41,6 +41,8 @@ class Parameters(BaseModel):
|
|||||||
decoder_input_details: bool = False
|
decoder_input_details: bool = False
|
||||||
# Bias generation towards certain tokens
|
# Bias generation towards certain tokens
|
||||||
logit_bias: Dict[str, float] = {}
|
logit_bias: Dict[str, float] = {}
|
||||||
|
# Return the N most likely tokens at each step
|
||||||
|
top_n_tokens: Optional[int]
|
||||||
|
|
||||||
@validator("best_of")
|
@validator("best_of")
|
||||||
def valid_best_of(cls, field_value, values):
|
def valid_best_of(cls, field_value, values):
|
||||||
@ -103,6 +105,12 @@ class Parameters(BaseModel):
|
|||||||
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
|
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@validator("top_n_tokens")
|
||||||
|
def valid_top_n_tokens(cls, v):
|
||||||
|
if v is not None and v <= 0:
|
||||||
|
raise ValidationError("`top_n_tokens` must be strictly positive")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class Request(BaseModel):
|
class Request(BaseModel):
|
||||||
# Prompt
|
# Prompt
|
||||||
@ -127,9 +135,7 @@ class Request(BaseModel):
|
|||||||
and parameters.best_of > 1
|
and parameters.best_of > 1
|
||||||
and field_value
|
and field_value
|
||||||
):
|
):
|
||||||
raise ValidationError(
|
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
|
||||||
"`best_of` != 1 is not supported when `stream` == True"
|
|
||||||
)
|
|
||||||
return field_value
|
return field_value
|
||||||
|
|
||||||
|
|
||||||
@ -181,6 +187,8 @@ class BestOfSequence(BaseModel):
|
|||||||
prefill: List[InputToken]
|
prefill: List[InputToken]
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
|
# Most likely tokens
|
||||||
|
top_tokens: Optional[List[List[Token]]]
|
||||||
|
|
||||||
|
|
||||||
# `generate` details
|
# `generate` details
|
||||||
@ -195,6 +203,8 @@ class Details(BaseModel):
|
|||||||
prefill: List[InputToken]
|
prefill: List[InputToken]
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
|
# Most likely tokens
|
||||||
|
top_tokens: Optional[List[List[Token]]]
|
||||||
# Additional sequences when using the `best_of` parameter
|
# Additional sequences when using the `best_of` parameter
|
||||||
best_of_sequences: Optional[List[BestOfSequence]]
|
best_of_sequences: Optional[List[BestOfSequence]]
|
||||||
|
|
||||||
@ -221,6 +231,8 @@ class StreamDetails(BaseModel):
|
|||||||
class StreamResponse(BaseModel):
|
class StreamResponse(BaseModel):
|
||||||
# Generated token
|
# Generated token
|
||||||
token: Token
|
token: Token
|
||||||
|
# Most likely tokens
|
||||||
|
top_tokens: Optional[List[Token]]
|
||||||
# Complete generated text
|
# Complete generated text
|
||||||
# Only available when the generation is finished
|
# Only available when the generation is finished
|
||||||
generated_text: Optional[str]
|
generated_text: Optional[str]
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "1.0.1"
|
"version": "1.0.2"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/": {
|
"/": {
|
||||||
|
@ -75,6 +75,81 @@ To serve both ChatUI and TGI in same environment, simply add your own endpoints
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
## Gradio
|
||||||
|
|
||||||
|
Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install huggingface-hub gradio
|
||||||
|
```
|
||||||
|
|
||||||
|
Assume you are serving your model on port 8080, we will query through [InferenceClient](consuming_tgi#inference-client).
|
||||||
|
|
||||||
|
```python
|
||||||
|
import gradio as gr
|
||||||
|
from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
|
client = InferenceClient(model="http://127.0.0.1:8080")
|
||||||
|
|
||||||
|
def inference(message, history):
|
||||||
|
partial_message = ""
|
||||||
|
for token in client.text_generation(message, max_new_tokens=20, stream=True):
|
||||||
|
partial_message += token
|
||||||
|
yield partial_message
|
||||||
|
|
||||||
|
gr.ChatInterface(
|
||||||
|
inference,
|
||||||
|
chatbot=gr.Chatbot(height=300),
|
||||||
|
textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7),
|
||||||
|
description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.",
|
||||||
|
title="Gradio 🤝 TGI",
|
||||||
|
examples=["Are tomatoes vegetables?"],
|
||||||
|
retry_btn="Retry",
|
||||||
|
undo_btn="Undo",
|
||||||
|
clear_btn="Clear",
|
||||||
|
).queue().launch()
|
||||||
|
```
|
||||||
|
|
||||||
|
The UI looks like this 👇
|
||||||
|
|
||||||
|
<div class="flex justify-center">
|
||||||
|
<img
|
||||||
|
class="block dark:hidden"
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi.png"
|
||||||
|
/>
|
||||||
|
<img
|
||||||
|
class="hidden dark:block"
|
||||||
|
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi-dark.png"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
You can try the demo directly here 👇
|
||||||
|
|
||||||
|
<div class="block dark:hidden">
|
||||||
|
<iframe
|
||||||
|
src="https://merve-gradio-tgi-2.hf.space?__theme=light"
|
||||||
|
width="850"
|
||||||
|
height="750"
|
||||||
|
></iframe>
|
||||||
|
</div>
|
||||||
|
<div class="hidden dark:block">
|
||||||
|
<iframe
|
||||||
|
src="https://merve-gradio-tgi-2.hf.space?__theme=dark"
|
||||||
|
width="850"
|
||||||
|
height="750"
|
||||||
|
></iframe>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
You can disable streaming mode using `return` instead of `yield` in your inference function, like below.
|
||||||
|
|
||||||
|
```python
|
||||||
|
def inference(message, history):
|
||||||
|
return client.text_generation(message, max_new_tokens=20)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can read more about how to customize a `ChatInterface` [here](https://www.gradio.app/guides/creating-a-chatbot-fast).
|
||||||
|
|
||||||
## API documentation
|
## API documentation
|
||||||
|
|
||||||
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. The Swagger UI is also available [here](https://huggingface.github.io/text-generation-inference).
|
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. The Swagger UI is also available [here](https://huggingface.github.io/text-generation-inference).
|
||||||
|
@ -8,7 +8,7 @@ Let's say you want to deploy [Falcon-7B Instruct](https://huggingface.co/tiiuae/
|
|||||||
model=tiiuae/falcon-7b-instruct
|
model=tiiuae/falcon-7b-instruct
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.1 --model-id $model
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.0.2 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
@ -85,7 +85,7 @@ curl 127.0.0.1:8080/generate \
|
|||||||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run ghcr.io/huggingface/text-generation-inference:1.0.1 --help
|
docker run ghcr.io/huggingface/text-generation-inference:1.0.2 --help
|
||||||
```
|
```
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
@ -159,6 +159,14 @@ struct Args {
|
|||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
|
|
||||||
|
/// This is the maximum allowed value for clients to set `top_n_tokens`.
|
||||||
|
/// `top_n_tokens is used to return information about the the `n` most likely
|
||||||
|
/// tokens at each generation step, instead of just the sampled token. This
|
||||||
|
/// information can be used for downstream tasks like for classification or
|
||||||
|
/// ranking.
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
|
||||||
/// This is the maximum allowed input length (expressed in number of tokens)
|
/// This is the maximum allowed input length (expressed in number of tokens)
|
||||||
/// for users. The larger this value, the longer prompt users can send which
|
/// for users. The larger this value, the longer prompt users can send which
|
||||||
/// can impact the overall memory required to handle the load.
|
/// can impact the overall memory required to handle the load.
|
||||||
@ -929,6 +937,8 @@ fn spawn_webserver(
|
|||||||
args.max_best_of.to_string(),
|
args.max_best_of.to_string(),
|
||||||
"--max-stop-sequences".to_string(),
|
"--max-stop-sequences".to_string(),
|
||||||
args.max_stop_sequences.to_string(),
|
args.max_stop_sequences.to_string(),
|
||||||
|
"--max-top-n-tokens".to_string(),
|
||||||
|
args.max_top_n_tokens.to_string(),
|
||||||
"--max-input-length".to_string(),
|
"--max-input-length".to_string(),
|
||||||
args.max_input_length.to_string(),
|
args.max_input_length.to_string(),
|
||||||
"--max-total-tokens".to_string(),
|
"--max-total-tokens".to_string(),
|
||||||
|
@ -101,6 +101,8 @@ message Request {
|
|||||||
StoppingCriteriaParameters stopping_parameters = 5;
|
StoppingCriteriaParameters stopping_parameters = 5;
|
||||||
/// Return prefill logprobs
|
/// Return prefill logprobs
|
||||||
bool prefill_logprobs = 6;
|
bool prefill_logprobs = 6;
|
||||||
|
/// Return most likely n tokens
|
||||||
|
uint32 top_n_tokens = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
@ -151,6 +153,17 @@ message PrefillTokens {
|
|||||||
repeated string texts = 3;
|
repeated string texts = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message TopTokens {
|
||||||
|
/// Top Token IDs
|
||||||
|
repeated uint32 ids = 1;
|
||||||
|
/// Top Logprobs
|
||||||
|
repeated float logprobs = 2;
|
||||||
|
/// Top Token Texts
|
||||||
|
repeated string texts = 3;
|
||||||
|
/// If the tokens are special
|
||||||
|
repeated bool is_special = 6;
|
||||||
|
}
|
||||||
|
|
||||||
message Generation {
|
message Generation {
|
||||||
/// Request ID
|
/// Request ID
|
||||||
uint64 request_id = 1;
|
uint64 request_id = 1;
|
||||||
@ -166,6 +179,8 @@ message Generation {
|
|||||||
bool token_is_special = 6;
|
bool token_is_special = 6;
|
||||||
/// Complete generated text
|
/// Complete generated text
|
||||||
optional GeneratedText generated_text = 7;
|
optional GeneratedText generated_text = 7;
|
||||||
|
/// Top tokens
|
||||||
|
TopTokens top_tokens = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
message FilterBatchRequest {
|
message FilterBatchRequest {
|
||||||
|
@ -132,6 +132,7 @@ impl Client {
|
|||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
}),
|
}),
|
||||||
prefill_logprobs: true,
|
prefill_logprobs: true,
|
||||||
|
top_n_tokens: 20,
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
n_tokens += max_input_length;
|
||||||
}
|
}
|
||||||
|
@ -51,6 +51,7 @@ impl Health {
|
|||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
}),
|
}),
|
||||||
|
top_n_tokens: 0,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: BATCH_ID,
|
id: BATCH_ID,
|
||||||
|
@ -138,12 +138,15 @@ impl Infer {
|
|||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<InferResponse, InferError> {
|
) -> Result<InferResponse, InferError> {
|
||||||
|
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||||
|
|
||||||
// Create stream and keep semaphore permit as long as generate lives
|
// Create stream and keep semaphore permit as long as generate lives
|
||||||
let (_permit, mut stream) = self.generate_stream(request).await?;
|
let (_permit, mut stream) = self.generate_stream(request).await?;
|
||||||
|
|
||||||
// Return values
|
// Return values
|
||||||
let mut result_prefill = Vec::new();
|
let mut result_prefill = Vec::new();
|
||||||
let mut result_tokens = Vec::new();
|
let mut result_tokens = Vec::new();
|
||||||
|
let mut result_top_tokens = Vec::new();
|
||||||
let mut result_generated_text = None;
|
let mut result_generated_text = None;
|
||||||
let mut result_start = None;
|
let mut result_start = None;
|
||||||
let mut result_queued = None;
|
let mut result_queued = None;
|
||||||
@ -164,7 +167,10 @@ impl Infer {
|
|||||||
.collect();
|
.collect();
|
||||||
}
|
}
|
||||||
// Push last token
|
// Push last token
|
||||||
InferStreamResponse::Token(token) => result_tokens.push(token),
|
InferStreamResponse::Intermediate { token, top_tokens } => {
|
||||||
|
result_tokens.push(token);
|
||||||
|
result_top_tokens.push(top_tokens);
|
||||||
|
}
|
||||||
// Final message
|
// Final message
|
||||||
// Set return values
|
// Set return values
|
||||||
InferStreamResponse::End {
|
InferStreamResponse::End {
|
||||||
@ -172,8 +178,10 @@ impl Infer {
|
|||||||
generated_text,
|
generated_text,
|
||||||
start,
|
start,
|
||||||
queued,
|
queued,
|
||||||
|
top_tokens,
|
||||||
} => {
|
} => {
|
||||||
result_tokens.push(token);
|
result_tokens.push(token);
|
||||||
|
result_top_tokens.push(top_tokens);
|
||||||
result_generated_text = Some(generated_text);
|
result_generated_text = Some(generated_text);
|
||||||
result_start = Some(start);
|
result_start = Some(start);
|
||||||
result_queued = Some(queued)
|
result_queued = Some(queued)
|
||||||
@ -191,6 +199,11 @@ impl Infer {
|
|||||||
generated_text,
|
generated_text,
|
||||||
queued,
|
queued,
|
||||||
start,
|
start,
|
||||||
|
top_tokens: if use_top_tokens {
|
||||||
|
result_top_tokens
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
},
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
let err = InferError::IncompleteGeneration;
|
let err = InferError::IncompleteGeneration;
|
||||||
@ -520,6 +533,26 @@ fn send_responses(
|
|||||||
special: generation.token_is_special,
|
special: generation.token_is_special,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// generation.top_tokens
|
||||||
|
|
||||||
|
let mut top_tokens = Vec::new();
|
||||||
|
if let Some(top_tokens_) = generation.top_tokens {
|
||||||
|
top_tokens.extend(
|
||||||
|
top_tokens_
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(top_tokens_.logprobs.into_iter())
|
||||||
|
.zip(top_tokens_.texts.into_iter())
|
||||||
|
.zip(top_tokens_.is_special.into_iter())
|
||||||
|
.map(|(((id, logprob), text), special)| Token {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
logprob,
|
||||||
|
special,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(generated_text) = generation.generated_text {
|
if let Some(generated_text) = generation.generated_text {
|
||||||
// Generation has ended
|
// Generation has ended
|
||||||
stopped = true;
|
stopped = true;
|
||||||
@ -527,6 +560,7 @@ fn send_responses(
|
|||||||
entry.response_tx.send_timeout(
|
entry.response_tx.send_timeout(
|
||||||
Ok(InferStreamResponse::End {
|
Ok(InferStreamResponse::End {
|
||||||
token,
|
token,
|
||||||
|
top_tokens,
|
||||||
generated_text,
|
generated_text,
|
||||||
queued: entry.queue_time,
|
queued: entry.queue_time,
|
||||||
start: entry.batch_time.unwrap(),
|
start: entry.batch_time.unwrap(),
|
||||||
@ -536,7 +570,7 @@ fn send_responses(
|
|||||||
} else {
|
} else {
|
||||||
// Send message
|
// Send message
|
||||||
entry.response_tx.send_timeout(
|
entry.response_tx.send_timeout(
|
||||||
Ok(InferStreamResponse::Token(token)),
|
Ok(InferStreamResponse::Intermediate { token, top_tokens }),
|
||||||
Duration::from_millis(10),
|
Duration::from_millis(10),
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
@ -566,10 +600,14 @@ pub(crate) enum InferStreamResponse {
|
|||||||
// Optional first message
|
// Optional first message
|
||||||
Prefill(PrefillTokens),
|
Prefill(PrefillTokens),
|
||||||
// Intermediate messages
|
// Intermediate messages
|
||||||
Token(Token),
|
Intermediate {
|
||||||
|
token: Token,
|
||||||
|
top_tokens: Vec<Token>,
|
||||||
|
},
|
||||||
// Last message
|
// Last message
|
||||||
End {
|
End {
|
||||||
token: Token,
|
token: Token,
|
||||||
|
top_tokens: Vec<Token>,
|
||||||
generated_text: GeneratedText,
|
generated_text: GeneratedText,
|
||||||
start: Instant,
|
start: Instant,
|
||||||
queued: Instant,
|
queued: Instant,
|
||||||
@ -583,6 +621,7 @@ pub(crate) struct InferResponse {
|
|||||||
pub(crate) generated_text: GeneratedText,
|
pub(crate) generated_text: GeneratedText,
|
||||||
pub(crate) queued: Instant,
|
pub(crate) queued: Instant,
|
||||||
pub(crate) start: Instant,
|
pub(crate) start: Instant,
|
||||||
|
pub(crate) top_tokens: Vec<Vec<Token>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
|
@ -139,6 +139,8 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default=json!({}), example=json!({"hello": 0.5}))]
|
#[schema(default=json!({}), example=json!({"hello": 0.5}))]
|
||||||
pub logit_bias: BTreeMap<String, f32>
|
pub logit_bias: BTreeMap<String, f32>
|
||||||
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||||
|
pub top_n_tokens: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> u32 {
|
fn default_max_new_tokens() -> u32 {
|
||||||
@ -163,6 +165,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
logit_bias: BTreeMap::new(),
|
logit_bias: BTreeMap::new(),
|
||||||
|
top_n_tokens: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -240,6 +243,8 @@ pub(crate) struct BestOfSequence {
|
|||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
pub prefill: Vec<PrefillToken>,
|
pub prefill: Vec<PrefillToken>,
|
||||||
pub tokens: Vec<Token>,
|
pub tokens: Vec<Token>,
|
||||||
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
|
pub top_tokens: Vec<Vec<Token>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
@ -254,6 +259,8 @@ pub(crate) struct Details {
|
|||||||
pub tokens: Vec<Token>,
|
pub tokens: Vec<Token>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
||||||
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
|
pub top_tokens: Vec<Vec<Token>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
@ -277,6 +284,8 @@ pub(crate) struct StreamDetails {
|
|||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct StreamResponse {
|
pub(crate) struct StreamResponse {
|
||||||
pub token: Token,
|
pub token: Token,
|
||||||
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
|
pub top_tokens: Vec<Token>,
|
||||||
#[schema(nullable = true, default = "null", example = "test")]
|
#[schema(nullable = true, default = "null", example = "test")]
|
||||||
pub generated_text: Option<String>,
|
pub generated_text: Option<String>,
|
||||||
#[schema(nullable = true, default = "null")]
|
#[schema(nullable = true, default = "null")]
|
||||||
|
@ -29,6 +29,8 @@ struct Args {
|
|||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_top_n_tokens: u32,
|
||||||
#[clap(default_value = "1024", long, env)]
|
#[clap(default_value = "1024", long, env)]
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
#[clap(default_value = "2048", long, env)]
|
#[clap(default_value = "2048", long, env)]
|
||||||
@ -75,6 +77,7 @@ fn main() -> Result<(), RouterError> {
|
|||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
@ -259,6 +262,7 @@ fn main() -> Result<(), RouterError> {
|
|||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
waiting_served_ratio,
|
waiting_served_ratio,
|
||||||
|
@ -235,6 +235,7 @@ impl State {
|
|||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
||||||
parameters: Some(entry.request.parameters.clone()),
|
parameters: Some(entry.request.parameters.clone()),
|
||||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||||
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
@ -329,6 +330,7 @@ mod tests {
|
|||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
},
|
},
|
||||||
|
top_n_tokens: 0,
|
||||||
},
|
},
|
||||||
response_tx,
|
response_tx,
|
||||||
span: info_span!("entry"),
|
span: info_span!("entry"),
|
||||||
|
@ -158,7 +158,7 @@ async fn generate(
|
|||||||
add_prompt = Some(req.inputs.clone());
|
add_prompt = Some(req.inputs.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
let details = req.parameters.details || req.parameters.decoder_input_details;
|
let details: bool = req.parameters.details || req.parameters.decoder_input_details;
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
let (response, best_of_responses) = match req.parameters.best_of {
|
let (response, best_of_responses) = match req.parameters.best_of {
|
||||||
@ -191,6 +191,7 @@ async fn generate(
|
|||||||
generated_tokens: response.generated_text.generated_tokens,
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
prefill: response.prefill,
|
prefill: response.prefill,
|
||||||
tokens: response.tokens,
|
tokens: response.tokens,
|
||||||
|
top_tokens: response.top_tokens,
|
||||||
seed: response.generated_text.seed,
|
seed: response.generated_text.seed,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -204,6 +205,7 @@ async fn generate(
|
|||||||
tokens: response.tokens,
|
tokens: response.tokens,
|
||||||
seed: response.generated_text.seed,
|
seed: response.generated_text.seed,
|
||||||
best_of_sequences,
|
best_of_sequences,
|
||||||
|
top_tokens: response.top_tokens,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
false => None,
|
false => None,
|
||||||
@ -385,12 +387,16 @@ async fn generate_stream(
|
|||||||
// Prefill is ignored
|
// Prefill is ignored
|
||||||
InferStreamResponse::Prefill(_) => {}
|
InferStreamResponse::Prefill(_) => {}
|
||||||
// Yield event for every new token
|
// Yield event for every new token
|
||||||
InferStreamResponse::Token(token) => {
|
InferStreamResponse::Intermediate{
|
||||||
|
token,
|
||||||
|
top_tokens,
|
||||||
|
} => {
|
||||||
tracing::debug!(parent: &span, "Token: {:?}", token);
|
tracing::debug!(parent: &span, "Token: {:?}", token);
|
||||||
|
|
||||||
// StreamResponse
|
// StreamResponse
|
||||||
let stream_token = StreamResponse {
|
let stream_token = StreamResponse {
|
||||||
token,
|
token,
|
||||||
|
top_tokens: top_tokens,
|
||||||
generated_text: None,
|
generated_text: None,
|
||||||
details: None,
|
details: None,
|
||||||
};
|
};
|
||||||
@ -403,6 +409,7 @@ async fn generate_stream(
|
|||||||
generated_text,
|
generated_text,
|
||||||
start,
|
start,
|
||||||
queued,
|
queued,
|
||||||
|
top_tokens,
|
||||||
} => {
|
} => {
|
||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
let details = match details {
|
||||||
@ -451,6 +458,7 @@ async fn generate_stream(
|
|||||||
|
|
||||||
let stream_token = StreamResponse {
|
let stream_token = StreamResponse {
|
||||||
token,
|
token,
|
||||||
|
top_tokens: top_tokens,
|
||||||
generated_text: Some(output_text),
|
generated_text: Some(output_text),
|
||||||
details
|
details
|
||||||
};
|
};
|
||||||
@ -509,6 +517,7 @@ pub async fn run(
|
|||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
waiting_served_ratio: f32,
|
waiting_served_ratio: f32,
|
||||||
@ -571,6 +580,7 @@ pub async fn run(
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
|
@ -15,6 +15,7 @@ pub struct Validation {
|
|||||||
/// Validation parameters
|
/// Validation parameters
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
/// Channel to communicate with the background tokenization task
|
/// Channel to communicate with the background tokenization task
|
||||||
@ -27,6 +28,7 @@ impl Validation {
|
|||||||
tokenizer: Option<Tokenizer>,
|
tokenizer: Option<Tokenizer>,
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
@ -54,6 +56,7 @@ impl Validation {
|
|||||||
max_best_of,
|
max_best_of,
|
||||||
sender,
|
sender,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
}
|
}
|
||||||
@ -143,6 +146,7 @@ impl Validation {
|
|||||||
watermark,
|
watermark,
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
logit_bias,
|
logit_bias,
|
||||||
|
top_n_tokens,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
@ -219,6 +223,15 @@ impl Validation {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let top_n_tokens = top_n_tokens
|
||||||
|
.map(|value| {
|
||||||
|
if value > self.max_top_n_tokens {
|
||||||
|
return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
|
||||||
|
}
|
||||||
|
Ok(value)
|
||||||
|
})
|
||||||
|
.unwrap_or(Ok(0))?;
|
||||||
|
|
||||||
// Check if inputs is empty
|
// Check if inputs is empty
|
||||||
if request.inputs.is_empty() {
|
if request.inputs.is_empty() {
|
||||||
return Err(EmptyInput);
|
return Err(EmptyInput);
|
||||||
@ -268,6 +281,7 @@ impl Validation {
|
|||||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||||
parameters,
|
parameters,
|
||||||
stopping_parameters,
|
stopping_parameters,
|
||||||
|
top_n_tokens: top_n_tokens,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -341,6 +355,7 @@ pub(crate) struct ValidGenerateRequest {
|
|||||||
pub decoder_input_details: bool,
|
pub decoder_input_details: bool,
|
||||||
pub parameters: NextTokenChooserParameters,
|
pub parameters: NextTokenChooserParameters,
|
||||||
pub stopping_parameters: StoppingCriteriaParameters,
|
pub stopping_parameters: StoppingCriteriaParameters,
|
||||||
|
pub top_n_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
@ -355,6 +370,10 @@ pub enum ValidationError {
|
|||||||
BestOfSeed,
|
BestOfSeed,
|
||||||
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
||||||
BestOfStream,
|
BestOfStream,
|
||||||
|
#[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
|
||||||
|
TopNTokens(u32, u32),
|
||||||
|
#[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
|
||||||
|
TopNTokensDisabled,
|
||||||
#[error("`decoder_input_details` == true is not supported when streaming tokens")]
|
#[error("`decoder_input_details` == true is not supported when streaming tokens")]
|
||||||
PrefillDetailsStream,
|
PrefillDetailsStream,
|
||||||
#[error("`temperature` must be strictly positive")]
|
#[error("`temperature` must be strictly positive")]
|
||||||
@ -396,14 +415,16 @@ mod tests {
|
|||||||
let tokenizer = None;
|
let tokenizer = None;
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_top_n_tokens = 4;
|
||||||
let max_total_tokens = 5;
|
let max_input_length = 5;
|
||||||
|
let max_total_tokens = 6;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequence,
|
max_stop_sequence,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
@ -423,14 +444,16 @@ mod tests {
|
|||||||
let tokenizer = Some(get_tokenizer().await);
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_top_n_tokens = 4;
|
||||||
let max_total_tokens = 5;
|
let max_input_length = 5;
|
||||||
|
let max_total_tokens = 6;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequence,
|
max_stop_sequence,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
@ -440,7 +463,7 @@ mod tests {
|
|||||||
.validate_input("Hello".to_string(), None, max_new_tokens)
|
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
|
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||||
_ => panic!("Unexpected not max new tokens"),
|
_ => panic!("Unexpected not max new tokens"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -450,14 +473,16 @@ mod tests {
|
|||||||
let tokenizer = Some(get_tokenizer().await);
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_top_n_tokens = 4;
|
||||||
let max_total_tokens = 5;
|
let max_input_length = 5;
|
||||||
|
let max_total_tokens = 6;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequence,
|
max_stop_sequence,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
@ -482,14 +507,16 @@ mod tests {
|
|||||||
let tokenizer = Some(get_tokenizer().await);
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_top_n_tokens = 4;
|
||||||
let max_total_tokens = 5;
|
let max_input_length = 5;
|
||||||
|
let max_total_tokens = 6;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
max_best_of,
|
max_best_of,
|
||||||
max_stop_sequence,
|
max_stop_sequence,
|
||||||
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
@ -536,4 +563,75 @@ mod tests {
|
|||||||
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
||||||
assert_eq!(valid_request.parameters.top_p, 1.0);
|
assert_eq!(valid_request.parameters.top_p, 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_validation_top_n_tokens() {
|
||||||
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
|
let max_best_of = 2;
|
||||||
|
let max_stop_sequences = 3;
|
||||||
|
let max_top_n_tokens = 4;
|
||||||
|
let max_input_length = 5;
|
||||||
|
let max_total_tokens = 6;
|
||||||
|
let workers = 1;
|
||||||
|
let validation = Validation::new(
|
||||||
|
workers,
|
||||||
|
tokenizer,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
|
match validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
top_n_tokens: Some(5),
|
||||||
|
..default_parameters()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Err(ValidationError::TopNTokens(4, 5)) => (),
|
||||||
|
_ => panic!("Unexpected top_n_tokens"),
|
||||||
|
}
|
||||||
|
|
||||||
|
validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
top_n_tokens: Some(4),
|
||||||
|
max_new_tokens: 1,
|
||||||
|
..default_parameters()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
top_n_tokens: Some(0),
|
||||||
|
max_new_tokens: 1,
|
||||||
|
..default_parameters()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let valid_request = validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
top_n_tokens: None,
|
||||||
|
max_new_tokens: 1,
|
||||||
|
..default_parameters()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(valid_request.top_n_tokens, 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
import torch
|
||||||
from text_generation_server.utils.tokens import (
|
from text_generation_server.utils.tokens import (
|
||||||
StopSequenceCriteria,
|
StopSequenceCriteria,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
FinishReason,
|
FinishReason,
|
||||||
|
batch_top_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -42,3 +44,22 @@ def test_stopping_criteria_max():
|
|||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
||||||
|
|
||||||
|
def test_batch_top_tokens():
|
||||||
|
top_n_tokens = [0, 2, 3, 4, 5]
|
||||||
|
top_n_tokens_tensor = torch.tensor(top_n_tokens)
|
||||||
|
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
|
||||||
|
|
||||||
|
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs)
|
||||||
|
|
||||||
|
assert topn_tok_ids[0] == []
|
||||||
|
assert topn_tok_ids[1] == [0, 3]
|
||||||
|
assert topn_tok_ids[2] == [0, 3, 1, 4]
|
||||||
|
assert topn_tok_ids[3] == [0, 3, 1, 4]
|
||||||
|
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
|
||||||
|
|
||||||
|
assert topn_tok_logprobs[0] == []
|
||||||
|
assert topn_tok_logprobs[1] == [-1, -2]
|
||||||
|
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
|
||||||
|
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
|
||||||
|
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
|
||||||
|
@ -125,6 +125,9 @@ def download_weights(
|
|||||||
try:
|
try:
|
||||||
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
|
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
|
||||||
utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
|
utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
|
||||||
|
is_local_model = True
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
return
|
||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
@ -12,6 +13,7 @@ from text_generation_server.models.types import (
|
|||||||
PrefillTokens,
|
PrefillTokens,
|
||||||
Generation,
|
Generation,
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
|
TopTokens,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
@ -42,6 +44,8 @@ class CausalLMBatch(Batch):
|
|||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
top_n_tokens: List[int]
|
||||||
|
top_n_tokens_tensor: torch.Tensor
|
||||||
|
|
||||||
# Metadata used for padding
|
# Metadata used for padding
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
@ -72,6 +76,7 @@ class CausalLMBatch(Batch):
|
|||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
@ -88,6 +93,7 @@ class CausalLMBatch(Batch):
|
|||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(
|
||||||
@ -121,6 +127,9 @@ class CausalLMBatch(Batch):
|
|||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||||
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
|
||||||
@ -138,6 +147,8 @@ class CausalLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@ -163,6 +174,7 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
total_remaining_decode_tokens = 0
|
total_remaining_decode_tokens = 0
|
||||||
new_padding_right_offset = 0
|
new_padding_right_offset = 0
|
||||||
@ -184,6 +196,7 @@ class CausalLMBatch(Batch):
|
|||||||
next_token_choosers.append(self.next_token_choosers[idx])
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
remaining_decode_tokens = (
|
remaining_decode_tokens = (
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
@ -223,6 +236,7 @@ class CausalLMBatch(Batch):
|
|||||||
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
|
||||||
del past_values
|
del past_values
|
||||||
|
|
||||||
|
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||||
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||||
|
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
@ -235,6 +249,8 @@ class CausalLMBatch(Batch):
|
|||||||
self.read_offsets = read_offsets
|
self.read_offsets = read_offsets
|
||||||
self.next_token_choosers = next_token_choosers
|
self.next_token_choosers = next_token_choosers
|
||||||
self.stopping_criterias = stopping_criterias
|
self.stopping_criterias = stopping_criterias
|
||||||
|
self.top_n_tokens = top_n_tokens
|
||||||
|
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
self.padding_right_offset = new_padding_right_offset
|
self.padding_right_offset = new_padding_right_offset
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
@ -262,6 +278,7 @@ class CausalLMBatch(Batch):
|
|||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
|
|
||||||
# Batch tensors
|
# Batch tensors
|
||||||
@ -269,6 +286,7 @@ class CausalLMBatch(Batch):
|
|||||||
attention_mask = None
|
attention_mask = None
|
||||||
position_ids = None
|
position_ids = None
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
|
top_n_tokens_tensor = None
|
||||||
|
|
||||||
# Used for slicing correctly inside the tensors
|
# Used for slicing correctly inside the tensors
|
||||||
# Equivalent to a cumsum on batch sizes
|
# Equivalent to a cumsum on batch sizes
|
||||||
@ -281,6 +299,7 @@ class CausalLMBatch(Batch):
|
|||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
top_n_tokens.extend(batch.top_n_tokens)
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
requests_idx_mapping = batch.requests_idx_mapping
|
requests_idx_mapping = batch.requests_idx_mapping
|
||||||
@ -310,6 +329,12 @@ class CausalLMBatch(Batch):
|
|||||||
(total_batch_size, max_input_length + padding_right_offset),
|
(total_batch_size, max_input_length + padding_right_offset),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if top_n_tokens_tensor is None:
|
||||||
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
|
total_batch_size,
|
||||||
|
)
|
||||||
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
|
|
||||||
# We need to slice the attention mask to remove padding from previous steps
|
# We need to slice the attention mask to remove padding from previous steps
|
||||||
# and to remove unused allocated space
|
# and to remove unused allocated space
|
||||||
left_offset = max_input_length - batch.max_input_length
|
left_offset = max_input_length - batch.max_input_length
|
||||||
@ -438,6 +463,8 @@ class CausalLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||||
@ -549,6 +576,12 @@ class CausalLM(Model):
|
|||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
stopped = True
|
stopped = True
|
||||||
|
|
||||||
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
|
batch.top_n_tokens,
|
||||||
|
batch.top_n_tokens_tensor,
|
||||||
|
torch.softmax(logits[:, -1], -1),
|
||||||
|
)
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
@ -559,6 +592,9 @@ class CausalLM(Model):
|
|||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
|
batch.top_n_tokens,
|
||||||
|
batch_top_token_ids,
|
||||||
|
batch_top_token_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
@ -571,6 +607,9 @@ class CausalLM(Model):
|
|||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
|
top_n_tokens,
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(
|
next_token_id, logprobs = next_token_chooser(
|
||||||
@ -637,6 +676,24 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
|
if top_n_tokens > 0:
|
||||||
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
|
top_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
special_toptokens = [
|
||||||
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
|
]
|
||||||
|
top_tokens = TopTokens(
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
top_tokens = None
|
||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
@ -645,6 +702,7 @@ class CausalLM(Model):
|
|||||||
next_token_text,
|
next_token_text,
|
||||||
next_token_id_squeezed.item() in self.all_special_ids,
|
next_token_id_squeezed.item() in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -64,6 +64,7 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
pretraining_tp=1,
|
pretraining_tp=1,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
rope_scaling=None,
|
rope_scaling=None,
|
||||||
|
rope_theta=10000.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@ -84,6 +85,7 @@ class LlamaConfig(PretrainedConfig):
|
|||||||
self.pretraining_tp = pretraining_tp
|
self.pretraining_tp = pretraining_tp
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.rope_scaling = rope_scaling
|
self.rope_scaling = rope_scaling
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
@ -189,7 +191,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||||
# )
|
# )
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config, dim=self.head_size, base=10000.0, device=weights.device
|
config=config, dim=self.head_size, base=config.rope_theta, device=weights.device
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import itertools
|
import itertools
|
||||||
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ from text_generation_server.models.types import (
|
|||||||
PrefillTokens,
|
PrefillTokens,
|
||||||
Generation,
|
Generation,
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
|
TopTokens,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
@ -165,6 +167,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_chooser: HeterogeneousNextTokenChooser
|
next_token_chooser: HeterogeneousNextTokenChooser
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
top_n_tokens: List[int]
|
||||||
|
top_n_tokens_tensor: torch.Tensor
|
||||||
|
|
||||||
# Number of blocks in this batch
|
# Number of blocks in this batch
|
||||||
blocks: int
|
blocks: int
|
||||||
@ -217,6 +221,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
@ -259,6 +264,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
max_new_tokens = stopping_criteria.max_new_tokens
|
max_new_tokens = stopping_criteria.max_new_tokens
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
|
|
||||||
# Paged attention
|
# Paged attention
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
@ -352,6 +358,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_next_token_indices = torch.tensor(
|
prefill_next_token_indices = torch.tensor(
|
||||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
@ -378,6 +387,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
@ -417,6 +428,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
blocks = 0
|
blocks = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
@ -443,6 +455,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
|
||||||
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
|
|
||||||
remaining_tokens = (
|
remaining_tokens = (
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
@ -487,6 +501,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
|
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
@ -518,6 +533,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
@ -566,6 +583,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||||
(total_batch_size, max_length)
|
(total_batch_size, max_length)
|
||||||
)
|
)
|
||||||
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
|
total_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
start_slots = []
|
start_slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
@ -577,6 +597,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_batch_size = 0
|
cumulative_batch_size = 0
|
||||||
@ -602,6 +623,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
position_ids[start_index:end_index] = batch.position_ids
|
position_ids[start_index:end_index] = batch.position_ids
|
||||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||||
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
slots[slots_start_index:slots_end_index] = batch.slots
|
slots[slots_start_index:slots_end_index] = batch.slots
|
||||||
|
|
||||||
all_input_ids_tensor[
|
all_input_ids_tensor[
|
||||||
@ -624,6 +646,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
|
top_n_tokens.extend(batch.top_n_tokens)
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
cumulative_slots += len(batch.slots)
|
cumulative_slots += len(batch.slots)
|
||||||
@ -667,6 +691,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
)
|
)
|
||||||
@ -832,10 +858,14 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
|
|
||||||
next_input_ids, next_token_logprobs = batch.next_token_chooser(
|
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
|
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||||
|
)
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
if len(batch) > 1 and prefill_logprobs:
|
if len(batch) > 1 and prefill_logprobs:
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||||
@ -932,8 +962,11 @@ class FlashCausalLM(Model):
|
|||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
|
batch.top_n_tokens,
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
next_token_logprobs,
|
next_token_logprobs,
|
||||||
|
batch_top_token_ids,
|
||||||
|
batch_top_token_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
@ -946,8 +979,11 @@ class FlashCausalLM(Model):
|
|||||||
all_input_ids,
|
all_input_ids,
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
|
top_n_tokens,
|
||||||
next_token_id,
|
next_token_id,
|
||||||
next_token_logprob,
|
next_token_logprob,
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids.append(next_token_id)
|
all_input_ids.append(next_token_id)
|
||||||
@ -1006,6 +1042,24 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
|
if top_n_tokens > 0:
|
||||||
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
|
top_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
special_toptokens = [
|
||||||
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
|
]
|
||||||
|
top_tokens = TopTokens(
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
top_tokens = None
|
||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
@ -1014,6 +1068,7 @@ class FlashCausalLM(Model):
|
|||||||
next_token_text,
|
next_token_text,
|
||||||
next_token_id in self.all_special_ids,
|
next_token_id in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -763,6 +763,8 @@ class IdeficsCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
|
top_tokens=None
|
||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
@ -771,6 +773,7 @@ class IdeficsCausalLM(Model):
|
|||||||
next_token_text,
|
next_token_text,
|
||||||
next_token_id_squeezed.item() in self.all_special_ids,
|
next_token_id_squeezed.item() in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
|
top_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -11,6 +12,7 @@ from text_generation_server.models.types import (
|
|||||||
Batch,
|
Batch,
|
||||||
Generation,
|
Generation,
|
||||||
PrefillTokens,
|
PrefillTokens,
|
||||||
|
TopTokens,
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
@ -48,6 +50,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
# Generation helpers
|
# Generation helpers
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
top_n_tokens: List[int]
|
||||||
|
top_n_tokens_tensor: torch.Tensor
|
||||||
|
|
||||||
# Metadata used for padding
|
# Metadata used for padding
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
@ -78,7 +82,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
@ -97,6 +101,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
max_decode_tokens += stopping_criteria.max_new_tokens
|
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(
|
||||||
@ -126,6 +131,9 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
prefix_offsets.append(0)
|
prefix_offsets.append(0)
|
||||||
read_offsets.append(1)
|
read_offsets.append(1)
|
||||||
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
|
||||||
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
|
||||||
@ -146,6 +154,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
max_decoder_input_length=1,
|
max_decoder_input_length=1,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
@ -173,6 +183,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
|
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
max_input_length = 0
|
max_input_length = 0
|
||||||
max_decoder_input_length = 0
|
max_decoder_input_length = 0
|
||||||
@ -204,6 +215,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
next_token_choosers.append(self.next_token_choosers[idx])
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
remaining_decode_tokens = (
|
remaining_decode_tokens = (
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
@ -239,6 +251,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
layer[2] = layer[2][keep_indices, :, -max_input_length:]
|
layer[2] = layer[2][keep_indices, :, -max_input_length:]
|
||||||
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
layer[3] = layer[3][keep_indices, :, -max_input_length:]
|
||||||
|
|
||||||
|
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||||
max_tokens = (
|
max_tokens = (
|
||||||
len(request_ids) * (max_input_length + max_decoder_input_length)
|
len(request_ids) * (max_input_length + max_decoder_input_length)
|
||||||
+ remaining_decode_tokens
|
+ remaining_decode_tokens
|
||||||
@ -254,6 +267,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
self.read_offsets = read_offsets
|
self.read_offsets = read_offsets
|
||||||
self.next_token_choosers = next_token_choosers
|
self.next_token_choosers = next_token_choosers
|
||||||
self.stopping_criterias = stopping_criterias
|
self.stopping_criterias = stopping_criterias
|
||||||
|
self.top_n_tokens = top_n_tokens
|
||||||
|
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
self.max_decoder_input_length = max_decoder_input_length
|
self.max_decoder_input_length = max_decoder_input_length
|
||||||
self.padding_right_offset = padding_right_offset
|
self.padding_right_offset = padding_right_offset
|
||||||
@ -289,6 +304,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
read_offsets = []
|
read_offsets = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
|
|
||||||
# Batch tensors
|
# Batch tensors
|
||||||
@ -296,6 +312,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
decoder_input_ids = None
|
decoder_input_ids = None
|
||||||
decoder_attention_mask = None
|
decoder_attention_mask = None
|
||||||
encoder_last_hidden_state = None
|
encoder_last_hidden_state = None
|
||||||
|
top_n_tokens_tensor = None
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
|
|
||||||
# Used for slicing correctly inside the tensors
|
# Used for slicing correctly inside the tensors
|
||||||
@ -312,6 +329,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
read_offsets.extend(batch.read_offsets)
|
read_offsets.extend(batch.read_offsets)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
top_n_tokens.extend(batch.top_n_tokens)
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
requests_idx_mapping = batch.requests_idx_mapping
|
requests_idx_mapping = batch.requests_idx_mapping
|
||||||
@ -384,6 +402,12 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if top_n_tokens_tensor is None:
|
||||||
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
|
total_batch_size,
|
||||||
|
)
|
||||||
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
|
|
||||||
# Copy to correct indices
|
# Copy to correct indices
|
||||||
encoder_last_hidden_state[
|
encoder_last_hidden_state[
|
||||||
start_index:end_index, -batch.max_input_length :, :
|
start_index:end_index, -batch.max_input_length :, :
|
||||||
@ -488,6 +512,8 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
max_input_length=max_input_length,
|
max_input_length=max_input_length,
|
||||||
max_decoder_input_length=max_decoder_input_length,
|
max_decoder_input_length=max_decoder_input_length,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
@ -613,6 +639,12 @@ class Seq2SeqLM(Model):
|
|||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
|
batch.top_n_tokens,
|
||||||
|
batch.top_n_tokens_tensor,
|
||||||
|
torch.softmax(logits[:, -1], -1),
|
||||||
|
)
|
||||||
|
|
||||||
# Finished requests
|
# Finished requests
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
stopped = True
|
stopped = True
|
||||||
@ -628,6 +660,9 @@ class Seq2SeqLM(Model):
|
|||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_decoder_input_ids,
|
batch.all_decoder_input_ids,
|
||||||
|
batch.top_n_tokens,
|
||||||
|
batch_top_token_ids,
|
||||||
|
batch_top_token_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
@ -641,6 +676,9 @@ class Seq2SeqLM(Model):
|
|||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_decoder_input_ids,
|
all_decoder_input_ids,
|
||||||
|
top_n_tokens,
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(
|
next_token_id, logprobs = next_token_chooser(
|
||||||
@ -698,6 +736,24 @@ class Seq2SeqLM(Model):
|
|||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
|
if top_n_tokens > 0:
|
||||||
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
|
top_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
special_toptokens = [
|
||||||
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
|
]
|
||||||
|
top_tokens = TopTokens(
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
top_tokens = None
|
||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
@ -706,6 +762,7 @@ class Seq2SeqLM(Model):
|
|||||||
next_token_text,
|
next_token_text,
|
||||||
next_token_id_squeezed.item() in self.all_special_ids,
|
next_token_id_squeezed.item() in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from functools import total_ordering
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -71,6 +72,25 @@ class PrefillTokens:
|
|||||||
return len(self.token_ids)
|
return len(self.token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TopTokens:
|
||||||
|
token_ids: List[int]
|
||||||
|
logprobs: List[float]
|
||||||
|
texts: List[str]
|
||||||
|
is_special: List[bool]
|
||||||
|
|
||||||
|
def to_pb(self) -> generate_pb2.TopTokens:
|
||||||
|
return generate_pb2.TopTokens(
|
||||||
|
ids=self.token_ids,
|
||||||
|
logprobs=self.logprobs,
|
||||||
|
texts=self.texts,
|
||||||
|
is_special=self.is_special,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.token_ids)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Generation:
|
class Generation:
|
||||||
request_id: int
|
request_id: int
|
||||||
@ -80,6 +100,8 @@ class Generation:
|
|||||||
token_text: str
|
token_text: str
|
||||||
token_is_special: bool
|
token_is_special: bool
|
||||||
generated_text: Optional[GeneratedText]
|
generated_text: Optional[GeneratedText]
|
||||||
|
# Optional for now, since it's not yet supported for every model.
|
||||||
|
top_tokens: Optional[TopTokens]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Generation:
|
def to_pb(self) -> generate_pb2.Generation:
|
||||||
return generate_pb2.Generation(
|
return generate_pb2.Generation(
|
||||||
@ -94,4 +116,5 @@ class Generation:
|
|||||||
generated_text=self.generated_text.to_pb()
|
generated_text=self.generated_text.to_pb()
|
||||||
if self.generated_text is not None
|
if self.generated_text is not None
|
||||||
else None,
|
else None,
|
||||||
|
top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None,
|
||||||
)
|
)
|
||||||
|
@ -6,20 +6,22 @@ from transformers import (
|
|||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
SequenceBiasLogitsProcessor,
|
SequenceBiasLogitsProcessor,
|
||||||
)
|
)
|
||||||
from typing import List, Tuple, Optional, Dict
|
from typing import Callable, List, Tuple, Optional, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
|
||||||
from text_generation_server.utils.logits_process import (
|
from text_generation_server.utils.logits_process import (
|
||||||
static_warper,
|
HeterogeneousProcessorWrapper,
|
||||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||||
HeterogeneousTemperatureLogitsWarper,
|
HeterogeneousTemperatureLogitsWarper,
|
||||||
HeterogeneousTopKLogitsWarper,
|
HeterogeneousTopKLogitsWarper,
|
||||||
HeterogeneousTopPLogitsWarper,
|
HeterogeneousTopPLogitsWarper,
|
||||||
HeterogeneousTypicalLogitsWarper,
|
HeterogeneousTypicalLogitsWarper,
|
||||||
HeterogeneousProcessorWrapper,
|
static_warper,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
|
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
@ -255,11 +257,10 @@ class HeterogeneousNextTokenChooser:
|
|||||||
scores = warper(input_ids, scores)
|
scores = warper(input_ids, scores)
|
||||||
|
|
||||||
next_ids = self.choice(scores)
|
next_ids = self.choice(scores)
|
||||||
next_logprobs = torch.gather(
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
).view(-1)
|
|
||||||
|
|
||||||
return next_ids, next_logprobs
|
return next_ids, next_logprobs, logprobs
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
@ -370,3 +371,50 @@ class HeterogeneousSampling:
|
|||||||
self.greedy_indices = new_greedy_indices
|
self.greedy_indices = new_greedy_indices
|
||||||
self.sampling_mapping = new_sampling_mapping
|
self.sampling_mapping = new_sampling_mapping
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def batch_top_tokens(
|
||||||
|
top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor
|
||||||
|
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||||
|
"""Find the top n most likely tokens for a batch of generations.
|
||||||
|
|
||||||
|
When multiple tokens have equal probabilities and they don't all fit, the
|
||||||
|
remaining tokens are also returned.
|
||||||
|
"""
|
||||||
|
max_top_n = max(top_n_tokens)
|
||||||
|
# Early exit when top_n_tokens is not used
|
||||||
|
if max_top_n == 0:
|
||||||
|
return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens)
|
||||||
|
|
||||||
|
# Ensure top_n doesn't exceed vocab size
|
||||||
|
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens]
|
||||||
|
|
||||||
|
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
|
||||||
|
# Sorted topk is faster than torch.sort() since we only need a small subset
|
||||||
|
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values
|
||||||
|
nth_highest = torch.gather(
|
||||||
|
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
|
||||||
|
)
|
||||||
|
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
|
||||||
|
|
||||||
|
# Find the new "fuzzy" top n values
|
||||||
|
top_n_indices = (logprobs >= nth_highest).nonzero()
|
||||||
|
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
|
||||||
|
|
||||||
|
# Take a new topk for these new max n values
|
||||||
|
top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True)
|
||||||
|
|
||||||
|
top_n_ishes = top_n_ishes.tolist()
|
||||||
|
top_indices = top_k.indices.tolist()
|
||||||
|
top_values = top_k.values.tolist()
|
||||||
|
|
||||||
|
return (
|
||||||
|
[
|
||||||
|
idxs[:n] if req_n > 0 else []
|
||||||
|
for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens)
|
||||||
|
],
|
||||||
|
[
|
||||||
|
vals[:n] if req_n > 0 else []
|
||||||
|
for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user