Skip to content

Commit

Permalink
fix conversion for f16 or f32 inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng committed May 18, 2024
1 parent 5e3e177 commit 8698064
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 290 deletions.
49 changes: 14 additions & 35 deletions convert/gemma.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
package convert

import (
"encoding/binary"
"fmt"
"io"
"log/slog"
"os"
"strings"

"github.com/d4l3k/go-bfloat16"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"

Expand All @@ -19,49 +16,27 @@ type GemmaModel struct {
ModelData
}

func gemmaLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name))

data := make([]byte, r.end-r.start)
if err := binary.Read(f, r.bo, data); err != nil {
return err
}

tDataF32 := bfloat16.DecodeFloat32(data)

var err error
tDataF32, err = addOnes(tDataF32, int(r.t.Shape[0]))
if err != nil {
return err
}

if err := binary.Write(w, r.bo, tDataF32); err != nil {
return err
}
return nil
}

func addOnes(data []float32, vectorSize int) ([]float32, error) {
n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, vectorSize)

var err error
n, err = n.Add(ones)
n, err := n.Add(ones)
if err != nil {
return []float32{}, err
return nil, err
}

newN, err := native.SelectF32(n, 0)
ts, err := native.SelectF32(n, 1)
if err != nil {
return []float32{}, err
return nil, err
}

var fullTensor []float32
for _, v := range newN {
fullTensor = append(fullTensor, v...)
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}

return fullTensor, nil

return f32s, nil
}

func (m *GemmaModel) GetTensors() error {
Expand All @@ -74,7 +49,7 @@ func (m *GemmaModel) GetTensors() error {
for _, l := range t {
if strings.HasSuffix(l.Name, "norm.weight") {
wt := l.WriterTo.(safetensorWriterTo)
wt.handler = gemmaLayerHandler
wt.repacker = m.Repack
l.WriterTo = wt
}
m.Tensors = append(m.Tensors, l)
Expand All @@ -92,6 +67,10 @@ func (m *GemmaModel) LoadVocab() error {
return nil
}

func (m *GemmaModel) Repack(_ string, data []float32, shape []uint64) ([]float32, error) {
return addOnes(data, int(shape[0]))
}

func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{
"general.architecture": "gemma",
Expand Down
136 changes: 54 additions & 82 deletions convert/llama.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package convert

import (
"encoding/binary"
"cmp"
"errors"
"fmt"
"io"
Expand All @@ -10,10 +10,8 @@ import (
"regexp"
"strings"

"github.com/nlpodyssey/gopickle/pytorch"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/x448/float16"

"github.com/ollama/ollama/llm"
)
Expand All @@ -22,83 +20,6 @@ type LlamaModel struct {
ModelData
}

func llamaTorchLayerHandler(w io.Writer, r torchWriterTo) error {

var tData []uint16
switch r.storage.(type) {
case *pytorch.HalfStorage:
data := r.storage.(*pytorch.HalfStorage).Data
tData = make([]uint16, len(data))
for cnt, v := range data {
tData[cnt] = uint16(float16.Fromfloat32(v))
}
case *pytorch.BFloat16Storage:
data := r.storage.(*pytorch.BFloat16Storage).Data
tData = make([]uint16, len(data))

for cnt, v := range data {
tData[cnt] = uint16(float16.Fromfloat32(v))
}
default:
return fmt.Errorf("unknown storage type for torch")
}

var err error
var heads uint32
if strings.Contains(r.t.Name, "attn_q") {
heads = uint32(r.params.AttentionHeads)
} else if strings.Contains(r.t.Name, "attn_k") {
heads = uint32(r.params.KeyValHeads)
if heads == 0 {
heads = uint32(r.params.AttentionHeads)
}
} else {
return fmt.Errorf("unknown layer type")
}

tData, err = llamaRepack(tData, int(heads), r.t.Shape)
if err != nil {
return err
}

if err = binary.Write(w, r.bo, tData); err != nil {
return err
}
return nil
}

func llamaRepack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
origShape := n.Shape().Clone()

// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
return nil, err
}

if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}

if err := n.Reshape(origShape...); err != nil {
return nil, err
}

if err := n.Transpose(); err != nil {
return nil, err
}
newN, err := native.SelectU16(n, 1)
if err != nil {
return nil, err
}

var fullTensor []uint16
for _, v := range newN {
fullTensor = append(fullTensor, v...)
}
return fullTensor, nil
}

func (m *LlamaModel) GetTensors() error {
t, err := m.Format.GetTensors(m.Path, m.Params)
if err != nil {
Expand All @@ -117,11 +38,11 @@ func (m *LlamaModel) GetTensors() error {
switch m.Format.(type) {
case *TorchFormat:
wt := l.WriterTo.(torchWriterTo)
wt.handler = llamaTorchLayerHandler
wt.repacker = m.Repack
l.WriterTo = wt
case *SafetensorFormat:
wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler
wt.repacker = m.Repack
l.WriterTo = wt
}
}
Expand Down Expand Up @@ -184,3 +105,54 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {

return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
}

func (m *LlamaModel) Repack(name string, data []float32, shape []uint64) ([]float32, error) {
return llamaRepack(name, m.Params, data, shape)
}

func llamaRepack(name string, params *Params, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
if dim != 0 {
dims = append(dims, int(dim))
}
}

var heads int
if strings.HasSuffix(name, "attn_q.weight") {
heads = params.AttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight") {
heads = cmp.Or(params.KeyValHeads, params.AttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor name: %s", name)
}

n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{heads, 2, dims[0] / heads / 2}, dims[1:]...)...); err != nil {
return nil, err
}

if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}

if err := n.Reshape(dims...); err != nil {
return nil, err
}

if err := n.Transpose(); err != nil {
return nil, err
}

ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}

var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}

