diff --git a/proto/generate.proto b/proto/generate.proto index 45ba8da5..3f607dc5 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -154,17 +154,6 @@ message TopTokens { repeated bool is_special = 6; } -message TopToken { - /// Token ID - uint32 token_id = 3; - /// Logprob - float token_logprob = 4; - /// Text - string token_text = 5; - /// Is it a special token - bool token_is_special = 6; -} - message Generation { /// Request ID uint64 request_id = 1; @@ -181,7 +170,6 @@ message Generation { /// Complete generated text optional GeneratedText generated_text = 7; /// Top tokens - // repeated TopToken top_tokens = 8; TopTokens top_tokens = 8; } diff --git a/router/src/server.rs b/router/src/server.rs index 7a191d61..91164098 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -198,11 +198,6 @@ async fn generate( .collect() }); - // let top_tokens = match response.top_tokens.is_empty() { - // true => None, - // false => Some(response.top_tokens), - // }; - Some(Details { finish_reason: FinishReason::from(response.generated_text.finish_reason), generated_tokens: response.generated_text.generated_tokens, diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index 46b1220f..4187ff25 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -47,9 +47,10 @@ def test_stopping_criteria_max(): def test_batch_top_tokens(): top_n_tokens = [0, 2, 3, 4, 5] + top_n_tokens_tensor = torch.tensor(top_n_tokens) inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5) - topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, inp_logprobs) + topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs) assert topn_tok_ids[0] == [] assert topn_tok_ids[1] == [0, 3] diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index b12a9751..233893fe 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -32,6 +32,7 @@ def serve( quantize: Optional[Quantization] = None, dtype: Optional[Dtype] = None, trust_remote_code: bool = False, + backend: str = "cuda", uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", json_output: bool = False, @@ -79,7 +80,7 @@ def serve( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) server.serve( - model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path + model_id, revision, sharded, quantize, dtype, backend, trust_remote_code, uds_path )