Skip to content

Commit

Permalink
feat: add flash_attn support
Browse files Browse the repository at this point in the history
  • Loading branch information
sammcj committed May 5, 2024
1 parent b7a87a2 commit e400b1a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
2 changes: 2 additions & 0 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ type Runner struct {
UseMMap bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"`
NumThread int `json:"num_thread,omitempty"`
FlashAttn bool `json:"flash_attn,omitempty"`

// Unused: RopeFrequencyBase is ignored. Instead the value in the model will be used
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
Expand Down Expand Up @@ -428,6 +429,7 @@ func DefaultOptions() Options {
UseMLock: false,
UseMMap: true,
UseNUMA: false,
FlashAttn: false, // for CPU only compatibility
},
}
}
Expand Down
17 changes: 14 additions & 3 deletions llm/ext_server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2106,6 +2106,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
printf(" -spf FNAME, --system-prompt-file FNAME\n");
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
printf(" -ctk TYPE, --cache-type-k TYPE\n");
Expand Down Expand Up @@ -2507,7 +2508,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
{
params.use_mmap = false;
}
else if (arg == "--numa") {
else if (arg == "--numa")
{
if (++i >= argc) {
invalid_param = true;
break;
Expand All @@ -2527,6 +2529,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
{
params.cont_batching = true;
}
else if (arg == "-fa" || arg == "--flash-attn")
{
params.flash_attn = true;
}
else if (arg == "-np" || arg == "--parallel")
{
if (++i >= argc)
Expand All @@ -2535,15 +2541,17 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
params.n_parallel = std::stoi(argv[i]);
} else if (arg == "-n" || arg == "--n-predict")
}
else if (arg == "-n" || arg == "--n-predict")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.n_predict = std::stoi(argv[i]);
} else if (arg == "-spf" || arg == "--system-prompt-file")
}
else if (arg == "-spf" || arg == "--system-prompt-file")
{
if (++i >= argc)
{
Expand Down Expand Up @@ -2678,6 +2686,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
exit(1);
}
}

gpt_params_handle_model_default(params);

