diff --git a/README.md b/README.md index dc074d50..c6db2822 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ to power Hugging Chat, the Inference API and Inference Endpoint. - [MPT](https://huggingface.co/mosaicml/mpt-30b) - [Llama V2](https://huggingface.co/meta-llama) - [Code Llama](https://huggingface.co/codellama) +- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) Other architectures are supported on a best effort basis using: diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index ff7f66a3..0bf80f8c 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -482,7 +482,6 @@ class AsyncClient: headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: async with session.post(self.base_url, json=request.dict()) as resp: - if resp.status != 200: raise parse_error(resp.status, await resp.json()) diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index c5768d9a..5d645759 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -18,6 +18,8 @@ The following models are optimized and can be served with TGI, which uses custom - [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b) - [MPT](https://huggingface.co/mosaicml/mpt-30b) - [Llama V2](https://huggingface.co/meta-llama) +- [Code Llama](https://huggingface.co/codellama) +- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models: diff --git a/integration-tests/models/test_flash_mistral.py b/integration-tests/models/test_flash_mistral.py new file mode 100644 index 00000000..0a1c76ba --- /dev/null +++ b/integration-tests/models/test_flash_mistral.py @@ -0,0 +1,63 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_mistral_handle(launcher): + with launcher("mistralai/Mistral-7B-Instruct-v0.1") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_mistral(flash_mistral_handle): + await flash_mistral_handle.health(300) + return flash_mistral_handle.client + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_mistral(flash_mistral, response_snapshot): + response = await flash_mistral.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_mistral_all_params(flash_mistral, response_snapshot): + response = await flash_mistral.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 5 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot): + responses = await generate_load( + flash_mistral, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index a7d63356..cdea8431 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc +flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c flash-attention-v2: # Clone flash attention diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index f721d51f..77b7f230 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -290,7 +290,7 @@ class MistralAttention(torch.nn.Module): cu_seqlen_prefill, max_s, self.softmax_scale, - max_past=self.max_past, + window_size_left=self.max_past, ) # Decode else: diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index aa02b950..bde0aa76 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -57,7 +57,7 @@ def attention( cu_seqlens, max_s, softmax_scale, - max_past=0, + window_size_left=0, ): if HAS_FLASH_ATTN_V2: return flash_attn_2_cuda.varlen_fwd( @@ -73,14 +73,17 @@ def attention( softmax_scale, False, True, - max_past, + window_size_left, + 0, False, None, ) if HAS_FLASH_ATTN: - if max_past != 0: - raise NotImplementedError("max_past is only available with flash attn v2") + if window_size_left != 0: + raise NotImplementedError( + "window_size_left is only available with flash attn v2" + ) # Flash attention v1 requires q, k and v to have the same number of heads if k.shape[1] != q.shape[1]: diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 8be2463f..cf61e47b 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -53,6 +53,7 @@ try: except ImportError: pass + # Monkey patching @classmethod def load_layer_norm(cls, prefix, weights, eps): diff --git a/update_doc.py b/update_doc.py index a4c95743..7e8fb769 100644 --- a/update_doc.py +++ b/update_doc.py @@ -8,7 +8,9 @@ def main(): args = parser.parse_args() - output = subprocess.check_output(["text-generation-launcher", "--help"]).decode("utf-8") + output = subprocess.check_output(["text-generation-launcher", "--help"]).decode( + "utf-8" + ) final_doc = f"# Text-generation-launcher arguments\n```\n{output}\n```" filename = "docs/source/basic_tutorials/launcher.md" @@ -16,16 +18,20 @@ def main(): with open(filename, "r") as f: doc = f.read() if doc != final_doc: - tmp = "launcher.md" with open(tmp, "w") as g: g.write(final_doc) - diff = subprocess.run(["diff",tmp, filename], capture_output=True).stdout.decode("utf-8") + diff = subprocess.run( + ["diff", tmp, filename], capture_output=True + ).stdout.decode("utf-8") print(diff) - raise Exception("Doc is not up-to-date, run `python update_doc.py` in order to update it") + raise Exception( + "Doc is not up-to-date, run `python update_doc.py` in order to update it" + ) else: with open(filename, "w") as f: f.write(final_doc) + if __name__ == "__main__": main()