text-generation-inference/integration-tests/models/__snapshots__
drbh bd405e035b
Impl simple mamba model (#1480)
This draft PR is a work in progress implementation of the mamba model.
This PR currently loads weights, and produces correct logits after a
single pass.

This PR still needs to correctly integrate this model so it produces
tokens as expected, and apply optimization to avoid all copies during
runtime/unnecessary operations.

#### Helpful resources
[Mamba: Linear-Time Sequence Modeling with Selective State Spaces
(Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752)
https://github.com/johnma2006/mamba-minimal

https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs
https://github.com/huggingface/transformers/pull/28094

Notes: this dev work is currently targeting `state-spaces/mamba-130m`,
so if you want to test please use that model. Additionally when starting
the router the prefill needs to be limited: `cargo run --
--max-batch-prefill-tokens 768 --max-input-length 768`


## Update / Current State

Integration tests have been added and basic functionality such as model
loading is supported.

```bash
cd integration-tests
pytest -vv models/test_fused_kernel_mamba.py
```
- [x] add tests
- [x] load model
- [x] make simple request 
- [ ] resolve warmup issue
- [ ] resolve output issues


fetching models tested during dev
```bash
text-generation-server download-weights state-spaces/mamba-130m
text-generation-server download-weights state-spaces/mamba-1.4b
text-generation-server download-weights state-spaces/mamba-2.8b
```

The server can be run 
```bash
cd server
 MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b
```

router
```bash
cargo run
```

make a request
```bash
curl -s localhost:3000/generate \
    -X POST \
    -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
    -H 'Content-Type: application/json' | jq
```

response
```json
{
  "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data."
}
```

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-02-08 10:19:45 +01:00
..
test_bloom_560m feat(server): support vectorized warpers in flash causal lm (#317) 2023-05-26 12:30:27 +02:00
test_bloom_560m_sharded feat(integration-tests): improve comparison and health checks (#336) 2023-05-16 20:22:11 +02:00
test_flash_awq Add AWQ quantization inference support (#1019) (#1054) 2023-09-25 15:31:27 +02:00
test_flash_awq_sharded Add AWQ quantization inference support (#1019) (#1054) 2023-09-25 15:31:27 +02:00
test_flash_falcon feat(server): add retry on download (#384) 2023-05-31 10:57:53 +02:00
test_flash_llama Remove the stripping of the prefix space (and any other mangling that tokenizers might do). (#1065) 2023-09-27 12:13:45 +02:00
test_flash_llama_gptq feat(server): Add exllama GPTQ CUDA kernel support #553 (#666) 2023-07-21 10:59:00 +02:00
test_flash_medusa Speculative (#1308) 2023-12-11 12:46:30 +01:00
test_flash_mistral feat: add mistral model (#1071) 2023-09-28 09:55:47 +02:00
test_flash_neox fix(server): fix init for flash causal lm (#352) 2023-05-22 15:05:32 +02:00
test_flash_neox_sharded fix(server): fix init for flash causal lm (#352) 2023-05-22 15:05:32 +02:00
test_flash_phi feat: adds phi model (#1442) 2024-01-25 15:37:53 +01:00
test_flash_santacoder feat(integration-tests): improve comparison and health checks (#336) 2023-05-16 20:22:11 +02:00
test_flash_starcoder feat(server): Rework model loading (#344) 2023-06-08 14:51:52 +02:00
test_flash_starcoder_gptq Reinstate exl2 with tp (#1490) 2024-01-26 14:00:29 +01:00
test_idefics Fixing non divisible embeddings. (#1476) 2024-01-24 13:08:41 +01:00
test_mamba Impl simple mamba model (#1480) 2024-02-08 10:19:45 +01:00
test_mpt feat(server): Add Non flash MPT. (#514) 2023-07-03 13:01:46 +02:00
test_mt0_base feat(server): support vectorized warpers in flash causal lm (#317) 2023-05-26 12:30:27 +02:00
test_neox feat(server): Rework model loading (#344) 2023-06-08 14:51:52 +02:00
test_neox_sharded feat(server): Rework model loading (#344) 2023-06-08 14:51:52 +02:00
test_t5_sharded feat(server): support fp16 for t5 (#360) 2023-05-23 18:16:48 +02:00