Skip to content

Commit

Permalink
update pull handler to use model.Name
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng committed May 9, 2024
1 parent 69c0caf commit cff2588
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 84 deletions.
11 changes: 6 additions & 5 deletions server/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/types/model"
)

const maxRetries = 6
Expand Down Expand Up @@ -332,15 +333,16 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
}
}

type downloadOpts struct {
mp ModelPath
type downloadOptions struct {
name model.Name
baseURL *url.URL
digest string
regOpts *registryOptions
fn func(api.ProgressResponse)
}

// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error {
func downloadBlob(ctx context.Context, opts downloadOptions) error {
fp, err := GetBlobsPath(opts.digest)
if err != nil {
return err
Expand All @@ -365,8 +367,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
download := data.(*blobDownload)
if !ok {
requestURL := opts.mp.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
requestURL := opts.baseURL.JoinPath("blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest)
return err
Expand Down
86 changes: 26 additions & 60 deletions server/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"runtime"
"strconv"
Expand Down Expand Up @@ -548,7 +549,8 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio

if !envconfig.NoPrune {
if old, err := ParseNamedManifest(name); err == nil {
defer func() { _ = old.RemoveLayers() }()
// nolint: errcheck
defer old.RemoveLayers()
}
}

Expand Down Expand Up @@ -773,59 +775,43 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil
}

func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)

var manifest *ManifestV2
var err error
var noprune string

// build deleteMap to prune unused layers
deleteMap := make(map[string]struct{})
func PullModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
old, _ := ParseNamedManifest(name)

if !envconfig.NoPrune {
manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if !name.IsFullyQualified() {
return model.Unqualified(name)
}

if manifest != nil {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = struct{}{}
}
deleteMap[manifest.Config.Digest] = struct{}{}
}
scheme := "https"
if opts.Insecure {
scheme = "http"
}

if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
if err != nil {
return err
}

fn(api.ProgressResponse{Status: "pulling manifest"})

manifest, err = pullModelManifest(ctx, mp, regOpts)
m, err := pullModelManifest(ctx, name, baseURL, &opts)
if err != nil {
return fmt.Errorf("pull model manifest: %s", err)
}

var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config)

layers := append(m.Layers, m.Config)
for _, layer := range layers {
if err := downloadBlob(
ctx,
downloadOpts{
mp: mp,
downloadOptions{
name: name,
baseURL: baseURL,
digest: layer.Digest,
regOpts: regOpts,
regOpts: &opts,
fn: fn,
}); err != nil {
return err
}
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)

fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
Expand All @@ -846,45 +832,25 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}

fn(api.ProgressResponse{Status: "writing manifest"})

manifestJSON, err := json.Marshal(manifest)
if err != nil {
if err := WriteManifest(name, m.Config, m.Layers); err != nil {
return err
}

fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}

err = os.WriteFile(fp, manifestJSON, 0o644)
if err != nil {
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
return err
}

if noprune == "" {
if !envconfig.NoPrune && old != nil {
fn(api.ProgressResponse{Status: "removing any unused layers"})
err = deleteUnusedLayers(nil, deleteMap, false)
if err != nil {
return err
}
_ = old.RemoveLayers()
}

fn(api.ProgressResponse{Status: "success"})

return nil
}

func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
func pullModelManifest(ctx context.Context, name model.Name, baseURL *url.URL, opts *registryOptions) (*ManifestV2, error) {
requestURL := baseURL.JoinPath("manifests", name.Tag)

headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion server/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
m, err := ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
if err := PullModel(ctx, name, registryOptions{}, fn); err != nil {
return nil, err
}

Expand Down
26 changes: 8 additions & 18 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,24 +406,18 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
}

func (s *Server) PullModelHandler(c *gin.Context) {
var req api.PullRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.PullRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return
}

Expand All @@ -434,19 +428,15 @@ func (s *Server) PullModelHandler(c *gin.Context) {
ch <- r
}

regOpts := &registryOptions{
Insecure: req.Insecure,
}

ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()

if err := PullModel(ctx, model, regOpts, fn); err != nil {
if err := PullModel(ctx, n, registryOptions{Insecure: r.Insecure}, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()

if req.Stream != nil && !*req.Stream {
if r.Stream != nil && !*r.Stream {
waitForStream(c, ch)
return
}
Expand Down

0 comments on commit cff2588

Please sign in to comment.