Skip to content

Commit

Permalink
parse special tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng committed May 16, 2024
1 parent 97897ea commit 27588a7
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 12 deletions.
3 changes: 2 additions & 1 deletion convert/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ type Params struct {
NormEPS float64 `json:"rms_norm_eps"`
BoSTokenID int `json:"bos_token_id"`
EoSTokenID int `json:"eos_token_id"`
PadTokenID int `json:"pad_token_id"`
UnkTokenID int `json:"unk_token_id"`
HeadDimension int `json:"head_dim"`
PaddingTokenID int `json:"pad_token_id"`
RopeFrequencyBase float64 `json:"rope_theta"`

Experts int `json:"num_local_experts"`
Expand Down
2 changes: 1 addition & 1 deletion convert/gemma.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {

"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
"tokenizer.ggml.padding_token_id": uint32(m.Params.PaddingTokenID),
"tokenizer.ggml.padding_token_id": uint32(m.Params.PadTokenID),
"tokenizer.ggml.unknown_token_id": uint32(3),
"tokenizer.ggml.add_bos_token": true,
"tokenizer.ggml.add_eos_token": false,
Expand Down
42 changes: 33 additions & 9 deletions convert/llama.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package convert

import (
"encoding/binary"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strings"
Expand Down Expand Up @@ -132,17 +130,31 @@ func (m *LlamaModel) GetTensors() error {
}

func (m *LlamaModel) LoadVocab() (err error) {
pre, ts, merges, err := parseTokens(filepath.Join(m.Path, "tokenizer.json"))
if errors.Is(err, os.ErrNotExist) {
return nil
} else if err != nil {
c, err := parseSpecialTokens(filepath.Join(m.Path, "special_tokens_map.json"))
if err != nil {
return err
}

pre, ts, merges, err := parseTokenizer(filepath.Join(m.Path, "tokenizer.json"))
if err != nil {
return err
}

m.Vocab = &Vocab{}
for _, t := range ts {
m.Vocab.Tokens = append(m.Vocab.Tokens, t.Content)
m.Vocab.Types = append(m.Vocab.Types, t.Type())

switch t.Content {
case c.bos:
m.Params.BoSTokenID = t.ID
case c.eos:
m.Params.EoSTokenID = t.ID
case c.pad:
m.Params.PadTokenID = t.ID
case c.unk:
m.Params.UnkTokenID = t.ID
}
}

m.Vocab.Merges = merges
Expand Down Expand Up @@ -170,10 +182,22 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
"tokenizer.ggml.pre": m.Params.PreTokenizer,
"tokenizer.ggml.tokens": m.Vocab.Tokens,
"tokenizer.ggml.token_type": m.Vocab.Types,
}

if m.Params.BoSTokenID != 0 {
kv["tokenizer.ggml.bos_token_id"] = uint32(m.Params.BoSTokenID)
}

if m.Params.EoSTokenID != 0 {
kv["tokenizer.ggml.eos_token_id"] = uint32(m.Params.EoSTokenID)
}

if m.Params.PadTokenID != 0 {
kv["tokenizer.ggml.pad_token_id"] = uint32(m.Params.PadTokenID)
}

"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
"tokenizer.ggml.unknown_token_id": uint32(0),
if m.Params.UnkTokenID != 0 {
kv["tokenizer.ggml.unk_token_id"] = uint32(m.Params.UnkTokenID)
}

if len(m.Vocab.Merges) > 0 {
Expand Down
51 changes: 51 additions & 0 deletions convert/tokenizer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package convert

import (
"encoding/json"
"os"
)

type specialTokens struct {
bos, eos, pad, unk string
}

func parseSpecialTokens(p string) (*specialTokens, error) {
f, err := os.Open(p)
if err != nil {
return nil, err
}
defer f.Close()

var m map[string]json.RawMessage
if err := json.NewDecoder(f).Decode(&m); err != nil {
return nil, err
}

parse := func(v json.RawMessage) string {
var s struct {
Content string `json:"content"`
}

if err := json.Unmarshal(v, &s); err != nil {
return string(v)
}

return s.Content
}

var sts specialTokens
for k, v := range m {
switch k {
case "bos_token":
sts.bos = parse(v)
case "eos_token":
sts.eos = parse(v)
case "pad_token":
sts.pad = parse(v)
case "unk_token":
sts.unk = parse(v)
}
}

return &sts, nil
}
2 changes: 1 addition & 1 deletion convert/tokenizer_bpe.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (t *Tokenizer) maxID() int {
)
}

func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, err error) {
func parseTokenizer(dirpath string) (pre string, tokens []Token, merges []string, err error) {
f, err := os.Open(dirpath)
if err != nil {
panic(err)
Expand Down

0 comments on commit 27588a7

Please sign in to comment.