Commit Graph

660 Commits

Author SHA1 Message Date
drbh
bc5e202d2c fix: adjust video process, reduce to 1 fps and adjust tensor shape 2024-12-23 13:47:18 -05:00
Miquel Farre
e65ead12bb moving video sampling and resize to validation. downstream we receive frames 2024-12-23 13:47:18 -05:00
David Holtz
322165d767 fix: remove unused deps and imports 2024-12-23 13:47:18 -05:00
David Holtz
b2c557594f feat: support video input chunks and enable qwen2 vl to process video 2024-12-23 13:47:18 -05:00
Miquel Farre
3c07391e8e fix 2024-12-23 13:47:18 -05:00
Miquel Farre
a25c3ecefc refactoring 2024-12-23 13:47:18 -05:00
Miquel Farre
464609fd43 fix 2024-12-23 13:47:18 -05:00
Miquel Farre
b9c8152ac6 downloading videos 2024-12-23 13:47:18 -05:00
Miquel Farre
c7c2fdae8c fix 2024-12-23 13:47:18 -05:00
Miquel Farre
05464d26bf connecting video to qwen2 2024-12-23 13:47:18 -05:00
Miquel Farre
18c9f06ded WIP video support 2024-12-23 13:47:18 -05:00
Mohit Sharma
8f66d323d0
Update vllm kernels for ROCM (#2826)
* (vllm) updated vllm rocm kernels

* revert silu

* update partition size

* remove grouped_topk

* (nit) remove log

* update moe-kernels commit
2024-12-18 12:44:42 +01:00
janne-alatalo
7eeefa3b57
Qwen2-VL runtime error fix when prompted with multiple images (#2840)
* Fix runtime error when Qwen2-VL was prompted with multiple images

Fix runtime error when Qwen2-VL model is prompted with prompt with more
than one image. The runtime error was:

 File "text-generation-inference/server/text_generation_server/models/custom_modeling/qwen2_vl.py", line 459, in get_position_ids
    text_pos_ids = torch.arange(text_length, device=d)
RuntimeError: upper bound and larger bound inconsistent with step sign

The error was caused by text_length variable going to negative value
when multiple images caused multiple loops in the get_position_ids
function's main loop.

The error is a simple logic mistake where next_image_pos is initialized
as relative offset from current_pos, but was used like it was absolute
position from zero.

* Fix runtime error when Qwen2-VL was prompted with multiple images

Fix runtime error when Qwen2-VL model is prompted with prompt with more
than one image. The runtime error was:

File "text-generation-inference/server/text_generation_server/models/custom_modeling/qwen2_vl.py", line 534, in forward
    inputs_embeds[input_ids == self.image_token_id] = image_embeds
RuntimeError: shape mismatch: value tensor of shape [512, 3584] cannot be broadcast to indexing result of shape [1024, 3584]

(The error message shape numbers can be different depending on the input
image resolutions)

The error was caused by adding the wrong number of <|image_pad|> tokens
to the tokenized input in the image_text_replacement function.

The error is a simple logical mistake where the number of image pad
tokens is checked from pixel_value_shape tensor's first dimension
length. However, the pixel_value_shape contains patches from all of the
images. Therefore the code added the total number of required image pad
tokens for the whole input to each of the images locations. This
resulted to extra image pad tokens to be present in the tokenized input.

The fix was to check the number of required tokens from the
image_grid_thw tensor. The tensor includes grid_t, grid_h, and grid_w
values for each image. grid_t * grid_h * grid_w results to the total
number of patches for the image [1]. The number of required image pad
tokens is number_of_patches // 4.

[1] 31f9a289a6/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py (L311)

---------

Co-authored-by: Janne Alatalo <janne.alatalo@jamk.fi>
2024-12-16 22:55:11 -05:00
Nicolas Patry
3bb3fd19ae
Fixup opt to reduce the amount of odd if statements. (#2833)
* Fixup opt to reduce the amount of odd if statements.

* Fixing cargo lock
2024-12-12 18:20:13 +01:00
Wang, Yi
bf59118a93
fix facebook/opt-125m not working issue (#2824)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-12-12 14:41:30 +01:00
Nicolas Patry
82c24f7420
Using both value from config as they might not be correct. (#2817)
* Using both value from config as they might not be correct.

* Fixing max_position_embeddings for falcon.

* Simple attempt to fix the healthcheck block allocation.

* Much simpler solution.

* Default value for Backend start_health
2024-12-10 19:37:09 +01:00
Nicolas Patry
a04356fb8c
Attempt for cleverer auto batch_prefill values (some simplifications). (#2808)
* Attempt for cleverer auto batch_prefill values (some simplifications).

* Less flaky tests.

* Fixing typo insertion.

* Update launcher/src/main.rs

Co-authored-by: Daniël de Kok <me@danieldk.eu>

* Adding small comment for source of calculation.

* Adding L40.

* Adding L40s.

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2024-12-09 19:44:32 +01:00
drbh
9f5c9a5e22
Enable paligemma2 (#2807)
* feat: support loading gemma2 as vlm text model

* feat: add test for paligemma2
2024-12-06 14:41:49 -05:00
Nicolas Patry
08f6fa0b59
Removing experimental to prefill chunking. 2024-12-06 19:09:40 +01:00
Nicolas Patry
5df8059037
Auto max prefill (#2797)
* Attempt at automatic max batch prefill.

* Taking into account number of shards.

* Adding more cards.

* Adding A100 + H100

* Adding a few more cards.

* Logprobs cost too much.

* h100 better name, and keep factor of 2

* Damn inflated sparse tflops.

* Typo in h100.

* Updated the flops calculation (checked with fvcore).

* chunking by default.

* Fix prefix caching for chat completion since we removed logprobs.

* More tests.

* Dropping all the prefill logprobs.

* Add a flag that enables users to get logprobs back.

* Repairing prompt token counting.

* Fixing a few tests.

* Remove some scaffolding.

* Attempting to reduces the issues (workarounds for now).
2024-12-06 05:52:00 +01:00
drbh
e0db633396
fix: avoid setting use_sgmv if no kernels present (#2796) 2024-12-04 15:26:09 -05:00
Nicolas Patry
b57f370386
Saving some VRAM. (#2790)
* Saving some VRAM.

- 8B on 4xL4 attention=flashdecoding . Before 4.28GB left, After 4.32GB
  left, so 400MB saved.

- Effect not as visible on attention=flashinfer and n_shard=1. I suspect
  it's linked to the torch allocator.

* Adding assertion.
2024-12-03 04:04:21 +01:00
Daniël de Kok
2003d8be0c
Sync (most) server dependencies with Nix (#2782)
* Sync (most) server dependencies with Nix

Skipped most grpcio packages, because of protobuf version
incompatibility with the opentelemetry packages.

* Add a primitive script to generate Poetry commands to sync with Nix

This is not fully automated, since getting the Nix versions may be
unresolvable. However, it does take most of the work out of doing
this manually.

* Upgrade eetq ?

* Fmt.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-12-03 04:04:06 +01:00
Dmitry Rogozhkin
535149d872
fix: only use eos_token_id as pad_token_id if int (#2774)
LLama 3 has a list of values as eos_token_id:
  "['<|end_of_text|>', '<|eom_id|>', '<|eot_id|>']"
This breaks tokenizer since it expects single value. This
commit uses tokenizer.eos_token_id instead in such a case.

Fixes: #2440

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
2024-12-02 06:26:37 +01:00
Daniël de Kok
72ab60fdd5
Use FP8 KV cache when specified by compressed-tensors (#2761)
The compressed-tensors configuration can specify the configuration of
the KV cache as well. Use an FP8 KV cache when the configuration tells
us to do so (all other options and types are ignored for now).
2024-11-26 08:27:41 +01:00
Daniël de Kok
289aa48554
Move JSON grammar -> regex grammar conversion to the router (#2772)
* Move JSON grammar -> regex grammar conversion to the router

This change moves the JSON grammar -> regex grammar conversion to the
router by adding a dependency on the `outlines-core` Rust crate. In
contrast to the Python implementation, the conversions are not LRU-cached
since they seem to be fast enough:

simple schema           time:   [5.8293 µs 5.8307 µs 5.8320 µs]
                        change: [-13.166% -12.884% -12.641%] (p = 0.00 < 0.05)
                        Performance has improved.

complex schema          time:   [14.875 µs 14.881 µs 14.887 µs]
                        change: [-2.1637% -1.9914% -1.7852%] (p = 0.00 < 0.05)
                        Performance has improved.

Using the schemas from:
https://github.com/dottxt-ai/outlines-core/blob/main/benchmarks/bench_json_schema.py
2024-11-25 18:47:34 +01:00
Daniël de Kok
e87893d38e
chore: Update to marlin-kernels 0.3.6 (#2771)
This fixes a bug in 2:4 Marlin:
https://github.com/vllm-project/vllm/pull/10464
2024-11-22 14:44:47 +00:00
OlivierDehaene
ab7ccf5bc3
feat: add payload limit (#2726)
* feat: add payload limit

* update launcher
2024-11-21 18:20:15 +00:00
drbh
6ee8d6dd3b
fix: set outlines version to 0.1.3 to avoid caching serialization issue (#2766)
fix: set outlines version to 0.1.3 to avoid bug
2024-11-20 18:09:39 -05:00
Daniël de Kok
46a5a7e73e
Add support for wNa16 int 2:4 compressed-tensors checkpoints (#2758)
This change adds support for wNa16 int checkpoints with 2:4 sparsity
using Marlin 2:4 kernels.
2024-11-20 18:25:23 +01:00
drbh
bd6e8b3c13
fix: adjust llama MLP name from dense to mlp to correctly apply lora (#2760) 2024-11-19 15:10:22 -05:00
Daniël de Kok
2007a9473a
Update to moe-kernels 0.7.0 (#2720)
This version syncs with the vLLM kernels and brings some performance
improvements.
2024-11-19 14:55:29 +01:00
Daniël de Kok
b4ec427ad0
Simplify two ipex conditions (#2755) 2024-11-19 08:04:23 +01:00
drbh
38cff84a3e
feat: support flash attention 2 in qwen2 vl vision blocks (#2721)
* feat: support flash attention 2 in qwen2 vl vision blocks

* fix: calc max_seqlen once and small refactors
2024-11-18 12:46:40 -05:00
Daniël de Kok
3c9df21ff8
Add support for compressed-tensors w8a8 int checkpoints (#2745)
* Add support for compressed-tensors w8a8 int checkpoints

This change adds a loader for w8a8 int checkpoints. One large benefit of
int8 support is that the corresponding cutlass matmul kernels also work on
compute capability 7.5.

Evaluation on neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8:

|     Tasks     |Version|     Filter     |n-shot|        Metric         |   |Value |   |Stderr|
|---------------|------:|----------------|-----:|-----------------------|---|-----:|---|------|
|gsm8k_cot_llama|      3|flexible-extract|     8|exact_match            |↑  |0.8431|±  |0.0100|
|               |       |strict-match    |     8|exact_match            |↑  |0.8393|±  |0.0101|
|ifeval         |      4|none            |     0|inst_level_loose_acc   |↑  |0.8597|±  |   N/A|
|               |       |none            |     0|inst_level_strict_acc  |↑  |0.8201|±  |   N/A|
|               |       |none            |     0|prompt_level_loose_acc |↑  |0.7967|±  |0.0173|
|               |       |none            |     0|prompt_level_strict_acc|↑  |0.7468|±  |0.0187|

Which is the same ballpark as vLLM.

As usual, lots of thanks to Neural Magic/vLLM for the kernels.

* Always use dynamic input quantization for w8a8 int

It's far less flaky and gives better output.

* Use marlin-kernels 0.3.5

* Fix a typo

Co-authored-by: drbh <david.richard.holtz@gmail.com>

* Small fixes

---------

Co-authored-by: drbh <david.richard.holtz@gmail.com>
2024-11-18 17:20:31 +01:00
Wang, Yi
a5ecd6e586
add ipex moe implementation to support Mixtral and PhiMoe (#2707)
* add ipex moe implementation to support Mixtral and PhiMoe

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update to ipex xpu 2.5

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* torch has xpu support in 2.5

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix oneapi basekit version

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Apply suggestions from code review

Co-authored-by: Daniël de Kok <me@github.danieldk.eu>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Daniël de Kok <me@github.danieldk.eu>
2024-11-18 17:16:55 +01:00
drbh
fea62e928f
fix: improve find_segments via numpy diff (#2686) 2024-11-18 09:51:06 -05:00
Daniël de Kok
52e48739a5
Remove vLLM dependency for CUDA (#2751)
* Remove vLLM dependency for CUDA

This change adds `attention-kernels` as a dependency for paged
attention and cache reshaping. With that, we don't use vLLM
anywhere for CUDA.

Tested run (since we don't have paged attention in CI):

```
❯ ATTENTION=paged python -m pytest integration-tests -k "llama and awq" --release
[...]
5 snapshots passed.
```

* Fix clippy warning
2024-11-17 17:34:50 +01:00
Nicolas Patry
34a3bdedc3
Upgrading our deps. (#2750)
* Upgrading our deps.

* fixup.

* Fixup.
2024-11-15 14:03:27 +01:00
Alex Weston
4580ced091
Upgrade outlines to 0.1.1 (#2742)
* Upgrade outlines to 0.1.1

* Update for new API

* Check if allowed tokens is None

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-11-15 13:22:52 +01:00
Billel Mokeddem
4f4857a4ac
Fix: Change embeddings to embedding (#2738)
fix: change embeddings to embedding

Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-135.us-west-2.compute.internal>
2024-11-15 13:16:15 +01:00
Billel Mokeddem
f9ee46f740
Fix: Change model_type from ssm to mamba (#2740)
Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-135.us-west-2.compute.internal>
2024-11-15 13:15:36 +01:00
Daniël de Kok
a785000842
Add initial support for compressed-tensors checkpoints (#2732)
compressed-tensors is a safetensors extension for sparse, quantized
tensors. The format is more powerful than earlier AWQ/GPTQ/FP8
quantization, because

- Different quantizer configurations can be used for different targets.
- The format can specify input/output quantizers in addition to weight
  quantizers.
- Configurable exclusions for quantization.

This change adds a dependency on the `compressed-tensors` package for
its configuration parsing and layer matching functionality.

The following types of quantization are supported in this PR:

- W8A16 and W4A16 INT using GPTQ-Marlin kernels.
- W8A8 and W8A16 FP using FP8-Marlin and cutlass kernels.

Support for other quantization types will be added in subsequent PRs.
2024-11-10 13:54:07 +01:00
Wang, Yi
b1f9044d6c
fix incorrect output of Qwen2-7B-Instruct-GPTQ-Int4 and Qwen2-7B-Inst… (#2717)
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
Close stale issues and PRs / stale (push) Has been cancelled
Nightly load test / load-tests (push) Has been cancelled
fix incorrect output of Qwen2-7B-Instruct-GPTQ-Int4 and Qwen2-7B-Instruct-AWQ
ipex kernel provide func like add_bias, so no need add it outside

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-11-04 16:07:51 +01:00
Nicolas Patry
9fde566602
Fixing linting on main. (#2719) 2024-11-04 15:21:41 +01:00
Travis Addair
aadc9cb485
Fix prefix caching + speculative decoding (#2711) 2024-11-04 15:08:43 +01:00
Nicolas Patry
a5593ba83e
Hotfixing auto length (warmup max_s was wrong). (#2716)
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
2024-11-04 09:55:54 +01:00
drbh
6e3220529d
fix: create position ids for text only input (#2714)
* fix: create position ids for text only input

* fix: prefer repeat over expand to avoid clone
2024-11-02 08:40:05 +08:00
drbh
01dacf8e8f
fix cuda graphs for qwen2-vl (#2708)
* feat: support multidimensional position ids on batch to enable cuda graphs on qwen2-vl

* fix: only check model type if config exists

* fix: adjust sharding and lm head logic

* fix qwen2 failure in intel cpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix: return correct shape logits and add streaming test

* fix: remove unused import and refactor test

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2024-11-01 03:05:34 +01:00
drbh
befd9f6735
Support qwen2 vl (#2689)
* feat: add support for qwen2 vl model

* feat: fix token padding, enable warmup and process basic request

* fix: improve get_position_ids, add lift embed_tokens

* fix: remove get_cos_sin_hack dev function

* feat: add simple test chat with meesage and text

* fix: lint test

* fix: adjust positional embeddings for multi dimensional position ids

* fix: update docs and lint unused vars

* fix: include linted file

* fix: add norm after text output

* fix: format model file

* fix: adjust for ruff lints

* fix: remove unused rotate_half

* feat: refactors and calc num features

* fix: prefer position_ids passed from vlm causal lm and reset ids on batch

* fix: adjust get_position_ids if not available and add required args to signatures

* fix: adjust resize case for qwen2_vl warmup

* fix: avoid qwen2 vl specific paths with qwen2
2024-10-30 12:40:51 -04:00