if (!params.kv_overrides.empty()) {
params.kv_overrides.emplace_back();
params.kv_overrides.back().key[0] = 0;
Expand Down
2 changes: 1 addition & 1 deletion llm/llama.cpp
Submodule llama.cpp updated 94 files
+15 −1 .flake8
+13 −1 .github/workflows/bench.yml
+1 −1 .github/workflows/close-issue.yml
+1 −2 .github/workflows/python-lint.yml
+3 −2 .pre-commit-config.yaml
+1 −2 Makefile
+10 −36 README.md
+6 −2 ci/run.sh
+8 −1 common/common.cpp
+2 −1 common/common.h
+2 −2 common/log.h
+97 −75 convert-hf-to-gguf-update.py
+128 −151 convert-hf-to-gguf.py
+27 −24 convert-llama-ggml-to-gguf.py
+17 −15 convert-lora-to-ggml.py
+16 −12 convert-persimmon-to-gguf.py
+36 −24 convert.py
+17 −11 examples/batched-bench/batched-bench.cpp
+15 −4 examples/gguf-split/gguf-split.cpp
+7 −7 examples/gguf-split/tests.sh
+27 −3 examples/llama-bench/llama-bench.cpp
+1 −1 examples/main/main.cpp
+115 −3 examples/perplexity/README.md
+175 −55 examples/perplexity/perplexity.cpp
+1 −0 examples/server/bench/bench.py
+5 −1 examples/server/server.cpp
+56 −32 examples/server/tests/features/results.feature
+100 −48 examples/server/tests/features/steps/steps.py
+6 −0 ggml-cuda.cu
+62 −15 ggml-cuda/common.cuh
+944 −0 ggml-cuda/fattn.cu
+3 −0 ggml-cuda/fattn.cuh
+36 −10 ggml-cuda/softmax.cu
+7 −0 ggml-kompute.cpp
+376 −176 ggml-metal.m
+654 −18 ggml-metal.metal
+5 −1 ggml-sycl.cpp
+5 −0 ggml-vulkan.cpp
+361 −16 ggml.c
+20 −0 ggml.h
+9 −3 ggml_vk_generate_shaders.py
+10 −8 gguf-py/examples/reader.py
+1 −3 gguf-py/gguf/constants.py
+3 −1 gguf-py/gguf/gguf_reader.py
+4 −1 gguf-py/gguf/gguf_writer.py
+13 −29 gguf-py/gguf/vocab.py
+58 −36 gguf-py/scripts/gguf-convert-endian.py
+20 −9 gguf-py/scripts/gguf-dump.py
+18 −13 gguf-py/scripts/gguf-set-metadata.py
+387 −204 llama.cpp
+6 −3 llama.h
+4 −0 models/ggml-vocab-bert-bge.gguf.inp
+2 −0 models/ggml-vocab-bert-bge.gguf.out
+ models/ggml-vocab-command-r.gguf
+106 −0 models/ggml-vocab-command-r.gguf.inp
+43 −0 models/ggml-vocab-command-r.gguf.out
+4 −0 models/ggml-vocab-deepseek-coder.gguf.inp
+2 −0 models/ggml-vocab-deepseek-coder.gguf.out
+4 −0 models/ggml-vocab-deepseek-llm.gguf.inp
+2 −0 models/ggml-vocab-deepseek-llm.gguf.out
+4 −0 models/ggml-vocab-falcon.gguf.inp
+2 −0 models/ggml-vocab-falcon.gguf.out
+4 −0 models/ggml-vocab-gpt-2.gguf.inp
+2 −0 models/ggml-vocab-gpt-2.gguf.out
+4 −0 models/ggml-vocab-llama-bpe.gguf.inp
+2 −0 models/ggml-vocab-llama-bpe.gguf.out
+4 −0 models/ggml-vocab-llama-spm.gguf.inp
+2 −0 models/ggml-vocab-llama-spm.gguf.out
+4 −0 models/ggml-vocab-mpt.gguf.inp
+2 −0 models/ggml-vocab-mpt.gguf.out
+ models/ggml-vocab-phi-3.gguf
+4 −0 models/ggml-vocab-phi-3.gguf.inp
+2 −0 models/ggml-vocab-phi-3.gguf.out
+ models/ggml-vocab-refact.gguf
+106 −0 models/ggml-vocab-refact.gguf.inp
+43 −0 models/ggml-vocab-refact.gguf.out
+4 −0 models/ggml-vocab-starcoder.gguf.inp
+2 −0 models/ggml-vocab-starcoder.gguf.out
+1 −1 requirements/requirements-convert.txt
+16 −24 scripts/compare-llama-bench.py
+66 −0 scripts/gen-unicode-data.py
+10 −4 scripts/run-with-preset.py
+8 −5 scripts/verify-checksum-models.py
+5 −2 tests/CMakeLists.txt
+48 −4 tests/test-backend-ops.cpp
+0 −117 tests/test-tokenizer-0-bpe.py
+0 −114 tests/test-tokenizer-0-spm.py
+30 −9 tests/test-tokenizer-0.cpp
+46 −0 tests/test-tokenizer-0.py
+34 −0 tests/test-tokenizer-0.sh
+458 −416 unicode-data.cpp
+1 −1 unicode-data.h
+11 −11 unicode.cpp
+1 −1 unicode.h
9 changes: 9 additions & 0 deletions llm/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--numa")
}

if opts.FlashAttn {
flashAttnSupported := (gpus[0].Library == "cuda" && gpus[0].Major >= 7 || gpus[0].Library == "metal")
if flashAttnSupported {
params = append(params, "--flash-attn")
} else {
slog.Warn("flash attention is not supported on your current hardware configuration, it is now disabled")
}
}

// "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests
numParallel := 1
if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
Expand Down

0 comments on commit e400b1a

Please sign in to comment.