Skip to content

Commit

Permalink
cache and reuse intermediate blobs
Browse files Browse the repository at this point in the history
particularly useful for zipfiles and f16s
  • Loading branch information
mxyng committed May 16, 2024
1 parent 84ed77c commit 39efb30
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 18 deletions.
28 changes: 25 additions & 3 deletions server/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,25 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return err
}
} else if strings.HasPrefix(c.Args, "@") {
blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
digest := strings.TrimPrefix(c.Args, "@")
slog.Info("original", "digest", digest)
if ib, ok := intermediateBlobs.Load(digest); ok {
p, err := GetBlobsPath(ib.(string))
if err != nil {
return err
}

if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
// pass
} else if err != nil {
return err
} else {
fn(api.ProgressResponse{Status: fmt.Sprintf("using cached layer %s", ib.(string))})
digest = ib.(string)
}
}

blobpath, err := GetBlobsPath(digest)
if err != nil {
return err
}
Expand All @@ -350,14 +368,14 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
}
defer blob.Close()

baseLayers, err = parseFromFile(ctx, blob, fn)
baseLayers, err = parseFromFile(ctx, blob, digest, fn)
if err != nil {
return err
}
} else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil {
defer file.Close()

baseLayers, err = parseFromFile(ctx, file, fn)
baseLayers, err = parseFromFile(ctx, file, "", fn)
if err != nil {
return err
}
Expand Down Expand Up @@ -397,10 +415,14 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return err
}

f16digest := baseLayer.Layer.Digest

baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType)
if err != nil {
return err
}

intermediateBlobs.Store(f16digest, baseLayer.Layer.Digest)
}
}

Expand Down
2 changes: 1 addition & 1 deletion server/layer.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
}, nil
}

func (l *Layer) Open() (io.ReadCloser, error) {
func (l *Layer) Open() (io.ReadSeekCloser, error) {
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return nil, err
Expand Down
23 changes: 9 additions & 14 deletions server/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ import (
"net/http"
"os"
"path/filepath"
"sync"

"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/types/model"
)

var intermediateBlobs sync.Map

type layerWithGGML struct {
*Layer
*llm.GGML
Expand Down Expand Up @@ -76,7 +79,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
return layers, nil
}

func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
stat, err := file.Stat()
if err != nil {
return nil, err
Expand Down Expand Up @@ -169,12 +172,7 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp
return nil, fmt.Errorf("aaa: %w", err)
}

blobpath, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}

bin, err := os.Open(blobpath)
bin, err := layer.Open()
if err != nil {
return nil, err
}
Expand All @@ -185,16 +183,13 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp
return nil, err
}

layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "")
if err != nil {
return nil, err
}

layers = append(layers, &layerWithGGML{layer, ggml})

intermediateBlobs.Store(digest, layer.Digest)
return layers, nil
}

func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
sr := io.NewSectionReader(file, 0, 512)
contentType, err := detectContentType(sr)
if err != nil {
Expand All @@ -205,7 +200,7 @@ func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressRespo
case "gguf", "ggla":
// noop
case "application/zip":
return parseFromZipFile(ctx, file, fn)
return parseFromZipFile(ctx, file, digest, fn)
default:
return nil, fmt.Errorf("unsupported content type: %s", contentType)
}
Expand Down
19 changes: 19 additions & 0 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,25 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
}

func (s *Server) CreateBlobHandler(c *gin.Context) {
ib, ok := intermediateBlobs.Load(c.Param("digest"))
if ok {
p, err := GetBlobsPath(ib.(string))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
intermediateBlobs.Delete(c.Param("digest"))
} else if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
} else {
c.Status(http.StatusOK)
return
}
}

path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
Expand Down

0 comments on commit 39efb30

Please sign in to comment.