Update the /api/create endpoint to use JSON (#7935)

Replaces `POST /api/create` to use JSON instead of a Modelfile.

This is a breaking change.
This commit is contained in:
Patrick Devine
2024-12-31 18:02:30 -08:00
committed by GitHub
parent 459d822b51
commit 86a622cbdc
17 changed files with 1523 additions and 1094 deletions

57
parser/expandpath_test.go Normal file
View File

@@ -0,0 +1,57 @@
package parser
import (
"os"
"os/user"
"path/filepath"
"testing"
)
func TestExpandPath(t *testing.T) {
mockCurrentUser := func() (*user.User, error) {
return &user.User{
Username: "testuser",
HomeDir: "/home/testuser",
}, nil
}
mockLookupUser := func(username string) (*user.User, error) {
fakeUsers := map[string]string{
"testuser": "/home/testuser",
"anotheruser": "/home/anotheruser",
}
if homeDir, ok := fakeUsers[username]; ok {
return &user.User{
Username: username,
HomeDir: homeDir,
}, nil
}
return nil, os.ErrNotExist
}
tests := []struct {
input string
expected string
windowsExpected string
shouldErr bool
}{
{"~", "/home/testuser", "D:\\home\\testuser", false},
{"~/myfolder/myfile.txt", "/home/testuser/myfolder/myfile.txt", "D:\\home\\testuser\\myfolder\\myfile.txt", false},
{"~anotheruser/docs/file.txt", "/home/anotheruser/docs/file.txt", "D:\\home\\anotheruser\\docs\\file.txt", false},
{"~nonexistentuser/file.txt", "", "", true},
{"relative/path/to/file", filepath.Join(os.Getenv("PWD"), "relative/path/to/file"), "relative\\path\\to\\file", false},
{"/absolute/path/to/file", "/absolute/path/to/file", "D:\\absolute\\path\\to\\file", false},
{".", os.Getenv("PWD"), os.Getenv("PWD"), false},
}
for _, test := range tests {
result, err := expandPathImpl(test.input, mockCurrentUser, mockLookupUser)
if (err != nil) != test.shouldErr {
t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.input, err != nil, test.shouldErr)
}
if result != test.expected && result != test.windowsExpected && !test.shouldErr {
t.Errorf("expandPathImpl(%q) = %q, want %q", test.input, result, test.expected)
}
}
}

View File

