This commit is contained in:
likelovewant
2025-09-28 12:37:28 +08:00
45 changed files with 1691 additions and 430 deletions

1
.gitignore vendored
View File

@@ -8,6 +8,7 @@
dist dist
build build
.cache .cache
.gocache
*.exe *.exe
.idea .idea
test_data test_data

View File

@@ -99,10 +99,12 @@ check_language(HIP)
if(CMAKE_HIP_COMPILER) if(CMAKE_HIP_COMPILER)
set(HIP_PLATFORM "amd") set(HIP_PLATFORM "amd")
find_package(hip REQUIRED)
if(NOT AMDGPU_TARGETS) if(NOT AMDGPU_TARGETS)
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011(:xnack-)|1012(:xnack-)|103[0-6]|110[0-3]|115[01]|120[01])$") find_package(hip REQUIRED)
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX) list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|90[012]|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011(:xnack-)|1012(:xnack-)|103[0-6]|110[0-3]|115[0123]|120[01])$")
endif()
if(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX}) list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
endif() endif()

View File

@@ -45,6 +45,12 @@ func checkError(resp *http.Response, body []byte) error {
return nil return nil
} }
if resp.StatusCode == http.StatusUnauthorized {
authError := AuthorizationError{StatusCode: resp.StatusCode}
json.Unmarshal(body, &authError)
return authError
}
apiError := StatusError{StatusCode: resp.StatusCode} apiError := StatusError{StatusCode: resp.StatusCode}
err := json.Unmarshal(body, &apiError) err := json.Unmarshal(body, &apiError)
@@ -215,6 +221,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
for scanner.Scan() { for scanner.Scan() {
var errorResponse struct { var errorResponse struct {
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
SigninURL string `json:"signin_url,omitempty"`
} }
bts := scanner.Bytes() bts := scanner.Bytes()
@@ -223,14 +230,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
} }
if response.StatusCode == http.StatusUnauthorized { if response.StatusCode == http.StatusUnauthorized {
pubKey, pkErr := auth.GetPublicKey()
if pkErr != nil {
return pkErr
}
return AuthorizationError{ return AuthorizationError{
StatusCode: response.StatusCode, StatusCode: response.StatusCode,
Status: response.Status, Status: response.Status,
PublicKey: pubKey, SigninURL: errorResponse.SigninURL,
} }
} else if response.StatusCode >= http.StatusBadRequest { } else if response.StatusCode >= http.StatusBadRequest {
return StatusError{ return StatusError{
@@ -439,8 +442,13 @@ func (c *Client) Version(ctx context.Context) (string, error) {
return version.Version, nil return version.Version, nil
} }
// Signout will disconnect an ollama instance from ollama.com // Signout will signout a client for a local ollama server.
func (c *Client) Signout(ctx context.Context, encodedKey string) error { func (c *Client) Signout(ctx context.Context) error {
return c.do(ctx, http.MethodPost, "/api/signout", nil, nil)
}
// Disconnect will disconnect an ollama instance from ollama.com.
func (c *Client) Disconnect(ctx context.Context, encodedKey string) error {
return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil) return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil)
} }

View File

@@ -41,7 +41,7 @@ func (e StatusError) Error() string {
type AuthorizationError struct { type AuthorizationError struct {
StatusCode int StatusCode int
Status string Status string
PublicKey string `json:"public_key"` SigninURL string `json:"signin_url"`
} }
func (e AuthorizationError) Error() string { func (e AuthorizationError) Error() string {

View File

@@ -18,46 +18,13 @@ import (
const defaultPrivateKey = "id_ed25519" const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) { func GetPublicKey() (string, error) {
fileIsReadable := func(fp string) bool {
info, err := os.Stat(fp)
if err != nil {
return false
}
// Check that it's a regular file, not a directory or other file type
if !info.Mode().IsRegular() {
return false
}
// Try to open it to check readability
file, err := os.Open(fp)
if err != nil {
return false
}
file.Close()
return true
}
systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey)
if fileIsReadable(systemPath) {
return systemPath, nil
}
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
return "", err return "", err
} }
return filepath.Join(home, ".ollama", defaultPrivateKey), nil keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
}
privateKeyFile, err := os.ReadFile(keyPath) privateKeyFile, err := os.ReadFile(keyPath)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
@@ -84,11 +51,12 @@ func NewNonce(r io.Reader, length int) (string, error) {
} }
func Sign(ctx context.Context, bts []byte) (string, error) { func Sign(ctx context.Context, bts []byte) (string, error) {
keyPath, err := keyPath() home, err := os.UserHomeDir()
if err != nil { if err != nil {
return "", err return "", err
} }
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath) privateKeyFile, err := os.ReadFile(keyPath)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) slog.Info(fmt.Sprintf("Failed to load private key: %v", err))

View File

@@ -5,7 +5,6 @@ import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors" "errors"
@@ -15,7 +14,6 @@ import (
"math" "math"
"net" "net"
"net/http" "net/http"
"net/url"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
@@ -37,7 +35,6 @@ import (
"golang.org/x/term" "golang.org/x/term"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
@@ -50,7 +47,7 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
const ConnectInstructions = "To sign in, navigate to:\n https://ollama.com/connect?name=%s&key=%s\n\n" const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
// ensureThinkingSupport emits a warning if the model does not advertise thinking support // ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) { func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
@@ -452,16 +449,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
var sErr api.AuthorizationError var sErr api.AuthorizationError
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
pubKey, pkErr := auth.GetPublicKey()
if pkErr != nil {
return pkErr
}
// the server and the client both have the same public key
if pubKey == sErr.PublicKey {
h, _ := os.Hostname()
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n") fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
if sErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, sErr.SigninURL)
} }
return nil return nil
} }
@@ -493,6 +484,16 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
user, err := client.Whoami(cmd.Context()) user, err := client.Whoami(cmd.Context())
if err != nil { if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
fmt.Println("You need to be signed in to Ollama to run Cloud models.")
fmt.Println()
if aErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, aErr.SigninURL)
}
return nil
}
return err return err
} }
@@ -502,34 +503,27 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
return nil return nil
} }
pubKey, pkErr := auth.GetPublicKey()
if pkErr != nil {
return pkErr
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
h, _ := os.Hostname()
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
return nil return nil
} }
func SignoutHandler(cmd *cobra.Command, args []string) error { func SignoutHandler(cmd *cobra.Command, args []string) error {
pubKey, pkErr := auth.GetPublicKey()
if pkErr != nil {
return pkErr
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
} }
err = client.Signout(cmd.Context(), encKey) err = client.Signout(cmd.Context())
if err != nil { if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
fmt.Println("You are not signed in to ollama.com")
fmt.Println()
return nil
} else {
return err return err
} }
}
fmt.Println("You have signed out of ollama.com") fmt.Println("You have signed out of ollama.com")
fmt.Println() fmt.Println()
return nil return nil
@@ -546,6 +540,25 @@ func PushHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
n := model.ParseName(args[0])
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
_, err := client.Whoami(cmd.Context())
if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
fmt.Println("You need to be signed in to push models to ollama.com.")
fmt.Println()
if aErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, aErr.SigninURL)
}
return nil
}
return err
}
}
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.Stop() defer p.Stop()
@@ -582,7 +595,6 @@ func PushHandler(cmd *cobra.Command, args []string) error {
request := api.PushRequest{Name: args[0], Insecure: insecure} request := api.PushRequest{Name: args[0], Insecure: insecure}
n := model.ParseName(args[0])
if err := client.Push(cmd.Context(), &request, fn); err != nil { if err := client.Push(cmd.Context(), &request, fn); err != nil {
if spinner != nil { if spinner != nil {
spinner.Stop() spinner.Stop()
@@ -1106,6 +1118,51 @@ type runOptions struct {
ShowConnect bool ShowConnect bool
} }
func (r runOptions) Copy() runOptions {
var messages []api.Message
if r.Messages != nil {
messages = make([]api.Message, len(r.Messages))
copy(messages, r.Messages)
}
var images []api.ImageData
if r.Images != nil {
images = make([]api.ImageData, len(r.Images))
copy(images, r.Images)
}
var opts map[string]any
if r.Options != nil {
opts = make(map[string]any, len(r.Options))
for k, v := range r.Options {
opts[k] = v
}
}
var think *api.ThinkValue
if r.Think != nil {
cThink := *r.Think
think = &cThink
}
return runOptions{
Model: r.Model,
ParentModel: r.ParentModel,
Prompt: r.Prompt,
Messages: messages,
WordWrap: r.WordWrap,
Format: r.Format,
System: r.System,
Images: images,
Options: opts,
MultiModal: r.MultiModal,
KeepAlive: r.KeepAlive,
Think: think,
HideThinking: r.HideThinking,
ShowConnect: r.ShowConnect,
}
}
type displayResponseState struct { type displayResponseState struct {
lineLength int lineLength int
wordBuffer string wordBuffer string

View File

@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"reflect"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -491,9 +492,35 @@ func TestPushHandler(t *testing.T) {
w.(http.Flusher).Flush() w.(http.Flusher).Flush()
} }
}, },
"/api/me": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST request, got %s", r.Method)
}
},
}, },
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n", expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
}, },
{
name: "not signed in push",
modelName: "notsignedin-model",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/me": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST request, got %s", r.Method)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized",
"signin_url": "https://somethingsomething",
})
if err != nil {
t.Fatal(err)
}
},
},
expectedOutput: "You need to be signed in to push",
},
{ {
name: "unauthorized push", name: "unauthorized push",
modelName: "unauthorized-model", modelName: "unauthorized-model",
@@ -508,6 +535,11 @@ func TestPushHandler(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
}, },
"/api/me": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST request, got %s", r.Method)
}
},
}, },
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own", expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
}, },
@@ -525,6 +557,9 @@ func TestPushHandler(t *testing.T) {
defer mockServer.Close() defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL) t.Setenv("OLLAMA_HOST", mockServer.URL)
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)
t.Setenv("USERPROFILE", tmpDir)
initializeKeypair() initializeKeypair()
cmd := &cobra.Command{} cmd := &cobra.Command{}
@@ -561,7 +596,7 @@ func TestPushHandler(t *testing.T) {
t.Errorf("expected no error, got %v", err) t.Errorf("expected no error, got %v", err)
} }
if tt.expectedOutput != "" { if tt.expectedOutput != "" {
if got := string(stdout); got != tt.expectedOutput { if got := string(stdout); !strings.Contains(got, tt.expectedOutput) {
t.Errorf("expected output %q, got %q", tt.expectedOutput, got) t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
} }
} }
@@ -919,3 +954,286 @@ func TestNewCreateRequest(t *testing.T) {
}) })
} }
} }
func TestRunOptions_Copy(t *testing.T) {
// Setup test data
originalKeepAlive := &api.Duration{Duration: 5 * time.Minute}
originalThink := &api.ThinkValue{Value: "test reasoning"}
original := runOptions{
Model: "test-model",
ParentModel: "parent-model",
Prompt: "test prompt",
Messages: []api.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi there"},
},
WordWrap: true,
Format: "json",
System: "system prompt",
Images: []api.ImageData{
[]byte("image1"),
[]byte("image2"),
},
Options: map[string]any{
"temperature": 0.7,
"max_tokens": 1000,
"top_p": 0.9,
},
MultiModal: true,
KeepAlive: originalKeepAlive,
Think: originalThink,
HideThinking: false,
ShowConnect: true,
}
// Test the copy
copied := original.Copy()
// Test 1: Verify the copy is not the same instance
if &copied == &original {
t.Error("Copy should return a different instance")
}
// Test 2: Verify all fields are copied correctly
tests := []struct {
name string
got interface{}
want interface{}
}{
{"Model", copied.Model, original.Model},
{"ParentModel", copied.ParentModel, original.ParentModel},
{"Prompt", copied.Prompt, original.Prompt},
{"WordWrap", copied.WordWrap, original.WordWrap},
{"Format", copied.Format, original.Format},
{"System", copied.System, original.System},
{"MultiModal", copied.MultiModal, original.MultiModal},
{"HideThinking", copied.HideThinking, original.HideThinking},
{"ShowConnect", copied.ShowConnect, original.ShowConnect},
}
for _, tt := range tests {
if !reflect.DeepEqual(tt.got, tt.want) {
t.Errorf("%s mismatch: got %v, want %v", tt.name, tt.got, tt.want)
}
}
// Test 3: Verify Messages slice is deeply copied
if len(copied.Messages) != len(original.Messages) {
t.Errorf("Messages length mismatch: got %d, want %d", len(copied.Messages), len(original.Messages))
}
if len(copied.Messages) > 0 && &copied.Messages[0] == &original.Messages[0] {
t.Error("Messages should be different instances")
}
// Modify original to verify independence
if len(original.Messages) > 0 {
originalContent := original.Messages[0].Content
original.Messages[0].Content = "modified"
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
t.Error("Messages should be independent after copy")
}
// Restore for other tests
original.Messages[0].Content = originalContent
}
// Test 4: Verify Images slice is deeply copied
if len(copied.Images) != len(original.Images) {
t.Errorf("Images length mismatch: got %d, want %d", len(copied.Images), len(original.Images))
}
if len(copied.Images) > 0 && &copied.Images[0] == &original.Images[0] {
t.Error("Images should be different instances")
}
// Modify original to verify independence
if len(original.Images) > 0 {
originalImage := original.Images[0]
original.Images[0] = []byte("modified")
if len(copied.Images) > 0 && string(copied.Images[0]) == "modified" {
t.Error("Images should be independent after copy")
}
// Restore for other tests
original.Images[0] = originalImage
}
// Test 5: Verify Options map is deeply copied
if len(copied.Options) != len(original.Options) {
t.Errorf("Options length mismatch: got %d, want %d", len(copied.Options), len(original.Options))
}
if len(copied.Options) > 0 && &copied.Options == &original.Options {
t.Error("Options map should be different instances")
}
// Modify original to verify independence
if len(original.Options) > 0 {
originalTemp := original.Options["temperature"]
original.Options["temperature"] = 0.9
if copied.Options["temperature"] == 0.9 {
t.Error("Options should be independent after copy")
}
// Restore for other tests
original.Options["temperature"] = originalTemp
}
// Test 6: Verify KeepAlive pointer is copied (shallow copy)
if copied.KeepAlive != original.KeepAlive {
t.Error("KeepAlive pointer should be the same (shallow copy)")
}
// Test 7: Verify Think pointer creates a new instance
if original.Think != nil && copied.Think == original.Think {
t.Error("Think should be a different instance")
}
if original.Think != nil && copied.Think != nil {
if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) {
t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value)
}
}
// Test 8: Test with zero values
zeroOriginal := runOptions{}
zeroCopy := zeroOriginal.Copy()
if !reflect.DeepEqual(zeroCopy, zeroOriginal) {
fmt.Printf("orig: %#v\ncopy: %#v\n", zeroOriginal, zeroCopy)
t.Error("Copy of zero value should equal original zero value")
}
}
func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) {
// Test with empty slices and maps
original := runOptions{
Messages: []api.Message{},
Images: []api.ImageData{},
Options: map[string]any{},
}
copied := original.Copy()
if copied.Messages == nil {
t.Error("Empty Messages slice should remain empty, not nil")
}
if copied.Images == nil {
t.Error("Empty Images slice should remain empty, not nil")
}
if copied.Options == nil {
t.Error("Empty Options map should remain empty, not nil")
}
if len(copied.Messages) != 0 {
t.Error("Empty Messages slice should remain empty")
}
if len(copied.Images) != 0 {
t.Error("Empty Images slice should remain empty")
}
if len(copied.Options) != 0 {
t.Error("Empty Options map should remain empty")
}
}
func TestRunOptions_Copy_NilPointers(t *testing.T) {
// Test with nil pointers
original := runOptions{
KeepAlive: nil,
Think: nil,
}
copied := original.Copy()
if copied.KeepAlive != nil {
t.Error("Nil KeepAlive should remain nil")
}
if copied.Think != nil {
t.Error("Nil Think should remain nil")
}
}
func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
tests := []struct {
name string
think *api.ThinkValue
}{
{"nil Think", nil},
{"bool true", &api.ThinkValue{Value: true}},
{"bool false", &api.ThinkValue{Value: false}},
{"string value", &api.ThinkValue{Value: "reasoning text"}},
{"int value", &api.ThinkValue{Value: 42}},
{"nil value", &api.ThinkValue{Value: nil}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
original := runOptions{Think: tt.think}
copied := original.Copy()
if tt.think == nil {
if copied.Think != nil {
t.Error("Nil Think should remain nil")
}
return
}
if copied.Think == nil {
t.Error("Non-nil Think should not become nil")
return
}
if copied.Think == original.Think {
t.Error("Think should be a different instance")
}
if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) {
t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value)
}
})
}
}
func TestRunOptions_Copy_Independence(t *testing.T) {
// Test that modifications to original don't affect copy
originalThink := &api.ThinkValue{Value: "original"}
original := runOptions{
Model: "original-model",
Messages: []api.Message{{Role: "user", Content: "original"}},
Options: map[string]any{"key": "value"},
Think: originalThink,
}
copied := original.Copy()
// Modify original
original.Model = "modified-model"
if len(original.Messages) > 0 {
original.Messages[0].Content = "modified"
}
original.Options["key"] = "modified"
if original.Think != nil {
original.Think.Value = "modified"
}
// Verify copy is unchanged
if copied.Model == "modified-model" {
t.Error("Copy Model should not be affected by original modification")
}
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
t.Error("Copy Messages should not be affected by original modification")
}
if copied.Options["key"] == "modified" {
t.Error("Copy Options should not be affected by original modification")
}
if copied.Think != nil && copied.Think.Value == "modified" {
t.Error("Copy Think should not be affected by original modification")
}
}

