text-generation-inference/integration-tests/models/test_mamba.py
drbh 51a4e62ed4 Impl simple mamba model (#1480)
This draft PR is a work in progress implementation of the mamba model.
This PR currently loads weights, and produces correct logits after a
single pass.

This PR still needs to correctly integrate this model so it produces
tokens as expected, and apply optimization to avoid all copies during
runtime/unnecessary operations.

[Mamba: Linear-Time Sequence Modeling with Selective State Spaces
(Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752)
https://github.com/johnma2006/mamba-minimal

https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs
https://github.com/huggingface/transformers/pull/28094

Notes: this dev work is currently targeting `state-spaces/mamba-130m`,
so if you want to test please use that model. Additionally when starting
the router the prefill needs to be limited: `cargo run --
--max-batch-prefill-tokens 768 --max-input-length 768`

Integration tests have been added and basic functionality such as model
loading is supported.

```bash
cd integration-tests
pytest -vv models/test_fused_kernel_mamba.py
```
- [x] add tests
- [x] load model
- [x] make simple request
- [ ] resolve warmup issue
- [ ] resolve output issues

fetching models tested during dev
```bash
text-generation-server download-weights state-spaces/mamba-130m
text-generation-server download-weights state-spaces/mamba-1.4b
text-generation-server download-weights state-spaces/mamba-2.8b
```

The server can be run
```bash
cd server
 MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b
```

router
```bash
cargo run
```

make a request
```bash
curl -s localhost:3000/generate \
    -X POST \
    -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
    -H 'Content-Type: application/json' | jq
```

response
```json
{
  "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data."
}
```

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-04-23 11:45:11 +03:00

60 lines
1.9 KiB
Python

import pytest
@pytest.fixture(scope="module")
def fused_kernel_mamba_handle(launcher):
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
yield handle
@pytest.fixture(scope="module")
async def fused_kernel_mamba(fused_kernel_mamba_handle):
await fused_kernel_mamba_handle.health(300)
return fused_kernel_mamba_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate(
"What is Deep Learning?", max_new_tokens=10
)
assert response.details.generated_tokens == 10
assert response.generated_text == "\n\nDeep learning is a new type of machine"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate(
"blue, red, yellow, ",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
assert responses == response_snapshot