update flash

This commit is contained in:
OlivierDehaene 2023-09-27 19:17:39 +02:00
parent 630e417ca0
commit 5b1aaeceb2
9 changed files with 86 additions and 11 deletions

View File

@ -68,6 +68,7 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
- [MPT](https://huggingface.co/mosaicml/mpt-30b) - [MPT](https://huggingface.co/mosaicml/mpt-30b)
- [Llama V2](https://huggingface.co/meta-llama) - [Llama V2](https://huggingface.co/meta-llama)
- [Code Llama](https://huggingface.co/codellama) - [Code Llama](https://huggingface.co/codellama)
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
Other architectures are supported on a best effort basis using: Other architectures are supported on a best effort basis using:

View File

@ -482,7 +482,6 @@ class AsyncClient:
headers=self.headers, cookies=self.cookies, timeout=self.timeout headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session: ) as session:
async with session.post(self.base_url, json=request.dict()) as resp: async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200: if resp.status != 200:
raise parse_error(resp.status, await resp.json()) raise parse_error(resp.status, await resp.json())

View File

@ -18,6 +18,8 @@ The following models are optimized and can be served with TGI, which uses custom
- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b) - [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)
- [MPT](https://huggingface.co/mosaicml/mpt-30b) - [MPT](https://huggingface.co/mosaicml/mpt-30b)
- [Llama V2](https://huggingface.co/meta-llama) - [Llama V2](https://huggingface.co/meta-llama)
- [Code Llama](https://huggingface.co/codellama)
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models: If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:

View File

@ -0,0 +1,63 @@
import pytest
@pytest.fixture(scope="module")
def flash_mistral_handle(launcher):
with launcher("mistralai/Mistral-7B-Instruct-v0.1") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_mistral(flash_mistral_handle):
await flash_mistral_handle.health(300)
return flash_mistral_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral(flash_mistral, response_snapshot):
response = await flash_mistral.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
response = await flash_mistral.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 5
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot):
responses = await generate_load(
flash_mistral, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -1,4 +1,4 @@
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c
flash-attention-v2: flash-attention-v2:
# Clone flash attention # Clone flash attention

View File

@ -290,7 +290,7 @@ class MistralAttention(torch.nn.Module):
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
max_past=self.max_past, window_size_left=self.max_past,
) )
# Decode # Decode
else: else:

View File

@ -57,7 +57,7 @@ def attention(
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
max_past=0, window_size_left=0,
): ):
if HAS_FLASH_ATTN_V2: if HAS_FLASH_ATTN_V2:
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
@ -73,14 +73,17 @@ def attention(
softmax_scale, softmax_scale,
False, False,
True, True,
max_past, window_size_left,
0,
False, False,
None, None,
) )
if HAS_FLASH_ATTN: if HAS_FLASH_ATTN:
if max_past != 0: if window_size_left != 0:
raise NotImplementedError("max_past is only available with flash attn v2") raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads # Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]: if k.shape[1] != q.shape[1]:

View File

@ -53,6 +53,7 @@ try:
except ImportError: except ImportError:
pass pass
# Monkey patching # Monkey patching
@classmethod @classmethod
def load_layer_norm(cls, prefix, weights, eps): def load_layer_norm(cls, prefix, weights, eps):

View File

@ -8,7 +8,9 @@ def main():
args = parser.parse_args() args = parser.parse_args()
output = subprocess.check_output(["text-generation-launcher", "--help"]).decode("utf-8") output = subprocess.check_output(["text-generation-launcher", "--help"]).decode(
"utf-8"
)
final_doc = f"# Text-generation-launcher arguments\n```\n{output}\n```" final_doc = f"# Text-generation-launcher arguments\n```\n{output}\n```"
filename = "docs/source/basic_tutorials/launcher.md" filename = "docs/source/basic_tutorials/launcher.md"
@ -16,16 +18,20 @@ def main():
with open(filename, "r") as f: with open(filename, "r") as f:
doc = f.read() doc = f.read()
if doc != final_doc: if doc != final_doc:
tmp = "launcher.md" tmp = "launcher.md"
with open(tmp, "w") as g: with open(tmp, "w") as g:
g.write(final_doc) g.write(final_doc)
diff = subprocess.run(["diff",tmp, filename], capture_output=True).stdout.decode("utf-8") diff = subprocess.run(
["diff", tmp, filename], capture_output=True
).stdout.decode("utf-8")
print(diff) print(diff)
raise Exception("Doc is not up-to-date, run `python update_doc.py` in order to update it") raise Exception(
"Doc is not up-to-date, run `python update_doc.py` in order to update it"
)
else: else:
with open(filename, "w") as f: with open(filename, "w") as f:
f.write(final_doc) f.write(final_doc)
if __name__ == "__main__": if __name__ == "__main__":
main() main()