View File

@@ -195,16 +195,24 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("Usage:\n /load <modelname>") fmt.Println("Usage:\n /load <modelname>")
continue continue
} }
origOpts := opts.Copy()
opts.Model = args[1] opts.Model = args[1]
opts.Messages = []api.Message{} opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model) fmt.Printf("Loading model '%s'\n", opts.Model)
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet) opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
opts = origOpts.Copy()
continue
}
return err return err
} }
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
fmt.Printf("error: %v\n", err) fmt.Printf("Couldn't find model '%s'\n", opts.Model)
opts = origOpts.Copy()
continue continue
} }
if strings.Contains(err.Error(), "does not support thinking") { if strings.Contains(err.Error(), "does not support thinking") {

40
docs/cloud.md Normal file
View File

@@ -0,0 +1,40 @@
# Cloud
| Ollama's cloud is currently in preview. For full documentation, see [Ollama's documentation](https://docs.ollama.com/cloud).
## Cloud Models
[Cloud models](https://ollama.com/cloud) are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldnt fit on a personal computer.
Ollama currently supports the following cloud models, with more coming soon:
- `gpt-oss:20b-cloud`
- `gpt-oss:120b-cloud`
- `deepseek-v3.1:671b-cloud`
- `qwen3-coder:480b-cloud`
### Get started
To run a cloud model, open the terminal and run:
```
ollama run gpt-oss:120b-cloud
```
To run cloud models with integrations that work with Ollama, first download the cloud model:
```
ollama pull qwen3-coder:480b-cloud
```
Then sign in to Ollama:
```
ollama signin
```
Finally, access the model using the model name `qwen3-coder:480b-cloud` via Ollama's local API or tooling.
## Cloud API access
Cloud models can also be accessed directly on ollama.com's API. For more information, see the [docs](https://docs.ollama.com/cloud).

View File

@@ -1,107 +0,0 @@
# Turbo
>  Turbo is preview
Ollamas [Turbo](https://ollama.com/turbo) is a new way to run open-source models with acceleration from datacenter-grade hardware.
Currently, the following models are available in Turbo:
- `gpt-oss:20b`
- `gpt-oss:120b`
## Get started
### Ollama for macOS & Windows
Download Ollama
- Select a model such as `gpt-oss:20b` or `gpt-oss:120b`
- Click on **Turbo**. Youll be prompted to create an account or sign in
### Ollamas CLI
- [Sign up](https://ollama.com/signup) for an Ollama account
- Add your Ollama key [to ollama.com](https://ollama.com/settings/keys).
On macOS and Linux:
```shell
cat ~/.ollama/id_ed25519.pub
```
On Windows:
```
type "%USERPROFILE%\.ollama\id_ed25519.pub"
```
- Then run a model setting `OLLAMA_HOST` to `ollama.com`:
```shell
OLLAMA_HOST=ollama.com ollama run gpt-oss:120b
```
### Ollamas Python library
- Download Ollama's [Python library](https://github.com/ollama/ollama-python)
- [Sign up](https://ollama.com/signup) for an Ollama account
- Create an API key by visiting https://ollama.com/settings/keys
```python
from ollama import Client
client = Client(
host="https://ollama.com",
headers={'Authorization': '<api key>'}
)
messages = [
{
'role': 'user',
'content': 'Why is the sky blue?',
},
]
for part in client.chat('gpt-oss:120b', messages=messages, stream=True):
print(part['message']['content'], end='', flush=True)
```
### Ollamas JavaScript library
- Download Ollama's [JavaScript library](https://github.com/ollama/ollama-js)
- [Sign up](https://ollama.com/signup) for an Ollama account
- Create an API key by visiting https://ollama.com/settings/keys
```typescript
import { Ollama } from 'ollama';
const ollama = new Ollama({
host: 'https://ollama.com',
headers: {
Authorization: "Bearer <api key>"
}
});
const response = await ollama.chat({
model: 'gpt-oss:120b',
messages: [{ role: 'user', content: 'Explain quantum computing' }],
stream: true
});
for await (const part of response) {
process.stdout.write(part.message.content)
}
```
### Community integrations
Turbo mode is also compatible with several community integrations.
#### Open WebUI
- Go to **settings** → **Admin settings** → **Connections**
- Under **Ollama API,** click **+**
- For the **URL** put `https://ollama.com`
- For the **API key,** create an API key on https://ollama.com/settings/keys and add it.
- Click **Save**
Now, if you navigate to the model selector, Turbo models should be available under **External**.

View File

@@ -244,6 +244,7 @@ func (kv KV) OllamaEngineRequired() bool {
"gemma3n", "gemma3n",
"mistral3", "mistral3",
"qwen3", "qwen3",
"qwen3moe",
"llama4", "llama4",
"mllama", "mllama",
"qwen25vl", "qwen25vl",

View File

@@ -1,6 +1,7 @@
package harmony package harmony
import ( import (
"encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
"strings" "strings"
@@ -265,6 +266,8 @@ type HarmonyMessageHandler struct {
state harmonyMessageState state harmonyMessageState
HarmonyParser *HarmonyParser HarmonyParser *HarmonyParser
FunctionNameMap *FunctionNameMap FunctionNameMap *FunctionNameMap
toolAccumulator *HarmonyToolCallAccumulator
convertedTools map[string]struct{}
} }
// NewHarmonyMessageHandler creates a new message handler // NewHarmonyMessageHandler creates a new message handler
@@ -277,6 +280,7 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
HeaderEndTag: "<|message|>", HeaderEndTag: "<|message|>",
}, },
FunctionNameMap: NewFunctionNameMap(), FunctionNameMap: NewFunctionNameMap(),
convertedTools: make(map[string]struct{}),
} }
} }
@@ -384,8 +388,85 @@ func NewFunctionNameMap() *FunctionNameMap {
} }
} }
// Init initializes the handler with tools and optional last message
// Implements the Parser interface
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
// Initialize the harmony parser
if h.HarmonyParser == nil {
h.HarmonyParser = &HarmonyParser{
MessageStartTag: "<|start|>",
MessageEndTag: "<|end|>",
HeaderEndTag: "<|message|>",
}
}
// Handle prefill for chat mode
if lastMessage != nil {
h.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
} else {
h.HarmonyParser.AddImplicitStart()
}
// Initialize tool accumulator
h.toolAccumulator = h.CreateToolParser()
// Process tools and return renamed versions
if len(tools) == 0 {
return tools
}
processedTools := make([]api.Tool, len(tools))
copy(processedTools, tools)
for i, tool := range processedTools {
if tool.Function.Name != "" {
processedTools[i].Function.Name = h.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
h.convertedTools[tool.Function.Name] = struct{}{}
}
}
return processedTools
}
// Add implements the Parser interface - processes streamed content and extracts content, thinking, and tool calls
func (h *HarmonyMessageHandler) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
content, thinking, toolContent := h.AddContent(s, h.toolAccumulator)
if toolContent != "" {
h.toolAccumulator.Add(toolContent)
}
// tool calls always happen one at a time, and always at the end of a message,
// so for simplicity we defer parsing them until we know we're done
if done {
toolName, raw := h.toolAccumulator.Drain()
if toolName != nil {
name := strings.TrimPrefix(*toolName, "functions.")
name = h.FunctionNameMap.OriginalFromConverted(name)
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(raw), &args); err != nil {
return "", "", nil, fmt.Errorf("error parsing tool call: raw='%s', err=%w", raw, err)
}
calls = append(calls, api.ToolCall{Function: api.ToolCallFunction{Name: name, Arguments: args}})
}
}
return content, thinking, calls, nil
}
// HasToolSupport implements the Parser interface
func (h *HarmonyMessageHandler) HasToolSupport() bool {
return true
}
// HasThinkingSupport implements the Parser interface
func (h *HarmonyMessageHandler) HasThinkingSupport() bool {
return true
}
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string { func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
harmonyFunctionName := m.deriveName(userFunctionName) harmonyFunctionName := m.deriveName(userFunctionName)
// built-in functions should not be renamed
if userFunctionName == "browser.open" || userFunctionName == "browser.search" || userFunctionName == "browser.find" || userFunctionName == "python" {
harmonyFunctionName = userFunctionName
}
m.userToHarmony[userFunctionName] = harmonyFunctionName m.userToHarmony[userFunctionName] = harmonyFunctionName
m.harmonyToUser[harmonyFunctionName] = userFunctionName m.harmonyToUser[harmonyFunctionName] = userFunctionName
return harmonyFunctionName return harmonyFunctionName

View File

@@ -513,6 +513,7 @@ func TestFunctionConvertAndAdd(t *testing.T) {
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}}, {name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}}, {name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}}, {name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
{name: "built-in functions should not be renamed", in: []string{"browser.open", "python", "not.a.built-in.function", "browser.not_a_real_built_in"}, want: []string{"browser.open", "python", "not_a_built_in_function", "browser_not_a_real_built_in"}},
} }
for i, tt := range tests { for i, tt := range tests {

View File

@@ -12,3 +12,6 @@ The integration tests have 2 modes of operating.
> [!IMPORTANT] > [!IMPORTANT]
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree. > Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL`

View File

@@ -22,13 +22,12 @@ func TestAPIGenerate(t *testing.T) {
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: smol, Model: smol,
Prompt: "why is the sky blue? be brief", Prompt: blueSkyPrompt,
Options: map[string]interface{}{ Options: map[string]interface{}{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
}, },
} }
anyResp := []string{"rayleigh", "scattering"}
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
@@ -120,14 +119,14 @@ func TestAPIGenerate(t *testing.T) {
// Verify the response contains the expected data // Verify the response contains the expected data
response := buf.String() response := buf.String()
atLeastOne := false atLeastOne := false
for _, resp := range anyResp { for _, resp := range blueSkyExpected {
if strings.Contains(strings.ToLower(response), resp) { if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true atLeastOne = true
break break
} }
} }
if !atLeastOne { if !atLeastOne {
t.Errorf("none of %v found in %s", anyResp, response) t.Errorf("none of %v found in %s", blueSkyExpected, response)
} }
case <-ctx.Done(): case <-ctx.Done():
t.Error("outer test context done while waiting for generate") t.Error("outer test context done while waiting for generate")
@@ -181,7 +180,7 @@ func TestAPIChat(t *testing.T) {
Messages: []api.Message{ Messages: []api.Message{
{ {
Role: "user", Role: "user",
Content: "why is the sky blue? be brief", Content: blueSkyPrompt,
}, },
}, },
Options: map[string]interface{}{ Options: map[string]interface{}{
@@ -189,7 +188,6 @@ func TestAPIChat(t *testing.T) {
"seed": 123, "seed": 123,
}, },
} }
anyResp := []string{"rayleigh", "scattering"}
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
@@ -279,14 +277,14 @@ func TestAPIChat(t *testing.T) {
// Verify the response contains the expected data // Verify the response contains the expected data
response := buf.String() response := buf.String()
atLeastOne := false atLeastOne := false
for _, resp := range anyResp { for _, resp := range blueSkyExpected {
if strings.Contains(strings.ToLower(response), resp) { if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true atLeastOne = true
break break
} }
} }
if !atLeastOne { if !atLeastOne {
t.Errorf("none of %v found in %s", anyResp, response) t.Errorf("none of %v found in %s", blueSkyExpected, response)
} }
case <-ctx.Done(): case <-ctx.Done():
t.Error("outer test context done while waiting for chat") t.Error("outer test context done while waiting for chat")

View File

@@ -19,14 +19,14 @@ func TestBlueSky(t *testing.T) {
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: smol, Model: smol,
Prompt: "why is the sky blue?", Prompt: blueSkyPrompt,
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
}, },
} }
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) GenerateTestHelper(ctx, t, req, blueSkyExpected)
} }
func TestUnicode(t *testing.T) { func TestUnicode(t *testing.T) {
@@ -110,12 +110,12 @@ func TestUnicodeModelDir(t *testing.T) {
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: smol, Model: smol,
Prompt: "why is the sky blue?", Prompt: blueSkyPrompt,
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
}, },
} }
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) GenerateTestHelper(ctx, t, req, blueSkyExpected)
} }

View File

@@ -63,11 +63,11 @@ func TestContextExhaustion(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err) t.Fatalf("PullIfMissing failed: %v", err)
} }
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second) DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
} }
// Send multiple generate requests with prior context and ensure the response is coherant and expected // Send multiple generate requests with prior context and ensure the response is coherant and expected
func TestGenerateWithHistory(t *testing.T) { func TestParallelGenerateWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := GenerateRequests() req, resp := GenerateRequests()
numParallel := 2 numParallel := 2
@@ -113,8 +113,48 @@ func TestGenerateWithHistory(t *testing.T) {
wg.Wait() wg.Wait()
} }
// Send generate requests with prior context and ensure the response is coherant and expected
func TestGenerateWithHistory(t *testing.T) {
req := api.GenerateRequest{
Model: smol,
Prompt: rainbowPrompt,
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
"num_ctx": 16384,
},
}
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial request
slog.Info("loading", "model", req.Model)
err := client.Generate(ctx,
&api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", req.Model, err)
}
req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
for i := 0; i < len(rainbowFollowups); i++ {
req.Prompt = rainbowFollowups[i]
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
}
}
// Send multiple chat requests with prior context and ensure the response is coherant and expected // Send multiple chat requests with prior context and ensure the response is coherant and expected
func TestChatWithHistory(t *testing.T) { func TestParallelChatWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := ChatRequests() req, resp := ChatRequests()
numParallel := 2 numParallel := 2
@@ -164,3 +204,55 @@ func TestChatWithHistory(t *testing.T) {
} }
wg.Wait() wg.Wait()
} }
// Send generate requests with prior context and ensure the response is coherant and expected
func TestChatWithHistory(t *testing.T) {
req := api.ChatRequest{
Model: smol,
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
"num_ctx": 16384,
},
Messages: []api.Message{
{
Role: "user",
Content: rainbowPrompt,
},
},
}
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial request
slog.Info("loading", "model", req.Model)
err := client.Generate(ctx,
&api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", req.Model, err)
}
assistant := DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
for i := 0; i < len(rainbowFollowups); i++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
req.Messages = append(req.Messages,
*assistant,
api.Message{Role: "user", Content: rainbowFollowups[i]},
)
assistant = DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
if assistant == nil {
t.Fatalf("didn't get an assistant response for context")
}
}
}

View File

@@ -4,7 +4,9 @@ package integration
import ( import (
"context" "context"
"fmt"
"log/slog" "log/slog"
"os"
"testing" "testing"
"time" "time"
@@ -20,6 +22,7 @@ func TestLibraryModelsGenerate(t *testing.T) {
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
chatModels := libraryChatModels chatModels := libraryChatModels
for _, model := range chatModels { for _, model := range chatModels {
@@ -30,16 +33,26 @@ func TestLibraryModelsGenerate(t *testing.T) {
if err := PullIfMissing(ctx, client, model); err != nil { if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err) t.Fatalf("pull failed %s", err)
} }
if targetArch != "" {
resp, err := client.Show(ctx, &api.ShowRequest{Name: model})
if err != nil {
t.Fatalf("unable to show model: %s", err)
}
arch := resp.ModelInfo["general.architecture"].(string)
if arch != targetArch {
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
}
}
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: model, Model: model,
Prompt: "why is the sky blue?", Prompt: blueSkyPrompt,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{ Options: map[string]interface{}{
"temperature": 0.1, "temperature": 0.1,
"seed": 123, "seed": 123,
}, },
} }
anyResp := []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength"} anyResp := blueSkyExpected
// Special cases // Special cases
if model == "duckdb-nsql" { if model == "duckdb-nsql" {
anyResp = []string{"select", "from"} anyResp = []string{"select", "from"}

View File

@@ -68,14 +68,13 @@ func TestModelsGenerate(t *testing.T) {
// TODO - fiddle with context size // TODO - fiddle with context size
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: model, Model: model,
Prompt: "why is the sky blue?", Prompt: blueSkyPrompt,
Options: map[string]interface{}{ Options: map[string]interface{}{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
}, },
} }
anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"} DoGenerate(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
}) })
} }
} }

View File

@@ -40,6 +40,18 @@ var (
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv // cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv // cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
func TestModelsPerf(t *testing.T) { func TestModelsPerf(t *testing.T) {
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
doModelPerfTest(t, ollamaEngineChatModels)
} else {
doModelPerfTest(t, append(ollamaEngineChatModels, llamaRunnerChatModels...))
}
}
func TestLibraryModelsPerf(t *testing.T) {
doModelPerfTest(t, libraryChatModels)
}
func doModelPerfTest(t *testing.T, chatModels []string) {
softTimeout, hardTimeout := getTimeouts(t) softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout) slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
@@ -65,14 +77,12 @@ func TestModelsPerf(t *testing.T) {
} }
longPrompt := "summarize the following: " + string(data) longPrompt := "summarize the following: " + string(data)
var chatModels []string targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
chatModels = ollamaEngineChatModels
} else {
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
}
for _, model := range chatModels { for _, model := range chatModels {
if !strings.Contains(model, ":") {
model = model + ":latest"
}
t.Run(model, func(t *testing.T) { t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout { if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime") t.Skip("skipping remaining tests to avoid excessive runtime")
@@ -88,6 +98,9 @@ func TestModelsPerf(t *testing.T) {
} }
arch := resp.ModelInfo["general.architecture"].(string) arch := resp.ModelInfo["general.architecture"].(string)
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64)) maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
if targetArch != "" && arch != targetArch {
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
}
if maxVram > 0 { if maxVram > 0 {
resp, err := client.List(ctx) resp, err := client.List(ctx)
@@ -151,8 +164,8 @@ func TestModelsPerf(t *testing.T) {
prompt string prompt string
anyResp []string anyResp []string
}{ }{
{"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}}, {blueSkyPrompt, blueSkyExpected},
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}}, {maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy", "love", "sorrow", "beauty"}},
} }
var gpuPercent int var gpuPercent int
for _, tc := range testCases { for _, tc := range testCases {
@@ -241,11 +254,12 @@ func TestModelsPerf(t *testing.T) {
} }
} }
} }
// Round the logged prompt count for comparisons across versions/configurations which can vary slightly
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n", fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
"MODEL", "MODEL",
"CONTEXT", "CONTEXT",
"GPU PERCENT", "GPU PERCENT",
"PROMPT COUNT", "APPROX PROMPT COUNT",
"LOAD TIME", "LOAD TIME",
"PROMPT EVAL TPS", "PROMPT EVAL TPS",
"EVAL TPS", "EVAL TPS",
@@ -254,7 +268,7 @@ func TestModelsPerf(t *testing.T) {
model, model,
numCtx, numCtx,
gpuPercent, gpuPercent,
resp.PromptEvalCount, (resp.PromptEvalCount/10)*10,
float64(resp.LoadDuration)/1000000000.0, float64(resp.LoadDuration)/1000000000.0,
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0), float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0), float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),

View File

@@ -76,7 +76,7 @@ func TestQuantization(t *testing.T) {
stream := true stream := true
genReq := api.GenerateRequest{ genReq := api.GenerateRequest{
Model: newName, Model: newName,
Prompt: "why is the sky blue?", Prompt: blueSkyPrompt,
KeepAlive: &api.Duration{Duration: 3 * time.Second}, KeepAlive: &api.Duration{Duration: 3 * time.Second},
Options: map[string]any{ Options: map[string]any{
"seed": 42, "seed": 42,
@@ -88,14 +88,13 @@ func TestQuantization(t *testing.T) {
// Some smaller quantizations can cause models to have poor quality // Some smaller quantizations can cause models to have poor quality
// or get stuck in repetition loops, so we stop as soon as we have any matches // or get stuck in repetition loops, so we stop as soon as we have any matches
anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"}
reqCtx, reqCancel := context.WithCancel(ctx) reqCtx, reqCancel := context.WithCancel(ctx)
atLeastOne := false atLeastOne := false
var buf bytes.Buffer var buf bytes.Buffer
genfn := func(response api.GenerateResponse) error { genfn := func(response api.GenerateResponse) error {
buf.Write([]byte(response.Response)) buf.Write([]byte(response.Response))
fullResp := strings.ToLower(buf.String()) fullResp := strings.ToLower(buf.String())
for _, resp := range anyResp { for _, resp := range blueSkyExpected {
if strings.Contains(fullResp, resp) { if strings.Contains(fullResp, resp) {
atLeastOne = true atLeastOne = true
t.Log(fullResp) t.Log(fullResp)

View File

@@ -256,13 +256,29 @@ var (
"snowflake-arctic-embed", "snowflake-arctic-embed",
"snowflake-arctic-embed2", "snowflake-arctic-embed2",
} }
blueSkyPrompt = "why is the sky blue? Be brief but factual in your reply"
blueSkyExpected = []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength", "interact"}
rainbowPrompt = "how do rainbows form? Be brief but factual in your reply"
rainbowFollowups = []string{
"Explain the physics involved in them. Be breif in your reply",
"Explain the chemistry involved in them. Be breif in your reply",
"Explain the quantum mechanics involved in them. Be breif in your reply",
"What are common myths related to them? Be brief in your reply",
"What are common fairytales related to them? Be brief in your reply",
"Can they form if there is no rain? Be breif in your reply",
"Can they form if there are no clouds? Be breif in your reply",
"Do they happen on other planets? Be brief in your reply",
}
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refracted", "reflect", "color", "spectrum", "frequency", "end", "gold", "fortune", "blessing", "prosperity"}
) )
func init() { func init() {
lifecycle.InitLogging() lifecycle.InitLogging()
custom := os.Getenv("OLLAMA_TEST_SMOL_MODEL") custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL")
if custom != "" { if custom != "" {
slog.Info("setting smol test model to " + custom) slog.Info("setting default test model to " + custom)
smol = custom smol = custom
} }
} }
@@ -577,11 +593,11 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
}, },
}, },
[][]string{ [][]string{
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"}, {"sunlight", "scatter", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorb", "wavelength", "water", "molecule"},
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"}, {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigment", "particle", "iron oxide", "rust", "air", "water", "wet", "mixture", "mixing", "mineral", "element", "decomposed", "matter", "wavelength"},
{"water", "droplet", "refracted", "reflect", "color", "spectrum"}, {"water", "droplet", "refract", "reflect", "color", "spectrum", "raindrop"},
{"fourth", "july", "declaration", "independence"}, {"fourth", "july", "declaration", "independence"},
{"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"}, {"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor", "fluid", "particles", "gas"},
} }
} }

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"iter" "iter"
"log/slog" "log/slog"
"slices"
"strings" "strings"
"github.com/dlclark/regexp2" "github.com/dlclark/regexp2"
@@ -13,16 +14,28 @@ import (
) )
type BytePairEncoding struct { type BytePairEncoding struct {
pre *regexp2.Regexp
vocab *Vocabulary vocab *Vocabulary
regexps []*regexp2.Regexp
} }
var _ TextProcessor = (*BytePairEncoding)(nil) var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{ return BytePairEncoding{
pre: regexp2.MustCompile(pre, regexp2.None),
vocab: vocab, vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
}
}),
} }
} }
@@ -35,13 +48,36 @@ func (bpe BytePairEncoding) Is(id int32, special Special) bool {
} }
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] { func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
return func(yield func(string) bool) { parts := []string{s}
for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) { for _, re := range bpe.regexps {
parts = slices.Collect(func(yield func(string) bool) {
for _, part := range parts {
r := []rune(part)
var offset int
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
if offset-m.Index != 0 {
if !yield(string(r[:m.Index])) {
return
}
}
if !yield(m.String()) { if !yield(m.String()) {
break return
}
offset = m.Index + m.Length
}
if offset < len(r) {
if !yield(string(r[offset:])) {
return
} }
} }
} }
})
}
return slices.Values(parts)
} }
// fragment is a string fragment and their corresponding token IDs // fragment is a string fragment and their corresponding token IDs

View File

@@ -59,12 +59,12 @@ func llama(t testing.TB) BytePairEncoding {
} }
return NewBytePairEncoding( return NewBytePairEncoding(
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
&Vocabulary{ &Vocabulary{
Values: tokens, Values: tokens,
Types: types, Types: types,
Merges: merges, Merges: merges,
}, },
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
) )
} }
@@ -282,3 +282,41 @@ func BenchmarkBytePairEncoding(b *testing.B) {
}) })
} }
} }
func TestSplit(t *testing.T) {
cases := []struct {
name string
patterns,
want []string
}{
{
name: "default",
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
},
{
name: "unicode",
patterns: []string{
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
},
{
name: "individual digits",
patterns: []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
})
}
}

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
_ "image/jpeg" _ "image/jpeg"
_ "image/png" _ "image/png"
"log/slog"
"os" "os"
"reflect" "reflect"
"strconv" "strconv"
@@ -171,35 +172,44 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
// make a copy // make a copy
tagsCopy := tags tagsCopy := tags
if tag := t.Field(i).Tag.Get("gguf"); tag != "" { if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
tagsCopy = append(tagsCopy, ParseTags(tag)) tagsCopy = append(tagsCopy, parseTag(tag))
} }
if tt == reflect.TypeOf((*Base)(nil)).Elem() { if tt == reflect.TypeOf((*Base)(nil)).Elem() {
vv.Set(reflect.ValueOf(base)) vv.Set(reflect.ValueOf(base))
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
var fn func([]Tag) [][]string var fn func([]Tag, string, string) [][]string
fn = func(tags []Tag) (names [][]string) { fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
if len(tags) > 0 { if len(tags) > 0 {
localNames := []string{tags[0].Name} var names []string
localNames = append(localNames, tags[0].Alternate...) if tags[0].name != "" {
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
for _, localName := range localNames { names = append(names, prefix+n+suffix)
fullName := []string{localName} }
nested := fn(tags[1:]) }
if len(nested) > 0 { childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix)
for _, rest := range nested { if len(names) == 0 {
names = append(names, append(fullName, rest...)) // current tag has no name, use child names only
fullNames = append(fullNames, childNames...)
} else if len(childNames) == 0 {
// current tag has names but no children, create branches for each name
for _, name := range names {
fullNames = append(fullNames, []string{name})
} }
} else { } else {
names = append(names, fullName) // merge each name with each child
for _, name := range names {
for _, childName := range childNames {
fullNames = append(fullNames, append([]string{name}, childName...))
}
} }
} }
} }
return names return fullNames
} }
names := fn(tagsCopy) names := fn(tagsCopy, "", "")
for _, name := range names { for _, name := range names {
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil { if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
logutil.Trace("found tensor", "", tensor) logutil.Trace("found tensor", "", tensor)
@@ -213,9 +223,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
for i := range vv.Len() { for i := range vv.Len() {
vvv := vv.Index(i) vvv := vv.Index(i)
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})) setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
} else { } else {
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
} }
} }
} }
@@ -254,18 +264,31 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
} }
type Tag struct { type Tag struct {
Name string name,
Alternate []string // prefix and suffix are applied to child tags
prefix,
suffix string
alternatives []string
} }
func ParseTags(s string) (tag Tag) { func parseTag(s string) (tag Tag) {
parts := strings.Split(s, ",") parts := strings.Split(s, ",")
if len(parts) > 0 { if len(parts) > 0 {
tag.Name = parts[0] tag.name = parts[0]
for _, part := range parts[1:] { for _, part := range parts[1:] {
if value, ok := strings.CutPrefix(part, "alt:"); ok { if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
tag.Alternate = append(tag.Alternate, value) // elevate alternative to primary if no primary given
tag.name = value
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
} else if ok {
tag.alternatives = append(tag.alternatives, value)
}
if value, ok := strings.CutPrefix(part, "pre:"); ok {
tag.prefix = value
}
if value, ok := strings.CutPrefix(part, "suf:"); ok {
tag.suffix = value
} }
} }
} }

View File

@@ -22,14 +22,14 @@ func TestParseTags(t *testing.T) {
{ {
value: "output", value: "output",
want: Tag{ want: Tag{
Name: "output", name: "output",
}, },
}, },
{ {
value: "output,alt:token_embd", value: "output,alt:token_embd",
want: Tag{ want: Tag{
Name: "output", name: "output",
Alternate: []string{ alternatives: []string{
"token_embd", "token_embd",
}, },
}, },
@@ -38,8 +38,8 @@ func TestParseTags(t *testing.T) {
for _, tt := range cases { for _, tt := range cases {
t.Run(tt.value, func(t *testing.T) { t.Run(tt.value, func(t *testing.T) {
got := ParseTags(tt.value) got := parseTag(tt.value)
if diff := cmp.Diff(tt.want, got); diff != "" { if diff := cmp.Diff(tt.want, got, cmp.AllowUnexported((Tag{}))); diff != "" {
t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff) t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
} }
}) })
@@ -125,6 +125,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
Input *nn.Embedding `gguf:"input"` Input *nn.Embedding `gguf:"input"`
Output *nn.Linear `gguf:"output,alt:input"` Output *nn.Linear `gguf:"output,alt:input"`
Nested *nested `gguf:"nested"` Nested *nested `gguf:"nested"`
Tensor ml.Tensor `gguf:"leaf,alt:tensor"`
} }
var m fakeModel var m fakeModel
@@ -133,6 +134,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
names: []string{ names: []string{
"input.weight", "input.weight",
"nested.b.weight", "nested.b.weight",
"leaf",
}, },
}}, v.Elem())) }}, v.Elem()))
@@ -142,6 +144,58 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
Nested: &nested{ Nested: &nested{
Weight: &nn.Linear{Weight: &fakeTensor{Name: "nested.b.weight"}}, Weight: &nn.Linear{Weight: &fakeTensor{Name: "nested.b.weight"}},
}, },
Tensor: &fakeTensor{Name: "leaf"},
}, m); diff != "" {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
}
}
func TestPopulateFieldsPrefixSuffixName(t *testing.T) {
type fakeBlock struct {
A *nn.Linear `gguf:"a"`
B *nn.Linear `gguf:",pre:b_"`
C *nn.Linear `gguf:",suf:_c"`
XY *nn.Linear `gguf:",pre:x_,suf:_y"`
}
type fakeModel struct {
Blocks []fakeBlock `gguf:"blk"`
}
m := fakeModel{
Blocks: make([]fakeBlock, 2),
}
v := reflect.ValueOf(&m)
v.Elem().Set(populateFields(Base{b: &fakeBackend{
names: []string{
"blk.0.a.weight",
"blk.0.b_weight",
"blk.0.b_bias",
"blk.0.weight_c",
"blk.0.x_weight_y",
"blk.1.a.weight",
"blk.1.b_weight",
"blk.1.b_bias",
"blk.1.weight_c",
"blk.1.x_weight_y",
},
}}, v.Elem()))
if diff := cmp.Diff(fakeModel{
Blocks: []fakeBlock{
{
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.a.weight"}},
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.b_weight"}, Bias: &fakeTensor{Name: "blk.0.b_bias"}},
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.weight_c"}},
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.x_weight_y"}},
},
{
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.a.weight"}},
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.b_weight"}, Bias: &fakeTensor{Name: "blk.1.b_bias"}},
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.weight_c"}},
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.x_weight_y"}},
},
},
}, m); diff != "" { }, m); diff != "" {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff) t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
} }

View File

@@ -0,0 +1,324 @@
package deepseek2
// uses deepseek 2 architecture but written based on deepseek 3 model
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
numExpertsUsed int
numExperts int
normTopKProb bool
routedScalingFactor float32
kvLoraRank,
qkNopeHeadDim,
qkRopeHeadDim,
kqNopeHeadDim,
qkHeadDim int
qLoraRank int
vHeadDim int
hiddenSize,
numHeads,
numKVHeads,
keyLength,
valueLength,
originalContextLength int
eps,
ropeBase,
ropeScale float32
kqScale float64
}
func (o Options) RoPEOptions() []func(*rope.Options) {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
return []func(*rope.Options){
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(attnFactor),
}
}
type Attention struct {
Q *nn.Linear `gguf:"attn_q"`
QA *nn.Linear `gguf:"attn_q_a"`
QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"`
QB *nn.Linear `gguf:"attn_q_b"`
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
KVB *nn.Linear `gguf:"attn_kv_b"`
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
seqLength := hiddenStates.Dim(1)
var query ml.Tensor
if opts.qLoraRank == 0 { // nil {
query = attn.Q.Forward(ctx, hiddenStates)
} else {
query = attn.QA.Forward(ctx, hiddenStates)
query = attn.QANorm.Forward(ctx, query, opts.eps)
query = attn.QB.Forward(ctx, query)
}
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
qPass := query.View(ctx, 0,
opts.qkNopeHeadDim, query.Stride(1),
query.Dim(1), query.Stride(2),
query.Dim(2))
qRot := query.View(ctx, opts.qkNopeHeadDim*query.Stride(0),
opts.qkRopeHeadDim, query.Stride(1),
query.Dim(1), query.Stride(2),
query.Dim(2))
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
kPass := compressedKV.View(ctx, 0, opts.kvLoraRank, compressedKV.Stride(1), compressedKV.Dim(1))
kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0),
opts.qkRopeHeadDim, compressedKV.Stride(1),
1, compressedKV.Stride(1),
compressedKV.Dim(1))
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kPass = kv.View(ctx, 0, opts.kqNopeHeadDim, kv.Stride(1), kv.Dim(1), kv.Stride(2), kv.Dim(2))
value := kv.View(ctx, opts.kqNopeHeadDim*kv.Stride(0),
opts.vHeadDim, kv.Stride(1),
kv.Dim(1), kv.Stride(2),
kv.Dim(2)).Contiguous(ctx)
qRot = fast.RoPE(ctx, qRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = kRot.Repeat(ctx, 1, qPass.Dim(1))
query = qRot.Concat(ctx, qPass, 0)
key := kRot.Concat(ctx, kPass, 0)
attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
}
type MLP interface {
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
}
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"`
SharedExpert *dense `gguf:",suf:_shexp"`
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
}
func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor {
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices)
hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices)
hiddenStates = hiddenStates.SILU(ctx, upStates)
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
experts = experts.Mul(ctx, topKWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextStates
}
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
scores = scores.Add(ctx, moe.ExpProbsBias)
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
return topKIndices
}
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
residuals := hiddenStates
routerLogits := moe.Router.Forward(ctx, hiddenStates)
scores := routerLogits.Sigmoid(ctx)
topKIndices := moe.topKIndices(ctx, scores, opts)
topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices)
if opts.normTopKProb {
topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx))
topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
}
topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor))
hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts)
sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts)
hiddenStates = hiddenStates.Add(ctx, sharedExpertResult)
return hiddenStates
}
type dense struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *Attention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP MLP
}
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenStates
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
}
func New(c fs.Config) (model.Model, error) {
layers := make([]Layer, c.Uint("block_count"))
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
for i := range layers {
if i < firstDenseLayerIndex {
layers[i].MLP = &dense{}
} else {
layers[i].MLP = &sparse{}
}
}
mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor"))))
kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length")))
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
// Split regex into multiple parts (according to DeepSeek3's regex)
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
),
Layers: layers,
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("expert_weights_norm", true),
qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal,
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
qkHeadDim: int(c.Uint("attention.key_length")),
vHeadDim: int(c.Uint("attention.value_length")),
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
routedScalingFactor: c.Float("expert_weights_scale"),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
kqScale: kqScale,
},
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.qkRopeHeadDim, m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("deepseek2", New)
}

