Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for flash_attn #4120

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

feat: add support for flash_attn #4120

wants to merge 8 commits into from

Conversation

sammcj
Copy link

@sammcj sammcj commented May 3, 2024

Only enabled by default on a supported CUDA version or Metal is detected, configurable via params and the API.

Credit to @wanderingmeow who took my broken idea and made it work 🎉

Fixes #4051

@sammcj sammcj requested a review from bsdnet May 3, 2024 23:29
@wanderingmeow
Copy link

How about adding a flash_attn flag to the Runner struct, so it can be set via /set parameter flash_attn true or API call. Not sure if this would be an API-breaking change, but I think it's worth considering. Currently, flash attention doesn't work for CPU or pre-Tensor cores GPU inference, it might not be desirable to enable it by default.

Also, a quick note: you've forgotten to update the llama.cpp dependency.

@sammcj
Copy link
Author

sammcj commented May 4, 2024

How about adding a flash_attn flag to the Runner struct, so it can be set via /set parameter flash_attn true or API call. Not sure if this would be an API-breaking change, but I think it's worth considering. Currently, flash attention doesn't work for CPU or pre-Tensor cores GPU inference, it might not be desirable to enable it by default.

Yeah good idea, added.

n.b. I feel like there's quite a bit of repeated parameter definitions throughout the server / client def that really could be pulled in from a single source of truth, but that's a problem for another day.

Also, a quick note: you've forgotten to update the llama.cpp dependency.

Whoops! Thanks, fixed now.

@jukofyork
Copy link

jukofyork commented May 4, 2024

👍

Hopefully this does get merged and not just left to die a painful "death-by-conflicts" like so many other PRs have already! ☹️

I'm getting double the context on some of the bigger models using llama.cpp server directly!

@sammcj
Copy link
Author

sammcj commented May 4, 2024

@bsdnet any chance of a approval?

@wanderingmeow
Copy link

@sammcj API binding is missing in server.go. Specifically, the flag -fa is not being passed into server.cpp when opts.FlashAttn is set.

if opts.FlashAttn {
    params = append(params, "-fa")
}

