mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
This PR adds basic modeling for phi-2 run ```bash text-generation-server \ serve \ microsoft/phi-2 \ --revision 834565c23f9b28b96ccbeabe614dd906b6db551a ``` test ```bash curl -s localhost:3000/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' | jq . ``` notes - recently (~1 day ago) the Phi weights and model were updated to accommodate adding [GQA/MQA attention to the model.](https://github.com/huggingface/transformers/pull/28163) This impl expects the original model format so a fixed revision is required at the moment. - this PR only includes a basic implementation of the model and can later be extended for support Flash and Sharded versions as well as make use of better optimization
66 lines
1.8 KiB
Python
66 lines
1.8 KiB
Python
import pytest
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def flash_phi_handle(launcher):
|
|
with launcher("microsoft/phi-2", num_shard=1) as handle:
|
|
yield handle
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
async def flash_phi(flash_phi_handle):
|
|
await flash_phi_handle.health(300)
|
|
return flash_phi_handle.client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_flash_phi(flash_phi, response_snapshot):
|
|
response = await flash_phi.generate(
|
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
|
)
|
|
|
|
assert response.details.generated_tokens == 10
|
|
assert response.generated_text == ": {request}\")\n response = self"
|
|
assert response == response_snapshot
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
|
response = await flash_phi.generate(
|
|
"Test request",
|
|
max_new_tokens=10,
|
|
repetition_penalty=1.2,
|
|
return_full_text=True,
|
|
stop_sequences=["network"],
|
|
temperature=0.5,
|
|
top_p=0.9,
|
|
top_k=10,
|
|
truncate=5,
|
|
typical_p=0.9,
|
|
watermark=True,
|
|
decoder_input_details=True,
|
|
seed=0,
|
|
)
|
|
|
|
assert response.details.generated_tokens == 6
|
|
assert response.generated_text == "Test request to send data over a network"
|
|
assert response == response_snapshot
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
|
responses = await generate_load(
|
|
flash_phi, "Test request", max_new_tokens=10, n=4
|
|
)
|
|
|
|
assert len(responses) == 4
|
|
assert all(
|
|
[r.generated_text == responses[0].generated_text for r in responses]
|
|
), f"{[r.generated_text for r in responses]}"
|
|
assert responses[0].generated_text == ": {request}\")\n response = self"
|
|
|
|
assert responses == response_snapshot
|