Removing dead code + "Fix" test.

This commit is contained in:
Nicolas Patry 2023-08-18 12:41:10 +02:00
parent e0b197ea09
commit 730e5938f5
4 changed files with 4 additions and 19 deletions

View File

@ -154,17 +154,6 @@ message TopTokens {
repeated bool is_special = 6; repeated bool is_special = 6;
} }
message TopToken {
/// Token ID
uint32 token_id = 3;
/// Logprob
float token_logprob = 4;
/// Text
string token_text = 5;
/// Is it a special token
bool token_is_special = 6;
}
message Generation { message Generation {
/// Request ID /// Request ID
uint64 request_id = 1; uint64 request_id = 1;
@ -181,7 +170,6 @@ message Generation {
/// Complete generated text /// Complete generated text
optional GeneratedText generated_text = 7; optional GeneratedText generated_text = 7;
/// Top tokens /// Top tokens
// repeated TopToken top_tokens = 8;
TopTokens top_tokens = 8; TopTokens top_tokens = 8;
} }

View File

@ -198,11 +198,6 @@ async fn generate(
.collect() .collect()
}); });
// let top_tokens = match response.top_tokens.is_empty() {
// true => None,
// false => Some(response.top_tokens),
// };
Some(Details { Some(Details {
finish_reason: FinishReason::from(response.generated_text.finish_reason), finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens, generated_tokens: response.generated_text.generated_tokens,

View File

@ -47,9 +47,10 @@ def test_stopping_criteria_max():
def test_batch_top_tokens(): def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5] top_n_tokens = [0, 2, 3, 4, 5]
top_n_tokens_tensor = torch.tensor(top_n_tokens)
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5) inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, inp_logprobs) topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs)
assert topn_tok_ids[0] == [] assert topn_tok_ids[0] == []
assert topn_tok_ids[1] == [0, 3] assert topn_tok_ids[1] == [0, 3]

View File

@ -32,6 +32,7 @@ def serve(
quantize: Optional[Quantization] = None, quantize: Optional[Quantization] = None,
dtype: Optional[Dtype] = None, dtype: Optional[Dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
backend: str = "cuda",
uds_path: Path = "/tmp/text-generation-server", uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
@ -79,7 +80,7 @@ def serve(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
) )
server.serve( server.serve(
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path model_id, revision, sharded, quantize, dtype, backend, trust_remote_code, uds_path
) )