@@ -3,21 +3,30 @@ package parser
import (
"bufio"
"bytes"
"crypto/sha256"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"strconv"
"strings"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
"github.com/ollama/ollama/api"
)
type File struct {
var ErrModelNotFound = errors.New("no Modelfile or safetensors files found")
type Modelfile struct {
Commands []Command
}
func (f File) String() string {
func (f Modelfile) String() string {
var sb strings.Builder
for _, cmd := range f.Commands {
fmt.Fprintln(&sb, cmd.String())
@@ -26,6 +35,223 @@ func (f File) String() string {
return sb.String()
}
// CreateRequest creates a new *api.CreateRequest from an existing Modelfile
func (f Modelfile) CreateRequest() (*api.CreateRequest, error) {
req := &api.CreateRequest{}
var messages []api.Message
var licenses []string
params := make(map[string]any)
for _, c := range f.Commands {
switch c.Name {
case "model":
path, err := expandPath(c.Args)
if err != nil {
return nil, err
}
digestMap, err := fileDigestMap(path)
if errors.Is(err, os.ErrNotExist) {
req.From = c.Args
continue
} else if err != nil {
return nil, err
}
req.Files = digestMap
case "adapter":
path, err := expandPath(c.Args)
if err != nil {
return nil, err
}
digestMap, err := fileDigestMap(path)
if err != nil {
return nil, err
}
req.Adapters = digestMap
case "template":
req.Template = c.Args
case "system":
req.System = c.Args
case "license":
licenses = append(licenses, c.Args)
case "message":
role, msg, _ := strings.Cut(c.Args, ": ")
messages = append(messages, api.Message{Role: role, Content: msg})
default:
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
if err != nil {
return nil, err
}
for k, v := range ps {
if ks, ok := params[k].([]string); ok {
params[k] = append(ks, v.([]string)...)
} else if vs, ok := v.([]string); ok {
params[k] = vs
} else {
params[k] = v
}
}
}
}
if len(params) > 0 {
req.Parameters = params
}
if len(messages) > 0 {
req.Messages = messages
}
if len(licenses) > 0 {
req.License = licenses
}
return req, nil
}
func fileDigestMap(path string) (map[string]string, error) {
fl := make(map[string]string)
fi, err := os.Stat(path)
if err != nil {
return nil, err
}
var files []string
if fi.IsDir() {
files, err = filesForModel(path)
if err != nil {
return nil, err
}
} else {
files = []string{path}
}
for _, f := range files {
digest, err := digestForFile(f)
if err != nil {
return nil, err
}
fl[f] = digest
}
return fl, nil
}
func digestForFile(filename string) (string, error) {
filepath, err := filepath.EvalSymlinks(filename)
if err != nil {
return "", err
}
bin, err := os.Open(filepath)
if err != nil {
return "", err
}
defer bin.Close()
hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return "", err
}
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
}
func filesForModel(path string) ([]string, error) {
detectContentType := func(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
var b bytes.Buffer
b.Grow(512)
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
return contentType, nil
}
glob := func(pattern, contentType string) ([]string, error) {
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
for _, safetensor := range matches {
if ct, err := detectContentType(safetensor); err != nil {
return nil, err
} else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
}
}
return matches, nil
}
var files []string
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapters.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapter_model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
files = append(files, pt...)
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers consolidated.x.pth, consolidated.pth
files = append(files, pt...)
} else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 {
// covers gguf files ending in .gguf
files = append(files, gg...)
} else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 {
// covers gguf files ending in .bin
files = append(files, gg...)
} else {
return nil, ErrModelNotFound
}
// add configuration files, json files are detected as text/plain
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
if err != nil {
return nil, err
}
files = append(files, js...)
// bert models require a nested config.json
// TODO(mxyng): merge this with the glob above
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain")
if err != nil {
return nil, err
}
files = append(files, js...)
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}
return files, nil
}
type Command struct {
Name string
Args string
@@ -77,14 +303,14 @@ func (e *ParserError) Error() string {
return e.Msg
}
func ParseFile(r io.Reader) (*File, error) {
func ParseFile(r io.Reader) (*Modelfile, error) {
var cmd Command
var curr state
var currLine int = 1
var b bytes.Buffer
var role string
var f File
var f Modelfile
tr := unicode.BOMOverride(unicode.UTF8.NewDecoder())
br := bufio.NewReader(transform.NewReader(r, tr))
@@ -328,3 +554,40 @@ func isValidCommand(cmd string) bool {
return false
}
}
func expandPathImpl(path string, currentUserFunc func() (*user.User, error), lookupUserFunc func(string) (*user.User, error)) (string, error) {
if strings.HasPrefix(path, "~") {
var homeDir string
if path == "~" || strings.HasPrefix(path, "~/") {
// Current user's home directory
currentUser, err := currentUserFunc()
if err != nil {
return "", fmt.Errorf("failed to get current user: %w", err)
}
homeDir = currentUser.HomeDir
path = strings.TrimPrefix(path, "~")
} else {
// Specific user's home directory
parts := strings.SplitN(path[1:], "/", 2)
userInfo, err := lookupUserFunc(parts[0])
if err != nil {
return "", fmt.Errorf("failed to find user '%s': %w", parts[0], err)
}
homeDir = userInfo.HomeDir
if len(parts) > 1 {
path = "/" + parts[1]
} else {
path = ""
}
}
path = filepath.Join(homeDir, path)
}
return filepath.Abs(path)
}
func expandPath(path string) (string, error) {
return expandPathImpl(path, user.Current, user.Lookup)
}

View File

@@ -2,18 +2,24 @@ package parser
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"os"
"strings"
"testing"
"unicode/utf16"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/encoding"
"golang.org/x/text/encoding/unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
func TestParseFileFile(t *testing.T) {
@@ -673,3 +679,150 @@ func TestParseMultiByte(t *testing.T) {
})
}
}
func TestCreateRequest(t *testing.T) {
cases := []struct {
input string
expected *api.CreateRequest
}{
{
`FROM test`,
&api.CreateRequest{From: "test"},
},
{
`FROM test
TEMPLATE some template
`,
&api.CreateRequest{
From: "test",
Template: "some template",
},
},
{
`FROM test
LICENSE single license
PARAMETER temperature 0.5
MESSAGE user Hello
`,
&api.CreateRequest{
From: "test",
License: []string{"single license"},
Parameters: map[string]any{"temperature": float32(0.5)},
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
},
},
{
`FROM test
PARAMETER temperature 0.5
PARAMETER top_k 1
SYSTEM You are a bot.
LICENSE license1
LICENSE license2
MESSAGE user Hello there!
MESSAGE assistant Hi! How are you?
`,
&api.CreateRequest{
From: "test",
License: []string{"license1", "license2"},
System: "You are a bot.",
Parameters: map[string]any{"temperature": float32(0.5), "top_k": int64(1)},
Messages: []api.Message{
{Role: "user", Content: "Hello there!"},
{Role: "assistant", Content: "Hi! How are you?"},
},
},
},
}
for _, c := range cases {
s, err := unicode.UTF8.NewEncoder().String(c.input)
if err != nil {
t.Fatal(err)
}
p, err := ParseFile(strings.NewReader(s))
if err != nil {
t.Error(err)
}
actual, err := p.CreateRequest()
if err != nil {
t.Error(err)
}
if diff := cmp.Diff(actual, c.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}
func getSHA256Digest(t *testing.T, r io.Reader) (string, int64) {
t.Helper()
h := sha256.New()
n, err := io.Copy(h, r)
if err != nil {
t.Fatal(err)
}
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) (string, string) {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "testbin.*.gguf")
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := llm.WriteGGUF(f, kv, ti); err != nil {
t.Fatal(err)
}
// Calculate sha256 of file
if _, err := f.Seek(0, 0); err != nil {
t.Fatal(err)
}
digest, _ := getSHA256Digest(t, f)
return f.Name(), digest
}
func TestCreateRequestFiles(t *testing.T) {
name, digest := createBinFile(t, nil, nil)
cases := []struct {
input string
expected *api.CreateRequest
}{
{
fmt.Sprintf("FROM %s", name),
&api.CreateRequest{Files: map[string]string{name: digest}},
},
}
for _, c := range cases {
s, err := unicode.UTF8.NewEncoder().String(c.input)
if err != nil {
t.Fatal(err)
}
p, err := ParseFile(strings.NewReader(s))
if err != nil {
t.Error(err)
}
actual, err := p.CreateRequest()
if err != nil {
t.Error(err)
}
if diff := cmp.Diff(actual, c.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}