Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng committed May 17, 2024
1 parent 9097d8d commit d56887a
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 51 deletions.
14 changes: 12 additions & 2 deletions convert/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ import (
"github.com/ollama/ollama/llm"
)

const (
_ int32 = iota
tokenTypeNormal
tokenTypeUnknown
tokenTypeControl
tokenTypeUserDefined
tokenTypeUnused
tokenTypeByte
)

type Params struct {
Architectures []string `json:"architectures"`
VocabSize int `json:"vocab_size"`
Expand Down Expand Up @@ -172,7 +182,7 @@ func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
}
v.Tokens = append(v.Tokens, t.key)
v.Scores = append(v.Scores, -1000.0)
v.Types = append(v.Types, int32(llm.GGUFTokenUserDefined))
v.Types = append(v.Types, tokenTypeUserDefined)
}
slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))

Expand All @@ -182,7 +192,7 @@ func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
for cnt := 0; cnt < missingTokens; cnt++ {
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
v.Scores = append(v.Scores, -1)
v.Types = append(v.Types, int32(llm.GGUFTokenUserDefined))
v.Types = append(v.Types, tokenTypeUserDefined)
}
}

Expand Down
2 changes: 0 additions & 2 deletions convert/gemma.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ func (m *GemmaModel) GetTensors() error {
}

slog.Debug(fmt.Sprintf("Total tensors: %d", len(t)))

m.Tensors = []llm.Tensor{}
for _, l := range t {
if strings.HasSuffix(l.Name, "norm.weight") {
wt := l.WriterTo.(safetensorWriterTo)
Expand Down
32 changes: 11 additions & 21 deletions convert/llama.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ func (m *LlamaModel) GetTensors() error {
return err
}

m.Tensors = []llm.Tensor{}

pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
re, err := regexp.Compile(pattern)
if err != nil {
Expand All @@ -133,30 +131,22 @@ func (m *LlamaModel) GetTensors() error {
return nil
}

func (m *LlamaModel) LoadVocab() error {
v := &Vocab{}

tokpath := filepath.Join(m.Path, "tokenizer.json")
pre, ts, merges, err := parseTokens(tokpath)
func (m *LlamaModel) LoadVocab() (err error) {
pre, ts, merges, err := parseTokens(filepath.Join(m.Path, "tokenizer.json"))
if errors.Is(err, os.ErrNotExist) {
v, err = LoadSentencePieceTokens(m.Path, m.Params)
if err != nil {
return err
}
return nil
} else if err != nil {
return err
} else {
for _, t := range ts {
v.Tokens = append(v.Tokens, t.Content)
v.Types = append(v.Types, t.Type())
}

m.Params.PreTokenizer = pre
v.Merges = merges
}

m.Vocab = v
m.Vocab = &Vocab{}
for _, t := range ts {
m.Vocab.Tokens = append(m.Vocab.Tokens, t.Content)
m.Vocab.Types = append(m.Vocab.Types, t.Type())
}

m.Vocab.Merges = merges
m.Params.PreTokenizer = pre
return nil
}

Expand All @@ -174,7 +164,7 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
"general.file_type": uint32(2),
"general.file_type": uint32(1),
"tokenizer.ggml.model": "gpt2",

"tokenizer.ggml.pre": m.Params.PreTokenizer,
Expand Down
2 changes: 0 additions & 2 deletions convert/mistral.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ func (m *MistralModel) GetTensors() error {
return err
}

m.Tensors = []llm.Tensor{}

pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
re, err := regexp.Compile(pattern)
if err != nil {
Expand Down
2 changes: 0 additions & 2 deletions convert/mixtral.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ func (m *MixtralModel) GetTensors() error {
return err
}

m.Tensors = []llm.Tensor{}

pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
re, err := regexp.Compile(pattern)
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions convert/safetensors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"path/filepath"
"regexp"
"slices"
"strings"

"github.com/d4l3k/go-bfloat16"
"github.com/mitchellh/mapstructure"
Expand Down Expand Up @@ -97,6 +98,10 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)

var tensors []llm.Tensor
for _, k := range keys {
if strings.HasSuffix(k, "self_attn.rotary_emb.inv_freq") {
continue
}

vals := parsed[k].(map[string]interface{})
var data tensorMetaData
if err = mapstructure.Decode(vals, &data); err != nil {
Expand Down
6 changes: 3 additions & 3 deletions convert/tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ type Token struct {
func (t *Token) Type() int32 {
switch {
case t.Special:
return 3
return tokenTypeControl
case t.UserDefined:
return 4
return tokenTypeUserDefined
default:
return 1
return tokenTypeNormal
}
}

Expand Down
13 changes: 4 additions & 9 deletions convert/torch.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,13 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
slog.Debug("getting torch tensors")

var files []string
var err error
files, err = filepath.Glob(filepath.Join(dirpath, "consolidated.*.pth"))
if err != nil {
files, err = filepath.Glob(filepath.Join(dirpath, "pytorch_model-*.bin"))
if err != nil {
slog.Error("didn't find any torch files")
return nil, err
}
if pt, _ := filepath.Glob(filepath.Join(dirpath, "consolidated*.pth")); len(pt) > 0 {
files = append(files, pt...)
} else if pt, _ := filepath.Glob(filepath.Join(dirpath, "pytorch_model*.pth")); len(pt) > 0 {
files = append(files, pt...)
}

var offset uint64

var tensors []llm.Tensor
for _, fn := range files {
m, err := pytorch.Load(fn)
Expand Down
10 changes: 0 additions & 10 deletions llm/gguf.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,6 @@ func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
return model, nil
}

const (
_ uint32 = iota
GGUFTokenNormal
GGUFTokenUnknown
GGUFTokenControl
GGUFTokenUserDefined
GGUFTokenUnused
GGUFTokenByte
)

const (
ggufTypeUint8 uint32 = iota
ggufTypeInt8
Expand Down

0 comments on commit d56887a

Please sign in to comment.