View File

@@ -227,17 +227,6 @@ func New(c fs.Config) (model.Model, error) {
m := Transformer{ m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")), TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer",
strings.Join([]string{
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`\p{N}{1,3}`,
` ?[^\s\p{L}\p{N}]+[\r\n/]*`,
`\s*[\r\n]+`,
`\s+(?!\S)`,
`\s+`,
}, "|"),
),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -250,6 +239,15 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
strings.Join([]string{
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`\p{N}{1,3}`,
` ?[^\s\p{L}\p{N}]+[\r\n/]*`,
`\s*[\r\n]+`,
`\s+(?!\S)`,
`\s+`,
}, "|"),
), ),
Options: Options{ Options: Options{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),

View File

@@ -54,10 +54,30 @@ func New(c fs.Config) (model.Model, error) {
} }
switch c.String("tokenizer.ggml.model") { switch c.String("tokenizer.ggml.model") {
case "gpt2": case "gpt2":
processor = model.NewBytePairEncoding( var pretokenizers []string
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, switch c.String("tokenizer.ggml.pre") {
&vocabulary, case "default":
) // no-op use the default bpe pretokenizer
case "qwen2":
pretokenizers = []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
case "refact":
pretokenizers = []string{
`\p{N}`,
`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`,
}
case "tekken":
pretokenizers = []string{
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
default:
// use a llama-style pretokenizer
pretokenizers = []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama": case "llama":
processor = model.NewSentencePiece(&vocabulary) processor = model.NewSentencePiece(&vocabulary)
default: default:

View File

@@ -34,8 +34,6 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer",
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -48,6 +46,7 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
), ),
ImageProcessor: newImageProcessor(c), ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c), VisionModel: newVisionModel(c),

View File

@@ -88,22 +88,10 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
return nextStates return nextStates
} }
// TextSharedExpert is TextMLP with different tensor names
type TextSharedExpert struct {
Gate *nn.Linear `gguf:"ffn_gate_shexp"`
Up *nn.Linear `gguf:"ffn_up_shexp"`
Down *nn.Linear `gguf:"ffn_down_shexp"`
}
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextMOE struct { type TextMOE struct {
Router *nn.Linear `gguf:"ffn_gate_inp"` Router *nn.Linear `gguf:"ffn_gate_inp"`
Experts *TextExperts Experts *TextExperts
SharedExpert *TextSharedExpert SharedExpert *TextMLP `gguf:",suf:_shexp"`
} }
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {

View File

@@ -33,7 +33,6 @@ var _ model.TextProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := &Model{ m := &Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
), ),
TextModel: newTextModel(c), TextModel: newTextModel(c),
VisionModel: newVisionModel(c), VisionModel: newVisionModel(c),