Additionally, I'm not sure if it's the right time to introduce the OLLAMA_LLAMA_EXTRA_ARGS environment variable. One concern is that it doesn't handle parameter overrides, and some flags don't have a disable equivalent (e.g., -fa enables flash attention but there's no --no-flash-attn to disable it). Maybe we should wait until after the revamp of server.cpp, especially server_params_parse(), or consider calling C++ directly in Go without spawning another process?

@sammcj sammcj force-pushed the main branch 3 times, most recently from 3b249b4 to ffb2e2a Compare May 5, 2024 12:13
@sammcj
Copy link
Author

sammcj commented May 5, 2024

Forgot to git add server.go after making an update, thanks @wanderingmeow, fixed now.

I hear you re: the server params, what I've done is remove the additional params and defaulted to enabling flash_attn if CUDA or Metal is detected and added an option to explicitly disable it.

@sammcj sammcj changed the title feat: option to parse args to llama.cpp, support flash_attn feat: add support for flash_attn May 5, 2024
@wanderingmeow
Copy link

Just wanted to point out that pre-Turing NVIDIA cards (CC < 7.0) don't support flash attention, as mentioned in ggerganov/llama.cpp/issues/7055.

To avoid any issues or confusion, I think we should add a check before enabling it to ensure it's supported by the user's hardware:

if opts.FlashAttn {
    flashAttnSupported := (gpus[0].Library == "cuda" && gpus[0].Major >= 7 || gpus[0].Library == "metal")
    if flashAttnSupported {
        params = append(params, "--flash-attn")
    } else {
        slog.Warn("flash attention is not supported on your current hardware configuration, it is now disabled")
    }
}

Considering this feature is opt-in and with this check in place, --disable-flash-attention flag in server.cpp can be removed.

@sammcj
Copy link
Author

sammcj commented May 5, 2024

That's a nice way of handling it! Thanks, I'm learning every day 😄.
PR updated.

@jukofyork
Copy link

Working really well! 👍

@chigkim
Copy link

chigkim commented May 6, 2024

Is there disadvantage? If not, should be enabled by default for everyone other than for the systems that can't support?

@sammcj
Copy link
Author

sammcj commented May 6, 2024

Is there disadvantage? If not, should be enabled by default for everyone other than for the systems that can't support?

While I agree, I feel like the most important thing for this right now is to get someone to approve it (as an optional feature) so folks can actually start using it and reporting their findings.

Ollama maintainers / @bsdnet or maybe @jmorganca - if there's anything I can do to help get this PR moving along please do let me know.

@sammcj
Copy link
Author

sammcj commented May 7, 2024

Looks like it needs approval to allow the workflow to run.

@sammcj sammcj force-pushed the main branch 2 times, most recently from 05407fd to 4bbd583 Compare May 8, 2024 21:15
@sammcj
Copy link
Author

sammcj commented May 8, 2024

@dhiltgen any chance of your eyes on this one?

api/types.go Outdated Show resolved Hide resolved
llm/server.go Outdated Show resolved Hide resolved
@sammcj
Copy link
Author

sammcj commented May 9, 2024

Cleaned up logic and rebased.

@sammcj
Copy link
Author

sammcj commented May 11, 2024

I'm all ears if anyone in the Ollama contributors thinks the PR needs improvements, please just let me know @jmorganca or @dhiltgen?

This will resolve #4051.

I'm somewhat afraid this PR will sit there if I don't keep nagging. I'm sure folks are just very busy but given the number of open PRs I suspect that the project would benefit from embracing some additional automation around the review and merge process of features/fixes, especially if continuing to maintain a partial fork of llama.cpp's server.

image

@jmorganca
Copy link
Member

Hi @sammcj this is close to merging, I'm just going to update the llama.cpp submodule in another PR since we can remove a temporary patch #4414

@sammcj
Copy link
Author

sammcj commented May 14, 2024

@jmorganca great, thanks!

Let me know if you want me to update or remove the bump of the submodule, otherwise I'm guessing updating from main once your other PR is merged should do it.

@sammcj
Copy link
Author

sammcj commented May 15, 2024

Daily reminder about this PR @jmorganca 🤣

Copy link
Member

@jmorganca jmorganca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs one more rebase

@sammcj
Copy link
Author

sammcj commented May 16, 2024

@jmorganca done :)

Copy link
Member

@jmorganca jmorganca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second look I'm seeing some strange results on metal with partial offloading. I'm not sure if this might be fixed in more recent commit of llama.cpp, but it spits out strange characters:

% ./ollama run llama3
>>> /set parameter num_gpu 10
Set parameter 'num_gpu' to '10'
>>> hi
.8;!177'1G.DC3"6C,64;B!H/G-<392^C

We might need to only enable this on fully loaded models

@jmorganca
Copy link
Member

jmorganca commented May 16, 2024

(Also: thanks so much for rebasing this quite a few times)

@sammcj
Copy link
Author

sammcj commented May 16, 2024

@jmorganca I haven't noticed this!

But to be safe, I just pushed an update to ensure it's disabled if we're using the CPU runner.

@sammcj sammcj requested a review from jmorganca May 16, 2024 23:44
llm/server.go Outdated Show resolved Hide resolved
@sammcj sammcj requested a review from jmorganca May 17, 2024 07:45
@dpublic
Copy link

dpublic commented May 17, 2024

I know this is late for this pull request, but can you add a new env var so people can force the use of flash_attn?
This would be a way for people to try it on cpu as the problems noted could be model-related.
This would be similar to how OLLAMA_NUM_PARALLEL sets a parameter.

@sammcj
Copy link
Author

sammcj commented May 17, 2024

I know this is late for this pull request, but can you add a new env var so people can force the use of flash_attn? This would be a way for people to try it on cpu as the problems noted could be model-related. This would be similar to how OLLAMA_NUM_PARALLEL sets a parameter.

@dpublic I think to avoid dragging this PR on I'd rather just leave it as is (auto-enabled).

I did originally have this, however the preference was to simply enable it if supported until such time as Ollama implements a proper config file / dotenv combo (which would be great) 🤞 as there's a lot of settings both for Ollama and llama.cpp which would benefit from a centralised configuration.

@sammcj
Copy link
Author

sammcj commented May 20, 2024

ping @jmorganca

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enable Flash Attention on GGML/GGUF (feature now merged into llama.cpp)
8 participants