breaking(router): modify /generate API to only return generated text

This commit is contained in:
OlivierDehaene 2023-02-01 18:38:30 +01:00
parent 7b870e1e18
commit f36e736723
8 changed files with 23 additions and 25 deletions

View File

@ -118,6 +118,6 @@
] ]
] ]
}, },
"generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException" "generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
} }
] ]

View File

@ -97,8 +97,8 @@ fn test_model(
launcher.terminate().unwrap(); launcher.terminate().unwrap();
launcher.wait().unwrap(); launcher.wait().unwrap();
let mut results: Vec<GeneratedText> = res.unwrap().json().unwrap(); let result: GeneratedText = res.unwrap().json().unwrap();
results.pop().unwrap() result
} }
fn read_json(name: &str) -> GeneratedText { fn read_json(name: &str) -> GeneratedText {

View File

@ -125,10 +125,10 @@ async fn generate(
tracing::info!("Output: {}", response.generated_text.text); tracing::info!("Output: {}", response.generated_text.text);
// Send response // Send response
let response = vec![GenerateResponse { let response = GenerateResponse {
generated_text: response.generated_text.text, generated_text: response.generated_text.text,
details, details,
}]; };
Ok((headers, Json(response))) Ok((headers, Json(response)))
} }

View File

@ -141,7 +141,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest" == "TestTestTestTestTestTestTestTestTestTest"
) )
assert generations[0].request_id == default_bloom_batch.requests[0].id assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
@ -165,7 +165,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
assert generations[1].generated_text.text == "TestTestTestTestTestTest" assert generations[1].generated_text.text == "TestTestTestTestTest"
assert ( assert (
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
) )
@ -188,7 +188,7 @@ def test_causal_lm_generate_token_completion_multi(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest" == "TestTestTestTestTestTestTestTestTestTest"
) )
assert ( assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
@ -261,7 +261,7 @@ def test_batch_concatenate(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 3 assert len(generations) == 3
assert generations[2].generated_text.text == "TestTestTestTestTestTest" assert generations[2].generated_text.text == "TestTestTestTestTest"
assert ( assert (
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
) )
@ -284,7 +284,7 @@ def test_batch_concatenate(
assert len(generations) == 2 assert len(generations) == 2
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest" == "TestTestTestTestTestTestTestTestTestTest"
) )
assert generations[0].request_id == default_bloom_batch.requests[0].id assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
@ -307,7 +307,7 @@ def test_batch_concatenate(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== "TestTestTestTestTestTestTestTestTestTestTest" == "TestTestTestTestTestTestTestTestTestTest"
) )
assert ( assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id

View File

@ -138,7 +138,7 @@ def test_causal_lm_generate_token_completion(
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert ( assert (
generations[0].generated_text.generated_tokens generations[0].generated_text.generated_tokens
@ -161,7 +161,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
assert generations[1].generated_text.text == "Test.java:784)" assert generations[1].generated_text.text == ".java:784)"
assert ( assert (
generations[1].request_id generations[1].request_id
== default_multi_requests_causal_lm_batch.requests[1].id == default_multi_requests_causal_lm_batch.requests[1].id
@ -183,7 +183,7 @@ def test_causal_lm_generate_token_completion_multi(
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert ( assert (
generations[0].request_id generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id == default_multi_requests_causal_lm_batch.requests[0].id
@ -255,7 +255,7 @@ def test_batch_concatenate(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 3 assert len(generations) == 3
assert generations[2].generated_text.text == "Test.java:784)" assert generations[2].generated_text.text == ".java:784)"
assert ( assert (
generations[2].request_id generations[2].request_id
== default_multi_requests_causal_lm_batch.requests[1].id == default_multi_requests_causal_lm_batch.requests[1].id
@ -277,7 +277,7 @@ def test_batch_concatenate(
assert next_batch is not None assert next_batch is not None
assert len(generations) == 2 assert len(generations) == 2
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert generations[0].request_id == default_causal_lm_batch.requests[0].id assert generations[0].request_id == default_causal_lm_batch.requests[0].id
assert ( assert (
generations[0].generated_text.generated_tokens generations[0].generated_text.generated_tokens
@ -297,7 +297,7 @@ def test_batch_concatenate(
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft." assert generations[0].generated_text.text == ".java:784) at net.minecraft."
assert ( assert (
generations[0].request_id generations[0].request_id
== default_multi_requests_causal_lm_batch.requests[0].id == default_multi_requests_causal_lm_batch.requests[0].id

View File

@ -42,7 +42,7 @@ def default_fim_pb_batch(default_fim_pb_request):
return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1) return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1)
@pytest.mark.skip # @pytest.mark.skip
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch): def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
batch = CausalLMBatch.from_pb( batch = CausalLMBatch.from_pb(
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
@ -57,7 +57,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
assert next_batch is None assert next_batch is None
assert len(generations) == 1 assert len(generations) == 1
assert generations[0].generated_text.text == "def test_get_all_users_with_" assert generations[0].generated_text.text == " test_get_all_users_with_"
assert generations[0].request_id == batch.requests[0].id assert generations[0].request_id == batch.requests[0].id
assert ( assert (
generations[0].generated_text.generated_tokens generations[0].generated_text.generated_tokens
@ -65,7 +65,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
) )
@pytest.mark.skip # @pytest.mark.skip
def test_fim_santacoder_generate_token_completion( def test_fim_santacoder_generate_token_completion(
default_santacoder, default_fim_pb_batch default_santacoder, default_fim_pb_batch
): ):
@ -84,7 +84,7 @@ def test_fim_santacoder_generate_token_completion(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value""" == """ineProperty(exports, "__esModule", { value"""
) )
assert generations[0].request_id == batch.requests[0].id assert generations[0].request_id == batch.requests[0].id
assert ( assert (

View File

@ -32,7 +32,7 @@ torch.backends.cudnn.allow_tf32 = True
def get_model( def get_model(
model_name: str, revision: Optional[str], sharded: bool, quantize: bool model_name: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model: ) -> Model:
config = AutoConfig.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name, revision=revision)
if config.model_type == "bloom": if config.model_type == "bloom":
if sharded: if sharded:

View File

@ -360,11 +360,9 @@ class CausalLM(Model):
if stop: if stop:
# Decode generated tokens # Decode generated tokens
generated_text = self.decode( output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :, 0] all_input_ids[-stopping_criteria.current_tokens :, 0]
) )
output_text = request.inputs + generated_text
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed seed = next_token_chooser.choice.seed