mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
This draft PR is a work in progress implementation of the mamba model. This PR currently loads weights, and produces correct logits after a single pass. This PR still needs to correctly integrate this model so it produces tokens as expected, and apply optimization to avoid all copies during runtime/unnecessary operations. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752) https://github.com/johnma2006/mamba-minimal https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs https://github.com/huggingface/transformers/pull/28094 Notes: this dev work is currently targeting `state-spaces/mamba-130m`, so if you want to test please use that model. Additionally when starting the router the prefill needs to be limited: `cargo run -- --max-batch-prefill-tokens 768 --max-input-length 768` Integration tests have been added and basic functionality such as model loading is supported. ```bash cd integration-tests pytest -vv models/test_fused_kernel_mamba.py ``` - [x] add tests - [x] load model - [x] make simple request - [ ] resolve warmup issue - [ ] resolve output issues fetching models tested during dev ```bash text-generation-server download-weights state-spaces/mamba-130m text-generation-server download-weights state-spaces/mamba-1.4b text-generation-server download-weights state-spaces/mamba-2.8b ``` The server can be run ```bash cd server MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b ``` router ```bash cargo run ``` make a request ```bash curl -s localhost:3000/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' | jq ``` response ```json { "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data." } ``` --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
import pytest
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def fused_kernel_mamba_handle(launcher):
|
|
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
|
|
yield handle
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
|
await fused_kernel_mamba_handle.health(300)
|
|
return fused_kernel_mamba_handle.client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
|
response = await fused_kernel_mamba.generate(
|
|
"What is Deep Learning?", max_new_tokens=10
|
|
)
|
|
|
|
assert response.details.generated_tokens == 10
|
|
assert response.generated_text == "\n\nDeep learning is a new type of machine"
|
|
assert response == response_snapshot
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
|
response = await fused_kernel_mamba.generate(
|
|
"blue, red, yellow, ",
|
|
max_new_tokens=10,
|
|
repetition_penalty=1.2,
|
|
return_full_text=True,
|
|
stop_sequences=["test"],
|
|
temperature=0.5,
|
|
top_p=0.9,
|
|
top_k=10,
|
|
truncate=5,
|
|
typical_p=0.9,
|
|
watermark=True,
|
|
decoder_input_details=True,
|
|
seed=0,
|
|
)
|
|
|
|
assert response.details.generated_tokens == 10
|
|
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
|
|
assert response == response_snapshot
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
|
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4)
|
|
|
|
assert len(responses) == 4
|
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
|
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
|
|
|
|
assert responses == response_snapshot
|