mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Removing dead code + "Fix" test.
This commit is contained in:
parent
e0b197ea09
commit
730e5938f5
@ -154,17 +154,6 @@ message TopTokens {
|
||||
repeated bool is_special = 6;
|
||||
}
|
||||
|
||||
message TopToken {
|
||||
/// Token ID
|
||||
uint32 token_id = 3;
|
||||
/// Logprob
|
||||
float token_logprob = 4;
|
||||
/// Text
|
||||
string token_text = 5;
|
||||
/// Is it a special token
|
||||
bool token_is_special = 6;
|
||||
}
|
||||
|
||||
message Generation {
|
||||
/// Request ID
|
||||
uint64 request_id = 1;
|
||||
@ -181,7 +170,6 @@ message Generation {
|
||||
/// Complete generated text
|
||||
optional GeneratedText generated_text = 7;
|
||||
/// Top tokens
|
||||
// repeated TopToken top_tokens = 8;
|
||||
TopTokens top_tokens = 8;
|
||||
}
|
||||
|
||||
|
@ -198,11 +198,6 @@ async fn generate(
|
||||
.collect()
|
||||
});
|
||||
|
||||
// let top_tokens = match response.top_tokens.is_empty() {
|
||||
// true => None,
|
||||
// false => Some(response.top_tokens),
|
||||
// };
|
||||
|
||||
Some(Details {
|
||||
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
||||
generated_tokens: response.generated_text.generated_tokens,
|
||||
|
@ -47,9 +47,10 @@ def test_stopping_criteria_max():
|
||||
|
||||
def test_batch_top_tokens():
|
||||
top_n_tokens = [0, 2, 3, 4, 5]
|
||||
top_n_tokens_tensor = torch.tensor(top_n_tokens)
|
||||
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
|
||||
|
||||
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, inp_logprobs)
|
||||
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs)
|
||||
|
||||
assert topn_tok_ids[0] == []
|
||||
assert topn_tok_ids[1] == [0, 3]
|
||||
|
@ -32,6 +32,7 @@ def serve(
|
||||
quantize: Optional[Quantization] = None,
|
||||
dtype: Optional[Dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
backend: str = "cuda",
|
||||
uds_path: Path = "/tmp/text-generation-server",
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
@ -79,7 +80,7 @@ def serve(
|
||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||
)
|
||||
server.serve(
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
||||
model_id, revision, sharded, quantize, dtype, backend, trust_remote_code, uds_path
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user