Skip to content

Commit

Permalink
feat: add flash_attn support
Browse files Browse the repository at this point in the history
  • Loading branch information
sammcj committed May 5, 2024
1 parent 5631fe7 commit 34136d4
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 19 deletions.
2 changes: 2 additions & 0 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ type Runner struct {
UseMMap bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"`
NumThread int `json:"num_thread,omitempty"`
FlashAttn bool `json:"flash_attn,omitempty"`

// Unused: RopeFrequencyBase is ignored. Instead the value in the model will be used
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
Expand Down Expand Up @@ -428,6 +429,7 @@ func DefaultOptions() Options {
UseMLock: false,
UseMMap: true,
UseNUMA: false,
FlashAttn: false, // for CPU only compatibility
},
}
}
Expand Down
11 changes: 0 additions & 11 deletions docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,3 @@ curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": 0
Alternatively, you can change the amount of time all models are loaded into memory by setting the `OLLAMA_KEEP_ALIVE` environment variable when starting the Ollama server. The `OLLAMA_KEEP_ALIVE` variable uses the same parameter types as the `keep_alive` parameter types mentioned above. Refer to section explaining [how to configure the Ollama server](#how-do-i-configure-ollama-server) to correctly set the environment variable.

If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` API parameter with the `/api/generate` or `/api/chat` API endpoints.

## Passing additional parameters to llama.cpp

You can pass additional parameters to the `llama.cpp` binary by setting the `OLLAMA_LLAMA_EXTRA_ARGS` environment variable. This can be useful for debugging or performance testing.

Example - enabling flash_attn and continuos batching:

```shell
export OLLAMA_LLAMA_EXTRA_ARGS="-fa,-cb"
ollama serve
```
14 changes: 11 additions & 3 deletions llm/ext_server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
printf(" --flash-attn-disable disable Flash Attention\n");
printf(" -spf FNAME, --system-prompt-file FNAME\n");
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
printf(" -ctk TYPE, --cache-type-k TYPE\n");
Expand Down Expand Up @@ -2508,7 +2509,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
{
params.use_mmap = false;
}
else if (arg == "--numa") {
else if (arg == "--numa")
{
if (++i >= argc) {
invalid_param = true;
break;
Expand All @@ -2532,6 +2534,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
{
params.flash_attn = true;
}
else if (arg == "--flash-attn-disable")
{
params.flash_attn = false;
}
else if (arg == "-np" || arg == "--parallel")
{
if (++i >= argc)
Expand All @@ -2540,15 +2546,17 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
params.n_parallel = std::stoi(argv[i]);
} else if (arg == "-n" || arg == "--n-predict")
}
else if (arg == "-n" || arg == "--n-predict")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.n_predict = std::stoi(argv[i]);
} else if (arg == "-spf" || arg == "--system-prompt-file")
}
else if (arg == "-spf" || arg == "--system-prompt-file")
{
if (++i >= argc)
{
Expand Down
2 changes: 1 addition & 1 deletion llm/llama.cpp
Submodule llama.cpp updated 94 files
+15 −1 .flake8
+13 −1 .github/workflows/bench.yml
+1 −1 .github/workflows/close-issue.yml
+1 −2 .github/workflows/python-lint.yml
+3 −2 .pre-commit-config.yaml
+1 −2 Makefile
+10 −36 README.md
+6 −2 ci/run.sh
+8 −1 common/common.cpp
+2 −1 common/common.h
+2 −2 common/log.h
+97 −75 convert-hf-to-gguf-update.py
+128 −151 convert-hf-to-gguf.py
+27 −24 convert-llama-ggml-to-gguf.py
+17 −15 convert-lora-to-ggml.py
+16 −12 convert-persimmon-to-gguf.py
+36 −24 convert.py
+17 −11 examples/batched-bench/batched-bench.cpp
+15 −4 examples/gguf-split/gguf-split.cpp
+7 −7 examples/gguf-split/tests.sh
+27 −3 examples/llama-bench/llama-bench.cpp
+1 −1 examples/main/main.cpp
+115 −3 examples/perplexity/README.md
+175 −55 examples/perplexity/perplexity.cpp
+1 −0 examples/server/bench/bench.py
+5 −1 examples/server/server.cpp
+56 −32 examples/server/tests/features/results.feature
+100 −48 examples/server/tests/features/steps/steps.py
+6 −0 ggml-cuda.cu
+62 −15 ggml-cuda/common.cuh
+944 −0 ggml-cuda/fattn.cu
+3 −0 ggml-cuda/fattn.cuh
+36 −10 ggml-cuda/softmax.cu
+7 −0 ggml-kompute.cpp
+376 −176 ggml-metal.m
+654 −18 ggml-metal.metal
+5 −1 ggml-sycl.cpp
+5 −0 ggml-vulkan.cpp
+361 −16 ggml.c
+20 −0 ggml.h
+9 −3 ggml_vk_generate_shaders.py
+10 −8 gguf-py/examples/reader.py
+1 −3 gguf-py/gguf/constants.py
+3 −1 gguf-py/gguf/gguf_reader.py
+4 −1 gguf-py/gguf/gguf_writer.py
+13 −29 gguf-py/gguf/vocab.py
+58 −36 gguf-py/scripts/gguf-convert-endian.py
+20 −9 gguf-py/scripts/gguf-dump.py
+18 −13 gguf-py/scripts/gguf-set-metadata.py
+387 −204 llama.cpp
+6 −3 llama.h
+4 −0 models/ggml-vocab-bert-bge.gguf.inp
+2 −0 models/ggml-vocab-bert-bge.gguf.out
+ models/ggml-vocab-command-r.gguf
+106 −0 models/ggml-vocab-command-r.gguf.inp
+43 −0 models/ggml-vocab-command-r.gguf.out
+4 −0 models/ggml-vocab-deepseek-coder.gguf.inp
+2 −0 models/ggml-vocab-deepseek-coder.gguf.out
+4 −0 models/ggml-vocab-deepseek-llm.gguf.inp
+2 −0 models/ggml-vocab-deepseek-llm.gguf.out
+4 −0 models/ggml-vocab-falcon.gguf.inp
+2 −0 models/ggml-vocab-falcon.gguf.out
+4 −0 models/ggml-vocab-gpt-2.gguf.inp
+2 −0 models/ggml-vocab-gpt-2.gguf.out
+4 −0 models/ggml-vocab-llama-bpe.gguf.inp
+2 −0 models/ggml-vocab-llama-bpe.gguf.out
+4 −0 models/ggml-vocab-llama-spm.gguf.inp
+2 −0 models/ggml-vocab-llama-spm.gguf.out
+4 −0 models/ggml-vocab-mpt.gguf.inp
+2 −0 models/ggml-vocab-mpt.gguf.out
+ models/ggml-vocab-phi-3.gguf
+4 −0 models/ggml-vocab-phi-3.gguf.inp
+2 −0 models/ggml-vocab-phi-3.gguf.out
+ models/ggml-vocab-refact.gguf
+106 −0 models/ggml-vocab-refact.gguf.inp
+43 −0 models/ggml-vocab-refact.gguf.out
+4 −0 models/ggml-vocab-starcoder.gguf.inp
+2 −0 models/ggml-vocab-starcoder.gguf.out
+1 −1 requirements/requirements-convert.txt
+16 −24 scripts/compare-llama-bench.py
+66 −0 scripts/gen-unicode-data.py
+10 −4 scripts/run-with-preset.py
+8 −5 scripts/verify-checksum-models.py
+5 −2 tests/CMakeLists.txt
+48 −4 tests/test-backend-ops.cpp
+0 −117 tests/test-tokenizer-0-bpe.py
+0 −114 tests/test-tokenizer-0-spm.py
+30 −9 tests/test-tokenizer-0.cpp
+46 −0 tests/test-tokenizer-0.py
+34 −0 tests/test-tokenizer-0.sh
+458 −416 unicode-data.cpp
+1 −1 unicode-data.h
+11 −11 unicode.cpp
+1 −1 unicode.h
8 changes: 4 additions & 4 deletions llm/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--numa")
}

if gpus[0].Library == "cuda" || gpus[0].Library == "metal" || opts.FlashAttn {
params = append(params, "--flash-attn")
}

// "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests
numParallel := 1
if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
Expand All @@ -205,10 +209,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
}
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))

if other_args := os.Getenv("OLLAMA_LLAMA_EXTRA_ARGS"); other_args != "" {
params = append(params, strings.Split(other_args, ",")...)
}

for i := 0; i < len(servers); i++ {
dir := availableServers[servers[i]]
if dir == "" {
Expand Down

0 comments on commit 34136d4

Please sign in to comment.