Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for flash_attn #4120

Merged
merged 8 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 11 additions & 3 deletions llm/ext_server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2104,6 +2104,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
printf(" -spf FNAME, --system-prompt-file FNAME\n");
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
printf(" -ctk TYPE, --cache-type-k TYPE\n");
Expand Down Expand Up @@ -2501,7 +2502,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
{
params.use_mmap = false;
}
else if (arg == "--numa") {
else if (arg == "--numa")
{
if (++i >= argc) {
invalid_param = true;
break;
Expand All @@ -2521,6 +2523,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
{
params.cont_batching = true;
}
else if (arg == "-fa" || arg == "--flash-attn")
sammcj marked this conversation as resolved.
Show resolved Hide resolved
{
params.flash_attn = true;
}
else if (arg == "-np" || arg == "--parallel")
{
if (++i >= argc)
Expand All @@ -2529,15 +2535,17 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
params.n_parallel = std::stoi(argv[i]);
} else if (arg == "-n" || arg == "--n-predict")
}
else if (arg == "-n" || arg == "--n-predict")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.n_predict = std::stoi(argv[i]);
} else if (arg == "-spf" || arg == "--system-prompt-file")
}
else if (arg == "-spf" || arg == "--system-prompt-file")
{
if (++i >= argc)
{
Expand Down
20 changes: 20 additions & 0 deletions llm/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,26 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--numa")
}

// Only enable flash_attn if all GPUs support it (CUDA 7+ or Metal)
flashAttnSupported := false
sammcj marked this conversation as resolved.
Show resolved Hide resolved
if cpuRunner != "" {
flashAttnSupported = false
} else {
for _, g := range gpus {
if g.Library == "cuda" && g.DriverMajor >= 7 {
flashAttnSupported = true
} else if g.Library == "metal" {
flashAttnSupported = true
} else {
flashAttnSupported = false
break
}
}
}
if flashAttnSupported {
params = append(params, "--flash-attn")
}

numParallel := envconfig.NumParallel

// TODO (jmorganca): multimodal models don't support parallel yet
Expand Down