View File

@@ -33,7 +33,6 @@ const (
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
), ),
ImageProcessor: newImageProcessor(c), ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c), VisionModel: newVisionModel(c),

View File

@@ -2,6 +2,7 @@ package models
import ( import (
_ "github.com/ollama/ollama/model/models/bert" _ "github.com/ollama/ollama/model/models/bert"
_ "github.com/ollama/ollama/model/models/deepseek2"
_ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n" _ "github.com/ollama/ollama/model/models/gemma3n"

View File

@@ -139,7 +139,6 @@ func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")), Layers: make([]DecoderLayer, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -152,6 +151,7 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
), ),
Options: Options{ Options: Options{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),

View File

@@ -29,7 +29,6 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := &Model{ m := &Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -42,6 +41,7 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
), ),
TextModel: NewTextModel(c), TextModel: NewTextModel(c),
VisionModel: newVisionModel(c), VisionModel: newVisionModel(c),

View File

@@ -35,7 +35,6 @@ func newEmbed(c fs.Config) (model.Model, error) {
} }
m := embedModel{ m := embedModel{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -48,6 +47,7 @@ func newEmbed(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
), ),
Model: &Model{ Model: &Model{
Layers: layers, Layers: layers,

View File

@@ -200,7 +200,6 @@ func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -213,6 +212,7 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
), ),
Layers: layers, Layers: layers,
Options: &Options{ Options: &Options{

View File

@@ -2,10 +2,16 @@ package parsers
import ( import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
) )
type Parser interface { type Parser interface {
Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) // Init initializes the parser with tools and optional last message for chat prefill
// Returns processed tools if the parser needs to modify them (e.g., harmony renames them)
Init(tools []api.Tool, lastMessage *api.Message) []api.Tool
// Add processes streamed content and returns parsed content, thinking, and tool calls
// The done flag indicates if this is the last chunk (used for draining accumulators)
Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error)
HasToolSupport() bool HasToolSupport() bool
HasThinkingSupport() bool HasThinkingSupport() bool
} }
@@ -17,6 +23,8 @@ func ParserForName(name string) Parser {
return parser return parser
case "passthrough": case "passthrough":
return &PassthroughParser{} return &PassthroughParser{}
case "harmony":
return harmony.NewHarmonyMessageHandler()
default: default:
return nil return nil
} }
@@ -24,7 +32,11 @@ func ParserForName(name string) Parser {
type PassthroughParser struct{} type PassthroughParser struct{}
func (p *PassthroughParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) { func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
return tools // passthrough doesn't modify tools
}
func (p *PassthroughParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
return s, "", nil, nil return s, "", nil, nil
} }

