Skip to content

Commit

Permalink
update push to use model.Name
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng committed May 9, 2024
1 parent a4b3abe commit 354ef95
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 50 deletions.
31 changes: 15 additions & 16 deletions server/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -602,43 +602,42 @@ func CopyModel(src, dst model.Name) error {
return err
}

func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
func PushModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
fn(api.ProgressResponse{Status: "retrieving manifest"})

if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
m, err := ParseNamedManifest(name)
if err != nil {
return err
}

scheme := "https"
if opts.Insecure {
scheme = "http"
}

manifest, _, err := GetManifest(mp)
baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err
}

var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config)

for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
for _, layer := range append(m.Layers, m.Config) {
if err := uploadBlob(ctx, uploadOptions{name: name, baseURL: baseURL, layer: layer, regOpts: &opts, fn: fn}); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
return err
}
}

fn(api.ProgressResponse{Status: "pushing manifest"})
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
requestURL := baseURL.JoinPath("v2", name.Namespace, name.Model, "manifests", name.Tag)

manifestJSON, err := json.Marshal(manifest)
manifestJSON, err := json.Marshal(m)
if err != nil {
return err
}

headers := make(http.Header)
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), &opts)
if err != nil {
return err
}
Expand Down
26 changes: 8 additions & 18 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,24 +445,18 @@ func (s *Server) PullModelHandler(c *gin.Context) {
}

func (s *Server) PushModelHandler(c *gin.Context) {
var req api.PushRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
var r api.PushRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return
}

Expand All @@ -473,19 +467,15 @@ func (s *Server) PushModelHandler(c *gin.Context) {
ch <- r
}

regOpts := &registryOptions{
Insecure: req.Insecure,
}

ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()

if err := PushModel(ctx, model, regOpts, fn); err != nil {
if err := PushModel(ctx, n, registryOptions{Insecure: r.Insecure}, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()

if req.Stream != nil && !*req.Stream {
if r.Stream != nil && !*r.Stream {
waitForStream(c, ch)
return
}
Expand Down
39 changes: 23 additions & 16 deletions server/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/types/model"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -360,40 +361,46 @@ func (p *progressWriter) Rollback() {
p.written = 0
}

func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
type uploadOptions struct {
name model.Name
baseURL *url.URL
layer *Layer
regOpts *registryOptions
fn func(api.ProgressResponse)
}

func uploadBlob(ctx context.Context, opts uploadOptions) error {
requestURL := opts.baseURL.JoinPath("v2", opts.name.Namespace, opts.name.Model, "blobs", opts.layer.Digest)

resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts.regOpts)
switch {
case errors.Is(err, os.ErrNotExist):
case err != nil:
return err
default:
defer resp.Body.Close()
fn(api.ProgressResponse{
Status: fmt.Sprintf("pushing %s", layer.Digest[7:19]),
Digest: layer.Digest,
Total: layer.Size,
Completed: layer.Size,
opts.fn(api.ProgressResponse{
Status: fmt.Sprintf("pushing %s", opts.layer.Digest[7:19]),
Digest: opts.layer.Digest,
Total: opts.layer.Size,
Completed: opts.layer.Size,
})

return nil
}

data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
data, ok := blobUploadManager.LoadOrStore(opts.layer.Digest, &blobUpload{Layer: opts.layer})
upload := data.(*blobUpload)
if !ok {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
blobUploadManager.Delete(layer.Digest)
requestURL := opts.baseURL.JoinPath("v2", opts.name.Namespace, opts.name.Model, "blobs/uploads/")
if err := upload.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobUploadManager.Delete(opts.layer.Digest)
return err
}

// nolint: contextcheck
go upload.Run(context.Background(), opts)
go upload.Run(context.Background(), opts.regOpts)
}

return upload.Wait(ctx, fn)
return upload.Wait(ctx, opts.fn)
}

0 comments on commit 354ef95

Please sign in to comment.