Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

do not set sliding_window if SUPPORTS_WINDOWING is false #2554

Closed
wants to merge 6 commits into from

Conversation

sywangyi
Copy link
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
@sywangyi
Copy link
Contributor Author

sywangyi commented Sep 24, 2024

for model like mistralai/Mistral-7B-v0.1, whose "sliding_window" is not null.
crash like
Traceback (most recent call last):
File "/opt/conda/lib/python3.11/site-packages/text_generation_server/interceptor.py", line 21, in intercept
return await response
^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 120, in _unary_interceptor
raise error
File "/opt/conda/lib/python3.11/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 111, in _unary_interceptor
return await behavior(request_or_iterator, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/text_generation_server/server.py", line 181, in Decode
generations, next_batch, timings = self.model.generate_token(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/text_generation_server/models/flash_causal_lm.py", line 1597, in generate_token
out, speculative_logits = self.forward(batch, adapter_data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/text_generation_server/models/flash_causal_lm.py", line 1500, in forward
logits, speculative_logits = self.model.forward(
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py", line 514, in forward
seqlen = seqlen.clamp(max=self.max_past_tensor)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/text_generation_server/layers/attention/common.py", line 68, in clamp
raise NotImplementedError("Not implemented seqlen for paged")
NotImplementedError: Not implemented seqlen for paged

During handling of the above exception, another exception occurred:

even sliding window size is not meet.

@sywangyi
Copy link
Contributor Author

there's logic in init.py, see https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/models/__init__.py#L509-L512, so no need to set sliding window if SUPPORTS_WINDOWING is false

@sywangyi
Copy link
Contributor Author

@ErikKaum @Narsil @danieldk please help review. this PR not only benefit intel-platform. but other platform whose attention ops do not support sliding window as well

@Narsil
Copy link
Collaborator

Narsil commented Oct 1, 2024

I cannot reproduce the issue. The error is raised if windowed attention is necessary AND the max-total-tokens > window size.

I tried on a target without attention and get the correct behavior.
I think someone "fixed" amd by just doing an early return because clamp is not supported in rocm and me actually brutally deactivating clamp when we moved to flashdecoding/flashinfer. I think we can just remove the Exception.
is torch.clamp supported in all intel targets ?

@sywangyi
Copy link
Contributor Author

sywangyi commented Oct 4, 2024

Hi, @Narsil you could use tag 2.3.0 to reproduce it. I could not reproduce the issue in latest tag either because mllama enabing makes the page attention path not work any more in intel platform. error like
Could not import Flash Attention enabled models: No module named 'flash_attn_2_cuda' I will fix it in another PR

this is launcher command to reproduce by my side
docker run --rm --shm-size 1g --network host -e http_proxy=${http_proxy} -e https_proxy=${https_proxy} -e HF_TOKEN=xxxxx --cap-add=sys_nice --ipc=host -v ~/data:/data --device /dev/dri --privileged ghcr.io/huggingface/text-generation-inference:2.3.0-intel-cpu --model-id mistralai/Mistral-7B-v0.1

the error occur when max-total-tokens < window size. if max-total > window size, the error should be raised since window size is not supported yet in ipex page attention kernel.

@sywangyi
Copy link
Contributor Author

sywangyi commented Oct 4, 2024

import flash_attn_2_cuda issue fixed by #2610

@sywangyi
Copy link
Contributor Author

could reproduce in latest main since mllama PR is merged. could you revisit the updated PR? thanks @Narsil

@Narsil Narsil mentioned this pull request Oct 11, 2024
5 tasks
@Narsil
Copy link
Collaborator

Narsil commented Oct 11, 2024

For the windowing, I reproduced and fixed the bug I think slightly more simply : #2637

Does that work ?
The current windowing system works by refusing to start of the window is > total_tokens (when windowing is not supported) , if total_tokens <= windowing we ignore the windowing.

The logic was already there, the bug was that Seqlen logic was flawed for non flashdecoding/flashinfer.

@Narsil Narsil closed this in #2637 Oct 11, 2024
Comment on lines +115 to +116
ENV TORCH_LLM_ALLREDUCE=1
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this doing ?
Should this be included in a different PR ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants