Skip to content

Commit

Permalink
update create handler to use model.Name
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng committed May 9, 2024
1 parent 6a10185 commit 802b2ad
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 55 deletions.
21 changes: 4 additions & 17 deletions server/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ func realpath(rel, from string) string {
return abspath
}

func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
Expand Down Expand Up @@ -546,16 +546,9 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
}
}

unref := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range manifest.Layers {
if !slices.Contains(digests, layer.Digest) {
unref[layer.Digest] = struct{}{}
}
}

if manifest.Config.Digest != layer.Digest {
unref[manifest.Config.Digest] = struct{}{}
if !envconfig.NoPrune {
if old, err := ParseNamedManifest(name); err == nil {
defer func() { _ = old.RemoveLayers() }()
}
}

Expand All @@ -564,12 +557,6 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return err
}

if !envconfig.NoPrune {
if err := deleteUnusedLayers(nil, unref, false); err != nil {
return err
}
}

fn(api.ProgressResponse{Status: "success"})
return nil
}
Expand Down
45 changes: 26 additions & 19 deletions server/manifest.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package server

import (
"bytes"
"crypto/sha256"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -33,12 +32,6 @@ func (m *Manifest) Remove() error {
return err
}

for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); err != nil {
return err
}
}

manifests, err := GetManifestPath()
if err != nil {
return err
Expand All @@ -47,6 +40,16 @@ func (m *Manifest) Remove() error {
return PruneDirectory(manifests)
}

func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); err != nil {
return err
}
}

return nil
}

func ParseNamedManifest(n model.Name) (*Manifest, error) {
if !n.IsFullyQualified() {
return nil, model.Unqualified(n)
Expand Down Expand Up @@ -84,30 +87,34 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
}, nil
}

func WriteManifest(name string, config *Layer, layers []*Layer) error {
manifest := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,
Layers: layers,
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
manifests, err := GetManifestPath()
if err != nil {
return err
}

var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(manifest); err != nil {
p := filepath.Join(manifests, name.Filepath())
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
return err
}

modelpath := ParseModelPath(name)
manifestPath, err := modelpath.GetManifestPath()
f, err := os.Create(p)
if err != nil {
return err
}

if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
m := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,
Layers: layers,
}

if err := json.NewEncoder(f).Encode(m); err != nil {
return err
}

return os.WriteFile(manifestPath, b.Bytes(), 0o644)
return nil
}

func Manifests() (map[model.Name]*Manifest, error) {
Expand Down
10 changes: 4 additions & 6 deletions server/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,23 @@ type layerWithGGML struct {
}

func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
modelpath := ParseModelPath(name.String())
manifest, _, err := GetManifest(modelpath)
m, err := ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err
}

modelpath = ParseModelPath(name.String())
manifest, _, err = GetManifest(modelpath)
m, err = ParseNamedManifest(name)
if err != nil {
return nil, err
}
case err != nil:
return nil, err
}

for _, layer := range manifest.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
for _, layer := range m.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
if err != nil {
return nil, err
}
Expand Down
29 changes: 17 additions & 12 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,39 +504,39 @@ func (s *Server) PushModelHandler(c *gin.Context) {
}

func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
var r api.CreateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

name := model.ParseName(cmp.Or(req.Model, req.Name))
name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
return
}

if req.Path == "" && req.Modelfile == "" {
if r.Path == "" && r.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return
}

var r io.Reader = strings.NewReader(req.Modelfile)
if req.Path != "" && req.Modelfile == "" {
f, err := os.Open(req.Path)
var rr io.Reader = strings.NewReader(r.Modelfile)
if r.Path != "" && r.Modelfile == "" {
ff, err := os.Open(r.Path)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return
}
defer f.Close()
defer ff.Close()

r = f
rr = ff
}

modelfile, err := model.ParseFile(r)
f, err := model.ParseFile(rr)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
Expand All @@ -552,12 +552,12 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()

if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(req.Quantization), modelfile, fn); err != nil {
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(r.Quantization), f, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()

if req.Stream != nil && !*req.Stream {
if r.Stream != nil && !*r.Stream {
waitForStream(c, ch)
return
}
Expand Down Expand Up @@ -591,6 +591,11 @@ func (s *Server) DeleteModelHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

if err := m.RemoveLayers(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}

func (s *Server) ShowModelHandler(c *gin.Context) {
Expand Down
2 changes: 1 addition & 1 deletion server/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func Test_Routes(t *testing.T) {
fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status)
}
err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
assert.Nil(t, err)
}

Expand Down

0 comments on commit 802b2ad

Please sign in to comment.