Skip to content

Commit

Permalink
feat: enable flash attention if supported
Browse files Browse the repository at this point in the history
  • Loading branch information
sammcj committed May 16, 2024
1 parent 5bece94 commit 7857efa
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
14 changes: 11 additions & 3 deletions llm/ext_server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2104,6 +2104,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
printf(" -spf FNAME, --system-prompt-file FNAME\n");
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
printf(" -ctk TYPE, --cache-type-k TYPE\n");
Expand Down Expand Up @@ -2501,7 +2502,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
{
params.use_mmap = false;
}
else if (arg == "--numa") {
else if (arg == "--numa")
{
if (++i >= argc) {
invalid_param = true;
break;
Expand All @@ -2521,6 +2523,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
{
params.cont_batching = true;
}
else if (arg == "-fa" || arg == "--flash-attn")
{
params.flash_attn = true;
}
else if (arg == "-np" || arg == "--parallel")
{
if (++i >= argc)
Expand All @@ -2529,15 +2535,17 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
params.n_parallel = std::stoi(argv[i]);
} else if (arg == "-n" || arg == "--n-predict")
}
else if (arg == "-n" || arg == "--n-predict")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.n_predict = std::stoi(argv[i]);
} else if (arg == "-spf" || arg == "--system-prompt-file")
}
else if (arg == "-spf" || arg == "--system-prompt-file")
{
if (++i >= argc)
{
Expand Down
16 changes: 16 additions & 0 deletions llm/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,22 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--numa")
}

// Only enable flash_attn if all GPUs support it (CUDA 7+ or Metal)
flashAttnSupported := false
for _, g := range gpus {
if g.Library == "cuda" && g.DriverMajor >= 7 {
flashAttnSupported = true
} else if g.Library == "metal" {
flashAttnSupported = true
} else {
flashAttnSupported = false
break
}
}
if flashAttnSupported {
params = append(params, "--flash-attn")
}

numParallel := envconfig.NumParallel

// TODO (jmorganca): multimodal models don't support parallel yet
Expand Down

0 comments on commit 7857efa

Please sign in to comment.