text-generation-inference/integration-tests/models/test_flash_phi.py
drbh b2fc097b2b feat: adds phi model (#1442)
This PR adds basic modeling for phi-2

run
```bash
text-generation-server \
    serve \
    microsoft/phi-2 \
    --revision 834565c23f9b28b96ccbeabe614dd906b6db551a
```

test
```bash
curl -s localhost:3000/generate \
   -X POST \
   -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
   -H 'Content-Type: application/json' | jq .
```

notes
- recently (~1 day ago) the Phi weights and model were updated to
accommodate adding [GQA/MQA attention to the
model.](https://github.com/huggingface/transformers/pull/28163) This
impl expects the original model format so a fixed revision is required
at the moment.
- this PR only includes a basic implementation of the model and can
later be extended for support Flash and Sharded versions as well as make
use of better optimization
2024-04-22 13:06:38 +03:00

66 lines
1.8 KiB
Python

import pytest
@pytest.fixture(scope="module")
def flash_phi_handle(launcher):
with launcher("microsoft/phi-2", num_shard=1) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_phi(flash_phi_handle):
await flash_phi_handle.health(300)
return flash_phi_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi(flash_phi, response_snapshot):
response = await flash_phi.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response.generated_text == ": {request}\")\n response = self"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_all_params(flash_phi, response_snapshot):
response = await flash_phi.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["network"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 6
assert response.generated_text == "Test request to send data over a network"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
responses = await generate_load(
flash_phi, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ": {request}\")\n response = self"
assert responses == response_snapshot