mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
update flash
This commit is contained in:
parent
630e417ca0
commit
5b1aaeceb2
@ -68,6 +68,7 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
|
||||
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
|
||||
- [Llama V2](https://huggingface.co/meta-llama)
|
||||
- [Code Llama](https://huggingface.co/codellama)
|
||||
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||
|
||||
Other architectures are supported on a best effort basis using:
|
||||
|
||||
|
@ -482,7 +482,6 @@ class AsyncClient:
|
||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||
) as session:
|
||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||
|
||||
if resp.status != 200:
|
||||
raise parse_error(resp.status, await resp.json())
|
||||
|
||||
|
@ -18,6 +18,8 @@ The following models are optimized and can be served with TGI, which uses custom
|
||||
- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)
|
||||
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
|
||||
- [Llama V2](https://huggingface.co/meta-llama)
|
||||
- [Code Llama](https://huggingface.co/codellama)
|
||||
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||
|
||||
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
||||
|
||||
|
63
integration-tests/models/test_flash_mistral.py
Normal file
63
integration-tests/models/test_flash_mistral.py
Normal file
@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_mistral_handle(launcher):
|
||||
with launcher("mistralai/Mistral-7B-Instruct-v0.1") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_mistral(flash_mistral_handle):
|
||||
await flash_mistral_handle.health(300)
|
||||
return flash_mistral_handle.client
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_mistral(flash_mistral, response_snapshot):
|
||||
response = await flash_mistral.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
|
||||
response = await flash_mistral.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 5
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_mistral, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
@ -1,4 +1,4 @@
|
||||
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
|
||||
flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c
|
||||
|
||||
flash-attention-v2:
|
||||
# Clone flash attention
|
||||
|
@ -290,7 +290,7 @@ class MistralAttention(torch.nn.Module):
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
max_past=self.max_past,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
|
@ -57,7 +57,7 @@ def attention(
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
max_past=0,
|
||||
window_size_left=0,
|
||||
):
|
||||
if HAS_FLASH_ATTN_V2:
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
@ -73,14 +73,17 @@ def attention(
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
max_past,
|
||||
window_size_left,
|
||||
0,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
if max_past != 0:
|
||||
raise NotImplementedError("max_past is only available with flash attn v2")
|
||||
if window_size_left != 0:
|
||||
raise NotImplementedError(
|
||||
"window_size_left is only available with flash attn v2"
|
||||
)
|
||||
|
||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||
if k.shape[1] != q.shape[1]:
|
||||
|
@ -53,6 +53,7 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# Monkey patching
|
||||
@classmethod
|
||||
def load_layer_norm(cls, prefix, weights, eps):
|
||||
|
@ -8,7 +8,9 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
output = subprocess.check_output(["text-generation-launcher", "--help"]).decode("utf-8")
|
||||
output = subprocess.check_output(["text-generation-launcher", "--help"]).decode(
|
||||
"utf-8"
|
||||
)
|
||||
final_doc = f"# Text-generation-launcher arguments\n```\n{output}\n```"
|
||||
|
||||
filename = "docs/source/basic_tutorials/launcher.md"
|
||||
@ -16,16 +18,20 @@ def main():
|
||||
with open(filename, "r") as f:
|
||||
doc = f.read()
|
||||
if doc != final_doc:
|
||||
|
||||
tmp = "launcher.md"
|
||||
with open(tmp, "w") as g:
|
||||
g.write(final_doc)
|
||||
diff = subprocess.run(["diff",tmp, filename], capture_output=True).stdout.decode("utf-8")
|
||||
diff = subprocess.run(
|
||||
["diff", tmp, filename], capture_output=True
|
||||
).stdout.decode("utf-8")
|
||||
print(diff)
|
||||
raise Exception("Doc is not up-to-date, run `python update_doc.py` in order to update it")
|
||||
raise Exception(
|
||||
"Doc is not up-to-date, run `python update_doc.py` in order to update it"
|
||||
)
|
||||
else:
|
||||
with open(filename, "w") as f:
|
||||
f.write(final_doc)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user