mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
fix
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -8,6 +8,7 @@
|
|||||||
dist
|
dist
|
||||||
build
|
build
|
||||||
.cache
|
.cache
|
||||||
|
.gocache
|
||||||
*.exe
|
*.exe
|
||||||
.idea
|
.idea
|
||||||
test_data
|
test_data
|
||||||
|
|||||||
@@ -99,10 +99,12 @@ check_language(HIP)
|
|||||||
if(CMAKE_HIP_COMPILER)
|
if(CMAKE_HIP_COMPILER)
|
||||||
set(HIP_PLATFORM "amd")
|
set(HIP_PLATFORM "amd")
|
||||||
|
|
||||||
find_package(hip REQUIRED)
|
|
||||||
if(NOT AMDGPU_TARGETS)
|
if(NOT AMDGPU_TARGETS)
|
||||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011(:xnack-)|1012(:xnack-)|103[0-6]|110[0-3]|115[01]|120[01])$")
|
find_package(hip REQUIRED)
|
||||||
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|90[012]|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011(:xnack-)|1012(:xnack-)|103[0-6]|110[0-3]|115[0123]|120[01])$")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||||
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,12 @@ func checkError(resp *http.Response, body []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized {
|
||||||
|
authError := AuthorizationError{StatusCode: resp.StatusCode}
|
||||||
|
json.Unmarshal(body, &authError)
|
||||||
|
return authError
|
||||||
|
}
|
||||||
|
|
||||||
apiError := StatusError{StatusCode: resp.StatusCode}
|
apiError := StatusError{StatusCode: resp.StatusCode}
|
||||||
|
|
||||||
err := json.Unmarshal(body, &apiError)
|
err := json.Unmarshal(body, &apiError)
|
||||||
@@ -215,6 +221,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
|||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
var errorResponse struct {
|
var errorResponse struct {
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
|
SigninURL string `json:"signin_url,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
bts := scanner.Bytes()
|
bts := scanner.Bytes()
|
||||||
@@ -223,14 +230,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
|||||||
}
|
}
|
||||||
|
|
||||||
if response.StatusCode == http.StatusUnauthorized {
|
if response.StatusCode == http.StatusUnauthorized {
|
||||||
pubKey, pkErr := auth.GetPublicKey()
|
|
||||||
if pkErr != nil {
|
|
||||||
return pkErr
|
|
||||||
}
|
|
||||||
return AuthorizationError{
|
return AuthorizationError{
|
||||||
StatusCode: response.StatusCode,
|
StatusCode: response.StatusCode,
|
||||||
Status: response.Status,
|
Status: response.Status,
|
||||||
PublicKey: pubKey,
|
SigninURL: errorResponse.SigninURL,
|
||||||
}
|
}
|
||||||
} else if response.StatusCode >= http.StatusBadRequest {
|
} else if response.StatusCode >= http.StatusBadRequest {
|
||||||
return StatusError{
|
return StatusError{
|
||||||
@@ -439,8 +442,13 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
|||||||
return version.Version, nil
|
return version.Version, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Signout will disconnect an ollama instance from ollama.com
|
// Signout will signout a client for a local ollama server.
|
||||||
func (c *Client) Signout(ctx context.Context, encodedKey string) error {
|
func (c *Client) Signout(ctx context.Context) error {
|
||||||
|
return c.do(ctx, http.MethodPost, "/api/signout", nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disconnect will disconnect an ollama instance from ollama.com.
|
||||||
|
func (c *Client) Disconnect(ctx context.Context, encodedKey string) error {
|
||||||
return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil)
|
return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ func (e StatusError) Error() string {
|
|||||||
type AuthorizationError struct {
|
type AuthorizationError struct {
|
||||||
StatusCode int
|
StatusCode int
|
||||||
Status string
|
Status string
|
||||||
PublicKey string `json:"public_key"`
|
SigninURL string `json:"signin_url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e AuthorizationError) Error() string {
|
func (e AuthorizationError) Error() string {
|
||||||
|
|||||||
40
auth/auth.go
40
auth/auth.go
@@ -18,46 +18,13 @@ import (
|
|||||||
|
|
||||||
const defaultPrivateKey = "id_ed25519"
|
const defaultPrivateKey = "id_ed25519"
|
||||||
|
|
||||||
func keyPath() (string, error) {
|
func GetPublicKey() (string, error) {
|
||||||
fileIsReadable := func(fp string) bool {
|
|
||||||
info, err := os.Stat(fp)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that it's a regular file, not a directory or other file type
|
|
||||||
if !info.Mode().IsRegular() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to open it to check readability
|
|
||||||
file, err := os.Open(fp)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
file.Close()
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey)
|
|
||||||
if fileIsReadable(systemPath) {
|
|
||||||
return systemPath, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||||
}
|
|
||||||
|
|
||||||
func GetPublicKey() (string, error) {
|
|
||||||
keyPath, err := keyPath()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||||
@@ -84,11 +51,12 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||||
keyPath, err := keyPath()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||||
|
|||||||
117
cmd/cmd.go
117
cmd/cmd.go
@@ -5,7 +5,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -15,7 +14,6 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -37,7 +35,6 @@ import (
|
|||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/auth"
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
@@ -50,7 +47,7 @@ import (
|
|||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ConnectInstructions = "To sign in, navigate to:\n https://ollama.com/connect?name=%s&key=%s\n\n"
|
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||||
|
|
||||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||||
@@ -452,16 +449,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
var sErr api.AuthorizationError
|
var sErr api.AuthorizationError
|
||||||
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
||||||
pubKey, pkErr := auth.GetPublicKey()
|
|
||||||
if pkErr != nil {
|
|
||||||
return pkErr
|
|
||||||
}
|
|
||||||
// the server and the client both have the same public key
|
|
||||||
if pubKey == sErr.PublicKey {
|
|
||||||
h, _ := os.Hostname()
|
|
||||||
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
|
||||||
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||||
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
|
|
||||||
|
if sErr.SigninURL != "" {
|
||||||
|
fmt.Printf(ConnectInstructions, sErr.SigninURL)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -493,6 +484,16 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
user, err := client.Whoami(cmd.Context())
|
user, err := client.Whoami(cmd.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
var aErr api.AuthorizationError
|
||||||
|
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
|
||||||
|
fmt.Println("You need to be signed in to Ollama to run Cloud models.")
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
if aErr.SigninURL != "" {
|
||||||
|
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -502,34 +503,27 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pubKey, pkErr := auth.GetPublicKey()
|
|
||||||
if pkErr != nil {
|
|
||||||
return pkErr
|
|
||||||
}
|
|
||||||
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
|
||||||
|
|
||||||
h, _ := os.Hostname()
|
|
||||||
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SignoutHandler(cmd *cobra.Command, args []string) error {
|
func SignoutHandler(cmd *cobra.Command, args []string) error {
|
||||||
pubKey, pkErr := auth.GetPublicKey()
|
|
||||||
if pkErr != nil {
|
|
||||||
return pkErr
|
|
||||||
}
|
|
||||||
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
|
||||||
|
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = client.Signout(cmd.Context(), encKey)
|
err = client.Signout(cmd.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
var aErr api.AuthorizationError
|
||||||
|
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
|
||||||
|
fmt.Println("You are not signed in to ollama.com")
|
||||||
|
fmt.Println()
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Println("You have signed out of ollama.com")
|
fmt.Println("You have signed out of ollama.com")
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
return nil
|
return nil
|
||||||
@@ -546,6 +540,25 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
n := model.ParseName(args[0])
|
||||||
|
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
|
||||||
|
_, err := client.Whoami(cmd.Context())
|
||||||
|
if err != nil {
|
||||||
|
var aErr api.AuthorizationError
|
||||||
|
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
|
||||||
|
fmt.Println("You need to be signed in to push models to ollama.com.")
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
if aErr.SigninURL != "" {
|
||||||
|
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
p := progress.NewProgress(os.Stderr)
|
p := progress.NewProgress(os.Stderr)
|
||||||
defer p.Stop()
|
defer p.Stop()
|
||||||
|
|
||||||
@@ -582,7 +595,6 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||||
|
|
||||||
n := model.ParseName(args[0])
|
|
||||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||||
if spinner != nil {
|
if spinner != nil {
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
@@ -1106,6 +1118,51 @@ type runOptions struct {
|
|||||||
ShowConnect bool
|
ShowConnect bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r runOptions) Copy() runOptions {
|
||||||
|
var messages []api.Message
|
||||||
|
if r.Messages != nil {
|
||||||
|
messages = make([]api.Message, len(r.Messages))
|
||||||
|
copy(messages, r.Messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
var images []api.ImageData
|
||||||
|
if r.Images != nil {
|
||||||
|
images = make([]api.ImageData, len(r.Images))
|
||||||
|
copy(images, r.Images)
|
||||||
|
}
|
||||||
|
|
||||||
|
var opts map[string]any
|
||||||
|
if r.Options != nil {
|
||||||
|
opts = make(map[string]any, len(r.Options))
|
||||||
|
for k, v := range r.Options {
|
||||||
|
opts[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var think *api.ThinkValue
|
||||||
|
if r.Think != nil {
|
||||||
|
cThink := *r.Think
|
||||||
|
think = &cThink
|
||||||
|
}
|
||||||
|
|
||||||
|
return runOptions{
|
||||||
|
Model: r.Model,
|
||||||
|
ParentModel: r.ParentModel,
|
||||||
|
Prompt: r.Prompt,
|
||||||
|
Messages: messages,
|
||||||
|
WordWrap: r.WordWrap,
|
||||||
|
Format: r.Format,
|
||||||
|
System: r.System,
|
||||||
|
Images: images,
|
||||||
|
Options: opts,
|
||||||
|
MultiModal: r.MultiModal,
|
||||||
|
KeepAlive: r.KeepAlive,
|
||||||
|
Think: think,
|
||||||
|
HideThinking: r.HideThinking,
|
||||||
|
ShowConnect: r.ShowConnect,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type displayResponseState struct {
|
type displayResponseState struct {
|
||||||
lineLength int
|
lineLength int
|
||||||
wordBuffer string
|
wordBuffer string
|
||||||
|
|||||||
320
cmd/cmd_test.go
320
cmd/cmd_test.go
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -491,9 +492,35 @@ func TestPushHandler(t *testing.T) {
|
|||||||
w.(http.Flusher).Flush()
|
w.(http.Flusher).Flush()
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/api/me": func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("expected POST request, got %s", r.Method)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
|
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "not signed in push",
|
||||||
|
modelName: "notsignedin-model",
|
||||||
|
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||||
|
"/api/me": func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("expected POST request, got %s", r.Method)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "unauthorized",
|
||||||
|
"signin_url": "https://somethingsomething",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedOutput: "You need to be signed in to push",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "unauthorized push",
|
name: "unauthorized push",
|
||||||
modelName: "unauthorized-model",
|
modelName: "unauthorized-model",
|
||||||
@@ -508,6 +535,11 @@ func TestPushHandler(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/api/me": func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("expected POST request, got %s", r.Method)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
||||||
},
|
},
|
||||||
@@ -525,6 +557,9 @@ func TestPushHandler(t *testing.T) {
|
|||||||
defer mockServer.Close()
|
defer mockServer.Close()
|
||||||
|
|
||||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
t.Setenv("HOME", tmpDir)
|
||||||
|
t.Setenv("USERPROFILE", tmpDir)
|
||||||
initializeKeypair()
|
initializeKeypair()
|
||||||
|
|
||||||
cmd := &cobra.Command{}
|
cmd := &cobra.Command{}
|
||||||
@@ -561,7 +596,7 @@ func TestPushHandler(t *testing.T) {
|
|||||||
t.Errorf("expected no error, got %v", err)
|
t.Errorf("expected no error, got %v", err)
|
||||||
}
|
}
|
||||||
if tt.expectedOutput != "" {
|
if tt.expectedOutput != "" {
|
||||||
if got := string(stdout); got != tt.expectedOutput {
|
if got := string(stdout); !strings.Contains(got, tt.expectedOutput) {
|
||||||
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -919,3 +954,286 @@ func TestNewCreateRequest(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRunOptions_Copy(t *testing.T) {
|
||||||
|
// Setup test data
|
||||||
|
originalKeepAlive := &api.Duration{Duration: 5 * time.Minute}
|
||||||
|
originalThink := &api.ThinkValue{Value: "test reasoning"}
|
||||||
|
|
||||||
|
original := runOptions{
|
||||||
|
Model: "test-model",
|
||||||
|
ParentModel: "parent-model",
|
||||||
|
Prompt: "test prompt",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "hello"},
|
||||||
|
{Role: "assistant", Content: "hi there"},
|
||||||
|
},
|
||||||
|
WordWrap: true,
|
||||||
|
Format: "json",
|
||||||
|
System: "system prompt",
|
||||||
|
Images: []api.ImageData{
|
||||||
|
[]byte("image1"),
|
||||||
|
[]byte("image2"),
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 1000,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
MultiModal: true,
|
||||||
|
KeepAlive: originalKeepAlive,
|
||||||
|
Think: originalThink,
|
||||||
|
HideThinking: false,
|
||||||
|
ShowConnect: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the copy
|
||||||
|
copied := original.Copy()
|
||||||
|
|
||||||
|
// Test 1: Verify the copy is not the same instance
|
||||||
|
if &copied == &original {
|
||||||
|
t.Error("Copy should return a different instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 2: Verify all fields are copied correctly
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
got interface{}
|
||||||
|
want interface{}
|
||||||
|
}{
|
||||||
|
{"Model", copied.Model, original.Model},
|
||||||
|
{"ParentModel", copied.ParentModel, original.ParentModel},
|
||||||
|
{"Prompt", copied.Prompt, original.Prompt},
|
||||||
|
{"WordWrap", copied.WordWrap, original.WordWrap},
|
||||||
|
{"Format", copied.Format, original.Format},
|
||||||
|
{"System", copied.System, original.System},
|
||||||
|
{"MultiModal", copied.MultiModal, original.MultiModal},
|
||||||
|
{"HideThinking", copied.HideThinking, original.HideThinking},
|
||||||
|
{"ShowConnect", copied.ShowConnect, original.ShowConnect},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if !reflect.DeepEqual(tt.got, tt.want) {
|
||||||
|
t.Errorf("%s mismatch: got %v, want %v", tt.name, tt.got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 3: Verify Messages slice is deeply copied
|
||||||
|
if len(copied.Messages) != len(original.Messages) {
|
||||||
|
t.Errorf("Messages length mismatch: got %d, want %d", len(copied.Messages), len(original.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(copied.Messages) > 0 && &copied.Messages[0] == &original.Messages[0] {
|
||||||
|
t.Error("Messages should be different instances")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify original to verify independence
|
||||||
|
if len(original.Messages) > 0 {
|
||||||
|
originalContent := original.Messages[0].Content
|
||||||
|
original.Messages[0].Content = "modified"
|
||||||
|
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
|
||||||
|
t.Error("Messages should be independent after copy")
|
||||||
|
}
|
||||||
|
// Restore for other tests
|
||||||
|
original.Messages[0].Content = originalContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 4: Verify Images slice is deeply copied
|
||||||
|
if len(copied.Images) != len(original.Images) {
|
||||||
|
t.Errorf("Images length mismatch: got %d, want %d", len(copied.Images), len(original.Images))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(copied.Images) > 0 && &copied.Images[0] == &original.Images[0] {
|
||||||
|
t.Error("Images should be different instances")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify original to verify independence
|
||||||
|
if len(original.Images) > 0 {
|
||||||
|
originalImage := original.Images[0]
|
||||||
|
original.Images[0] = []byte("modified")
|
||||||
|
if len(copied.Images) > 0 && string(copied.Images[0]) == "modified" {
|
||||||
|
t.Error("Images should be independent after copy")
|
||||||
|
}
|
||||||
|
// Restore for other tests
|
||||||
|
original.Images[0] = originalImage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 5: Verify Options map is deeply copied
|
||||||
|
if len(copied.Options) != len(original.Options) {
|
||||||
|
t.Errorf("Options length mismatch: got %d, want %d", len(copied.Options), len(original.Options))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(copied.Options) > 0 && &copied.Options == &original.Options {
|
||||||
|
t.Error("Options map should be different instances")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify original to verify independence
|
||||||
|
if len(original.Options) > 0 {
|
||||||
|
originalTemp := original.Options["temperature"]
|
||||||
|
original.Options["temperature"] = 0.9
|
||||||
|
if copied.Options["temperature"] == 0.9 {
|
||||||
|
t.Error("Options should be independent after copy")
|
||||||
|
}
|
||||||
|
// Restore for other tests
|
||||||
|
original.Options["temperature"] = originalTemp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 6: Verify KeepAlive pointer is copied (shallow copy)
|
||||||
|
if copied.KeepAlive != original.KeepAlive {
|
||||||
|
t.Error("KeepAlive pointer should be the same (shallow copy)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 7: Verify Think pointer creates a new instance
|
||||||
|
if original.Think != nil && copied.Think == original.Think {
|
||||||
|
t.Error("Think should be a different instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
if original.Think != nil && copied.Think != nil {
|
||||||
|
if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) {
|
||||||
|
t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 8: Test with zero values
|
||||||
|
zeroOriginal := runOptions{}
|
||||||
|
zeroCopy := zeroOriginal.Copy()
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(zeroCopy, zeroOriginal) {
|
||||||
|
fmt.Printf("orig: %#v\ncopy: %#v\n", zeroOriginal, zeroCopy)
|
||||||
|
t.Error("Copy of zero value should equal original zero value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) {
|
||||||
|
// Test with empty slices and maps
|
||||||
|
original := runOptions{
|
||||||
|
Messages: []api.Message{},
|
||||||
|
Images: []api.ImageData{},
|
||||||
|
Options: map[string]any{},
|
||||||
|
}
|
||||||
|
|
||||||
|
copied := original.Copy()
|
||||||
|
|
||||||
|
if copied.Messages == nil {
|
||||||
|
t.Error("Empty Messages slice should remain empty, not nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if copied.Images == nil {
|
||||||
|
t.Error("Empty Images slice should remain empty, not nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if copied.Options == nil {
|
||||||
|
t.Error("Empty Options map should remain empty, not nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(copied.Messages) != 0 {
|
||||||
|
t.Error("Empty Messages slice should remain empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(copied.Images) != 0 {
|
||||||
|
t.Error("Empty Images slice should remain empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(copied.Options) != 0 {
|
||||||
|
t.Error("Empty Options map should remain empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunOptions_Copy_NilPointers(t *testing.T) {
|
||||||
|
// Test with nil pointers
|
||||||
|
original := runOptions{
|
||||||
|
KeepAlive: nil,
|
||||||
|
Think: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
copied := original.Copy()
|
||||||
|
|
||||||
|
if copied.KeepAlive != nil {
|
||||||
|
t.Error("Nil KeepAlive should remain nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if copied.Think != nil {
|
||||||
|
t.Error("Nil Think should remain nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
think *api.ThinkValue
|
||||||
|
}{
|
||||||
|
{"nil Think", nil},
|
||||||
|
{"bool true", &api.ThinkValue{Value: true}},
|
||||||
|
{"bool false", &api.ThinkValue{Value: false}},
|
||||||
|
{"string value", &api.ThinkValue{Value: "reasoning text"}},
|
||||||
|
{"int value", &api.ThinkValue{Value: 42}},
|
||||||
|
{"nil value", &api.ThinkValue{Value: nil}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
original := runOptions{Think: tt.think}
|
||||||
|
copied := original.Copy()
|
||||||
|
|
||||||
|
if tt.think == nil {
|
||||||
|
if copied.Think != nil {
|
||||||
|
t.Error("Nil Think should remain nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if copied.Think == nil {
|
||||||
|
t.Error("Non-nil Think should not become nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if copied.Think == original.Think {
|
||||||
|
t.Error("Think should be a different instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) {
|
||||||
|
t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||||
|
// Test that modifications to original don't affect copy
|
||||||
|
originalThink := &api.ThinkValue{Value: "original"}
|
||||||
|
original := runOptions{
|
||||||
|
Model: "original-model",
|
||||||
|
Messages: []api.Message{{Role: "user", Content: "original"}},
|
||||||
|
Options: map[string]any{"key": "value"},
|
||||||
|
Think: originalThink,
|
||||||
|
}
|
||||||
|
|
||||||
|
copied := original.Copy()
|
||||||
|
|
||||||
|
// Modify original
|
||||||
|
original.Model = "modified-model"
|
||||||
|
if len(original.Messages) > 0 {
|
||||||
|
original.Messages[0].Content = "modified"
|
||||||
|
}
|
||||||
|
original.Options["key"] = "modified"
|
||||||
|
if original.Think != nil {
|
||||||
|
original.Think.Value = "modified"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify copy is unchanged
|
||||||
|
if copied.Model == "modified-model" {
|
||||||
|
t.Error("Copy Model should not be affected by original modification")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
|
||||||
|
t.Error("Copy Messages should not be affected by original modification")
|
||||||
|
}
|
||||||
|
|
||||||
|
if copied.Options["key"] == "modified" {
|
||||||
|
t.Error("Copy Options should not be affected by original modification")
|
||||||
|
}
|
||||||
|
|
||||||
|
if copied.Think != nil && copied.Think.Value == "modified" {
|
||||||
|
t.Error("Copy Think should not be affected by original modification")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -195,16 +195,24 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
fmt.Println("Usage:\n /load <modelname>")
|
fmt.Println("Usage:\n /load <modelname>")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
origOpts := opts.Copy()
|
||||||
|
|
||||||
opts.Model = args[1]
|
opts.Model = args[1]
|
||||||
opts.Messages = []api.Message{}
|
opts.Messages = []api.Message{}
|
||||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||||
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
|
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "not found") {
|
||||||
|
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
|
||||||
|
opts = origOpts.Copy()
|
||||||
|
continue
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
fmt.Printf("error: %v\n", err)
|
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
|
||||||
|
opts = origOpts.Copy()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if strings.Contains(err.Error(), "does not support thinking") {
|
if strings.Contains(err.Error(), "does not support thinking") {
|
||||||
|
|||||||
40
docs/cloud.md
Normal file
40
docs/cloud.md
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# Cloud
|
||||||
|
|
||||||
|
| Ollama's cloud is currently in preview. For full documentation, see [Ollama's documentation](https://docs.ollama.com/cloud).
|
||||||
|
|
||||||
|
## Cloud Models
|
||||||
|
|
||||||
|
[Cloud models](https://ollama.com/cloud) are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn’t fit on a personal computer.
|
||||||
|
|
||||||
|
Ollama currently supports the following cloud models, with more coming soon:
|
||||||
|
|
||||||
|
- `gpt-oss:20b-cloud`
|
||||||
|
- `gpt-oss:120b-cloud`
|
||||||
|
- `deepseek-v3.1:671b-cloud`
|
||||||
|
- `qwen3-coder:480b-cloud`
|
||||||
|
|
||||||
|
### Get started
|
||||||
|
|
||||||
|
To run a cloud model, open the terminal and run:
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama run gpt-oss:120b-cloud
|
||||||
|
```
|
||||||
|
|
||||||
|
To run cloud models with integrations that work with Ollama, first download the cloud model:
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama pull qwen3-coder:480b-cloud
|
||||||
|
```
|
||||||
|
|
||||||
|
Then sign in to Ollama:
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama signin
|
||||||
|
```
|
||||||
|
|
||||||
|
Finally, access the model using the model name `qwen3-coder:480b-cloud` via Ollama's local API or tooling.
|
||||||
|
|
||||||
|
## Cloud API access
|
||||||
|
|
||||||
|
Cloud models can also be accessed directly on ollama.com's API. For more information, see the [docs](https://docs.ollama.com/cloud).
|
||||||
107
docs/turbo.md
107
docs/turbo.md
@@ -1,107 +0,0 @@
|
|||||||
# Turbo
|
|
||||||
|
|
||||||
> ⚠️ Turbo is preview
|
|
||||||
|
|
||||||
Ollama’s [Turbo](https://ollama.com/turbo) is a new way to run open-source models with acceleration from datacenter-grade hardware.
|
|
||||||
|
|
||||||
Currently, the following models are available in Turbo:
|
|
||||||
|
|
||||||
- `gpt-oss:20b`
|
|
||||||
- `gpt-oss:120b`
|
|
||||||
|
|
||||||
## Get started
|
|
||||||
|
|
||||||
### Ollama for macOS & Windows
|
|
||||||
|
|
||||||
Download Ollama
|
|
||||||
|
|
||||||
- Select a model such as `gpt-oss:20b` or `gpt-oss:120b`
|
|
||||||
- Click on **Turbo**. You’ll be prompted to create an account or sign in
|
|
||||||
|
|
||||||
### Ollama’s CLI
|
|
||||||
|
|
||||||
- [Sign up](https://ollama.com/signup) for an Ollama account
|
|
||||||
- Add your Ollama key [to ollama.com](https://ollama.com/settings/keys).
|
|
||||||
|
|
||||||
On macOS and Linux:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
cat ~/.ollama/id_ed25519.pub
|
|
||||||
```
|
|
||||||
|
|
||||||
On Windows:
|
|
||||||
|
|
||||||
```
|
|
||||||
type "%USERPROFILE%\.ollama\id_ed25519.pub"
|
|
||||||
```
|
|
||||||
|
|
||||||
- Then run a model setting `OLLAMA_HOST` to `ollama.com`:
|
|
||||||
```shell
|
|
||||||
OLLAMA_HOST=ollama.com ollama run gpt-oss:120b
|
|
||||||
```
|
|
||||||
|
|
||||||
### Ollama’s Python library
|
|
||||||
|
|
||||||
- Download Ollama's [Python library](https://github.com/ollama/ollama-python)
|
|
||||||
- [Sign up](https://ollama.com/signup) for an Ollama account
|
|
||||||
- Create an API key by visiting https://ollama.com/settings/keys
|
|
||||||
|
|
||||||
```python
|
|
||||||
from ollama import Client
|
|
||||||
|
|
||||||
client = Client(
|
|
||||||
host="https://ollama.com",
|
|
||||||
headers={'Authorization': '<api key>'}
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
'role': 'user',
|
|
||||||
'content': 'Why is the sky blue?',
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
for part in client.chat('gpt-oss:120b', messages=messages, stream=True):
|
|
||||||
print(part['message']['content'], end='', flush=True)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Ollama’s JavaScript library
|
|
||||||
|
|
||||||
- Download Ollama's [JavaScript library](https://github.com/ollama/ollama-js)
|
|
||||||
- [Sign up](https://ollama.com/signup) for an Ollama account
|
|
||||||
- Create an API key by visiting https://ollama.com/settings/keys
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
import { Ollama } from 'ollama';
|
|
||||||
|
|
||||||
const ollama = new Ollama({
|
|
||||||
host: 'https://ollama.com',
|
|
||||||
headers: {
|
|
||||||
Authorization: "Bearer <api key>"
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
const response = await ollama.chat({
|
|
||||||
model: 'gpt-oss:120b',
|
|
||||||
messages: [{ role: 'user', content: 'Explain quantum computing' }],
|
|
||||||
stream: true
|
|
||||||
});
|
|
||||||
|
|
||||||
for await (const part of response) {
|
|
||||||
process.stdout.write(part.message.content)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Community integrations
|
|
||||||
|
|
||||||
Turbo mode is also compatible with several community integrations.
|
|
||||||
|
|
||||||
#### Open WebUI
|
|
||||||
|
|
||||||
- Go to **settings** → **Admin settings** → **Connections**
|
|
||||||
- Under **Ollama API,** click **+**
|
|
||||||
- For the **URL** put `https://ollama.com`
|
|
||||||
- For the **API key,** create an API key on https://ollama.com/settings/keys and add it.
|
|
||||||
- Click **Save**
|
|
||||||
|
|
||||||
Now, if you navigate to the model selector, Turbo models should be available under **External**.
|
|
||||||
@@ -244,6 +244,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
|||||||
"gemma3n",
|
"gemma3n",
|
||||||
"mistral3",
|
"mistral3",
|
||||||
"qwen3",
|
"qwen3",
|
||||||
|
"qwen3moe",
|
||||||
"llama4",
|
"llama4",
|
||||||
"mllama",
|
"mllama",
|
||||||
"qwen25vl",
|
"qwen25vl",
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package harmony
|
package harmony
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -265,6 +266,8 @@ type HarmonyMessageHandler struct {
|
|||||||
state harmonyMessageState
|
state harmonyMessageState
|
||||||
HarmonyParser *HarmonyParser
|
HarmonyParser *HarmonyParser
|
||||||
FunctionNameMap *FunctionNameMap
|
FunctionNameMap *FunctionNameMap
|
||||||
|
toolAccumulator *HarmonyToolCallAccumulator
|
||||||
|
convertedTools map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHarmonyMessageHandler creates a new message handler
|
// NewHarmonyMessageHandler creates a new message handler
|
||||||
@@ -277,6 +280,7 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
|||||||
HeaderEndTag: "<|message|>",
|
HeaderEndTag: "<|message|>",
|
||||||
},
|
},
|
||||||
FunctionNameMap: NewFunctionNameMap(),
|
FunctionNameMap: NewFunctionNameMap(),
|
||||||
|
convertedTools: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -384,8 +388,85 @@ func NewFunctionNameMap() *FunctionNameMap {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Init initializes the handler with tools and optional last message
|
||||||
|
// Implements the Parser interface
|
||||||
|
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||||
|
// Initialize the harmony parser
|
||||||
|
if h.HarmonyParser == nil {
|
||||||
|
h.HarmonyParser = &HarmonyParser{
|
||||||
|
MessageStartTag: "<|start|>",
|
||||||
|
MessageEndTag: "<|end|>",
|
||||||
|
HeaderEndTag: "<|message|>",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle prefill for chat mode
|
||||||
|
if lastMessage != nil {
|
||||||
|
h.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||||
|
} else {
|
||||||
|
h.HarmonyParser.AddImplicitStart()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize tool accumulator
|
||||||
|
h.toolAccumulator = h.CreateToolParser()
|
||||||
|
|
||||||
|
// Process tools and return renamed versions
|
||||||
|
if len(tools) == 0 {
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
processedTools := make([]api.Tool, len(tools))
|
||||||
|
copy(processedTools, tools)
|
||||||
|
for i, tool := range processedTools {
|
||||||
|
if tool.Function.Name != "" {
|
||||||
|
processedTools[i].Function.Name = h.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||||
|
h.convertedTools[tool.Function.Name] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return processedTools
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add implements the Parser interface - processes streamed content and extracts content, thinking, and tool calls
|
||||||
|
func (h *HarmonyMessageHandler) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||||
|
content, thinking, toolContent := h.AddContent(s, h.toolAccumulator)
|
||||||
|
if toolContent != "" {
|
||||||
|
h.toolAccumulator.Add(toolContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tool calls always happen one at a time, and always at the end of a message,
|
||||||
|
// so for simplicity we defer parsing them until we know we're done
|
||||||
|
if done {
|
||||||
|
toolName, raw := h.toolAccumulator.Drain()
|
||||||
|
if toolName != nil {
|
||||||
|
name := strings.TrimPrefix(*toolName, "functions.")
|
||||||
|
name = h.FunctionNameMap.OriginalFromConverted(name)
|
||||||
|
var args api.ToolCallFunctionArguments
|
||||||
|
if err := json.Unmarshal([]byte(raw), &args); err != nil {
|
||||||
|
return "", "", nil, fmt.Errorf("error parsing tool call: raw='%s', err=%w", raw, err)
|
||||||
|
}
|
||||||
|
calls = append(calls, api.ToolCall{Function: api.ToolCallFunction{Name: name, Arguments: args}})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return content, thinking, calls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasToolSupport implements the Parser interface
|
||||||
|
func (h *HarmonyMessageHandler) HasToolSupport() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasThinkingSupport implements the Parser interface
|
||||||
|
func (h *HarmonyMessageHandler) HasThinkingSupport() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
|
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
|
||||||
harmonyFunctionName := m.deriveName(userFunctionName)
|
harmonyFunctionName := m.deriveName(userFunctionName)
|
||||||
|
// built-in functions should not be renamed
|
||||||
|
if userFunctionName == "browser.open" || userFunctionName == "browser.search" || userFunctionName == "browser.find" || userFunctionName == "python" {
|
||||||
|
harmonyFunctionName = userFunctionName
|
||||||
|
}
|
||||||
m.userToHarmony[userFunctionName] = harmonyFunctionName
|
m.userToHarmony[userFunctionName] = harmonyFunctionName
|
||||||
m.harmonyToUser[harmonyFunctionName] = userFunctionName
|
m.harmonyToUser[harmonyFunctionName] = userFunctionName
|
||||||
return harmonyFunctionName
|
return harmonyFunctionName
|
||||||
|
|||||||
@@ -513,6 +513,7 @@ func TestFunctionConvertAndAdd(t *testing.T) {
|
|||||||
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
|
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
|
||||||
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
|
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
|
||||||
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
|
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
|
||||||
|
{name: "built-in functions should not be renamed", in: []string{"browser.open", "python", "not.a.built-in.function", "browser.not_a_real_built_in"}, want: []string{"browser.open", "python", "not_a_built_in_function", "browser_not_a_real_built_in"}},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
|
|||||||
@@ -12,3 +12,6 @@ The integration tests have 2 modes of operating.
|
|||||||
|
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||||
|
|
||||||
|
|
||||||
|
Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL`
|
||||||
@@ -22,13 +22,12 @@ func TestAPIGenerate(t *testing.T) {
|
|||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: smol,
|
Model: smol,
|
||||||
Prompt: "why is the sky blue? be brief",
|
Prompt: blueSkyPrompt,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]interface{}{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
anyResp := []string{"rayleigh", "scattering"}
|
|
||||||
|
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
@@ -120,14 +119,14 @@ func TestAPIGenerate(t *testing.T) {
|
|||||||
// Verify the response contains the expected data
|
// Verify the response contains the expected data
|
||||||
response := buf.String()
|
response := buf.String()
|
||||||
atLeastOne := false
|
atLeastOne := false
|
||||||
for _, resp := range anyResp {
|
for _, resp := range blueSkyExpected {
|
||||||
if strings.Contains(strings.ToLower(response), resp) {
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
atLeastOne = true
|
atLeastOne = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !atLeastOne {
|
if !atLeastOne {
|
||||||
t.Errorf("none of %v found in %s", anyResp, response)
|
t.Errorf("none of %v found in %s", blueSkyExpected, response)
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Error("outer test context done while waiting for generate")
|
t.Error("outer test context done while waiting for generate")
|
||||||
@@ -181,7 +180,7 @@ func TestAPIChat(t *testing.T) {
|
|||||||
Messages: []api.Message{
|
Messages: []api.Message{
|
||||||
{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: "why is the sky blue? be brief",
|
Content: blueSkyPrompt,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]interface{}{
|
||||||
@@ -189,7 +188,6 @@ func TestAPIChat(t *testing.T) {
|
|||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
anyResp := []string{"rayleigh", "scattering"}
|
|
||||||
|
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
@@ -279,14 +277,14 @@ func TestAPIChat(t *testing.T) {
|
|||||||
// Verify the response contains the expected data
|
// Verify the response contains the expected data
|
||||||
response := buf.String()
|
response := buf.String()
|
||||||
atLeastOne := false
|
atLeastOne := false
|
||||||
for _, resp := range anyResp {
|
for _, resp := range blueSkyExpected {
|
||||||
if strings.Contains(strings.ToLower(response), resp) {
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
atLeastOne = true
|
atLeastOne = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !atLeastOne {
|
if !atLeastOne {
|
||||||
t.Errorf("none of %v found in %s", anyResp, response)
|
t.Errorf("none of %v found in %s", blueSkyExpected, response)
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Error("outer test context done while waiting for chat")
|
t.Error("outer test context done while waiting for chat")
|
||||||
|
|||||||
@@ -19,14 +19,14 @@ func TestBlueSky(t *testing.T) {
|
|||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: smol,
|
Model: smol,
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: blueSkyPrompt,
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]any{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
GenerateTestHelper(ctx, t, req, blueSkyExpected)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnicode(t *testing.T) {
|
func TestUnicode(t *testing.T) {
|
||||||
@@ -110,12 +110,12 @@ func TestUnicodeModelDir(t *testing.T) {
|
|||||||
|
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: smol,
|
Model: smol,
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: blueSkyPrompt,
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]any{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
GenerateTestHelper(ctx, t, req, blueSkyExpected)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,11 +63,11 @@ func TestContextExhaustion(t *testing.T) {
|
|||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
t.Fatalf("PullIfMissing failed: %v", err)
|
t.Fatalf("PullIfMissing failed: %v", err)
|
||||||
}
|
}
|
||||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second)
|
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||||
func TestGenerateWithHistory(t *testing.T) {
|
func TestParallelGenerateWithHistory(t *testing.T) {
|
||||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||||
req, resp := GenerateRequests()
|
req, resp := GenerateRequests()
|
||||||
numParallel := 2
|
numParallel := 2
|
||||||
@@ -113,8 +113,48 @@ func TestGenerateWithHistory(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send generate requests with prior context and ensure the response is coherant and expected
|
||||||
|
func TestGenerateWithHistory(t *testing.T) {
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: smol,
|
||||||
|
Prompt: rainbowPrompt,
|
||||||
|
Stream: &stream,
|
||||||
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
|
Options: map[string]any{
|
||||||
|
"num_ctx": 16384,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
softTimeout, hardTimeout := getTimeouts(t)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Get the server running (if applicable) warm the model up with a single initial request
|
||||||
|
slog.Info("loading", "model", req.Model)
|
||||||
|
err := client.Generate(ctx,
|
||||||
|
&api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options},
|
||||||
|
func(response api.GenerateResponse) error { return nil },
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
|
||||||
|
|
||||||
|
for i := 0; i < len(rainbowFollowups); i++ {
|
||||||
|
req.Prompt = rainbowFollowups[i]
|
||||||
|
if time.Now().Sub(started) > softTimeout {
|
||||||
|
slog.Info("exceeded soft timeout, winding down test")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||||
func TestChatWithHistory(t *testing.T) {
|
func TestParallelChatWithHistory(t *testing.T) {
|
||||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||||
req, resp := ChatRequests()
|
req, resp := ChatRequests()
|
||||||
numParallel := 2
|
numParallel := 2
|
||||||
@@ -164,3 +204,55 @@ func TestChatWithHistory(t *testing.T) {
|
|||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send generate requests with prior context and ensure the response is coherant and expected
|
||||||
|
func TestChatWithHistory(t *testing.T) {
|
||||||
|
req := api.ChatRequest{
|
||||||
|
Model: smol,
|
||||||
|
Stream: &stream,
|
||||||
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
|
Options: map[string]any{
|
||||||
|
"num_ctx": 16384,
|
||||||
|
},
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: rainbowPrompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
softTimeout, hardTimeout := getTimeouts(t)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Get the server running (if applicable) warm the model up with a single initial request
|
||||||
|
slog.Info("loading", "model", req.Model)
|
||||||
|
err := client.Generate(ctx,
|
||||||
|
&api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options},
|
||||||
|
func(response api.GenerateResponse) error { return nil },
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assistant := DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
|
||||||
|
|
||||||
|
for i := 0; i < len(rainbowFollowups); i++ {
|
||||||
|
if time.Now().Sub(started) > softTimeout {
|
||||||
|
slog.Info("exceeded soft timeout, winding down test")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Messages = append(req.Messages,
|
||||||
|
*assistant,
|
||||||
|
api.Message{Role: "user", Content: rainbowFollowups[i]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assistant = DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
|
||||||
|
if assistant == nil {
|
||||||
|
t.Fatalf("didn't get an assistant response for context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ package integration
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -20,6 +22,7 @@ func TestLibraryModelsGenerate(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
|
||||||
|
|
||||||
chatModels := libraryChatModels
|
chatModels := libraryChatModels
|
||||||
for _, model := range chatModels {
|
for _, model := range chatModels {
|
||||||
@@ -30,16 +33,26 @@ func TestLibraryModelsGenerate(t *testing.T) {
|
|||||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||||
t.Fatalf("pull failed %s", err)
|
t.Fatalf("pull failed %s", err)
|
||||||
}
|
}
|
||||||
|
if targetArch != "" {
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Name: model})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to show model: %s", err)
|
||||||
|
}
|
||||||
|
arch := resp.ModelInfo["general.architecture"].(string)
|
||||||
|
if arch != targetArch {
|
||||||
|
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
|
||||||
|
}
|
||||||
|
}
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: model,
|
Model: model,
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: blueSkyPrompt,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]interface{}{
|
||||||
"temperature": 0.1,
|
"temperature": 0.1,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
anyResp := []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength"}
|
anyResp := blueSkyExpected
|
||||||
// Special cases
|
// Special cases
|
||||||
if model == "duckdb-nsql" {
|
if model == "duckdb-nsql" {
|
||||||
anyResp = []string{"select", "from"}
|
anyResp = []string{"select", "from"}
|
||||||
|
|||||||
@@ -68,14 +68,13 @@ func TestModelsGenerate(t *testing.T) {
|
|||||||
// TODO - fiddle with context size
|
// TODO - fiddle with context size
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: model,
|
Model: model,
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: blueSkyPrompt,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]interface{}{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}
|
DoGenerate(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
|
||||||
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,6 +40,18 @@ var (
|
|||||||
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
|
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
|
||||||
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
|
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
|
||||||
func TestModelsPerf(t *testing.T) {
|
func TestModelsPerf(t *testing.T) {
|
||||||
|
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
|
||||||
|
doModelPerfTest(t, ollamaEngineChatModels)
|
||||||
|
} else {
|
||||||
|
doModelPerfTest(t, append(ollamaEngineChatModels, llamaRunnerChatModels...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLibraryModelsPerf(t *testing.T) {
|
||||||
|
doModelPerfTest(t, libraryChatModels)
|
||||||
|
}
|
||||||
|
|
||||||
|
func doModelPerfTest(t *testing.T, chatModels []string) {
|
||||||
softTimeout, hardTimeout := getTimeouts(t)
|
softTimeout, hardTimeout := getTimeouts(t)
|
||||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||||
@@ -65,14 +77,12 @@ func TestModelsPerf(t *testing.T) {
|
|||||||
}
|
}
|
||||||
longPrompt := "summarize the following: " + string(data)
|
longPrompt := "summarize the following: " + string(data)
|
||||||
|
|
||||||
var chatModels []string
|
targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
|
||||||
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
|
|
||||||
chatModels = ollamaEngineChatModels
|
|
||||||
} else {
|
|
||||||
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, model := range chatModels {
|
for _, model := range chatModels {
|
||||||
|
if !strings.Contains(model, ":") {
|
||||||
|
model = model + ":latest"
|
||||||
|
}
|
||||||
t.Run(model, func(t *testing.T) {
|
t.Run(model, func(t *testing.T) {
|
||||||
if time.Now().Sub(started) > softTimeout {
|
if time.Now().Sub(started) > softTimeout {
|
||||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||||
@@ -88,6 +98,9 @@ func TestModelsPerf(t *testing.T) {
|
|||||||
}
|
}
|
||||||
arch := resp.ModelInfo["general.architecture"].(string)
|
arch := resp.ModelInfo["general.architecture"].(string)
|
||||||
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
|
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
|
||||||
|
if targetArch != "" && arch != targetArch {
|
||||||
|
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
|
||||||
|
}
|
||||||
|
|
||||||
if maxVram > 0 {
|
if maxVram > 0 {
|
||||||
resp, err := client.List(ctx)
|
resp, err := client.List(ctx)
|
||||||
@@ -151,8 +164,8 @@ func TestModelsPerf(t *testing.T) {
|
|||||||
prompt string
|
prompt string
|
||||||
anyResp []string
|
anyResp []string
|
||||||
}{
|
}{
|
||||||
{"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}},
|
{blueSkyPrompt, blueSkyExpected},
|
||||||
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}},
|
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy", "love", "sorrow", "beauty"}},
|
||||||
}
|
}
|
||||||
var gpuPercent int
|
var gpuPercent int
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
@@ -241,11 +254,12 @@ func TestModelsPerf(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Round the logged prompt count for comparisons across versions/configurations which can vary slightly
|
||||||
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
|
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
|
||||||
"MODEL",
|
"MODEL",
|
||||||
"CONTEXT",
|
"CONTEXT",
|
||||||
"GPU PERCENT",
|
"GPU PERCENT",
|
||||||
"PROMPT COUNT",
|
"APPROX PROMPT COUNT",
|
||||||
"LOAD TIME",
|
"LOAD TIME",
|
||||||
"PROMPT EVAL TPS",
|
"PROMPT EVAL TPS",
|
||||||
"EVAL TPS",
|
"EVAL TPS",
|
||||||
@@ -254,7 +268,7 @@ func TestModelsPerf(t *testing.T) {
|
|||||||
model,
|
model,
|
||||||
numCtx,
|
numCtx,
|
||||||
gpuPercent,
|
gpuPercent,
|
||||||
resp.PromptEvalCount,
|
(resp.PromptEvalCount/10)*10,
|
||||||
float64(resp.LoadDuration)/1000000000.0,
|
float64(resp.LoadDuration)/1000000000.0,
|
||||||
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
|
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
|
||||||
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),
|
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func TestQuantization(t *testing.T) {
|
|||||||
stream := true
|
stream := true
|
||||||
genReq := api.GenerateRequest{
|
genReq := api.GenerateRequest{
|
||||||
Model: newName,
|
Model: newName,
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: blueSkyPrompt,
|
||||||
KeepAlive: &api.Duration{Duration: 3 * time.Second},
|
KeepAlive: &api.Duration{Duration: 3 * time.Second},
|
||||||
Options: map[string]any{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
@@ -88,14 +88,13 @@ func TestQuantization(t *testing.T) {
|
|||||||
|
|
||||||
// Some smaller quantizations can cause models to have poor quality
|
// Some smaller quantizations can cause models to have poor quality
|
||||||
// or get stuck in repetition loops, so we stop as soon as we have any matches
|
// or get stuck in repetition loops, so we stop as soon as we have any matches
|
||||||
anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"}
|
|
||||||
reqCtx, reqCancel := context.WithCancel(ctx)
|
reqCtx, reqCancel := context.WithCancel(ctx)
|
||||||
atLeastOne := false
|
atLeastOne := false
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
genfn := func(response api.GenerateResponse) error {
|
genfn := func(response api.GenerateResponse) error {
|
||||||
buf.Write([]byte(response.Response))
|
buf.Write([]byte(response.Response))
|
||||||
fullResp := strings.ToLower(buf.String())
|
fullResp := strings.ToLower(buf.String())
|
||||||
for _, resp := range anyResp {
|
for _, resp := range blueSkyExpected {
|
||||||
if strings.Contains(fullResp, resp) {
|
if strings.Contains(fullResp, resp) {
|
||||||
atLeastOne = true
|
atLeastOne = true
|
||||||
t.Log(fullResp)
|
t.Log(fullResp)
|
||||||
|
|||||||
@@ -256,13 +256,29 @@ var (
|
|||||||
"snowflake-arctic-embed",
|
"snowflake-arctic-embed",
|
||||||
"snowflake-arctic-embed2",
|
"snowflake-arctic-embed2",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
blueSkyPrompt = "why is the sky blue? Be brief but factual in your reply"
|
||||||
|
blueSkyExpected = []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength", "interact"}
|
||||||
|
|
||||||
|
rainbowPrompt = "how do rainbows form? Be brief but factual in your reply"
|
||||||
|
rainbowFollowups = []string{
|
||||||
|
"Explain the physics involved in them. Be breif in your reply",
|
||||||
|
"Explain the chemistry involved in them. Be breif in your reply",
|
||||||
|
"Explain the quantum mechanics involved in them. Be breif in your reply",
|
||||||
|
"What are common myths related to them? Be brief in your reply",
|
||||||
|
"What are common fairytales related to them? Be brief in your reply",
|
||||||
|
"Can they form if there is no rain? Be breif in your reply",
|
||||||
|
"Can they form if there are no clouds? Be breif in your reply",
|
||||||
|
"Do they happen on other planets? Be brief in your reply",
|
||||||
|
}
|
||||||
|
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refracted", "reflect", "color", "spectrum", "frequency", "end", "gold", "fortune", "blessing", "prosperity"}
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
lifecycle.InitLogging()
|
lifecycle.InitLogging()
|
||||||
custom := os.Getenv("OLLAMA_TEST_SMOL_MODEL")
|
custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL")
|
||||||
if custom != "" {
|
if custom != "" {
|
||||||
slog.Info("setting smol test model to " + custom)
|
slog.Info("setting default test model to " + custom)
|
||||||
smol = custom
|
smol = custom
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -577,11 +593,11 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
[][]string{
|
[][]string{
|
||||||
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
{"sunlight", "scatter", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorb", "wavelength", "water", "molecule"},
|
||||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigment", "particle", "iron oxide", "rust", "air", "water", "wet", "mixture", "mixing", "mineral", "element", "decomposed", "matter", "wavelength"},
|
||||||
{"water", "droplet", "refracted", "reflect", "color", "spectrum"},
|
{"water", "droplet", "refract", "reflect", "color", "spectrum", "raindrop"},
|
||||||
{"fourth", "july", "declaration", "independence"},
|
{"fourth", "july", "declaration", "independence"},
|
||||||
{"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"},
|
{"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor", "fluid", "particles", "gas"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"iter"
|
"iter"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/dlclark/regexp2"
|
"github.com/dlclark/regexp2"
|
||||||
@@ -13,16 +14,28 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type BytePairEncoding struct {
|
type BytePairEncoding struct {
|
||||||
pre *regexp2.Regexp
|
|
||||||
vocab *Vocabulary
|
vocab *Vocabulary
|
||||||
|
regexps []*regexp2.Regexp
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||||
|
|
||||||
func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
|
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||||
|
if len(pretokenizers) == 0 {
|
||||||
|
// set default byte-level pretokenizer if none provided, e.g.
|
||||||
|
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||||
|
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||||
|
}
|
||||||
|
|
||||||
return BytePairEncoding{
|
return BytePairEncoding{
|
||||||
pre: regexp2.MustCompile(pre, regexp2.None),
|
|
||||||
vocab: vocab,
|
vocab: vocab,
|
||||||
|
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||||
|
for _, p := range pretokenizers {
|
||||||
|
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,13 +48,36 @@ func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||||
return func(yield func(string) bool) {
|
parts := []string{s}
|
||||||
for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) {
|
for _, re := range bpe.regexps {
|
||||||
|
parts = slices.Collect(func(yield func(string) bool) {
|
||||||
|
for _, part := range parts {
|
||||||
|
r := []rune(part)
|
||||||
|
var offset int
|
||||||
|
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
|
||||||
|
if offset-m.Index != 0 {
|
||||||
|
if !yield(string(r[:m.Index])) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !yield(m.String()) {
|
if !yield(m.String()) {
|
||||||
break
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
offset = m.Index + m.Length
|
||||||
|
}
|
||||||
|
|
||||||
|
if offset < len(r) {
|
||||||
|
if !yield(string(r[offset:])) {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return slices.Values(parts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// fragment is a string fragment and their corresponding token IDs
|
// fragment is a string fragment and their corresponding token IDs
|
||||||
|
|||||||
@@ -59,12 +59,12 @@ func llama(t testing.TB) BytePairEncoding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return NewBytePairEncoding(
|
return NewBytePairEncoding(
|
||||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
|
||||||
&Vocabulary{
|
&Vocabulary{
|
||||||
Values: tokens,
|
Values: tokens,
|
||||||
Types: types,
|
Types: types,
|
||||||
Merges: merges,
|
Merges: merges,
|
||||||
},
|
},
|
||||||
|
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -282,3 +282,41 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSplit(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
patterns,
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unicode",
|
||||||
|
patterns: []string{
|
||||||
|
"\\p{N}{1,3}",
|
||||||
|
`[一-龥-ゟ゠-ヿ]+`,
|
||||||
|
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
},
|
||||||
|
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "individual digits",
|
||||||
|
patterns: []string{
|
||||||
|
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
},
|
||||||
|
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
|
||||||
|
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
|
||||||
|
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
_ "image/png"
|
_ "image/png"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -171,35 +172,44 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
|||||||
// make a copy
|
// make a copy
|
||||||
tagsCopy := tags
|
tagsCopy := tags
|
||||||
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
|
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
|
||||||
tagsCopy = append(tagsCopy, ParseTags(tag))
|
tagsCopy = append(tagsCopy, parseTag(tag))
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
||||||
vv.Set(reflect.ValueOf(base))
|
vv.Set(reflect.ValueOf(base))
|
||||||
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
||||||
var fn func([]Tag) [][]string
|
var fn func([]Tag, string, string) [][]string
|
||||||
fn = func(tags []Tag) (names [][]string) {
|
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
|
||||||
if len(tags) > 0 {
|
if len(tags) > 0 {
|
||||||
localNames := []string{tags[0].Name}
|
var names []string
|
||||||
localNames = append(localNames, tags[0].Alternate...)
|
if tags[0].name != "" {
|
||||||
|
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
|
||||||
for _, localName := range localNames {
|
names = append(names, prefix+n+suffix)
|
||||||
fullName := []string{localName}
|
}
|
||||||
nested := fn(tags[1:])
|
}
|
||||||
if len(nested) > 0 {
|
childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix)
|
||||||
for _, rest := range nested {
|
if len(names) == 0 {
|
||||||
names = append(names, append(fullName, rest...))
|
// current tag has no name, use child names only
|
||||||
|
fullNames = append(fullNames, childNames...)
|
||||||
|
} else if len(childNames) == 0 {
|
||||||
|
// current tag has names but no children, create branches for each name
|
||||||
|
for _, name := range names {
|
||||||
|
fullNames = append(fullNames, []string{name})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
names = append(names, fullName)
|
// merge each name with each child
|
||||||
|
for _, name := range names {
|
||||||
|
for _, childName := range childNames {
|
||||||
|
fullNames = append(fullNames, append([]string{name}, childName...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return names
|
return fullNames
|
||||||
}
|
}
|
||||||
|
|
||||||
names := fn(tagsCopy)
|
names := fn(tagsCopy, "", "")
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
||||||
logutil.Trace("found tensor", "", tensor)
|
logutil.Trace("found tensor", "", tensor)
|
||||||
@@ -213,9 +223,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
|||||||
for i := range vv.Len() {
|
for i := range vv.Len() {
|
||||||
vvv := vv.Index(i)
|
vvv := vv.Index(i)
|
||||||
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
||||||
setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
|
setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
|
||||||
} else {
|
} else {
|
||||||
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
|
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -254,18 +264,31 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Tag struct {
|
type Tag struct {
|
||||||
Name string
|
name,
|
||||||
Alternate []string
|
// prefix and suffix are applied to child tags
|
||||||
|
prefix,
|
||||||
|
suffix string
|
||||||
|
alternatives []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseTags(s string) (tag Tag) {
|
func parseTag(s string) (tag Tag) {
|
||||||
parts := strings.Split(s, ",")
|
parts := strings.Split(s, ",")
|
||||||
if len(parts) > 0 {
|
if len(parts) > 0 {
|
||||||
tag.Name = parts[0]
|
tag.name = parts[0]
|
||||||
|
|
||||||
for _, part := range parts[1:] {
|
for _, part := range parts[1:] {
|
||||||
if value, ok := strings.CutPrefix(part, "alt:"); ok {
|
if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
|
||||||
tag.Alternate = append(tag.Alternate, value)
|
// elevate alternative to primary if no primary given
|
||||||
|
tag.name = value
|
||||||
|
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
|
||||||
|
} else if ok {
|
||||||
|
tag.alternatives = append(tag.alternatives, value)
|
||||||
|
}
|
||||||
|
if value, ok := strings.CutPrefix(part, "pre:"); ok {
|
||||||
|
tag.prefix = value
|
||||||
|
}
|
||||||
|
if value, ok := strings.CutPrefix(part, "suf:"); ok {
|
||||||
|
tag.suffix = value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,14 +22,14 @@ func TestParseTags(t *testing.T) {
|
|||||||
{
|
{
|
||||||
value: "output",
|
value: "output",
|
||||||
want: Tag{
|
want: Tag{
|
||||||
Name: "output",
|
name: "output",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
value: "output,alt:token_embd",
|
value: "output,alt:token_embd",
|
||||||
want: Tag{
|
want: Tag{
|
||||||
Name: "output",
|
name: "output",
|
||||||
Alternate: []string{
|
alternatives: []string{
|
||||||
"token_embd",
|
"token_embd",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -38,8 +38,8 @@ func TestParseTags(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
t.Run(tt.value, func(t *testing.T) {
|
t.Run(tt.value, func(t *testing.T) {
|
||||||
got := ParseTags(tt.value)
|
got := parseTag(tt.value)
|
||||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
if diff := cmp.Diff(tt.want, got, cmp.AllowUnexported((Tag{}))); diff != "" {
|
||||||
t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
|
t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -125,6 +125,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
|||||||
Input *nn.Embedding `gguf:"input"`
|
Input *nn.Embedding `gguf:"input"`
|
||||||
Output *nn.Linear `gguf:"output,alt:input"`
|
Output *nn.Linear `gguf:"output,alt:input"`
|
||||||
Nested *nested `gguf:"nested"`
|
Nested *nested `gguf:"nested"`
|
||||||
|
Tensor ml.Tensor `gguf:"leaf,alt:tensor"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var m fakeModel
|
var m fakeModel
|
||||||
@@ -133,6 +134,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
|||||||
names: []string{
|
names: []string{
|
||||||
"input.weight",
|
"input.weight",
|
||||||
"nested.b.weight",
|
"nested.b.weight",
|
||||||
|
"leaf",
|
||||||
},
|
},
|
||||||
}}, v.Elem()))
|
}}, v.Elem()))
|
||||||
|
|
||||||
@@ -142,6 +144,58 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
|||||||
Nested: &nested{
|
Nested: &nested{
|
||||||
Weight: &nn.Linear{Weight: &fakeTensor{Name: "nested.b.weight"}},
|
Weight: &nn.Linear{Weight: &fakeTensor{Name: "nested.b.weight"}},
|
||||||
},
|
},
|
||||||
|
Tensor: &fakeTensor{Name: "leaf"},
|
||||||
|
}, m); diff != "" {
|
||||||
|
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPopulateFieldsPrefixSuffixName(t *testing.T) {
|
||||||
|
type fakeBlock struct {
|
||||||
|
A *nn.Linear `gguf:"a"`
|
||||||
|
B *nn.Linear `gguf:",pre:b_"`
|
||||||
|
C *nn.Linear `gguf:",suf:_c"`
|
||||||
|
XY *nn.Linear `gguf:",pre:x_,suf:_y"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeModel struct {
|
||||||
|
Blocks []fakeBlock `gguf:"blk"`
|
||||||
|
}
|
||||||
|
|
||||||
|
m := fakeModel{
|
||||||
|
Blocks: make([]fakeBlock, 2),
|
||||||
|
}
|
||||||
|
v := reflect.ValueOf(&m)
|
||||||
|
v.Elem().Set(populateFields(Base{b: &fakeBackend{
|
||||||
|
names: []string{
|
||||||
|
"blk.0.a.weight",
|
||||||
|
"blk.0.b_weight",
|
||||||
|
"blk.0.b_bias",
|
||||||
|
"blk.0.weight_c",
|
||||||
|
"blk.0.x_weight_y",
|
||||||
|
"blk.1.a.weight",
|
||||||
|
"blk.1.b_weight",
|
||||||
|
"blk.1.b_bias",
|
||||||
|
"blk.1.weight_c",
|
||||||
|
"blk.1.x_weight_y",
|
||||||
|
},
|
||||||
|
}}, v.Elem()))
|
||||||
|
|
||||||
|
if diff := cmp.Diff(fakeModel{
|
||||||
|
Blocks: []fakeBlock{
|
||||||
|
{
|
||||||
|
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.a.weight"}},
|
||||||
|
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.b_weight"}, Bias: &fakeTensor{Name: "blk.0.b_bias"}},
|
||||||
|
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.weight_c"}},
|
||||||
|
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.x_weight_y"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.a.weight"}},
|
||||||
|
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.b_weight"}, Bias: &fakeTensor{Name: "blk.1.b_bias"}},
|
||||||
|
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.weight_c"}},
|
||||||
|
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.x_weight_y"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
}, m); diff != "" {
|
}, m); diff != "" {
|
||||||
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
|
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|||||||
324
model/models/deepseek2/model.go
Normal file
324
model/models/deepseek2/model.go
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
package deepseek2
|
||||||
|
|
||||||
|
// uses deepseek 2 architecture but written based on deepseek 3 model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/fast"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
numExpertsUsed int
|
||||||
|
numExperts int
|
||||||
|
normTopKProb bool
|
||||||
|
routedScalingFactor float32
|
||||||
|
|
||||||
|
kvLoraRank,
|
||||||
|
qkNopeHeadDim,
|
||||||
|
qkRopeHeadDim,
|
||||||
|
kqNopeHeadDim,
|
||||||
|
qkHeadDim int
|
||||||
|
qLoraRank int
|
||||||
|
vHeadDim int
|
||||||
|
|
||||||
|
hiddenSize,
|
||||||
|
numHeads,
|
||||||
|
numKVHeads,
|
||||||
|
keyLength,
|
||||||
|
valueLength,
|
||||||
|
originalContextLength int
|
||||||
|
|
||||||
|
eps,
|
||||||
|
ropeBase,
|
||||||
|
ropeScale float32
|
||||||
|
kqScale float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Options) RoPEOptions() []func(*rope.Options) {
|
||||||
|
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||||
|
return []func(*rope.Options){
|
||||||
|
rope.WithOriginalContextLength(o.originalContextLength),
|
||||||
|
rope.WithExtrapolationFactor(1.),
|
||||||
|
rope.WithAttentionFactor(attnFactor),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Attention struct {
|
||||||
|
Q *nn.Linear `gguf:"attn_q"`
|
||||||
|
|
||||||
|
QA *nn.Linear `gguf:"attn_q_a"`
|
||||||
|
QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"`
|
||||||
|
QB *nn.Linear `gguf:"attn_q_b"`
|
||||||
|
|
||||||
|
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
|
||||||
|
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||||
|
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||||
|
|
||||||
|
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
|
seqLength := hiddenStates.Dim(1)
|
||||||
|
|
||||||
|
var query ml.Tensor
|
||||||
|
if opts.qLoraRank == 0 { // nil {
|
||||||
|
query = attn.Q.Forward(ctx, hiddenStates)
|
||||||
|
} else {
|
||||||
|
query = attn.QA.Forward(ctx, hiddenStates)
|
||||||
|
query = attn.QANorm.Forward(ctx, query, opts.eps)
|
||||||
|
query = attn.QB.Forward(ctx, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
|
||||||
|
|
||||||
|
qPass := query.View(ctx, 0,
|
||||||
|
opts.qkNopeHeadDim, query.Stride(1),
|
||||||
|
query.Dim(1), query.Stride(2),
|
||||||
|
query.Dim(2))
|
||||||
|
|
||||||
|
qRot := query.View(ctx, opts.qkNopeHeadDim*query.Stride(0),
|
||||||
|
opts.qkRopeHeadDim, query.Stride(1),
|
||||||
|
query.Dim(1), query.Stride(2),
|
||||||
|
query.Dim(2))
|
||||||
|
|
||||||
|
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
|
||||||
|
|
||||||
|
kPass := compressedKV.View(ctx, 0, opts.kvLoraRank, compressedKV.Stride(1), compressedKV.Dim(1))
|
||||||
|
kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0),
|
||||||
|
opts.qkRopeHeadDim, compressedKV.Stride(1),
|
||||||
|
1, compressedKV.Stride(1),
|
||||||
|
compressedKV.Dim(1))
|
||||||
|
|
||||||
|
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||||
|
kPass = attn.KVB.Forward(ctx, kPass)
|
||||||
|
|
||||||
|
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||||
|
kPass = kv.View(ctx, 0, opts.kqNopeHeadDim, kv.Stride(1), kv.Dim(1), kv.Stride(2), kv.Dim(2))
|
||||||
|
value := kv.View(ctx, opts.kqNopeHeadDim*kv.Stride(0),
|
||||||
|
opts.vHeadDim, kv.Stride(1),
|
||||||
|
kv.Dim(1), kv.Stride(2),
|
||||||
|
kv.Dim(2)).Contiguous(ctx)
|
||||||
|
|
||||||
|
qRot = fast.RoPE(ctx, qRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||||
|
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||||
|
|
||||||
|
kRot = kRot.Repeat(ctx, 1, qPass.Dim(1))
|
||||||
|
|
||||||
|
query = qRot.Concat(ctx, qPass, 0)
|
||||||
|
key := kRot.Concat(ctx, kPass, 0)
|
||||||
|
|
||||||
|
attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache)
|
||||||
|
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||||
|
return attn.Output.Forward(ctx, attention)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MLP interface {
|
||||||
|
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
type sparse struct {
|
||||||
|
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||||
|
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||||
|
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||||
|
SharedExpert *dense `gguf:",suf:_shexp"`
|
||||||
|
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||||
|
|
||||||
|
upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||||
|
hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||||
|
hiddenStates = hiddenStates.SILU(ctx, upStates)
|
||||||
|
|
||||||
|
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||||
|
experts = experts.Mul(ctx, topKWeights)
|
||||||
|
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||||
|
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||||
|
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||||
|
}
|
||||||
|
return nextStates
|
||||||
|
}
|
||||||
|
|
||||||
|
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
scores = scores.Add(ctx, moe.ExpProbsBias)
|
||||||
|
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
|
||||||
|
return topKIndices
|
||||||
|
}
|
||||||
|
|
||||||
|
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
residuals := hiddenStates
|
||||||
|
|
||||||
|
routerLogits := moe.Router.Forward(ctx, hiddenStates)
|
||||||
|
scores := routerLogits.Sigmoid(ctx)
|
||||||
|
topKIndices := moe.topKIndices(ctx, scores, opts)
|
||||||
|
topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices)
|
||||||
|
|
||||||
|
if opts.normTopKProb {
|
||||||
|
topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||||
|
topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx))
|
||||||
|
topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor))
|
||||||
|
hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts)
|
||||||
|
sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts)
|
||||||
|
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, sharedExpertResult)
|
||||||
|
return hiddenStates
|
||||||
|
}
|
||||||
|
|
||||||
|
type dense struct {
|
||||||
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||||
|
return mlp.Down.Forward(ctx, hiddenStates)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Layer struct {
|
||||||
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||||
|
Attention *Attention
|
||||||
|
|
||||||
|
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||||
|
MLP MLP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
|
residual := hiddenStates
|
||||||
|
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||||
|
|
||||||
|
if outputs != nil {
|
||||||
|
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||||
|
residual = residual.Rows(ctx, outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
residual = hiddenStates
|
||||||
|
|
||||||
|
hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts)
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
return hiddenStates
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
model.Base
|
||||||
|
model.BytePairEncoding
|
||||||
|
|
||||||
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
|
Layers []Layer `gguf:"blk"`
|
||||||
|
|
||||||
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||||
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
|
*Options
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(c fs.Config) (model.Model, error) {
|
||||||
|
layers := make([]Layer, c.Uint("block_count"))
|
||||||
|
|
||||||
|
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
|
||||||
|
for i := range layers {
|
||||||
|
if i < firstDenseLayerIndex {
|
||||||
|
layers[i].MLP = &dense{}
|
||||||
|
} else {
|
||||||
|
layers[i].MLP = &sparse{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor"))))
|
||||||
|
kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length")))
|
||||||
|
|
||||||
|
m := Model{
|
||||||
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
|
EOS: append(
|
||||||
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
// Split regex into multiple parts (according to DeepSeek3's regex)
|
||||||
|
"\\p{N}{1,3}",
|
||||||
|
`[一-龥-ゟ゠-ヿ]+`,
|
||||||
|
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
),
|
||||||
|
Layers: layers,
|
||||||
|
Options: &Options{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
keyLength: int(c.Uint("attention.key_length")),
|
||||||
|
valueLength: int(c.Uint("attention.value_length")),
|
||||||
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
ropeBase: c.Float("rope.freq_base"),
|
||||||
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
|
numExperts: int(c.Uint("expert_count")),
|
||||||
|
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||||
|
normTopKProb: c.Bool("expert_weights_norm", true),
|
||||||
|
|
||||||
|
qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal,
|
||||||
|
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
|
||||||
|
qkHeadDim: int(c.Uint("attention.key_length")),
|
||||||
|
vHeadDim: int(c.Uint("attention.value_length")),
|
||||||
|
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
|
||||||
|
qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
|
||||||
|
kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
|
||||||
|
|
||||||
|
routedScalingFactor: c.Float("expert_weights_scale"),
|
||||||
|
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||||
|
|
||||||
|
kqScale: kqScale,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
return fast.RoPE(ctx, key, shift, m.qkRopeHeadDim, m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
|
|
||||||
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
m.Cache.SetLayer(i)
|
||||||
|
|
||||||
|
var outputs ml.Tensor
|
||||||
|
if i == len(m.Layers)-1 {
|
||||||
|
outputs = batch.Outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
model.Register("deepseek2", New)
|
||||||
|
}
|
||||||
@@ -227,17 +227,6 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
m := Transformer{
|
m := Transformer{
|
||||||
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
|
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
c.String("tokenizer.ggml.pretokenizer",
|
|
||||||
strings.Join([]string{
|
|
||||||
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
|
||||||
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
|
||||||
`\p{N}{1,3}`,
|
|
||||||
` ?[^\s\p{L}\p{N}]+[\r\n/]*`,
|
|
||||||
`\s*[\r\n]+`,
|
|
||||||
`\s+(?!\S)`,
|
|
||||||
`\s+`,
|
|
||||||
}, "|"),
|
|
||||||
),
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
@@ -250,6 +239,15 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
strings.Join([]string{
|
||||||
|
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
||||||
|
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
||||||
|
`\p{N}{1,3}`,
|
||||||
|
` ?[^\s\p{L}\p{N}]+[\r\n/]*`,
|
||||||
|
`\s*[\r\n]+`,
|
||||||
|
`\s+(?!\S)`,
|
||||||
|
`\s+`,
|
||||||
|
}, "|"),
|
||||||
),
|
),
|
||||||
Options: Options{
|
Options: Options{
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
|||||||
@@ -54,10 +54,30 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
}
|
}
|
||||||
switch c.String("tokenizer.ggml.model") {
|
switch c.String("tokenizer.ggml.model") {
|
||||||
case "gpt2":
|
case "gpt2":
|
||||||
processor = model.NewBytePairEncoding(
|
var pretokenizers []string
|
||||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
switch c.String("tokenizer.ggml.pre") {
|
||||||
&vocabulary,
|
case "default":
|
||||||
)
|
// no-op use the default bpe pretokenizer
|
||||||
|
case "qwen2":
|
||||||
|
pretokenizers = []string{
|
||||||
|
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
}
|
||||||
|
case "refact":
|
||||||
|
pretokenizers = []string{
|
||||||
|
`\p{N}`,
|
||||||
|
`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`,
|
||||||
|
}
|
||||||
|
case "tekken":
|
||||||
|
pretokenizers = []string{
|
||||||
|
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// use a llama-style pretokenizer
|
||||||
|
pretokenizers = []string{
|
||||||
|
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||||
case "llama":
|
case "llama":
|
||||||
processor = model.NewSentencePiece(&vocabulary)
|
processor = model.NewSentencePiece(&vocabulary)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -34,8 +34,6 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
|
|||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
c.String("tokenizer.ggml.pretokenizer",
|
|
||||||
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
@@ -48,6 +46,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
),
|
),
|
||||||
ImageProcessor: newImageProcessor(c),
|
ImageProcessor: newImageProcessor(c),
|
||||||
VisionModel: newVisionModel(c),
|
VisionModel: newVisionModel(c),
|
||||||
|
|||||||
@@ -88,22 +88,10 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
|
|||||||
return nextStates
|
return nextStates
|
||||||
}
|
}
|
||||||
|
|
||||||
// TextSharedExpert is TextMLP with different tensor names
|
|
||||||
type TextSharedExpert struct {
|
|
||||||
Gate *nn.Linear `gguf:"ffn_gate_shexp"`
|
|
||||||
Up *nn.Linear `gguf:"ffn_up_shexp"`
|
|
||||||
Down *nn.Linear `gguf:"ffn_down_shexp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
|
||||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
|
||||||
return mlp.Down.Forward(ctx, hiddenStates)
|
|
||||||
}
|
|
||||||
|
|
||||||
type TextMOE struct {
|
type TextMOE struct {
|
||||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||||
Experts *TextExperts
|
Experts *TextExperts
|
||||||
SharedExpert *TextSharedExpert
|
SharedExpert *TextMLP `gguf:",suf:_shexp"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ var _ model.TextProcessor = (*Model)(nil)
|
|||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := &Model{
|
m := &Model{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
@@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
),
|
),
|
||||||
TextModel: newTextModel(c),
|
TextModel: newTextModel(c),
|
||||||
VisionModel: newVisionModel(c),
|
VisionModel: newVisionModel(c),
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ const (
|
|||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
@@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
),
|
),
|
||||||
ImageProcessor: newImageProcessor(c),
|
ImageProcessor: newImageProcessor(c),
|
||||||
VisionModel: newVisionModel(c),
|
VisionModel: newVisionModel(c),
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package models
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
_ "github.com/ollama/ollama/model/models/bert"
|
_ "github.com/ollama/ollama/model/models/bert"
|
||||||
|
_ "github.com/ollama/ollama/model/models/deepseek2"
|
||||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||||
|
|||||||
@@ -139,7 +139,6 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
m := Model{
|
m := Model{
|
||||||
Layers: make([]DecoderLayer, c.Uint("block_count")),
|
Layers: make([]DecoderLayer, c.Uint("block_count")),
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
@@ -152,6 +151,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
),
|
),
|
||||||
Options: Options{
|
Options: Options{
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ var _ model.MultimodalProcessor = (*Model)(nil)
|
|||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := &Model{
|
m := &Model{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
@@ -42,6 +41,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
),
|
),
|
||||||
TextModel: NewTextModel(c),
|
TextModel: NewTextModel(c),
|
||||||
VisionModel: newVisionModel(c),
|
VisionModel: newVisionModel(c),
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ func newEmbed(c fs.Config) (model.Model, error) {
|
|||||||
}
|
}
|
||||||
m := embedModel{
|
m := embedModel{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
@@ -48,6 +47,7 @@ func newEmbed(c fs.Config) (model.Model, error) {
|
|||||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
),
|
),
|
||||||
Model: &Model{
|
Model: &Model{
|
||||||
Layers: layers,
|
Layers: layers,
|
||||||
|
|||||||
@@ -200,7 +200,6 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
|
|
||||||
m := Model{
|
m := Model{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
@@ -213,6 +212,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||||
),
|
),
|
||||||
Layers: layers,
|
Layers: layers,
|
||||||
Options: &Options{
|
Options: &Options{
|
||||||
|
|||||||
@@ -2,10 +2,16 @@ package parsers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/harmony"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Parser interface {
|
type Parser interface {
|
||||||
Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error)
|
// Init initializes the parser with tools and optional last message for chat prefill
|
||||||
|
// Returns processed tools if the parser needs to modify them (e.g., harmony renames them)
|
||||||
|
Init(tools []api.Tool, lastMessage *api.Message) []api.Tool
|
||||||
|
// Add processes streamed content and returns parsed content, thinking, and tool calls
|
||||||
|
// The done flag indicates if this is the last chunk (used for draining accumulators)
|
||||||
|
Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error)
|
||||||
HasToolSupport() bool
|
HasToolSupport() bool
|
||||||
HasThinkingSupport() bool
|
HasThinkingSupport() bool
|
||||||
}
|
}
|
||||||
@@ -17,6 +23,8 @@ func ParserForName(name string) Parser {
|
|||||||
return parser
|
return parser
|
||||||
case "passthrough":
|
case "passthrough":
|
||||||
return &PassthroughParser{}
|
return &PassthroughParser{}
|
||||||
|
case "harmony":
|
||||||
|
return harmony.NewHarmonyMessageHandler()
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -24,7 +32,11 @@ func ParserForName(name string) Parser {
|
|||||||
|
|
||||||
type PassthroughParser struct{}
|
type PassthroughParser struct{}
|
||||||
|
|
||||||
func (p *PassthroughParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
|
func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||||
|
return tools // passthrough doesn't modify tools
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PassthroughParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||||
return s, "", nil, nil
|
return s, "", nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
@@ -31,6 +32,7 @@ const (
|
|||||||
type Qwen3CoderParser struct {
|
type Qwen3CoderParser struct {
|
||||||
state qwenParserState
|
state qwenParserState
|
||||||
acc strings.Builder
|
acc strings.Builder
|
||||||
|
tools []api.Tool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||||
@@ -41,7 +43,12 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
|
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||||
|
p.tools = tools
|
||||||
|
return tools // Qwen doesn't modify tools
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||||
p.acc.WriteString(s)
|
p.acc.WriteString(s)
|
||||||
|
|
||||||
events := p.parseEvents()
|
events := p.parseEvents()
|
||||||
@@ -51,7 +58,7 @@ func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thin
|
|||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
switch event := event.(type) {
|
switch event := event.(type) {
|
||||||
case qwenEventRawToolCall:
|
case qwenEventRawToolCall:
|
||||||
toolCall, err := parseToolCall(event, tools)
|
toolCall, err := parseToolCall(event, p.tools)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
@@ -198,12 +205,21 @@ func overlap(s, delim string) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func trailingWhitespaceLen(s string) int {
|
func trailingWhitespaceLen(s string) int {
|
||||||
for i := len(s) - 1; i >= 0; i-- {
|
remaining := s
|
||||||
if !unicode.IsSpace(rune(s[i])) {
|
total := 0
|
||||||
return len(s) - i - 1
|
for len(remaining) > 0 {
|
||||||
|
r, size := utf8.DecodeLastRuneInString(remaining)
|
||||||
|
// if it's an invalid utf8 rune, assume it isn't whitespace
|
||||||
|
if r == utf8.RuneError && size == 1 {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
if !unicode.IsSpace(r) {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
return len(s)
|
total += size
|
||||||
|
remaining = remaining[:len(remaining)-size]
|
||||||
|
}
|
||||||
|
return total
|
||||||
}
|
}
|
||||||
|
|
||||||
type XMLFunctionCall struct {
|
type XMLFunctionCall struct {
|
||||||
@@ -359,7 +375,7 @@ func parseValue(raw string, paramType api.PropertyType) any {
|
|||||||
|
|
||||||
// Try array
|
// Try array
|
||||||
if typeSet["array"] {
|
if typeSet["array"] {
|
||||||
var arr []interface{}
|
var arr []any
|
||||||
if err := json.Unmarshal([]byte(raw), &arr); err == nil {
|
if err := json.Unmarshal([]byte(raw), &arr); err == nil {
|
||||||
return arr
|
return arr
|
||||||
}
|
}
|
||||||
@@ -371,7 +387,7 @@ func parseValue(raw string, paramType api.PropertyType) any {
|
|||||||
|
|
||||||
// Try object
|
// Try object
|
||||||
if typeSet["object"] {
|
if typeSet["object"] {
|
||||||
var obj map[string]interface{}
|
var obj map[string]any
|
||||||
if err := json.Unmarshal([]byte(raw), &obj); err == nil {
|
if err := json.Unmarshal([]byte(raw), &obj); err == nil {
|
||||||
return obj
|
return obj
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -166,6 +166,137 @@ func TestQwenParserStreaming(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "unicode content",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "你好 🌍<tool_call>test</tool_call>مرحبا",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "你好 🌍"},
|
||||||
|
qwenEventRawToolCall{raw: "test"},
|
||||||
|
qwenEventContent{content: "مرحبا"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "arabic text handling",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "مرحبا بالعالم",
|
||||||
|
wantEvents: []qwenEvent{qwenEventContent{content: "مرحبا بالعالم"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "emoji passthrough",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "✅",
|
||||||
|
wantEvents: []qwenEvent{qwenEventContent{content: "✅"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "emoji after tool call",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "<tool_call>test</tool_call>完成 ✅",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventRawToolCall{raw: "test"},
|
||||||
|
qwenEventContent{content: "完成 ✅"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "unicode streaming with whitespace handling",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "مرحبا",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "مرحبا"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: " \n",
|
||||||
|
wantEvents: []qwenEvent{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "世界",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: " \n世界"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "non-breaking space withheld across chunks",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "Hello\u00a0",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "Hello"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "world",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "\u00a0world"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "ideographic space before partial tool",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "Hello\u3000<tool",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "Hello"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "_call>abc",
|
||||||
|
wantEvents: []qwenEvent{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "</tool_call>def",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventRawToolCall{raw: "abc"},
|
||||||
|
qwenEventContent{content: "def"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "ideographic space before partial tool fakeout",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "Hello\u3000<tool",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "Hello"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "fakeout>abc",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "\u3000<toolfakeout>abc"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "unicode with partial tool tag",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "测试🎯 <to",
|
||||||
|
wantEvents: []qwenEvent{
|
||||||
|
qwenEventContent{content: "测试🎯"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
anyOnlies := false
|
anyOnlies := false
|
||||||
@@ -347,6 +478,27 @@ ls && echo "a > b and a < b"
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "unicode in function names and parameters",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: `<function=获取天气>
|
||||||
|
<parameter=城市>
|
||||||
|
北京
|
||||||
|
</parameter>
|
||||||
|
<parameter=message>
|
||||||
|
Hello! 你好! 🌟 مرحبا
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "获取天气",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"城市": "北京",
|
||||||
|
"message": "Hello! 你好! 🌟 مرحبا",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, step := range steps {
|
for i, step := range steps {
|
||||||
@@ -360,6 +512,42 @@ ls && echo "a > b and a < b"
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTrailingWhitespaceLenUnicode(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ascii space",
|
||||||
|
input: "Hello ",
|
||||||
|
want: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-breaking space",
|
||||||
|
input: "Hello\u00a0",
|
||||||
|
want: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ideographic space",
|
||||||
|
input: "Hello\u3000",
|
||||||
|
want: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple runes of whitespace",
|
||||||
|
input: "Hi\u00a0\u3000",
|
||||||
|
want: 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
got := trailingWhitespaceLen(tc.input)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("%s: trailingWhitespaceLen(%q) = %d, want %d", tc.name, tc.input, got, tc.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestQwenToolCallValueParsing(t *testing.T) {
|
func TestQwenToolCallValueParsing(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
desc string
|
desc string
|
||||||
@@ -867,6 +1055,8 @@ func TestTrailingWhitespaceLen(t *testing.T) {
|
|||||||
{desc: "trailing whitespace with newlines", s: "abc \n", want: 2},
|
{desc: "trailing whitespace with newlines", s: "abc \n", want: 2},
|
||||||
{desc: "only whitespace", s: " \n ", want: 4},
|
{desc: "only whitespace", s: " \n ", want: 4},
|
||||||
{desc: "leading whitespace doesn't count", s: " \n abc", want: 0},
|
{desc: "leading whitespace doesn't count", s: " \n abc", want: 0},
|
||||||
|
{desc: "unicode with trailing space", s: "测试🎯 ", want: 1},
|
||||||
|
{desc: "unicode with trailing tab and newline", s: "مرحبا\t\n", want: 2},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
@@ -876,3 +1066,30 @@ func TestTrailingWhitespaceLen(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOverlapFunction(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
desc string
|
||||||
|
s string
|
||||||
|
delim string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{desc: "no overlap", s: "hello", delim: "<tool", want: 0},
|
||||||
|
{desc: "full overlap", s: "hello<tool", delim: "<tool>", want: 5},
|
||||||
|
{desc: "partial overlap", s: "hello<to", delim: "<tool>", want: 3},
|
||||||
|
{desc: "unicode with partial overlap", s: "测试🎯<to", delim: "<tool>", want: 3},
|
||||||
|
{desc: "unicode string with no overlap", s: "مرحبا", delim: "<tool>", want: 0},
|
||||||
|
{desc: "unicode at boundary", s: "世界<", delim: "<tool>", want: 1},
|
||||||
|
{desc: "unicode delimiter single rune", s: "hello🔧", delim: "🔧工具", want: len("🔧")},
|
||||||
|
{desc: "unicode delimiter multiple runes", s: "hello🔧工", delim: "🔧工具", want: len("🔧工")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
|
got := overlap(tc.s, tc.delim)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("overlap(%q, %q) = %d, want %d", tc.s, tc.delim, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -82,7 +82,6 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
|
|||||||
merges := make([]string, 0, 1)
|
merges := make([]string, 0, 1)
|
||||||
// Only need vocab for Grammar Test
|
// Only need vocab for Grammar Test
|
||||||
return model.NewBytePairEncoding(
|
return model.NewBytePairEncoding(
|
||||||
``,
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: tokens,
|
Values: tokens,
|
||||||
Types: make([]int32, len(vocab)),
|
Types: make([]int32, len(vocab)),
|
||||||
|
|||||||
222
server/routes.go
222
server/routes.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"cmp"
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -34,7 +35,6 @@ import (
|
|||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/harmony"
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/model/parsers"
|
"github.com/ollama/ollama/model/parsers"
|
||||||
@@ -49,6 +49,8 @@ import (
|
|||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||||
|
|
||||||
func shouldUseHarmony(model *Model) bool {
|
func shouldUseHarmony(model *Model) bool {
|
||||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||||
// heuristic to check whether the template expects to be parsed via harmony:
|
// heuristic to check whether the template expects to be parsed via harmony:
|
||||||
@@ -151,6 +153,17 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
|||||||
return runner.llama, model, &opts, nil
|
return runner.llama, model, &opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func signinURL() (string, error) {
|
||||||
|
pubKey, err := auth.GetPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||||
|
h, _ := os.Hostname()
|
||||||
|
return fmt.Sprintf(signinURLStr, url.PathEscape(h), encKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
@@ -251,18 +264,21 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
client := api.NewClient(remoteURL, http.DefaultClient)
|
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||||
err = client.Generate(c, &req, fn)
|
err = client.Generate(c, &req, fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var sErr api.AuthorizationError
|
var authError api.AuthorizationError
|
||||||
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
if errors.As(err, &authError) {
|
||||||
pk, pkErr := auth.GetPublicKey()
|
sURL, sErr := signinURL()
|
||||||
if pkErr != nil {
|
if sErr != nil {
|
||||||
slog.Error("couldn't get public key", "error", pkErr)
|
slog.Error(sErr.Error())
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{
|
|
||||||
"error": "unauthorized",
|
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
|
||||||
"public_key": pk,
|
return
|
||||||
})
|
}
|
||||||
|
var apiError api.StatusError
|
||||||
|
if errors.As(err, &apiError) {
|
||||||
|
c.JSON(apiError.StatusCode, apiError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -291,17 +307,21 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
useHarmony := shouldUseHarmony(m) && !req.Raw
|
var builtinParser parsers.Parser
|
||||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
if shouldUseHarmony(m) && m.Config.Parser == "" {
|
||||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
m.Config.Parser = "harmony"
|
||||||
if useHarmony {
|
|
||||||
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
|
||||||
harmonyMessageHandler.HarmonyParser.AddImplicitStart()
|
|
||||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate Think value: string values currently only allowed for gptoss models
|
if !req.Raw && m.Config.Parser != "" {
|
||||||
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
builtinParser = parsers.ParserForName(m.Config.Parser)
|
||||||
|
if builtinParser != nil {
|
||||||
|
// no tools or last message for generate endpoint
|
||||||
|
builtinParser.Init(nil, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate Think value: string values currently only allowed for harmony/gptoss models
|
||||||
|
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -425,7 +445,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var thinkingState *thinking.Parser
|
var thinkingState *thinking.Parser
|
||||||
if !useHarmony {
|
if builtinParser == nil {
|
||||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||||
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
||||||
thinkingState = &thinking.Parser{
|
thinkingState = &thinking.Parser{
|
||||||
@@ -462,11 +482,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if useHarmony {
|
if builtinParser != nil {
|
||||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
|
content, thinking, toolCalls, err := builtinParser.Add(cr.Content, cr.Done)
|
||||||
|
if err != nil {
|
||||||
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
res.Response = content
|
res.Response = content
|
||||||
res.Thinking = thinking
|
res.Thinking = thinking
|
||||||
harmonyToolParser.Add(toolContent)
|
if cr.Done && len(toolCalls) > 0 {
|
||||||
|
res.ToolCalls = toolCalls
|
||||||
|
}
|
||||||
} else if thinkingState != nil {
|
} else if thinkingState != nil {
|
||||||
thinking, content := thinkingState.AddContent(cr.Content)
|
thinking, content := thinkingState.AddContent(cr.Content)
|
||||||
res.Thinking = thinking
|
res.Thinking = thinking
|
||||||
@@ -478,26 +504,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cr.Done {
|
if cr.Done {
|
||||||
if useHarmony {
|
|
||||||
toolName, toolContent := harmonyToolParser.Drain()
|
|
||||||
if toolName != nil {
|
|
||||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
|
||||||
var args api.ToolCallFunctionArguments
|
|
||||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
|
||||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
|
||||||
ch <- gin.H{"error": errStr}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: *toolName,
|
|
||||||
Arguments: args,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res.DoneReason = cr.DoneReason.String()
|
res.DoneReason = cr.DoneReason.String()
|
||||||
res.TotalDuration = time.Since(checkpointStart)
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
@@ -512,7 +518,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if useHarmony {
|
if builtinParser != nil {
|
||||||
// only send messages with meaningful content (empty messages confuse clients)
|
// only send messages with meaningful content (empty messages confuse clients)
|
||||||
if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 {
|
if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 {
|
||||||
ch <- res
|
ch <- res
|
||||||
@@ -1423,9 +1429,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
r.POST("/api/show", s.ShowHandler)
|
r.POST("/api/show", s.ShowHandler)
|
||||||
r.DELETE("/api/delete", s.DeleteHandler)
|
r.DELETE("/api/delete", s.DeleteHandler)
|
||||||
|
|
||||||
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
|
|
||||||
r.POST("/api/me", s.WhoamiHandler)
|
r.POST("/api/me", s.WhoamiHandler)
|
||||||
|
|
||||||
|
r.POST("/api/signout", s.SignoutHandler)
|
||||||
|
// deprecated
|
||||||
|
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
|
||||||
|
|
||||||
// Create
|
// Create
|
||||||
r.POST("/api/create", s.CreateHandler)
|
r.POST("/api/create", s.CreateHandler)
|
||||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||||
@@ -1636,11 +1645,32 @@ func (s *Server) WhoamiHandler(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error(err.Error())
|
slog.Error(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// user isn't signed in
|
||||||
|
if user != nil && user.Name == "" {
|
||||||
|
sURL, sErr := signinURL()
|
||||||
|
if sErr != nil {
|
||||||
|
slog.Error(sErr.Error())
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, user)
|
c.JSON(http.StatusOK, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) SignoutHandler(c *gin.Context) {
|
func (s *Server) SignoutHandler(c *gin.Context) {
|
||||||
encodedKey := c.Param("encodedKey")
|
pubKey, err := auth.GetPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("couldn't get public key", "error", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||||
|
|
||||||
// todo allow other hosts
|
// todo allow other hosts
|
||||||
u, err := url.Parse("https://ollama.com")
|
u, err := url.Parse("https://ollama.com")
|
||||||
@@ -1651,11 +1681,11 @@ func (s *Server) SignoutHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
client := api.NewClient(u, http.DefaultClient)
|
client := api.NewClient(u, http.DefaultClient)
|
||||||
err = client.Signout(c, encodedKey)
|
err = client.Disconnect(c, encKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error(err.Error())
|
var authError api.AuthorizationError
|
||||||
if strings.Contains(err.Error(), "page not found") || strings.Contains(err.Error(), "invalid credentials") {
|
if errors.As(err, &authError) {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "you are not currently signed in"})
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not currently signed in"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
|
||||||
@@ -1813,18 +1843,21 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
client := api.NewClient(remoteURL, http.DefaultClient)
|
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||||
err = client.Chat(c, &req, fn)
|
err = client.Chat(c, &req, fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var sErr api.AuthorizationError
|
var authError api.AuthorizationError
|
||||||
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
if errors.As(err, &authError) {
|
||||||
pk, pkErr := auth.GetPublicKey()
|
sURL, sErr := signinURL()
|
||||||
if pkErr != nil {
|
if sErr != nil {
|
||||||
slog.Error("couldn't get public key", "error", pkErr)
|
slog.Error(sErr.Error())
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{
|
|
||||||
"error": "unauthorized",
|
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
|
||||||
"public_key": pk,
|
return
|
||||||
})
|
}
|
||||||
|
var apiError api.StatusError
|
||||||
|
if errors.As(err, &apiError) {
|
||||||
|
c.JSON(apiError.StatusCode, apiError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -1870,32 +1903,23 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
msgs = filterThinkTags(msgs, m)
|
msgs = filterThinkTags(msgs, m)
|
||||||
|
|
||||||
var builtinParser parsers.Parser
|
if shouldUseHarmony(m) && m.Config.Parser == "" {
|
||||||
if m.Config.Parser != "" {
|
m.Config.Parser = "harmony"
|
||||||
builtinParser = parsers.ParserForName(m.Config.Parser)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
var builtinParser parsers.Parser
|
||||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
|
||||||
|
|
||||||
useHarmony := shouldUseHarmony(m) || m.Config.Parser == "harmony"
|
|
||||||
|
|
||||||
processedTools := req.Tools
|
processedTools := req.Tools
|
||||||
if useHarmony {
|
|
||||||
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
if m.Config.Parser != "" {
|
||||||
|
builtinParser = parsers.ParserForName(m.Config.Parser)
|
||||||
|
if builtinParser != nil {
|
||||||
|
// Determine last message for chat prefill
|
||||||
var lastMessage *api.Message
|
var lastMessage *api.Message
|
||||||
if len(msgs) > 0 {
|
if len(msgs) > 0 {
|
||||||
lastMessage = &msgs[len(msgs)-1]
|
lastMessage = &msgs[len(msgs)-1]
|
||||||
}
|
}
|
||||||
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
// Initialize parser and get processed tools
|
||||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
processedTools = builtinParser.Init(req.Tools, lastMessage)
|
||||||
|
|
||||||
// make a copy of tools to pass to the chat prompt. Function names may be
|
|
||||||
// renamed to be valid Harmony function names.
|
|
||||||
processedTools = make([]api.Tool, len(req.Tools))
|
|
||||||
copy(processedTools, req.Tools)
|
|
||||||
for i, tool := range processedTools {
|
|
||||||
processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1919,8 +1943,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate Think value: string values currently only allowed for gptoss models
|
// Validate Think value: string values currently only allowed for harmony/gptoss models
|
||||||
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1939,7 +1963,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var toolParser *tools.Parser
|
var toolParser *tools.Parser
|
||||||
if len(req.Tools) > 0 && !useHarmony {
|
if len(req.Tools) > 0 && (builtinParser == nil || !builtinParser.HasToolSupport()) {
|
||||||
toolParser = tools.NewParser(m.Template.Template, req.Tools)
|
toolParser = tools.NewParser(m.Template.Template, req.Tools)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1971,38 +1995,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(drifkin): fold this as much as possibleinto the generic m.Config.Parser logic
|
if builtinParser != nil {
|
||||||
if useHarmony {
|
|
||||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
|
|
||||||
res.Message.Content = content
|
|
||||||
res.Message.Thinking = thinking
|
|
||||||
harmonyToolParser.Add(toolContent)
|
|
||||||
|
|
||||||
if r.Done {
|
|
||||||
toolName, toolContent := harmonyToolParser.Drain()
|
|
||||||
if toolName != nil {
|
|
||||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
|
||||||
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
|
|
||||||
var args api.ToolCallFunctionArguments
|
|
||||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
|
||||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
|
||||||
ch <- gin.H{"error": errStr}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// only send messages with meaningful content (empty messages confuse clients)
|
|
||||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
|
|
||||||
ch <- res
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
} else if builtinParser != nil {
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
|
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
|
||||||
|
|
||||||
content, thinking, toolCalls, err := builtinParser.Add(r.Content, req.Tools)
|
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -273,9 +273,21 @@ func findArguments(buffer []byte) (map[string]any, int) {
|
|||||||
if args, ok := obj["arguments"].(map[string]any); ok {
|
if args, ok := obj["arguments"].(map[string]any); ok {
|
||||||
return args, true
|
return args, true
|
||||||
}
|
}
|
||||||
|
if argsStr, ok := obj["arguments"].(string); ok {
|
||||||
|
var argsData map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
|
||||||
|
return argsData, ok
|
||||||
|
}
|
||||||
|
}
|
||||||
if args, ok := obj["parameters"].(map[string]any); ok {
|
if args, ok := obj["parameters"].(map[string]any); ok {
|
||||||
return args, true
|
return args, true
|
||||||
}
|
}
|
||||||
|
if argsStr, ok := obj["parameters"].(string); ok {
|
||||||
|
var argsData map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
|
||||||
|
return argsData, ok
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil, true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1274,6 +1274,22 @@ func TestFindArguments(t *testing.T) {
|
|||||||
"items": []any{"{", "}", map[string]any{"key": "value"}},
|
"items": []any{"{", "}", map[string]any{"key": "value"}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "stringified arguments",
|
||||||
|
buffer: []byte(`{"name": "get_temperature", "arguments": "{\"format\": \"fahrenheit\", \"location\": \"San Francisco, CA\"}"}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"format": "fahrenheit",
|
||||||
|
"location": "San Francisco, CA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stringified parameters",
|
||||||
|
buffer: []byte(`{"name": "get_temperature", "parameters": "{\"format\": \"fahrenheit\", \"location\": \"San Francisco, CA\"}"}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"format": "fahrenheit",
|
||||||
|
"location": "San Francisco, CA",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
Reference in New Issue
Block a user