diff --git a/Dockerfile b/Dockerfile index c84b5239..ffaa31a5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,7 @@ # vim: filetype=dockerfile ARG FLAVOR=${TARGETARCH} +ARG PARALLEL=8 ARG ROCMVERSION=6.3.3 ARG JETPACK5VERSION=r35.4.1 @@ -34,46 +35,51 @@ ENV LDFLAGS=-s FROM base AS cpu RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH +ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'CPU' \ - && cmake --build --parallel --preset 'CPU' \ - && cmake --install build --component CPU --strip --parallel 8 + && cmake --build --parallel ${PARALLEL} --preset 'CPU' \ + && cmake --install build --component CPU --strip --parallel ${PARALLEL} FROM base AS cuda-11 ARG CUDA11VERSION=11.8 RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-} ENV PATH=/usr/local/cuda-11/bin:$PATH +ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \ - && cmake --build --parallel --preset 'CUDA 11' \ - && cmake --install build --component CUDA --strip --parallel 8 + && cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \ + && cmake --install build --component CUDA --strip --parallel ${PARALLEL} FROM base AS cuda-12 ARG CUDA12VERSION=12.8 RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} ENV PATH=/usr/local/cuda-12/bin:$PATH +ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\ - && cmake --build --parallel --preset 'CUDA 12' \ - && cmake --install build --component CUDA --strip --parallel 8 + && cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \ + && cmake --install build --component CUDA --strip --parallel ${PARALLEL} FROM base AS cuda-13 ARG CUDA13VERSION=13.0 RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} ENV PATH=/usr/local/cuda-13/bin:$PATH +ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \ - && cmake --build --parallel --preset 'CUDA 13' \ - && cmake --install build --component CUDA --strip --parallel 8 + && cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \ + && cmake --install build --component CUDA --strip --parallel ${PARALLEL} FROM base AS rocm-6 ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH +ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'ROCm 6' \ - && cmake --build --parallel --preset 'ROCm 6' \ - && cmake --install build --component HIP --strip --parallel 8 + && cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \ + && cmake --install build --component HIP --strip --parallel ${PARALLEL} FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5 ARG CMAKEVERSION @@ -81,10 +87,11 @@ RUN apt-get update && apt-get install -y curl ccache \ && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 COPY CMakeLists.txt CMakePresets.json . COPY ml/backend/ggml/ggml ml/backend/ggml/ggml +ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'JetPack 5' \ - && cmake --build --parallel --preset 'JetPack 5' \ - && cmake --install build --component CUDA --strip --parallel 8 + && cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \ + && cmake --install build --component CUDA --strip --parallel ${PARALLEL} FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6 ARG CMAKEVERSION @@ -92,10 +99,11 @@ RUN apt-get update && apt-get install -y curl ccache \ && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 COPY CMakeLists.txt CMakePresets.json . COPY ml/backend/ggml/ggml ml/backend/ggml/ggml +ARG PARALLEL RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'JetPack 6' \ - && cmake --build --parallel --preset 'JetPack 6' \ - && cmake --install build --component CUDA --strip --parallel 8 + && cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \ + && cmake --install build --component CUDA --strip --parallel ${PARALLEL} FROM base AS build WORKDIR /go/src/github.com/ollama/ollama diff --git a/api/client.go b/api/client.go index 7cc2acb3..20e6d795 100644 --- a/api/client.go +++ b/api/client.go @@ -222,7 +222,17 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f return fmt.Errorf("unmarshal: %w", err) } - if response.StatusCode >= http.StatusBadRequest { + if response.StatusCode == http.StatusUnauthorized { + pubKey, pkErr := auth.GetPublicKey() + if pkErr != nil { + return pkErr + } + return AuthorizationError{ + StatusCode: response.StatusCode, + Status: response.Status, + PublicKey: pubKey, + } + } else if response.StatusCode >= http.StatusBadRequest { return StatusError{ StatusCode: response.StatusCode, Status: response.Status, @@ -428,3 +438,16 @@ func (c *Client) Version(ctx context.Context) (string, error) { return version.Version, nil } + +// Signout will disconnect an ollama instance from ollama.com +func (c *Client) Signout(ctx context.Context, encodedKey string) error { + return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil) +} + +func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) { + var resp UserResponse + if err := c.do(ctx, http.MethodPost, "/api/me", nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} diff --git a/api/types.go b/api/types.go index a7ddbc37..5b8e034c 100644 --- a/api/types.go +++ b/api/types.go @@ -11,6 +11,8 @@ import ( "strings" "time" + "github.com/google/uuid" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/types/model" ) @@ -36,6 +38,19 @@ func (e StatusError) Error() string { } } +type AuthorizationError struct { + StatusCode int + Status string + PublicKey string `json:"public_key"` +} + +func (e AuthorizationError) Error() string { + if e.Status != "" { + return e.Status + } + return "something went wrong, please see the ollama server logs for details" +} + // ImageData represents the raw binary data of an image file. type ImageData []byte @@ -313,13 +328,29 @@ func (t *ToolFunction) String() string { // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Message Message `json:"message"` - DoneReason string `json:"done_reason,omitempty"` + // Model is the model name that generated the response. + Model string `json:"model"` + // RemoteModel is the name of the upstream model that generated the response. + RemoteModel string `json:"remote_model,omitempty"` + + // RemoteHost is the URL of the upstream Ollama host that generated the response. + RemoteHost string `json:"remote_host,omitempty"` + + // CreatedAt is the timestamp of the response. + CreatedAt time.Time `json:"created_at"` + + // Message contains the message or part of a message from the model. + Message Message `json:"message"` + + // Done specifies if the response is complete. Done bool `json:"done"` + // DoneReason is the reason the model stopped generating text. + DoneReason string `json:"done_reason,omitempty"` + + DebugInfo *DebugInfo `json:"_debug_info,omitempty"` + Metrics } @@ -329,13 +360,6 @@ type DebugInfo struct { ImageCount int `json:"image_count,omitempty"` } -// DebugTemplateResponse is returned when _debug_render_only is set to true -type DebugTemplateResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - DebugInfo DebugInfo `json:"_debug_info"` -} - type Metrics struct { TotalDuration time.Duration `json:"total_duration,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"` @@ -431,18 +455,47 @@ type EmbeddingResponse struct { // CreateRequest is the request passed to [Client.Create]. type CreateRequest struct { - Model string `json:"model"` - Stream *bool `json:"stream,omitempty"` + // Model is the model name to create. + Model string `json:"model"` + + // Stream specifies whether the response is streaming; it is true by default. + Stream *bool `json:"stream,omitempty"` + + // Quantize is the quantization format for the model; leave blank to not change the quantization level. Quantize string `json:"quantize,omitempty"` - From string `json:"from,omitempty"` - Files map[string]string `json:"files,omitempty"` - Adapters map[string]string `json:"adapters,omitempty"` - Template string `json:"template,omitempty"` - License any `json:"license,omitempty"` - System string `json:"system,omitempty"` - Parameters map[string]any `json:"parameters,omitempty"` - Messages []Message `json:"messages,omitempty"` + // From is the name of the model or file to use as the source. + From string `json:"from,omitempty"` + + // RemoteHost is the URL of the upstream ollama API for the model (if any). + RemoteHost string `json:"remote_host,omitempty"` + + // Files is a map of files include when creating the model. + Files map[string]string `json:"files,omitempty"` + + // Adapters is a map of LoRA adapters to include when creating the model. + Adapters map[string]string `json:"adapters,omitempty"` + + // Template is the template used when constructing a request to the model. + Template string `json:"template,omitempty"` + + // License is a string or list of strings for licenses. + License any `json:"license,omitempty"` + + // System is the system prompt for the model. + System string `json:"system,omitempty"` + + // Parameters is a map of hyper-parameters which are applied to the model. + Parameters map[string]any `json:"parameters,omitempty"` + + // Messages is a list of messages added to the model before chat and generation requests. + Messages []Message `json:"messages,omitempty"` + + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` + + // Info is a map of additional information for the model + Info map[string]any `json:"info,omitempty"` // Deprecated: set the model name with Model instead Name string `json:"name"` @@ -480,8 +533,12 @@ type ShowResponse struct { Parameters string `json:"parameters,omitempty"` Template string `json:"template,omitempty"` System string `json:"system,omitempty"` + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` Details ModelDetails `json:"details,omitempty"` Messages []Message `json:"messages,omitempty"` + RemoteModel string `json:"remote_model,omitempty"` + RemoteHost string `json:"remote_host,omitempty"` ModelInfo map[string]any `json:"model_info,omitempty"` ProjectorInfo map[string]any `json:"projector_info,omitempty"` Tensors []Tensor `json:"tensors,omitempty"` @@ -540,12 +597,14 @@ type ProcessResponse struct { // ListModelResponse is a single model description in [ListResponse]. type ListModelResponse struct { - Name string `json:"name"` - Model string `json:"model"` - ModifiedAt time.Time `json:"modified_at"` - Size int64 `json:"size"` - Digest string `json:"digest"` - Details ModelDetails `json:"details,omitempty"` + Name string `json:"name"` + Model string `json:"model"` + RemoteModel string `json:"remote_model,omitempty"` + RemoteHost string `json:"remote_host,omitempty"` + ModifiedAt time.Time `json:"modified_at"` + Size int64 `json:"size"` + Digest string `json:"digest"` + Details ModelDetails `json:"details,omitempty"` } // ProcessModelResponse is a single model description in [ProcessResponse]. @@ -569,6 +628,12 @@ type GenerateResponse struct { // Model is the model name that generated the response. Model string `json:"model"` + // RemoteModel is the name of the upstream model that generated the response. + RemoteModel string `json:"remote_model,omitempty"` + + // RemoteHost is the URL of the upstream Ollama host that generated the response. + RemoteHost string `json:"remote_host,omitempty"` + // CreatedAt is the timestamp of the response. CreatedAt time.Time `json:"created_at"` @@ -592,6 +657,8 @@ type GenerateResponse struct { Metrics ToolCalls []ToolCall `json:"tool_calls,omitempty"` + + DebugInfo *DebugInfo `json:"_debug_info,omitempty"` } // ModelDetails provides details about a model. @@ -604,6 +671,18 @@ type ModelDetails struct { QuantizationLevel string `json:"quantization_level"` } +// UserResponse provides information about a user. +type UserResponse struct { + ID uuid.UUID `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Bio string `json:"bio,omitempty"` + AvatarURL string `json:"avatarurl,omitempty"` + FirstName string `json:"firstname,omitempty"` + LastName string `json:"lastname,omitempty"` + Plan string `json:"plan,omitempty"` +} + // Tensor describes the metadata for a given tensor. type Tensor struct { Name string `json:"name"` diff --git a/auth/auth.go b/auth/auth.go index e1d85412..b26e2315 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -19,6 +19,31 @@ import ( const defaultPrivateKey = "id_ed25519" func keyPath() (string, error) { + fileIsReadable := func(fp string) bool { + info, err := os.Stat(fp) + if err != nil { + return false + } + + // Check that it's a regular file, not a directory or other file type + if !info.Mode().IsRegular() { + return false + } + + // Try to open it to check readability + file, err := os.Open(fp) + if err != nil { + return false + } + file.Close() + return true + } + + systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey) + if fileIsReadable(systemPath) { + return systemPath, nil + } + home, err := os.UserHomeDir() if err != nil { return "", err diff --git a/cmd/cmd.go b/cmd/cmd.go index 19f1e192..294e1662 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -5,6 +5,7 @@ import ( "context" "crypto/ed25519" "crypto/rand" + "encoding/base64" "encoding/json" "encoding/pem" "errors" @@ -14,6 +15,7 @@ import ( "math" "net" "net/http" + "net/url" "os" "os/signal" "path/filepath" @@ -35,6 +37,7 @@ import ( "golang.org/x/term" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/parser" @@ -47,6 +50,8 @@ import ( "github.com/ollama/ollama/version" ) +const ConnectInstructions = "To sign in, navigate to:\n https://ollama.com/connect?name=%s&key=%s\n\n" + // ensureThinkingSupport emits a warning if the model does not advertise thinking support func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) { if name == "" { @@ -286,7 +291,17 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { Think: opts.Think, } - return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil }) + return client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error { + if r.RemoteModel != "" && opts.ShowConnect { + p.StopAndClear() + if strings.HasPrefix(r.RemoteHost, "https://ollama.com") { + fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", r.RemoteModel) + } else { + fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", r.RemoteModel, r.RemoteHost) + } + } + return nil + }) } func StopHandler(cmd *cobra.Command, args []string) error { @@ -307,9 +322,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { interactive := true opts := runOptions{ - Model: args[0], - WordWrap: os.Getenv("TERM") == "xterm-256color", - Options: map[string]any{}, + Model: args[0], + WordWrap: os.Getenv("TERM") == "xterm-256color", + Options: map[string]any{}, + ShowConnect: true, } format, err := cmd.Flags().GetString("format") @@ -367,6 +383,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { } prompts = append([]string{string(in)}, prompts...) + opts.ShowConnect = false opts.WordWrap = false interactive = false } @@ -433,6 +450,21 @@ func RunHandler(cmd *cobra.Command, args []string) error { if interactive { if err := loadOrUnloadModel(cmd, &opts); err != nil { + var sErr api.AuthorizationError + if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { + pubKey, pkErr := auth.GetPublicKey() + if pkErr != nil { + return pkErr + } + // the server and the client both have the same public key + if pubKey == sErr.PublicKey { + h, _ := os.Hostname() + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) + fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n") + fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey) + } + return nil + } return err } @@ -453,6 +485,56 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generate(cmd, opts) } +func SigninHandler(cmd *cobra.Command, args []string) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + user, err := client.Whoami(cmd.Context()) + if err != nil { + return err + } + + if user != nil && user.Name != "" { + fmt.Printf("You are already signed in as user '%s'\n", user.Name) + fmt.Println() + return nil + } + + pubKey, pkErr := auth.GetPublicKey() + if pkErr != nil { + return pkErr + } + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) + + h, _ := os.Hostname() + fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey) + + return nil +} + +func SignoutHandler(cmd *cobra.Command, args []string) error { + pubKey, pkErr := auth.GetPublicKey() + if pkErr != nil { + return pkErr + } + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) + + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + err = client.Signout(cmd.Context(), encKey) + if err != nil { + return err + } + fmt.Println("You have signed out of ollama.com") + fmt.Println() + return nil +} + func PushHandler(cmd *cobra.Command, args []string) error { client, err := api.ClientFromEnvironment() if err != nil { @@ -505,7 +587,8 @@ func PushHandler(cmd *cobra.Command, args []string) error { if spinner != nil { spinner.Stop() } - if strings.Contains(err.Error(), "access denied") { + errStr := strings.ToLower(err.Error()) + if strings.Contains(errStr, "access denied") || strings.Contains(errStr, "unauthorized") { return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own") } return err @@ -539,7 +622,14 @@ func ListHandler(cmd *cobra.Command, args []string) error { for _, m := range models.Models { if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) { - data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")}) + var size string + if m.RemoteModel != "" { + size = "-" + } else { + size = format.HumanBytes(m.Size) + } + + data = append(data, []string{m.Name, m.Digest[:12], size, format.HumanTime(m.ModifiedAt, "Never")}) } } @@ -624,8 +714,8 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { KeepAlive: &api.Duration{Duration: 0}, } if err := loadOrUnloadModel(cmd, opts); err != nil { - if !strings.Contains(err.Error(), "not found") { - return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err) + if !strings.Contains(strings.ToLower(err.Error()), "not found") { + fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0]) } } @@ -736,12 +826,36 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error { } tableRender("Model", func() (rows [][]string) { + if resp.RemoteHost != "" { + rows = append(rows, []string{"", "Remote model", resp.RemoteModel}) + rows = append(rows, []string{"", "Remote URL", resp.RemoteHost}) + } + if resp.ModelInfo != nil { arch := resp.ModelInfo["general.architecture"].(string) rows = append(rows, []string{"", "architecture", arch}) - rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))}) - rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)}) - rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)}) + + var paramStr string + if resp.Details.ParameterSize != "" { + paramStr = resp.Details.ParameterSize + } else if v, ok := resp.ModelInfo["general.parameter_count"]; ok { + if f, ok := v.(float64); ok { + paramStr = format.HumanNumber(uint64(f)) + } + } + rows = append(rows, []string{"", "parameters", paramStr}) + + if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok { + if f, ok := v.(float64); ok { + rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)}) + } + } + + if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok { + if f, ok := v.(float64); ok { + rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)}) + } + } } else { rows = append(rows, []string{"", "architecture", resp.Details.Family}) rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize}) @@ -989,6 +1103,7 @@ type runOptions struct { KeepAlive *api.Duration Think *api.ThinkValue HideThinking bool + ShowConnect bool } type displayResponseState struct { @@ -1544,6 +1659,22 @@ func NewCLI() *cobra.Command { pushCmd.Flags().Bool("insecure", false, "Use an insecure registry") + signinCmd := &cobra.Command{ + Use: "signin", + Short: "Sign in to ollama.com", + Args: cobra.ExactArgs(0), + PreRunE: checkServerHeartbeat, + RunE: SigninHandler, + } + + signoutCmd := &cobra.Command{ + Use: "signout", + Short: "Sign out from ollama.com", + Args: cobra.ExactArgs(0), + PreRunE: checkServerHeartbeat, + RunE: SignoutHandler, + } + listCmd := &cobra.Command{ Use: "list", Aliases: []string{"ls"}, @@ -1638,6 +1769,8 @@ func NewCLI() *cobra.Command { stopCmd, pullCmd, pushCmd, + signinCmd, + signoutCmd, listCmd, psCmd, copyCmd, diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index cf5fe7ca..bb793572 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -3,6 +3,7 @@ package cmd import ( "bytes" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" @@ -304,6 +305,8 @@ func TestDeleteHandler(t *testing.T) { w.WriteHeader(http.StatusOK) } else { w.WriteHeader(http.StatusNotFound) + errPayload := `{"error":"model '%s' not found"}` + w.Write([]byte(fmt.Sprintf(errPayload, req.Name))) } return } @@ -346,7 +349,7 @@ func TestDeleteHandler(t *testing.T) { } err := DeleteHandler(cmd, []string{"test-model-not-found"}) - if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") { + if err == nil || !strings.Contains(err.Error(), "model 'test-model-not-found' not found") { t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err) } } @@ -499,7 +502,7 @@ func TestPushHandler(t *testing.T) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) err := json.NewEncoder(w).Encode(map[string]string{ - "error": "access denied", + "error": "403: {\"errors\":[{\"code\":\"ACCESS DENIED\", \"message\":\"access denied\"}]}", }) if err != nil { t.Fatal(err) @@ -522,6 +525,7 @@ func TestPushHandler(t *testing.T) { defer mockServer.Close() t.Setenv("OLLAMA_HOST", mockServer.URL) + initializeKeypair() cmd := &cobra.Command{} cmd.Flags().Bool("insecure", false, "") diff --git a/convert/convert_bert.go b/convert/convert_bert.go index a9f4b8a7..6b0d0030 100644 --- a/convert/convert_bert.go +++ b/convert/convert_bert.go @@ -28,6 +28,7 @@ type bertModel struct { LayerNormEPS float32 `json:"layer_norm_eps"` LayerNormEpsilon float32 `json:"layer_norm_epsilon"` NormEpsilon float32 `json:"norm_epsilon"` + normalizeEmbeddings bool PoolingType uint32 } @@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error { var pooling string for _, m := range modules { - if m.Type == "sentence_transformers.models.Pooling" { + switch m.Type { + case "sentence_transformers.models.Pooling": pooling = m.Path - break + case "sentence_transformers.models.Normalize": + p.normalizeEmbeddings = true } } @@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV { kv["general.architecture"] = "bert" kv["bert.attention.causal"] = false kv["bert.pooling_type"] = p.PoolingType + kv["bert.normalize_embeddings"] = p.normalizeEmbeddings kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer) diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index 7f029f93..eea0de2f 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -96,7 +96,7 @@ type safetensor struct { func (st safetensor) Kind() uint32 { kind := st.tensorBase.Kind() - if st.dtype == "BF16" && kind != tensorKindFP32 { + if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 { kind = tensorKindBF16 } diff --git a/convert/reader_test.go b/convert/reader_test.go index 6dbe32a5..c3d094f1 100644 --- a/convert/reader_test.go +++ b/convert/reader_test.go @@ -230,3 +230,65 @@ func TestSafetensors(t *testing.T) { }) } } + +func TestSafetensorKind(t *testing.T) { + tests := []struct { + name string + st safetensor + expected uint32 + }{ + { + name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16", + st: safetensor{ + tensorBase: &tensorBase{ + name: "weight.matrix", + shape: []uint64{10, 10}, // will default to FP16 + }, + dtype: "BF16", + }, + expected: tensorKindBF16, + }, + { + name: "BF16 dtype with v. prefix should return base kind", + st: safetensor{ + tensorBase: &tensorBase{ + name: "v.weight.matrix", + shape: []uint64{10, 10}, // will default to FP16 + }, + dtype: "BF16", + }, + expected: tensorKindFP16, + }, + { + name: "BF16 dtype with FP32 base kind should return FP32", + st: safetensor{ + tensorBase: &tensorBase{ + name: "weight.matrix", + shape: []uint64{10}, // will default to FP32 + }, + dtype: "BF16", + }, + expected: tensorKindFP32, + }, + { + name: "Non-BF16 dtype should return base kind", + st: safetensor{ + tensorBase: &tensorBase{ + name: "weight.matrix", + shape: []uint64{10, 10}, // will default to FP16 + }, + dtype: "FP16", + }, + expected: tensorKindFP16, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.st.Kind() + if result != tt.expected { + t.Errorf("Kind() = %d, expected %d", result, tt.expected) + } + }) + } +} diff --git a/discover/cuda_common.go b/discover/cuda_common.go index ca008af6..a2c43420 100644 --- a/discover/cuda_common.go +++ b/discover/cuda_common.go @@ -16,7 +16,7 @@ import ( // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. var CudaTegra string = os.Getenv("JETSON_JETPACK") -func cudaVariant(gpuInfo CudaGPUInfo) string { +func cudaVariant(gpuInfos []CudaGPUInfo) string { if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" { if CudaTegra != "" { ver := strings.Split(CudaTegra, ".") @@ -45,12 +45,19 @@ func cudaVariant(gpuInfo CudaGPUInfo) string { } } - if gpuInfo.DriverMajor < 13 { + // Check GPU compute capability FIRST, lowest common denominator if multi-gpu + for _, gpuInfo := range gpuInfos { + if gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) { + // GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1) + return "v12" + } + } + + // GPU is Turing or newer (CC >= 7.5) - can use newer CUDA + if len(gpuInfos) > 0 && gpuInfos[0].DriverMajor < 13 { // The detected driver is older than 580 (Aug 2025) // Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance - if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) { - slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor)) - } + slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfos[0].DriverMajor, gpuInfos[0].DriverMinor)) return "v12" } return "v13" diff --git a/discover/gpu.go b/discover/gpu.go index 95070ecb..4bb0d94e 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -284,18 +284,8 @@ func GetGPUInfo() GpuInfoList { gpuInfo.MinimumMemory = cudaMinimumMemory gpuInfo.DriverMajor = driverMajor gpuInfo.DriverMinor = driverMinor - variant := cudaVariant(gpuInfo) - // Start with our bundled libraries - if variant != "" { - variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant) - if _, err := os.Stat(variantPath); err == nil { - // Put the variant directory first in the search path to avoid runtime linking to the wrong library - gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...) - } - } gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - gpuInfo.Variant = variant if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) { unsupportedGPUs = append(unsupportedGPUs, @@ -333,6 +323,24 @@ func GetGPUInfo() GpuInfoList { // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... cudaGPUs = append(cudaGPUs, gpuInfo) } + // Second pass on NVIDIA GPUs to set lowest common denominator variant and DependencyPaths + variant := cudaVariant(cudaGPUs) + var variantPath string + // Start with our bundled libraries + if variant != "" { + variantPath = filepath.Join(LibOllamaPath, "cuda_"+variant) + if _, err := os.Stat(variantPath); err != nil { + variantPath = "" + } + } + + for i := range cudaGPUs { + cudaGPUs[i].Variant = variant + if variantPath != "" { + // Put the variant directory first in the search path to avoid runtime linking to the wrong library + cudaGPUs[i].DependencyPath = append([]string{variantPath}, cudaGPUs[i].DependencyPath...) + } + } } // Intel diff --git a/docs/development.md b/docs/development.md index 9726b5d9..ff07b5fb 100644 --- a/docs/development.md +++ b/docs/development.md @@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository: go run . serve ``` +> [!NOTE] +> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first. + + ## macOS (Apple Silicon) macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required. diff --git a/envconfig/config.go b/envconfig/config.go index 7fc01887..09243ab9 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -134,6 +134,17 @@ func LoadTimeout() (loadTimeout time.Duration) { return loadTimeout } +func Remotes() []string { + var r []string + raw := strings.TrimSpace(Var("OLLAMA_REMOTES")) + if raw == "" { + r = []string{"ollama.com"} + } else { + r = strings.Split(raw, ",") + } + return r +} + func Bool(k string) func() bool { return func() bool { if s := Var(k); s != "" { @@ -270,6 +281,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, + "OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"}, // Informational "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 6b582b49..5da902bc 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -243,6 +243,7 @@ func (kv KV) OllamaEngineRequired() bool { "gemma3", "gemma3n", "mistral3", + "qwen3", "llama4", "mllama", "qwen25vl", diff --git a/integration/embed_test.go b/integration/embed_test.go index eb00f4ba..a6852448 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" ) @@ -44,9 +45,8 @@ func TestAllMiniLMEmbeddings(t *testing.T) { } res, err := embeddingTestHelper(ctx, client, t, req) - if err != nil { - t.Fatalf("error: %v", err) + t.Fatal(err) } if len(res.Embedding) != 384 { @@ -74,9 +74,8 @@ func TestAllMiniLMEmbed(t *testing.T) { } res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatalf("error: %v", err) + t.Fatal(err) } if len(res.Embeddings) != 1 { @@ -112,9 +111,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) { } res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatalf("error: %v", err) + t.Fatal(err) } if len(res.Embeddings) != 2 { @@ -156,93 +154,135 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { truncTrue, truncFalse := true, false - type testReq struct { - Name string - Request api.EmbedRequest + want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{ + Model: "all-minilm", + Input: "why", + }) + if err != nil { + t.Fatal(err) } - reqs := []testReq{ + cases := []struct { + name string + request api.EmbedRequest + check func(*api.EmbedResponse, error) + }{ { - Name: "Target Truncation", - Request: api.EmbedRequest{ + name: "target truncation", + request: api.EmbedRequest{ Model: "all-minilm", Input: "why", }, - }, - { - Name: "Default Truncate", - Request: api.EmbedRequest{ - Model: "all-minilm", - Input: "why is the sky blue?", - Options: map[string]any{"num_ctx": 1}, + check: func(got *api.EmbedResponse, err error) { + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { + t.Errorf("embedding mismatch (-want +got):\n%s", diff) + } }, }, { - Name: "Explicit Truncate", - Request: api.EmbedRequest{ + name: "default truncate", + request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Options: map[string]any{"num_ctx": 3}, + }, + check: func(got *api.EmbedResponse, err error) { + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { + t.Errorf("embedding mismatch (-want +got):\n%s", diff) + } + }, + }, + { + name: "explicit truncate", + request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 3}, + }, + check: func(got *api.EmbedResponse, err error) { + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" { + t.Errorf("embedding mismatch (-want +got):\n%s", diff) + } + }, + }, + { + name: "truncate error", + request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 3}, + }, + check: func(res *api.EmbedResponse, err error) { + if err.Error() != "input exceeds maximum context length" { + t.Fatalf("expected truncation error, got: %v", err) + } + }, + }, + { + name: "input after truncate error", + request: api.EmbedRequest{ Model: "all-minilm", Input: "why is the sky blue?", Truncate: &truncTrue, Options: map[string]any{"num_ctx": 1}, }, + check: func(res *api.EmbedResponse, err error) { + if err.Error() != "input after truncation exceeds maximum context length" { + t.Fatalf("expected truncation error, got: %v", err) + } + }, + }, + { + name: "input after truncate error", + request: api.EmbedRequest{ + Model: "all-minilm", + Input: "why is the sky blue?", + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 0}, + }, + check: func(res *api.EmbedResponse, err error) { + if err.Error() != "input after truncation exceeds maximum context length" { + t.Fatalf("expected truncation error, got: %v", err) + } + }, }, } - res := make(map[string]*api.EmbedResponse) - - for _, req := range reqs { - response, err := embedTestHelper(ctx, client, t, req.Request) - if err != nil { - t.Fatalf("error: %v", err) - } - res[req.Name] = response - } - - if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] { - t.Fatal("expected default request to truncate correctly") - } - - if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] { - t.Fatal("expected default request and truncate true request to be the same") - } - - // check that truncate set to false returns an error if context length is exceeded - _, err := embedTestHelper(ctx, client, t, api.EmbedRequest{ - Model: "all-minilm", - Input: "why is the sky blue?", - Truncate: &truncFalse, - Options: map[string]any{"num_ctx": 1}, - }) - - if err == nil { - t.Fatal("expected error, got nil") + for _, req := range cases { + t.Run(req.name, func(t *testing.T) { + req.check(embedTestHelper(ctx, client, t, req.request)) + }) } } func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { + t.Helper() + if err := PullIfMissing(ctx, client, req.Model); err != nil { - t.Fatalf("failed to pull model %s: %v", req.Model, err) + t.Fatal(err) } - response, err := client.Embeddings(ctx, &req) - - if err != nil { - return nil, err - } - - return response, nil + return client.Embeddings(ctx, &req) } func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { + t.Helper() + if err := PullIfMissing(ctx, client, req.Model); err != nil { - t.Fatalf("failed to pull model %s: %v", req.Model, err) + t.Fatal(err) } - response, err := client.Embed(ctx, &req) - - if err != nil { - return nil, err - } - - return response, nil + return client.Embed(ctx, &req) } diff --git a/logutil/logutil.go b/logutil/logutil.go index fff277b8..00daf6a6 100644 --- a/logutil/logutil.go +++ b/logutil/logutil.go @@ -5,6 +5,8 @@ import ( "io" "log/slog" "path/filepath" + "runtime" + "time" ) const LevelTrace slog.Level = -8 @@ -29,10 +31,18 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger { })) } +type key string + func Trace(msg string, args ...any) { - slog.Log(context.TODO(), LevelTrace, msg, args...) + TraceContext(context.WithValue(context.TODO(), key("skip"), 1), msg, args...) } func TraceContext(ctx context.Context, msg string, args ...any) { - slog.Log(ctx, LevelTrace, msg, args...) + if logger := slog.Default(); logger.Enabled(ctx, LevelTrace) { + skip, _ := ctx.Value(key("skip")).(int) + pc, _, _, _ := runtime.Caller(1 + skip) + record := slog.NewRecord(time.Now(), LevelTrace, msg, pc) + record.Add(args...) + logger.Handler().Handle(ctx, record) + } } diff --git a/ml/backend.go b/ml/backend.go index 154a0f1b..455715b0 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -416,6 +416,7 @@ type Tensor interface { AddID(ctx Context, t2, ids Tensor) Tensor Softmax(ctx Context) Tensor + L2Norm(ctx Context, eps float32) Tensor LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor Scale(ctx Context, s float64) Tensor @@ -429,12 +430,13 @@ type Tensor interface { Sin(ctx Context) Tensor Cos(ctx Context) Tensor Tanh(ctx Context) Tensor - GELU(ctx Context) Tensor - QuickGELU(ctx Context) Tensor - SILU(ctx Context) Tensor - RELU(ctx Context) Tensor + GELU(ctx Context, up ...Tensor) Tensor + SILU(ctx Context, up ...Tensor) Tensor + RELU(ctx Context, up ...Tensor) Tensor Sigmoid(ctx Context) Tensor - SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor + + // AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit] + SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor Reshape(ctx Context, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 931386d5..49dc3e1a 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor { } } +func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)), + } +} + func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)) if w != nil { @@ -1424,35 +1431,46 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int } } -func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { +func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { + if len(t2) > 0 { + return &Tensor{ + b: t.b, + t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t), + } + } return &Tensor{ b: t.b, t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t), } } -func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor { - return &Tensor{ - b: t.b, - t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t), +func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { + if len(t2) > 0 { + return &Tensor{ + b: t.b, + t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t), + } } -} - -func (t *Tensor) SILU(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b, t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t), } } -func (t *Tensor) RELU(ctx ml.Context) ml.Tensor { +func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { + if len(t2) > 0 { + return &Tensor{ + b: t.b, + t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t), + } + } return &Tensor{ b: t.b, t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t), } } -func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor { +func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor { return &Tensor{ b: t.b, t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)), diff --git a/ml/nn/attention.go b/ml/nn/attention.go index 21b4a28a..94dbde0b 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -26,6 +26,7 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache } func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { + ctx.Forward(query) if key != nil && value != nil { if query.Dim(0) != key.Dim(0) { panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) @@ -39,6 +40,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) } + ctx.Forward(key, value) if cache != nil { cache.Put(ctx, key, value) } diff --git a/ml/nn/pooling/pooling.go b/ml/nn/pooling/pooling.go new file mode 100644 index 00000000..63b63b3a --- /dev/null +++ b/ml/nn/pooling/pooling.go @@ -0,0 +1,42 @@ +package pooling + +import ( + "github.com/ollama/ollama/ml" +) + +type Type uint32 + +const ( + TypeNone Type = iota + TypeMean + TypeCLS + TypeLast +) + +func (t Type) String() string { + switch t { + case TypeMean: + return "Mean" + case TypeCLS: + return "CLS" + case TypeLast: + return "Last" + default: + return "Unknown" + } +} + +func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { + switch t { + case TypeMean: + hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) + return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + case TypeCLS: + return hiddenStates.View(ctx, 0, hiddenStates.Dim(0)) + case TypeLast: + hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0)) + return hiddenStates + default: + panic("unknown pooling type") + } +} diff --git a/ml/nn/pooling/pooling_test.go b/ml/nn/pooling/pooling_test.go new file mode 100644 index 00000000..c8001945 --- /dev/null +++ b/ml/nn/pooling/pooling_test.go @@ -0,0 +1,79 @@ +package pooling_test + +import ( + "bytes" + "os" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/discover" + fsggml "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/backend/ggml" + "github.com/ollama/ollama/ml/nn/pooling" +) + +func setup(tb testing.TB, n int) ml.Backend { + tb.Helper() + + f, err := os.CreateTemp(tb.TempDir(), "*.bin") + if err != nil { + tb.Fatal(err) + } + defer f.Close() + + if err := fsggml.WriteGGUF(f, fsggml.KV{ + "general.architecture": "test", + "test.block_count": uint32(1), + }, []*fsggml.Tensor{ + {Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))}, + }); err != nil { + tb.Fatal(err) + } + + var gpuLayers ml.GPULayersList + if gpus := discover.GetGPUInfo(); len(gpus) > 0 { + gpuLayers = append(gpuLayers, ml.GPULayers{ + ID: gpus[0].ID, + Layers: slices.Collect(func(yield func(int) bool) { + for i := range n { + if !yield(i) { + return + } + } + }), + }) + } + b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers}) + if err != nil { + tb.Fatal(err) + } + + return b +} + +func TestForward(t *testing.T) { + cases := map[pooling.Type][]float32{ + pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11}, + pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7}, + pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15}, + } + for typ, want := range cases { + t.Run(typ.String(), func(t *testing.T) { + b := setup(t, 99) + defer b.Close() + + ctx := b.NewContext() + defer ctx.Close() + + tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2) + tt = typ.Forward(ctx, tt) + + ctx.Forward(tt).Compute(tt) + if diff := cmp.Diff(want, tt.Floats()); diff != "" { + t.Error(diff) + } + }) + } +} diff --git a/model/input/input.go b/model/input/input.go index bd9b53ec..35dc41b3 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -54,10 +54,9 @@ type Batch struct { // Inputs is the input tokens, including placeholders for multimodal inputs. Inputs ml.Tensor - // Multimodal is a set of multimodal embeddings previously created by - // EncodeMultimodal, along with an index into Inputs. Unused for text-only - // models or for batches without multimodal elements. - Multimodal []MultimodalIndex + // Outputs are the set of indicies into Inputs for which output data should + // be returned. + Outputs ml.Tensor // Positions is the position for each Input, relative to its sequence. Equal // in length to Inputs. @@ -66,7 +65,8 @@ type Batch struct { // Sequences is the sequence for each Input. Equal in length to Inputs. Sequences []int - // Outputs are the set of indicies into Inputs for which output data should - // be returned. - Outputs []int32 + // Multimodal is a set of multimodal embeddings previously created by + // EncodeMultimodal, along with an index into Inputs. Unused for text-only + // models or for batches without multimodal elements. + Multimodal []MultimodalIndex } diff --git a/model/model.go b/model/model.go index 3a72f09a..f3d6bb3d 100644 --- a/model/model.go +++ b/model/model.go @@ -5,7 +5,6 @@ import ( "fmt" _ "image/jpeg" _ "image/png" - "math" "os" "reflect" "strconv" @@ -21,10 +20,15 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" _ "github.com/ollama/ollama/ml/backend" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model/input" ) -var ErrNoVisionModel = errors.New("this model is missing data required for image input") +var ( + ErrNoVisionModel = errors.New("this model is missing data required for image input") + ErrUnsupportedModel = errors.New("model not supported") + ErrUnsupportedTokenizer = errors.New("tokenizer not supported") +) // Model implements a specific model architecture, defining the forward pass and any model-specific configuration type Model interface { @@ -103,23 +107,12 @@ func New(modelPath string, params ml.BackendParams) (Model, error) { return nil, err } - arch := b.Config().Architecture() - if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 { - arch = arch + "_embed" - } - - f, ok := models[arch] - if !ok { - return nil, fmt.Errorf("unsupported model architecture %q", arch) - } - - m, err := f(b.Config()) + m, err := modelForArch(b.Config()) if err != nil { return nil, err } base := Base{b: b, config: m.Config()} - v := reflect.ValueOf(m) v.Elem().Set(populateFields(base, v.Elem())) return m, nil @@ -131,30 +124,38 @@ func NewTextProcessor(s string) (TextProcessor, error) { return nil, err } defer r.Close() + meta, err := fsggml.Decode(r, -1) if err != nil { return nil, err } - return getTextProcessor(meta.KV()) -} -func getTextProcessor(kv fsggml.KV) (TextProcessor, error) { - arch := kv.Architecture() - f, ok := models[arch] - if !ok { - return nil, fmt.Errorf("unsupported model architecture %q", arch) - } - m, err := f(kv) + m, err := modelForArch(meta.KV()) if err != nil { return nil, err } + tp, ok := m.(TextProcessor) if !ok { - return nil, fmt.Errorf("%v is not a TextProcessor", m) + return nil, ErrUnsupportedTokenizer } return tp, nil } +func modelForArch(c fs.Config) (Model, error) { + arch := c.Architecture() + if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone { + arch = arch + "_embed" + } + + f, ok := models[arch] + if !ok { + return nil, ErrUnsupportedModel + } + + return f(c) +} + func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { t := v.Type() @@ -242,7 +243,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) { vv = vv.Elem() } - vv = vv.Elem() + vv = reflect.Indirect(vv) if v.IsNil() { vv = reflect.New(v.Type().Elem()).Elem() } diff --git a/model/model_test.go b/model/model_test.go index 020f9ffb..01080ffd 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -1,9 +1,9 @@ package model import ( + "errors" "reflect" "slices" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -12,7 +12,6 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/model/input" ) func TestParseTags(t *testing.T) { @@ -148,39 +147,58 @@ func TestPopulateFieldsAlternateName(t *testing.T) { } } -func TestGetTextProcessor(t *testing.T) { - tp, err := getTextProcessor(fsggml.KV{}) - if err == nil { - t.Error("expected error") - } else if !strings.Contains(err.Error(), "unsupported model architecture") { - t.Errorf("unexpected error: %v", err) - } else if tp != nil { - t.Error("expected nil tp") +func TestModelForArch(t *testing.T) { + type fakeModel struct { + Model } - models["dummy"] = func(fs.Config) (Model, error) { - return notTextProcessorModel{}, nil + type fakeEmbeddingModel struct { + Model } - tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"}) - if err == nil { - t.Error("expected error") - } else if !strings.Contains(err.Error(), "not a TextProcessor") { - t.Errorf("unexpected error: %v", err) - } else if tp != nil { - t.Error("expected nil tp") + + models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil } + models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil } + + cases := []struct { + name string + config fs.Config + want any + err error + }{ + { + name: "model", + config: fsggml.KV{ + "general.architecture": "model", + }, + want: fakeModel{}, + }, + { + name: "embedding", + config: fsggml.KV{ + "general.architecture": "model", + "model.pooling_type": uint32(1), + }, + want: fakeEmbeddingModel{}, + }, + { + name: "unsupported", + config: fsggml.KV{ + "general.architecture": "unsupported", + }, + err: ErrUnsupportedModel, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := modelForArch(tt.config) + if !errors.Is(err, tt.err) { + t.Fatal(err) + } + + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff) + } + }) } } - -type notTextProcessorModel struct{} - -func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) { - panic("unimplemented") -} - -func (notTextProcessorModel) Backend() ml.Backend { - panic("unimplemented") -} - -func (notTextProcessorModel) Config() config { - panic("unimplemented") -} diff --git a/model/models/bert/embed.go b/model/models/bert/embed.go new file mode 100644 index 00000000..166c11e1 --- /dev/null +++ b/model/models/bert/embed.go @@ -0,0 +1,181 @@ +package bert + +import ( + "cmp" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/pooling" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Model struct { + model.Base + model.TextProcessor + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + TypeEmbedding *nn.Embedding `gguf:"token_types"` + PositionEmbedding *nn.Embedding `gguf:"position_embd"` + TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"` + + Layers []EncoderLayer `gguf:"blk"` + + Options +} + +// Forward implements model.Model. +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize)) + hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)))) + hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps) + + for _, layer := range m.Layers { + hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options) + } + + hiddenStates = m.poolingType.Forward(ctx, hiddenStates) + if m.normalize { + hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) + } + + return hiddenStates, nil +} + +type EncoderLayer struct { + *Attention + AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"` + + *MLP + MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"` +} + +func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + // Attention + residual := hiddenStates + hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + + // MLP + residual = hiddenStates + hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + + return hiddenStates +} + +type Attention struct { + Query *nn.Linear `gguf:"attn_q"` + QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"` + + Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"` + + Value *nn.Linear `gguf:"attn_v"` + + Output *nn.Linear `gguf:"attn_output"` +} + +func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + batchSize := hiddenStates.Dim(1) + + query := a.Query.Forward(ctx, hiddenStates) + if a.QueryNorm != nil { + query = a.QueryNorm.Forward(ctx, query, opts.eps) + } + query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) + + key := a.Key.Forward(ctx, hiddenStates) + if a.KeyNorm != nil { + key = a.KeyNorm.Forward(ctx, key, opts.eps) + } + key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize) + + value := a.Value.Forward(ctx, hiddenStates) + value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize) + + attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil) + attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) + return a.Output.Forward(ctx, attention) +} + +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx)) +} + +type Options struct { + hiddenSize, + numHeads, + numKVHeads, + keyLength, + valueLength int + poolingType pooling.Type + eps float32 + normalize bool +} + +func (o Options) headDim() int { + return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) +} + +func New(c fs.Config) (model.Model, error) { + var processor model.TextProcessor + switch c.String("tokenizer.ggml.model", "bert") { + case "bert": + processor = model.NewWordPiece( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{ + int32(cmp.Or( + c.Uint("tokenizer.ggml.cls_token_id"), + c.Uint("tokenizer.ggml.bos_token_id"), + )), + }, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true), + EOS: []int32{ + int32(cmp.Or( + c.Uint("tokenizer.ggml.separator_token_id"), + //nolint:misspell + // NOTE: "seperator_token_id" is a typo in model metadata but we need to + // support it for compatibility. + c.Uint("tokenizer.ggml.seperator_token_id"), + c.Uint("tokenizer.ggml.eos_token_id"), + )), + }, + }, + ) + default: + return nil, model.ErrUnsupportedTokenizer + } + + return &Model{ + TextProcessor: processor, + Layers: make([]EncoderLayer, c.Uint("block_count")), + Options: Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + eps: c.Float("attention.layer_norm_epsilon"), + poolingType: pooling.Type(c.Uint("pooling_type")), + normalize: c.Bool("normalize_embeddings", true), + }, + }, nil +} + +func init() { + model.Register("bert", New) + model.Register("bert_embed", New) +} diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index e621d03a..2b16dc62 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -24,7 +24,7 @@ type Options struct { type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` @@ -40,7 +40,7 @@ const ( func New(c fs.Config) (model.Model, error) { m := Model{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), @@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) { attnValLen: int(c.Uint("attention.value_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base", 10000.0), - ropeScale: c.Float("rope.freq_scale", 1.0), + ropeScale: c.Float("rope.scaling.factor", 1.0), attnLogitSoftcap: c.Float("attn_logit_softcapping"), finalLogitSoftcap: c.Float("final_logit_softcapping"), }, @@ -88,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -98,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil } type MLP struct { @@ -138,7 +138,7 @@ type MLP struct { } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) @@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var lastLayerOutputs ml.Tensor if i == len(m.Layers)-1 { - lastLayerOutputs = outputs + lastLayerOutputs = batch.Outputs } hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go index 16c299e2..52554776 100644 --- a/model/models/gemma3/embed.go +++ b/model/models/gemma3/embed.go @@ -1,49 +1,38 @@ package gemma3 import ( - "errors" - "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" ) type embedModel struct { model.Base - model.SentencePieceModel + model.SentencePiece *TextModel - PoolingType uint32 + poolingType pooling.Type Dense [2]*nn.Linear `gguf:"dense"` } func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - batch.Outputs = batch.Positions // return all positions hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) - - switch m.PoolingType { - case 0: // None - case 1: // Mean - hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) - hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - default: - return nil, errors.New("unsupported pooling type") - } - + hiddenStates = m.poolingType.Forward(ctx, hiddenStates) for _, dense := range m.Dense { hiddenStates = dense.Forward(ctx, hiddenStates) } - + hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) return hiddenStates, nil } func newEmbedModel(c fs.Config) (model.Model, error) { m := &embedModel{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), @@ -61,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) { }, ), TextModel: newTextModel(c), - PoolingType: c.Uint("pooling_type", 0), + poolingType: pooling.Type(c.Uint("pooling_type", 0)), } m.Cache = kvcache.NewWrapperCache( diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 5c92b6bf..27da889e 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -16,7 +16,7 @@ import ( type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece *VisionModel `gguf:"v"` *TextModel @@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i func New(c fs.Config) (model.Model, error) { m := Model{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 2a3b2393..631baecc 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -53,7 +53,10 @@ func newTextModel(c fs.Config) *TextModel { eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), - ropeScale: c.Float("rope.freq_scale", 1.0), + ropeScale: 1, + // NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights + // (8 instead of 1) + // ropeScale: c.Float("rope.scaling.factor", 1.0), }, } @@ -84,7 +87,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -95,7 +98,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -113,7 +116,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.TextConfig.ropeGlobalBase } - return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil } type TextMLP struct { @@ -123,7 +126,7 @@ type TextMLP struct { } func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -161,7 +164,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) @@ -194,7 +196,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac var lastLayerOutputs ml.Tensor if i == len(m.Layers)-1 { - lastLayerOutputs = outputs + lastLayerOutputs = batch.Outputs } hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig) diff --git a/model/models/gemma3n/model.go b/model/models/gemma3n/model.go index 6e83a972..e59e3193 100644 --- a/model/models/gemma3n/model.go +++ b/model/models/gemma3n/model.go @@ -10,7 +10,7 @@ import ( type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece *TextModel } @@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func New(c fs.Config) (model.Model, error) { m := Model{ TextModel: newTextModel(c), - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index b75a2abb..d0e9a026 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx) hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx) - hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))) + hiddenStates = hiddenStates.Rows(ctx, batch.Outputs) hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) return m.Output.Forward(ctx, hiddenStates), nil @@ -95,7 +95,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.ropeBaseLocal } - return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil } type TextScaledWordEmbedding struct { @@ -170,8 +170,7 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position } active = d.PerLayerInputGate.Forward(ctx, active) - active = active.GELU(ctx) - active = active.Mul(ctx, perLayerInput) + active = active.GELU(ctx, perLayerInput) active = d.PerLayerProjection.Forward(ctx, active) active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps) @@ -257,14 +256,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten query := attn.Query.Forward(ctx, hiddenStates) query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) query = attn.QueryNorm.Forward(ctx, query, opts.eps) - query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) var key, value ml.Tensor if !sharedKV { key = attn.Key.Forward(ctx, hiddenStates) key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) key = attn.KeyNorm.Forward(ctx, key, opts.eps) - key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) value = attn.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) @@ -292,7 +291,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx) } - hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates) + hiddenStates = hiddenStates.GELU(ctx, upStates) hiddenStates = mlp.Down.Forward(ctx, hiddenStates) return hiddenStates } @@ -350,7 +349,7 @@ func newTextModel(c fs.Config) *TextModel { eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), ropeBase: c.Float("rope.freq_base", 1_000_000), ropeBaseLocal: c.Float("rope.freq_base_local", 10_000), - ropeScale: c.Float("rope.freq_scale", 1.0), + ropeScale: c.Float("rope.scaling.factor", 1.0), slidingWindowPattern: c.Bools("attention.sliding_window_pattern"), activationSparsityScale: c.Floats("activation_sparsity_scale"), diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index 3ef07809..8456ea5f 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err } var outputs ml.Tensor - if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if i == len(m.TransformerBlocks)-1 { + outputs = batch.Outputs } hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options) @@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts * up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts) } - hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7) + hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7) experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) experts = experts.Mul(ctx, routingWeights) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 77d8f36d..f6ec0227 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -2,7 +2,6 @@ package llama import ( "cmp" - "fmt" "math" "github.com/ollama/ollama/fs" @@ -23,51 +22,60 @@ type Options struct { type Model struct { model.Base - model.BytePairEncoding + model.TextProcessor TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` OutputNorm *nn.RMSNorm `gguf:"output_norm"` Output *nn.Linear `gguf:"output,alt:token_embd"` - *Options + Options } func New(c fs.Config) (model.Model, error) { - // This model currently only supports the gpt2 tokenizer - if c.String("tokenizer.ggml.model") == "llama" { - return nil, fmt.Errorf("unsupported tokenizer: llama") + if c.Uint("expert_count") > 0 { + // TODO: support mixtures of experts + return nil, model.ErrUnsupportedModel } - // Best effort detection of library/deepseek-coder model(s) which are incompatible - if c.String("general.name") == "deepseek-ai" { - return nil, fmt.Errorf("unsupported model: %s", c.String("general.name")) - } - m := Model{ - BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), - &model.Vocabulary{ - Values: c.Strings("tokenizer.ggml.tokens"), - Types: c.Ints("tokenizer.ggml.token_type"), - Merges: c.Strings("tokenizer.ggml.merges"), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, - AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - EOS: append( - []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, - c.Ints("tokenizer.ggml.eos_token_ids")..., - ), - }, + + var processor model.TextProcessor + vocabulary := model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., ), - Layers: make([]Layer, c.Uint("block_count")), - Options: &Options{ + } + switch c.String("tokenizer.ggml.model") { + case "gpt2": + processor = model.NewBytePairEncoding( + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, + &vocabulary, + ) + case "llama": + processor = model.NewSentencePiece(&vocabulary) + default: + return nil, model.ErrUnsupportedTokenizer + } + + m := Model{ + TextProcessor: processor, + Layers: make([]Layer, c.Uint("block_count")), + Options: Options{ hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), headDim: int(c.Uint("attention.key_length")), ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeBase: c.Float("rope.freq_base", 1e5), + ropeScale: c.Float("rope.scaling.factor", 1), }, } @@ -98,8 +106,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) - key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) @@ -108,7 +116,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) - return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil + return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil } type MLP struct { @@ -118,7 +126,7 @@ type MLP struct { } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -160,10 +168,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } - hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) + hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 99a898d2..9cb2efc8 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index 045ab403..e0f93260 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -33,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) if useRope { - query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) - key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) } if opts.useQKNorm { @@ -58,14 +58,14 @@ type TextMLP struct { } func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } type TextExperts struct { - Gate *nn.Linear `gguf:"ffn_gate_exps"` - Up *nn.Linear `gguf:"ffn_up_exps"` - Down *nn.Linear `gguf:"ffn_down_exps"` + Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` } func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor { @@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed) hiddenStates = hiddenStates.Mul(ctx, scores) - upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts) - gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts) - downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) + upStates := e.Up.Forward(ctx, hiddenStates, experts) + gateStates := e.Gate.Forward(ctx, hiddenStates, experts) + downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)) for i := 1; i < opts.numExpertsUsed; i++ { @@ -96,7 +96,7 @@ type TextSharedExpert struct { } func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } @@ -196,7 +196,7 @@ func newTextModel(c fs.Config) *TextModel { numExpertsUsed: int(c.Uint("expert_used_count")), ropeDim: int(c.Uint("rope.dimension_count")), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), eps: c.Float("attention.layer_norm_rms_epsilon"), interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)), noRopeInterval: int(c.Uint("no_rope_interval", 4)), @@ -248,5 +248,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil } diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 408e54d3..435b1a30 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 19c36f9f..d2e2eac6 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -40,11 +40,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil } type MLP struct { @@ -65,7 +65,7 @@ type MLP struct { } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -132,7 +132,7 @@ func newTextModel(c fs.Config) *TextModel { ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), }, } } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 65bdcff2..3bfb8c90 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -51,7 +51,7 @@ type VisionMLP struct { } func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index d0ad4670..239d999d 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) // TODO: attention mask, cross attention mask - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil } func init() { diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 47a518ce..65f0a827 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -26,11 +26,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -45,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil } return key, nil @@ -58,7 +58,7 @@ type TextMLP struct { } func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel { ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), crossAttentionLayers: c.Ints("attention.cross_attention_layers"), }, } diff --git a/model/models/models.go b/model/models/models.go index c880a472..cc998078 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -1,6 +1,7 @@ package models import ( + _ "github.com/ollama/ollama/model/models/bert" _ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3n" diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 3c662f06..5a345837 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -43,8 +43,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, value := attn.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) - key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) @@ -59,7 +59,7 @@ type MLP struct { } func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } @@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) @@ -124,7 +124,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) - return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil } func New(c fs.Config) (model.Model, error) { @@ -160,7 +160,7 @@ func New(c fs.Config) (model.Model, error) { headDim: int(c.Uint("attention.key_length")), ropeDim: int(c.Uint("rope.dimension_count")), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), eps: c.Float("attention.layer_norm_rms_epsilon"), }, } diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index d73f499d..6c76305d 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache) } func init() { diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index 4b6bc166..e6c6e6c1 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -38,7 +38,7 @@ func NewTextModel(c fs.Config) *TextModel { originalContextLength: int(c.Uint("context_length", 128000)), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), }, } @@ -60,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) + q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) + k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -78,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten // Shift applies rotary position embeddings to the key tensor for causal attention caching func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil } // MLP implements the feed-forward network component with SwiGLU activation @@ -90,7 +90,7 @@ type MLP struct { func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { // Apply SwiGLU activation gating - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) // Project back to hidden dimension return mlp.Down.Forward(ctx, hiddenState) } diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 4d7afaa1..3dd60e3b 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -100,8 +100,7 @@ type VisionMLP struct { func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { // Using activation as specified in config (likely GELU or SiLU/Swish) gateOutput := mlp.Gate.Forward(ctx, hiddenStates) - upOutput := mlp.Up.Forward(ctx, hiddenStates) - hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput) + hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } diff --git a/model/models/qwen3/embed.go b/model/models/qwen3/embed.go new file mode 100644 index 00000000..9a77efea --- /dev/null +++ b/model/models/qwen3/embed.go @@ -0,0 +1,73 @@ +package qwen3 + +import ( + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn/pooling" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type embedModel struct { + model.Base + model.BytePairEncoding + + *Model + poolingType pooling.Type +} + +func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates, err := m.forward(ctx, batch) + if err != nil { + return nil, err + } + + hiddenStates = m.poolingType.Forward(ctx, hiddenStates) + hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) + return hiddenStates, nil +} + +func newEmbed(c fs.Config) (model.Model, error) { + layers := make([]Layer, c.Uint("block_count")) + for i := range layers { + layers[i].MLP = &dense{} + } + m := embedModel{ + BytePairEncoding: model.NewBytePairEncoding( + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + ), + Model: &Model{ + Layers: layers, + Options: &Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + keyLength: int(c.Uint("attention.key_length")), + valueLength: int(c.Uint("attention.value_length")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + normTopKProb: c.Bool("norm_top_k_prob", true), + }, + }, + poolingType: pooling.Type(c.Uint("pooling_type")), + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + return &m, nil +} diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 7a83e0d0..35226834 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -30,10 +30,10 @@ func (o Options) headDim() int { } type Attention struct { - QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` Query *nn.Linear `gguf:"attn_q"` - KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` + QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` Value *nn.Linear `gguf:"attn_v"` Output *nn.Linear `gguf:"attn_output"` } @@ -52,8 +52,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, query = sa.QueryNorm.Forward(ctx, query, opts.eps) key = sa.KeyNorm.Forward(ctx, key, opts.eps) - query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) - key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) + query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) + key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) @@ -65,10 +65,10 @@ type MLP interface { } type sparse struct { - Router *nn.Linear `gguf:"ffn_gate_inp"` - Gate *nn.Linear `gguf:"ffn_gate_exps"` - Up *nn.Linear `gguf:"ffn_up_exps"` - Down *nn.Linear `gguf:"ffn_down_exps"` + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` } func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { @@ -87,13 +87,9 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) - upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts)) - hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts) - hiddenStates = hiddenStates.SILU(ctx) - hiddenStates = hiddenStates.Mul(ctx, upStates) - - experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts) + experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) experts = experts.Mul(ctx, routingWeights) nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) @@ -111,7 +107,8 @@ type dense struct { } func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates). + SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } @@ -154,29 +151,39 @@ type Model struct { *Options } -// Forward implements model.Model. func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates, err := m.forward(ctx, batch) + if err != nil { + return nil, err + } + + return m.Output.Forward(ctx, hiddenStates), nil +} + +// Forward implements model.Model. +func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) for i, layer := range m.Layers { - m.Cache.SetLayer(i) + if m.Cache != nil { + m.Cache.SetLayer(i) + } var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) } - hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) - return m.Output.Forward(ctx, hiddenStates), nil + return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil } var _ model.Model = (*Model)(nil) @@ -216,7 +223,7 @@ func New(c fs.Config) (model.Model, error) { valueLength: int(c.Uint("attention.value_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), + ropeScale: c.Float("rope.scaling.factor", 1), numExperts: int(c.Uint("expert_count")), numExpertsUsed: int(c.Uint("expert_used_count")), normTopKProb: c.Bool("norm_top_k_prob", true), @@ -230,4 +237,5 @@ func New(c fs.Config) (model.Model, error) { func init() { model.Register("qwen3", New) model.Register("qwen3moe", New) + model.Register("qwen3_embed", newEmbed) } diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go new file mode 100644 index 00000000..e6dbd1f4 --- /dev/null +++ b/model/parsers/parsers.go @@ -0,0 +1,37 @@ +package parsers + +import ( + "github.com/ollama/ollama/api" +) + +type Parser interface { + Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) + HasToolSupport() bool + HasThinkingSupport() bool +} + +func ParserForName(name string) Parser { + switch name { + case "qwen3-coder": + parser := &Qwen3CoderParser{} + return parser + case "passthrough": + return &PassthroughParser{} + default: + return nil + } +} + +type PassthroughParser struct{} + +func (p *PassthroughParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) { + return s, "", nil, nil +} + +func (p *PassthroughParser) HasToolSupport() bool { + return false +} + +func (p *PassthroughParser) HasThinkingSupport() bool { + return false +} diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go new file mode 100644 index 00000000..a7838fb7 --- /dev/null +++ b/model/parsers/qwen3coder.go @@ -0,0 +1,447 @@ +package parsers + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "log/slog" + "math" + "regexp" + "strconv" + "strings" + "unicode" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" +) + +type qwenParserState int + +const ( + toolOpenTag = "" + toolCloseTag = "" +) + +const ( + qwenParserState_LookingForToolStart qwenParserState = iota + qwenParserState_CollectingToolContent +) + +type Qwen3CoderParser struct { + state qwenParserState + acc strings.Builder +} + +func (p *Qwen3CoderParser) HasToolSupport() bool { + return true +} + +func (p *Qwen3CoderParser) HasThinkingSupport() bool { + return false +} + +func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) { + p.acc.WriteString(s) + + events := p.parseEvents() + + var toolCalls []api.ToolCall + var sb strings.Builder + for _, event := range events { + switch event := event.(type) { + case qwenEventRawToolCall: + toolCall, err := parseToolCall(event, tools) + if err != nil { + slog.Warn("qwen tool call parsing failed", "error", err) + return "", "", nil, err + } + toolCalls = append(toolCalls, toolCall) + case qwenEventContent: + // TODO(drifkin): if the same turn contains multiple interleaved content + // events, we naively append them together here. See the note below about + // `qwenEvent`s for more details + sb.WriteString(event.content) + } + } + + return sb.String(), "", toolCalls, nil +} + +func (p *Qwen3CoderParser) parseEvents() []qwenEvent { + var all []qwenEvent + + keepLooping := true + for keepLooping { + var events []qwenEvent + events, keepLooping = eat(p) + if len(events) > 0 { + all = append(all, events...) + } + } + + if len(all) > 0 { + slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "acc", p.acc.String()) + } + + return all +} + +// we use some internal event types in order to communicate between `Add` and +// `eat`. We do this to support interleaving content and parallel tool calls in +// the parser, even though qwen3-coder isn't supposed to do this. Our API +// doesn't currently support models outputting multiple messages in a turn, so +// we wouldn't be able to represent it yet, but there's no reason to prevent the +// parser from supporting it, especially for future models if they end up using +// a similar format. +type qwenEvent interface { + isQwenEvent() +} + +type qwenEventRawToolCall struct { + raw string +} + +type qwenEventContent struct { + content string +} + +func (qwenEventContent) isQwenEvent() {} +func (qwenEventRawToolCall) isQwenEvent() {} + +// eat consumes the parser's buffer, and returns a list of any unambiguous +// events from the current parser state. If the parser transitions to another +// state, it may have additional events to emit on the next call, which is what +// the second return value indicates +func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) { + var events []qwenEvent + + switch p.state { + case qwenParserState_LookingForToolStart: + if strings.Contains(p.acc.String(), toolOpenTag) { + // we found a full tool open tag, so we can emit the content before the + // tag, being sure to trim any trailing whitespace + split := strings.SplitN(p.acc.String(), toolOpenTag, 2) + before := split[0] + before = strings.TrimRightFunc(before, unicode.IsSpace) + if len(before) > 0 { + events = append(events, qwenEventContent{content: before}) + } + after := split[1] + p.acc.Reset() + p.acc.WriteString(after) + p.state = qwenParserState_CollectingToolContent + return events, true + } else if overlap := overlap(p.acc.String(), toolOpenTag); overlap > 0 { + // we found a partial tool open tag, so we can emit the unambiguous part, + // which is the (trailing-whitespace trimmed) content before the partial + // tool open tag + beforePartialTag := p.acc.String()[:len(p.acc.String())-overlap] + trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen + unambiguous := p.acc.String()[:ambiguousStart] + ambiguous := p.acc.String()[ambiguousStart:] + p.acc.Reset() + p.acc.WriteString(ambiguous) + events = append(events, qwenEventContent{content: unambiguous}) + return events, false + } else { + // we found content that is entirely not a tool call. We should withhold + // any trailing whitespace in case this is the end of the content + whitespaceLen := trailingWhitespaceLen(p.acc.String()) + ambiguousStart := len(p.acc.String()) - whitespaceLen + unambiguous := p.acc.String()[:ambiguousStart] + ambiguous := p.acc.String()[ambiguousStart:] + p.acc.Reset() + p.acc.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, qwenEventContent{content: unambiguous}) + } + return events, false + } + case qwenParserState_CollectingToolContent: + if strings.Contains(p.acc.String(), toolCloseTag) { + split := strings.SplitN(p.acc.String(), toolCloseTag, 2) + before := split[0] + if len(before) == 0 { + slog.Warn("qwen tool call closing tag found but no content before it") + } + // remove any whitespace between the tool call and any content after it + after := strings.TrimLeftFunc(split[1], unicode.IsSpace) + p.acc.Reset() + p.acc.WriteString(after) + events = append(events, qwenEventRawToolCall{raw: before}) + p.state = qwenParserState_LookingForToolStart + return events, true + } else { + // note that we don't need to check the overlap here because we only plan + // on parsing the tool call once we see the full closing tag. We don't + // stream back the unparsed tool content, so there's no need to be eager + // here + return events, false + } + default: + panic("unreachable") + } +} + +// TODO(drifkin): move this to a shared location +// longest overlap between suffix of s and prefix of delim +func overlap(s, delim string) int { + max := min(len(delim), len(s)) + for i := max; i > 0; i-- { + if strings.HasSuffix(s, delim[:i]) { + return i + } + } + return 0 +} + +func trailingWhitespaceLen(s string) int { + for i := len(s) - 1; i >= 0; i-- { + if !unicode.IsSpace(rune(s[i])) { + return len(s) - i - 1 + } + } + return len(s) +} + +type XMLFunctionCall struct { + XMLName xml.Name `xml:"function"` + Name string `xml:"name,attr"` + Parameters []XMLParameter `xml:"parameter"` +} + +type XMLParameter struct { + Name string `xml:"name,attr"` + Value string `xml:",chardata"` +} + +// parseToolCall parses a raw tool call string into an api.ToolCall. +// The raw string follows an xml-like format, here's an example: +// +// +// +// San Francisco +// +// +// celsius +// +// +func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) { + toolCall := api.ToolCall{} + + xmlString := transformToXML(raw.raw) + + var functionCall XMLFunctionCall + err := xml.Unmarshal([]byte(xmlString), &functionCall) + if err != nil { + return api.ToolCall{}, err + } + + toolCall.Function = api.ToolCallFunction{ + Name: functionCall.Name, + } + + // Find the matching tool to get parameter types + var matchedTool *api.Tool + for i := range tools { + if tools[i].Function.Name == functionCall.Name { + matchedTool = &tools[i] + break + } + } + + toolCall.Function.Arguments = make(api.ToolCallFunctionArguments) + for _, parameter := range functionCall.Parameters { + // Look up the parameter type if we found the tool + var paramType api.PropertyType + if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil { + if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok { + paramType = prop.Type + } + } + + toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType) + } + + return toolCall, nil +} + +// parseValue converts a raw string value to the appropriate type based on the parameter type specification. +// +// For union types (multiple types in PropertyType, which we support but doesn't +// seem as though the reference parser does type coercion with those types in +// mind) we use a type precedence approach: +// 1. null - checked first regardless of declared types (matches reference implementation) +// 2. boolean - only "true"/"false" are valid booleans +// 3. integer - must parse as a whole number +// 4. number - must parse as numeric (returns int if no decimal part) +// 5. array - must parse as valid JSON array +// 6. object - must parse as valid JSON object +// 7. string - always succeeds (least specific type) +// +// This precedence ensures we return the most specific type that successfully parses, +// following the principle of least surprise. For example, with PropertyType{"string", "number"}, +// "123" becomes 123 (number), while "hello" becomes "hello" (string). +func parseValue(raw string, paramType api.PropertyType) any { + // first remove a single leading newlines, and a single trailing newline (if + // they exist). This follows the reference implementation + raw = strings.TrimPrefix(raw, "\n") + raw = strings.TrimSuffix(raw, "\n") + + // Check for null first (case-insensitive) - this takes precedence over any type + if strings.ToLower(raw) == "null" { + return nil + } + + // If no type is specified, default to string + if len(paramType) == 0 { + return raw + } + + // Check if any of the specified types match, using type precedence + // Order: boolean -> integer -> number -> array -> object -> string + typeSet := make(map[string]bool) + for _, t := range paramType { + typeSet[t] = true + } + + // Try boolean first (most restrictive) + if typeSet["boolean"] { + lower := strings.ToLower(raw) + switch lower { + case "true": + return true + case "false": + return false + } + // If not a valid boolean but boolean is the only type, return false (matching reference) + if len(paramType) == 1 { + return false + } + // Otherwise try other types + } + + // Try integer + if typeSet["integer"] { + if i, err := strconv.ParseInt(raw, 10, 64); err == nil { + // Return as int if it fits in int32, otherwise int64 + if i >= math.MinInt32 && i <= math.MaxInt32 { + return int(i) + } + return i + } + // If integer is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // Try number (float) + if typeSet["number"] { + if f, err := strconv.ParseFloat(raw, 64); err == nil { + // If the number has no decimal part, return as int (matching reference) + if f == math.Trunc(f) { + i := int64(f) + if i >= math.MinInt32 && i <= math.MaxInt32 { + return int(i) + } + return i + } + return f + } + // If number is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // Try array + if typeSet["array"] { + var arr []interface{} + if err := json.Unmarshal([]byte(raw), &arr); err == nil { + return arr + } + // If array is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // Try object + if typeSet["object"] { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(raw), &obj); err == nil { + return obj + } + // If object is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // String always succeeds (or if "string" is in the type set) + if typeSet["string"] { + return raw + } + + // If we get here, none of the types matched and string wasn't an option + // We return string as a fallback. The reference implementation will attempt + // to parse the value as a python literal, but we purposefully don't support + // that + return raw +} + +var ( + qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`) + qwenXMLTagRegex = regexp.MustCompile(``) +) + +// transformToXML transforms a raw qwen tool call with xml-like tags into valid +// xml so that it can be parsed by any xml parser +func transformToXML(raw string) string { + // take the form `` and transform it to ``, taking + // care to properly escape the string that becomes the attribute value + transformed := qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string { + groups := qwenTagRegex.FindStringSubmatch(match) + tag := groups[1] + var escapedValue strings.Builder + xml.EscapeText(&escapedValue, []byte(groups[2])) + return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String()) + }) + + // Walk the resulting string, escaping any character data that sits between the + // xml tags we just emitted + var out strings.Builder + lastIdx := 0 + for _, loc := range qwenXMLTagRegex.FindAllStringIndex(transformed, -1) { + if loc[0] > lastIdx { + escapeTextNode(&out, transformed[lastIdx:loc[0]]) + } + out.WriteString(transformed[loc[0]:loc[1]]) + lastIdx = loc[1] + } + if lastIdx < len(transformed) { + escapeTextNode(&out, transformed[lastIdx:]) + } + + return out.String() +} + +// escapeTextNode escapes XML character data without altering other characters +// like newlines or tabs (which is why we don't use xml.EscapeText for this) +func escapeTextNode(sb *strings.Builder, s string) { + for _, r := range s { + switch r { + case '&': + sb.WriteString("&") + case '<': + sb.WriteString("<") + case '>': + sb.WriteString(">") + default: + sb.WriteRune(r) + } + } +} diff --git a/model/parsers/qwen3coder_test.go b/model/parsers/qwen3coder_test.go new file mode 100644 index 00000000..43823e6f --- /dev/null +++ b/model/parsers/qwen3coder_test.go @@ -0,0 +1,878 @@ +package parsers + +import ( + "reflect" + "testing" + + "github.com/ollama/ollama/api" +) + +// tool creates a test tool with the given name and properties +func tool(name string, props map[string]api.ToolProperty) api.Tool { + t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}} + t.Function.Parameters.Type = "object" + t.Function.Parameters.Properties = props + return t +} + +func TestQwenParserStreaming(t *testing.T) { + type step struct { + input string + wantEvents []qwenEvent + } + + cases := []struct { + desc string + steps []step + only bool + }{ + { + desc: "simple message streamed word by word", + steps: []step{ + { + input: "hi", + wantEvents: []qwenEvent{qwenEventContent{content: "hi"}}, + }, + { + input: " there", + wantEvents: []qwenEvent{qwenEventContent{content: " there"}}, + }, + }, + }, + { + desc: "content before tool call", + steps: []step{ + { + input: "hi there", + wantEvents: []qwenEvent{qwenEventContent{content: "hi there"}}, + }, + }, + }, + { + desc: "multiple tool calls in one message", + steps: []step{ + { + input: "before1in tool callafter1in tool call 2after2", + wantEvents: []qwenEvent{ + qwenEventContent{content: "before1"}, + qwenEventRawToolCall{raw: "in tool call"}, + qwenEventContent{content: "after1"}, + qwenEventRawToolCall{raw: "in tool call 2"}, + qwenEventContent{content: "after2"}, + }, + }, + }, + }, + { + desc: "tool calls with split tags", + steps: []step{ + { + input: "beforein tool callaf", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "in tool call"}, + qwenEventContent{content: "af"}, + }, + }, + { + input: "ter", + wantEvents: []qwenEvent{ + qwenEventContent{content: "ter"}, + }, + }, + }, + }, + { + desc: "trailing whitespace between content and tool call", + steps: []step{ + { + input: "abc\ndef", + wantEvents: []qwenEvent{ + qwenEventContent{content: "abc"}, + qwenEventRawToolCall{raw: "def"}, + }, + }, + }, + }, + { + desc: "trailing whitespace between tool call and content", + steps: []step{ + { + input: "abc\ndef", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + qwenEventContent{content: "def"}, + }, + }, + }, + }, + { + desc: "empty content before tool call", + steps: []step{ + { + input: "\nabc", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + }, + }, + }, + }, + { + desc: "partial tool open tag fakeout", + steps: []step{ + { + input: "abc\n + +San Francisco + + +celsius + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_temperature", + Arguments: map[string]any{ + "location": "San Francisco", + "unit": "celsius", + }, + }, + }, + }, + { + name: "names with spaces", + tools: []api.Tool{}, + rawToolCall: ` + +San Francisco + + +celsius + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get current temperature", + Arguments: map[string]any{ + "location with spaces": "San Francisco", + "unit with spaces": "celsius", + }, + }, + }, + }, + // this mirrors the reference implementation's behavior, but unclear if it + // ever happens. If so, then we should probably remove them instead, this + // test is to just document the current behavior and test that we don't get + // xml errors + { + name: "names with quotes", + tools: []api.Tool{}, + rawToolCall: ` + +San Francisco + + +"celsius" + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "\"get current temperature\"", + Arguments: map[string]any{ + "\"location with spaces\"": "San Francisco", + "\"unit with spaces\"": "\"celsius\"", + }, + }, + }, + }, + { + name: "tool call with typed parameters", + tools: []api.Tool{ + tool("calculate", map[string]api.ToolProperty{ + "x": {Type: api.PropertyType{"number"}}, + "y": {Type: api.PropertyType{"integer"}}, + "enabled": {Type: api.PropertyType{"boolean"}}, + "items": {Type: api.PropertyType{"array"}}, + }), + }, + rawToolCall: ` + +3.14 + + +42 + + +true + + +["a", "b", "c"] + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "calculate", + Arguments: map[string]any{ + "x": 3.14, + "y": 42, + "enabled": true, + "items": []any{"a", "b", "c"}, + }, + }, + }, + }, + // regression test for + { + name: "ampersands in parameter values", + tools: []api.Tool{}, + rawToolCall: ` + +ls && echo "done" + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "exec", + Arguments: map[string]any{ + "command": "ls && echo \"done\"", + }, + }, + }, + }, + { + name: "angle brackets in parameter values", + tools: []api.Tool{}, + rawToolCall: ` + +ls && echo "a > b and a < b" + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "exec", + Arguments: map[string]any{ + "command": "ls && echo \"a > b and a < b\"", + }, + }, + }, + }, + } + + for i, step := range steps { + gotToolCall, err := parseToolCall(qwenEventRawToolCall{raw: step.rawToolCall}, step.tools) + if err != nil { + t.Errorf("step %d (%s): %v", i, step.name, err) + } + if !reflect.DeepEqual(gotToolCall, step.wantToolCall) { + t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall) + } + } +} + +func TestQwenToolCallValueParsing(t *testing.T) { + cases := []struct { + desc string + raw string + paramType api.PropertyType + want any + }{ + { + desc: "default string value (no type specified)", + paramType: api.PropertyType{}, + raw: "some-string", + want: "some-string", + }, + { + desc: "trim a single leading and trailing newline", + paramType: api.PropertyType{}, + raw: "\nsome-string\n", + want: "some-string", + }, + { + desc: "trim at most one leading and trailing newline", + paramType: api.PropertyType{}, + raw: "\n\nsome-string\n\n", + want: "\nsome-string\n", + }, + { + desc: "newline really has to be the first character to be trimmed", + paramType: api.PropertyType{}, + raw: " \nsome-string\n ", + want: " \nsome-string\n ", + }, + { + desc: "numeric type", + paramType: api.PropertyType{"number"}, + raw: "123", + want: 123, + }, + // Integer parsing tests + { + desc: "integer type", + paramType: api.PropertyType{"integer"}, + raw: "42", + want: 42, + }, + { + desc: "negative integer", + paramType: api.PropertyType{"integer"}, + raw: "-100", + want: -100, + }, + { + desc: "zero integer", + paramType: api.PropertyType{"integer"}, + raw: "0", + want: 0, + }, + { + desc: "integer with leading zeros", + paramType: api.PropertyType{"integer"}, + raw: "007", + want: 7, + }, + { + desc: "large integer", + paramType: api.PropertyType{"integer"}, + raw: "2147483648", // Just beyond int32 max + want: int64(2147483648), + }, + // Float/number parsing tests + { + desc: "float type", + paramType: api.PropertyType{"number"}, + raw: "3.14", + want: 3.14, + }, + { + desc: "negative float", + paramType: api.PropertyType{"number"}, + raw: "-273.15", + want: -273.15, + }, + { + desc: "float without decimal part", + paramType: api.PropertyType{"number"}, + raw: "100.0", + want: 100, + }, + { + desc: "scientific notation positive", + paramType: api.PropertyType{"number"}, + raw: "1.23e5", + want: 123000, // Will be int since it has no decimal part + }, + { + desc: "scientific notation negative", + paramType: api.PropertyType{"number"}, + raw: "1.5e-3", + want: 0.0015, + }, + { + desc: "very small float", + paramType: api.PropertyType{"number"}, + raw: "0.00000001", + want: 0.00000001, + }, + // String parsing tests + { + desc: "explicit string type", + paramType: api.PropertyType{"string"}, + raw: "hello world", + want: "hello world", + }, + { + desc: "string with special characters", + paramType: api.PropertyType{"string"}, + raw: "/usr/local/bin/test-file_v2.0.sh", + want: "/usr/local/bin/test-file_v2.0.sh", + }, + { + desc: "string with quotes", + paramType: api.PropertyType{"string"}, + raw: `He said "hello" to me`, + want: `He said "hello" to me`, + }, + { + desc: "multiline string", + paramType: api.PropertyType{"string"}, + raw: "line one\nline two\nline three", + want: "line one\nline two\nline three", + }, + { + desc: "empty string", + paramType: api.PropertyType{"string"}, + raw: "", + want: "", + }, + { + desc: "string that looks like a number", + paramType: api.PropertyType{"string"}, + raw: "12345", + want: "12345", + }, + // Boolean parsing tests + { + desc: "boolean true", + paramType: api.PropertyType{"boolean"}, + raw: "true", + want: true, + }, + { + desc: "boolean false", + paramType: api.PropertyType{"boolean"}, + raw: "false", + want: false, + }, + { + desc: "boolean case insensitive true", + paramType: api.PropertyType{"boolean"}, + raw: "True", + want: true, + }, + { + desc: "boolean case insensitive false", + paramType: api.PropertyType{"boolean"}, + raw: "FALSE", + want: false, + }, + // Null parsing tests + { + desc: "null value lowercase", + paramType: api.PropertyType{"string"}, + raw: "null", + want: nil, + }, + { + desc: "null value case insensitive", + paramType: api.PropertyType{"integer"}, + raw: "NULL", + want: nil, + }, + // Array parsing tests + { + desc: "array of strings", + paramType: api.PropertyType{"array"}, + raw: `["foo", "bar", "baz"]`, + want: []any{"foo", "bar", "baz"}, + }, + { + desc: "array of numbers", + paramType: api.PropertyType{"array"}, + raw: `[1, 2.5, 3]`, + want: []any{float64(1), 2.5, float64(3)}, + }, + { + desc: "array of mixed types", + paramType: api.PropertyType{"array"}, + raw: `["string", 123, true, null]`, + want: []any{"string", float64(123), true, nil}, + }, + { + desc: "empty array", + paramType: api.PropertyType{"array"}, + raw: `[]`, + want: []any{}, + }, + // Object parsing tests + { + desc: "simple object", + paramType: api.PropertyType{"object"}, + raw: `{"key": "value", "number": 42}`, + want: map[string]any{"key": "value", "number": float64(42)}, + }, + { + desc: "nested object", + paramType: api.PropertyType{"object"}, + raw: `{"outer": {"inner": "value"}}`, + want: map[string]any{"outer": map[string]any{"inner": "value"}}, + }, + { + desc: "empty object", + paramType: api.PropertyType{"object"}, + raw: `{}`, + want: map[string]any{}, + }, + // Error cases and fallback behavior + { + desc: "invalid integer falls back to string", + paramType: api.PropertyType{"integer"}, + raw: "not-a-number", + want: "not-a-number", + }, + { + desc: "invalid float falls back to string", + paramType: api.PropertyType{"number"}, + raw: "3.14.159", + want: "3.14.159", + }, + { + desc: "invalid boolean falls back to false", + paramType: api.PropertyType{"boolean"}, + raw: "yes", + want: false, + }, + { + desc: "invalid JSON array falls back to string", + paramType: api.PropertyType{"array"}, + raw: "[1, 2, unclosed", + want: "[1, 2, unclosed", + }, + { + desc: "invalid JSON object falls back to string", + paramType: api.PropertyType{"object"}, + raw: `{"key": unclosed`, + want: `{"key": unclosed`, + }, + // Edge cases + { + desc: "integer overflow should use int64", + paramType: api.PropertyType{"integer"}, + raw: "2147483648", // Beyond int32 max + want: int64(2147483648), + }, + { + desc: "float with many decimal places", + paramType: api.PropertyType{"number"}, + raw: "3.141592653589793", + want: 3.141592653589793, + }, + { + desc: "string with JSON-like content", + paramType: api.PropertyType{"string"}, + raw: `{"this": "is", "just": "a string"}`, + want: `{"this": "is", "just": "a string"}`, + }, + { + desc: "whitespace-only string", + paramType: api.PropertyType{"string"}, + raw: " ", + want: " ", + }, + // Unknown parameter (no type specified in tools) + { + desc: "parameter not in tool definition defaults to string", + paramType: api.PropertyType{}, + raw: "some value", + want: "some value", + }, + // Union type tests + { + desc: "string or number union - valid number", + paramType: api.PropertyType{"string", "number"}, + raw: "42.5", + want: 42.5, + }, + { + desc: "string or number union - non-numeric string", + paramType: api.PropertyType{"string", "number"}, + raw: "hello", + want: "hello", + }, + { + desc: "number or string union - valid number (order shouldn't matter)", + paramType: api.PropertyType{"number", "string"}, + raw: "42.5", + want: 42.5, + }, + { + desc: "integer or null union - valid integer", + paramType: api.PropertyType{"integer", "null"}, + raw: "123", + want: 123, + }, + { + desc: "integer or null union - null value", + paramType: api.PropertyType{"integer", "null"}, + raw: "null", + want: nil, + }, + { + desc: "null or integer union - null value (order shouldn't matter)", + paramType: api.PropertyType{"null", "integer"}, + raw: "null", + want: nil, + }, + { + desc: "boolean or string union - valid boolean", + paramType: api.PropertyType{"boolean", "string"}, + raw: "true", + want: true, + }, + { + desc: "boolean or string union - non-boolean becomes string", + paramType: api.PropertyType{"boolean", "string"}, + raw: "yes", + want: "yes", + }, + { + desc: "string or boolean union - valid boolean (precedence test)", + paramType: api.PropertyType{"string", "boolean"}, + raw: "false", + want: false, // Should be boolean, not string "false" + }, + { + desc: "integer or number union - integer value", + paramType: api.PropertyType{"integer", "number"}, + raw: "42", + want: 42, + }, + { + desc: "integer or number union - float value", + paramType: api.PropertyType{"integer", "number"}, + raw: "42.5", + want: 42.5, + }, + { + desc: "number or integer union - integer value (precedence test)", + paramType: api.PropertyType{"number", "integer"}, + raw: "42", + want: 42, // Should try integer first due to precedence + }, + { + desc: "array or object union - valid array", + paramType: api.PropertyType{"array", "object"}, + raw: `[1, 2, 3]`, + want: []any{float64(1), float64(2), float64(3)}, + }, + { + desc: "array or object union - valid object", + paramType: api.PropertyType{"array", "object"}, + raw: `{"key": "value"}`, + want: map[string]any{"key": "value"}, + }, + { + desc: "object or array union - valid array (precedence test)", + paramType: api.PropertyType{"object", "array"}, + raw: `[1, 2, 3]`, + want: []any{float64(1), float64(2), float64(3)}, + }, + { + desc: "complex multi-type union - null", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "null", + want: nil, + }, + { + desc: "complex multi-type union - boolean", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "true", + want: true, + }, + { + desc: "complex multi-type union - number", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "3.14", + want: 3.14, + }, + { + desc: "complex multi-type union - string", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "hello", + want: "hello", + }, + { + desc: "integer string union - integer string becomes integer", + paramType: api.PropertyType{"integer", "string"}, + raw: "123", + want: 123, + }, + { + desc: "string integer union - integer string becomes integer (precedence)", + paramType: api.PropertyType{"string", "integer"}, + raw: "123", + want: 123, // Integer has higher precedence than string + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + got := parseValue(tc.raw, tc.paramType) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("got %v (type %T), want %v (type %T)", got, got, tc.want, tc.want) + } + }) + } +} + +func TestQwenXMLTransform(t *testing.T) { + cases := []struct { + desc string + raw string + want string + }{ + { + desc: "simple example", + raw: ` + +San Francisco + + +celsius + +`, + want: ` + +San Francisco + + +celsius + +`, + }, + // even though quotes aren't expected in these tags, we have these tests to + // make sure they're escaped so they don't blow up the xml parser in case + // they happen + { + desc: "names with quotes", + raw: ` + +San Francisco + + +celsius + +`, + want: ` + +San Francisco + + +celsius + +`, + }, + { + desc: "ampersands in parameter values", + raw: ` + + San Francisco & San Jose + + `, + want: ` + + San Francisco & San Jose + + `, + }, + } + + for _, tc := range cases { + got := transformToXML(tc.raw) + if got != tc.want { + t.Errorf("got %q, want %q", got, tc.want) + } + } +} + +func TestTrailingWhitespaceLen(t *testing.T) { + cases := []struct { + desc string + s string + want int + }{ + {desc: "no whitespace", s: "abc", want: 0}, + {desc: "trailing whitespace", s: "abc ", want: 1}, + {desc: "trailing whitespace with newlines", s: "abc \n", want: 2}, + {desc: "only whitespace", s: " \n ", want: 4}, + {desc: "leading whitespace doesn't count", s: " \n abc", want: 0}, + } + + for _, tc := range cases { + got := trailingWhitespaceLen(tc.s) + if got != tc.want { + t.Errorf("got %d, want %d", got, tc.want) + } + } +} diff --git a/model/renderers/qwen3coder.go b/model/renderers/qwen3coder.go new file mode 100644 index 00000000..df3b3a45 --- /dev/null +++ b/model/renderers/qwen3coder.go @@ -0,0 +1,217 @@ +package renderers + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + + "github.com/ollama/ollama/api" +) + +var ( + imStartTag = "<|im_start|>" + imEndTag = "<|im_end|>" +) + +// renderAdditionalKeys renders all JSON fields except the ones in handledKeys +// This follows the same approach from the reference implementation, which gives +// a particular key ordering +func renderAdditionalKeys(obj any, handledKeys map[string]bool) string { + data, err := json.Marshal(obj) + if err != nil { + return "" + } + + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + return "" + } + + var sb strings.Builder + for key, value := range m { + if handledKeys[key] { + continue + } + + // Check if value is a map or array (needs JSON serialization) + switch v := value.(type) { + case map[string]any, []any: + jsonBytes, _ := json.Marshal(v) + // TODO(drifkin): it would be nice to format the JSON here similarly to + // python's default json.dumps behavior (spaces after commas and colons). + // This would let us be byte-for-byte compatible with the reference + // implementation for most common inputs + jsonStr := string(jsonBytes) + sb.WriteString("\n<" + key + ">" + jsonStr + "") + case nil: + continue + default: + // Simple types, convert to string + sb.WriteString("\n<" + key + ">" + fmt.Sprintf("%v", value) + "") + } + } + + return sb.String() +} + +func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) { + var sb strings.Builder + + // filter out system messages and choose the first (if any) to win + var systemMessage string + var filteredMessages []api.Message + for _, message := range messages { + if message.Role != "system" { + filteredMessages = append(filteredMessages, message) + continue + } + + if systemMessage == "" { + systemMessage = message.Content + } + } + + if systemMessage != "" || len(tools) > 0 { + sb.WriteString(imStartTag + "system\n") + + // if we have tools but no system message, match the reference implementation by providing a default system message + if systemMessage == "" { + systemMessage = "You are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." + } + + sb.WriteString(systemMessage) + + if len(tools) > 0 { + sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n") + sb.WriteString("") + for _, tool := range tools { + sb.WriteString("\n") + sb.WriteString("\n") + sb.WriteString("" + tool.Function.Name + "") + if tool.Function.Description != "" { + sb.WriteString("\n" + tool.Function.Description + "") + } + sb.WriteString("\n") + + for name, prop := range tool.Function.Parameters.Properties { + sb.WriteString("\n") + sb.WriteString("\n" + name + "") + + if len(prop.Type) > 0 { + // TODO(!!!)(drifkin): we should match the reference implementation for + // more complex types here instead of using this format + sb.WriteString("\n" + prop.ToTypeScriptType() + "") + } + + if prop.Description != "" { + sb.WriteString("\n" + prop.Description + "") + } + + // Render any additional keys not already handled + handledKeys := map[string]bool{ + "type": true, + "description": true, + } + sb.WriteString(renderAdditionalKeys(prop, handledKeys)) + + sb.WriteString("\n") + } + + // Render extra keys for parameters (everything except 'type' and 'properties') + paramHandledKeys := map[string]bool{ + "type": true, + "properties": true, + } + sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys)) + + sb.WriteString("\n") + sb.WriteString("\n") + } + sb.WriteString("\n") + sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n") + } + + sb.WriteString(imEndTag + "\n") + } + + for i, message := range filteredMessages { + lastMessage := i == len(filteredMessages)-1 + prefill := lastMessage && message.Role == "assistant" + switch message.Role { + case "assistant": + if len(message.ToolCalls) > 0 { + sb.WriteString(imStartTag + "assistant\n") + if message.Content != "" { + sb.WriteString(message.Content + "\n") + } + for _, toolCall := range message.ToolCalls { + sb.WriteString("\n\n") + for name, value := range toolCall.Function.Arguments { + valueStr := formatToolCallArgument(value) + sb.WriteString("\n\n" + valueStr + "\n") + } + sb.WriteString("\n\n") + } + sb.WriteString("<|im_end|>\n") + } else { + sb.WriteString(imStartTag + "assistant\n") + sb.WriteString(message.Content) + if !prefill { + sb.WriteString(imEndTag + "\n") + } + } + case "tool": + // consecutive tool responses should share a single `user`, but + // have their own tags + + // only start a new user block if this is the first tool response + if i == 0 || filteredMessages[i-1].Role != "tool" { + sb.WriteString(imStartTag + "user\n") + } + + sb.WriteString("\n") + sb.WriteString(message.Content) + sb.WriteString("\n\n") + + // close the user block only if this is the last tool response + if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" { + sb.WriteString(imEndTag + "\n") + } + default: + sb.WriteString(imStartTag + message.Role + "\n") + sb.WriteString(message.Content) + sb.WriteString(imEndTag + "\n") + } + + if lastMessage && !prefill { + sb.WriteString(imStartTag + "assistant\n") + } + } + + return sb.String(), nil +} + +func formatToolCallArgument(value any) string { + if value == nil { + return "null" + } + + switch v := value.(type) { + case string: + return v + case []byte: + return string(v) + } + + if reflect.TypeOf(value) != nil { + kind := reflect.TypeOf(value).Kind() + if kind == reflect.Map || kind == reflect.Slice || kind == reflect.Array { + if marshalled, err := json.Marshal(value); err == nil { + return string(marshalled) + } + } + } + + return fmt.Sprintf("%v", value) +} diff --git a/model/renderers/qwen3coder_test.go b/model/renderers/qwen3coder_test.go new file mode 100644 index 00000000..4aaa066d --- /dev/null +++ b/model/renderers/qwen3coder_test.go @@ -0,0 +1,338 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" +) + +func TestQwen3CoderRenderer(t *testing.T) { + tests := []struct { + name string + msgs []api.Message + tools []api.Tool + expected string + }{ + { + name: "basic", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello, how are you?"}, + }, + expected: `<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +Hello, how are you?<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "with tools and response", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant with access to tools."}, + {Role: "user", Content: "What is the weather like in San Francisco?"}, + { + Role: "assistant", + Content: "I'll check the weather in San Francisco for you.", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{ + "unit": "fahrenheit", + }, + }, + }, + }, + }, + {Role: "tool", Content: "{\"location\": \"San Francisco, CA\", \"temperature\": 68, \"condition\": \"partly cloudy\", \"humidity\": 65, \"wind_speed\": 12}", ToolName: "get_weather"}, + {Role: "user", Content: "That sounds nice! What about New York?"}, + }, + tools: []api.Tool{ + {Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: api.ToolFunctionParameters{ + Required: []string{"unit"}, + Properties: map[string]api.ToolProperty{ + "unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"}, + // TODO(drifkin): add multiple params back once we have predictable + // order via some sort of ordered map type (see + // ) + /* + "location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"}, + */ + }, + }, + }}, + }, + expected: `<|im_start|>system +You are a helpful assistant with access to tools. + +# Tools + +You have access to the following functions: + + + +get_weather +Get the current weather in a given location + + +unit +string +The unit of temperature +["celsius","fahrenheit"] + +["unit"] + + + + +If you choose to call a function ONLY reply in the following format with NO suffix: + + + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + + + + +Reminder: +- Function calls MUST follow the specified format: an inner block must be nested within XML tags +- Required parameters MUST be specified +- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +<|im_end|> +<|im_start|>user +What is the weather like in San Francisco?<|im_end|> +<|im_start|>assistant +I'll check the weather in San Francisco for you. + + + + +fahrenheit + + +<|im_end|> +<|im_start|>user + +{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12} + +<|im_end|> +<|im_start|>user +That sounds nice! What about New York?<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "parallel tool calls", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant with access to tools."}, + {Role: "user", Content: "call double(1) and triple(2)"}, + {Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}}, + {Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}}, + }}, + {Role: "tool", Content: "{\"number\": 2}", ToolName: "double"}, + {Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"}, + }, + tools: []api.Tool{ + {Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{ + "number": {Type: api.PropertyType{"string"}, Description: "The number to double"}, + }}}}, + {Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{ + "number": {Type: api.PropertyType{"string"}, Description: "The number to triple"}, + }}}}, + }, + expected: `<|im_start|>system +You are a helpful assistant with access to tools. + +# Tools + +You have access to the following functions: + + + +double +Double a number + + +number +string +The number to double + + + + +triple +Triple a number + + +number +string +The number to triple + + + + + +If you choose to call a function ONLY reply in the following format with NO suffix: + + + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + + + + +Reminder: +- Function calls MUST follow the specified format: an inner block must be nested within XML tags +- Required parameters MUST be specified +- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +<|im_end|> +<|im_start|>user +call double(1) and triple(2)<|im_end|> +<|im_start|>assistant +I'll call double(1) and triple(2) for you. + + + + +1 + + + + + + +2 + + +<|im_end|> +<|im_start|>user + +{"number": 2} + + +{"number": 6} + +<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "prefill", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Tell me something interesting."}, + {Role: "assistant", Content: "I'll tell you something interesting about cats"}, + }, + expected: `<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +Tell me something interesting.<|im_end|> +<|im_start|>assistant +I'll tell you something interesting about cats`, + }, + { + name: "complex tool call arguments should remain json encoded", + msgs: []api.Message{ + {Role: "user", Content: "call tool"}, + {Role: "assistant", ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{ + Name: "echo", + Arguments: map[string]any{ + "payload": map[string]any{"foo": "bar"}, + }, + }}, + }}, + {Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"}, + }, + expected: `<|im_start|>user +call tool<|im_end|> +<|im_start|>assistant + + + + +{"foo":"bar"} + + +<|im_end|> +<|im_start|>user + +{"payload": {"foo": "bar"}} + +<|im_end|> +<|im_start|>assistant +`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(rendered, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestFormatToolCallArgument(t *testing.T) { + tests := []struct { + name string + arg any + expected string + }{ + { + name: "string", + arg: "foo", + // notice no quotes around the string + expected: "foo", + }, + { + name: "map", + arg: map[string]any{"foo": "bar"}, + expected: "{\"foo\":\"bar\"}", + }, + { + name: "number", + arg: 1, + expected: "1", + }, + { + name: "boolean", + arg: true, + expected: "true", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatToolCallArgument(tt.arg) + if got != tt.expected { + t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected) + } + }) + } +} diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go new file mode 100644 index 00000000..2dfb51e4 --- /dev/null +++ b/model/renderers/renderer.go @@ -0,0 +1,26 @@ +package renderers + +import ( + "fmt" + + "github.com/ollama/ollama/api" +) + +type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error) + +func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { + renderer := rendererForName(name) + if renderer == nil { + return "", fmt.Errorf("unknown renderer %q", name) + } + return renderer(msgs, tools, think) +} + +func rendererForName(name string) rendererFunc { + switch name { + case "qwen3-coder": + return Qwen3CoderRenderer + default: + return nil + } +} diff --git a/model/sentencepiece.go b/model/sentencepiece.go index 827ce00d..db07beee 100644 --- a/model/sentencepiece.go +++ b/model/sentencepiece.go @@ -12,18 +12,18 @@ import ( const spmWhitespaceSep = "▁" -type SentencePieceModel struct { +type SentencePiece struct { maxTokenLen int vocab *Vocabulary } -var _ TextProcessor = (*SentencePieceModel)(nil) +var _ TextProcessor = (*SentencePiece)(nil) -func (spm SentencePieceModel) Vocabulary() *Vocabulary { +func (spm SentencePiece) Vocabulary() *Vocabulary { return spm.vocab } -func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { +func NewSentencePiece(vocab *Vocabulary) SentencePiece { logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) counter := map[int]int{} @@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], "max token len", maxTokenLen) - return SentencePieceModel{ + return SentencePiece{ maxTokenLen: maxTokenLen, vocab: vocab, } } -func (spm SentencePieceModel) Is(id int32, special Special) bool { +func (spm SentencePiece) Is(id int32, special Special) bool { return spm.vocab.Is(id, special) } -func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) { +func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) { fragments := []fragment{{value: s}} for _, special := range spm.vocab.SpecialVocabulary() { id := spm.vocab.Encode(special) @@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} { return item } -func (spm SentencePieceModel) Decode(ids []int32) (string, error) { +func (spm SentencePiece) Decode(ids []int32) (string, error) { var sb strings.Builder for _, id := range ids { data := spm.vocab.Decode(id) diff --git a/model/sentencepiece_test.go b/model/sentencepiece_test.go index 50ac2678..8f4570c1 100644 --- a/model/sentencepiece_test.go +++ b/model/sentencepiece_test.go @@ -12,7 +12,7 @@ import ( "github.com/ollama/ollama/convert/sentencepiece" ) -func loadSentencePieceVocab(t *testing.T) SentencePieceModel { +func loadSentencePieceVocab(t *testing.T) SentencePiece { t.Helper() bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model")) @@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel { } } - return NewSentencePieceModel(&v) + return NewSentencePiece(&v) } func TestSentencePieceEncode(t *testing.T) { @@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) { }) } -func TestSentencePieceModelDecodeByteTokens(t *testing.T) { +func TestSentencePieceDecodeByteTokens(t *testing.T) { vocab := &Vocabulary{ Values: []string{ "normal", @@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) { Scores: []float32{0, 0, 0, 0, 0}, } - spm := NewSentencePieceModel(vocab) + spm := NewSentencePiece(vocab) tests := []struct { name string diff --git a/model/wordpiece.go b/model/wordpiece.go new file mode 100644 index 00000000..e8d5e848 --- /dev/null +++ b/model/wordpiece.go @@ -0,0 +1,167 @@ +package model + +import ( + "fmt" + "iter" + "strings" + "unicode" + + "github.com/ollama/ollama/logutil" +) + +type WordPiece struct { + vocab *Vocabulary +} + +// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries. +// this differs from original word piece which uses "##" to indicate subwords. +const ggmlPrefix = "▁" + +var wordPieceReplacer = strings.NewReplacer( + " .", ".", + " ?", "?", + " !", "!", + " ,", ",", + " ' ", "'", + " n't", "n't", + " 'm", "'m", + " do not", " don't", + " 's", "'s", + " 've", "'ve", + " 're", "'re", +) + +// Decode implements TextProcessor. +func (wpm WordPiece) Decode(ids []int32) (string, error) { + var sb strings.Builder + for i, id := range ids { + if id < 0 || int(id) >= len(wpm.vocab.Values) { + return "", fmt.Errorf("invalid token id: %d", id) + } + + var separator string + piece := wpm.vocab.Values[id] + if i > 0 && + (strings.HasPrefix(piece, ggmlPrefix) || + (strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) { + separator = " " + } + + sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix))) + } + + return sb.String(), nil +} + +// words splits a string into words, treating CJK characters as separate words. +// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models. +func (wpm WordPiece) words(s string) iter.Seq[string] { + return func(yield func(string) bool) { + runes := make([]rune, 0, len(s)*3) + for _, r := range s { + switch { + case r >= 0x4E00 && r <= 0x9FFF, + r >= 0x3400 && r <= 0x4DBF, + r >= 0x20000 && r <= 0x2A6DF, + r >= 0x2A700 && r <= 0x2B73F, + r >= 0x2B740 && r <= 0x2B81F, + r >= 0x2B820 && r <= 0x2CEAF, + r >= 0xF900 && r <= 0xFAFF, + r >= 0x2F800 && r <= 0x2FA1F: + runes = append(runes, ' ', r, ' ') + default: + runes = append(runes, r) + } + } + + for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) { + // split on but keep punctuation + var start int + for start < len(w) { + end := strings.IndexFunc(w[start:], unicode.IsPunct) + if end < 0 { + end = len(w) - start + } else if end == 0 { + end = 1 + } + + if !yield(w[start : start+end]) { + return + } + + start += end + } + } + } +} + +// Encode implements TextProcessor. +func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) { + var ids []int32 + + // TODO: use [UNK] from config + unk := wpm.vocab.Encode("[UNK]") + for word := range wpm.words(s) { + var start int + var pieces []int32 + for start < len(word) { + end := len(word) + + var piece int32 + for start < end { + subword := word[start:end] + if start == 0 { + subword = ggmlPrefix + subword + } + + // TODO: some models might not want [ToLower] + piece = wpm.vocab.Encode(strings.ToLower(subword)) + if piece >= 0 { + break + } + + end-- + } + + if piece < 0 { + // Unknown token + pieces = pieces[:0] + break + } + + pieces = append(pieces, piece) + start = end + } + + if len(pieces) > 0 { + ids = append(ids, pieces...) + } else { + ids = append(ids, unk) + } + } + + if addSpecial && len(ids) > 0 { + ids = wpm.vocab.addSpecials(ids) + } + + logutil.Trace("encoded", "string", s, "ids", ids) + return ids, nil +} + +// Is implements TextProcessor. +func (wpm WordPiece) Is(id int32, special Special) bool { + return wpm.vocab.Is(id, special) +} + +// Vocabulary implements TextProcessor. +func (wpm WordPiece) Vocabulary() *Vocabulary { + return wpm.vocab +} + +var _ TextProcessor = (*WordPiece)(nil) + +func NewWordPiece(vocab *Vocabulary) WordPiece { + return WordPiece{ + vocab: vocab, + } +} diff --git a/model/wordpiece_test.go b/model/wordpiece_test.go new file mode 100644 index 00000000..258fbffc --- /dev/null +++ b/model/wordpiece_test.go @@ -0,0 +1,51 @@ +package model + +import ( + "slices" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestWordPiece(t *testing.T) { + wpm := NewWordPiece( + &Vocabulary{ + Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"}, + AddBOS: true, + AddEOS: true, + BOS: []int32{1}, + EOS: []int32{2}, + }) + + ids, err := wpm.Encode("Hello world!", true) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" { + t.Errorf("unexpected ids (-want +got):\n%s", diff) + } + + words, err := wpm.Decode(ids) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } +} + +func TestWordPieceWords(t *testing.T) { + var wpm WordPiece + + basic := slices.Collect(wpm.words("Hey friend! How are you?!?")) + if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } + + chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika")) + if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } +} diff --git a/openai/openai.go b/openai/openai.go index b6a8a95e..7ef5ac6d 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -105,16 +105,18 @@ type ChatCompletionRequest struct { Tools []api.Tool `json:"tools"` Reasoning *Reasoning `json:"reasoning,omitempty"` ReasoningEffort *string `json:"reasoning_effort,omitempty"` + DebugRenderOnly bool `json:"_debug_render_only"` } type ChatCompletion struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - SystemFingerprint string `json:"system_fingerprint"` - Choices []Choice `json:"choices"` - Usage Usage `json:"usage,omitempty"` + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage,omitempty"` + DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"` } type ChatCompletionChunk struct { @@ -141,6 +143,7 @@ type CompletionRequest struct { Temperature *float32 `json:"temperature"` TopP float32 `json:"top_p"` Suffix string `json:"suffix"` + DebugRenderOnly bool `json:"_debug_render_only"` } type Completion struct { @@ -273,8 +276,8 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { } return nil }(r.DoneReason), - }}, - Usage: toUsage(r), + }}, Usage: toUsage(r), + DebugInfo: r.DebugInfo, } } @@ -568,13 +571,14 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { } return &api.ChatRequest{ - Model: r.Model, - Messages: messages, - Format: format, - Options: options, - Stream: &r.Stream, - Tools: r.Tools, - Think: think, + Model: r.Model, + Messages: messages, + Format: format, + Options: options, + Stream: &r.Stream, + Tools: r.Tools, + Think: think, + DebugRenderOnly: r.DebugRenderOnly, }, nil } @@ -648,11 +652,12 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { } return api.GenerateRequest{ - Model: r.Model, - Prompt: r.Prompt, - Options: options, - Stream: &r.Stream, - Suffix: r.Suffix, + Model: r.Model, + Prompt: r.Prompt, + Options: options, + Stream: &r.Stream, + Suffix: r.Suffix, + DebugRenderOnly: r.DebugRenderOnly, }, nil } diff --git a/parser/parser.go b/parser/parser.go index e080f1bb..c2e8f981 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -100,6 +100,10 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) req.System = c.Args case "license": licenses = append(licenses, c.Args) + case "renderer": + req.Renderer = c.Args + case "parser": + req.Parser = c.Args case "message": role, msg, _ := strings.Cut(c.Args, ": ") messages = append(messages, api.Message{Role: role, Content: msg}) @@ -320,7 +324,7 @@ func (c Command) String() string { switch c.Name { case "model": fmt.Fprintf(&sb, "FROM %s", c.Args) - case "license", "template", "system", "adapter": + case "license", "template", "system", "adapter", "renderer", "parser": fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) case "message": role, message, _ := strings.Cut(c.Args, ": ") @@ -346,7 +350,7 @@ const ( var ( errMissingFrom = errors.New("no FROM line") errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") - errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"") + errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"") ) type ParserError struct { @@ -606,7 +610,7 @@ func isValidMessageRole(role string) bool { func isValidCommand(cmd string) bool { switch strings.ToLower(cmd) { - case "from", "license", "template", "system", "adapter", "parameter", "message": + case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message": return true default: return false diff --git a/parser/parser_test.go b/parser/parser_test.go index 7d5a808b..1524e890 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -198,6 +198,34 @@ BADCOMMAND param1 value1 } } +func TestParseFileRenderer(t *testing.T) { + input := ` +FROM foo +RENDERER renderer1 +` + + reader := strings.NewReader(input) + + modelfile, err := ParseFile(reader) + require.NoError(t, err) + + assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "renderer", Args: "renderer1"}}, modelfile.Commands) +} + +func TestParseFileParser(t *testing.T) { + input := ` +FROM foo +PARSER parser1 +` + + reader := strings.NewReader(input) + + modelfile, err := ParseFile(reader) + require.NoError(t, err) + + assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "parser", Args: "parser1"}}, modelfile.Commands) +} + func TestParseFileMessages(t *testing.T) { cases := []struct { input string diff --git a/runner/llamarunner/cache.go b/runner/llamarunner/cache.go index 44b24613..9ed1c292 100644 --- a/runner/llamarunner/cache.go +++ b/runner/llamarunner/cache.go @@ -204,13 +204,8 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int { targetFree = max(targetFree, 1) currentFree := c.numCtx - inputLen - discard := targetFree - currentFree - if discard < 0 { - discard = 0 - } - - return discard + return max(targetFree-currentFree, 0) } type ErrReprocessInputs struct { diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index f558f7b8..a3ffc3bd 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -242,13 +242,8 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 { targetFree = max(targetFree, 1) currentFree := c.numCtx - inputLen - discard := targetFree - currentFree - if discard < 0 { - discard = 0 - } - - return discard + return max(targetFree-currentFree, 0) } type ErrReprocessInputs struct { diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 1081a1f5..480cfc19 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -11,7 +11,6 @@ import ( "image" "log" "log/slog" - "math" "net" "net/http" "os" @@ -32,6 +31,7 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" "github.com/ollama/ollama/runner/common" @@ -405,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { func (s *Server) run(ctx context.Context) { s.ready.Wait() - supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 + supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone var activeBatch batchState for { @@ -467,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er // Prepare the seqs and batch, but defer the input token values as we may not be ready yet var batchInputs []*input.Input + var batchOutputs []int32 var batch input.Batch resumeSeq := -1 @@ -549,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) batch.Sequences = append(batch.Sequences, seq.cache.Id) - seq.iBatch = len(batch.Outputs) - if i+1 == len(seq.inputs) { - batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) + seq.iBatch = len(batchOutputs) + if i+1 == len(seq.inputs) || seq.embeddingOnly { + batchOutputs = append(batchOutputs, int32(len(batchInputs)-1)) } logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs)) seq.pendingInputs = append(seq.pendingInputs, inp) @@ -576,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er // Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs)) + batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs)) nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch) if err != nil { err = fmt.Errorf("failed to build graph: %w", err) @@ -703,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) { } // sample a token - vocabSize := len(outputs) / len(activeBatch.batch.Outputs) - logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches) + vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0) + logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches) token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) if err != nil { s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) @@ -898,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { - if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 { + if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone { http.Error(w, "this model does not support embeddings", http.StatusNotImplemented) return } @@ -1046,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error { batch.Positions[i] = int32(i) } - batch.Outputs = make([]int32, s.parallel) - for i := range batch.Outputs { - batch.Outputs[i] = int32(i) - } - batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) + batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel) cache := s.model.Config().Cache if cache != nil { diff --git a/scripts/env.sh b/scripts/env.sh index 65a970bd..4f5641fd 100644 --- a/scripts/env.sh +++ b/scripts/env.sh @@ -16,6 +16,7 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \ --build-arg=OLLAMA_FAST_BUILD \ --build-arg=CUSTOM_CPU_FLAGS \ --build-arg=GPU_RUNNER_CPU_FLAGS \ + --build-arg=PARALLEL \ --build-arg=AMDGPU_TARGETS" echo "Building Ollama" diff --git a/server/create.go b/server/create.go index bd970876..19f24ec8 100644 --- a/server/create.go +++ b/server/create.go @@ -10,8 +10,11 @@ import ( "io" "io/fs" "log/slog" + "net" "net/http" + "net/url" "os" + "path" "path/filepath" "slices" "strings" @@ -39,6 +42,14 @@ var ( ) func (s *Server) CreateHandler(c *gin.Context) { + config := &ConfigV2{ + OS: "linux", + Architecture: "amd64", + RootFS: RootFS{ + Type: "layers", + }, + } + var r api.CreateRequest if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) @@ -48,6 +59,9 @@ func (s *Server) CreateHandler(c *gin.Context) { return } + config.Renderer = r.Renderer + config.Parser = r.Parser + for v := range r.Files { if !fs.ValidPath(v) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()}) @@ -77,20 +91,34 @@ func (s *Server) CreateHandler(c *gin.Context) { oldManifest, _ := ParseNamedManifest(name) var baseLayers []*layerGGML + var err error + var remote bool + if r.From != "" { - slog.Debug("create model from model name") + slog.Debug("create model from model name", "from", r.From) fromName := model.ParseName(r.From) if !fromName.IsValid() { ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest} return } + if r.RemoteHost != "" { + ru, err := remoteURL(r.RemoteHost) + if err != nil { + ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest} + return + } - ctx, cancel := context.WithCancel(c.Request.Context()) - defer cancel() + config.RemoteModel = r.From + config.RemoteHost = ru + remote = true + } else { + ctx, cancel := context.WithCancel(c.Request.Context()) + defer cancel() - baseLayers, err = parseFromModel(ctx, fromName, fn) - if err != nil { - ch <- gin.H{"error": err.Error()} + baseLayers, err = parseFromModel(ctx, fromName, fn) + if err != nil { + ch <- gin.H{"error": err.Error()} + } } } else if r.Files != nil { baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn) @@ -110,7 +138,7 @@ func (s *Server) CreateHandler(c *gin.Context) { } var adapterLayers []*layerGGML - if r.Adapters != nil { + if !remote && r.Adapters != nil { adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn) if err != nil { for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} { @@ -128,7 +156,56 @@ func (s *Server) CreateHandler(c *gin.Context) { baseLayers = append(baseLayers, adapterLayers...) } - if err := createModel(r, name, baseLayers, fn); err != nil { + // Info is not currently exposed by Modelfiles, but allows overriding various + // config values + if r.Info != nil { + caps, ok := r.Info["capabilities"] + if ok { + switch tcaps := caps.(type) { + case []any: + caps := make([]string, len(tcaps)) + for i, c := range tcaps { + str, ok := c.(string) + if !ok { + continue + } + caps[i] = str + } + config.Capabilities = append(config.Capabilities, caps...) + } + } + + strFromInfo := func(k string) string { + v, ok := r.Info[k] + if ok { + val := v.(string) + return val + } + return "" + } + + vFromInfo := func(k string) float64 { + v, ok := r.Info[k] + if ok { + val := v.(float64) + return val + } + return 0 + } + + config.ModelFamily = strFromInfo("model_family") + if config.ModelFamily != "" { + config.ModelFamilies = []string{config.ModelFamily} + } + + config.BaseName = strFromInfo("base_name") + config.FileType = strFromInfo("quantization_level") + config.ModelType = strFromInfo("parameter_size") + config.ContextLen = int(vFromInfo("context_length")) + config.EmbedLen = int(vFromInfo("embedding_length")) + } + + if err := createModel(r, name, baseLayers, config, fn); err != nil { if errors.Is(err, errBadTemplate) { ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest} return @@ -154,6 +231,51 @@ func (s *Server) CreateHandler(c *gin.Context) { streamResponse(c, ch) } +func remoteURL(raw string) (string, error) { + // Special‑case: user supplied only a path ("/foo/bar"). + if strings.HasPrefix(raw, "/") { + return (&url.URL{ + Scheme: "http", + Host: net.JoinHostPort("localhost", "11434"), + Path: path.Clean(raw), + }).String(), nil + } + + if !strings.Contains(raw, "://") { + raw = "http://" + raw + } + + if raw == "ollama.com" || raw == "http://ollama.com" { + raw = "https://ollama.com:443" + } + + u, err := url.Parse(raw) + if err != nil { + return "", fmt.Errorf("parse error: %w", err) + } + + if u.Host == "" { + u.Host = "localhost" + } + + hostPart, portPart, err := net.SplitHostPort(u.Host) + if err == nil { + u.Host = net.JoinHostPort(hostPart, portPart) + } else { + u.Host = net.JoinHostPort(u.Host, "11434") + } + + if u.Path != "" { + u.Path = path.Clean(u.Path) + } + + if u.Path == "/" { + u.Path = "" + } + + return u.String(), nil +} + func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) { switch detectModelTypeFromFiles(files) { case "safetensors": @@ -316,15 +438,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) { return ggml.KV{}, fmt.Errorf("no base model was found") } -func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, fn func(resp api.ProgressResponse)) (err error) { - config := ConfigV2{ - OS: "linux", - Architecture: "amd64", - RootFS: RootFS{ - Type: "layers", - }, - } - +func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) { var layers []Layer for _, layer := range baseLayers { if layer.GGML != nil { @@ -404,7 +518,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, return err } - configLayer, err := createConfigLayer(layers, config) + configLayer, err := createConfigLayer(layers, *config) if err != nil { return err } diff --git a/server/create_test.go b/server/create_test.go index 59a07ff1..061efb81 100644 --- a/server/create_test.go +++ b/server/create_test.go @@ -104,3 +104,154 @@ func TestConvertFromSafetensors(t *testing.T) { }) } } + +func TestRemoteURL(t *testing.T) { + tests := []struct { + name string + input string + expected string + hasError bool + }{ + { + name: "absolute path", + input: "/foo/bar", + expected: "http://localhost:11434/foo/bar", + hasError: false, + }, + { + name: "absolute path with cleanup", + input: "/foo/../bar", + expected: "http://localhost:11434/bar", + hasError: false, + }, + { + name: "root path", + input: "/", + expected: "http://localhost:11434/", + hasError: false, + }, + { + name: "host without scheme", + input: "example.com", + expected: "http://example.com:11434", + hasError: false, + }, + { + name: "host with port", + input: "example.com:8080", + expected: "http://example.com:8080", + hasError: false, + }, + { + name: "full URL", + input: "https://example.com:8080/path", + expected: "https://example.com:8080/path", + hasError: false, + }, + { + name: "full URL with path cleanup", + input: "https://example.com:8080/path/../other", + expected: "https://example.com:8080/other", + hasError: false, + }, + { + name: "ollama.com special case", + input: "ollama.com", + expected: "https://ollama.com:443", + hasError: false, + }, + { + name: "http ollama.com special case", + input: "http://ollama.com", + expected: "https://ollama.com:443", + hasError: false, + }, + { + name: "URL with only host", + input: "http://example.com", + expected: "http://example.com:11434", + hasError: false, + }, + { + name: "URL with root path cleaned", + input: "http://example.com/", + expected: "http://example.com:11434", + hasError: false, + }, + { + name: "invalid URL", + input: "http://[::1]:namedport", // invalid port + expected: "", + hasError: true, + }, + { + name: "empty string", + input: "", + expected: "http://localhost:11434", + hasError: false, + }, + { + name: "host with scheme but no port", + input: "http://localhost", + expected: "http://localhost:11434", + hasError: false, + }, + { + name: "complex path cleanup", + input: "/a/b/../../c/./d", + expected: "http://localhost:11434/c/d", + hasError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := remoteURL(tt.input) + + if tt.hasError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestRemoteURL_Idempotent(t *testing.T) { + // Test that applying remoteURL twice gives the same result as applying it once + testInputs := []string{ + "/foo/bar", + "example.com", + "https://example.com:8080/path", + "ollama.com", + "http://localhost:11434", + } + + for _, input := range testInputs { + t.Run(input, func(t *testing.T) { + firstResult, err := remoteURL(input) + if err != nil { + t.Fatalf("first call failed: %v", err) + } + + secondResult, err := remoteURL(firstResult) + if err != nil { + t.Fatalf("second call failed: %v", err) + } + + if firstResult != secondResult { + t.Errorf("function is not idempotent: first=%q, second=%q", firstResult, secondResult) + } + }) + } +} diff --git a/server/images.go b/server/images.go index 504eb95c..9466b7fb 100644 --- a/server/images.go +++ b/server/images.go @@ -24,6 +24,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/fs/gguf" + "github.com/ollama/ollama/model/parsers" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/template" "github.com/ollama/ollama/thinking" @@ -73,29 +74,38 @@ func (m *Model) Capabilities() []model.Capability { capabilities := []model.Capability{} // Check for completion capability - f, err := gguf.Open(m.ModelPath) - if err == nil { - defer f.Close() + if m.ModelPath != "" { + f, err := gguf.Open(m.ModelPath) + if err == nil { + defer f.Close() - if f.KeyValue("pooling_type").Valid() { - capabilities = append(capabilities, model.CapabilityEmbedding) + if f.KeyValue("pooling_type").Valid() { + capabilities = append(capabilities, model.CapabilityEmbedding) + } else { + // If no embedding is specified, we assume the model supports completion + capabilities = append(capabilities, model.CapabilityCompletion) + } + if f.KeyValue("vision.block_count").Valid() { + capabilities = append(capabilities, model.CapabilityVision) + } } else { - // If no embedding is specified, we assume the model supports completion - capabilities = append(capabilities, model.CapabilityCompletion) + slog.Error("couldn't open model file", "error", err) } - if f.KeyValue("vision.block_count").Valid() { - capabilities = append(capabilities, model.CapabilityVision) + } else if len(m.Config.Capabilities) > 0 { + for _, c := range m.Config.Capabilities { + capabilities = append(capabilities, model.Capability(c)) } } else { - slog.Error("couldn't open model file", "error", err) + slog.Warn("unknown capabilities for model", "model", m.Name) } if m.Template == nil { return capabilities } + builtinParser := parsers.ParserForName(m.Config.Parser) // Check for tools capability - if slices.Contains(m.Template.Vars(), "tools") { + if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) { capabilities = append(capabilities, model.CapabilityTools) } @@ -109,10 +119,16 @@ func (m *Model) Capabilities() []model.Capability { capabilities = append(capabilities, model.CapabilityVision) } + // Skip the thinking check if it's already set + if slices.Contains(capabilities, "thinking") { + return capabilities + } + // Check for thinking capability openingTag, closingTag := thinking.InferTags(m.Template.Template) hasTags := openingTag != "" && closingTag != "" - if hasTags || slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) { + isGptoss := slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) + if hasTags || isGptoss || (builtinParser != nil && builtinParser.HasThinkingSupport()) { capabilities = append(capabilities, model.CapabilityThinking) } @@ -198,6 +214,20 @@ func (m *Model) String() string { }) } + if m.Config.Renderer != "" { + modelfile.Commands = append(modelfile.Commands, parser.Command{ + Name: "renderer", + Args: m.Config.Renderer, + }) + } + + if m.Config.Parser != "" { + modelfile.Commands = append(modelfile.Commands, parser.Command{ + Name: "parser", + Args: m.Config.Parser, + }) + } + for k, v := range m.Options { switch v := v.(type) { case []any: @@ -236,8 +266,19 @@ type ConfigV2 struct { ModelFormat string `json:"model_format"` ModelFamily string `json:"model_family"` ModelFamilies []string `json:"model_families"` - ModelType string `json:"model_type"` - FileType string `json:"file_type"` + ModelType string `json:"model_type"` // shown as Parameter Size + FileType string `json:"file_type"` // shown as Quantization Level + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` + + RemoteHost string `json:"remote_host,omitempty"` + RemoteModel string `json:"remote_model,omitempty"` + + // used for remotes + Capabilities []string `json:"capabilities,omitempty"` + ContextLen int `json:"context_length,omitempty"` + EmbedLen int `json:"embedding_length,omitempty"` + BaseName string `json:"base_name,omitempty"` // required by spec Architecture string `json:"architecture"` diff --git a/server/internal/internal/backoff/backoff.go b/server/internal/internal/backoff/backoff.go index 1f0634f7..08b4ed7f 100644 --- a/server/internal/internal/backoff/backoff.go +++ b/server/internal/internal/backoff/backoff.go @@ -25,10 +25,7 @@ func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] { // n^2 backoff timer is a little smoother than the // common choice of 2^n. - d := time.Duration(n*n) * 10 * time.Millisecond - if d > maxBackoff { - d = maxBackoff - } + d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff) // Randomize the delay between 0.5-1.5 x msec, in order // to prevent accidental "thundering herd" problems. d = time.Duration(float64(d) * (rand.Float64() + 0.5)) diff --git a/server/prompt.go b/server/prompt.go index f1d8020e..56bc6303 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -11,6 +11,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/model/renderers" "github.com/ollama/ollama/template" ) @@ -41,18 +42,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } } - thinkVal := false - thinkLevel := "" - if think != nil { - thinkVal = think.Bool() - thinkLevel = think.String() - } - var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { + p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think) + if err != nil { return "", nil, err } - s, err := tokenize(ctx, b.String()) + s, err := tokenize(ctx, p) if err != nil { return "", nil, err } @@ -101,6 +96,23 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } // truncate any messages that do not fit into the context window + p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think) + if err != nil { + return "", nil, err + } + + return p, images, nil +} + +func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { + if m.Config.Renderer != "" { + rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think) + if err != nil { + return "", err + } + return rendered, nil + } + var b bytes.Buffer thinkVal := false thinkLevel := "" @@ -108,9 +120,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. thinkVal = think.Bool() thinkLevel = think.String() } - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { - return "", nil, err + if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { + return "", err } - - return b.String(), images, nil + return b.String(), nil } diff --git a/server/routes.go b/server/routes.go index 5114cb74..a08a7289 100644 --- a/server/routes.go +++ b/server/routes.go @@ -15,6 +15,7 @@ import ( "net" "net/http" "net/netip" + "net/url" "os" "os/signal" "slices" @@ -28,6 +29,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/discover" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" @@ -35,6 +37,7 @@ import ( "github.com/ollama/ollama/harmony" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/model/parsers" "github.com/ollama/ollama/openai" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" @@ -188,6 +191,87 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + origModel := req.Model + + remoteURL, err := url.Parse(m.Config.RemoteHost) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) { + slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname()) + c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"}) + return + } + + req.Model = m.Config.RemoteModel + + if req.Template == "" && m.Template.String() != "" { + req.Template = m.Template.String() + } + + if req.Options == nil { + req.Options = map[string]any{} + } + + for k, v := range m.Options { + if _, ok := req.Options[k]; !ok { + req.Options[k] = v + } + } + + // update the system prompt from the model if one isn't already specified + if req.System == "" && m.System != "" { + req.System = m.System + } + + if len(m.Messages) > 0 { + slog.Warn("embedded messages in the model not supported with '/api/generate'; try '/api/chat' instead") + } + + fn := func(resp api.GenerateResponse) error { + resp.Model = origModel + resp.RemoteModel = m.Config.RemoteModel + resp.RemoteHost = m.Config.RemoteHost + + data, err := json.Marshal(resp) + if err != nil { + return err + } + + if _, err = c.Writer.Write(append(data, '\n')); err != nil { + return err + } + c.Writer.Flush() + return nil + } + + client := api.NewClient(remoteURL, http.DefaultClient) + err = client.Generate(c, &req, fn) + if err != nil { + var sErr api.AuthorizationError + if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { + pk, pkErr := auth.GetPublicKey() + if pkErr != nil { + slog.Error("couldn't get public key", "error", pkErr) + c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"}) + return + } + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "unauthorized", + "public_key": pk, + }) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + return + } + // expire the runner if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { s.sched.expireRunner(m) @@ -329,10 +413,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { // If debug mode is enabled, return the rendered template instead of calling the model if req.DebugRenderOnly { - c.JSON(http.StatusOK, api.DebugTemplateResponse{ + c.JSON(http.StatusOK, api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - DebugInfo: api.DebugInfo{ + DebugInfo: &api.DebugInfo{ RenderedTemplate: prompt, ImageCount: len(images), }, @@ -348,6 +432,9 @@ func (s *Server) GenerateHandler(c *gin.Context) { OpeningTag: openingTag, ClosingTag: closingTag, } + if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) { + thinkingState.AddContent(openingTag) + } } } @@ -488,7 +575,6 @@ func (s *Server) EmbedHandler(c *gin.Context) { } truncate := true - if req.Truncate != nil && !*req.Truncate { truncate = false } @@ -551,11 +637,27 @@ func (s *Server) EmbedHandler(c *gin.Context) { ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) if len(tokens) > ctxLen { if !truncate { - c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"}) + c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"}) + return + } + + if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { + ctxLen-- + } + + if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) { + ctxLen-- + } + + slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens)) + if ctxLen <= 0 { + // return error if the truncated input would be empty or just special tokens + c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"}) return } tokens = tokens[:ctxLen] + s, err = r.Detokenize(c.Request.Context(), tokens) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -922,6 +1024,28 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { ModifiedAt: manifest.fi.ModTime(), } + if m.Config.RemoteHost != "" { + resp.RemoteHost = m.Config.RemoteHost + resp.RemoteModel = m.Config.RemoteModel + + if m.Config.ModelFamily != "" { + resp.ModelInfo = make(map[string]any) + resp.ModelInfo["general.architecture"] = m.Config.ModelFamily + + if m.Config.BaseName != "" { + resp.ModelInfo["general.basename"] = m.Config.BaseName + } + + if m.Config.ContextLen > 0 { + resp.ModelInfo[fmt.Sprintf("%s.context_length", m.Config.ModelFamily)] = m.Config.ContextLen + } + + if m.Config.EmbedLen > 0 { + resp.ModelInfo[fmt.Sprintf("%s.embedding_length", m.Config.ModelFamily)] = m.Config.EmbedLen + } + } + } + var params []string cs := 30 for k, v := range m.Options { @@ -952,6 +1076,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { fmt.Fprint(&sb, m.String()) resp.Modelfile = sb.String() + // skip loading tensor information if this is a remote model + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + return resp, nil + } + kvData, tensors, err := getModelData(m.ModelPath, req.Verbose) if err != nil { return nil, err @@ -1028,11 +1157,13 @@ func (s *Server) ListHandler(c *gin.Context) { // tag should never be masked models = append(models, api.ListModelResponse{ - Model: n.DisplayShortest(), - Name: n.DisplayShortest(), - Size: m.Size(), - Digest: m.digest, - ModifiedAt: m.fi.ModTime(), + Model: n.DisplayShortest(), + Name: n.DisplayShortest(), + RemoteModel: cf.RemoteModel, + RemoteHost: cf.RemoteHost, + Size: m.Size(), + Digest: m.digest, + ModifiedAt: m.fi.ModTime(), Details: api.ModelDetails{ Format: cf.ModelFormat, Family: cf.ModelFamily, @@ -1292,6 +1423,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/show", s.ShowHandler) r.DELETE("/api/delete", s.DeleteHandler) + r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler) + r.POST("/api/me", s.WhoamiHandler) + // Create r.POST("/api/create", s.CreateHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler) @@ -1488,6 +1622,49 @@ func streamResponse(c *gin.Context, ch chan any) { }) } +func (s *Server) WhoamiHandler(c *gin.Context) { + // todo allow other hosts + u, err := url.Parse("https://ollama.com") + if err != nil { + slog.Error(err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"}) + return + } + + client := api.NewClient(u, http.DefaultClient) + user, err := client.Whoami(c) + if err != nil { + slog.Error(err.Error()) + } + c.JSON(http.StatusOK, user) +} + +func (s *Server) SignoutHandler(c *gin.Context) { + encodedKey := c.Param("encodedKey") + + // todo allow other hosts + u, err := url.Parse("https://ollama.com") + if err != nil { + slog.Error(err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"}) + return + } + + client := api.NewClient(u, http.DefaultClient) + err = client.Signout(c, encodedKey) + if err != nil { + slog.Error(err.Error()) + if strings.Contains(err.Error(), "page not found") || strings.Contains(err.Error(), "invalid credentials") { + c.JSON(http.StatusNotFound, gin.H{"error": "you are not currently signed in"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"}) + return + } + + c.JSON(http.StatusOK, nil) +} + func (s *Server) PsHandler(c *gin.Context) { models := []api.ProcessModelResponse{} @@ -1544,21 +1721,34 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - // expire the runner - if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { - model, err := GetModel(req.Model) - if err != nil { - switch { - case os.IsNotExist(err): - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) - case err.Error() == errtypes.InvalidModelNameErrMsg: - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - default: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } - return + name := model.ParseName(req.Model) + if !name.IsValid() { + c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + return + } + + name, err := getExistingName(name) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + return + } + + m, err := GetModel(req.Model) + if err != nil { + switch { + case os.IsNotExist(err): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + case err.Error() == errtypes.InvalidModelNameErrMsg: + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } - s.sched.expireRunner(model) + return + } + + // expire the runner + if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { + s.sched.expireRunner(m) c.JSON(http.StatusOK, api.ChatResponse{ Model: req.Model, @@ -1570,6 +1760,80 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + origModel := req.Model + + remoteURL, err := url.Parse(m.Config.RemoteHost) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) { + slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname()) + c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"}) + return + } + + req.Model = m.Config.RemoteModel + if req.Options == nil { + req.Options = map[string]any{} + } + + msgs := append(m.Messages, req.Messages...) + if req.Messages[0].Role != "system" && m.System != "" { + msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...) + } + msgs = filterThinkTags(msgs, m) + req.Messages = msgs + + for k, v := range m.Options { + if _, ok := req.Options[k]; !ok { + req.Options[k] = v + } + } + + fn := func(resp api.ChatResponse) error { + resp.Model = origModel + resp.RemoteModel = m.Config.RemoteModel + resp.RemoteHost = m.Config.RemoteHost + + data, err := json.Marshal(resp) + if err != nil { + return err + } + + if _, err = c.Writer.Write(append(data, '\n')); err != nil { + return err + } + c.Writer.Flush() + return nil + } + + client := api.NewClient(remoteURL, http.DefaultClient) + err = client.Chat(c, &req, fn) + if err != nil { + var sErr api.AuthorizationError + if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { + pk, pkErr := auth.GetPublicKey() + if pkErr != nil { + slog.Error("couldn't get public key", "error", pkErr) + c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"}) + return + } + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "unauthorized", + "public_key": pk, + }) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + return + } + caps := []model.Capability{model.CapabilityCompletion} if len(req.Tools) > 0 { caps = append(caps, model.CapabilityTools) @@ -1578,17 +1842,6 @@ func (s *Server) ChatHandler(c *gin.Context) { caps = append(caps, model.CapabilityThinking) } - name := model.ParseName(req.Model) - if !name.IsValid() { - c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - } - name, err := getExistingName(name) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) @@ -1617,10 +1870,15 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) + var builtinParser parsers.Parser + if m.Config.Parser != "" { + builtinParser = parsers.ParserForName(m.Config.Parser) + } + var harmonyMessageHandler *harmony.HarmonyMessageHandler var harmonyToolParser *harmony.HarmonyToolCallAccumulator - useHarmony := shouldUseHarmony(m) + useHarmony := shouldUseHarmony(m) || m.Config.Parser == "harmony" processedTools := req.Tools if useHarmony { @@ -1650,10 +1908,10 @@ func (s *Server) ChatHandler(c *gin.Context) { // If debug mode is enabled, return the rendered template instead of calling the model if req.DebugRenderOnly { - c.JSON(http.StatusOK, api.DebugTemplateResponse{ + c.JSON(http.StatusOK, api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - DebugInfo: api.DebugInfo{ + DebugInfo: &api.DebugInfo{ RenderedTemplate: prompt, ImageCount: len(images), }, @@ -1713,6 +1971,7 @@ func (s *Server) ChatHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } + // TODO(drifkin): fold this as much as possibleinto the generic m.Config.Parser logic if useHarmony { content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser) res.Message.Content = content @@ -1739,6 +1998,27 @@ func (s *Server) ChatHandler(c *gin.Context) { ch <- res } + return + } else if builtinParser != nil { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content) + + content, thinking, toolCalls, err := builtinParser.Add(r.Content, req.Tools) + if err != nil { + ch <- gin.H{"error": err.Error()} + return + } + + res.Message.Content = content + res.Message.Thinking = thinking + res.Message.ToolCalls = toolCalls + + if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done) + ch <- res + } else { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser) + } + return } diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 3b3d9910..189ef040 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -11,6 +11,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "reflect" "slices" "strings" "testing" @@ -20,6 +21,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/types/model" ) var stream bool = false @@ -615,6 +617,78 @@ func TestCreateTemplateSystem(t *testing.T) { }) } +func TestCreateAndShowRemoteModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + var s Server + + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test", + From: "bob", + RemoteHost: "https://ollama.com", + Info: map[string]any{ + "capabilities": []string{"completion", "tools", "thinking"}, + "model_family": "gptoss", + "context_length": 131072, + "embedding_length": 2880, + "quantization_level": "MXFP4", + "parameter_size": "20.9B", + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("exected status code 200, actual %d", w.Code) + } + + w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test"}) + if w.Code != http.StatusOK { + t.Fatalf("exected status code 200, actual %d", w.Code) + } + + var resp api.ShowResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + expectedDetails := api.ModelDetails{ + ParentModel: "", + Format: "", + Family: "gptoss", + Families: []string{"gptoss"}, + ParameterSize: "20.9B", + QuantizationLevel: "MXFP4", + } + + if !reflect.DeepEqual(resp.Details, expectedDetails) { + t.Errorf("model details: expected %#v, actual %#v", expectedDetails, resp.Details) + } + + expectedCaps := []model.Capability{ + model.Capability("completion"), + model.Capability("tools"), + model.Capability("thinking"), + } + + if !slices.Equal(resp.Capabilities, expectedCaps) { + t.Errorf("capabilities: expected %#v, actual %#v", expectedCaps, resp.Capabilities) + } + + v, ok := resp.ModelInfo["gptoss.context_length"] + ctxlen := v.(float64) + if !ok || int(ctxlen) != 131072 { + t.Errorf("context len: expected %d, actual %d", 131072, int(ctxlen)) + } + + v, ok = resp.ModelInfo["gptoss.embedding_length"] + embedlen := v.(float64) + if !ok || int(embedlen) != 2880 { + t.Errorf("embed len: expected %d, actual %d", 2880, int(embedlen)) + } + + fmt.Printf("resp = %#v\n", resp) +} + func TestCreateLicenses(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/server/routes_debug_test.go b/server/routes_debug_test.go index f04a1da9..6507284e 100644 --- a/server/routes_debug_test.go +++ b/server/routes_debug_test.go @@ -180,7 +180,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) { t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String()) } - var response api.DebugTemplateResponse + var response api.GenerateResponse if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } @@ -385,7 +385,7 @@ func TestChatDebugRenderOnly(t *testing.T) { t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String()) } - var response api.DebugTemplateResponse + var response api.ChatResponse if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } diff --git a/server/routes_test.go b/server/routes_test.go index 87b52663..bb7e2b7c 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -126,7 +126,15 @@ func TestRoutes(t *testing.T) { t.Fatalf("failed to create model: %v", err) } - if err := createModel(r, modelName, baseLayers, fn); err != nil { + config := &ConfigV2{ + OS: "linux", + Architecture: "amd64", + RootFS: RootFS{ + Type: "layers", + }, + } + + if err := createModel(r, modelName, baseLayers, config, fn); err != nil { t.Fatal(err) } } diff --git a/server/sched.go b/server/sched.go index c501c0e8..74aa406a 100644 --- a/server/sched.go +++ b/server/sched.go @@ -382,10 +382,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm // load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs // (if any). Returns whether the scheduler needs to evict a model to make this one fit. func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool { - numParallel := int(envconfig.NumParallel()) - if numParallel < 1 { - numParallel = 1 - } + numParallel := max(int(envconfig.NumParallel()), 1) // Embedding models should always be loaded with parallel=1 if req.model.CheckCapabilities(model.CapabilityCompletion) != nil {