Skip to content

Commit

Permalink
remove last bits of ParseModelPath
Browse files Browse the repository at this point in the history
  • Loading branch information
mxyng committed May 9, 2024
1 parent 5ca2349 commit 2ee1424
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 403 deletions.
94 changes: 9 additions & 85 deletions server/images.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@ import (
"bytes"
"cmp"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"log/slog"
"net/http"
"net/url"
Expand Down Expand Up @@ -42,9 +39,8 @@ type registryOptions struct {
}

type Model struct {
Name string `json:"name"`
Name model.Name
Config ConfigV2
ShortName string
ModelPath string
ParentModel string
AdapterPaths []string
Expand Down Expand Up @@ -161,46 +157,17 @@ type RootFS struct {
DiffIDs []string `json:"diff_ids"`
}

func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
fp, err := mp.GetManifestPath()
if err != nil {
return nil, "", err
}

if _, err = os.Stat(fp); err != nil {
return nil, "", err
}

var manifest *ManifestV2

bts, err := os.ReadFile(fp)
if err != nil {
return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
}

shaSum := sha256.Sum256(bts)
shaStr := hex.EncodeToString(shaSum[:])

if err := json.Unmarshal(bts, &manifest); err != nil {
return nil, "", err
}

return manifest, shaStr, nil
}

func GetModel(name string) (*Model, error) {
mp := ParseModelPath(name)
manifest, digest, err := GetManifest(mp)
func GetModel(name model.Name) (*Model, error) {
manifest, err := ParseNamedManifest(name)
if err != nil {
return nil, err
}

model := &Model{
Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(),
Digest: digest,
Template: "{{ .Prompt }}",
License: []string{},
Name: name,
Digest: manifest.digest,
Template: "{{ .Prompt }}",
License: []string{},
}

filename, err := GetBlobsPath(manifest.Config.Digest)
Expand Down Expand Up @@ -688,18 +655,8 @@ func PullModel(ctx context.Context, name model.Name, opts registryOptions, fn fu

fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
if err := layer.Verify(); err != nil {
_ = layer.Remove()
return err
}
}
Expand Down Expand Up @@ -737,17 +694,6 @@ func pullModelManifest(ctx context.Context, name model.Name, baseURL *url.URL, o
return m, err
}

// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
func GetSHA256Digest(r io.Reader) (string, int64) {
h := sha256.New()
n, err := io.Copy(h, r)
if err != nil {
log.Fatal(err)
}

return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}

var errUnauthorized = fmt.Errorf("unauthorized: access denied")

// getTokenSubject returns the subject of a JWT token, it does not validate the token
Expand Down Expand Up @@ -907,25 +853,3 @@ func parseRegistryChallenge(authStr string) registryChallenge {
Scope: getValue(authStr, "scope"),
}
}

var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")

func verifyBlob(digest string) error {
fp, err := GetBlobsPath(digest)
if err != nil {
return err
}

f, err := os.Open(fp)
if err != nil {
return err
}
defer f.Close()

fileDigest, _ := GetSHA256Digest(f)
if digest != fileDigest {
return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
}

return nil
}
20 changes: 20 additions & 0 deletions server/layer.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,26 @@ func (l *Layer) Remove() error {
return os.Remove(blob)
}

func (l *Layer) Verify() error {
rc, err := l.Open()
if err != nil {
return err
}
defer rc.Close()

sha256sum := sha256.New()
if _, err := io.Copy(sha256sum, rc); err != nil {
return err
}

digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
if digest != l.Digest {
return fmt.Errorf("digest mismatch: %s != %s", digest, l.Digest)
}

return nil
}

func Layers() (map[string]*Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {
Expand Down
14 changes: 14 additions & 0 deletions server/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,17 @@ func Manifests() (map[model.Name]*Manifest, error) {

return ms, nil
}

func GetManifestPath() (string, error) {
dir, err := modelsDir()
if err != nil {
return "", err
}

path := filepath.Join(dir, "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", err
}

return path, nil
}
120 changes: 0 additions & 120 deletions server/modelpath.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,105 +2,16 @@ package server

import (
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
)

type ModelPath struct {
ProtocolScheme string
Registry string
Namespace string
Repository string
Tag string
}

const (
DefaultRegistry = "registry.ollama.ai"
DefaultNamespace = "library"
DefaultTag = "latest"
DefaultProtocolScheme = "https"
)

var (
ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
ErrInvalidDigestFormat = errors.New("invalid digest format")
)

func ParseModelPath(name string) ModelPath {
mp := ModelPath{
ProtocolScheme: DefaultProtocolScheme,
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "",
Tag: DefaultTag,
}

before, after, found := strings.Cut(name, "://")
if found {
mp.ProtocolScheme = before
name = after
}

name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
mp.Registry = parts[0]
mp.Namespace = parts[1]
mp.Repository = parts[2]
case 2:
mp.Namespace = parts[0]
mp.Repository = parts[1]
case 1:
mp.Repository = parts[0]
}

if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
mp.Repository = repo
mp.Tag = tag
}

return mp
}

var errModelPathInvalid = errors.New("invalid model path")

func (mp ModelPath) Validate() error {
if mp.Repository == "" {
return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
}

if strings.Contains(mp.Tag, ":") {
return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
}

return nil
}

func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}

func (mp ModelPath) GetFullTagname() string {
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}

func (mp ModelPath) GetShortTagname() string {
if mp.Registry == DefaultRegistry {
if mp.Namespace == DefaultNamespace {
return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
}
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
}

// modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set.
// The models directory is where Ollama stores its model files and manifests.
func modelsDir() (string, error) {
Expand All @@ -114,37 +25,6 @@ func modelsDir() (string, error) {
return filepath.Join(home, ".ollama", "models"), nil
}

// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) {
dir, err := modelsDir()
if err != nil {
return "", err
}

return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
}

func (mp ModelPath) BaseURL() *url.URL {
return &url.URL{
Scheme: mp.ProtocolScheme,
Host: mp.Registry,
}
}

func GetManifestPath() (string, error) {
dir, err := modelsDir()
if err != nil {
return "", err
}

path := filepath.Join(dir, "manifests")
if err := os.MkdirAll(path, 0o755); err != nil {
return "", err
}

return path, nil
}

func GetBlobsPath(digest string) (string, error) {
dir, err := modelsDir()
if err != nil {
Expand Down

0 comments on commit 2ee1424

Please sign in to comment.