text-generation-inference/integration-tests/models
drbh 51a4e62ed4 Impl simple mamba model (#1480)
This draft PR is a work in progress implementation of the mamba model.
This PR currently loads weights, and produces correct logits after a
single pass.

This PR still needs to correctly integrate this model so it produces
tokens as expected, and apply optimization to avoid all copies during
runtime/unnecessary operations.

[Mamba: Linear-Time Sequence Modeling with Selective State Spaces
(Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752)
https://github.com/johnma2006/mamba-minimal

https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs
https://github.com/huggingface/transformers/pull/28094

Notes: this dev work is currently targeting `state-spaces/mamba-130m`,
so if you want to test please use that model. Additionally when starting
the router the prefill needs to be limited: `cargo run --
--max-batch-prefill-tokens 768 --max-input-length 768`

Integration tests have been added and basic functionality such as model
loading is supported.

```bash
cd integration-tests
pytest -vv models/test_fused_kernel_mamba.py
```
- [x] add tests
- [x] load model
- [x] make simple request
- [ ] resolve warmup issue
- [ ] resolve output issues

fetching models tested during dev
```bash
text-generation-server download-weights state-spaces/mamba-130m
text-generation-server download-weights state-spaces/mamba-1.4b
text-generation-server download-weights state-spaces/mamba-2.8b
```

The server can be run
```bash
cd server
 MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b
```

router
```bash
cargo run
```

make a request
```bash
curl -s localhost:3000/generate \
    -X POST \
    -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
    -H 'Content-Type: application/json' | jq
```

response
```json
{
  "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data."
}
```

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-04-23 11:45:11 +03:00
..
__snapshots__ Impl simple mamba model (#1480) 2024-04-23 11:45:11 +03:00
test_bloom_560m_sharded.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_bloom_560m.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_flash_awq_sharded.py feat: format code (#1070) 2023-09-27 12:22:09 +02:00
test_flash_awq.py feat: format code (#1070) 2023-09-27 12:22:09 +02:00
test_flash_falcon.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_flash_llama_gptq.py feat: add cuda memory fraction (#659) 2023-07-24 11:43:58 +02:00
test_flash_llama.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_flash_medusa.py chore: formatting 2024-04-18 16:26:00 +03:00
test_flash_mistral.py chore: formatting 2024-04-18 16:26:00 +03:00
test_flash_neox_sharded.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_flash_neox.py feat(server): add paged attention to flash models (#516) 2023-06-30 19:09:59 +02:00
test_flash_phi.py v1.4.0 (#1494) 2024-04-22 15:47:42 +03:00
test_flash_santacoder.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_flash_starcoder_gptq.py Make GPTQ test less flaky (#1295) 2023-11-28 21:22:35 +01:00
test_flash_starcoder.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_idefics.py chore: formatting 2024-04-18 16:26:00 +03:00
test_mamba.py Impl simple mamba model (#1480) 2024-04-23 11:45:11 +03:00
test_mpt.py feat(server): Add Non flash MPT. (#514) 2023-07-03 13:01:46 +02:00
test_mt0_base.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00
test_neox_sharded.py feat(server): Rework model loading (#344) 2023-06-08 14:51:52 +02:00
test_neox.py feat(server): Rework model loading (#344) 2023-06-08 14:51:52 +02:00
test_t5_sharded.py feat(server): only compute prefill logprobs when asked (#406) 2023-06-02 17:12:30 +02:00