Skip to content

Commit

Permalink
update create handler to use model.Name
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng committed May 10, 2024
1 parent 513f3c3 commit 5013cd0
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 66 deletions.
6 changes: 4 additions & 2 deletions app/lifecycle/server_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ func terminate(cmd *exec.Cmd) error {
if err != nil {
return err
}
defer dll.Release() // nolint: errcheck
//nolint:errcheck
defer dll.Release()

pid := cmd.Process.Pid

Expand Down Expand Up @@ -73,7 +74,8 @@ func isProcessExited(pid int) (bool, error) {
if err != nil {
return false, fmt.Errorf("failed to open process: %v", err)
}
defer windows.CloseHandle(hProcess) // nolint: errcheck
//nolint:errcheck
defer windows.CloseHandle(hProcess)

var exitCode uint32
err = windows.GetExitCodeProcess(hProcess, &exitCode)
Expand Down
2 changes: 1 addition & 1 deletion readline/readline.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (i *Instance) Readline() (string, error) {

defer func() {
fd := int(syscall.Stdin)
// nolint: errcheck
//nolint:errcheck
UnsetRawMode(fd, i.Terminal.termios)
i.Terminal.rawmode = false
}()
Expand Down
2 changes: 1 addition & 1 deletion server/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
return err
}

// nolint: contextcheck
//nolint:contextcheck
go download.Run(context.Background(), requestURL, opts.regOpts)
}

Expand Down
24 changes: 6 additions & 18 deletions server/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ func realpath(rel, from string) string {
return abspath
}

func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
Expand Down Expand Up @@ -546,16 +546,10 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
}
}

unref := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range manifest.Layers {
if !slices.Contains(digests, layer.Digest) {
unref[layer.Digest] = struct{}{}
}
}

if manifest.Config.Digest != layer.Digest {
unref[manifest.Config.Digest] = struct{}{}
if !envconfig.NoPrune {
if old, err := ParseNamedManifest(name); err == nil {
//nolint:errcheck
defer old.RemoveLayers()
}
}

Expand All @@ -564,12 +558,6 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return err
}

if !envconfig.NoPrune {
if err := deleteUnusedLayers(nil, unref); err != nil {
return err
}
}

fn(api.ProgressResponse{Status: "success"})
return nil
}
Expand Down Expand Up @@ -637,7 +625,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{})
// save (i.e. delete from the deleteMap) any files used in other manifests
manifest, _, err := GetManifest(fmp)
if err != nil {
// nolint: nilerr
//nolint:nilerr
return nil
}

Expand Down
46 changes: 27 additions & 19 deletions server/manifest.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package server

import (
"bytes"
"crypto/sha256"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -34,12 +33,6 @@ func (m *Manifest) Remove() error {
return err
}

for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); err != nil {
return err
}
}

manifests, err := GetManifestPath()
if err != nil {
return err
Expand All @@ -48,6 +41,16 @@ func (m *Manifest) Remove() error {
return PruneDirectory(manifests)
}

func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); err != nil {
return err
}
}

return nil
}

func ParseNamedManifest(n model.Name) (*Manifest, error) {
if !n.IsFullyQualified() {
return nil, model.Unqualified(n)
Expand Down Expand Up @@ -85,30 +88,35 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
}, nil
}

func WriteManifest(name string, config *Layer, layers []*Layer) error {
manifest := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,
Layers: layers,
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
manifests, err := GetManifestPath()
if err != nil {
return err
}

var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(manifest); err != nil {
p := filepath.Join(manifests, name.Filepath())
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
return err
}

modelpath := ParseModelPath(name)
manifestPath, err := modelpath.GetManifestPath()
f, err := os.Create(p)
if err != nil {
return err
}
defer f.Close()

m := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,
Layers: layers,
}

if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
if err := json.NewEncoder(f).Encode(m); err != nil {
return err
}

return os.WriteFile(manifestPath, b.Bytes(), 0o644)
return nil
}

func Manifests() (map[model.Name]*Manifest, error) {
Expand Down
10 changes: 4 additions & 6 deletions server/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,23 @@ type layerWithGGML struct {
}

func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
modelpath := ParseModelPath(name.String())
manifest, _, err := GetManifest(modelpath)
m, err := ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err
}

modelpath = ParseModelPath(name.String())
manifest, _, err = GetManifest(modelpath)
m, err = ParseNamedManifest(name)
if err != nil {
return nil, err
}
case err != nil:
return nil, err
}

for _, layer := range manifest.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
for _, layer := range m.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
if err != nil {
return nil, err
}
Expand Down
35 changes: 18 additions & 17 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,39 +506,39 @@ func (s *Server) PushModelHandler(c *gin.Context) {
}

func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
var r api.CreateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

name := model.ParseName(cmp.Or(req.Model, req.Name))
name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
return
}

if req.Path == "" && req.Modelfile == "" {
if r.Path == "" && r.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return
}

var r io.Reader = strings.NewReader(req.Modelfile)
if req.Path != "" && req.Modelfile == "" {
f, err := os.Open(req.Path)
var rr io.Reader = strings.NewReader(r.Modelfile)
if r.Path != "" && r.Modelfile == "" {
ff, err := os.Open(r.Path)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return
}
defer f.Close()
defer ff.Close()

r = f
rr = ff
}

modelfile, err := model.ParseFile(r)
f, err := model.ParseFile(rr)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
Expand All @@ -554,17 +554,13 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()

quantization := req.Quantization
if req.Quantize != "" {
quantization = req.Quantize
}

if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(quantization), modelfile, fn); err != nil {
quantization := cmp.Or(r.Quantize, r.Quantization)
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()

if req.Stream != nil && !*req.Stream {
if r.Stream != nil && !*r.Stream {
waitForStream(c, ch)
return
}
Expand Down Expand Up @@ -598,6 +594,11 @@ func (s *Server) DeleteModelHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

if err := m.RemoveLayers(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}

func (s *Server) ShowModelHandler(c *gin.Context) {
Expand Down
4 changes: 3 additions & 1 deletion server/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ func Test_Routes(t *testing.T) {
}

createTestModel := func(t *testing.T, name string) {
t.Helper()

fname := createTestFile(t, "ollama-model")

r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
Expand All @@ -61,7 +63,7 @@ func Test_Routes(t *testing.T) {
fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status)
}
err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
assert.Nil(t, err)
}

Expand Down
2 changes: 1 addition & 1 deletion server/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryO
return err
}

// nolint: contextcheck
//nolint:contextcheck
go upload.Run(context.Background(), opts)
}

Expand Down

0 comments on commit 5013cd0

Please sign in to comment.