return f32s, nil
}
91 changes: 5 additions & 86 deletions convert/mistral.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
package convert

import (
"encoding/binary"
"fmt"
"io"
"os"
"regexp"
"strings"

"github.com/d4l3k/go-bfloat16"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/x448/float16"

"github.com/ollama/ollama/llm"
)
Expand All @@ -20,82 +11,6 @@ type MistralModel struct {
ModelData
}

func mistralLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
layerSize := r.end - r.start

var err error
tData := make([]uint16, layerSize/2)
if err = binary.Read(f, r.bo, tData); err != nil {
return err
}

var heads uint32
if strings.Contains(r.t.Name, "attn_q") {
heads = uint32(r.params.AttentionHeads)
} else if strings.Contains(r.t.Name, "attn_k") {
heads = uint32(r.params.KeyValHeads)
if heads == 0 {
heads = uint32(r.params.AttentionHeads)
}
} else {
return fmt.Errorf("unknown layer type")
}

tData, err = repack(tData, int(heads), r.t.Shape)
if err != nil {
return err
}

var buf []byte
for _, n := range tData {
buf = r.bo.AppendUint16(buf, n)
}

tempBuf := make([]uint16, len(tData))
tDataF32 := bfloat16.DecodeFloat32(buf)
for cnt, v := range tDataF32 {
tDataF16 := float16.Fromfloat32(v)
tempBuf[cnt] = uint16(tDataF16)
}

if err = binary.Write(w, r.bo, tempBuf); err != nil {
return err
}
return nil
}

func repack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
origShape := n.Shape().Clone()

// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
return nil, err
}

if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}

if err := n.Reshape(origShape...); err != nil {
return nil, err
}

if err := n.Transpose(); err != nil {
return nil, err
}
newN, err := native.SelectU16(n, 1)
if err != nil {
return nil, err
}

var fullTensor []uint16
for _, v := range newN {
fullTensor = append(fullTensor, v...)
}
return fullTensor, nil
}

func (m *MistralModel) GetTensors() error {
t, err := m.Format.GetTensors(m.Path, m.Params)
if err != nil {
Expand All @@ -112,7 +27,7 @@ func (m *MistralModel) GetTensors() error {
matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 {
wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler
wt.repacker = m.Repack
l.WriterTo = wt
}
m.Tensors = append(m.Tensors, l)
Expand Down Expand Up @@ -158,3 +73,7 @@ func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {

return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
}

func (m *MistralModel) Repack(name string, data []float32, shape []uint64) ([]float32, error) {
return llamaRepack(name, m.Params, data, shape)
}

0 comments on commit 8698064

Please sign in to comment.