Skip to content

Commit

Permalink
feat: add flash_attn support
Browse files Browse the repository at this point in the history
  • Loading branch information
sammcj committed May 4, 2024
1 parent 3fca6e7 commit 3024414
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
2 changes: 2 additions & 0 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ type Runner struct {
UseMMap bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"`
NumThread int `json:"num_thread,omitempty"`
FlashAttn bool `json:"flash_attn,omitempty"`

// Unused: RopeFrequencyBase is ignored. Instead the value in the model will be used
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
Expand Down Expand Up @@ -428,6 +429,7 @@ func DefaultOptions() Options {
UseMLock: false,
UseMMap: true,
UseNUMA: false,
FlashAttn: false, // for CPU only compatibility
},
}
}
Expand Down
9 changes: 6 additions & 3 deletions llm/ext_server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2508,7 +2508,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
{
params.use_mmap = false;
}
else if (arg == "--numa") {
else if (arg == "--numa")
{
if (++i >= argc) {
invalid_param = true;
break;
Expand Down Expand Up @@ -2540,15 +2541,17 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
params.n_parallel = std::stoi(argv[i]);
} else if (arg == "-n" || arg == "--n-predict")
}
else if (arg == "-n" || arg == "--n-predict")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.n_predict = std::stoi(argv[i]);
} else if (arg == "-spf" || arg == "--system-prompt-file")
}
else if (arg == "-spf" || arg == "--system-prompt-file")
{
if (++i >= argc)
{
Expand Down

0 comments on commit 3024414

Please sign in to comment.