diff --git a/Dockerfile b/Dockerfile
index c84b5239..ffaa31a5 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,6 +1,7 @@
# vim: filetype=dockerfile
ARG FLAVOR=${TARGETARCH}
+ARG PARALLEL=8
ARG ROCMVERSION=6.3.3
ARG JETPACK5VERSION=r35.4.1
@@ -34,46 +35,51 @@ ENV LDFLAGS=-s
FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
+ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \
- && cmake --build --parallel --preset 'CPU' \
- && cmake --install build --component CPU --strip --parallel 8
+ && cmake --build --parallel ${PARALLEL} --preset 'CPU' \
+ && cmake --install build --component CPU --strip --parallel ${PARALLEL}
FROM base AS cuda-11
ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
+ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
- && cmake --build --parallel --preset 'CUDA 11' \
- && cmake --install build --component CUDA --strip --parallel 8
+ && cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
+ && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS cuda-12
ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH
+ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
- && cmake --build --parallel --preset 'CUDA 12' \
- && cmake --install build --component CUDA --strip --parallel 8
+ && cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
+ && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS cuda-13
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH
+ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
- && cmake --build --parallel --preset 'CUDA 13' \
- && cmake --install build --component CUDA --strip --parallel 8
+ && cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
+ && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
+ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \
- && cmake --build --parallel --preset 'ROCm 6' \
- && cmake --install build --component HIP --strip --parallel 8
+ && cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
+ && cmake --install build --component HIP --strip --parallel ${PARALLEL}
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
ARG CMAKEVERSION
@@ -81,10 +87,11 @@ RUN apt-get update && apt-get install -y curl ccache \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
+ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 5' \
- && cmake --build --parallel --preset 'JetPack 5' \
- && cmake --install build --component CUDA --strip --parallel 8
+ && cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
+ && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
ARG CMAKEVERSION
@@ -92,10 +99,11 @@ RUN apt-get update && apt-get install -y curl ccache \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
+ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 6' \
- && cmake --build --parallel --preset 'JetPack 6' \
- && cmake --install build --component CUDA --strip --parallel 8
+ && cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
+ && cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
diff --git a/api/client.go b/api/client.go
index 7cc2acb3..20e6d795 100644
--- a/api/client.go
+++ b/api/client.go
@@ -222,7 +222,17 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
return fmt.Errorf("unmarshal: %w", err)
}
- if response.StatusCode >= http.StatusBadRequest {
+ if response.StatusCode == http.StatusUnauthorized {
+ pubKey, pkErr := auth.GetPublicKey()
+ if pkErr != nil {
+ return pkErr
+ }
+ return AuthorizationError{
+ StatusCode: response.StatusCode,
+ Status: response.Status,
+ PublicKey: pubKey,
+ }
+ } else if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
@@ -428,3 +438,16 @@ func (c *Client) Version(ctx context.Context) (string, error) {
return version.Version, nil
}
+
+// Signout will disconnect an ollama instance from ollama.com
+func (c *Client) Signout(ctx context.Context, encodedKey string) error {
+ return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil)
+}
+
+func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
+ var resp UserResponse
+ if err := c.do(ctx, http.MethodPost, "/api/me", nil, &resp); err != nil {
+ return nil, err
+ }
+ return &resp, nil
+}
diff --git a/api/types.go b/api/types.go
index a7ddbc37..5b8e034c 100644
--- a/api/types.go
+++ b/api/types.go
@@ -11,6 +11,8 @@ import (
"strings"
"time"
+ "github.com/google/uuid"
+
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
@@ -36,6 +38,19 @@ func (e StatusError) Error() string {
}
}
+type AuthorizationError struct {
+ StatusCode int
+ Status string
+ PublicKey string `json:"public_key"`
+}
+
+func (e AuthorizationError) Error() string {
+ if e.Status != "" {
+ return e.Status
+ }
+ return "something went wrong, please see the ollama server logs for details"
+}
+
// ImageData represents the raw binary data of an image file.
type ImageData []byte
@@ -313,13 +328,29 @@ func (t *ToolFunction) String() string {
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
- Model string `json:"model"`
- CreatedAt time.Time `json:"created_at"`
- Message Message `json:"message"`
- DoneReason string `json:"done_reason,omitempty"`
+ // Model is the model name that generated the response.
+ Model string `json:"model"`
+ // RemoteModel is the name of the upstream model that generated the response.
+ RemoteModel string `json:"remote_model,omitempty"`
+
+ // RemoteHost is the URL of the upstream Ollama host that generated the response.
+ RemoteHost string `json:"remote_host,omitempty"`
+
+ // CreatedAt is the timestamp of the response.
+ CreatedAt time.Time `json:"created_at"`
+
+ // Message contains the message or part of a message from the model.
+ Message Message `json:"message"`
+
+ // Done specifies if the response is complete.
Done bool `json:"done"`
+ // DoneReason is the reason the model stopped generating text.
+ DoneReason string `json:"done_reason,omitempty"`
+
+ DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
+
Metrics
}
@@ -329,13 +360,6 @@ type DebugInfo struct {
ImageCount int `json:"image_count,omitempty"`
}
-// DebugTemplateResponse is returned when _debug_render_only is set to true
-type DebugTemplateResponse struct {
- Model string `json:"model"`
- CreatedAt time.Time `json:"created_at"`
- DebugInfo DebugInfo `json:"_debug_info"`
-}
-
type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
@@ -431,18 +455,47 @@ type EmbeddingResponse struct {
// CreateRequest is the request passed to [Client.Create].
type CreateRequest struct {
- Model string `json:"model"`
- Stream *bool `json:"stream,omitempty"`
+ // Model is the model name to create.
+ Model string `json:"model"`
+
+ // Stream specifies whether the response is streaming; it is true by default.
+ Stream *bool `json:"stream,omitempty"`
+
+ // Quantize is the quantization format for the model; leave blank to not change the quantization level.
Quantize string `json:"quantize,omitempty"`
- From string `json:"from,omitempty"`
- Files map[string]string `json:"files,omitempty"`
- Adapters map[string]string `json:"adapters,omitempty"`
- Template string `json:"template,omitempty"`
- License any `json:"license,omitempty"`
- System string `json:"system,omitempty"`
- Parameters map[string]any `json:"parameters,omitempty"`
- Messages []Message `json:"messages,omitempty"`
+ // From is the name of the model or file to use as the source.
+ From string `json:"from,omitempty"`
+
+ // RemoteHost is the URL of the upstream ollama API for the model (if any).
+ RemoteHost string `json:"remote_host,omitempty"`
+
+ // Files is a map of files include when creating the model.
+ Files map[string]string `json:"files,omitempty"`
+
+ // Adapters is a map of LoRA adapters to include when creating the model.
+ Adapters map[string]string `json:"adapters,omitempty"`
+
+ // Template is the template used when constructing a request to the model.
+ Template string `json:"template,omitempty"`
+
+ // License is a string or list of strings for licenses.
+ License any `json:"license,omitempty"`
+
+ // System is the system prompt for the model.
+ System string `json:"system,omitempty"`
+
+ // Parameters is a map of hyper-parameters which are applied to the model.
+ Parameters map[string]any `json:"parameters,omitempty"`
+
+ // Messages is a list of messages added to the model before chat and generation requests.
+ Messages []Message `json:"messages,omitempty"`
+
+ Renderer string `json:"renderer,omitempty"`
+ Parser string `json:"parser,omitempty"`
+
+ // Info is a map of additional information for the model
+ Info map[string]any `json:"info,omitempty"`
// Deprecated: set the model name with Model instead
Name string `json:"name"`
@@ -480,8 +533,12 @@ type ShowResponse struct {
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
+ Renderer string `json:"renderer,omitempty"`
+ Parser string `json:"parser,omitempty"`
Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"`
+ RemoteModel string `json:"remote_model,omitempty"`
+ RemoteHost string `json:"remote_host,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
@@ -540,12 +597,14 @@ type ProcessResponse struct {
// ListModelResponse is a single model description in [ListResponse].
type ListModelResponse struct {
- Name string `json:"name"`
- Model string `json:"model"`
- ModifiedAt time.Time `json:"modified_at"`
- Size int64 `json:"size"`
- Digest string `json:"digest"`
- Details ModelDetails `json:"details,omitempty"`
+ Name string `json:"name"`
+ Model string `json:"model"`
+ RemoteModel string `json:"remote_model,omitempty"`
+ RemoteHost string `json:"remote_host,omitempty"`
+ ModifiedAt time.Time `json:"modified_at"`
+ Size int64 `json:"size"`
+ Digest string `json:"digest"`
+ Details ModelDetails `json:"details,omitempty"`
}
// ProcessModelResponse is a single model description in [ProcessResponse].
@@ -569,6 +628,12 @@ type GenerateResponse struct {
// Model is the model name that generated the response.
Model string `json:"model"`
+ // RemoteModel is the name of the upstream model that generated the response.
+ RemoteModel string `json:"remote_model,omitempty"`
+
+ // RemoteHost is the URL of the upstream Ollama host that generated the response.
+ RemoteHost string `json:"remote_host,omitempty"`
+
// CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"`
@@ -592,6 +657,8 @@ type GenerateResponse struct {
Metrics
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+
+ DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
}
// ModelDetails provides details about a model.
@@ -604,6 +671,18 @@ type ModelDetails struct {
QuantizationLevel string `json:"quantization_level"`
}
+// UserResponse provides information about a user.
+type UserResponse struct {
+ ID uuid.UUID `json:"id"`
+ Email string `json:"email"`
+ Name string `json:"name"`
+ Bio string `json:"bio,omitempty"`
+ AvatarURL string `json:"avatarurl,omitempty"`
+ FirstName string `json:"firstname,omitempty"`
+ LastName string `json:"lastname,omitempty"`
+ Plan string `json:"plan,omitempty"`
+}
+
// Tensor describes the metadata for a given tensor.
type Tensor struct {
Name string `json:"name"`
diff --git a/auth/auth.go b/auth/auth.go
index e1d85412..b26e2315 100644
--- a/auth/auth.go
+++ b/auth/auth.go
@@ -19,6 +19,31 @@ import (
const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) {
+ fileIsReadable := func(fp string) bool {
+ info, err := os.Stat(fp)
+ if err != nil {
+ return false
+ }
+
+ // Check that it's a regular file, not a directory or other file type
+ if !info.Mode().IsRegular() {
+ return false
+ }
+
+ // Try to open it to check readability
+ file, err := os.Open(fp)
+ if err != nil {
+ return false
+ }
+ file.Close()
+ return true
+ }
+
+ systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey)
+ if fileIsReadable(systemPath) {
+ return systemPath, nil
+ }
+
home, err := os.UserHomeDir()
if err != nil {
return "", err
diff --git a/cmd/cmd.go b/cmd/cmd.go
index 19f1e192..294e1662 100644
--- a/cmd/cmd.go
+++ b/cmd/cmd.go
@@ -5,6 +5,7 @@ import (
"context"
"crypto/ed25519"
"crypto/rand"
+ "encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
@@ -14,6 +15,7 @@ import (
"math"
"net"
"net/http"
+ "net/url"
"os"
"os/signal"
"path/filepath"
@@ -35,6 +37,7 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
@@ -47,6 +50,8 @@ import (
"github.com/ollama/ollama/version"
)
+const ConnectInstructions = "To sign in, navigate to:\n https://ollama.com/connect?name=%s&key=%s\n\n"
+
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
if name == "" {
@@ -286,7 +291,17 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
Think: opts.Think,
}
- return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
+ return client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
+ if r.RemoteModel != "" && opts.ShowConnect {
+ p.StopAndClear()
+ if strings.HasPrefix(r.RemoteHost, "https://ollama.com") {
+ fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", r.RemoteModel)
+ } else {
+ fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", r.RemoteModel, r.RemoteHost)
+ }
+ }
+ return nil
+ })
}
func StopHandler(cmd *cobra.Command, args []string) error {
@@ -307,9 +322,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true
opts := runOptions{
- Model: args[0],
- WordWrap: os.Getenv("TERM") == "xterm-256color",
- Options: map[string]any{},
+ Model: args[0],
+ WordWrap: os.Getenv("TERM") == "xterm-256color",
+ Options: map[string]any{},
+ ShowConnect: true,
}
format, err := cmd.Flags().GetString("format")
@@ -367,6 +383,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
prompts = append([]string{string(in)}, prompts...)
+ opts.ShowConnect = false
opts.WordWrap = false
interactive = false
}
@@ -433,6 +450,21 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
+ var sErr api.AuthorizationError
+ if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
+ pubKey, pkErr := auth.GetPublicKey()
+ if pkErr != nil {
+ return pkErr
+ }
+ // the server and the client both have the same public key
+ if pubKey == sErr.PublicKey {
+ h, _ := os.Hostname()
+ encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
+ fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
+ fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
+ }
+ return nil
+ }
return err
}
@@ -453,6 +485,56 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generate(cmd, opts)
}
+func SigninHandler(cmd *cobra.Command, args []string) error {
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return err
+ }
+
+ user, err := client.Whoami(cmd.Context())
+ if err != nil {
+ return err
+ }
+
+ if user != nil && user.Name != "" {
+ fmt.Printf("You are already signed in as user '%s'\n", user.Name)
+ fmt.Println()
+ return nil
+ }
+
+ pubKey, pkErr := auth.GetPublicKey()
+ if pkErr != nil {
+ return pkErr
+ }
+ encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
+
+ h, _ := os.Hostname()
+ fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
+
+ return nil
+}
+
+func SignoutHandler(cmd *cobra.Command, args []string) error {
+ pubKey, pkErr := auth.GetPublicKey()
+ if pkErr != nil {
+ return pkErr
+ }
+ encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
+
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return err
+ }
+
+ err = client.Signout(cmd.Context(), encKey)
+ if err != nil {
+ return err
+ }
+ fmt.Println("You have signed out of ollama.com")
+ fmt.Println()
+ return nil
+}
+
func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -505,7 +587,8 @@ func PushHandler(cmd *cobra.Command, args []string) error {
if spinner != nil {
spinner.Stop()
}
- if strings.Contains(err.Error(), "access denied") {
+ errStr := strings.ToLower(err.Error())
+ if strings.Contains(errStr, "access denied") || strings.Contains(errStr, "unauthorized") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
}
return err
@@ -539,7 +622,14 @@ func ListHandler(cmd *cobra.Command, args []string) error {
for _, m := range models.Models {
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
- data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
+ var size string
+ if m.RemoteModel != "" {
+ size = "-"
+ } else {
+ size = format.HumanBytes(m.Size)
+ }
+
+ data = append(data, []string{m.Name, m.Digest[:12], size, format.HumanTime(m.ModifiedAt, "Never")})
}
}
@@ -624,8 +714,8 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
KeepAlive: &api.Duration{Duration: 0},
}
if err := loadOrUnloadModel(cmd, opts); err != nil {
- if !strings.Contains(err.Error(), "not found") {
- return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err)
+ if !strings.Contains(strings.ToLower(err.Error()), "not found") {
+ fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
}
}
@@ -736,12 +826,36 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
}
tableRender("Model", func() (rows [][]string) {
+ if resp.RemoteHost != "" {
+ rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
+ rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
+ }
+
if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch})
- rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))})
- rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)})
- rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)})
+
+ var paramStr string
+ if resp.Details.ParameterSize != "" {
+ paramStr = resp.Details.ParameterSize
+ } else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
+ if f, ok := v.(float64); ok {
+ paramStr = format.HumanNumber(uint64(f))
+ }
+ }
+ rows = append(rows, []string{"", "parameters", paramStr})
+
+ if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
+ if f, ok := v.(float64); ok {
+ rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
+ }
+ }
+
+ if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
+ if f, ok := v.(float64); ok {
+ rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
+ }
+ }
} else {
rows = append(rows, []string{"", "architecture", resp.Details.Family})
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
@@ -989,6 +1103,7 @@ type runOptions struct {
KeepAlive *api.Duration
Think *api.ThinkValue
HideThinking bool
+ ShowConnect bool
}
type displayResponseState struct {
@@ -1544,6 +1659,22 @@ func NewCLI() *cobra.Command {
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
+ signinCmd := &cobra.Command{
+ Use: "signin",
+ Short: "Sign in to ollama.com",
+ Args: cobra.ExactArgs(0),
+ PreRunE: checkServerHeartbeat,
+ RunE: SigninHandler,
+ }
+
+ signoutCmd := &cobra.Command{
+ Use: "signout",
+ Short: "Sign out from ollama.com",
+ Args: cobra.ExactArgs(0),
+ PreRunE: checkServerHeartbeat,
+ RunE: SignoutHandler,
+ }
+
listCmd := &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
@@ -1638,6 +1769,8 @@ func NewCLI() *cobra.Command {
stopCmd,
pullCmd,
pushCmd,
+ signinCmd,
+ signoutCmd,
listCmd,
psCmd,
copyCmd,
diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go
index cf5fe7ca..bb793572 100644
--- a/cmd/cmd_test.go
+++ b/cmd/cmd_test.go
@@ -3,6 +3,7 @@ package cmd
import (
"bytes"
"encoding/json"
+ "fmt"
"io"
"net/http"
"net/http/httptest"
@@ -304,6 +305,8 @@ func TestDeleteHandler(t *testing.T) {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusNotFound)
+ errPayload := `{"error":"model '%s' not found"}`
+ w.Write([]byte(fmt.Sprintf(errPayload, req.Name)))
}
return
}
@@ -346,7 +349,7 @@ func TestDeleteHandler(t *testing.T) {
}
err := DeleteHandler(cmd, []string{"test-model-not-found"})
- if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") {
+ if err == nil || !strings.Contains(err.Error(), "model 'test-model-not-found' not found") {
t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
}
}
@@ -499,7 +502,7 @@ func TestPushHandler(t *testing.T) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(map[string]string{
- "error": "access denied",
+ "error": "403: {\"errors\":[{\"code\":\"ACCESS DENIED\", \"message\":\"access denied\"}]}",
})
if err != nil {
t.Fatal(err)
@@ -522,6 +525,7 @@ func TestPushHandler(t *testing.T) {
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
+ initializeKeypair()
cmd := &cobra.Command{}
cmd.Flags().Bool("insecure", false, "")
diff --git a/convert/convert_bert.go b/convert/convert_bert.go
index a9f4b8a7..6b0d0030 100644
--- a/convert/convert_bert.go
+++ b/convert/convert_bert.go
@@ -28,6 +28,7 @@ type bertModel struct {
LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"`
+ normalizeEmbeddings bool
PoolingType uint32
}
@@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
var pooling string
for _, m := range modules {
- if m.Type == "sentence_transformers.models.Pooling" {
+ switch m.Type {
+ case "sentence_transformers.models.Pooling":
pooling = m.Path
- break
+ case "sentence_transformers.models.Normalize":
+ p.normalizeEmbeddings = true
}
}
@@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false
kv["bert.pooling_type"] = p.PoolingType
+ kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go
index 7f029f93..eea0de2f 100644
--- a/convert/reader_safetensors.go
+++ b/convert/reader_safetensors.go
@@ -96,7 +96,7 @@ type safetensor struct {
func (st safetensor) Kind() uint32 {
kind := st.tensorBase.Kind()
- if st.dtype == "BF16" && kind != tensorKindFP32 {
+ if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
kind = tensorKindBF16
}
diff --git a/convert/reader_test.go b/convert/reader_test.go
index 6dbe32a5..c3d094f1 100644
--- a/convert/reader_test.go
+++ b/convert/reader_test.go
@@ -230,3 +230,65 @@ func TestSafetensors(t *testing.T) {
})
}
}
+
+func TestSafetensorKind(t *testing.T) {
+ tests := []struct {
+ name string
+ st safetensor
+ expected uint32
+ }{
+ {
+ name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16",
+ st: safetensor{
+ tensorBase: &tensorBase{
+ name: "weight.matrix",
+ shape: []uint64{10, 10}, // will default to FP16
+ },
+ dtype: "BF16",
+ },
+ expected: tensorKindBF16,
+ },
+ {
+ name: "BF16 dtype with v. prefix should return base kind",
+ st: safetensor{
+ tensorBase: &tensorBase{
+ name: "v.weight.matrix",
+ shape: []uint64{10, 10}, // will default to FP16
+ },
+ dtype: "BF16",
+ },
+ expected: tensorKindFP16,
+ },
+ {
+ name: "BF16 dtype with FP32 base kind should return FP32",
+ st: safetensor{
+ tensorBase: &tensorBase{
+ name: "weight.matrix",
+ shape: []uint64{10}, // will default to FP32
+ },
+ dtype: "BF16",
+ },
+ expected: tensorKindFP32,
+ },
+ {
+ name: "Non-BF16 dtype should return base kind",
+ st: safetensor{
+ tensorBase: &tensorBase{
+ name: "weight.matrix",
+ shape: []uint64{10, 10}, // will default to FP16
+ },
+ dtype: "FP16",
+ },
+ expected: tensorKindFP16,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.st.Kind()
+ if result != tt.expected {
+ t.Errorf("Kind() = %d, expected %d", result, tt.expected)
+ }
+ })
+ }
+}
diff --git a/discover/cuda_common.go b/discover/cuda_common.go
index ca008af6..a2c43420 100644
--- a/discover/cuda_common.go
+++ b/discover/cuda_common.go
@@ -16,7 +16,7 @@ import (
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
var CudaTegra string = os.Getenv("JETSON_JETPACK")
-func cudaVariant(gpuInfo CudaGPUInfo) string {
+func cudaVariant(gpuInfos []CudaGPUInfo) string {
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
if CudaTegra != "" {
ver := strings.Split(CudaTegra, ".")
@@ -45,12 +45,19 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
}
}
- if gpuInfo.DriverMajor < 13 {
+ // Check GPU compute capability FIRST, lowest common denominator if multi-gpu
+ for _, gpuInfo := range gpuInfos {
+ if gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) {
+ // GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1)
+ return "v12"
+ }
+ }
+
+ // GPU is Turing or newer (CC >= 7.5) - can use newer CUDA
+ if len(gpuInfos) > 0 && gpuInfos[0].DriverMajor < 13 {
// The detected driver is older than 580 (Aug 2025)
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
- if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) {
- slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
- }
+ slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfos[0].DriverMajor, gpuInfos[0].DriverMinor))
return "v12"
}
return "v13"
diff --git a/discover/gpu.go b/discover/gpu.go
index 95070ecb..4bb0d94e 100644
--- a/discover/gpu.go
+++ b/discover/gpu.go
@@ -284,18 +284,8 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.MinimumMemory = cudaMinimumMemory
gpuInfo.DriverMajor = driverMajor
gpuInfo.DriverMinor = driverMinor
- variant := cudaVariant(gpuInfo)
- // Start with our bundled libraries
- if variant != "" {
- variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant)
- if _, err := os.Stat(variantPath); err == nil {
- // Put the variant directory first in the search path to avoid runtime linking to the wrong library
- gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...)
- }
- }
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
- gpuInfo.Variant = variant
if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) {
unsupportedGPUs = append(unsupportedGPUs,
@@ -333,6 +323,24 @@ func GetGPUInfo() GpuInfoList {
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
cudaGPUs = append(cudaGPUs, gpuInfo)
}
+ // Second pass on NVIDIA GPUs to set lowest common denominator variant and DependencyPaths
+ variant := cudaVariant(cudaGPUs)
+ var variantPath string
+ // Start with our bundled libraries
+ if variant != "" {
+ variantPath = filepath.Join(LibOllamaPath, "cuda_"+variant)
+ if _, err := os.Stat(variantPath); err != nil {
+ variantPath = ""
+ }
+ }
+
+ for i := range cudaGPUs {
+ cudaGPUs[i].Variant = variant
+ if variantPath != "" {
+ // Put the variant directory first in the search path to avoid runtime linking to the wrong library
+ cudaGPUs[i].DependencyPath = append([]string{variantPath}, cudaGPUs[i].DependencyPath...)
+ }
+ }
}
// Intel
diff --git a/docs/development.md b/docs/development.md
index 9726b5d9..ff07b5fb 100644
--- a/docs/development.md
+++ b/docs/development.md
@@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository:
go run . serve
```
+> [!NOTE]
+> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
+
+
## macOS (Apple Silicon)
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
diff --git a/envconfig/config.go b/envconfig/config.go
index 7fc01887..09243ab9 100644
--- a/envconfig/config.go
+++ b/envconfig/config.go
@@ -134,6 +134,17 @@ func LoadTimeout() (loadTimeout time.Duration) {
return loadTimeout
}
+func Remotes() []string {
+ var r []string
+ raw := strings.TrimSpace(Var("OLLAMA_REMOTES"))
+ if raw == "" {
+ r = []string{"ollama.com"}
+ } else {
+ r = strings.Split(raw, ",")
+ }
+ return r
+}
+
func Bool(k string) func() bool {
return func() bool {
if s := Var(k); s != "" {
@@ -270,6 +281,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
+ "OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
// Informational
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go
index 6b582b49..5da902bc 100644
--- a/fs/ggml/ggml.go
+++ b/fs/ggml/ggml.go
@@ -243,6 +243,7 @@ func (kv KV) OllamaEngineRequired() bool {
"gemma3",
"gemma3n",
"mistral3",
+ "qwen3",
"llama4",
"mllama",
"qwen25vl",
diff --git a/integration/embed_test.go b/integration/embed_test.go
index eb00f4ba..a6852448 100644
--- a/integration/embed_test.go
+++ b/integration/embed_test.go
@@ -8,6 +8,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
@@ -44,9 +45,8 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
}
res, err := embeddingTestHelper(ctx, client, t, req)
-
if err != nil {
- t.Fatalf("error: %v", err)
+ t.Fatal(err)
}
if len(res.Embedding) != 384 {
@@ -74,9 +74,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
}
res, err := embedTestHelper(ctx, client, t, req)
-
if err != nil {
- t.Fatalf("error: %v", err)
+ t.Fatal(err)
}
if len(res.Embeddings) != 1 {
@@ -112,9 +111,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
}
res, err := embedTestHelper(ctx, client, t, req)
-
if err != nil {
- t.Fatalf("error: %v", err)
+ t.Fatal(err)
}
if len(res.Embeddings) != 2 {
@@ -156,93 +154,135 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
truncTrue, truncFalse := true, false
- type testReq struct {
- Name string
- Request api.EmbedRequest
+ want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
+ Model: "all-minilm",
+ Input: "why",
+ })
+ if err != nil {
+ t.Fatal(err)
}
- reqs := []testReq{
+ cases := []struct {
+ name string
+ request api.EmbedRequest
+ check func(*api.EmbedResponse, error)
+ }{
{
- Name: "Target Truncation",
- Request: api.EmbedRequest{
+ name: "target truncation",
+ request: api.EmbedRequest{
Model: "all-minilm",
Input: "why",
},
- },
- {
- Name: "Default Truncate",
- Request: api.EmbedRequest{
- Model: "all-minilm",
- Input: "why is the sky blue?",
- Options: map[string]any{"num_ctx": 1},
+ check: func(got *api.EmbedResponse, err error) {
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
+ t.Errorf("embedding mismatch (-want +got):\n%s", diff)
+ }
},
},
{
- Name: "Explicit Truncate",
- Request: api.EmbedRequest{
+ name: "default truncate",
+ request: api.EmbedRequest{
+ Model: "all-minilm",
+ Input: "why is the sky blue?",
+ Options: map[string]any{"num_ctx": 3},
+ },
+ check: func(got *api.EmbedResponse, err error) {
+ if err != nil {
+ t.Fatal(err)
+ }
+ if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
+ t.Errorf("embedding mismatch (-want +got):\n%s", diff)
+ }
+ },
+ },
+ {
+ name: "explicit truncate",
+ request: api.EmbedRequest{
+ Model: "all-minilm",
+ Input: "why is the sky blue?",
+ Truncate: &truncTrue,
+ Options: map[string]any{"num_ctx": 3},
+ },
+ check: func(got *api.EmbedResponse, err error) {
+ if err != nil {
+ t.Fatal(err)
+ }
+ if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
+ t.Errorf("embedding mismatch (-want +got):\n%s", diff)
+ }
+ },
+ },
+ {
+ name: "truncate error",
+ request: api.EmbedRequest{
+ Model: "all-minilm",
+ Input: "why is the sky blue?",
+ Truncate: &truncFalse,
+ Options: map[string]any{"num_ctx": 3},
+ },
+ check: func(res *api.EmbedResponse, err error) {
+ if err.Error() != "input exceeds maximum context length" {
+ t.Fatalf("expected truncation error, got: %v", err)
+ }
+ },
+ },
+ {
+ name: "input after truncate error",
+ request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1},
},
+ check: func(res *api.EmbedResponse, err error) {
+ if err.Error() != "input after truncation exceeds maximum context length" {
+ t.Fatalf("expected truncation error, got: %v", err)
+ }
+ },
+ },
+ {
+ name: "input after truncate error",
+ request: api.EmbedRequest{
+ Model: "all-minilm",
+ Input: "why is the sky blue?",
+ Truncate: &truncTrue,
+ Options: map[string]any{"num_ctx": 0},
+ },
+ check: func(res *api.EmbedResponse, err error) {
+ if err.Error() != "input after truncation exceeds maximum context length" {
+ t.Fatalf("expected truncation error, got: %v", err)
+ }
+ },
},
}
- res := make(map[string]*api.EmbedResponse)
-
- for _, req := range reqs {
- response, err := embedTestHelper(ctx, client, t, req.Request)
- if err != nil {
- t.Fatalf("error: %v", err)
- }
- res[req.Name] = response
- }
-
- if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
- t.Fatal("expected default request to truncate correctly")
- }
-
- if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
- t.Fatal("expected default request and truncate true request to be the same")
- }
-
- // check that truncate set to false returns an error if context length is exceeded
- _, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
- Model: "all-minilm",
- Input: "why is the sky blue?",
- Truncate: &truncFalse,
- Options: map[string]any{"num_ctx": 1},
- })
-
- if err == nil {
- t.Fatal("expected error, got nil")
+ for _, req := range cases {
+ t.Run(req.name, func(t *testing.T) {
+ req.check(embedTestHelper(ctx, client, t, req.request))
+ })
}
}
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
+ t.Helper()
+
if err := PullIfMissing(ctx, client, req.Model); err != nil {
- t.Fatalf("failed to pull model %s: %v", req.Model, err)
+ t.Fatal(err)
}
- response, err := client.Embeddings(ctx, &req)
-
- if err != nil {
- return nil, err
- }
-
- return response, nil
+ return client.Embeddings(ctx, &req)
}
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
+ t.Helper()
+
if err := PullIfMissing(ctx, client, req.Model); err != nil {
- t.Fatalf("failed to pull model %s: %v", req.Model, err)
+ t.Fatal(err)
}
- response, err := client.Embed(ctx, &req)
-
- if err != nil {
- return nil, err
- }
-
- return response, nil
+ return client.Embed(ctx, &req)
}
diff --git a/logutil/logutil.go b/logutil/logutil.go
index fff277b8..00daf6a6 100644
--- a/logutil/logutil.go
+++ b/logutil/logutil.go
@@ -5,6 +5,8 @@ import (
"io"
"log/slog"
"path/filepath"
+ "runtime"
+ "time"
)
const LevelTrace slog.Level = -8
@@ -29,10 +31,18 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger {
}))
}
+type key string
+
func Trace(msg string, args ...any) {
- slog.Log(context.TODO(), LevelTrace, msg, args...)
+ TraceContext(context.WithValue(context.TODO(), key("skip"), 1), msg, args...)
}
func TraceContext(ctx context.Context, msg string, args ...any) {
- slog.Log(ctx, LevelTrace, msg, args...)
+ if logger := slog.Default(); logger.Enabled(ctx, LevelTrace) {
+ skip, _ := ctx.Value(key("skip")).(int)
+ pc, _, _, _ := runtime.Caller(1 + skip)
+ record := slog.NewRecord(time.Now(), LevelTrace, msg, pc)
+ record.Add(args...)
+ logger.Handler().Handle(ctx, record)
+ }
}
diff --git a/ml/backend.go b/ml/backend.go
index 154a0f1b..455715b0 100644
--- a/ml/backend.go
+++ b/ml/backend.go
@@ -416,6 +416,7 @@ type Tensor interface {
AddID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor
+ L2Norm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
@@ -429,12 +430,13 @@ type Tensor interface {
Sin(ctx Context) Tensor
Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor
- GELU(ctx Context) Tensor
- QuickGELU(ctx Context) Tensor
- SILU(ctx Context) Tensor
- RELU(ctx Context) Tensor
+ GELU(ctx Context, up ...Tensor) Tensor
+ SILU(ctx Context, up ...Tensor) Tensor
+ RELU(ctx Context, up ...Tensor) Tensor
Sigmoid(ctx Context) Tensor
- SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor
+
+ // AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
+ SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
Reshape(ctx Context, shape ...int) Tensor
View(ctx Context, offset int, shape ...int) Tensor
diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go
index 931386d5..49dc3e1a 100644
--- a/ml/backend/ggml/ggml.go
+++ b/ml/backend/ggml/ggml.go
@@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
}
}
+func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
+ }
+}
+
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
if w != nil {
@@ -1424,35 +1431,46 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
}
}
-func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
+func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
+ if len(t2) > 0 {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
+ }
+ }
return &Tensor{
b: t.b,
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
}
}
-func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor {
- return &Tensor{
- b: t.b,
- t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t),
+func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
+ if len(t2) > 0 {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
+ }
}
-}
-
-func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
}
}
-func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
+func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
+ if len(t2) > 0 {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
+ }
+ }
return &Tensor{
b: t.b,
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
}
}
-func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
+func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),
diff --git a/ml/nn/attention.go b/ml/nn/attention.go
index 21b4a28a..94dbde0b 100644
--- a/ml/nn/attention.go
+++ b/ml/nn/attention.go
@@ -26,6 +26,7 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
}
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
+ ctx.Forward(query)
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
@@ -39,6 +40,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
+ ctx.Forward(key, value)
if cache != nil {
cache.Put(ctx, key, value)
}
diff --git a/ml/nn/pooling/pooling.go b/ml/nn/pooling/pooling.go
new file mode 100644
index 00000000..63b63b3a
--- /dev/null
+++ b/ml/nn/pooling/pooling.go
@@ -0,0 +1,42 @@
+package pooling
+
+import (
+ "github.com/ollama/ollama/ml"
+)
+
+type Type uint32
+
+const (
+ TypeNone Type = iota
+ TypeMean
+ TypeCLS
+ TypeLast
+)
+
+func (t Type) String() string {
+ switch t {
+ case TypeMean:
+ return "Mean"
+ case TypeCLS:
+ return "CLS"
+ case TypeLast:
+ return "Last"
+ default:
+ return "Unknown"
+ }
+}
+
+func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
+ switch t {
+ case TypeMean:
+ hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
+ return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+ case TypeCLS:
+ return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
+ case TypeLast:
+ hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
+ return hiddenStates
+ default:
+ panic("unknown pooling type")
+ }
+}
diff --git a/ml/nn/pooling/pooling_test.go b/ml/nn/pooling/pooling_test.go
new file mode 100644
index 00000000..c8001945
--- /dev/null
+++ b/ml/nn/pooling/pooling_test.go
@@ -0,0 +1,79 @@
+package pooling_test
+
+import (
+ "bytes"
+ "os"
+ "slices"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/ollama/ollama/discover"
+ fsggml "github.com/ollama/ollama/fs/ggml"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/backend/ggml"
+ "github.com/ollama/ollama/ml/nn/pooling"
+)
+
+func setup(tb testing.TB, n int) ml.Backend {
+ tb.Helper()
+
+ f, err := os.CreateTemp(tb.TempDir(), "*.bin")
+ if err != nil {
+ tb.Fatal(err)
+ }
+ defer f.Close()
+
+ if err := fsggml.WriteGGUF(f, fsggml.KV{
+ "general.architecture": "test",
+ "test.block_count": uint32(1),
+ }, []*fsggml.Tensor{
+ {Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))},
+ }); err != nil {
+ tb.Fatal(err)
+ }
+
+ var gpuLayers ml.GPULayersList
+ if gpus := discover.GetGPUInfo(); len(gpus) > 0 {
+ gpuLayers = append(gpuLayers, ml.GPULayers{
+ ID: gpus[0].ID,
+ Layers: slices.Collect(func(yield func(int) bool) {
+ for i := range n {
+ if !yield(i) {
+ return
+ }
+ }
+ }),
+ })
+ }
+ b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers})
+ if err != nil {
+ tb.Fatal(err)
+ }
+
+ return b
+}
+
+func TestForward(t *testing.T) {
+ cases := map[pooling.Type][]float32{
+ pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11},
+ pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7},
+ pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15},
+ }
+ for typ, want := range cases {
+ t.Run(typ.String(), func(t *testing.T) {
+ b := setup(t, 99)
+ defer b.Close()
+
+ ctx := b.NewContext()
+ defer ctx.Close()
+
+ tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2)
+ tt = typ.Forward(ctx, tt)
+
+ ctx.Forward(tt).Compute(tt)
+ if diff := cmp.Diff(want, tt.Floats()); diff != "" {
+ t.Error(diff)
+ }
+ })
+ }
+}
diff --git a/model/input/input.go b/model/input/input.go
index bd9b53ec..35dc41b3 100644
--- a/model/input/input.go
+++ b/model/input/input.go
@@ -54,10 +54,9 @@ type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
- // Multimodal is a set of multimodal embeddings previously created by
- // EncodeMultimodal, along with an index into Inputs. Unused for text-only
- // models or for batches without multimodal elements.
- Multimodal []MultimodalIndex
+ // Outputs are the set of indicies into Inputs for which output data should
+ // be returned.
+ Outputs ml.Tensor
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
@@ -66,7 +65,8 @@ type Batch struct {
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
- // Outputs are the set of indicies into Inputs for which output data should
- // be returned.
- Outputs []int32
+ // Multimodal is a set of multimodal embeddings previously created by
+ // EncodeMultimodal, along with an index into Inputs. Unused for text-only
+ // models or for batches without multimodal elements.
+ Multimodal []MultimodalIndex
}
diff --git a/model/model.go b/model/model.go
index 3a72f09a..f3d6bb3d 100644
--- a/model/model.go
+++ b/model/model.go
@@ -5,7 +5,6 @@ import (
"fmt"
_ "image/jpeg"
_ "image/png"
- "math"
"os"
"reflect"
"strconv"
@@ -21,10 +20,15 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend"
+ "github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
)
-var ErrNoVisionModel = errors.New("this model is missing data required for image input")
+var (
+ ErrNoVisionModel = errors.New("this model is missing data required for image input")
+ ErrUnsupportedModel = errors.New("model not supported")
+ ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
+)
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
@@ -103,23 +107,12 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return nil, err
}
- arch := b.Config().Architecture()
- if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
- arch = arch + "_embed"
- }
-
- f, ok := models[arch]
- if !ok {
- return nil, fmt.Errorf("unsupported model architecture %q", arch)
- }
-
- m, err := f(b.Config())
+ m, err := modelForArch(b.Config())
if err != nil {
return nil, err
}
base := Base{b: b, config: m.Config()}
-
v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem()))
return m, nil
@@ -131,30 +124,38 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
defer r.Close()
+
meta, err := fsggml.Decode(r, -1)
if err != nil {
return nil, err
}
- return getTextProcessor(meta.KV())
-}
-func getTextProcessor(kv fsggml.KV) (TextProcessor, error) {
- arch := kv.Architecture()
- f, ok := models[arch]
- if !ok {
- return nil, fmt.Errorf("unsupported model architecture %q", arch)
- }
- m, err := f(kv)
+ m, err := modelForArch(meta.KV())
if err != nil {
return nil, err
}
+
tp, ok := m.(TextProcessor)
if !ok {
- return nil, fmt.Errorf("%v is not a TextProcessor", m)
+ return nil, ErrUnsupportedTokenizer
}
return tp, nil
}
+func modelForArch(c fs.Config) (Model, error) {
+ arch := c.Architecture()
+ if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
+ arch = arch + "_embed"
+ }
+
+ f, ok := models[arch]
+ if !ok {
+ return nil, ErrUnsupportedModel
+ }
+
+ return f(c)
+}
+
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type()
@@ -242,7 +243,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
vv = vv.Elem()
}
- vv = vv.Elem()
+ vv = reflect.Indirect(vv)
if v.IsNil() {
vv = reflect.New(v.Type().Elem()).Elem()
}
diff --git a/model/model_test.go b/model/model_test.go
index 020f9ffb..01080ffd 100644
--- a/model/model_test.go
+++ b/model/model_test.go
@@ -1,9 +1,9 @@
package model
import (
+ "errors"
"reflect"
"slices"
- "strings"
"testing"
"github.com/google/go-cmp/cmp"
@@ -12,7 +12,6 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn"
- "github.com/ollama/ollama/model/input"
)
func TestParseTags(t *testing.T) {
@@ -148,39 +147,58 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
}
}
-func TestGetTextProcessor(t *testing.T) {
- tp, err := getTextProcessor(fsggml.KV{})
- if err == nil {
- t.Error("expected error")
- } else if !strings.Contains(err.Error(), "unsupported model architecture") {
- t.Errorf("unexpected error: %v", err)
- } else if tp != nil {
- t.Error("expected nil tp")
+func TestModelForArch(t *testing.T) {
+ type fakeModel struct {
+ Model
}
- models["dummy"] = func(fs.Config) (Model, error) {
- return notTextProcessorModel{}, nil
+ type fakeEmbeddingModel struct {
+ Model
}
- tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
- if err == nil {
- t.Error("expected error")
- } else if !strings.Contains(err.Error(), "not a TextProcessor") {
- t.Errorf("unexpected error: %v", err)
- } else if tp != nil {
- t.Error("expected nil tp")
+
+ models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil }
+ models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil }
+
+ cases := []struct {
+ name string
+ config fs.Config
+ want any
+ err error
+ }{
+ {
+ name: "model",
+ config: fsggml.KV{
+ "general.architecture": "model",
+ },
+ want: fakeModel{},
+ },
+ {
+ name: "embedding",
+ config: fsggml.KV{
+ "general.architecture": "model",
+ "model.pooling_type": uint32(1),
+ },
+ want: fakeEmbeddingModel{},
+ },
+ {
+ name: "unsupported",
+ config: fsggml.KV{
+ "general.architecture": "unsupported",
+ },
+ err: ErrUnsupportedModel,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := modelForArch(tt.config)
+ if !errors.Is(err, tt.err) {
+ t.Fatal(err)
+ }
+
+ if diff := cmp.Diff(tt.want, got); diff != "" {
+ t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff)
+ }
+ })
}
}
-
-type notTextProcessorModel struct{}
-
-func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
- panic("unimplemented")
-}
-
-func (notTextProcessorModel) Backend() ml.Backend {
- panic("unimplemented")
-}
-
-func (notTextProcessorModel) Config() config {
- panic("unimplemented")
-}
diff --git a/model/models/bert/embed.go b/model/models/bert/embed.go
new file mode 100644
index 00000000..166c11e1
--- /dev/null
+++ b/model/models/bert/embed.go
@@ -0,0 +1,181 @@
+package bert
+
+import (
+ "cmp"
+ "math"
+
+ "github.com/ollama/ollama/fs"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/nn"
+ "github.com/ollama/ollama/ml/nn/pooling"
+ "github.com/ollama/ollama/model"
+ "github.com/ollama/ollama/model/input"
+)
+
+type Model struct {
+ model.Base
+ model.TextProcessor
+
+ TokenEmbedding *nn.Embedding `gguf:"token_embd"`
+ TypeEmbedding *nn.Embedding `gguf:"token_types"`
+ PositionEmbedding *nn.Embedding `gguf:"position_embd"`
+ TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
+
+ Layers []EncoderLayer `gguf:"blk"`
+
+ Options
+}
+
+// Forward implements model.Model.
+func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
+ hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
+ hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
+ hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
+ hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
+
+ for _, layer := range m.Layers {
+ hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
+ }
+
+ hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
+ if m.normalize {
+ hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
+ }
+
+ return hiddenStates, nil
+}
+
+type EncoderLayer struct {
+ *Attention
+ AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
+
+ *MLP
+ MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
+}
+
+func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
+ // Attention
+ residual := hiddenStates
+ hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
+ hiddenStates = hiddenStates.Add(ctx, residual)
+ hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
+
+ // MLP
+ residual = hiddenStates
+ hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
+ hiddenStates = hiddenStates.Add(ctx, residual)
+ hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
+
+ return hiddenStates
+}
+
+type Attention struct {
+ Query *nn.Linear `gguf:"attn_q"`
+ QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
+
+ Key *nn.Linear `gguf:"attn_k"`
+ KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
+
+ Value *nn.Linear `gguf:"attn_v"`
+
+ Output *nn.Linear `gguf:"attn_output"`
+}
+
+func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
+ batchSize := hiddenStates.Dim(1)
+
+ query := a.Query.Forward(ctx, hiddenStates)
+ if a.QueryNorm != nil {
+ query = a.QueryNorm.Forward(ctx, query, opts.eps)
+ }
+ query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
+
+ key := a.Key.Forward(ctx, hiddenStates)
+ if a.KeyNorm != nil {
+ key = a.KeyNorm.Forward(ctx, key, opts.eps)
+ }
+ key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
+
+ value := a.Value.Forward(ctx, hiddenStates)
+ value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
+
+ attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
+ attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
+ return a.Output.Forward(ctx, attention)
+}
+
+type MLP struct {
+ Up *nn.Linear `gguf:"ffn_up"`
+ Down *nn.Linear `gguf:"ffn_down"`
+}
+
+func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
+ return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
+}
+
+type Options struct {
+ hiddenSize,
+ numHeads,
+ numKVHeads,
+ keyLength,
+ valueLength int
+ poolingType pooling.Type
+ eps float32
+ normalize bool
+}
+
+func (o Options) headDim() int {
+ return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
+}
+
+func New(c fs.Config) (model.Model, error) {
+ var processor model.TextProcessor
+ switch c.String("tokenizer.ggml.model", "bert") {
+ case "bert":
+ processor = model.NewWordPiece(
+ &model.Vocabulary{
+ Values: c.Strings("tokenizer.ggml.tokens"),
+ Scores: c.Floats("tokenizer.ggml.scores"),
+ Types: c.Ints("tokenizer.ggml.token_type"),
+ AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
+ BOS: []int32{
+ int32(cmp.Or(
+ c.Uint("tokenizer.ggml.cls_token_id"),
+ c.Uint("tokenizer.ggml.bos_token_id"),
+ )),
+ },
+ AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
+ EOS: []int32{
+ int32(cmp.Or(
+ c.Uint("tokenizer.ggml.separator_token_id"),
+ //nolint:misspell
+ // NOTE: "seperator_token_id" is a typo in model metadata but we need to
+ // support it for compatibility.
+ c.Uint("tokenizer.ggml.seperator_token_id"),
+ c.Uint("tokenizer.ggml.eos_token_id"),
+ )),
+ },
+ },
+ )
+ default:
+ return nil, model.ErrUnsupportedTokenizer
+ }
+
+ return &Model{
+ TextProcessor: processor,
+ Layers: make([]EncoderLayer, c.Uint("block_count")),
+ Options: Options{
+ hiddenSize: int(c.Uint("embedding_length")),
+ numHeads: int(c.Uint("attention.head_count")),
+ numKVHeads: int(c.Uint("attention.head_count_kv")),
+ eps: c.Float("attention.layer_norm_epsilon"),
+ poolingType: pooling.Type(c.Uint("pooling_type")),
+ normalize: c.Bool("normalize_embeddings", true),
+ },
+ }, nil
+}
+
+func init() {
+ model.Register("bert", New)
+ model.Register("bert_embed", New)
+}
diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go
index e621d03a..2b16dc62 100644
--- a/model/models/gemma2/model.go
+++ b/model/models/gemma2/model.go
@@ -24,7 +24,7 @@ type Options struct {
type Model struct {
model.Base
- model.SentencePieceModel
+ model.SentencePiece
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -40,7 +40,7 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
- SentencePieceModel: model.NewSentencePieceModel(
+ SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) {
attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0),
- ropeScale: c.Float("rope.freq_scale", 1.0),
+ ropeScale: c.Float("rope.scaling.factor", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"),
},
@@ -88,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
- q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
+ q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -98,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
- k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
+ k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
- return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil
+ return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
}
type MLP struct {
@@ -138,7 +138,7 @@ type MLP struct {
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
- hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+ hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
- outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
@@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
- lastLayerOutputs = outputs
+ lastLayerOutputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go
index 16c299e2..52554776 100644
--- a/model/models/gemma3/embed.go
+++ b/model/models/gemma3/embed.go
@@ -1,49 +1,38 @@
package gemma3
import (
- "errors"
-
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
+ "github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type embedModel struct {
model.Base
- model.SentencePieceModel
+ model.SentencePiece
*TextModel
- PoolingType uint32
+ poolingType pooling.Type
Dense [2]*nn.Linear `gguf:"dense"`
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
- batch.Outputs = batch.Positions // return all positions
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
-
- switch m.PoolingType {
- case 0: // None
- case 1: // Mean
- hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
- hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
- default:
- return nil, errors.New("unsupported pooling type")
- }
-
+ hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates)
}
-
+ hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil
}
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
- SentencePieceModel: model.NewSentencePieceModel(
+ SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -61,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) {
},
),
TextModel: newTextModel(c),
- PoolingType: c.Uint("pooling_type", 0),
+ poolingType: pooling.Type(c.Uint("pooling_type", 0)),
}
m.Cache = kvcache.NewWrapperCache(
diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go
index 5c92b6bf..27da889e 100644
--- a/model/models/gemma3/model.go
+++ b/model/models/gemma3/model.go
@@ -16,7 +16,7 @@ import (
type Model struct {
model.Base
- model.SentencePieceModel
+ model.SentencePiece
*VisionModel `gguf:"v"`
*TextModel
@@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
func New(c fs.Config) (model.Model, error) {
m := Model{
- SentencePieceModel: model.NewSentencePieceModel(
+ SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go
index 2a3b2393..631baecc 100644
--- a/model/models/gemma3/model_text.go
+++ b/model/models/gemma3/model_text.go
@@ -53,7 +53,10 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
- ropeScale: c.Float("rope.freq_scale", 1.0),
+ ropeScale: 1,
+ // NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
+ // (8 instead of 1)
+ // ropeScale: c.Float("rope.scaling.factor", 1.0),
},
}
@@ -84,7 +87,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
- q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
+ q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -95,7 +98,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
- k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
+ k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -113,7 +116,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
ropeBase = m.TextConfig.ropeGlobalBase
}
- return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
+ return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
}
type TextMLP struct {
@@ -123,7 +126,7 @@ type TextMLP struct {
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
- hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+ hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -161,7 +164,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
- outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
@@ -194,7 +196,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
- lastLayerOutputs = outputs
+ lastLayerOutputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
diff --git a/model/models/gemma3n/model.go b/model/models/gemma3n/model.go
index 6e83a972..e59e3193 100644
--- a/model/models/gemma3n/model.go
+++ b/model/models/gemma3n/model.go
@@ -10,7 +10,7 @@ import (
type Model struct {
model.Base
- model.SentencePieceModel
+ model.SentencePiece
*TextModel
}
@@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
- SentencePieceModel: model.NewSentencePieceModel(
+ SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go
index b75a2abb..d0e9a026 100644
--- a/model/models/gemma3n/model_text.go
+++ b/model/models/gemma3n/model_text.go
@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
- hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
+ hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
@@ -95,7 +95,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
ropeBase = m.ropeBaseLocal
}
- return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
+ return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
}
type TextScaledWordEmbedding struct {
@@ -170,8 +170,7 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
}
active = d.PerLayerInputGate.Forward(ctx, active)
- active = active.GELU(ctx)
- active = active.Mul(ctx, perLayerInput)
+ active = active.GELU(ctx, perLayerInput)
active = d.PerLayerProjection.Forward(ctx, active)
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
@@ -257,14 +256,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
query := attn.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
- query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
+ query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
var key, value ml.Tensor
if !sharedKV {
key = attn.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
- key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
+ key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
value = attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
@@ -292,7 +291,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
}
- hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
+ hiddenStates = hiddenStates.GELU(ctx, upStates)
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
return hiddenStates
}
@@ -350,7 +349,7 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeBase: c.Float("rope.freq_base", 1_000_000),
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
- ropeScale: c.Float("rope.freq_scale", 1.0),
+ ropeScale: c.Float("rope.scaling.factor", 1.0),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
activationSparsityScale: c.Floats("activation_sparsity_scale"),
diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go
index 3ef07809..8456ea5f 100644
--- a/model/models/gptoss/model.go
+++ b/model/models/gptoss/model.go
@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
}
var outputs ml.Tensor
- if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
- outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
+ if i == len(m.TransformerBlocks)-1 {
+ outputs = batch.Outputs
}
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
@@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
}
- hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7)
+ hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7)
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
diff --git a/model/models/llama/model.go b/model/models/llama/model.go
index 77d8f36d..f6ec0227 100644
--- a/model/models/llama/model.go
+++ b/model/models/llama/model.go
@@ -2,7 +2,6 @@ package llama
import (
"cmp"
- "fmt"
"math"
"github.com/ollama/ollama/fs"
@@ -23,51 +22,60 @@ type Options struct {
type Model struct {
model.Base
- model.BytePairEncoding
+ model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
- *Options
+ Options
}
func New(c fs.Config) (model.Model, error) {
- // This model currently only supports the gpt2 tokenizer
- if c.String("tokenizer.ggml.model") == "llama" {
- return nil, fmt.Errorf("unsupported tokenizer: llama")
+ if c.Uint("expert_count") > 0 {
+ // TODO: support mixtures of experts
+ return nil, model.ErrUnsupportedModel
}
- // Best effort detection of library/deepseek-coder model(s) which are incompatible
- if c.String("general.name") == "deepseek-ai" {
- return nil, fmt.Errorf("unsupported model: %s", c.String("general.name"))
- }
- m := Model{
- BytePairEncoding: model.NewBytePairEncoding(
- c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
- &model.Vocabulary{
- Values: c.Strings("tokenizer.ggml.tokens"),
- Types: c.Ints("tokenizer.ggml.token_type"),
- Merges: c.Strings("tokenizer.ggml.merges"),
- AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
- BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
- AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
- EOS: append(
- []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
- c.Ints("tokenizer.ggml.eos_token_ids")...,
- ),
- },
+
+ var processor model.TextProcessor
+ vocabulary := model.Vocabulary{
+ Values: c.Strings("tokenizer.ggml.tokens"),
+ Scores: c.Floats("tokenizer.ggml.scores"),
+ Types: c.Ints("tokenizer.ggml.token_type"),
+ Merges: c.Strings("tokenizer.ggml.merges"),
+ AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
+ BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
+ AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
+ EOS: append(
+ []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
+ c.Ints("tokenizer.ggml.eos_token_ids")...,
),
- Layers: make([]Layer, c.Uint("block_count")),
- Options: &Options{
+ }
+ switch c.String("tokenizer.ggml.model") {
+ case "gpt2":
+ processor = model.NewBytePairEncoding(
+ `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
+ &vocabulary,
+ )
+ case "llama":
+ processor = model.NewSentencePiece(&vocabulary)
+ default:
+ return nil, model.ErrUnsupportedTokenizer
+ }
+
+ m := Model{
+ TextProcessor: processor,
+ Layers: make([]Layer, c.Uint("block_count")),
+ Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
- ropeBase: c.Float("rope.freq_base"),
- ropeScale: c.Float("rope.freq_scale", 1),
+ ropeBase: c.Float("rope.freq_base", 1e5),
+ ropeScale: c.Float("rope.scaling.factor", 1),
},
}
@@ -98,8 +106,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
- query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
- key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
+ query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
+ key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
@@ -108,7 +116,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
- return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
+ return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
}
type MLP struct {
@@ -118,7 +126,7 @@ type MLP struct {
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
- hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+ hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -160,10 +168,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor
if i == len(m.Layers)-1 {
- outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
+ outputs = batch.Outputs
}
- hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
+ hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go
index 99a898d2..9cb2efc8 100644
--- a/model/models/llama4/model.go
+++ b/model/models/llama4/model.go
@@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
- outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
-
- return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
+ return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
}
func init() {
diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go
index 045ab403..e0f93260 100644
--- a/model/models/llama4/model_text.go
+++ b/model/models/llama4/model_text.go
@@ -33,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if useRope {
- query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
- key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
+ query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
+ key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
}
if opts.useQKNorm {
@@ -58,14 +58,14 @@ type TextMLP struct {
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
- hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
+ hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextExperts struct {
- Gate *nn.Linear `gguf:"ffn_gate_exps"`
- Up *nn.Linear `gguf:"ffn_up_exps"`
- Down *nn.Linear `gguf:"ffn_down_exps"`
+ Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
+ Up *nn.LinearBatch `gguf:"ffn_up_exps"`
+ Down *nn.LinearBatch `gguf:"ffn_down_exps"`
}
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
hiddenStates = hiddenStates.Mul(ctx, scores)
- upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
- gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
- downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
+ upStates := e.Up.Forward(ctx, hiddenStates, experts)
+ gateStates := e.Gate.Forward(ctx, hiddenStates, experts)
+ downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
@@ -96,7 +96,7 @@ type TextSharedExpert struct {
}
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
- hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
+ hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
@@ -196,7 +196,7 @@ func newTextModel(c fs.Config) *TextModel {
numExpertsUsed: int(c.Uint("expert_used_count")),
ropeDim: int(c.Uint("rope.dimension_count")),
ropeBase: c.Float("rope.freq_base"),
- ropeScale: c.Float("rope.freq_scale", 1),
+ ropeScale: c.Float("rope.scaling.factor", 1),
eps: c.Float("attention.layer_norm_rms_epsilon"),
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
noRopeInterval: int(c.Uint("no_rope_interval", 4)),
@@ -248,5 +248,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
- return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
+ return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
}
diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go
index 408e54d3..435b1a30 100644
--- a/model/models/mistral3/model.go
+++ b/model/models/mistral3/model.go
@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
- outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
- return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
+ return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
}
func init() {
diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go
index 19c36f9f..d2e2eac6 100644
--- a/model/models/mistral3/model_text.go
+++ b/model/models/mistral3/model_text.go
@@ -40,11 +40,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
- q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+ q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
- k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+ k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
- return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil
+ return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil
}
type MLP struct {
@@ -65,7 +65,7 @@ type MLP struct {
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
- hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+ hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -132,7 +132,7 @@ func newTextModel(c fs.Config) *TextModel {
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
- ropeScale: c.Float("rope.freq_scale", 1),
+ ropeScale: c.Float("rope.scaling.factor", 1),
},
}
}
diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go
index 65bdcff2..3bfb8c90 100644
--- a/model/models/mistral3/model_vision.go
+++ b/model/models/mistral3/model_vision.go
@@ -51,7 +51,7 @@ type VisionMLP struct {
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
- hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
+ hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go
index d0ad4670..239d999d 100644
--- a/model/models/mllama/model.go
+++ b/model/models/mllama/model.go
@@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
- outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
// TODO: attention mask, cross attention mask
- return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
+ return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
}
func init() {
diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go
index 47a518ce..65f0a827 100644
--- a/model/models/mllama/model_text.go
+++ b/model/models/mllama/model_text.go
@@ -26,11 +26,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
- query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
+ query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
- key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
+ key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -45,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
- return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
+ return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
}
return key, nil
@@ -58,7 +58,7 @@ type TextMLP struct {
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
- hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+ hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel {
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
- ropeScale: c.Float("rope.freq_scale", 1),
+ ropeScale: c.Float("rope.scaling.factor", 1),
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
},
}
diff --git a/model/models/models.go b/model/models/models.go
index c880a472..cc998078 100644
--- a/model/models/models.go
+++ b/model/models/models.go
@@ -1,6 +1,7 @@
package models
import (
+ _ "github.com/ollama/ollama/model/models/bert"
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go
index 3c662f06..5a345837 100644
--- a/model/models/qwen2/model.go
+++ b/model/models/qwen2/model.go
@@ -43,8 +43,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
value := attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
- query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
- key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
+ query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
+ key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
@@ -59,7 +59,7 @@ type MLP struct {
}
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
- hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
+ hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
@@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor
if i == len(m.Layers)-1 {
- outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
+ outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
@@ -124,7 +124,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
- return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
+ return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
}
func New(c fs.Config) (model.Model, error) {
@@ -160,7 +160,7 @@ func New(c fs.Config) (model.Model, error) {
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
ropeBase: c.Float("rope.freq_base"),
- ropeScale: c.Float("rope.freq_scale", 1),
+ ropeScale: c.Float("rope.scaling.factor", 1),
eps: c.Float("attention.layer_norm_rms_epsilon"),
},
}
diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go
index d73f499d..6c76305d 100644
--- a/model/models/qwen25vl/model.go
+++ b/model/models/qwen25vl/model.go
@@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
- outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
- return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
+ return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
}
func init() {
diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go
index 4b6bc166..e6c6e6c1 100644
--- a/model/models/qwen25vl/model_text.go
+++ b/model/models/qwen25vl/model_text.go
@@ -38,7 +38,7 @@ func NewTextModel(c fs.Config) *TextModel {
originalContextLength: int(c.Uint("context_length", 128000)),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
- ropeScale: c.Float("rope.freq_scale", 1),
+ ropeScale: c.Float("rope.scaling.factor", 1),
},
}
@@ -60,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
- q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
+ q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
- k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
+ k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -78,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
// Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
- return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
+ return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
}
// MLP implements the feed-forward network component with SwiGLU activation
@@ -90,7 +90,7 @@ type MLP struct {
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
// Apply SwiGLU activation gating
- hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+ hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
// Project back to hidden dimension
return mlp.Down.Forward(ctx, hiddenState)
}
diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go
index 4d7afaa1..3dd60e3b 100644
--- a/model/models/qwen25vl/model_vision.go
+++ b/model/models/qwen25vl/model_vision.go
@@ -100,8 +100,7 @@ type VisionMLP struct {
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using activation as specified in config (likely GELU or SiLU/Swish)
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
- upOutput := mlp.Up.Forward(ctx, hiddenStates)
- hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput)
+ hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
diff --git a/model/models/qwen3/embed.go b/model/models/qwen3/embed.go
new file mode 100644
index 00000000..9a77efea
--- /dev/null
+++ b/model/models/qwen3/embed.go
@@ -0,0 +1,73 @@
+package qwen3
+
+import (
+ "github.com/ollama/ollama/fs"
+ "github.com/ollama/ollama/kvcache"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/nn/pooling"
+ "github.com/ollama/ollama/model"
+ "github.com/ollama/ollama/model/input"
+)
+
+type embedModel struct {
+ model.Base
+ model.BytePairEncoding
+
+ *Model
+ poolingType pooling.Type
+}
+
+func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
+ hiddenStates, err := m.forward(ctx, batch)
+ if err != nil {
+ return nil, err
+ }
+
+ hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
+ hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
+ return hiddenStates, nil
+}
+
+func newEmbed(c fs.Config) (model.Model, error) {
+ layers := make([]Layer, c.Uint("block_count"))
+ for i := range layers {
+ layers[i].MLP = &dense{}
+ }
+ m := embedModel{
+ BytePairEncoding: model.NewBytePairEncoding(
+ `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
+ &model.Vocabulary{
+ Values: c.Strings("tokenizer.ggml.tokens"),
+ Types: c.Ints("tokenizer.ggml.token_type"),
+ Merges: c.Strings("tokenizer.ggml.merges"),
+ AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
+ BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
+ AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
+ EOS: append(
+ []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
+ c.Ints("tokenizer.ggml.eos_token_ids")...,
+ ),
+ },
+ ),
+ Model: &Model{
+ Layers: layers,
+ Options: &Options{
+ hiddenSize: int(c.Uint("embedding_length")),
+ numHeads: int(c.Uint("attention.head_count")),
+ numKVHeads: int(c.Uint("attention.head_count_kv")),
+ keyLength: int(c.Uint("attention.key_length")),
+ valueLength: int(c.Uint("attention.value_length")),
+ eps: c.Float("attention.layer_norm_rms_epsilon"),
+ ropeBase: c.Float("rope.freq_base"),
+ ropeScale: c.Float("rope.freq_scale", 1),
+ numExperts: int(c.Uint("expert_count")),
+ numExpertsUsed: int(c.Uint("expert_used_count")),
+ normTopKProb: c.Bool("norm_top_k_prob", true),
+ },
+ },
+ poolingType: pooling.Type(c.Uint("pooling_type")),
+ }
+
+ m.Cache = kvcache.NewCausalCache(m.Shift)
+ return &m, nil
+}
diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go
index 7a83e0d0..35226834 100644
--- a/model/models/qwen3/model.go
+++ b/model/models/qwen3/model.go
@@ -30,10 +30,10 @@ func (o Options) headDim() int {
}
type Attention struct {
- QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Query *nn.Linear `gguf:"attn_q"`
- KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
+ QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
+ KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
@@ -52,8 +52,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
- query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
- key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
+ query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
+ key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
@@ -65,10 +65,10 @@ type MLP interface {
}
type sparse struct {
- Router *nn.Linear `gguf:"ffn_gate_inp"`
- Gate *nn.Linear `gguf:"ffn_gate_exps"`
- Up *nn.Linear `gguf:"ffn_up_exps"`
- Down *nn.Linear `gguf:"ffn_down_exps"`
+ Router *nn.Linear `gguf:"ffn_gate_inp"`
+ Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
+ Up *nn.LinearBatch `gguf:"ffn_up_exps"`
+ Down *nn.LinearBatch `gguf:"ffn_down_exps"`
}
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
@@ -87,13 +87,9 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
- upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
+ hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts))
- hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
- hiddenStates = hiddenStates.SILU(ctx)
- hiddenStates = hiddenStates.Mul(ctx, upStates)
-
- experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
+ experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
@@ -111,7 +107,8 @@ type dense struct {
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
- hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
+ hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).
+ SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
@@ -154,29 +151,39 @@ type Model struct {
*Options
}
-// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
+ hiddenStates, err := m.forward(ctx, batch)
+ if err != nil {
+ return nil, err
+ }
+
+ return m.Output.Forward(ctx, hiddenStates), nil
+}
+
+// Forward implements model.Model.
+func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
- m.Cache.SetLayer(i)
+ if m.Cache != nil {
+ m.Cache.SetLayer(i)
+ }
var outputs ml.Tensor
if i == len(m.Layers)-1 {
- outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
+ outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
}
- hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
- return m.Output.Forward(ctx, hiddenStates), nil
+ return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
- return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
+ return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
}
var _ model.Model = (*Model)(nil)
@@ -216,7 +223,7 @@ func New(c fs.Config) (model.Model, error) {
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
- ropeScale: c.Float("rope.freq_scale", 1),
+ ropeScale: c.Float("rope.scaling.factor", 1),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true),
@@ -230,4 +237,5 @@ func New(c fs.Config) (model.Model, error) {
func init() {
model.Register("qwen3", New)
model.Register("qwen3moe", New)
+ model.Register("qwen3_embed", newEmbed)
}
diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go
new file mode 100644
index 00000000..e6dbd1f4
--- /dev/null
+++ b/model/parsers/parsers.go
@@ -0,0 +1,37 @@
+package parsers
+
+import (
+ "github.com/ollama/ollama/api"
+)
+
+type Parser interface {
+ Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error)
+ HasToolSupport() bool
+ HasThinkingSupport() bool
+}
+
+func ParserForName(name string) Parser {
+ switch name {
+ case "qwen3-coder":
+ parser := &Qwen3CoderParser{}
+ return parser
+ case "passthrough":
+ return &PassthroughParser{}
+ default:
+ return nil
+ }
+}
+
+type PassthroughParser struct{}
+
+func (p *PassthroughParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
+ return s, "", nil, nil
+}
+
+func (p *PassthroughParser) HasToolSupport() bool {
+ return false
+}
+
+func (p *PassthroughParser) HasThinkingSupport() bool {
+ return false
+}
diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go
new file mode 100644
index 00000000..a7838fb7
--- /dev/null
+++ b/model/parsers/qwen3coder.go
@@ -0,0 +1,447 @@
+package parsers
+
+import (
+ "context"
+ "encoding/json"
+ "encoding/xml"
+ "fmt"
+ "log/slog"
+ "math"
+ "regexp"
+ "strconv"
+ "strings"
+ "unicode"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/logutil"
+)
+
+type qwenParserState int
+
+const (
+ toolOpenTag = ""
+ toolCloseTag = ""
+)
+
+const (
+ qwenParserState_LookingForToolStart qwenParserState = iota
+ qwenParserState_CollectingToolContent
+)
+
+type Qwen3CoderParser struct {
+ state qwenParserState
+ acc strings.Builder
+}
+
+func (p *Qwen3CoderParser) HasToolSupport() bool {
+ return true
+}
+
+func (p *Qwen3CoderParser) HasThinkingSupport() bool {
+ return false
+}
+
+func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
+ p.acc.WriteString(s)
+
+ events := p.parseEvents()
+
+ var toolCalls []api.ToolCall
+ var sb strings.Builder
+ for _, event := range events {
+ switch event := event.(type) {
+ case qwenEventRawToolCall:
+ toolCall, err := parseToolCall(event, tools)
+ if err != nil {
+ slog.Warn("qwen tool call parsing failed", "error", err)
+ return "", "", nil, err
+ }
+ toolCalls = append(toolCalls, toolCall)
+ case qwenEventContent:
+ // TODO(drifkin): if the same turn contains multiple interleaved content
+ // events, we naively append them together here. See the note below about
+ // `qwenEvent`s for more details
+ sb.WriteString(event.content)
+ }
+ }
+
+ return sb.String(), "", toolCalls, nil
+}
+
+func (p *Qwen3CoderParser) parseEvents() []qwenEvent {
+ var all []qwenEvent
+
+ keepLooping := true
+ for keepLooping {
+ var events []qwenEvent
+ events, keepLooping = eat(p)
+ if len(events) > 0 {
+ all = append(all, events...)
+ }
+ }
+
+ if len(all) > 0 {
+ slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "acc", p.acc.String())
+ }
+
+ return all
+}
+
+// we use some internal event types in order to communicate between `Add` and
+// `eat`. We do this to support interleaving content and parallel tool calls in
+// the parser, even though qwen3-coder isn't supposed to do this. Our API
+// doesn't currently support models outputting multiple messages in a turn, so
+// we wouldn't be able to represent it yet, but there's no reason to prevent the
+// parser from supporting it, especially for future models if they end up using
+// a similar format.
+type qwenEvent interface {
+ isQwenEvent()
+}
+
+type qwenEventRawToolCall struct {
+ raw string
+}
+
+type qwenEventContent struct {
+ content string
+}
+
+func (qwenEventContent) isQwenEvent() {}
+func (qwenEventRawToolCall) isQwenEvent() {}
+
+// eat consumes the parser's buffer, and returns a list of any unambiguous
+// events from the current parser state. If the parser transitions to another
+// state, it may have additional events to emit on the next call, which is what
+// the second return value indicates
+func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
+ var events []qwenEvent
+
+ switch p.state {
+ case qwenParserState_LookingForToolStart:
+ if strings.Contains(p.acc.String(), toolOpenTag) {
+ // we found a full tool open tag, so we can emit the content before the
+ // tag, being sure to trim any trailing whitespace
+ split := strings.SplitN(p.acc.String(), toolOpenTag, 2)
+ before := split[0]
+ before = strings.TrimRightFunc(before, unicode.IsSpace)
+ if len(before) > 0 {
+ events = append(events, qwenEventContent{content: before})
+ }
+ after := split[1]
+ p.acc.Reset()
+ p.acc.WriteString(after)
+ p.state = qwenParserState_CollectingToolContent
+ return events, true
+ } else if overlap := overlap(p.acc.String(), toolOpenTag); overlap > 0 {
+ // we found a partial tool open tag, so we can emit the unambiguous part,
+ // which is the (trailing-whitespace trimmed) content before the partial
+ // tool open tag
+ beforePartialTag := p.acc.String()[:len(p.acc.String())-overlap]
+ trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
+ ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
+ unambiguous := p.acc.String()[:ambiguousStart]
+ ambiguous := p.acc.String()[ambiguousStart:]
+ p.acc.Reset()
+ p.acc.WriteString(ambiguous)
+ events = append(events, qwenEventContent{content: unambiguous})
+ return events, false
+ } else {
+ // we found content that is entirely not a tool call. We should withhold
+ // any trailing whitespace in case this is the end of the content
+ whitespaceLen := trailingWhitespaceLen(p.acc.String())
+ ambiguousStart := len(p.acc.String()) - whitespaceLen
+ unambiguous := p.acc.String()[:ambiguousStart]
+ ambiguous := p.acc.String()[ambiguousStart:]
+ p.acc.Reset()
+ p.acc.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, qwenEventContent{content: unambiguous})
+ }
+ return events, false
+ }
+ case qwenParserState_CollectingToolContent:
+ if strings.Contains(p.acc.String(), toolCloseTag) {
+ split := strings.SplitN(p.acc.String(), toolCloseTag, 2)
+ before := split[0]
+ if len(before) == 0 {
+ slog.Warn("qwen tool call closing tag found but no content before it")
+ }
+ // remove any whitespace between the tool call and any content after it
+ after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
+ p.acc.Reset()
+ p.acc.WriteString(after)
+ events = append(events, qwenEventRawToolCall{raw: before})
+ p.state = qwenParserState_LookingForToolStart
+ return events, true
+ } else {
+ // note that we don't need to check the overlap here because we only plan
+ // on parsing the tool call once we see the full closing tag. We don't
+ // stream back the unparsed tool content, so there's no need to be eager
+ // here
+ return events, false
+ }
+ default:
+ panic("unreachable")
+ }
+}
+
+// TODO(drifkin): move this to a shared location
+// longest overlap between suffix of s and prefix of delim
+func overlap(s, delim string) int {
+ max := min(len(delim), len(s))
+ for i := max; i > 0; i-- {
+ if strings.HasSuffix(s, delim[:i]) {
+ return i
+ }
+ }
+ return 0
+}
+
+func trailingWhitespaceLen(s string) int {
+ for i := len(s) - 1; i >= 0; i-- {
+ if !unicode.IsSpace(rune(s[i])) {
+ return len(s) - i - 1
+ }
+ }
+ return len(s)
+}
+
+type XMLFunctionCall struct {
+ XMLName xml.Name `xml:"function"`
+ Name string `xml:"name,attr"`
+ Parameters []XMLParameter `xml:"parameter"`
+}
+
+type XMLParameter struct {
+ Name string `xml:"name,attr"`
+ Value string `xml:",chardata"`
+}
+
+// parseToolCall parses a raw tool call string into an api.ToolCall.
+// The raw string follows an xml-like format, here's an example:
+//
+//
+//
+// San Francisco
+//
+//
+// celsius
+//
+//
+func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
+ toolCall := api.ToolCall{}
+
+ xmlString := transformToXML(raw.raw)
+
+ var functionCall XMLFunctionCall
+ err := xml.Unmarshal([]byte(xmlString), &functionCall)
+ if err != nil {
+ return api.ToolCall{}, err
+ }
+
+ toolCall.Function = api.ToolCallFunction{
+ Name: functionCall.Name,
+ }
+
+ // Find the matching tool to get parameter types
+ var matchedTool *api.Tool
+ for i := range tools {
+ if tools[i].Function.Name == functionCall.Name {
+ matchedTool = &tools[i]
+ break
+ }
+ }
+
+ toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
+ for _, parameter := range functionCall.Parameters {
+ // Look up the parameter type if we found the tool
+ var paramType api.PropertyType
+ if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
+ if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
+ paramType = prop.Type
+ }
+ }
+
+ toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
+ }
+
+ return toolCall, nil
+}
+
+// parseValue converts a raw string value to the appropriate type based on the parameter type specification.
+//
+// For union types (multiple types in PropertyType, which we support but doesn't
+// seem as though the reference parser does type coercion with those types in
+// mind) we use a type precedence approach:
+// 1. null - checked first regardless of declared types (matches reference implementation)
+// 2. boolean - only "true"/"false" are valid booleans
+// 3. integer - must parse as a whole number
+// 4. number - must parse as numeric (returns int if no decimal part)
+// 5. array - must parse as valid JSON array
+// 6. object - must parse as valid JSON object
+// 7. string - always succeeds (least specific type)
+//
+// This precedence ensures we return the most specific type that successfully parses,
+// following the principle of least surprise. For example, with PropertyType{"string", "number"},
+// "123" becomes 123 (number), while "hello" becomes "hello" (string).
+func parseValue(raw string, paramType api.PropertyType) any {
+ // first remove a single leading newlines, and a single trailing newline (if
+ // they exist). This follows the reference implementation
+ raw = strings.TrimPrefix(raw, "\n")
+ raw = strings.TrimSuffix(raw, "\n")
+
+ // Check for null first (case-insensitive) - this takes precedence over any type
+ if strings.ToLower(raw) == "null" {
+ return nil
+ }
+
+ // If no type is specified, default to string
+ if len(paramType) == 0 {
+ return raw
+ }
+
+ // Check if any of the specified types match, using type precedence
+ // Order: boolean -> integer -> number -> array -> object -> string
+ typeSet := make(map[string]bool)
+ for _, t := range paramType {
+ typeSet[t] = true
+ }
+
+ // Try boolean first (most restrictive)
+ if typeSet["boolean"] {
+ lower := strings.ToLower(raw)
+ switch lower {
+ case "true":
+ return true
+ case "false":
+ return false
+ }
+ // If not a valid boolean but boolean is the only type, return false (matching reference)
+ if len(paramType) == 1 {
+ return false
+ }
+ // Otherwise try other types
+ }
+
+ // Try integer
+ if typeSet["integer"] {
+ if i, err := strconv.ParseInt(raw, 10, 64); err == nil {
+ // Return as int if it fits in int32, otherwise int64
+ if i >= math.MinInt32 && i <= math.MaxInt32 {
+ return int(i)
+ }
+ return i
+ }
+ // If integer is the only type and parsing failed, fall back to string
+ if len(paramType) == 1 {
+ return raw
+ }
+ }
+
+ // Try number (float)
+ if typeSet["number"] {
+ if f, err := strconv.ParseFloat(raw, 64); err == nil {
+ // If the number has no decimal part, return as int (matching reference)
+ if f == math.Trunc(f) {
+ i := int64(f)
+ if i >= math.MinInt32 && i <= math.MaxInt32 {
+ return int(i)
+ }
+ return i
+ }
+ return f
+ }
+ // If number is the only type and parsing failed, fall back to string
+ if len(paramType) == 1 {
+ return raw
+ }
+ }
+
+ // Try array
+ if typeSet["array"] {
+ var arr []interface{}
+ if err := json.Unmarshal([]byte(raw), &arr); err == nil {
+ return arr
+ }
+ // If array is the only type and parsing failed, fall back to string
+ if len(paramType) == 1 {
+ return raw
+ }
+ }
+
+ // Try object
+ if typeSet["object"] {
+ var obj map[string]interface{}
+ if err := json.Unmarshal([]byte(raw), &obj); err == nil {
+ return obj
+ }
+ // If object is the only type and parsing failed, fall back to string
+ if len(paramType) == 1 {
+ return raw
+ }
+ }
+
+ // String always succeeds (or if "string" is in the type set)
+ if typeSet["string"] {
+ return raw
+ }
+
+ // If we get here, none of the types matched and string wasn't an option
+ // We return string as a fallback. The reference implementation will attempt
+ // to parse the value as a python literal, but we purposefully don't support
+ // that
+ return raw
+}
+
+var (
+ qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
+ qwenXMLTagRegex = regexp.MustCompile(`?(?:function|parameter)(?:\s+name="[^"]*")?>`)
+)
+
+// transformToXML transforms a raw qwen tool call with xml-like tags into valid
+// xml so that it can be parsed by any xml parser
+func transformToXML(raw string) string {
+ // take the form `` and transform it to ``, taking
+ // care to properly escape the string that becomes the attribute value
+ transformed := qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
+ groups := qwenTagRegex.FindStringSubmatch(match)
+ tag := groups[1]
+ var escapedValue strings.Builder
+ xml.EscapeText(&escapedValue, []byte(groups[2]))
+ return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
+ })
+
+ // Walk the resulting string, escaping any character data that sits between the
+ // xml tags we just emitted
+ var out strings.Builder
+ lastIdx := 0
+ for _, loc := range qwenXMLTagRegex.FindAllStringIndex(transformed, -1) {
+ if loc[0] > lastIdx {
+ escapeTextNode(&out, transformed[lastIdx:loc[0]])
+ }
+ out.WriteString(transformed[loc[0]:loc[1]])
+ lastIdx = loc[1]
+ }
+ if lastIdx < len(transformed) {
+ escapeTextNode(&out, transformed[lastIdx:])
+ }
+
+ return out.String()
+}
+
+// escapeTextNode escapes XML character data without altering other characters
+// like newlines or tabs (which is why we don't use xml.EscapeText for this)
+func escapeTextNode(sb *strings.Builder, s string) {
+ for _, r := range s {
+ switch r {
+ case '&':
+ sb.WriteString("&")
+ case '<':
+ sb.WriteString("<")
+ case '>':
+ sb.WriteString(">")
+ default:
+ sb.WriteRune(r)
+ }
+ }
+}
diff --git a/model/parsers/qwen3coder_test.go b/model/parsers/qwen3coder_test.go
new file mode 100644
index 00000000..43823e6f
--- /dev/null
+++ b/model/parsers/qwen3coder_test.go
@@ -0,0 +1,878 @@
+package parsers
+
+import (
+ "reflect"
+ "testing"
+
+ "github.com/ollama/ollama/api"
+)
+
+// tool creates a test tool with the given name and properties
+func tool(name string, props map[string]api.ToolProperty) api.Tool {
+ t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
+ t.Function.Parameters.Type = "object"
+ t.Function.Parameters.Properties = props
+ return t
+}
+
+func TestQwenParserStreaming(t *testing.T) {
+ type step struct {
+ input string
+ wantEvents []qwenEvent
+ }
+
+ cases := []struct {
+ desc string
+ steps []step
+ only bool
+ }{
+ {
+ desc: "simple message streamed word by word",
+ steps: []step{
+ {
+ input: "hi",
+ wantEvents: []qwenEvent{qwenEventContent{content: "hi"}},
+ },
+ {
+ input: " there",
+ wantEvents: []qwenEvent{qwenEventContent{content: " there"}},
+ },
+ },
+ },
+ {
+ desc: "content before tool call",
+ steps: []step{
+ {
+ input: "hi there",
+ wantEvents: []qwenEvent{qwenEventContent{content: "hi there"}},
+ },
+ },
+ },
+ {
+ desc: "multiple tool calls in one message",
+ steps: []step{
+ {
+ input: "before1in tool callafter1in tool call 2after2",
+ wantEvents: []qwenEvent{
+ qwenEventContent{content: "before1"},
+ qwenEventRawToolCall{raw: "in tool call"},
+ qwenEventContent{content: "after1"},
+ qwenEventRawToolCall{raw: "in tool call 2"},
+ qwenEventContent{content: "after2"},
+ },
+ },
+ },
+ },
+ {
+ desc: "tool calls with split tags",
+ steps: []step{
+ {
+ input: "beforein tool callaf",
+ wantEvents: []qwenEvent{
+ qwenEventRawToolCall{raw: "in tool call"},
+ qwenEventContent{content: "af"},
+ },
+ },
+ {
+ input: "ter",
+ wantEvents: []qwenEvent{
+ qwenEventContent{content: "ter"},
+ },
+ },
+ },
+ },
+ {
+ desc: "trailing whitespace between content and tool call",
+ steps: []step{
+ {
+ input: "abc\ndef",
+ wantEvents: []qwenEvent{
+ qwenEventContent{content: "abc"},
+ qwenEventRawToolCall{raw: "def"},
+ },
+ },
+ },
+ },
+ {
+ desc: "trailing whitespace between tool call and content",
+ steps: []step{
+ {
+ input: "abc\ndef",
+ wantEvents: []qwenEvent{
+ qwenEventRawToolCall{raw: "abc"},
+ qwenEventContent{content: "def"},
+ },
+ },
+ },
+ },
+ {
+ desc: "empty content before tool call",
+ steps: []step{
+ {
+ input: "\nabc",
+ wantEvents: []qwenEvent{
+ qwenEventRawToolCall{raw: "abc"},
+ },
+ },
+ },
+ },
+ {
+ desc: "partial tool open tag fakeout",
+ steps: []step{
+ {
+ input: "abc\n
+
+San Francisco
+
+
+celsius
+
+`,
+ wantToolCall: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "get_current_temperature",
+ Arguments: map[string]any{
+ "location": "San Francisco",
+ "unit": "celsius",
+ },
+ },
+ },
+ },
+ {
+ name: "names with spaces",
+ tools: []api.Tool{},
+ rawToolCall: `
+
+San Francisco
+
+
+celsius
+
+`,
+ wantToolCall: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "get current temperature",
+ Arguments: map[string]any{
+ "location with spaces": "San Francisco",
+ "unit with spaces": "celsius",
+ },
+ },
+ },
+ },
+ // this mirrors the reference implementation's behavior, but unclear if it
+ // ever happens. If so, then we should probably remove them instead, this
+ // test is to just document the current behavior and test that we don't get
+ // xml errors
+ {
+ name: "names with quotes",
+ tools: []api.Tool{},
+ rawToolCall: `
+
+San Francisco
+
+
+"celsius"
+
+`,
+ wantToolCall: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "\"get current temperature\"",
+ Arguments: map[string]any{
+ "\"location with spaces\"": "San Francisco",
+ "\"unit with spaces\"": "\"celsius\"",
+ },
+ },
+ },
+ },
+ {
+ name: "tool call with typed parameters",
+ tools: []api.Tool{
+ tool("calculate", map[string]api.ToolProperty{
+ "x": {Type: api.PropertyType{"number"}},
+ "y": {Type: api.PropertyType{"integer"}},
+ "enabled": {Type: api.PropertyType{"boolean"}},
+ "items": {Type: api.PropertyType{"array"}},
+ }),
+ },
+ rawToolCall: `
+
+3.14
+
+
+42
+
+
+true
+
+
+["a", "b", "c"]
+
+`,
+ wantToolCall: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "calculate",
+ Arguments: map[string]any{
+ "x": 3.14,
+ "y": 42,
+ "enabled": true,
+ "items": []any{"a", "b", "c"},
+ },
+ },
+ },
+ },
+ // regression test for
+ {
+ name: "ampersands in parameter values",
+ tools: []api.Tool{},
+ rawToolCall: `
+
+ls && echo "done"
+
+`,
+ wantToolCall: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "exec",
+ Arguments: map[string]any{
+ "command": "ls && echo \"done\"",
+ },
+ },
+ },
+ },
+ {
+ name: "angle brackets in parameter values",
+ tools: []api.Tool{},
+ rawToolCall: `
+
+ls && echo "a > b and a < b"
+
+`,
+ wantToolCall: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "exec",
+ Arguments: map[string]any{
+ "command": "ls && echo \"a > b and a < b\"",
+ },
+ },
+ },
+ },
+ }
+
+ for i, step := range steps {
+ gotToolCall, err := parseToolCall(qwenEventRawToolCall{raw: step.rawToolCall}, step.tools)
+ if err != nil {
+ t.Errorf("step %d (%s): %v", i, step.name, err)
+ }
+ if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
+ t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
+ }
+ }
+}
+
+func TestQwenToolCallValueParsing(t *testing.T) {
+ cases := []struct {
+ desc string
+ raw string
+ paramType api.PropertyType
+ want any
+ }{
+ {
+ desc: "default string value (no type specified)",
+ paramType: api.PropertyType{},
+ raw: "some-string",
+ want: "some-string",
+ },
+ {
+ desc: "trim a single leading and trailing newline",
+ paramType: api.PropertyType{},
+ raw: "\nsome-string\n",
+ want: "some-string",
+ },
+ {
+ desc: "trim at most one leading and trailing newline",
+ paramType: api.PropertyType{},
+ raw: "\n\nsome-string\n\n",
+ want: "\nsome-string\n",
+ },
+ {
+ desc: "newline really has to be the first character to be trimmed",
+ paramType: api.PropertyType{},
+ raw: " \nsome-string\n ",
+ want: " \nsome-string\n ",
+ },
+ {
+ desc: "numeric type",
+ paramType: api.PropertyType{"number"},
+ raw: "123",
+ want: 123,
+ },
+ // Integer parsing tests
+ {
+ desc: "integer type",
+ paramType: api.PropertyType{"integer"},
+ raw: "42",
+ want: 42,
+ },
+ {
+ desc: "negative integer",
+ paramType: api.PropertyType{"integer"},
+ raw: "-100",
+ want: -100,
+ },
+ {
+ desc: "zero integer",
+ paramType: api.PropertyType{"integer"},
+ raw: "0",
+ want: 0,
+ },
+ {
+ desc: "integer with leading zeros",
+ paramType: api.PropertyType{"integer"},
+ raw: "007",
+ want: 7,
+ },
+ {
+ desc: "large integer",
+ paramType: api.PropertyType{"integer"},
+ raw: "2147483648", // Just beyond int32 max
+ want: int64(2147483648),
+ },
+ // Float/number parsing tests
+ {
+ desc: "float type",
+ paramType: api.PropertyType{"number"},
+ raw: "3.14",
+ want: 3.14,
+ },
+ {
+ desc: "negative float",
+ paramType: api.PropertyType{"number"},
+ raw: "-273.15",
+ want: -273.15,
+ },
+ {
+ desc: "float without decimal part",
+ paramType: api.PropertyType{"number"},
+ raw: "100.0",
+ want: 100,
+ },
+ {
+ desc: "scientific notation positive",
+ paramType: api.PropertyType{"number"},
+ raw: "1.23e5",
+ want: 123000, // Will be int since it has no decimal part
+ },
+ {
+ desc: "scientific notation negative",
+ paramType: api.PropertyType{"number"},
+ raw: "1.5e-3",
+ want: 0.0015,
+ },
+ {
+ desc: "very small float",
+ paramType: api.PropertyType{"number"},
+ raw: "0.00000001",
+ want: 0.00000001,
+ },
+ // String parsing tests
+ {
+ desc: "explicit string type",
+ paramType: api.PropertyType{"string"},
+ raw: "hello world",
+ want: "hello world",
+ },
+ {
+ desc: "string with special characters",
+ paramType: api.PropertyType{"string"},
+ raw: "/usr/local/bin/test-file_v2.0.sh",
+ want: "/usr/local/bin/test-file_v2.0.sh",
+ },
+ {
+ desc: "string with quotes",
+ paramType: api.PropertyType{"string"},
+ raw: `He said "hello" to me`,
+ want: `He said "hello" to me`,
+ },
+ {
+ desc: "multiline string",
+ paramType: api.PropertyType{"string"},
+ raw: "line one\nline two\nline three",
+ want: "line one\nline two\nline three",
+ },
+ {
+ desc: "empty string",
+ paramType: api.PropertyType{"string"},
+ raw: "",
+ want: "",
+ },
+ {
+ desc: "string that looks like a number",
+ paramType: api.PropertyType{"string"},
+ raw: "12345",
+ want: "12345",
+ },
+ // Boolean parsing tests
+ {
+ desc: "boolean true",
+ paramType: api.PropertyType{"boolean"},
+ raw: "true",
+ want: true,
+ },
+ {
+ desc: "boolean false",
+ paramType: api.PropertyType{"boolean"},
+ raw: "false",
+ want: false,
+ },
+ {
+ desc: "boolean case insensitive true",
+ paramType: api.PropertyType{"boolean"},
+ raw: "True",
+ want: true,
+ },
+ {
+ desc: "boolean case insensitive false",
+ paramType: api.PropertyType{"boolean"},
+ raw: "FALSE",
+ want: false,
+ },
+ // Null parsing tests
+ {
+ desc: "null value lowercase",
+ paramType: api.PropertyType{"string"},
+ raw: "null",
+ want: nil,
+ },
+ {
+ desc: "null value case insensitive",
+ paramType: api.PropertyType{"integer"},
+ raw: "NULL",
+ want: nil,
+ },
+ // Array parsing tests
+ {
+ desc: "array of strings",
+ paramType: api.PropertyType{"array"},
+ raw: `["foo", "bar", "baz"]`,
+ want: []any{"foo", "bar", "baz"},
+ },
+ {
+ desc: "array of numbers",
+ paramType: api.PropertyType{"array"},
+ raw: `[1, 2.5, 3]`,
+ want: []any{float64(1), 2.5, float64(3)},
+ },
+ {
+ desc: "array of mixed types",
+ paramType: api.PropertyType{"array"},
+ raw: `["string", 123, true, null]`,
+ want: []any{"string", float64(123), true, nil},
+ },
+ {
+ desc: "empty array",
+ paramType: api.PropertyType{"array"},
+ raw: `[]`,
+ want: []any{},
+ },
+ // Object parsing tests
+ {
+ desc: "simple object",
+ paramType: api.PropertyType{"object"},
+ raw: `{"key": "value", "number": 42}`,
+ want: map[string]any{"key": "value", "number": float64(42)},
+ },
+ {
+ desc: "nested object",
+ paramType: api.PropertyType{"object"},
+ raw: `{"outer": {"inner": "value"}}`,
+ want: map[string]any{"outer": map[string]any{"inner": "value"}},
+ },
+ {
+ desc: "empty object",
+ paramType: api.PropertyType{"object"},
+ raw: `{}`,
+ want: map[string]any{},
+ },
+ // Error cases and fallback behavior
+ {
+ desc: "invalid integer falls back to string",
+ paramType: api.PropertyType{"integer"},
+ raw: "not-a-number",
+ want: "not-a-number",
+ },
+ {
+ desc: "invalid float falls back to string",
+ paramType: api.PropertyType{"number"},
+ raw: "3.14.159",
+ want: "3.14.159",
+ },
+ {
+ desc: "invalid boolean falls back to false",
+ paramType: api.PropertyType{"boolean"},
+ raw: "yes",
+ want: false,
+ },
+ {
+ desc: "invalid JSON array falls back to string",
+ paramType: api.PropertyType{"array"},
+ raw: "[1, 2, unclosed",
+ want: "[1, 2, unclosed",
+ },
+ {
+ desc: "invalid JSON object falls back to string",
+ paramType: api.PropertyType{"object"},
+ raw: `{"key": unclosed`,
+ want: `{"key": unclosed`,
+ },
+ // Edge cases
+ {
+ desc: "integer overflow should use int64",
+ paramType: api.PropertyType{"integer"},
+ raw: "2147483648", // Beyond int32 max
+ want: int64(2147483648),
+ },
+ {
+ desc: "float with many decimal places",
+ paramType: api.PropertyType{"number"},
+ raw: "3.141592653589793",
+ want: 3.141592653589793,
+ },
+ {
+ desc: "string with JSON-like content",
+ paramType: api.PropertyType{"string"},
+ raw: `{"this": "is", "just": "a string"}`,
+ want: `{"this": "is", "just": "a string"}`,
+ },
+ {
+ desc: "whitespace-only string",
+ paramType: api.PropertyType{"string"},
+ raw: " ",
+ want: " ",
+ },
+ // Unknown parameter (no type specified in tools)
+ {
+ desc: "parameter not in tool definition defaults to string",
+ paramType: api.PropertyType{},
+ raw: "some value",
+ want: "some value",
+ },
+ // Union type tests
+ {
+ desc: "string or number union - valid number",
+ paramType: api.PropertyType{"string", "number"},
+ raw: "42.5",
+ want: 42.5,
+ },
+ {
+ desc: "string or number union - non-numeric string",
+ paramType: api.PropertyType{"string", "number"},
+ raw: "hello",
+ want: "hello",
+ },
+ {
+ desc: "number or string union - valid number (order shouldn't matter)",
+ paramType: api.PropertyType{"number", "string"},
+ raw: "42.5",
+ want: 42.5,
+ },
+ {
+ desc: "integer or null union - valid integer",
+ paramType: api.PropertyType{"integer", "null"},
+ raw: "123",
+ want: 123,
+ },
+ {
+ desc: "integer or null union - null value",
+ paramType: api.PropertyType{"integer", "null"},
+ raw: "null",
+ want: nil,
+ },
+ {
+ desc: "null or integer union - null value (order shouldn't matter)",
+ paramType: api.PropertyType{"null", "integer"},
+ raw: "null",
+ want: nil,
+ },
+ {
+ desc: "boolean or string union - valid boolean",
+ paramType: api.PropertyType{"boolean", "string"},
+ raw: "true",
+ want: true,
+ },
+ {
+ desc: "boolean or string union - non-boolean becomes string",
+ paramType: api.PropertyType{"boolean", "string"},
+ raw: "yes",
+ want: "yes",
+ },
+ {
+ desc: "string or boolean union - valid boolean (precedence test)",
+ paramType: api.PropertyType{"string", "boolean"},
+ raw: "false",
+ want: false, // Should be boolean, not string "false"
+ },
+ {
+ desc: "integer or number union - integer value",
+ paramType: api.PropertyType{"integer", "number"},
+ raw: "42",
+ want: 42,
+ },
+ {
+ desc: "integer or number union - float value",
+ paramType: api.PropertyType{"integer", "number"},
+ raw: "42.5",
+ want: 42.5,
+ },
+ {
+ desc: "number or integer union - integer value (precedence test)",
+ paramType: api.PropertyType{"number", "integer"},
+ raw: "42",
+ want: 42, // Should try integer first due to precedence
+ },
+ {
+ desc: "array or object union - valid array",
+ paramType: api.PropertyType{"array", "object"},
+ raw: `[1, 2, 3]`,
+ want: []any{float64(1), float64(2), float64(3)},
+ },
+ {
+ desc: "array or object union - valid object",
+ paramType: api.PropertyType{"array", "object"},
+ raw: `{"key": "value"}`,
+ want: map[string]any{"key": "value"},
+ },
+ {
+ desc: "object or array union - valid array (precedence test)",
+ paramType: api.PropertyType{"object", "array"},
+ raw: `[1, 2, 3]`,
+ want: []any{float64(1), float64(2), float64(3)},
+ },
+ {
+ desc: "complex multi-type union - null",
+ paramType: api.PropertyType{"string", "number", "boolean", "null"},
+ raw: "null",
+ want: nil,
+ },
+ {
+ desc: "complex multi-type union - boolean",
+ paramType: api.PropertyType{"string", "number", "boolean", "null"},
+ raw: "true",
+ want: true,
+ },
+ {
+ desc: "complex multi-type union - number",
+ paramType: api.PropertyType{"string", "number", "boolean", "null"},
+ raw: "3.14",
+ want: 3.14,
+ },
+ {
+ desc: "complex multi-type union - string",
+ paramType: api.PropertyType{"string", "number", "boolean", "null"},
+ raw: "hello",
+ want: "hello",
+ },
+ {
+ desc: "integer string union - integer string becomes integer",
+ paramType: api.PropertyType{"integer", "string"},
+ raw: "123",
+ want: 123,
+ },
+ {
+ desc: "string integer union - integer string becomes integer (precedence)",
+ paramType: api.PropertyType{"string", "integer"},
+ raw: "123",
+ want: 123, // Integer has higher precedence than string
+ },
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.desc, func(t *testing.T) {
+ got := parseValue(tc.raw, tc.paramType)
+ if !reflect.DeepEqual(got, tc.want) {
+ t.Errorf("got %v (type %T), want %v (type %T)", got, got, tc.want, tc.want)
+ }
+ })
+ }
+}
+
+func TestQwenXMLTransform(t *testing.T) {
+ cases := []struct {
+ desc string
+ raw string
+ want string
+ }{
+ {
+ desc: "simple example",
+ raw: `
+
+San Francisco
+
+
+celsius
+
+`,
+ want: `
+
+San Francisco
+
+
+celsius
+
+`,
+ },
+ // even though quotes aren't expected in these tags, we have these tests to
+ // make sure they're escaped so they don't blow up the xml parser in case
+ // they happen
+ {
+ desc: "names with quotes",
+ raw: `
+
+San Francisco
+
+
+celsius
+
+`,
+ want: `
+
+San Francisco
+
+
+celsius
+
+`,
+ },
+ {
+ desc: "ampersands in parameter values",
+ raw: `
+
+ San Francisco & San Jose
+
+ `,
+ want: `
+
+ San Francisco & San Jose
+
+ `,
+ },
+ }
+
+ for _, tc := range cases {
+ got := transformToXML(tc.raw)
+ if got != tc.want {
+ t.Errorf("got %q, want %q", got, tc.want)
+ }
+ }
+}
+
+func TestTrailingWhitespaceLen(t *testing.T) {
+ cases := []struct {
+ desc string
+ s string
+ want int
+ }{
+ {desc: "no whitespace", s: "abc", want: 0},
+ {desc: "trailing whitespace", s: "abc ", want: 1},
+ {desc: "trailing whitespace with newlines", s: "abc \n", want: 2},
+ {desc: "only whitespace", s: " \n ", want: 4},
+ {desc: "leading whitespace doesn't count", s: " \n abc", want: 0},
+ }
+
+ for _, tc := range cases {
+ got := trailingWhitespaceLen(tc.s)
+ if got != tc.want {
+ t.Errorf("got %d, want %d", got, tc.want)
+ }
+ }
+}
diff --git a/model/renderers/qwen3coder.go b/model/renderers/qwen3coder.go
new file mode 100644
index 00000000..df3b3a45
--- /dev/null
+++ b/model/renderers/qwen3coder.go
@@ -0,0 +1,217 @@
+package renderers
+
+import (
+ "encoding/json"
+ "fmt"
+ "reflect"
+ "strings"
+
+ "github.com/ollama/ollama/api"
+)
+
+var (
+ imStartTag = "<|im_start|>"
+ imEndTag = "<|im_end|>"
+)
+
+// renderAdditionalKeys renders all JSON fields except the ones in handledKeys
+// This follows the same approach from the reference implementation, which gives
+// a particular key ordering
+func renderAdditionalKeys(obj any, handledKeys map[string]bool) string {
+ data, err := json.Marshal(obj)
+ if err != nil {
+ return ""
+ }
+
+ var m map[string]any
+ if err := json.Unmarshal(data, &m); err != nil {
+ return ""
+ }
+
+ var sb strings.Builder
+ for key, value := range m {
+ if handledKeys[key] {
+ continue
+ }
+
+ // Check if value is a map or array (needs JSON serialization)
+ switch v := value.(type) {
+ case map[string]any, []any:
+ jsonBytes, _ := json.Marshal(v)
+ // TODO(drifkin): it would be nice to format the JSON here similarly to
+ // python's default json.dumps behavior (spaces after commas and colons).
+ // This would let us be byte-for-byte compatible with the reference
+ // implementation for most common inputs
+ jsonStr := string(jsonBytes)
+ sb.WriteString("\n<" + key + ">" + jsonStr + "" + key + ">")
+ case nil:
+ continue
+ default:
+ // Simple types, convert to string
+ sb.WriteString("\n<" + key + ">" + fmt.Sprintf("%v", value) + "" + key + ">")
+ }
+ }
+
+ return sb.String()
+}
+
+func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
+ var sb strings.Builder
+
+ // filter out system messages and choose the first (if any) to win
+ var systemMessage string
+ var filteredMessages []api.Message
+ for _, message := range messages {
+ if message.Role != "system" {
+ filteredMessages = append(filteredMessages, message)
+ continue
+ }
+
+ if systemMessage == "" {
+ systemMessage = message.Content
+ }
+ }
+
+ if systemMessage != "" || len(tools) > 0 {
+ sb.WriteString(imStartTag + "system\n")
+
+ // if we have tools but no system message, match the reference implementation by providing a default system message
+ if systemMessage == "" {
+ systemMessage = "You are Qwen, a helpful AI assistant that can interact with a computer to solve tasks."
+ }
+
+ sb.WriteString(systemMessage)
+
+ if len(tools) > 0 {
+ sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n")
+ sb.WriteString("")
+ for _, tool := range tools {
+ sb.WriteString("\n")
+ sb.WriteString("\n")
+ sb.WriteString("" + tool.Function.Name + "")
+ if tool.Function.Description != "" {
+ sb.WriteString("\n" + tool.Function.Description + "")
+ }
+ sb.WriteString("\n")
+
+ for name, prop := range tool.Function.Parameters.Properties {
+ sb.WriteString("\n")
+ sb.WriteString("\n" + name + "")
+
+ if len(prop.Type) > 0 {
+ // TODO(!!!)(drifkin): we should match the reference implementation for
+ // more complex types here instead of using this format
+ sb.WriteString("\n" + prop.ToTypeScriptType() + "")
+ }
+
+ if prop.Description != "" {
+ sb.WriteString("\n" + prop.Description + "")
+ }
+
+ // Render any additional keys not already handled
+ handledKeys := map[string]bool{
+ "type": true,
+ "description": true,
+ }
+ sb.WriteString(renderAdditionalKeys(prop, handledKeys))
+
+ sb.WriteString("\n")
+ }
+
+ // Render extra keys for parameters (everything except 'type' and 'properties')
+ paramHandledKeys := map[string]bool{
+ "type": true,
+ "properties": true,
+ }
+ sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys))
+
+ sb.WriteString("\n")
+ sb.WriteString("\n")
+ }
+ sb.WriteString("\n")
+ sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n")
+ }
+
+ sb.WriteString(imEndTag + "\n")
+ }
+
+ for i, message := range filteredMessages {
+ lastMessage := i == len(filteredMessages)-1
+ prefill := lastMessage && message.Role == "assistant"
+ switch message.Role {
+ case "assistant":
+ if len(message.ToolCalls) > 0 {
+ sb.WriteString(imStartTag + "assistant\n")
+ if message.Content != "" {
+ sb.WriteString(message.Content + "\n")
+ }
+ for _, toolCall := range message.ToolCalls {
+ sb.WriteString("\n\n")
+ for name, value := range toolCall.Function.Arguments {
+ valueStr := formatToolCallArgument(value)
+ sb.WriteString("\n\n" + valueStr + "\n")
+ }
+ sb.WriteString("\n\n")
+ }
+ sb.WriteString("<|im_end|>\n")
+ } else {
+ sb.WriteString(imStartTag + "assistant\n")
+ sb.WriteString(message.Content)
+ if !prefill {
+ sb.WriteString(imEndTag + "\n")
+ }
+ }
+ case "tool":
+ // consecutive tool responses should share a single `user`, but
+ // have their own tags
+
+ // only start a new user block if this is the first tool response
+ if i == 0 || filteredMessages[i-1].Role != "tool" {
+ sb.WriteString(imStartTag + "user\n")
+ }
+
+ sb.WriteString("\n")
+ sb.WriteString(message.Content)
+ sb.WriteString("\n\n")
+
+ // close the user block only if this is the last tool response
+ if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
+ sb.WriteString(imEndTag + "\n")
+ }
+ default:
+ sb.WriteString(imStartTag + message.Role + "\n")
+ sb.WriteString(message.Content)
+ sb.WriteString(imEndTag + "\n")
+ }
+
+ if lastMessage && !prefill {
+ sb.WriteString(imStartTag + "assistant\n")
+ }
+ }
+
+ return sb.String(), nil
+}
+
+func formatToolCallArgument(value any) string {
+ if value == nil {
+ return "null"
+ }
+
+ switch v := value.(type) {
+ case string:
+ return v
+ case []byte:
+ return string(v)
+ }
+
+ if reflect.TypeOf(value) != nil {
+ kind := reflect.TypeOf(value).Kind()
+ if kind == reflect.Map || kind == reflect.Slice || kind == reflect.Array {
+ if marshalled, err := json.Marshal(value); err == nil {
+ return string(marshalled)
+ }
+ }
+ }
+
+ return fmt.Sprintf("%v", value)
+}
diff --git a/model/renderers/qwen3coder_test.go b/model/renderers/qwen3coder_test.go
new file mode 100644
index 00000000..4aaa066d
--- /dev/null
+++ b/model/renderers/qwen3coder_test.go
@@ -0,0 +1,338 @@
+package renderers
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/ollama/ollama/api"
+)
+
+func TestQwen3CoderRenderer(t *testing.T) {
+ tests := []struct {
+ name string
+ msgs []api.Message
+ tools []api.Tool
+ expected string
+ }{
+ {
+ name: "basic",
+ msgs: []api.Message{
+ {Role: "system", Content: "You are a helpful assistant."},
+ {Role: "user", Content: "Hello, how are you?"},
+ },
+ expected: `<|im_start|>system
+You are a helpful assistant.<|im_end|>
+<|im_start|>user
+Hello, how are you?<|im_end|>
+<|im_start|>assistant
+`,
+ },
+ {
+ name: "with tools and response",
+ msgs: []api.Message{
+ {Role: "system", Content: "You are a helpful assistant with access to tools."},
+ {Role: "user", Content: "What is the weather like in San Francisco?"},
+ {
+ Role: "assistant",
+ Content: "I'll check the weather in San Francisco for you.",
+ ToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: map[string]any{
+ "unit": "fahrenheit",
+ },
+ },
+ },
+ },
+ },
+ {Role: "tool", Content: "{\"location\": \"San Francisco, CA\", \"temperature\": 68, \"condition\": \"partly cloudy\", \"humidity\": 65, \"wind_speed\": 12}", ToolName: "get_weather"},
+ {Role: "user", Content: "That sounds nice! What about New York?"},
+ },
+ tools: []api.Tool{
+ {Function: api.ToolFunction{
+ Name: "get_weather",
+ Description: "Get the current weather in a given location",
+ Parameters: api.ToolFunctionParameters{
+ Required: []string{"unit"},
+ Properties: map[string]api.ToolProperty{
+ "unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
+ // TODO(drifkin): add multiple params back once we have predictable
+ // order via some sort of ordered map type (see
+ // )
+ /*
+ "location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
+ */
+ },
+ },
+ }},
+ },
+ expected: `<|im_start|>system
+You are a helpful assistant with access to tools.
+
+# Tools
+
+You have access to the following functions:
+
+
+
+get_weather
+Get the current weather in a given location
+
+
+unit
+string
+The unit of temperature
+["celsius","fahrenheit"]
+
+["unit"]
+
+
+
+
+If you choose to call a function ONLY reply in the following format with NO suffix:
+
+
+
+
+value_1
+
+
+This is the value for the second parameter
+that can span
+multiple lines
+
+
+
+
+
+Reminder:
+- Function calls MUST follow the specified format: an inner block must be nested within XML tags
+- Required parameters MUST be specified
+- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
+- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
+<|im_end|>
+<|im_start|>user
+What is the weather like in San Francisco?<|im_end|>
+<|im_start|>assistant
+I'll check the weather in San Francisco for you.
+
+
+
+
+fahrenheit
+
+
+<|im_end|>
+<|im_start|>user
+
+{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
+
+<|im_end|>
+<|im_start|>user
+That sounds nice! What about New York?<|im_end|>
+<|im_start|>assistant
+`,
+ },
+ {
+ name: "parallel tool calls",
+ msgs: []api.Message{
+ {Role: "system", Content: "You are a helpful assistant with access to tools."},
+ {Role: "user", Content: "call double(1) and triple(2)"},
+ {Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
+ {Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
+ {Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
+ }},
+ {Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
+ {Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
+ },
+ tools: []api.Tool{
+ {Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
+ "number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
+ }}}},
+ {Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
+ "number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
+ }}}},
+ },
+ expected: `<|im_start|>system
+You are a helpful assistant with access to tools.
+
+# Tools
+
+You have access to the following functions:
+
+
+
+double
+Double a number
+
+
+number
+string
+The number to double
+
+
+
+
+triple
+Triple a number
+
+
+number
+string
+The number to triple
+
+
+
+
+
+If you choose to call a function ONLY reply in the following format with NO suffix:
+
+
+
+
+value_1
+
+
+This is the value for the second parameter
+that can span
+multiple lines
+
+
+
+
+
+Reminder:
+- Function calls MUST follow the specified format: an inner block must be nested within XML tags
+- Required parameters MUST be specified
+- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
+- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
+<|im_end|>
+<|im_start|>user
+call double(1) and triple(2)<|im_end|>
+<|im_start|>assistant
+I'll call double(1) and triple(2) for you.
+
+
+
+
+1
+
+
+
+
+
+
+2
+
+
+<|im_end|>
+<|im_start|>user
+
+{"number": 2}
+
+
+{"number": 6}
+
+<|im_end|>
+<|im_start|>assistant
+`,
+ },
+ {
+ name: "prefill",
+ msgs: []api.Message{
+ {Role: "system", Content: "You are a helpful assistant."},
+ {Role: "user", Content: "Tell me something interesting."},
+ {Role: "assistant", Content: "I'll tell you something interesting about cats"},
+ },
+ expected: `<|im_start|>system
+You are a helpful assistant.<|im_end|>
+<|im_start|>user
+Tell me something interesting.<|im_end|>
+<|im_start|>assistant
+I'll tell you something interesting about cats`,
+ },
+ {
+ name: "complex tool call arguments should remain json encoded",
+ msgs: []api.Message{
+ {Role: "user", Content: "call tool"},
+ {Role: "assistant", ToolCalls: []api.ToolCall{
+ {Function: api.ToolCallFunction{
+ Name: "echo",
+ Arguments: map[string]any{
+ "payload": map[string]any{"foo": "bar"},
+ },
+ }},
+ }},
+ {Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
+ },
+ expected: `<|im_start|>user
+call tool<|im_end|>
+<|im_start|>assistant
+
+
+
+
+{"foo":"bar"}
+
+
+<|im_end|>
+<|im_start|>user
+
+{"payload": {"foo": "bar"}}
+
+<|im_end|>
+<|im_start|>assistant
+`,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if diff := cmp.Diff(rendered, tt.expected); diff != "" {
+ t.Errorf("mismatch (-got +want):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestFormatToolCallArgument(t *testing.T) {
+ tests := []struct {
+ name string
+ arg any
+ expected string
+ }{
+ {
+ name: "string",
+ arg: "foo",
+ // notice no quotes around the string
+ expected: "foo",
+ },
+ {
+ name: "map",
+ arg: map[string]any{"foo": "bar"},
+ expected: "{\"foo\":\"bar\"}",
+ },
+ {
+ name: "number",
+ arg: 1,
+ expected: "1",
+ },
+ {
+ name: "boolean",
+ arg: true,
+ expected: "true",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := formatToolCallArgument(tt.arg)
+ if got != tt.expected {
+ t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected)
+ }
+ })
+ }
+}
diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go
new file mode 100644
index 00000000..2dfb51e4
--- /dev/null
+++ b/model/renderers/renderer.go
@@ -0,0 +1,26 @@
+package renderers
+
+import (
+ "fmt"
+
+ "github.com/ollama/ollama/api"
+)
+
+type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error)
+
+func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
+ renderer := rendererForName(name)
+ if renderer == nil {
+ return "", fmt.Errorf("unknown renderer %q", name)
+ }
+ return renderer(msgs, tools, think)
+}
+
+func rendererForName(name string) rendererFunc {
+ switch name {
+ case "qwen3-coder":
+ return Qwen3CoderRenderer
+ default:
+ return nil
+ }
+}
diff --git a/model/sentencepiece.go b/model/sentencepiece.go
index 827ce00d..db07beee 100644
--- a/model/sentencepiece.go
+++ b/model/sentencepiece.go
@@ -12,18 +12,18 @@ import (
const spmWhitespaceSep = "▁"
-type SentencePieceModel struct {
+type SentencePiece struct {
maxTokenLen int
vocab *Vocabulary
}
-var _ TextProcessor = (*SentencePieceModel)(nil)
+var _ TextProcessor = (*SentencePiece)(nil)
-func (spm SentencePieceModel) Vocabulary() *Vocabulary {
+func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
}
-func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
+func NewSentencePiece(vocab *Vocabulary) SentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
@@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
- return SentencePieceModel{
+ return SentencePiece{
maxTokenLen: maxTokenLen,
vocab: vocab,
}
}
-func (spm SentencePieceModel) Is(id int32, special Special) bool {
+func (spm SentencePiece) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
-func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
+func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special)
@@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} {
return item
}
-func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
+func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
diff --git a/model/sentencepiece_test.go b/model/sentencepiece_test.go
index 50ac2678..8f4570c1 100644
--- a/model/sentencepiece_test.go
+++ b/model/sentencepiece_test.go
@@ -12,7 +12,7 @@ import (
"github.com/ollama/ollama/convert/sentencepiece"
)
-func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
+func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
@@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
}
}
- return NewSentencePieceModel(&v)
+ return NewSentencePiece(&v)
}
func TestSentencePieceEncode(t *testing.T) {
@@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) {
})
}
-func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
+func TestSentencePieceDecodeByteTokens(t *testing.T) {
vocab := &Vocabulary{
Values: []string{
"normal",
@@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
Scores: []float32{0, 0, 0, 0, 0},
}
- spm := NewSentencePieceModel(vocab)
+ spm := NewSentencePiece(vocab)
tests := []struct {
name string
diff --git a/model/wordpiece.go b/model/wordpiece.go
new file mode 100644
index 00000000..e8d5e848
--- /dev/null
+++ b/model/wordpiece.go
@@ -0,0 +1,167 @@
+package model
+
+import (
+ "fmt"
+ "iter"
+ "strings"
+ "unicode"
+
+ "github.com/ollama/ollama/logutil"
+)
+
+type WordPiece struct {
+ vocab *Vocabulary
+}
+
+// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
+// this differs from original word piece which uses "##" to indicate subwords.
+const ggmlPrefix = "▁"
+
+var wordPieceReplacer = strings.NewReplacer(
+ " .", ".",
+ " ?", "?",
+ " !", "!",
+ " ,", ",",
+ " ' ", "'",
+ " n't", "n't",
+ " 'm", "'m",
+ " do not", " don't",
+ " 's", "'s",
+ " 've", "'ve",
+ " 're", "'re",
+)
+
+// Decode implements TextProcessor.
+func (wpm WordPiece) Decode(ids []int32) (string, error) {
+ var sb strings.Builder
+ for i, id := range ids {
+ if id < 0 || int(id) >= len(wpm.vocab.Values) {
+ return "", fmt.Errorf("invalid token id: %d", id)
+ }
+
+ var separator string
+ piece := wpm.vocab.Values[id]
+ if i > 0 &&
+ (strings.HasPrefix(piece, ggmlPrefix) ||
+ (strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
+ separator = " "
+ }
+
+ sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
+ }
+
+ return sb.String(), nil
+}
+
+// words splits a string into words, treating CJK characters as separate words.
+// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
+func (wpm WordPiece) words(s string) iter.Seq[string] {
+ return func(yield func(string) bool) {
+ runes := make([]rune, 0, len(s)*3)
+ for _, r := range s {
+ switch {
+ case r >= 0x4E00 && r <= 0x9FFF,
+ r >= 0x3400 && r <= 0x4DBF,
+ r >= 0x20000 && r <= 0x2A6DF,
+ r >= 0x2A700 && r <= 0x2B73F,
+ r >= 0x2B740 && r <= 0x2B81F,
+ r >= 0x2B820 && r <= 0x2CEAF,
+ r >= 0xF900 && r <= 0xFAFF,
+ r >= 0x2F800 && r <= 0x2FA1F:
+ runes = append(runes, ' ', r, ' ')
+ default:
+ runes = append(runes, r)
+ }
+ }
+
+ for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
+ // split on but keep punctuation
+ var start int
+ for start < len(w) {
+ end := strings.IndexFunc(w[start:], unicode.IsPunct)
+ if end < 0 {
+ end = len(w) - start
+ } else if end == 0 {
+ end = 1
+ }
+
+ if !yield(w[start : start+end]) {
+ return
+ }
+
+ start += end
+ }
+ }
+ }
+}
+
+// Encode implements TextProcessor.
+func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
+ var ids []int32
+
+ // TODO: use [UNK] from config
+ unk := wpm.vocab.Encode("[UNK]")
+ for word := range wpm.words(s) {
+ var start int
+ var pieces []int32
+ for start < len(word) {
+ end := len(word)
+
+ var piece int32
+ for start < end {
+ subword := word[start:end]
+ if start == 0 {
+ subword = ggmlPrefix + subword
+ }
+
+ // TODO: some models might not want [ToLower]
+ piece = wpm.vocab.Encode(strings.ToLower(subword))
+ if piece >= 0 {
+ break
+ }
+
+ end--
+ }
+
+ if piece < 0 {
+ // Unknown token
+ pieces = pieces[:0]
+ break
+ }
+
+ pieces = append(pieces, piece)
+ start = end
+ }
+
+ if len(pieces) > 0 {
+ ids = append(ids, pieces...)
+ } else {
+ ids = append(ids, unk)
+ }
+ }
+
+ if addSpecial && len(ids) > 0 {
+ ids = wpm.vocab.addSpecials(ids)
+ }
+
+ logutil.Trace("encoded", "string", s, "ids", ids)
+ return ids, nil
+}
+
+// Is implements TextProcessor.
+func (wpm WordPiece) Is(id int32, special Special) bool {
+ return wpm.vocab.Is(id, special)
+}
+
+// Vocabulary implements TextProcessor.
+func (wpm WordPiece) Vocabulary() *Vocabulary {
+ return wpm.vocab
+}
+
+var _ TextProcessor = (*WordPiece)(nil)
+
+func NewWordPiece(vocab *Vocabulary) WordPiece {
+ return WordPiece{
+ vocab: vocab,
+ }
+}
diff --git a/model/wordpiece_test.go b/model/wordpiece_test.go
new file mode 100644
index 00000000..258fbffc
--- /dev/null
+++ b/model/wordpiece_test.go
@@ -0,0 +1,51 @@
+package model
+
+import (
+ "slices"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestWordPiece(t *testing.T) {
+ wpm := NewWordPiece(
+ &Vocabulary{
+ Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
+ AddBOS: true,
+ AddEOS: true,
+ BOS: []int32{1},
+ EOS: []int32{2},
+ })
+
+ ids, err := wpm.Encode("Hello world!", true)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
+ t.Errorf("unexpected ids (-want +got):\n%s", diff)
+ }
+
+ words, err := wpm.Decode(ids)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
+ t.Errorf("unexpected words (-want +got):\n%s", diff)
+ }
+}
+
+func TestWordPieceWords(t *testing.T) {
+ var wpm WordPiece
+
+ basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
+ if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
+ t.Errorf("unexpected words (-want +got):\n%s", diff)
+ }
+
+ chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
+ if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
+ t.Errorf("unexpected words (-want +got):\n%s", diff)
+ }
+}
diff --git a/openai/openai.go b/openai/openai.go
index b6a8a95e..7ef5ac6d 100644
--- a/openai/openai.go
+++ b/openai/openai.go
@@ -105,16 +105,18 @@ type ChatCompletionRequest struct {
Tools []api.Tool `json:"tools"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
+ DebugRenderOnly bool `json:"_debug_render_only"`
}
type ChatCompletion struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- SystemFingerprint string `json:"system_fingerprint"`
- Choices []Choice `json:"choices"`
- Usage Usage `json:"usage,omitempty"`
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ SystemFingerprint string `json:"system_fingerprint"`
+ Choices []Choice `json:"choices"`
+ Usage Usage `json:"usage,omitempty"`
+ DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
}
type ChatCompletionChunk struct {
@@ -141,6 +143,7 @@ type CompletionRequest struct {
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
+ DebugRenderOnly bool `json:"_debug_render_only"`
}
type Completion struct {
@@ -273,8 +276,8 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}
return nil
}(r.DoneReason),
- }},
- Usage: toUsage(r),
+ }}, Usage: toUsage(r),
+ DebugInfo: r.DebugInfo,
}
}
@@ -568,13 +571,14 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
return &api.ChatRequest{
- Model: r.Model,
- Messages: messages,
- Format: format,
- Options: options,
- Stream: &r.Stream,
- Tools: r.Tools,
- Think: think,
+ Model: r.Model,
+ Messages: messages,
+ Format: format,
+ Options: options,
+ Stream: &r.Stream,
+ Tools: r.Tools,
+ Think: think,
+ DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
@@ -648,11 +652,12 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
}
return api.GenerateRequest{
- Model: r.Model,
- Prompt: r.Prompt,
- Options: options,
- Stream: &r.Stream,
- Suffix: r.Suffix,
+ Model: r.Model,
+ Prompt: r.Prompt,
+ Options: options,
+ Stream: &r.Stream,
+ Suffix: r.Suffix,
+ DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
diff --git a/parser/parser.go b/parser/parser.go
index e080f1bb..c2e8f981 100644
--- a/parser/parser.go
+++ b/parser/parser.go
@@ -100,6 +100,10 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
req.System = c.Args
case "license":
licenses = append(licenses, c.Args)
+ case "renderer":
+ req.Renderer = c.Args
+ case "parser":
+ req.Parser = c.Args
case "message":
role, msg, _ := strings.Cut(c.Args, ": ")
messages = append(messages, api.Message{Role: role, Content: msg})
@@ -320,7 +324,7 @@ func (c Command) String() string {
switch c.Name {
case "model":
fmt.Fprintf(&sb, "FROM %s", c.Args)
- case "license", "template", "system", "adapter":
+ case "license", "template", "system", "adapter", "renderer", "parser":
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
case "message":
role, message, _ := strings.Cut(c.Args, ": ")
@@ -346,7 +350,7 @@ const (
var (
errMissingFrom = errors.New("no FROM line")
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
- errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
+ errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"")
)
type ParserError struct {
@@ -606,7 +610,7 @@ func isValidMessageRole(role string) bool {
func isValidCommand(cmd string) bool {
switch strings.ToLower(cmd) {
- case "from", "license", "template", "system", "adapter", "parameter", "message":
+ case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message":
return true
default:
return false
diff --git a/parser/parser_test.go b/parser/parser_test.go
index 7d5a808b..1524e890 100644
--- a/parser/parser_test.go
+++ b/parser/parser_test.go
@@ -198,6 +198,34 @@ BADCOMMAND param1 value1
}
}
+func TestParseFileRenderer(t *testing.T) {
+ input := `
+FROM foo
+RENDERER renderer1
+`
+
+ reader := strings.NewReader(input)
+
+ modelfile, err := ParseFile(reader)
+ require.NoError(t, err)
+
+ assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "renderer", Args: "renderer1"}}, modelfile.Commands)
+}
+
+func TestParseFileParser(t *testing.T) {
+ input := `
+FROM foo
+PARSER parser1
+`
+
+ reader := strings.NewReader(input)
+
+ modelfile, err := ParseFile(reader)
+ require.NoError(t, err)
+
+ assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "parser", Args: "parser1"}}, modelfile.Commands)
+}
+
func TestParseFileMessages(t *testing.T) {
cases := []struct {
input string
diff --git a/runner/llamarunner/cache.go b/runner/llamarunner/cache.go
index 44b24613..9ed1c292 100644
--- a/runner/llamarunner/cache.go
+++ b/runner/llamarunner/cache.go
@@ -204,13 +204,8 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
targetFree = max(targetFree, 1)
currentFree := c.numCtx - inputLen
- discard := targetFree - currentFree
- if discard < 0 {
- discard = 0
- }
-
- return discard
+ return max(targetFree-currentFree, 0)
}
type ErrReprocessInputs struct {
diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go
index f558f7b8..a3ffc3bd 100644
--- a/runner/ollamarunner/cache.go
+++ b/runner/ollamarunner/cache.go
@@ -242,13 +242,8 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
targetFree = max(targetFree, 1)
currentFree := c.numCtx - inputLen
- discard := targetFree - currentFree
- if discard < 0 {
- discard = 0
- }
-
- return discard
+ return max(targetFree-currentFree, 0)
}
type ErrReprocessInputs struct {
diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go
index 1081a1f5..480cfc19 100644
--- a/runner/ollamarunner/runner.go
+++ b/runner/ollamarunner/runner.go
@@ -11,7 +11,6 @@ import (
"image"
"log"
"log/slog"
- "math"
"net"
"net/http"
"os"
@@ -32,6 +31,7 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
@@ -405,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) run(ctx context.Context) {
s.ready.Wait()
- supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
+ supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
var activeBatch batchState
for {
@@ -467,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
var batchInputs []*input.Input
+ var batchOutputs []int32
var batch input.Batch
resumeSeq := -1
@@ -549,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id)
- seq.iBatch = len(batch.Outputs)
- if i+1 == len(seq.inputs) {
- batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
+ seq.iBatch = len(batchOutputs)
+ if i+1 == len(seq.inputs) || seq.embeddingOnly {
+ batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
}
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
seq.pendingInputs = append(seq.pendingInputs, inp)
@@ -576,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
+ batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
if err != nil {
err = fmt.Errorf("failed to build graph: %w", err)
@@ -703,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
}
// sample a token
- vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
- logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
+ vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
+ logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
@@ -898,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
- if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
+ if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
return
}
@@ -1046,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
batch.Positions[i] = int32(i)
}
- batch.Outputs = make([]int32, s.parallel)
- for i := range batch.Outputs {
- batch.Outputs[i] = int32(i)
- }
-
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
+ batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
cache := s.model.Config().Cache
if cache != nil {
diff --git a/scripts/env.sh b/scripts/env.sh
index 65a970bd..4f5641fd 100644
--- a/scripts/env.sh
+++ b/scripts/env.sh
@@ -16,6 +16,7 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
--build-arg=OLLAMA_FAST_BUILD \
--build-arg=CUSTOM_CPU_FLAGS \
--build-arg=GPU_RUNNER_CPU_FLAGS \
+ --build-arg=PARALLEL \
--build-arg=AMDGPU_TARGETS"
echo "Building Ollama"
diff --git a/server/create.go b/server/create.go
index bd970876..19f24ec8 100644
--- a/server/create.go
+++ b/server/create.go
@@ -10,8 +10,11 @@ import (
"io"
"io/fs"
"log/slog"
+ "net"
"net/http"
+ "net/url"
"os"
+ "path"
"path/filepath"
"slices"
"strings"
@@ -39,6 +42,14 @@ var (
)
func (s *Server) CreateHandler(c *gin.Context) {
+ config := &ConfigV2{
+ OS: "linux",
+ Architecture: "amd64",
+ RootFS: RootFS{
+ Type: "layers",
+ },
+ }
+
var r api.CreateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -48,6 +59,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
return
}
+ config.Renderer = r.Renderer
+ config.Parser = r.Parser
+
for v := range r.Files {
if !fs.ValidPath(v) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
@@ -77,20 +91,34 @@ func (s *Server) CreateHandler(c *gin.Context) {
oldManifest, _ := ParseNamedManifest(name)
var baseLayers []*layerGGML
+ var err error
+ var remote bool
+
if r.From != "" {
- slog.Debug("create model from model name")
+ slog.Debug("create model from model name", "from", r.From)
fromName := model.ParseName(r.From)
if !fromName.IsValid() {
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
return
}
+ if r.RemoteHost != "" {
+ ru, err := remoteURL(r.RemoteHost)
+ if err != nil {
+ ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
+ return
+ }
- ctx, cancel := context.WithCancel(c.Request.Context())
- defer cancel()
+ config.RemoteModel = r.From
+ config.RemoteHost = ru
+ remote = true
+ } else {
+ ctx, cancel := context.WithCancel(c.Request.Context())
+ defer cancel()
- baseLayers, err = parseFromModel(ctx, fromName, fn)
- if err != nil {
- ch <- gin.H{"error": err.Error()}
+ baseLayers, err = parseFromModel(ctx, fromName, fn)
+ if err != nil {
+ ch <- gin.H{"error": err.Error()}
+ }
}
} else if r.Files != nil {
baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn)
@@ -110,7 +138,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
}
var adapterLayers []*layerGGML
- if r.Adapters != nil {
+ if !remote && r.Adapters != nil {
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
if err != nil {
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
@@ -128,7 +156,56 @@ func (s *Server) CreateHandler(c *gin.Context) {
baseLayers = append(baseLayers, adapterLayers...)
}
- if err := createModel(r, name, baseLayers, fn); err != nil {
+ // Info is not currently exposed by Modelfiles, but allows overriding various
+ // config values
+ if r.Info != nil {
+ caps, ok := r.Info["capabilities"]
+ if ok {
+ switch tcaps := caps.(type) {
+ case []any:
+ caps := make([]string, len(tcaps))
+ for i, c := range tcaps {
+ str, ok := c.(string)
+ if !ok {
+ continue
+ }
+ caps[i] = str
+ }
+ config.Capabilities = append(config.Capabilities, caps...)
+ }
+ }
+
+ strFromInfo := func(k string) string {
+ v, ok := r.Info[k]
+ if ok {
+ val := v.(string)
+ return val
+ }
+ return ""
+ }
+
+ vFromInfo := func(k string) float64 {
+ v, ok := r.Info[k]
+ if ok {
+ val := v.(float64)
+ return val
+ }
+ return 0
+ }
+
+ config.ModelFamily = strFromInfo("model_family")
+ if config.ModelFamily != "" {
+ config.ModelFamilies = []string{config.ModelFamily}
+ }
+
+ config.BaseName = strFromInfo("base_name")
+ config.FileType = strFromInfo("quantization_level")
+ config.ModelType = strFromInfo("parameter_size")
+ config.ContextLen = int(vFromInfo("context_length"))
+ config.EmbedLen = int(vFromInfo("embedding_length"))
+ }
+
+ if err := createModel(r, name, baseLayers, config, fn); err != nil {
if errors.Is(err, errBadTemplate) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
return
@@ -154,6 +231,51 @@ func (s *Server) CreateHandler(c *gin.Context) {
streamResponse(c, ch)
}
+func remoteURL(raw string) (string, error) {
+ // Special‑case: user supplied only a path ("/foo/bar").
+ if strings.HasPrefix(raw, "/") {
+ return (&url.URL{
+ Scheme: "http",
+ Host: net.JoinHostPort("localhost", "11434"),
+ Path: path.Clean(raw),
+ }).String(), nil
+ }
+
+ if !strings.Contains(raw, "://") {
+ raw = "http://" + raw
+ }
+
+ if raw == "ollama.com" || raw == "http://ollama.com" {
+ raw = "https://ollama.com:443"
+ }
+
+ u, err := url.Parse(raw)
+ if err != nil {
+ return "", fmt.Errorf("parse error: %w", err)
+ }
+
+ if u.Host == "" {
+ u.Host = "localhost"
+ }
+
+ hostPart, portPart, err := net.SplitHostPort(u.Host)
+ if err == nil {
+ u.Host = net.JoinHostPort(hostPart, portPart)
+ } else {
+ u.Host = net.JoinHostPort(u.Host, "11434")
+ }
+
+ if u.Path != "" {
+ u.Path = path.Clean(u.Path)
+ }
+
+ if u.Path == "/" {
+ u.Path = ""
+ }
+
+ return u.String(), nil
+}
+
func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
switch detectModelTypeFromFiles(files) {
case "safetensors":
@@ -316,15 +438,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
return ggml.KV{}, fmt.Errorf("no base model was found")
}
-func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, fn func(resp api.ProgressResponse)) (err error) {
- config := ConfigV2{
- OS: "linux",
- Architecture: "amd64",
- RootFS: RootFS{
- Type: "layers",
- },
- }
-
+func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
var layers []Layer
for _, layer := range baseLayers {
if layer.GGML != nil {
@@ -404,7 +518,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
return err
}
- configLayer, err := createConfigLayer(layers, config)
+ configLayer, err := createConfigLayer(layers, *config)
if err != nil {
return err
}
diff --git a/server/create_test.go b/server/create_test.go
index 59a07ff1..061efb81 100644
--- a/server/create_test.go
+++ b/server/create_test.go
@@ -104,3 +104,154 @@ func TestConvertFromSafetensors(t *testing.T) {
})
}
}
+
+func TestRemoteURL(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expected string
+ hasError bool
+ }{
+ {
+ name: "absolute path",
+ input: "/foo/bar",
+ expected: "http://localhost:11434/foo/bar",
+ hasError: false,
+ },
+ {
+ name: "absolute path with cleanup",
+ input: "/foo/../bar",
+ expected: "http://localhost:11434/bar",
+ hasError: false,
+ },
+ {
+ name: "root path",
+ input: "/",
+ expected: "http://localhost:11434/",
+ hasError: false,
+ },
+ {
+ name: "host without scheme",
+ input: "example.com",
+ expected: "http://example.com:11434",
+ hasError: false,
+ },
+ {
+ name: "host with port",
+ input: "example.com:8080",
+ expected: "http://example.com:8080",
+ hasError: false,
+ },
+ {
+ name: "full URL",
+ input: "https://example.com:8080/path",
+ expected: "https://example.com:8080/path",
+ hasError: false,
+ },
+ {
+ name: "full URL with path cleanup",
+ input: "https://example.com:8080/path/../other",
+ expected: "https://example.com:8080/other",
+ hasError: false,
+ },
+ {
+ name: "ollama.com special case",
+ input: "ollama.com",
+ expected: "https://ollama.com:443",
+ hasError: false,
+ },
+ {
+ name: "http ollama.com special case",
+ input: "http://ollama.com",
+ expected: "https://ollama.com:443",
+ hasError: false,
+ },
+ {
+ name: "URL with only host",
+ input: "http://example.com",
+ expected: "http://example.com:11434",
+ hasError: false,
+ },
+ {
+ name: "URL with root path cleaned",
+ input: "http://example.com/",
+ expected: "http://example.com:11434",
+ hasError: false,
+ },
+ {
+ name: "invalid URL",
+ input: "http://[::1]:namedport", // invalid port
+ expected: "",
+ hasError: true,
+ },
+ {
+ name: "empty string",
+ input: "",
+ expected: "http://localhost:11434",
+ hasError: false,
+ },
+ {
+ name: "host with scheme but no port",
+ input: "http://localhost",
+ expected: "http://localhost:11434",
+ hasError: false,
+ },
+ {
+ name: "complex path cleanup",
+ input: "/a/b/../../c/./d",
+ expected: "http://localhost:11434/c/d",
+ hasError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := remoteURL(tt.input)
+
+ if tt.hasError {
+ if err == nil {
+ t.Errorf("expected error but got none")
+ }
+ return
+ }
+
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ return
+ }
+
+ if result != tt.expected {
+ t.Errorf("expected %q, got %q", tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestRemoteURL_Idempotent(t *testing.T) {
+ // Test that applying remoteURL twice gives the same result as applying it once
+ testInputs := []string{
+ "/foo/bar",
+ "example.com",
+ "https://example.com:8080/path",
+ "ollama.com",
+ "http://localhost:11434",
+ }
+
+ for _, input := range testInputs {
+ t.Run(input, func(t *testing.T) {
+ firstResult, err := remoteURL(input)
+ if err != nil {
+ t.Fatalf("first call failed: %v", err)
+ }
+
+ secondResult, err := remoteURL(firstResult)
+ if err != nil {
+ t.Fatalf("second call failed: %v", err)
+ }
+
+ if firstResult != secondResult {
+ t.Errorf("function is not idempotent: first=%q, second=%q", firstResult, secondResult)
+ }
+ })
+ }
+}
diff --git a/server/images.go b/server/images.go
index 504eb95c..9466b7fb 100644
--- a/server/images.go
+++ b/server/images.go
@@ -24,6 +24,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/gguf"
+ "github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
@@ -73,29 +74,38 @@ func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}
// Check for completion capability
- f, err := gguf.Open(m.ModelPath)
- if err == nil {
- defer f.Close()
+ if m.ModelPath != "" {
+ f, err := gguf.Open(m.ModelPath)
+ if err == nil {
+ defer f.Close()
- if f.KeyValue("pooling_type").Valid() {
- capabilities = append(capabilities, model.CapabilityEmbedding)
+ if f.KeyValue("pooling_type").Valid() {
+ capabilities = append(capabilities, model.CapabilityEmbedding)
+ } else {
+ // If no embedding is specified, we assume the model supports completion
+ capabilities = append(capabilities, model.CapabilityCompletion)
+ }
+ if f.KeyValue("vision.block_count").Valid() {
+ capabilities = append(capabilities, model.CapabilityVision)
+ }
} else {
- // If no embedding is specified, we assume the model supports completion
- capabilities = append(capabilities, model.CapabilityCompletion)
+ slog.Error("couldn't open model file", "error", err)
}
- if f.KeyValue("vision.block_count").Valid() {
- capabilities = append(capabilities, model.CapabilityVision)
+ } else if len(m.Config.Capabilities) > 0 {
+ for _, c := range m.Config.Capabilities {
+ capabilities = append(capabilities, model.Capability(c))
}
} else {
- slog.Error("couldn't open model file", "error", err)
+ slog.Warn("unknown capabilities for model", "model", m.Name)
}
if m.Template == nil {
return capabilities
}
+ builtinParser := parsers.ParserForName(m.Config.Parser)
// Check for tools capability
- if slices.Contains(m.Template.Vars(), "tools") {
+ if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) {
capabilities = append(capabilities, model.CapabilityTools)
}
@@ -109,10 +119,16 @@ func (m *Model) Capabilities() []model.Capability {
capabilities = append(capabilities, model.CapabilityVision)
}
+ // Skip the thinking check if it's already set
+ if slices.Contains(capabilities, "thinking") {
+ return capabilities
+ }
+
// Check for thinking capability
openingTag, closingTag := thinking.InferTags(m.Template.Template)
hasTags := openingTag != "" && closingTag != ""
- if hasTags || slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) {
+ isGptoss := slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily)
+ if hasTags || isGptoss || (builtinParser != nil && builtinParser.HasThinkingSupport()) {
capabilities = append(capabilities, model.CapabilityThinking)
}
@@ -198,6 +214,20 @@ func (m *Model) String() string {
})
}
+ if m.Config.Renderer != "" {
+ modelfile.Commands = append(modelfile.Commands, parser.Command{
+ Name: "renderer",
+ Args: m.Config.Renderer,
+ })
+ }
+
+ if m.Config.Parser != "" {
+ modelfile.Commands = append(modelfile.Commands, parser.Command{
+ Name: "parser",
+ Args: m.Config.Parser,
+ })
+ }
+
for k, v := range m.Options {
switch v := v.(type) {
case []any:
@@ -236,8 +266,19 @@ type ConfigV2 struct {
ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"`
ModelFamilies []string `json:"model_families"`
- ModelType string `json:"model_type"`
- FileType string `json:"file_type"`
+ ModelType string `json:"model_type"` // shown as Parameter Size
+ FileType string `json:"file_type"` // shown as Quantization Level
+ Renderer string `json:"renderer,omitempty"`
+ Parser string `json:"parser,omitempty"`
+
+ RemoteHost string `json:"remote_host,omitempty"`
+ RemoteModel string `json:"remote_model,omitempty"`
+
+ // used for remotes
+ Capabilities []string `json:"capabilities,omitempty"`
+ ContextLen int `json:"context_length,omitempty"`
+ EmbedLen int `json:"embedding_length,omitempty"`
+ BaseName string `json:"base_name,omitempty"`
// required by spec
Architecture string `json:"architecture"`
diff --git a/server/internal/internal/backoff/backoff.go b/server/internal/internal/backoff/backoff.go
index 1f0634f7..08b4ed7f 100644
--- a/server/internal/internal/backoff/backoff.go
+++ b/server/internal/internal/backoff/backoff.go
@@ -25,10 +25,7 @@ func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
- d := time.Duration(n*n) * 10 * time.Millisecond
- if d > maxBackoff {
- d = maxBackoff
- }
+ d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
diff --git a/server/prompt.go b/server/prompt.go
index f1d8020e..56bc6303 100644
--- a/server/prompt.go
+++ b/server/prompt.go
@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
+ "github.com/ollama/ollama/model/renderers"
"github.com/ollama/ollama/template"
)
@@ -41,18 +42,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
}
}
- thinkVal := false
- thinkLevel := ""
- if think != nil {
- thinkVal = think.Bool()
- thinkLevel = think.String()
- }
- var b bytes.Buffer
- if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
+ p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
+ if err != nil {
return "", nil, err
}
- s, err := tokenize(ctx, b.String())
+ s, err := tokenize(ctx, p)
if err != nil {
return "", nil, err
}
@@ -101,6 +96,23 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
}
// truncate any messages that do not fit into the context window
+ p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
+ if err != nil {
+ return "", nil, err
+ }
+
+ return p, images, nil
+}
+
+func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
+ if m.Config.Renderer != "" {
+ rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
+ if err != nil {
+ return "", err
+ }
+ return rendered, nil
+ }
+
var b bytes.Buffer
thinkVal := false
thinkLevel := ""
@@ -108,9 +120,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
thinkVal = think.Bool()
thinkLevel = think.String()
}
- if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
- return "", nil, err
+ if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
+ return "", err
}
-
- return b.String(), images, nil
+ return b.String(), nil
}
diff --git a/server/routes.go b/server/routes.go
index 5114cb74..a08a7289 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -15,6 +15,7 @@ import (
"net"
"net/http"
"net/netip"
+ "net/url"
"os"
"os/signal"
"slices"
@@ -28,6 +29,7 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
@@ -35,6 +37,7 @@ import (
"github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
+ "github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
@@ -188,6 +191,87 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
+ if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
+ origModel := req.Model
+
+ remoteURL, err := url.Parse(m.Config.RemoteHost)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
+ slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
+ c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
+ return
+ }
+
+ req.Model = m.Config.RemoteModel
+
+ if req.Template == "" && m.Template.String() != "" {
+ req.Template = m.Template.String()
+ }
+
+ if req.Options == nil {
+ req.Options = map[string]any{}
+ }
+
+ for k, v := range m.Options {
+ if _, ok := req.Options[k]; !ok {
+ req.Options[k] = v
+ }
+ }
+
+ // update the system prompt from the model if one isn't already specified
+ if req.System == "" && m.System != "" {
+ req.System = m.System
+ }
+
+ if len(m.Messages) > 0 {
+ slog.Warn("embedded messages in the model not supported with '/api/generate'; try '/api/chat' instead")
+ }
+
+ fn := func(resp api.GenerateResponse) error {
+ resp.Model = origModel
+ resp.RemoteModel = m.Config.RemoteModel
+ resp.RemoteHost = m.Config.RemoteHost
+
+ data, err := json.Marshal(resp)
+ if err != nil {
+ return err
+ }
+
+ if _, err = c.Writer.Write(append(data, '\n')); err != nil {
+ return err
+ }
+ c.Writer.Flush()
+ return nil
+ }
+
+ client := api.NewClient(remoteURL, http.DefaultClient)
+ err = client.Generate(c, &req, fn)
+ if err != nil {
+ var sErr api.AuthorizationError
+ if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
+ pk, pkErr := auth.GetPublicKey()
+ if pkErr != nil {
+ slog.Error("couldn't get public key", "error", pkErr)
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
+ return
+ }
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "error": "unauthorized",
+ "public_key": pk,
+ })
+ return
+ }
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ return
+ }
+
// expire the runner
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m)
@@ -329,10 +413,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// If debug mode is enabled, return the rendered template instead of calling the model
if req.DebugRenderOnly {
- c.JSON(http.StatusOK, api.DebugTemplateResponse{
+ c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
- DebugInfo: api.DebugInfo{
+ DebugInfo: &api.DebugInfo{
RenderedTemplate: prompt,
ImageCount: len(images),
},
@@ -348,6 +432,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
OpeningTag: openingTag,
ClosingTag: closingTag,
}
+ if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
+ thinkingState.AddContent(openingTag)
+ }
}
}
@@ -488,7 +575,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
truncate := true
-
if req.Truncate != nil && !*req.Truncate {
truncate = false
}
@@ -551,11 +637,27 @@ func (s *Server) EmbedHandler(c *gin.Context) {
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if !truncate {
- c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
+ c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
+ return
+ }
+
+ if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
+ ctxLen--
+ }
+
+ if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
+ ctxLen--
+ }
+
+ slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
+ if ctxLen <= 0 {
+ // return error if the truncated input would be empty or just special tokens
+ c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
+
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -922,6 +1024,28 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
ModifiedAt: manifest.fi.ModTime(),
}
+ if m.Config.RemoteHost != "" {
+ resp.RemoteHost = m.Config.RemoteHost
+ resp.RemoteModel = m.Config.RemoteModel
+
+ if m.Config.ModelFamily != "" {
+ resp.ModelInfo = make(map[string]any)
+ resp.ModelInfo["general.architecture"] = m.Config.ModelFamily
+
+ if m.Config.BaseName != "" {
+ resp.ModelInfo["general.basename"] = m.Config.BaseName
+ }
+
+ if m.Config.ContextLen > 0 {
+ resp.ModelInfo[fmt.Sprintf("%s.context_length", m.Config.ModelFamily)] = m.Config.ContextLen
+ }
+
+ if m.Config.EmbedLen > 0 {
+ resp.ModelInfo[fmt.Sprintf("%s.embedding_length", m.Config.ModelFamily)] = m.Config.EmbedLen
+ }
+ }
+ }
+
var params []string
cs := 30
for k, v := range m.Options {
@@ -952,6 +1076,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
fmt.Fprint(&sb, m.String())
resp.Modelfile = sb.String()
+ // skip loading tensor information if this is a remote model
+ if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
+ return resp, nil
+ }
+
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil {
return nil, err
@@ -1028,11 +1157,13 @@ func (s *Server) ListHandler(c *gin.Context) {
// tag should never be masked
models = append(models, api.ListModelResponse{
- Model: n.DisplayShortest(),
- Name: n.DisplayShortest(),
- Size: m.Size(),
- Digest: m.digest,
- ModifiedAt: m.fi.ModTime(),
+ Model: n.DisplayShortest(),
+ Name: n.DisplayShortest(),
+ RemoteModel: cf.RemoteModel,
+ RemoteHost: cf.RemoteHost,
+ Size: m.Size(),
+ Digest: m.digest,
+ ModifiedAt: m.fi.ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
@@ -1292,6 +1423,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/show", s.ShowHandler)
r.DELETE("/api/delete", s.DeleteHandler)
+ r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
+ r.POST("/api/me", s.WhoamiHandler)
+
// Create
r.POST("/api/create", s.CreateHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
@@ -1488,6 +1622,49 @@ func streamResponse(c *gin.Context, ch chan any) {
})
}
+func (s *Server) WhoamiHandler(c *gin.Context) {
+ // todo allow other hosts
+ u, err := url.Parse("https://ollama.com")
+ if err != nil {
+ slog.Error(err.Error())
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
+ return
+ }
+
+ client := api.NewClient(u, http.DefaultClient)
+ user, err := client.Whoami(c)
+ if err != nil {
+ slog.Error(err.Error())
+ }
+ c.JSON(http.StatusOK, user)
+}
+
+func (s *Server) SignoutHandler(c *gin.Context) {
+ encodedKey := c.Param("encodedKey")
+
+ // todo allow other hosts
+ u, err := url.Parse("https://ollama.com")
+ if err != nil {
+ slog.Error(err.Error())
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
+ return
+ }
+
+ client := api.NewClient(u, http.DefaultClient)
+ err = client.Signout(c, encodedKey)
+ if err != nil {
+ slog.Error(err.Error())
+ if strings.Contains(err.Error(), "page not found") || strings.Contains(err.Error(), "invalid credentials") {
+ c.JSON(http.StatusNotFound, gin.H{"error": "you are not currently signed in"})
+ return
+ }
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
+ return
+ }
+
+ c.JSON(http.StatusOK, nil)
+}
+
func (s *Server) PsHandler(c *gin.Context) {
models := []api.ProcessModelResponse{}
@@ -1544,21 +1721,34 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
- // expire the runner
- if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
- model, err := GetModel(req.Model)
- if err != nil {
- switch {
- case os.IsNotExist(err):
- c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
- case err.Error() == errtypes.InvalidModelNameErrMsg:
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
- default:
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
- }
- return
+ name := model.ParseName(req.Model)
+ if !name.IsValid() {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
+ return
+ }
+
+ name, err := getExistingName(name)
+ if err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
+ return
+ }
+
+ m, err := GetModel(req.Model)
+ if err != nil {
+ switch {
+ case os.IsNotExist(err):
+ c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
+ case err.Error() == errtypes.InvalidModelNameErrMsg:
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ default:
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
- s.sched.expireRunner(model)
+ return
+ }
+
+ // expire the runner
+ if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
+ s.sched.expireRunner(m)
c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
@@ -1570,6 +1760,80 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
+ if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
+ origModel := req.Model
+
+ remoteURL, err := url.Parse(m.Config.RemoteHost)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
+ slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
+ c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
+ return
+ }
+
+ req.Model = m.Config.RemoteModel
+ if req.Options == nil {
+ req.Options = map[string]any{}
+ }
+
+ msgs := append(m.Messages, req.Messages...)
+ if req.Messages[0].Role != "system" && m.System != "" {
+ msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
+ }
+ msgs = filterThinkTags(msgs, m)
+ req.Messages = msgs
+
+ for k, v := range m.Options {
+ if _, ok := req.Options[k]; !ok {
+ req.Options[k] = v
+ }
+ }
+
+ fn := func(resp api.ChatResponse) error {
+ resp.Model = origModel
+ resp.RemoteModel = m.Config.RemoteModel
+ resp.RemoteHost = m.Config.RemoteHost
+
+ data, err := json.Marshal(resp)
+ if err != nil {
+ return err
+ }
+
+ if _, err = c.Writer.Write(append(data, '\n')); err != nil {
+ return err
+ }
+ c.Writer.Flush()
+ return nil
+ }
+
+ client := api.NewClient(remoteURL, http.DefaultClient)
+ err = client.Chat(c, &req, fn)
+ if err != nil {
+ var sErr api.AuthorizationError
+ if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
+ pk, pkErr := auth.GetPublicKey()
+ if pkErr != nil {
+ slog.Error("couldn't get public key", "error", pkErr)
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
+ return
+ }
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "error": "unauthorized",
+ "public_key": pk,
+ })
+ return
+ }
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ return
+ }
+
caps := []model.Capability{model.CapabilityCompletion}
if len(req.Tools) > 0 {
caps = append(caps, model.CapabilityTools)
@@ -1578,17 +1842,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
caps = append(caps, model.CapabilityThinking)
}
- name := model.ParseName(req.Model)
- if !name.IsValid() {
- c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
- return
- }
- name, err := getExistingName(name)
- if err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
- return
- }
-
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
@@ -1617,10 +1870,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
msgs = filterThinkTags(msgs, m)
+ var builtinParser parsers.Parser
+ if m.Config.Parser != "" {
+ builtinParser = parsers.ParserForName(m.Config.Parser)
+ }
+
var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
- useHarmony := shouldUseHarmony(m)
+ useHarmony := shouldUseHarmony(m) || m.Config.Parser == "harmony"
processedTools := req.Tools
if useHarmony {
@@ -1650,10 +1908,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
// If debug mode is enabled, return the rendered template instead of calling the model
if req.DebugRenderOnly {
- c.JSON(http.StatusOK, api.DebugTemplateResponse{
+ c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
- DebugInfo: api.DebugInfo{
+ DebugInfo: &api.DebugInfo{
RenderedTemplate: prompt,
ImageCount: len(images),
},
@@ -1713,6 +1971,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
+ // TODO(drifkin): fold this as much as possibleinto the generic m.Config.Parser logic
if useHarmony {
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
res.Message.Content = content
@@ -1739,6 +1998,27 @@ func (s *Server) ChatHandler(c *gin.Context) {
ch <- res
}
+ return
+ } else if builtinParser != nil {
+ slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
+
+ content, thinking, toolCalls, err := builtinParser.Add(r.Content, req.Tools)
+ if err != nil {
+ ch <- gin.H{"error": err.Error()}
+ return
+ }
+
+ res.Message.Content = content
+ res.Message.Thinking = thinking
+ res.Message.ToolCalls = toolCalls
+
+ if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
+ slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
+ ch <- res
+ } else {
+ slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
+ }
+
return
}
diff --git a/server/routes_create_test.go b/server/routes_create_test.go
index 3b3d9910..189ef040 100644
--- a/server/routes_create_test.go
+++ b/server/routes_create_test.go
@@ -11,6 +11,7 @@ import (
"net/http/httptest"
"os"
"path/filepath"
+ "reflect"
"slices"
"strings"
"testing"
@@ -20,6 +21,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
+ "github.com/ollama/ollama/types/model"
)
var stream bool = false
@@ -615,6 +617,78 @@ func TestCreateTemplateSystem(t *testing.T) {
})
}
+func TestCreateAndShowRemoteModel(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var s Server
+
+ w := createRequest(t, s.CreateHandler, api.CreateRequest{
+ Model: "test",
+ From: "bob",
+ RemoteHost: "https://ollama.com",
+ Info: map[string]any{
+ "capabilities": []string{"completion", "tools", "thinking"},
+ "model_family": "gptoss",
+ "context_length": 131072,
+ "embedding_length": 2880,
+ "quantization_level": "MXFP4",
+ "parameter_size": "20.9B",
+ },
+ Stream: &stream,
+ })
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("exected status code 200, actual %d", w.Code)
+ }
+
+ w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test"})
+ if w.Code != http.StatusOK {
+ t.Fatalf("exected status code 200, actual %d", w.Code)
+ }
+
+ var resp api.ShowResponse
+ if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
+ t.Fatal(err)
+ }
+
+ expectedDetails := api.ModelDetails{
+ ParentModel: "",
+ Format: "",
+ Family: "gptoss",
+ Families: []string{"gptoss"},
+ ParameterSize: "20.9B",
+ QuantizationLevel: "MXFP4",
+ }
+
+ if !reflect.DeepEqual(resp.Details, expectedDetails) {
+ t.Errorf("model details: expected %#v, actual %#v", expectedDetails, resp.Details)
+ }
+
+ expectedCaps := []model.Capability{
+ model.Capability("completion"),
+ model.Capability("tools"),
+ model.Capability("thinking"),
+ }
+
+ if !slices.Equal(resp.Capabilities, expectedCaps) {
+ t.Errorf("capabilities: expected %#v, actual %#v", expectedCaps, resp.Capabilities)
+ }
+
+ v, ok := resp.ModelInfo["gptoss.context_length"]
+ ctxlen := v.(float64)
+ if !ok || int(ctxlen) != 131072 {
+ t.Errorf("context len: expected %d, actual %d", 131072, int(ctxlen))
+ }
+
+ v, ok = resp.ModelInfo["gptoss.embedding_length"]
+ embedlen := v.(float64)
+ if !ok || int(embedlen) != 2880 {
+ t.Errorf("embed len: expected %d, actual %d", 2880, int(embedlen))
+ }
+
+ fmt.Printf("resp = %#v\n", resp)
+}
+
func TestCreateLicenses(t *testing.T) {
gin.SetMode(gin.TestMode)
diff --git a/server/routes_debug_test.go b/server/routes_debug_test.go
index f04a1da9..6507284e 100644
--- a/server/routes_debug_test.go
+++ b/server/routes_debug_test.go
@@ -180,7 +180,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
}
- var response api.DebugTemplateResponse
+ var response api.GenerateResponse
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
@@ -385,7 +385,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
}
- var response api.DebugTemplateResponse
+ var response api.ChatResponse
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
diff --git a/server/routes_test.go b/server/routes_test.go
index 87b52663..bb7e2b7c 100644
--- a/server/routes_test.go
+++ b/server/routes_test.go
@@ -126,7 +126,15 @@ func TestRoutes(t *testing.T) {
t.Fatalf("failed to create model: %v", err)
}
- if err := createModel(r, modelName, baseLayers, fn); err != nil {
+ config := &ConfigV2{
+ OS: "linux",
+ Architecture: "amd64",
+ RootFS: RootFS{
+ Type: "layers",
+ },
+ }
+
+ if err := createModel(r, modelName, baseLayers, config, fn); err != nil {
t.Fatal(err)
}
}
diff --git a/server/sched.go b/server/sched.go
index c501c0e8..74aa406a 100644
--- a/server/sched.go
+++ b/server/sched.go
@@ -382,10 +382,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
// (if any). Returns whether the scheduler needs to evict a model to make this one fit.
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool {
- numParallel := int(envconfig.NumParallel())
- if numParallel < 1 {
- numParallel = 1
- }
+ numParallel := max(int(envconfig.NumParallel()), 1)
// Embedding models should always be loaded with parallel=1
if req.model.CheckCapabilities(model.CapabilityCompletion) != nil {