View File

@@ -11,6 +11,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"unicode" "unicode"
"unicode/utf8"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
@@ -31,6 +32,7 @@ const (
type Qwen3CoderParser struct { type Qwen3CoderParser struct {
state qwenParserState state qwenParserState
acc strings.Builder acc strings.Builder
tools []api.Tool
} }
func (p *Qwen3CoderParser) HasToolSupport() bool { func (p *Qwen3CoderParser) HasToolSupport() bool {
@@ -41,7 +43,12 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
return false return false
} }
func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) { func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
p.tools = tools
return tools // Qwen doesn't modify tools
}
func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.acc.WriteString(s) p.acc.WriteString(s)
events := p.parseEvents() events := p.parseEvents()
@@ -51,7 +58,7 @@ func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thin
for _, event := range events { for _, event := range events {
switch event := event.(type) { switch event := event.(type) {
case qwenEventRawToolCall: case qwenEventRawToolCall:
toolCall, err := parseToolCall(event, tools) toolCall, err := parseToolCall(event, p.tools)
if err != nil { if err != nil {
slog.Warn("qwen tool call parsing failed", "error", err) slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err return "", "", nil, err
@@ -198,12 +205,21 @@ func overlap(s, delim string) int {
} }
func trailingWhitespaceLen(s string) int { func trailingWhitespaceLen(s string) int {
for i := len(s) - 1; i >= 0; i-- { remaining := s
if !unicode.IsSpace(rune(s[i])) { total := 0
return len(s) - i - 1 for len(remaining) > 0 {
r, size := utf8.DecodeLastRuneInString(remaining)
// if it's an invalid utf8 rune, assume it isn't whitespace
if r == utf8.RuneError && size == 1 {
break
} }
if !unicode.IsSpace(r) {
break
} }
return len(s) total += size
remaining = remaining[:len(remaining)-size]
}
return total
} }
type XMLFunctionCall struct { type XMLFunctionCall struct {
@@ -359,7 +375,7 @@ func parseValue(raw string, paramType api.PropertyType) any {
// Try array // Try array
if typeSet["array"] { if typeSet["array"] {
var arr []interface{} var arr []any
if err := json.Unmarshal([]byte(raw), &arr); err == nil { if err := json.Unmarshal([]byte(raw), &arr); err == nil {
return arr return arr
} }
@@ -371,7 +387,7 @@ func parseValue(raw string, paramType api.PropertyType) any {
// Try object // Try object
if typeSet["object"] { if typeSet["object"] {
var obj map[string]interface{} var obj map[string]any
if err := json.Unmarshal([]byte(raw), &obj); err == nil { if err := json.Unmarshal([]byte(raw), &obj); err == nil {
return obj return obj
} }

View File

@@ -166,6 +166,137 @@ func TestQwenParserStreaming(t *testing.T) {
}, },
}, },
}, },
{
desc: "unicode content",
steps: []step{
{
input: "你好 🌍<tool_call>test</tool_call>مرحبا",
wantEvents: []qwenEvent{
qwenEventContent{content: "你好 🌍"},
qwenEventRawToolCall{raw: "test"},
qwenEventContent{content: "مرحبا"},
},
},
},
},
{
desc: "arabic text handling",
steps: []step{
{
input: "مرحبا بالعالم",
wantEvents: []qwenEvent{qwenEventContent{content: "مرحبا بالعالم"}},
},
},
},
{
desc: "emoji passthrough",
steps: []step{
{
input: "✅",
wantEvents: []qwenEvent{qwenEventContent{content: "✅"}},
},
},
},
{
desc: "emoji after tool call",
steps: []step{
{
input: "<tool_call>test</tool_call>完成 ✅",
wantEvents: []qwenEvent{
qwenEventRawToolCall{raw: "test"},
qwenEventContent{content: "完成 ✅"},
},
},
},
},
{
desc: "unicode streaming with whitespace handling",
steps: []step{
{
input: "مرحبا",
wantEvents: []qwenEvent{
qwenEventContent{content: "مرحبا"},
},
},
{
input: " \n",
wantEvents: []qwenEvent{},
},
{
input: "世界",
wantEvents: []qwenEvent{
qwenEventContent{content: " \n世界"},
},
},
},
},
{
desc: "non-breaking space withheld across chunks",
steps: []step{
{
input: "Hello\u00a0",
wantEvents: []qwenEvent{
qwenEventContent{content: "Hello"},
},
},
{
input: "world",
wantEvents: []qwenEvent{
qwenEventContent{content: "\u00a0world"},
},
},
},
},
{
desc: "ideographic space before partial tool",
steps: []step{
{
input: "Hello\u3000<tool",
wantEvents: []qwenEvent{
qwenEventContent{content: "Hello"},
},
},
{
input: "_call>abc",
wantEvents: []qwenEvent{},
},
{
input: "</tool_call>def",
wantEvents: []qwenEvent{
qwenEventRawToolCall{raw: "abc"},
qwenEventContent{content: "def"},
},
},
},
},
{
desc: "ideographic space before partial tool fakeout",
steps: []step{
{
input: "Hello\u3000<tool",
wantEvents: []qwenEvent{
qwenEventContent{content: "Hello"},
},
},
{
input: "fakeout>abc",
wantEvents: []qwenEvent{
qwenEventContent{content: "\u3000<toolfakeout>abc"},
},
},
},
},
{
desc: "unicode with partial tool tag",
steps: []step{
{
input: "测试🎯 <to",
wantEvents: []qwenEvent{
qwenEventContent{content: "测试🎯"},
},
},
},
},
} }
anyOnlies := false anyOnlies := false
@@ -347,6 +478,27 @@ ls && echo "a > b and a < b"
}, },
}, },
}, },
{
name: "unicode in function names and parameters",
tools: []api.Tool{},
rawToolCall: `<function=获取天气>
<parameter=城市>
北京
</parameter>
<parameter=message>
Hello! 你好! 🌟 مرحبا
</parameter>
</function>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: map[string]any{
"城市": "北京",
"message": "Hello! 你好! 🌟 مرحبا",
},
},
},
},
} }
for i, step := range steps { for i, step := range steps {
@@ -360,6 +512,42 @@ ls && echo "a > b and a < b"
} }
} }
func TestTrailingWhitespaceLenUnicode(t *testing.T) {
cases := []struct {
name string
input string
want int
}{
{
name: "ascii space",
input: "Hello ",
want: 1,
},
{
name: "non-breaking space",
input: "Hello\u00a0",
want: 2,
},
{
name: "ideographic space",
input: "Hello\u3000",
want: 3,
},
{
name: "multiple runes of whitespace",
input: "Hi\u00a0\u3000",
want: 5,
},
}
for _, tc := range cases {
got := trailingWhitespaceLen(tc.input)
if got != tc.want {
t.Errorf("%s: trailingWhitespaceLen(%q) = %d, want %d", tc.name, tc.input, got, tc.want)
}
}
}
func TestQwenToolCallValueParsing(t *testing.T) { func TestQwenToolCallValueParsing(t *testing.T) {
cases := []struct { cases := []struct {
desc string desc string
@@ -867,6 +1055,8 @@ func TestTrailingWhitespaceLen(t *testing.T) {
{desc: "trailing whitespace with newlines", s: "abc \n", want: 2}, {desc: "trailing whitespace with newlines", s: "abc \n", want: 2},
{desc: "only whitespace", s: " \n ", want: 4}, {desc: "only whitespace", s: " \n ", want: 4},
{desc: "leading whitespace doesn't count", s: " \n abc", want: 0}, {desc: "leading whitespace doesn't count", s: " \n abc", want: 0},
{desc: "unicode with trailing space", s: "测试🎯 ", want: 1},
{desc: "unicode with trailing tab and newline", s: "مرحبا\t\n", want: 2},
} }
for _, tc := range cases { for _, tc := range cases {
@@ -876,3 +1066,30 @@ func TestTrailingWhitespaceLen(t *testing.T) {
} }
} }
} }
func TestOverlapFunction(t *testing.T) {
cases := []struct {
desc string
s string
delim string
want int
}{
{desc: "no overlap", s: "hello", delim: "<tool", want: 0},
{desc: "full overlap", s: "hello<tool", delim: "<tool>", want: 5},
{desc: "partial overlap", s: "hello<to", delim: "<tool>", want: 3},
{desc: "unicode with partial overlap", s: "测试🎯<to", delim: "<tool>", want: 3},
{desc: "unicode string with no overlap", s: "مرحبا", delim: "<tool>", want: 0},
{desc: "unicode at boundary", s: "世界<", delim: "<tool>", want: 1},
{desc: "unicode delimiter single rune", s: "hello🔧", delim: "🔧工具", want: len("🔧")},
{desc: "unicode delimiter multiple runes", s: "hello🔧工", delim: "🔧工具", want: len("🔧工")},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
got := overlap(tc.s, tc.delim)
if got != tc.want {
t.Errorf("overlap(%q, %q) = %d, want %d", tc.s, tc.delim, got, tc.want)
}
})
}
}

View File

@@ -82,7 +82,6 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
merges := make([]string, 0, 1) merges := make([]string, 0, 1)
// Only need vocab for Grammar Test // Only need vocab for Grammar Test
return model.NewBytePairEncoding( return model.NewBytePairEncoding(
``,
&model.Vocabulary{ &model.Vocabulary{
Values: tokens, Values: tokens,
Types: make([]int32, len(vocab)), Types: make([]int32, len(vocab)),

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"cmp" "cmp"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -34,7 +35,6 @@ import (
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/model/parsers" "github.com/ollama/ollama/model/parsers"
@@ -49,6 +49,8 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
func shouldUseHarmony(model *Model) bool { func shouldUseHarmony(model *Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony: // heuristic to check whether the template expects to be parsed via harmony:
@@ -151,6 +153,17 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return runner.llama, model, &opts, nil return runner.llama, model, &opts, nil
} }
func signinURL() (string, error) {
pubKey, err := auth.GetPublicKey()
if err != nil {
return "", err
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
h, _ := os.Hostname()
return fmt.Sprintf(signinURLStr, url.PathEscape(h), encKey), nil
}
func (s *Server) GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now() checkpointStart := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
@@ -251,18 +264,21 @@ func (s *Server) GenerateHandler(c *gin.Context) {
client := api.NewClient(remoteURL, http.DefaultClient) client := api.NewClient(remoteURL, http.DefaultClient)
err = client.Generate(c, &req, fn) err = client.Generate(c, &req, fn)
if err != nil { if err != nil {
var sErr api.AuthorizationError var authError api.AuthorizationError
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { if errors.As(err, &authError) {
pk, pkErr := auth.GetPublicKey() sURL, sErr := signinURL()
if pkErr != nil { if sErr != nil {
slog.Error("couldn't get public key", "error", pkErr) slog.Error(sErr.Error())
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return return
} }
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized", c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
"public_key": pk, return
}) }
var apiError api.StatusError
if errors.As(err, &apiError) {
c.JSON(apiError.StatusCode, apiError)
return return
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -291,17 +307,21 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
useHarmony := shouldUseHarmony(m) && !req.Raw var builtinParser parsers.Parser
var harmonyMessageHandler *harmony.HarmonyMessageHandler if shouldUseHarmony(m) && m.Config.Parser == "" {
var harmonyToolParser *harmony.HarmonyToolCallAccumulator m.Config.Parser = "harmony"
if useHarmony {
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.HarmonyParser.AddImplicitStart()
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
} }
// Validate Think value: string values currently only allowed for gptoss models if !req.Raw && m.Config.Parser != "" {
if req.Think != nil && req.Think.IsString() && !useHarmony { builtinParser = parsers.ParserForName(m.Config.Parser)
if builtinParser != nil {
// no tools or last message for generate endpoint
builtinParser.Init(nil, nil)
}
}
// Validate Think value: string values currently only allowed for harmony/gptoss models
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
return return
} }
@@ -425,7 +445,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
var thinkingState *thinking.Parser var thinkingState *thinking.Parser
if !useHarmony { if builtinParser == nil {
openingTag, closingTag := thinking.InferTags(m.Template.Template) openingTag, closingTag := thinking.InferTags(m.Template.Template)
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" { if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
thinkingState = &thinking.Parser{ thinkingState = &thinking.Parser{
@@ -462,11 +482,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}, },
} }
if useHarmony { if builtinParser != nil {
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser) content, thinking, toolCalls, err := builtinParser.Add(cr.Content, cr.Done)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Response = content res.Response = content
res.Thinking = thinking res.Thinking = thinking
harmonyToolParser.Add(toolContent) if cr.Done && len(toolCalls) > 0 {
res.ToolCalls = toolCalls
}
} else if thinkingState != nil { } else if thinkingState != nil {
thinking, content := thinkingState.AddContent(cr.Content) thinking, content := thinkingState.AddContent(cr.Content)
res.Thinking = thinking res.Thinking = thinking
@@ -478,26 +504,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
if cr.Done { if cr.Done {
if useHarmony {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
})
}
}
res.DoneReason = cr.DoneReason.String() res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
@@ -512,7 +518,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
} }
if useHarmony { if builtinParser != nil {
// only send messages with meaningful content (empty messages confuse clients) // only send messages with meaningful content (empty messages confuse clients)
if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 { if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 {
ch <- res ch <- res
@@ -1423,9 +1429,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/show", s.ShowHandler) r.POST("/api/show", s.ShowHandler)
r.DELETE("/api/delete", s.DeleteHandler) r.DELETE("/api/delete", s.DeleteHandler)
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
r.POST("/api/me", s.WhoamiHandler) r.POST("/api/me", s.WhoamiHandler)
r.POST("/api/signout", s.SignoutHandler)
// deprecated
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
// Create // Create
r.POST("/api/create", s.CreateHandler) r.POST("/api/create", s.CreateHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler)
@@ -1636,11 +1645,32 @@ func (s *Server) WhoamiHandler(c *gin.Context) {
if err != nil { if err != nil {
slog.Error(err.Error()) slog.Error(err.Error())
} }
// user isn't signed in
if user != nil && user.Name == "" {
sURL, sErr := signinURL()
if sErr != nil {
slog.Error(sErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL})
return
}
c.JSON(http.StatusOK, user) c.JSON(http.StatusOK, user)
} }
func (s *Server) SignoutHandler(c *gin.Context) { func (s *Server) SignoutHandler(c *gin.Context) {
encodedKey := c.Param("encodedKey") pubKey, err := auth.GetPublicKey()
if err != nil {
slog.Error("couldn't get public key", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
return
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
// todo allow other hosts // todo allow other hosts
u, err := url.Parse("https://ollama.com") u, err := url.Parse("https://ollama.com")
@@ -1651,11 +1681,11 @@ func (s *Server) SignoutHandler(c *gin.Context) {
} }
client := api.NewClient(u, http.DefaultClient) client := api.NewClient(u, http.DefaultClient)
err = client.Signout(c, encodedKey) err = client.Disconnect(c, encKey)
if err != nil { if err != nil {
slog.Error(err.Error()) var authError api.AuthorizationError
if strings.Contains(err.Error(), "page not found") || strings.Contains(err.Error(), "invalid credentials") { if errors.As(err, &authError) {
c.JSON(http.StatusNotFound, gin.H{"error": "you are not currently signed in"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not currently signed in"})
return return
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
@@ -1813,18 +1843,21 @@ func (s *Server) ChatHandler(c *gin.Context) {
client := api.NewClient(remoteURL, http.DefaultClient) client := api.NewClient(remoteURL, http.DefaultClient)
err = client.Chat(c, &req, fn) err = client.Chat(c, &req, fn)
if err != nil { if err != nil {
var sErr api.AuthorizationError var authError api.AuthorizationError
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { if errors.As(err, &authError) {
pk, pkErr := auth.GetPublicKey() sURL, sErr := signinURL()
if pkErr != nil { if sErr != nil {
slog.Error("couldn't get public key", "error", pkErr) slog.Error(sErr.Error())
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return return
} }
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized", c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
"public_key": pk, return
}) }
var apiError api.StatusError
if errors.As(err, &apiError) {
c.JSON(apiError.StatusCode, apiError)
return return
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -1870,32 +1903,23 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
msgs = filterThinkTags(msgs, m) msgs = filterThinkTags(msgs, m)
var builtinParser parsers.Parser if shouldUseHarmony(m) && m.Config.Parser == "" {
if m.Config.Parser != "" { m.Config.Parser = "harmony"
builtinParser = parsers.ParserForName(m.Config.Parser)
} }
var harmonyMessageHandler *harmony.HarmonyMessageHandler var builtinParser parsers.Parser
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
useHarmony := shouldUseHarmony(m) || m.Config.Parser == "harmony"
processedTools := req.Tools processedTools := req.Tools
if useHarmony {
harmonyMessageHandler = harmony.NewHarmonyMessageHandler() if m.Config.Parser != "" {
builtinParser = parsers.ParserForName(m.Config.Parser)
if builtinParser != nil {
// Determine last message for chat prefill
var lastMessage *api.Message var lastMessage *api.Message
if len(msgs) > 0 { if len(msgs) > 0 {
lastMessage = &msgs[len(msgs)-1] lastMessage = &msgs[len(msgs)-1]
} }
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage) // Initialize parser and get processed tools
harmonyToolParser = harmonyMessageHandler.CreateToolParser() processedTools = builtinParser.Init(req.Tools, lastMessage)
// make a copy of tools to pass to the chat prompt. Function names may be
// renamed to be valid Harmony function names.
processedTools = make([]api.Tool, len(req.Tools))
copy(processedTools, req.Tools)
for i, tool := range processedTools {
processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
} }
} }
@@ -1919,8 +1943,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
// Validate Think value: string values currently only allowed for gptoss models // Validate Think value: string values currently only allowed for harmony/gptoss models
if req.Think != nil && req.Think.IsString() && !useHarmony { if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
return return
} }
@@ -1939,7 +1963,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
var toolParser *tools.Parser var toolParser *tools.Parser
if len(req.Tools) > 0 && !useHarmony { if len(req.Tools) > 0 && (builtinParser == nil || !builtinParser.HasToolSupport()) {
toolParser = tools.NewParser(m.Template.Template, req.Tools) toolParser = tools.NewParser(m.Template.Template, req.Tools)
} }
@@ -1971,38 +1995,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
// TODO(drifkin): fold this as much as possibleinto the generic m.Config.Parser logic if builtinParser != nil {
if useHarmony {
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
res.Message.Content = content
res.Message.Thinking = thinking
harmonyToolParser.Add(toolContent)
if r.Done {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
}
}
// only send messages with meaningful content (empty messages confuse clients)
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
ch <- res
}
return
} else if builtinParser != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content) slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
content, thinking, toolCalls, err := builtinParser.Add(r.Content, req.Tools) content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
if err != nil { if err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
return return

View File

@@ -273,9 +273,21 @@ func findArguments(buffer []byte) (map[string]any, int) {
if args, ok := obj["arguments"].(map[string]any); ok { if args, ok := obj["arguments"].(map[string]any); ok {
return args, true return args, true
} }
if argsStr, ok := obj["arguments"].(string); ok {
var argsData map[string]interface{}
if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
return argsData, ok
}
}
if args, ok := obj["parameters"].(map[string]any); ok { if args, ok := obj["parameters"].(map[string]any); ok {
return args, true return args, true
} }
if argsStr, ok := obj["parameters"].(string); ok {
var argsData map[string]interface{}
if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
return argsData, ok
}
}
return nil, true return nil, true
} }

View File

@@ -1274,6 +1274,22 @@ func TestFindArguments(t *testing.T) {
"items": []any{"{", "}", map[string]any{"key": "value"}}, "items": []any{"{", "}", map[string]any{"key": "value"}},
}, },
}, },
{
name: "stringified arguments",
buffer: []byte(`{"name": "get_temperature", "arguments": "{\"format\": \"fahrenheit\", \"location\": \"San Francisco, CA\"}"}`),
want: map[string]any{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
{
name: "stringified parameters",
buffer: []byte(`{"name": "get_temperature", "parameters": "{\"format\": \"fahrenheit\", \"location\": \"San Francisco, CA\"}"}`),
want: map[string]any{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
} }
for _, tt := range tests { for _, tt := range tests {