mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
breaking(router): modify /generate API to only return generated text
This commit is contained in:
parent
7b870e1e18
commit
f36e736723
@ -118,6 +118,6 @@
|
||||
]
|
||||
]
|
||||
},
|
||||
"generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
||||
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
||||
}
|
||||
]
|
@ -97,8 +97,8 @@ fn test_model(
|
||||
launcher.terminate().unwrap();
|
||||
launcher.wait().unwrap();
|
||||
|
||||
let mut results: Vec<GeneratedText> = res.unwrap().json().unwrap();
|
||||
results.pop().unwrap()
|
||||
let result: GeneratedText = res.unwrap().json().unwrap();
|
||||
result
|
||||
}
|
||||
|
||||
fn read_json(name: &str) -> GeneratedText {
|
||||
|
@ -125,10 +125,10 @@ async fn generate(
|
||||
tracing::info!("Output: {}", response.generated_text.text);
|
||||
|
||||
// Send response
|
||||
let response = vec![GenerateResponse {
|
||||
let response = GenerateResponse {
|
||||
generated_text: response.generated_text.text,
|
||||
details,
|
||||
}];
|
||||
};
|
||||
Ok((headers, Json(response)))
|
||||
}
|
||||
|
||||
|
@ -141,7 +141,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
|
||||
assert len(generations) == 1
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
== "TestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||
assert (
|
||||
@ -165,7 +165,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
assert generations[1].generated_text.text == "TestTestTestTestTestTest"
|
||||
assert generations[1].generated_text.text == "TestTestTestTestTest"
|
||||
assert (
|
||||
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
|
||||
)
|
||||
@ -188,7 +188,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
assert len(generations) == 1
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
== "TestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert (
|
||||
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||
@ -261,7 +261,7 @@ def test_batch_concatenate(
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 3
|
||||
assert generations[2].generated_text.text == "TestTestTestTestTestTest"
|
||||
assert generations[2].generated_text.text == "TestTestTestTestTest"
|
||||
assert (
|
||||
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
|
||||
)
|
||||
@ -284,7 +284,7 @@ def test_batch_concatenate(
|
||||
assert len(generations) == 2
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
== "TestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||
assert (
|
||||
@ -307,7 +307,7 @@ def test_batch_concatenate(
|
||||
assert len(generations) == 1
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== "TestTestTestTestTestTestTestTestTestTestTest"
|
||||
== "TestTestTestTestTestTestTestTestTestTest"
|
||||
)
|
||||
assert (
|
||||
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||
|
@ -138,7 +138,7 @@ def test_causal_lm_generate_token_completion(
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
|
||||
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
||||
assert (
|
||||
generations[0].generated_text.generated_tokens
|
||||
@ -161,7 +161,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
assert generations[1].generated_text.text == "Test.java:784)"
|
||||
assert generations[1].generated_text.text == ".java:784)"
|
||||
assert (
|
||||
generations[1].request_id
|
||||
== default_multi_requests_causal_lm_batch.requests[1].id
|
||||
@ -183,7 +183,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
|
||||
assert (
|
||||
generations[0].request_id
|
||||
== default_multi_requests_causal_lm_batch.requests[0].id
|
||||
@ -255,7 +255,7 @@ def test_batch_concatenate(
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 3
|
||||
assert generations[2].generated_text.text == "Test.java:784)"
|
||||
assert generations[2].generated_text.text == ".java:784)"
|
||||
assert (
|
||||
generations[2].request_id
|
||||
== default_multi_requests_causal_lm_batch.requests[1].id
|
||||
@ -277,7 +277,7 @@ def test_batch_concatenate(
|
||||
assert next_batch is not None
|
||||
|
||||
assert len(generations) == 2
|
||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
|
||||
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
|
||||
assert (
|
||||
generations[0].generated_text.generated_tokens
|
||||
@ -297,7 +297,7 @@ def test_batch_concatenate(
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
|
||||
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
|
||||
assert (
|
||||
generations[0].request_id
|
||||
== default_multi_requests_causal_lm_batch.requests[0].id
|
||||
|
@ -42,7 +42,7 @@ def default_fim_pb_batch(default_fim_pb_request):
|
||||
return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
# @pytest.mark.skip
|
||||
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
|
||||
batch = CausalLMBatch.from_pb(
|
||||
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
|
||||
@ -57,7 +57,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
|
||||
assert next_batch is None
|
||||
|
||||
assert len(generations) == 1
|
||||
assert generations[0].generated_text.text == "def test_get_all_users_with_"
|
||||
assert generations[0].generated_text.text == " test_get_all_users_with_"
|
||||
assert generations[0].request_id == batch.requests[0].id
|
||||
assert (
|
||||
generations[0].generated_text.generated_tokens
|
||||
@ -65,7 +65,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
# @pytest.mark.skip
|
||||
def test_fim_santacoder_generate_token_completion(
|
||||
default_santacoder, default_fim_pb_batch
|
||||
):
|
||||
@ -84,7 +84,7 @@ def test_fim_santacoder_generate_token_completion(
|
||||
assert len(generations) == 1
|
||||
assert (
|
||||
generations[0].generated_text.text
|
||||
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
|
||||
== """ineProperty(exports, "__esModule", { value"""
|
||||
)
|
||||
assert generations[0].request_id == batch.requests[0].id
|
||||
assert (
|
||||
|
@ -32,7 +32,7 @@ torch.backends.cudnn.allow_tf32 = True
|
||||
def get_model(
|
||||
model_name: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||
) -> Model:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
config = AutoConfig.from_pretrained(model_name, revision=revision)
|
||||
|
||||
if config.model_type == "bloom":
|
||||
if sharded:
|
||||
|
@ -360,11 +360,9 @@ class CausalLM(Model):
|
||||
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
generated_text = self.decode(
|
||||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||
)
|
||||
output_text = request.inputs + generated_text
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
|
Loading…
Reference in New Issue
Block a user