Merge branch 'ollama:main' into main

This commit is contained in:
likelovewant
2025-11-22 17:42:25 +08:00
committed by GitHub
82 changed files with 5064 additions and 907 deletions

2
.gitattributes vendored
View File

@@ -15,6 +15,8 @@ ml/backend/**/*.cu linguist-vendored
ml/backend/**/*.cuh linguist-vendored
ml/backend/**/*.m linguist-vendored
ml/backend/**/*.metal linguist-vendored
ml/backend/**/*.comp linguist-vendored
ml/backend/**/*.glsl linguist-vendored
ml/backend/**/CMakeLists.txt linguist-vendored
llama/build-info.cpp linguist-generated

View File

@@ -366,6 +366,7 @@ jobs:
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;

View File

@@ -226,12 +226,9 @@ jobs:
if: always()
run: go test -count=1 -benchtime=1x ./...
# TODO(bmizerany): replace this heavy tool with just the
# tools/checks/binaries we want and then make them all run in parallel
# across jobs, not on a single tiny vm on Github Actions.
- uses: golangci/golangci-lint-action@v6
- uses: golangci/golangci-lint-action@v9
with:
args: --timeout 10m0s -v
only-new-issues: true
patches:
runs-on: ubuntu-latest
@@ -240,4 +237,4 @@ jobs:
- name: Verify patches apply cleanly and do not change files
run: |
make -f Makefile.sync clean checkout apply-patches sync
git diff --compact-summary --exit-code
git diff --compact-summary --exit-code

View File

@@ -1,41 +1,77 @@
run:
timeout: 5m
version: "2"
linters:
default: none
enable:
- asasalint
- bidichk
- bodyclose
- containedctx
- copyloopvar
- errcheck
- errorlint
- exptostd
- gocheckcompilerdirectives
- gofmt
- gofumpt
- gosimple
- gocritic
- govet
- ineffassign
- intrange
- makezero
- misspell
- modernize
- nilerr
- nilnil
- nolintlint
- nosprintfhostport
- perfsprint
- prealloc
- sloglint
- staticcheck
- unconvert
- unused
- usestdlibvars
- usetesting
- wastedassign
- whitespace
disable:
- usestdlibvars
- errcheck
linters-settings:
staticcheck:
checks:
- all
- -SA1019 # omit Deprecated check
severity:
default-severity: error
rules:
- linters:
- gofmt
- goimports
- intrange
severity: info
settings:
errcheck:
exclude-functions:
- fmt.Fprintf
perfsprint:
strconcat: false
concat-loop: false
staticcheck:
checks:
- all
# Using a deprecated function, variable, constant or field.
# https://staticcheck.dev/docs/checks/#SA1019
- -SA1019
# Incorrect or missing package comment.
# https://staticcheck.dev/docs/checks/#ST1000
- -ST1000
# Poorly chosen identifier.
# https://staticcheck.dev/docs/checks/#ST1003
- -ST1003
# The documentation of an exported function should start with the function's name.
# https://staticcheck.dev/docs/checks/#ST1020
- -ST1020
# The documentation of an exported type should start with type's name.
# https://staticcheck.dev/docs/checks/#ST1021
- -ST1021
# The documentation of an exported variable or constant should start with variable's name.
# https://staticcheck.dev/docs/checks/#ST1022
- -ST1022
usestdlibvars:
http-method: false
http-status-code: false
formatters:
enable:
- gci
- gofmt
- gofumpt
settings:
gci:
sections:
- standard
- default
- localmodule

View File

@@ -16,7 +16,7 @@ See the [development documentation](./docs/development.md) for instructions on h
* New features: new features (e.g. API fields, environment variables) add surface area to Ollama and make it harder to maintain in the long run as they cannot be removed without potentially breaking users in the future.
* Refactoring: large code improvements are important, but can be harder or take longer to review and merge.
* Documentation: small updates to fill in or correct missing documentation is helpful, however large documentation additions can be hard to maintain over time.
* Documentation: small updates to fill in or correct missing documentation are helpful, however large documentation additions can be hard to maintain over time.
### Issues that may not be accepted
@@ -43,7 +43,7 @@ Tips for proposals:
* Explain how the change will be tested.
Additionally, for bonus points: Provide draft documentation you would expect to
see if the change were accepted.
see if the changes were accepted.
## Pull requests
@@ -66,7 +66,6 @@ Examples:
llm/backend/mlx: support the llama architecture
CONTRIBUTING: provide clarity on good commit messages, and bad
docs: simplify manual installation with shorter curl commands
Bad Examples:

View File

@@ -39,14 +39,14 @@ ENV CC=clang CXX=clang++
FROM base-${TARGETARCH} AS base
ARG CMAKEVERSION
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ENV LDFLAGS=-s
FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
@@ -57,6 +57,8 @@ ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
@@ -67,6 +69,8 @@ ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
@@ -78,6 +82,8 @@ ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
@@ -87,6 +93,8 @@ RUN --mount=type=cache,target=/root/.ccache \
FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
@@ -118,6 +126,8 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS vulkan
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \

View File

@@ -389,6 +389,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VS Code extension for multi-file/whole-repo coding
- [Void](https://github.com/voideditor/void) (Open source AI code editor and Cursor alternative)
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
@@ -449,6 +450,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
- [KDeps](https://github.com/kdeps/kdeps) (Kdeps is an offline-first AI framework for building Dockerized full-stack AI applications declaratively using Apple PKL and integrates APIs with Ollama on the backend.)
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
@@ -638,7 +640,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
- [AI Summmary Helper plugin](https://github.com/philffm/ai-summary-helper)
- [AI Summary Helper plugin](https://github.com/philffm/ai-summary-helper)
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
@@ -646,7 +648,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Editor tool to analyze scripts via Ollama)
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
@@ -656,7 +658,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.

View File

@@ -14,7 +14,7 @@ Please include the following details in your report:
## Security best practices
While the maintainer team does their best to secure Ollama, users are encouraged to implement their own security best practices, such as:
While the maintainer team does its best to secure Ollama, users are encouraged to implement their own security best practices, such as:
- Regularly updating to the latest version of Ollama
- Securing access to hosted instances of Ollama

View File

@@ -397,8 +397,8 @@ func checkUserLoggedIn(uiServerPort int) bool {
// handleConnectURLScheme fetches the connect URL and opens it in the browser
func handleConnectURLScheme() {
if checkUserLoggedIn(uiServerPort) {
slog.Info("user is already logged in, opening settings instead")
sendUIRequestMessage("/")
slog.Info("user is already logged in, opening app instead")
showWindow(wv.webview.Window())
return
}
@@ -434,37 +434,30 @@ func openInBrowser(url string) {
}
}
// parseURLScheme parses an ollama:// URL and returns whether it's a connect URL and the UI path
func parseURLScheme(urlSchemeRequest string) (isConnect bool, uiPath string, err error) {
// parseURLScheme parses an ollama:// URL and validates it
// Supports: ollama:// (open app) and ollama://connect (OAuth)
func parseURLScheme(urlSchemeRequest string) (isConnect bool, err error) {
parsedURL, err := url.Parse(urlSchemeRequest)
if err != nil {
return false, "", err
return false, fmt.Errorf("invalid URL: %w", err)
}
// Check if this is a connect URL
if parsedURL.Host == "connect" || strings.TrimPrefix(parsedURL.Path, "/") == "connect" {
return true, "", nil
return true, nil
}
// Extract the UI path
path := "/"
if parsedURL.Path != "" && parsedURL.Path != "/" {
// For URLs like ollama:///settings, use the path directly
path = parsedURL.Path
} else if parsedURL.Host != "" {
// For URLs like ollama://settings (without triple slash),
// the "settings" part is parsed as the host, not the path.
// We need to convert it to a path by prepending "/"
// This also handles ollama://settings/ where Windows adds a trailing slash
path = "/" + parsedURL.Host
// Allow bare ollama:// or ollama:/// to open the app
if (parsedURL.Host == "" && parsedURL.Path == "") || parsedURL.Path == "/" {
return false, nil
}
return false, path, nil
return false, fmt.Errorf("unsupported ollama:// URL path: %s", urlSchemeRequest)
}
// handleURLSchemeInCurrentInstance processes URL scheme requests in the current instance
func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
isConnect, uiPath, err := parseURLScheme(urlSchemeRequest)
isConnect, err := parseURLScheme(urlSchemeRequest)
if err != nil {
slog.Error("failed to parse URL scheme request", "url", urlSchemeRequest, "error", err)
return
@@ -473,6 +466,8 @@ func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
if isConnect {
handleConnectURLScheme()
} else {
sendUIRequestMessage(uiPath)
if wv.webview != nil {
showWindow(wv.webview.Window())
}
}
}

View File

@@ -24,27 +24,14 @@ bool firstTimeRun,startHidden; // Set in run before initialization
for (NSURL *url in urls) {
if ([url.scheme isEqualToString:@"ollama"]) {
NSString *path = url.path;
if (!path || [path isEqualToString:@""]) {
// For URLs like ollama://settings (without triple slash),
// the "settings" part is parsed as the host, not the path.
// We need to convert it to a path by prepending "/"
if (url.host && ![url.host isEqualToString:@""]) {
path = [@"/" stringByAppendingString:url.host];
} else {
path = @"/";
}
}
if ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"]) {
if (path && ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"])) {
// Special case: handle connect by opening browser instead of app
handleConnectURL();
} else {
// Set app to be active and visible
[NSApp setActivationPolicy:NSApplicationActivationPolicyRegular];
[NSApp activateIgnoringOtherApps:YES];
// Open the path with the UI
[self uiRequest:path];
}
break;
@@ -260,7 +247,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)openHelp:(id)sender {
NSURL *url = [NSURL URLWithString:@"https://github.com/ollama/ollama/tree/main/docs"];
NSURL *url = [NSURL URLWithString:@"https://docs.ollama.com/"];
[[NSWorkspace sharedWorkspace] openURL:url];
}

View File

@@ -138,7 +138,7 @@ func (app *appCallbacks) HandleURLScheme(urlScheme string) {
// handleURLSchemeRequest processes URL scheme requests from other instances
func handleURLSchemeRequest(urlScheme string) {
isConnect, uiPath, err := parseURLScheme(urlScheme)
isConnect, err := parseURLScheme(urlScheme)
if err != nil {
slog.Error("failed to parse URL scheme request", "url", urlScheme, "error", err)
return
@@ -147,7 +147,9 @@ func handleURLSchemeRequest(urlScheme string) {
if isConnect {
handleConnectURLScheme()
} else {
sendUIRequestMessage(uiPath)
if wv.webview != nil {
showWindow(wv.webview.Window())
}
}
}

View File

@@ -15,6 +15,7 @@ import {
import { parseJsonlFromResponse } from "./util/jsonl-parsing";
import { ollamaClient as ollama } from "./lib/ollama-client";
import type { ModelResponse } from "ollama/browser";
import { API_BASE } from "./lib/config";
// Extend Model class with utility methods
declare module "@/gotypes" {
@@ -27,8 +28,6 @@ Model.prototype.isCloud = function (): boolean {
return this.model.endsWith("cloud");
};
const API_BASE = import.meta.env.DEV ? "http://127.0.0.1:3001" : "";
// Helper function to convert Uint8Array to base64
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow

View File

@@ -0,0 +1,10 @@
// API configuration
const DEV_API_URL = "http://127.0.0.1:3001";
// Base URL for fetch API calls (can be relative in production)
export const API_BASE = import.meta.env.DEV ? DEV_API_URL : "";
// Full host URL for Ollama client (needs full origin in production)
export const OLLAMA_HOST = import.meta.env.DEV
? DEV_API_URL
: window.location.origin;

View File

@@ -1,4 +1,5 @@
import { Ollama } from "ollama/browser";
import { OLLAMA_HOST } from "./config";
let _ollamaClient: Ollama | null = null;
@@ -6,7 +7,7 @@ export const ollamaClient = new Proxy({} as Ollama, {
get(_target, prop) {
if (!_ollamaClient) {
_ollamaClient = new Ollama({
host: window.location.origin,
host: OLLAMA_HOST,
});
}
const value = _ollamaClient[prop as keyof Ollama];

114
cmd/bench/README.md Normal file
View File

@@ -0,0 +1,114 @@
Ollama Benchmark Tool
---------------------
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
## Features
* Benchmark multiple models in a single run
* Support for both text and image prompts
* Configurable generation parameters (temperature, max tokens, seed, etc.)
* Supports benchstat and CSV output formats
* Detailed performance metrics (prefill, generate, load, total durations)
## Building from Source
```
go build -o ollama-bench bench.go
./bench -model gpt-oss:20b -epochs 6 -format csv
```
Using Go Run (without building)
```
go run bench.go -model gpt-oss:20b -epochs 3
```
## Usage
### Basic Example
```
./bench -model gemma3 -epochs 6
```
### Benchmark Multiple Models
```
./bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
benchstat -col /name gemma.bench
```
### With Image Prompt
```
./bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
```
### Advanced Example
```
./bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
```
## Command Line Options
| Option | Description | Default |
| -model | Comma-separated list of models to benchmark | (required) |
| -epochs | Number of iterations per model | 1 |
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
| -temperature | Temperature parameter | 0.0 |
| -seed | Random seed | 0 (random) |
| -timeout | Timeout in seconds | 300 |
| -p | Prompt text | "Write a long story." |
| -image | Image file to include in prompt | |
| -k | Keep-alive duration in seconds | 0 |
| -format | Output format (benchstat, csv) | benchstat |
| -output | Output file for results | "" (stdout) |
| -v | Verbose mode | false |
| -debug | Show debug information | false |
## Output Formats
### Markdown Format
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
```
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|-------|------|-------|----------|------------|--------------|
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
```
### Benchstat Format
Compatible with Go's benchstat tool for statistical analysis:
```
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
```
### CSV Format
Machine-readable comma-separated values:
```
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
gpt-oss:20b,prefill,128,78125.00,12800.00
gpt-oss:20b,generate,512,19531.25,51200.00
gpt-oss:20b,load,1,1500000000,0
```
## Metrics Explained
The tool reports four types of metrics for each model:
* prefill: Time spent processing the prompt
* generate: Time spent generating the response
* load: Model loading time (one-time cost)
* total: Total request duration

309
cmd/bench/bench.go Normal file
View File

@@ -0,0 +1,309 @@
package main
import (
"cmp"
"context"
"flag"
"fmt"
"io"
"os"
"runtime"
"slices"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
)
type flagOptions struct {
models *string
epochs *int
maxTokens *int
temperature *float64
seed *int
timeout *int
prompt *string
imageFile *string
keepAlive *float64
format *string
outputFile *string
debug *bool
verbose *bool
}
type Metrics struct {
Model string
Step string
Count int
Duration time.Duration
}
var once sync.Once
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
switch format {
case "benchstat":
if verbose {
printHeader := func() {
fmt.Printf("sysname: %s\n", runtime.GOOS)
fmt.Printf("machine: %s\n", runtime.GOARCH)
}
once.Do(printHeader)
}
for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" {
if m.Count > 0 {
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
} else {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
m.Model, m.Step, m.Count)
}
} else {
var suffix string
if m.Step == "load" {
suffix = "/step=load"
}
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
m.Model, suffix, m.Duration.Nanoseconds())
}
}
case "csv":
printHeader := func() {
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
fmt.Fprintln(w, strings.Join(headings, ","))
}
once.Do(printHeader)
for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" {
var nsPerToken float64
var tokensPerSec float64
if m.Count > 0 {
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
}
fmt.Fprintf(w, "%s,%s,%d,%.2f,%.2f\n", m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
} else {
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
}
}
case "markdown":
printHeader := func() {
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
}
once.Do(printHeader)
for _, m := range metrics {
var nsPerToken, tokensPerSec float64
var nsPerTokenStr, tokensPerSecStr string
if m.Step == "generate" || m.Step == "prefill" {
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
} else {
nsPerTokenStr = "-"
tokensPerSecStr = "-"
}
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
}
default:
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
}
}
func BenchmarkChat(fOpt flagOptions) error {
models := strings.Split(*fOpt.models, ",")
// todo - add multi-image support
var imgData api.ImageData
var err error
if *fOpt.imageFile != "" {
imgData, err = readImage(*fOpt.imageFile)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't read image '%s': %v\n", *fOpt.imageFile, err)
return err
}
}
if *fOpt.debug && imgData != nil {
fmt.Fprintf(os.Stderr, "Read file '%s'\n", *fOpt.imageFile)
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't create ollama client: %v\n", err)
return err
}
for _, model := range models {
for range *fOpt.epochs {
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
var keepAliveDuration *api.Duration
if *fOpt.keepAlive > 0 {
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
keepAliveDuration = &duration
}
req := &api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: *fOpt.prompt,
},
},
Options: options,
KeepAlive: keepAliveDuration,
}
if imgData != nil {
req.Messages[0].Images = []api.ImageData{imgData}
}
var responseMetrics *api.Metrics
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
defer cancel()
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
}
if resp.Done {
responseMetrics = &resp.Metrics
}
return nil
})
if *fOpt.debug {
fmt.Fprintln(os.Stderr)
}
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
continue
}
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
continue
}
if responseMetrics == nil {
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
continue
}
metrics := []Metrics{
{
Model: model,
Step: "prefill",
Count: responseMetrics.PromptEvalCount,
Duration: responseMetrics.PromptEvalDuration,
},
{
Model: model,
Step: "generate",
Count: responseMetrics.EvalCount,
Duration: responseMetrics.EvalDuration,
},
{
Model: model,
Step: "load",
Count: 1,
Duration: responseMetrics.LoadDuration,
},
{
Model: model,
Step: "total",
Count: 1,
Duration: responseMetrics.TotalDuration,
},
}
OutputMetrics(os.Stdout, *fOpt.format, metrics, *fOpt.verbose)
if *fOpt.keepAlive > 0 {
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
}
}
}
return nil
}
func readImage(filePath string) (api.ImageData, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
return nil, err
}
return api.ImageData(data), nil
}
func main() {
fOpt := flagOptions{
models: flag.String("model", "", "Model to benchmark"),
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
seed: flag.Int("seed", 0, "Random seed"),
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
imageFile: flag.String("image", "", "Filename for an image to include"),
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
verbose: flag.Bool("v", false, "Show system information"),
debug: flag.Bool("debug", false, "Show debug information"),
}
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0])
fmt.Fprintf(os.Stderr, "Description:\n")
fmt.Fprintf(os.Stderr, " Model benchmarking tool with configurable parameters\n\n")
fmt.Fprintf(os.Stderr, "Options:\n")
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nExamples:\n")
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
}
flag.Parse()
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
os.Exit(1)
}
if len(*fOpt.models) == 0 {
fmt.Fprintf(os.Stderr, "ERROR: No model(s) specified to benchmark.\n")
flag.Usage()
return
}
BenchmarkChat(fOpt)
}

463
cmd/bench/bench_test.go Normal file
View File

@@ -0,0 +1,463 @@
package main
import (
"bytes"
"crypto/rand"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func createTestFlagOptions() flagOptions {
models := "test-model"
format := "benchstat"
epochs := 1
maxTokens := 100
temperature := 0.7
seed := 42
timeout := 30
prompt := "test prompt"
imageFile := ""
keepAlive := 5.0
verbose := false
debug := false
return flagOptions{
models: &models,
format: &format,
epochs: &epochs,
maxTokens: &maxTokens,
temperature: &temperature,
seed: &seed,
timeout: &timeout,
prompt: &prompt,
imageFile: &imageFile,
keepAlive: &keepAlive,
verbose: &verbose,
debug: &debug,
}
}
func captureOutput(f func()) string {
oldStdout := os.Stdout
oldStderr := os.Stderr
defer func() {
os.Stdout = oldStdout
os.Stderr = oldStderr
}()
r, w, _ := os.Pipe()
os.Stdout = w
os.Stderr = w
f()
w.Close()
var buf bytes.Buffer
io.Copy(&buf, r)
return buf.String()
}
func createMockOllamaServer(t *testing.T, responses []api.ChatResponse) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/chat" {
t.Errorf("Expected path /api/chat, got %s", r.URL.Path)
http.Error(w, "Not found", http.StatusNotFound)
return
}
if r.Method != "POST" {
t.Errorf("Expected POST method, got %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
for _, resp := range responses {
jsonData, err := json.Marshal(resp)
if err != nil {
t.Errorf("Failed to marshal response: %v", err)
return
}
w.Write(jsonData)
w.Write([]byte("\n"))
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
time.Sleep(10 * time.Millisecond) // Simulate some delay
}
}))
}
func TestBenchmarkChat_Success(t *testing.T) {
fOpt := createTestFlagOptions()
mockResponses := []api.ChatResponse{
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response part 1",
},
Done: false,
},
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response part 2",
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
},
}
server := createMockOllamaServer(t, mockResponses)
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=prefill") {
t.Errorf("Expected output to contain prefill metrics, got: %s", output)
}
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=generate") {
t.Errorf("Expected output to contain generate metrics, got: %s", output)
}
if !strings.Contains(output, "ns/token") {
t.Errorf("Expected output to contain ns/token metric, got: %s", output)
}
}
func TestBenchmarkChat_ServerError(t *testing.T) {
fOpt := createTestFlagOptions()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Internal server error", http.StatusInternalServerError)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected error to be handled internally, got returned error: %v", err)
}
})
if !strings.Contains(output, "ERROR: Couldn't chat with model") {
t.Errorf("Expected error message about chat failure, got: %s", output)
}
}
func TestBenchmarkChat_Timeout(t *testing.T) {
fOpt := createTestFlagOptions()
shortTimeout := 1 // Very short timeout
fOpt.timeout = &shortTimeout
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate a long delay that will cause timeout
time.Sleep(2 * time.Second)
w.Header().Set("Content-Type", "application/json")
response := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response",
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
}
jsonData, _ := json.Marshal(response)
w.Write(jsonData)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected timeout to be handled internally, got returned error: %v", err)
}
})
if !strings.Contains(output, "ERROR: Chat request timed out") {
t.Errorf("Expected timeout error message, got: %s", output)
}
}
func TestBenchmarkChat_NoMetrics(t *testing.T) {
fOpt := createTestFlagOptions()
mockResponses := []api.ChatResponse{
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response",
},
Done: false, // Never sends Done=true
},
}
server := createMockOllamaServer(t, mockResponses)
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
if !strings.Contains(output, "ERROR: No metrics received") {
t.Errorf("Expected no metrics error message, got: %s", output)
}
}
func TestBenchmarkChat_MultipleModels(t *testing.T) {
fOpt := createTestFlagOptions()
models := "model1,model2"
epochs := 2
fOpt.models = &models
fOpt.epochs = &epochs
callCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
var req api.ChatRequest
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &req)
response := api.ChatResponse{
Model: req.Model,
Message: api.Message{
Role: "assistant",
Content: "test response for " + req.Model,
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
}
jsonData, _ := json.Marshal(response)
w.Write(jsonData)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
// Should be called 4 times (2 models × 2 epochs)
if callCount != 4 {
t.Errorf("Expected 4 API calls, got %d", callCount)
}
if !strings.Contains(output, "BenchmarkModel/name=model1") || !strings.Contains(output, "BenchmarkModel/name=model2") {
t.Errorf("Expected output for both models, got: %s", output)
}
}
func TestBenchmarkChat_WithImage(t *testing.T) {
fOpt := createTestFlagOptions()
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpfile.Name())
content := []byte("fake image data")
if _, err := tmpfile.Write(content); err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
tmpfile.Close()
tmpfileName := tmpfile.Name()
fOpt.imageFile = &tmpfileName
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify the request contains image data
var req api.ChatRequest
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &req)
if len(req.Messages) == 0 || len(req.Messages[0].Images) == 0 {
t.Error("Expected request to contain images")
}
w.Header().Set("Content-Type", "application/json")
response := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response with image",
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
}
jsonData, _ := json.Marshal(response)
w.Write(jsonData)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
if !strings.Contains(output, "BenchmarkModel/name=test-model") {
t.Errorf("Expected benchmark output, got: %s", output)
}
}
func TestBenchmarkChat_ImageError(t *testing.T) {
randFileName := func() string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
const length = 8
result := make([]byte, length)
rand.Read(result) // Fill with random bytes
for i := range result {
result[i] = charset[result[i]%byte(len(charset))]
}
return string(result) + ".txt"
}
fOpt := createTestFlagOptions()
imageFile := randFileName()
fOpt.imageFile = &imageFile
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err == nil {
t.Error("Expected error from image reading, got nil")
}
})
if !strings.Contains(output, "ERROR: Couldn't read image") {
t.Errorf("Expected image read error message, got: %s", output)
}
}
func TestReadImage_Success(t *testing.T) {
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpfile.Name())
content := []byte("fake image data")
if _, err := tmpfile.Write(content); err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
tmpfile.Close()
imgData, err := readImage(tmpfile.Name())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if imgData == nil {
t.Error("Expected image data, got nil")
}
expected := api.ImageData(content)
if string(imgData) != string(expected) {
t.Errorf("Expected image data %v, got %v", expected, imgData)
}
}
func TestReadImage_FileNotFound(t *testing.T) {
imgData, err := readImage("nonexistentfile.jpg")
if err == nil {
t.Error("Expected error for non-existent file, got nil")
}
if imgData != nil {
t.Error("Expected nil image data for non-existent file")
}
}
func TestOptionsMapCreation(t *testing.T) {
fOpt := createTestFlagOptions()
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
if options["num_predict"] != *fOpt.maxTokens {
t.Errorf("Expected num_predict %d, got %v", *fOpt.maxTokens, options["num_predict"])
}
if options["temperature"] != *fOpt.temperature {
t.Errorf("Expected temperature %f, got %v", *fOpt.temperature, options["temperature"])
}
if options["seed"] != *fOpt.seed {
t.Errorf("Expected seed %d, got %v", *fOpt.seed, options["seed"])
}
}

View File

@@ -206,6 +206,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &commandrModel{}
case "GptOssForCausalLM":
conv = &gptossModel{}
case "DeepseekOCRForCausalLM":
conv = &deepseekocr{}
default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}

View File

@@ -0,0 +1,136 @@
package convert
import (
"fmt"
"github.com/ollama/ollama/fs/ggml"
)
type deepseekocr struct {
ModelParameters
LanguageConfig struct {
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
NumRoutedExperts uint32 `json:"n_routed_experts"`
NumSharedExperts uint32 `json:"n_shared_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
FirstKDenseReplace uint32 `json:"first_k_dense_replace"`
} `json:"language_config"`
VisionConfig struct {
ImageSize uint32 `json:"image_size"`
Width struct {
Vision struct {
Heads uint32 `json:"heads"`
ImageSize uint32 `json:"image_size"`
Layers uint32 `json:"layers"`
PatchSize uint32 `json:"patch_size"`
Width uint32 `json:"width"`
} `json:"clip-l-14-224"`
Sam struct {
GlobalAttentionIndexes []int32 `json:"global_attn_indexes"`
Heads uint32 `json:"heads"`
Layers uint32 `json:"layers"`
Width uint32 `json:"width"`
} `json:"sam_vit_b"`
}
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers
kv["context_length"] = m.LanguageConfig.MaxPositionEmbeddings
kv["embedding_length"] = m.LanguageConfig.HiddenSize
kv["feed_forward_length"] = m.LanguageConfig.IntermediateSize
kv["attention.head_count"] = m.LanguageConfig.NumAttentionHeads
kv["attention.head_count_kv"] = m.LanguageConfig.NumKeyValueHeads
kv["expert_count"] = m.LanguageConfig.NumRoutedExperts
kv["expert_used_count"] = m.LanguageConfig.NumExpertsPerToken
kv["leading_dense_block_count"] = m.LanguageConfig.FirstKDenseReplace
kv["vision.block_count"] = m.VisionConfig.Width.Vision.Layers
kv["vision.embedding_length"] = m.VisionConfig.Width.Vision.Width
kv["vision.head_count"] = m.VisionConfig.Width.Vision.Heads
kv["vision.image_size"] = m.VisionConfig.Width.Vision.ImageSize
kv["vision.patch_size"] = m.VisionConfig.Width.Vision.PatchSize
kv["sam.block_count"] = m.VisionConfig.Width.Sam.Layers
kv["sam.embedding_length"] = m.VisionConfig.Width.Sam.Width
kv["sam.head_count"] = m.VisionConfig.Width.Sam.Heads
kv["sam.global_attention_indexes"] = m.VisionConfig.Width.Sam.GlobalAttentionIndexes
return kv
}
func (m *deepseekocr) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, m.LanguageConfig.HiddenLayers*3)
for i := range m.LanguageConfig.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *deepseekocr) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
"mlp.gate", "ffn_gate_inp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"model.norm", "output_norm",
"lm_head", "output",
"model.vision_model", "v",
"embeddings.patch_embedding", "patch_embd",
"embeddings.class_embedding", "class_embd",
"embeddings.position_embedding", "position_embd",
"transformer.layers", "blk",
"model.projector", "mm",
"model.image_newline", "mm.image_newline",
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
"model.view_seperator", "mm.view_seperator",
"model.sam_model.patch_embed.proj", "s.patch_embd",
"model.sam_model.pos_embed", "s.position_embd",
"model.sam_model.blocks", "s.blk",
"model.sam_model.neck", "s.neck",
"model.sam_model.net_", "s.net_",
}
}

View File

@@ -44,7 +44,10 @@ func (t tensorBase) Kind() uint32 {
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||
t.name == "v.pre_tile_position_embd.weight" ||
t.name == "v.post_tile_position_embd.weight" {
t.name == "v.post_tile_position_embd.weight" ||
t.name == "s.position_embd" ||
strings.HasSuffix(t.name, "rel_pos_h") ||
strings.HasSuffix(t.name, "rel_pos_w") {
// these tensors are always F32
return tensorKindFP32
}

View File

@@ -96,7 +96,10 @@ type safetensor struct {
func (st safetensor) Kind() uint32 {
kind := st.tensorBase.Kind()
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
if st.dtype == "BF16" &&
!strings.HasPrefix(st.name, "v.") &&
!strings.HasPrefix(st.name, "s.") &&
kind != tensorKindFP32 {
kind = tensorKindBF16
}

View File

@@ -2,6 +2,7 @@ package discover
import (
"bufio"
"errors"
"fmt"
"io"
"log/slog"
@@ -10,12 +11,21 @@ import (
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
func GetCPUMem() (memInfo, error) {
mem, err := getCPUMem()
if err != nil {
return memInfo{}, err
}
return getCPUMemByCgroups(mem), nil
}
func getCPUMem() (memInfo, error) {
var mem memInfo
var total, available, free, buffers, cached, freeSwap uint64
f, err := os.Open("/proc/meminfo")
@@ -56,6 +66,32 @@ func GetCPUMem() (memInfo, error) {
return mem, nil
}
func getCPUMemByCgroups(mem memInfo) memInfo {
total, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.max")
if err == nil {
mem.TotalMemory = total
}
used, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.current")
if err == nil {
mem.FreeMemory = mem.TotalMemory - used
}
return mem
}
func getUint64ValueFromFile(path string) (uint64, error) {
f, err := os.Open(path)
if err != nil {
return 0, err
}
defer f.Close()
s := bufio.NewScanner(f)
for s.Scan() {
line := s.Text()
return strconv.ParseUint(line, 10, 64)
}
return 0, errors.New("empty file content")
}
const CpuInfoFilename = "/proc/cpuinfo"
type linuxCpuInfo struct {
@@ -74,7 +110,41 @@ func GetCPUDetails() []CPU {
return nil
}
defer file.Close()
return linuxCPUDetails(file)
cpus := linuxCPUDetails(file)
return overwriteThreadCountByLinuxCgroups(cpus)
}
func overwriteThreadCountByLinuxCgroups(cpus []CPU) []CPU {
file, err := os.Open("/sys/fs/cgroup/cpu.max")
if err != nil {
return cpus
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if sl := strings.Split(line, " "); len(sl) == 2 {
allowdUs, err := strconv.ParseInt(sl[0], 10, 64)
if err != nil {
slog.Warn("failed to parse CPU allowed micro secs", "error", err)
return cpus
}
unitUs, err := strconv.ParseInt(sl[1], 10, 64)
if err != nil {
slog.Warn("failed to parse CPU unit micro secs", "error", err)
return cpus
}
threads := int(max(allowdUs/unitUs, 1))
cpu := cpus[0]
cpu.CoreCount = threads
cpu.ThreadCount = threads
return []CPU{cpu}
}
}
return cpus
}
func linuxCPUDetails(file io.Reader) []CPU {

View File

@@ -65,6 +65,10 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
}
slog.Info("discovering available GPUs...")
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
overrideWarnings()
requested := envconfig.LLMLibrary()
jetpack := cudaJetpack()
@@ -90,7 +94,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
var dirs []string
if dir != "" {
if requested != "" && filepath.Base(dir) != requested {
slog.Debug("skipping available library at users request", "requested", requested, "libDir", dir)
slog.Debug("skipping available library at user's request", "requested", requested, "libDir", dir)
continue
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
continue
@@ -113,7 +117,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
// In the second pass, we more deeply initialize the GPUs to weed out devices that
// aren't supported by a given library. We run this phase in parallel to speed up discovery.
// Only devices that need verification are included in this pass
slog.Debug("evluating which if any devices to filter out", "initial_count", len(devices))
slog.Debug("evaluating which, if any, devices to filter out", "initial_count", len(devices))
ctx2ndPass, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
var wg sync.WaitGroup
@@ -121,11 +125,21 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
supportedMu := sync.Mutex{}
supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index
for i := range devices {
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
if !devices[i].NeedsInitValidation() {
// No need to validate, add to the supported map
supportedMu.Lock()
if _, ok := supported[devices[i].Library]; !ok {
supported[devices[i].Library] = make(map[string]map[string]int)
}
if _, ok := supported[devices[i].Library][libDir]; !ok {
supported[devices[i].Library][libDir] = make(map[string]int)
}
supported[devices[i].Library][libDir][devices[i].ID] = i
supportedMu.Unlock()
continue
}
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
slog.Debug("verifying device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
slog.Debug("verifying if device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
wg.Add(1)
go func(i int) {
defer wg.Done()
@@ -449,3 +463,24 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map
return devices
}
func overrideWarnings() {
anyFound := false
m := envconfig.AsMap()
for _, k := range []string{
"CUDA_VISIBLE_DEVICES",
"HIP_VISIBLE_DEVICES",
"ROCR_VISIBLE_DEVICES",
"GGML_VK_VISIBLE_DEVICES",
"GPU_DEVICE_ORDINAL",
"HSA_OVERRIDE_GFX_VERSION",
} {
if e, found := m[k]; found && e.Value != "" {
anyFound = true
slog.Warn("user overrode visible devices", k, e.Value)
}
}
if anyFound {
slog.Warn("if GPUs are not correctly discovered, unset and try again")
}
}

View File

@@ -9,15 +9,9 @@ sidebarTitle: Cloud
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.
Ollama currently supports the following cloud models, with more coming soon:
### Supported models
- `deepseek-v3.1:671b-cloud`
- `gpt-oss:20b-cloud`
- `gpt-oss:120b-cloud`
- `kimi-k2:1t-cloud`
- `qwen3-coder:480b-cloud`
- `glm-4.6:cloud`
- `minimax-m2:cloud`
For a list of supported models, see Ollama's [model library](https://ollama.com/search?c=cloud).
### Running Cloud models

View File

@@ -1,34 +1,34 @@
---
title: VS Code
title: VS Code
---
## Install
Install [VS Code](https://code.visualstudio.com/download).
Install [VS Code](https://code.visualstudio.com/download).
## Usage with Ollama
## Usage with Ollama
1. Open Copilot side bar found in top right window
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/vscode-sidebar.png"
alt="VS Code chat Sidebar"
width="75%"
/>
</div>
2. Select the model drowpdown > **Manage models**
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/vscode-models.png"
alt="VS Code model picker"
width="75%"
/>
</div>
<div style={{ display: "flex", justifyContent: "center" }}>
<img
src="/images/vscode-sidebar.png"
alt="VS Code chat Sidebar"
width="75%"
/>
</div>
2. Select the model dropdown > **Manage models**
<div style={{ display: "flex", justifyContent: "center" }}>
<img
src="/images/vscode-models.png"
alt="VS Code model picker"
width="75%"
/>
</div>
3. Enter **Ollama** under **Provider Dropdown** and select desired models (e.g `qwen3, qwen3-coder:480b-cloud`)
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/vscode-model-options.png"
alt="VS Code model options dropdown"
width="75%"
/>
</div>
<div style={{ display: "flex", justifyContent: "center" }}>
<img
src="/images/vscode-model-options.png"
alt="VS Code model options dropdown"
width="75%"
/>
</div>

View File

@@ -111,6 +111,12 @@ components:
description: Model keep-alive duration (for example `5m` or `0` to unload immediately)
options:
$ref: "#/components/schemas/ModelOptions"
logprobs:
type: boolean
description: Whether to return log probabilities of the output tokens
top_logprobs:
type: integer
description: Number of most likely tokens to return at each token position when logprobs are enabled
GenerateResponse:
type: object
properties:
@@ -150,6 +156,11 @@ components:
eval_duration:
type: integer
description: Time spent generating tokens in nanoseconds
logprobs:
type: array
items:
$ref: "#/components/schemas/Logprob"
description: Log probability information for the generated tokens when logprobs are enabled
GenerateStreamEvent:
type: object
properties:
@@ -287,6 +298,12 @@ components:
- type: string
- type: number
description: Model keep-alive duration (for example `5m` or `0` to unload immediately)
logprobs:
type: boolean
description: Whether to return log probabilities of the output tokens
top_logprobs:
type: integer
description: Number of most likely tokens to return at each token position when logprobs are enabled
ChatResponse:
type: object
properties:
@@ -344,6 +361,11 @@ components:
eval_duration:
type: integer
description: Time spent generating tokens in nanoseconds
logprobs:
type: array
items:
$ref: "#/components/schemas/Logprob"
description: Log probability information for the generated tokens when logprobs are enabled
ChatStreamEvent:
type: object
properties:
@@ -706,6 +728,41 @@ components:
version:
type: string
description: Version of Ollama
TokenLogprob:
type: object
description: Log probability information for a single token alternative
properties:
token:
type: string
description: The text representation of the token
logprob:
type: number
description: The log probability of this token
bytes:
type: array
items:
type: integer
description: The raw byte representation of the token
Logprob:
type: object
description: Log probability information for a generated token
properties:
token:
type: string
description: The text representation of the token
logprob:
type: number
description: The log probability of this token
bytes:
type: array
items:
type: integer
description: The raw byte representation of the token
top_logprobs:
type: array
items:
$ref: "#/components/schemas/TokenLogprob"
description: Most likely tokens and their log probabilities at this position
ErrorResponse:
type: object
properties:

View File

@@ -249,6 +249,9 @@ func (kv KV) OllamaEngineRequired() bool {
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
"deepseekocr",
"deepseek2",
"nomic-bert",
}, kv.Architecture())
}

View File

@@ -305,7 +305,7 @@ func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error
a.values[i] = e
} else {
discardGGUFString(llm, r)
_ = discardGGUFString(llm, r)
}
}
@@ -568,7 +568,6 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
g.SetLimit(runtime.GOMAXPROCS(0))
// TODO consider reducing if tensors size * gomaxprocs is larger than free memory
for _, t := range ts {
t := t
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
g.Go(func() error {
_, err := t.WriteTo(w)

1
go.mod
View File

@@ -17,7 +17,6 @@ require (
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.12.0
golang.org/x/sys v0.36.0
)
require (

View File

@@ -388,9 +388,9 @@ func NewFunctionNameMap() *FunctionNameMap {
}
}
// Init initializes the handler with tools and optional last message
// Init initializes the handler with tools, optional last message, and think value
// Implements the Parser interface
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
// Initialize the harmony parser
if h.HarmonyParser == nil {
h.HarmonyParser = &HarmonyParser{

View File

@@ -3,7 +3,6 @@ package kvcache
import (
"errors"
"fmt"
"log/slog"
"math"
"slices"
@@ -40,18 +39,18 @@ type Causal struct {
// ** current forward pass **
// the active layer for Get and Put
curLayer int
// starting location for data storage for this batch
curLoc int
// size of the current batch
curBatchSize int
// locations for data storage for this batch
curLoc ml.Tensor
// mask of the cache as used by this batch
curMask ml.Tensor
// the active layer for Get and Put
curLayer int
// locations in the cache that are needed for this batch
curCellRange cellRange
@@ -206,45 +205,47 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curPositions = batch.Positions
c.opts.Except = nil
var locs []int32
if !reserve {
c.updateSlidingWindow()
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
locs, err = c.findLocs()
if err != nil {
return err
}
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
loc := int(locs[i])
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
seqRange.min = min(seqRange.min, c.curLoc+i)
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i)
seqRange.min = min(seqRange.min, loc)
c.curCellRange.min = min(c.curCellRange.min, loc)
seqRange.max = max(seqRange.max, c.curLoc+i)
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i)
seqRange.max = max(seqRange.max, loc)
c.curCellRange.max = max(c.curCellRange.max, loc)
c.cellRanges[seq] = seqRange
}
} else {
// If we are reserving memory, don't update any of the cache metadata but set the size
// to the worst case.
c.curLoc = 0
locs = make([]int32, c.curBatchSize)
for i := range locs {
locs[i] = int32(i)
}
c.curCellRange.min = 0
c.curCellRange.max = len(c.cells) - 1
}
c.curLoc = ctx.Input().FromInts(locs, len(locs))
c.curMask = c.buildMask(ctx)
return nil
@@ -257,22 +258,20 @@ func newRange() cellRange {
}
}
// Find the first contiguous block of at least curBatchSize
func (c *Causal) findStartLoc() (int, error) {
var start, count int
// Returns a slice of locations where each token in the batch should be stored
func (c *Causal) findLocs() ([]int32, error) {
loc := make([]int32, 0, c.curBatchSize)
for i := range c.cells {
if len(c.cells[i].sequences) == 0 {
count++
if count >= c.curBatchSize {
return start, nil
loc = append(loc, int32(i))
if len(loc) >= c.curBatchSize {
return loc, nil
}
} else {
start = i + 1
count = 0
}
}
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
}
func (c *Causal) updateSlidingWindow() {
@@ -402,145 +401,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
return maskTensor
}
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
for i, key := range c.keys {
if key == nil {
continue
}
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
}
}
func (c *Causal) defrag() {
slog.Debug("defragmenting kv cache")
// Defrag strategy:
// - Search for empty holes at the beginning of the cache,
// filling them with active data starting at the end
// - If there are contiguous elements that need to be moved,
// combine them into a single operation by holding new moves
// until we see that the next one is non-contiguous
// - Fill up the context with the maximum number of operations it
// can hold then compute that and continue with a new context
//
// We could try to optimize placement by grouping blocks from
// the same sequences together but most likely the next forward
// pass will disrupt this anyways, so the real world benefit
// seems limited as this time.
ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v). We also need to refer to the original
// k and v cache tensors - once per layer, not per move.
layers := 0
for _, key := range c.keys {
if key == nil {
continue
}
layers++
}
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
src := len(c.cells) - 1
for dst := 0; dst < src; dst++ {
if len(c.cells[dst].sequences) == 0 {
for ; src > dst; src-- {
if len(c.cells[src].sequences) != 0 {
c.cells[dst] = c.cells[src]
c.cells[src] = cacheCell{}
if pendingLen > 0 {
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
pendingSrc = src
pendingLen++
break
} else {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
}
pendingSrc = src
pendingDst = dst
pendingLen = 1
break
}
}
}
if moves >= maxMoves {
ctx.Compute()
ctx.Close()
ctx = c.backend.NewContext()
moves = 0
}
}
if pendingLen > 0 {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
if moves > 0 {
ctx.Compute()
}
ctx.Close()
// Reset range metadata
for seq := range c.cellRanges {
seqRange := newRange()
for i, cell := range c.cells {
if slices.Contains(cell.sequences, seq) {
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
c.cellRanges[seq] = seqRange
}
c.updateSlidingWindow()
}
func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
@@ -625,18 +485,25 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
}
}
rowSize := c.keys[c.curLayer].Stride(2)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
keyCache := c.keys[c.curLayer]
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
if c.config.PermutedV {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
value = value.Permute(ctx, 2, 0, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
} else {
rowSize := c.values[c.curLayer].Stride(2)
value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -80,10 +80,10 @@ func TestIssue7978(t *testing.T) {
}
}
func TestSchemaToGrammer(t *testing.T) {
func TestSchemaToGrammar(t *testing.T) {
cases := []struct {
schema string
prefix []byte // nil is check as nil
prefix []byte // nil is checked as nil
}{
{`invalid`, nil},
@@ -92,7 +92,7 @@ func TestSchemaToGrammer(t *testing.T) {
}
for _, c := range cases {
t.Run("x", func(t *testing.T) {
t.Run(c.schema, func(t *testing.T) {
g := SchemaToGrammar([]byte(c.schema))
if c.prefix == nil && g != nil {
t.Fatalf("grammar = %v, want nil", g)

View File

@@ -20,10 +20,10 @@ fix vulkan PCI ID and ID handling
ggml/src/ggml-cuda/vendors/hip.h | 3 +
ggml/src/ggml-impl.h | 8 +
ggml/src/ggml-metal/ggml-metal.cpp | 2 +
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 209 +++++++++++--
ggml/src/mem_hip.cpp | 452 +++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 209 +++++++++++++
9 files changed, 926 insertions(+), 30 deletions(-)
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 209 +++++++++--
ggml/src/mem_hip.cpp | 529 +++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 209 +++++++++++
9 files changed, 1003 insertions(+), 30 deletions(-)
create mode 100644 ggml/src/mem_hip.cpp
create mode 100644 ggml/src/mem_nvml.cpp
@@ -58,7 +58,7 @@ index f9a6587f1..03f359ae9 100644
target_include_directories(ggml-base PRIVATE .)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index c9333689f..41b00af83 100644
index c9333689f..f1a20e7fe 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
@@ -111,7 +111,7 @@ index c9333689f..41b00af83 100644
+ if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s device %s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
+ ggml_hip_mgmt_release();
+ return;
+ }
@@ -243,7 +243,7 @@ index 05ff6a5a6..032dee76d 100644
/* .async = */ true,
/* .host_buffer = */ false,
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 3a6bbe564..d2c278a35 100644
index 3a6bbe564..ca02ea079 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -229,6 +229,7 @@ class vk_memory_logger;
@@ -337,7 +337,7 @@ index 3a6bbe564..d2c278a35 100644
+ if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s device %s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
+ ggml_hip_mgmt_release();
+ return;
+ }
@@ -548,11 +548,12 @@ index 3a6bbe564..d2c278a35 100644
}
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
new file mode 100644
index 000000000..5a7f5d465
index 000000000..c1949b899
--- /dev/null
+++ b/ggml/src/mem_hip.cpp
@@ -0,0 +1,452 @@
@@ -0,0 +1,529 @@
+#include "ggml.h"
+#include "ggml-impl.h"
+
+#ifdef _WIN32
+// AMD Device Library eXtra (ADLX)
@@ -570,7 +571,6 @@ index 000000000..5a7f5d465
+// Unused function parameters are commented out to avoid unnecessary type
+// definitions.
+
+#include "ggml-impl.h"
+#include <filesystem>
+#include <mutex>
+
@@ -990,15 +990,92 @@ index 000000000..5a7f5d465
+
+#else // #ifdef _WIN32
+
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <string>
+#include <vector>
+#include <filesystem>
+
+#include <sys/stat.h>
+#include <dirent.h>
+#include <unistd.h>
+#include <glob.h>
+namespace fs = std::filesystem;
+
+extern "C" {
+
+// TODO Linux implementation of accurate VRAM reporting
+int ggml_hip_mgmt_init() {
+ return -1;
+ return 0;
+}
+void ggml_hip_mgmt_release() {}
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
+ return -1;
+ GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
+ const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
+ const std::string drmTotalMemoryFile = "mem_info_vram_total";
+ const std::string drmUsedMemoryFile = "mem_info_vram_used";
+ const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
+
+ glob_t glob_result;
+ glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
+
+ for (size_t i = 0; i < glob_result.gl_pathc; ++i) {
+ const char* device_file = glob_result.gl_pathv[i];
+ std::ifstream file(device_file);
+ if (!file.is_open()) {
+ std::cerr << "Failed to open sysfs node" << std::endl;
+ globfree(&glob_result);
+ return 1;
+ }
+
+ std::string line;
+ while (std::getline(file, line)) {
+ // Check for PCI_SLOT_NAME label
+ if (line.find(drmUeventPCISlotLabel) == 0) {
+ std::istringstream iss(line.substr(drmUeventPCISlotLabel.size()));
+ std::string pciSlot;
+ iss >> pciSlot;
+ if (pciSlot == std::string(id)) {
+ std::string dir = fs::path(device_file).parent_path().string();
+
+ std::string totalFile = dir + "/" + drmTotalMemoryFile;
+ std::ifstream totalFileStream(totalFile.c_str());
+ if (!totalFileStream.is_open()) {
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
+ file.close();
+ globfree(&glob_result);
+ return 1;
+ }
+
+ uint64_t memory;
+ totalFileStream >> memory;
+ *total = memory;
+
+ std::string usedFile = dir + "/" + drmUsedMemoryFile;
+ std::ifstream usedFileStream(usedFile.c_str());
+ if (!usedFileStream.is_open()) {
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
+ file.close();
+ globfree(&glob_result);
+ return 1;
+ }
+
+ uint64_t memoryUsed;
+ usedFileStream >> memoryUsed;
+ *free = memory - memoryUsed;
+
+ file.close();
+ globfree(&glob_result);
+ return 0;
+ }
+ }
+ }
+
+ file.close();
+ }
+ GGML_LOG_DEBUG("%s unable to find matching device\n", __func__);
+ globfree(&glob_result);
+ return 1;
+}
+
+} // extern "C"

View File

@@ -38,7 +38,7 @@ index 44ae76d66..639d551a2 100644
#ifdef __cplusplus
}
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index d2c278a35..221e29509 100644
index ca02ea079..c12b069e5 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -73,6 +73,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();

View File

@@ -11,7 +11,7 @@ vidmem optimization.
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 221e29509..18b7cbccf 100644
index c12b069e5..76c78c2ea 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -5654,14 +5654,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr

View File

@@ -50,7 +50,7 @@ Subject: [PATCH] Vulkan MMQ Integer Dot Refactor and K-Quant support (#16536)
create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 18b7cbccf..53b57c179 100644
index 76c78c2ea..7669ed206 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -488,6 +488,7 @@ struct vk_device_struct {

View File

@@ -58,7 +58,7 @@ index 639d551a2..e5c446d1d 100644
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 53b57c179..b2855b078 100644
index 7669ed206..63a762ec2 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -387,12 +387,76 @@ static constexpr uint32_t num_argsort_pipelines = 11;

View File

@@ -31,7 +31,7 @@ Add new backend tests.
6 files changed, 371 insertions(+), 117 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index b2855b078..aaf4334b5 100644
index 63a762ec2..db92a7901 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -458,6 +458,11 @@ static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {

View File

@@ -9,7 +9,7 @@ Subject: [PATCH] vulkan: Handle argsort with a large number of rows (#16851)
2 files changed, 16 insertions(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index aaf4334b5..3604ceb04 100644
index db92a7901..e959674d1 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -1084,6 +1084,7 @@ struct vk_op_soft_max_push_constants {

View File

@@ -20,7 +20,7 @@ Subject: [PATCH] vulkan: Fix crash when FP16 mul_mat accumulation is not
1 file changed, 13 insertions(+), 7 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 3604ceb04..80185d9f0 100644
index e959674d1..903050b0b 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -146,8 +146,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);

View File

@@ -0,0 +1,25 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <git@mxy.ng>
Date: Tue, 18 Nov 2025 11:13:04 -0800
Subject: [PATCH] ggml-cuda: skip large batches
cuda panics on batches larger than 1024 so mark it as unsupported to
fallback to cpu
---
ggml/src/ggml-cuda/ggml-cuda.cu | 3 +++
1 file changed, 3 insertions(+)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index f1a20e7fe..1a71e07c9 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3677,6 +3677,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
return false;
}
+ if (op->op == GGML_OP_MUL_MAT && b->ne[2] * b->ne[3] > 1024) {
+ return false;
+ }
#ifdef GGML_USE_MUSA
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {

View File

@@ -0,0 +1,28 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Tue, 18 Nov 2025 09:58:23 -0800
Subject: [PATCH] win: exit instead of abort
---
ggml/src/ggml.c | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 9be35c1be..923c33d05 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -229,8 +229,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
fprintf(stderr, "%s\n", message);
ggml_print_backtrace();
}
-
+#if defined(_WIN32)
+ fflush(stderr);
+ fflush(stdout);
+ exit(1);
+#else
abort();
+#endif
}
// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp

View File

@@ -173,6 +173,7 @@ type Tensor interface {
Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
QuickGELU(ctx Context, up ...Tensor) Tensor
SILU(ctx Context, up ...Tensor) Tensor
RELU(ctx Context, up ...Tensor) Tensor
Sigmoid(ctx Context) Tensor
@@ -193,6 +194,7 @@ type Tensor interface {
Repeat(ctx Context, dim, n int) Tensor
Concat(ctx Context, t2 Tensor, dim int) Tensor
Rows(ctx Context, t2 Tensor) Tensor
SetRows(ctx Context, src Tensor, idxs Tensor) Tensor
Copy(ctx Context, t2 Tensor) Tensor
Duplicate(ctx Context) Tensor
@@ -207,6 +209,8 @@ type Tensor interface {
Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}
// ScaledDotProductAttention implements a fused attention
@@ -230,7 +234,7 @@ type Tensor interface {
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
type ScaledDotProductAttention interface {
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, scale float64) Tensor
ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
}
type number interface {
@@ -372,3 +376,10 @@ const (
DTypeI32
DTypeMXFP4
)
type SamplingMode int
const (
SamplingModeNearest SamplingMode = iota
SamplingModeBilinear
)

View File

@@ -314,7 +314,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
"altup_proj", "altup_unembd_proj",
"per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"):
createTensor(tensor{source: t}, output.bts, blocks)
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm.") || strings.HasPrefix(t.Name, "s."):
// TODO: assign vision tensors to the gpu if possible
createTensor(tensor{source: t}, output.bts, blocks)
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
@@ -499,7 +499,6 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
for _, t := range b.meta.Tensors().Items() {
t := t
g.Go(func() error {
tts := make([]*C.struct_ggml_tensor, max(1, len(b.tensorLoadTargets[t.Name])))
for i := range tts {
@@ -1339,6 +1338,13 @@ func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_set_rows(ctx.(*Context).ctx, t.t, src.(*Tensor).t, idxs.(*Tensor).t),
}
}
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
@@ -1379,6 +1385,10 @@ func inferShape(t *Tensor, shape []int) {
}
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
if !C.ggml_is_contiguous(t.t) {
return t.Contiguous(ctx, shape...)
}
if slices.Contains(shape, -1) {
inferShape(t, shape)
}
@@ -1568,6 +1578,16 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
var tt *C.struct_ggml_tensor
if len(t2) > 0 {
tt = C.ggml_geglu_quick_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t)
} else {
tt = C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t)
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
if len(t2) > 0 {
return &Tensor{
@@ -1625,7 +1645,7 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, scale float64) ml.Tensor {
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor
if mask != nil {
kqMask = mask.(*Tensor).t
@@ -1642,6 +1662,16 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t)
}
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
if vmla != nil {
var cur ml.Tensor = &Tensor{b: t.b, t: kqv}
cur = cur.Permute(ctx, 0, 2, 1, 3)
cur = vmla.Mulmat(ctx, cur)
cur = cur.Permute(ctx, 0, 2, 1, 3)
cur = cur.Contiguous(ctx)
kqv = cur.(*Tensor).t
}
return &Tensor{b: t.b, t: kqv}
} else {
kq := key.MulmatFullPrec(ctx, query)
@@ -1654,6 +1684,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
}
kqv := value.Mulmat(ctx, kq)
if vmla != nil {
kqv = vmla.Mulmat(ctx, kqv)
}
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
}
@@ -1711,6 +1745,23 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
var mode C.uint32_t
switch samplingMode {
case ml.SamplingModeNearest:
mode = C.GGML_SCALE_MODE_NEAREST
case ml.SamplingModeBilinear:
mode = C.GGML_SCALE_MODE_BILINEAR
default:
panic("unsupported interpolate mode")
}
return &Tensor{
b: t.b,
t: C.ggml_interpolate(ctx.(*Context).ctx, t.t, C.int64_t(dims[0]), C.int64_t(dims[1]), C.int64_t(dims[2]), C.int64_t(dims[3]), mode),
}
}
// Slice returns a view of the tensor sliced along dim from low to high in step steps.
// Slice panics if the dimension is invalid or the slice parameters are out of range.
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.

View File

@@ -3513,7 +3513,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
if (ggml_hip_mgmt_init() == 0) {
int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
if (status == 0) {
GGML_LOG_DEBUG("%s device %s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
ggml_hip_mgmt_release();
return;
}
@@ -3677,6 +3677,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
return false;
}
if (op->op == GGML_OP_MUL_MAT && b->ne[2] * b->ne[3] > 1024) {
return false;
}
#ifdef GGML_USE_MUSA
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {

View File

@@ -13212,7 +13212,7 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
if (ggml_hip_mgmt_init() == 0) {
int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
if (status == 0) {
GGML_LOG_DEBUG("%s device %s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
ggml_hip_mgmt_release();
return;
}

View File

@@ -229,8 +229,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
fprintf(stderr, "%s\n", message);
ggml_print_backtrace();
}
#if defined(_WIN32)
fflush(stderr);
fflush(stdout);
exit(1);
#else
abort();
#endif
}
// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp

View File

@@ -1,4 +1,5 @@
#include "ggml.h"
#include "ggml-impl.h"
#ifdef _WIN32
// AMD Device Library eXtra (ADLX)
@@ -16,7 +17,6 @@
// Unused function parameters are commented out to avoid unnecessary type
// definitions.
#include "ggml-impl.h"
#include <filesystem>
#include <mutex>
@@ -436,15 +436,92 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
#else // #ifdef _WIN32
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include <filesystem>
#include <sys/stat.h>
#include <dirent.h>
#include <unistd.h>
#include <glob.h>
namespace fs = std::filesystem;
extern "C" {
// TODO Linux implementation of accurate VRAM reporting
int ggml_hip_mgmt_init() {
return -1;
return 0;
}
void ggml_hip_mgmt_release() {}
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
return -1;
GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
const std::string drmTotalMemoryFile = "mem_info_vram_total";
const std::string drmUsedMemoryFile = "mem_info_vram_used";
const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
glob_t glob_result;
glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
for (size_t i = 0; i < glob_result.gl_pathc; ++i) {
const char* device_file = glob_result.gl_pathv[i];
std::ifstream file(device_file);
if (!file.is_open()) {
std::cerr << "Failed to open sysfs node" << std::endl;
globfree(&glob_result);
return 1;
}
std::string line;
while (std::getline(file, line)) {
// Check for PCI_SLOT_NAME label
if (line.find(drmUeventPCISlotLabel) == 0) {
std::istringstream iss(line.substr(drmUeventPCISlotLabel.size()));
std::string pciSlot;
iss >> pciSlot;
if (pciSlot == std::string(id)) {
std::string dir = fs::path(device_file).parent_path().string();
std::string totalFile = dir + "/" + drmTotalMemoryFile;
std::ifstream totalFileStream(totalFile.c_str());
if (!totalFileStream.is_open()) {
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
file.close();
globfree(&glob_result);
return 1;
}
uint64_t memory;
totalFileStream >> memory;
*total = memory;
std::string usedFile = dir + "/" + drmUsedMemoryFile;
std::ifstream usedFileStream(usedFile.c_str());
if (!usedFileStream.is_open()) {
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
file.close();
globfree(&glob_result);
return 1;
}
uint64_t memoryUsed;
usedFileStream >> memoryUsed;
*free = memory - memoryUsed;
file.close();
globfree(&glob_result);
return 0;
}
}
}
file.close();
}
GGML_LOG_DEBUG("%s unable to find matching device\n", __func__);
globfree(&glob_result);
return 1;
}
} // extern "C"

View File

@@ -22,10 +22,14 @@ import (
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithSinks(ctx, query, key, value, nil, scale, cache)
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
}
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
}
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
ctx.Forward(query)
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
@@ -56,7 +60,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
// Only use the fast SDPA implementation if we have a cache, since that's what
// will do any expected backend-specific transformations for us
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, scale)
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale)
} else {
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
@@ -71,6 +75,11 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
if vmla != nil {
kqv = vmla.Mulmat(ctx, kqv)
}
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
}

View File

@@ -237,7 +237,7 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
}
}
if addSpecial && len(ids) > 0 {
if addSpecial {
ids = bpe.vocab.addSpecials(ids)
}

View File

@@ -25,12 +25,15 @@ const (
// Composite returns an image with the alpha channel removed by drawing over a white background.
func Composite(img image.Image) image.Image {
dst := image.NewRGBA(img.Bounds())
white := color.RGBA{255, 255, 255, 255}
draw.Draw(dst, dst.Bounds(), &image.Uniform{white}, image.Point{}, draw.Src)
draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over)
return CompositeColor(img, white)
}
// CompositeColor returns an image with the alpha channel removed by drawing over a white background.
func CompositeColor(img image.Image, color color.Color) image.Image {
dst := image.NewRGBA(img.Bounds())
draw.Draw(dst, dst.Bounds(), &image.Uniform{color}, image.Point{}, draw.Src)
draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over)
return dst
}
@@ -55,6 +58,31 @@ func Resize(img image.Image, newSize image.Point, method int) image.Image {
return dst
}
// Pad returns an image which has been resized to fit within a new size, preserving aspect ratio, and padded with a color.
func Pad(img image.Image, newSize image.Point, color color.Color, kernel draw.Interpolator) image.Image {
dst := image.NewRGBA(image.Rect(0, 0, newSize.X, newSize.Y))
draw.Draw(dst, dst.Bounds(), &image.Uniform{color}, image.Point{}, draw.Src)
var minPoint, maxPoint image.Point
if img.Bounds().Dx() > img.Bounds().Dy() {
// landscape
height := newSize.X * img.Bounds().Dy() / img.Bounds().Dx()
minPoint = image.Point{0, (newSize.Y - height) / 2}
maxPoint = image.Point{newSize.X, height + minPoint.Y}
} else {
// portrait
width := newSize.Y * img.Bounds().Dx() / img.Bounds().Dy()
minPoint = image.Point{(newSize.X - width) / 2, 0}
maxPoint = image.Point{minPoint.X + width, newSize.Y}
}
kernel.Scale(dst, image.Rectangle{
Min: minPoint,
Max: maxPoint,
}, img, img.Bounds(), draw.Over, nil)
return dst
}
// Normalize returns a slice of float32 containing each of the r, g, b values for an image normalized around a value.
func Normalize(img image.Image, mean, std [3]float32, rescale bool, channelFirst bool) []float32 {
var pixelVals []float32

View File

@@ -156,6 +156,7 @@ func New(c fs.Config) (model.Model, error) {
)),
},
},
true,
)
default:
return nil, model.ErrUnsupportedTokenizer

View File

@@ -3,6 +3,7 @@ package deepseek2
// uses deepseek 2 architecture but written based on deepseek 3 model
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
@@ -16,6 +17,7 @@ import (
)
type Options struct {
isMLA bool
numExpertsUsed int
numExperts int
normTopKProb bool
@@ -32,8 +34,6 @@ type Options struct {
hiddenSize,
numHeads,
numKVHeads,
keyLength,
valueLength,
originalContextLength int
eps,
@@ -62,6 +62,9 @@ type Attention struct {
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
KVB *nn.Linear `gguf:"attn_kv_b"`
KB *nn.Linear `gguf:"attn_k_b"`
VB *nn.Linear `gguf:"attn_v_b"`
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
@@ -69,7 +72,7 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
seqLength := hiddenStates.Dim(1)
var query ml.Tensor
if opts.qLoraRank == 0 { // nil {
if opts.qLoraRank == 0 {
query = attn.Q.Forward(ctx, hiddenStates)
} else {
query = attn.QA.Forward(ctx, hiddenStates)
@@ -88,21 +91,35 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
compressedKV.Stride(1), compressedKV.Dim(1),
)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
var attention ml.Tensor
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
if !opts.isMLA { // v3
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
} else { // v3.1
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
qPassAbsorb := attn.KB.Forward(ctx, qPass)
qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3)
query = qRot.Concat(ctx, qPassAbsorb, 0)
kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
key := kRot.Concat(ctx, kPass, 0)
value := kPass
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
}
attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
}
@@ -233,6 +250,34 @@ func New(c fs.Config) (model.Model, error) {
mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor"))))
kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length")))
isMLA := c.Uint("attention.key_length_mla") != 0 && c.Uint("attention.value_length_mla") != 0
keyLength := int(cmp.Or(c.Uint("attention.key_length_mla"), c.Uint("attention.key_length")))
valueLength := int(cmp.Or(c.Uint("attention.value_length_mla"), c.Uint("attention.value_length")))
var pre []string
switch c.String("tokenizer.ggml.pre") {
case "deepseek-v3":
pre = []string{
// Split regex into multiple parts (according to DeepSeek3's regex)
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
}
case "deepseek-llm":
// TODO: these models haven't been vetted so skip for now
// pre = []string{
// "[\r\n]",
// "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ--ℝℤΩℨK--ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA--z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
// "\\s?[!-/:-~---‟ -。]+",
// "\\s+$",
// "[一-龥ࠀ-一가-퟿]+",
// "[0-9]",
// }
fallthrough
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
@@ -247,18 +292,14 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
// Split regex into multiple parts (according to DeepSeek3's regex)
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
pre...,
),
Layers: layers,
Options: &Options{
isMLA: isMLA,
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
@@ -266,13 +307,13 @@ func New(c fs.Config) (model.Model, error) {
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("expert_weights_norm", true),
qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal,
qLoraRank: int(c.Uint("attention.q_lora_rank")),
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
qkHeadDim: int(c.Uint("attention.key_length")),
vHeadDim: int(c.Uint("attention.value_length")),
qkHeadDim: keyLength,
vHeadDim: valueLength,
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
qkNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
kqNopeHeadDim: keyLength - int(c.Uint("rope.dimension_count")),
routedScalingFactor: c.Float("expert_weights_scale"),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),

View File

@@ -0,0 +1,83 @@
package deepseekocr
import (
"bytes"
"image"
"image/color"
"math"
"slices"
"golang.org/x/image/draw"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/imageproc"
)
type ratio struct {
x, y int
}
func ProcessImage(ctx ml.Context, bts []byte) (ml.Tensor, ml.Tensor, []int, error) {
img, _, err := image.Decode(bytes.NewReader(bts))
if err != nil {
return nil, nil, nil, err
}
minNum, maxNum, imageSize, baseSize := 2, 9, 640, 1024
var targetRatios []ratio
for n := minNum; n <= maxNum; n++ {
for i := 1; i <= n; i++ {
for j := 1; j <= n; j++ {
if i*j <= maxNum && i*j >= minNum && !slices.Contains(targetRatios, ratio{i, j}) {
targetRatios = append(targetRatios, ratio{i, j})
}
}
}
}
targetRatio := findBestAspectRatio(targetRatios, img.Bounds().Dx(), img.Bounds().Dy(), imageSize)
targetWidth, targetHeight := imageSize*targetRatio.x, imageSize*targetRatio.y
blocks := targetRatio.x * targetRatio.y
mean := imageproc.ImageNetStandardMean
std := imageproc.ImageNetStandardSTD
var patches []float32
resized := imageproc.Resize(img, image.Point{X: targetWidth, Y: targetHeight}, imageproc.ResizeBilinear)
for i := range blocks {
patch := image.NewRGBA(image.Rect(0, 0, imageSize, imageSize))
draw.Draw(patch, patch.Bounds(), resized, image.Point{
X: i % (targetWidth / imageSize) * imageSize,
Y: i / (targetWidth / imageSize) * imageSize,
}, draw.Over)
patches = append(patches, imageproc.Normalize(patch, mean, std, true, true)...)
}
img = imageproc.CompositeColor(img, color.Gray{})
img = imageproc.Pad(img, image.Point{X: baseSize, Y: baseSize}, color.Gray{127}, draw.BiLinear)
return ctx.Input().FromFloats(patches, imageSize, imageSize, 3, blocks),
ctx.Input().FromFloats(imageproc.Normalize(img, mean, std, true, true), baseSize, baseSize, 3),
[]int{targetRatio.x, targetRatio.y},
nil
}
func findBestAspectRatio(targetRatios []ratio, width, height, imageSize int) ratio {
bestDiff := math.MaxFloat64
best := ratio{1, 1}
realRatio := float64(width) / float64(height)
for _, target := range targetRatios {
targetRatio := float64(target.x) / float64(target.y)
diff := math.Abs(realRatio - targetRatio)
if diff < bestDiff {
bestDiff = diff
best = target
} else if diff == bestDiff {
if float64(width*height) > 0.5*float64(imageSize*imageSize*best.x*best.y) {
best = target
}
}
}
return best
}

View File

@@ -0,0 +1,192 @@
package deepseekocr
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.TextProcessor
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
Text *textModel
ImageNewline ml.Tensor `gguf:"mm.image_newline"`
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
ViewSeperator ml.Tensor `gguf:"mm.view_seperator"`
Projector *nn.Linear `gguf:"mm.layers"`
}
func (m *Model) EncodeMultimodal(ctx ml.Context, bts []byte) ([]input.Multimodal, error) {
patches, original, crop, err := ProcessImage(ctx, bts)
if err != nil {
return nil, err
}
var outputs []ml.Tensor
if true { // TODO: local features if sum(patches) != 0
samOutputs := m.Sam.Forward(ctx, patches)
visionOutputs := m.Vision.Forward(ctx, patches, samOutputs)
samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3)
visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1)
localOutputs := visionOutputs.Concat(ctx, samOutputs, 0)
localOutputs = m.Projector.Forward(ctx, localOutputs)
hw := int(math.Sqrt(float64(localOutputs.Dim(1))))
localOutputs = localOutputs.Reshape(ctx, -1, hw, crop[0], crop[1])
localOutputs = localOutputs.Permute(ctx, 0, 2, 1, 3)
localOutputs = localOutputs.Contiguous(ctx, -1, crop[0]*hw, crop[1]*hw)
localOutputs = localOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, localOutputs.Dim(2)), 1)
localOutputs = localOutputs.Reshape(ctx, localOutputs.Dim(0), -1)
outputs = append(outputs, localOutputs)
}
samOutputs := m.Sam.Forward(ctx, original)
visionOutputs := m.Vision.Forward(ctx, original, samOutputs)
samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3)
visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1)
globalOutputs := visionOutputs.Concat(ctx, samOutputs, 0)
globalOutputs = m.Projector.Forward(ctx, globalOutputs)
hw := int(math.Sqrt(float64(globalOutputs.Dim(1))))
globalOutputs = globalOutputs.Reshape(ctx, -1, hw, hw)
globalOutputs = globalOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, globalOutputs.Dim(2)), 1)
globalOutputs = globalOutputs.Reshape(ctx, globalOutputs.Dim(0), -1)
outputs = append(outputs, globalOutputs, m.ViewSeperator)
return []input.Multimodal{
{Tensor: outputs[0].Stack(ctx, 1, outputs[1:]...)},
}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
outputs := make([]*input.Input, 0, len(inputs))
for i := range inputs {
if inputs[i].Multimodal == nil {
outputs = append(outputs, inputs[i])
continue
}
t := inputs[i].Multimodal[0].Tensor
outputs = append(outputs, &input.Input{
Token: 128815,
Multimodal: inputs[i].Multimodal,
MultimodalHash: inputs[i].MultimodalHash,
SameBatch: t.Dim(1) - 1,
})
outputs = slices.Grow(outputs, t.Dim(1)-1)
outputs = append(outputs, slices.Repeat([]*input.Input{{Token: 128815}}, t.Dim(1)-1)...)
}
return outputs, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
inputsEmbeds := m.Text.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
for _, mm := range batch.Multimodal {
t := mm.Multimodal[0].Tensor
ctx.Forward(t.Copy(ctx, inputsEmbeds.View(ctx, mm.Index*inputsEmbeds.Stride(1), t.Dim(0)*t.Dim(1))))
}
hiddenStates := inputsEmbeds
for i, block := range m.Text.Blocks {
if m.Cache != nil {
m.Cache.SetLayer(i)
}
var outputs ml.Tensor
if i == len(m.Text.Blocks)-1 {
outputs = batch.Outputs
}
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Text.Options)
}
hiddenStates = m.Text.OutputNorm.Forward(ctx, hiddenStates, m.Text.Options.eps)
return m.Text.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("deepseekocr", func(c fs.Config) (model.Model, error) {
textBlocks := make([]textBlock, c.Uint("block_count"))
leadingDenseBlockCount := int(c.Uint("leading_dense_block_count", 1))
for i := range textBlocks {
if i >= leadingDenseBlockCount {
textBlocks[i].FeedForward = &textMoe{}
} else {
textBlocks[i].FeedForward = &textMLP{}
}
}
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
// Split regex into multiple parts (according to DeepSeek3's regex)
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
),
Text: &textModel{
Blocks: textBlocks,
Options: textOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
ropeBase: c.Float("rope.freq_base", 10_000),
ropeScale: c.Float("rope.scaling.factor", 1.0),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-6),
},
},
Vision: &visionModel{
Blocks: make([]visionBlock, c.Uint("vision.block_count")),
Options: visionOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.head_count")),
imageSize: int(c.Uint("vision.image_size", 224)),
patchSize: int(c.Uint("vision.patch_size", 14)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
},
},
Sam: &samModel{
Blocks: make([]samBlock, c.Uint("sam.block_count")),
Options: samOptions{
hiddenSize: int(c.Uint("sam.embedding_length")),
numHeads: int(c.Uint("sam.head_count")),
eps: c.Float("sam.attention.layer_norm_epsilon", 1e-6),
globalAttentionLayers: c.Ints("sam.global_attention_indexes"),
},
},
}
m.Cache = kvcache.NewCausalCache(m.Text.Shift)
return &m, nil
})
}

View File

@@ -0,0 +1,225 @@
package deepseekocr
import (
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type samModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
PositionEmbedding ml.Tensor `gguf:"position_embd"`
Blocks []samBlock `gguf:"blk"`
Neck *samNeck `gguf:"neck"`
Net2 *nn.Conv2D `gguf:"net_2"`
Net3 *nn.Conv2D `gguf:"net_3"`
Options samOptions
}
func (m *samModel) absolutePositionEmbedding(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
source := m.PositionEmbedding.Dim(1)
target := hiddenStates.Dim(2)
if source != target {
positionEmbed := m.PositionEmbedding.Permute(ctx, 2, 0, 1, 3)
positionEmbed = positionEmbed.Interpolate(ctx, [4]int{target, target, hiddenStates.Dim(0), 1}, ml.SamplingModeBilinear)
return positionEmbed.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
return m.PositionEmbedding
}
func (m *samModel) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
hiddenStates := m.PatchEmbedding.Forward(ctx, t, 16, 16, 0, 0, 1, 1)
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
if m.PositionEmbedding != nil {
hiddenStates = hiddenStates.Add(ctx, m.absolutePositionEmbedding(ctx, hiddenStates))
}
for i, block := range m.Blocks {
var windowSize int
if !slices.Contains(m.Options.globalAttentionLayers, int32(i)) {
windowSize = 14
}
hiddenStates = block.Forward(ctx, hiddenStates, windowSize, m.Options)
}
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
hiddenStates = m.Neck.Forward(ctx, hiddenStates, m.Options)
hiddenStates = m.Net2.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
hiddenStates = m.Net3.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
return hiddenStates
}
type samOptions struct {
hiddenSize,
numHeads int
eps float32
globalAttentionLayers []int32
}
func (o samOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
type samBlock struct {
Norm1 *nn.LayerNorm `gguf:"norm1"`
Attention *samAttention `gguf:"attn"`
Norm2 *nn.LayerNorm `gguf:"norm2"`
FeedForward *samMLP `gguf:"mlp"`
}
func (m *samBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, windowSize int, opts samOptions) ml.Tensor {
c, w, h := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
residual := hiddenStates
hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps)
var pw, ph int
if windowSize > 0 {
pw = (windowSize - hiddenStates.Dim(1)%windowSize) % windowSize
ph = (windowSize - hiddenStates.Dim(2)%windowSize) % windowSize
if pw > 0 || ph > 0 {
hiddenStates = hiddenStates.Pad(ctx, 0, pw, ph, 0)
}
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, (w+pw)/windowSize, windowSize, -1)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, c, windowSize, windowSize, -1)
}
hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts)
if windowSize > 0 {
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, windowSize, (w+pw)/windowSize, -1)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3)
hiddenStates = hiddenStates.Contiguous(ctx, c, w+pw, h+ph, -1)
hiddenStates = hiddenStates.Pad(ctx, 0, -pw, -ph, 0)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type samAttention struct {
QKV *nn.Linear `gguf:"qkv"`
Output *nn.Linear `gguf:"proj"`
RelativePosition *struct {
Height ml.Tensor `gguf:"h"`
Width ml.Tensor `gguf:"w"`
} `gguf:",pre:rel_pos_"`
}
func relativeCoordinates(ctx ml.Context, qn, kn int) ml.Tensor {
s := make([]int32, qn*kn)
for i := range qn {
for j := range kn {
q := i * max(kn/qn, 1)
k := j * max(qn/kn, 1)
s[i*kn+j] = int32(q - k + (kn-1)*max(qn/kn, 1))
}
}
return ctx.Input().FromInts(s, qn*kn)
}
func relativePositions(ctx ml.Context, positions ml.Tensor, qn, kn int) ml.Tensor {
maxRelativeDistance := 2*max(qn, kn) - 1
if positions.Dim(1) != maxRelativeDistance {
// linear interpolation kernel not available so approx. with bilinear interpolation
positions = positions.Interpolate(ctx, [4]int{positions.Dim(0), maxRelativeDistance, 1, 1}, ml.SamplingModeBilinear)
}
rc := relativeCoordinates(ctx, qn, kn)
return positions.Rows(ctx, rc).Reshape(ctx, positions.Dim(0), kn, qn)
}
func (m *samAttention) decomposedRelativePositions(ctx ml.Context, query ml.Tensor, qn, kn []int) (ml.Tensor, ml.Tensor) {
qh, qw := qn[0], qn[1]
kh, kw := kn[0], kn[1]
rh := relativePositions(ctx, m.RelativePosition.Height, qh, kh)
rw := relativePositions(ctx, m.RelativePosition.Width, qw, kw)
query = query.Contiguous(ctx, query.Dim(0), qw, qh, -1)
rh = rh.Mulmat(ctx, query).Reshape(ctx, 1, kh, qh*qw, -1)
rw = rw.Mulmat(ctx, query.Permute(ctx, 0, 2, 1, 3)).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, kw, 1, qh*qw, -1)
return rh, rw
}
func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
w, h, b := hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)
qkv := m.QKV.Forward(ctx, hiddenStates)
qkv = qkv.Reshape(ctx, opts.headDim(), -1, w*h, b)
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
ctx.Forward(query, key, value)
query = query.Permute(ctx, 0, 2, 1, 3)
rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w})
mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw)
mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b)
key = key.Permute(ctx, 0, 2, 1, 3)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim())))
scores = scores.Add(ctx, mask)
scores = scores.Softmax(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3)
attention = attention.Contiguous(ctx, -1, w, h, b)
return m.Output.Forward(ctx, attention)
}
type samMLP struct {
Lin1 *nn.Linear `gguf:"lin1"`
Lin2 *nn.Linear `gguf:"lin2"`
}
func (m *samMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
return m.Lin2.Forward(ctx, m.Lin1.Forward(ctx, hiddenStates).GELU(ctx))
}
type LayerNorm2D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (ln *LayerNorm2D) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
u := x.Mean(ctx)
d := x.Sub(ctx, u)
s := d.Sqr(ctx).Mean(ctx)
x = d.Div(ctx, s.Add(ctx, ctx.Input().FromFloats([]float32{eps}, 1)).Sqrt(ctx))
x = x.Mul(ctx, ln.Weight).Add(ctx, ln.Bias)
return x.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
}
type samNeck struct {
C1 *nn.Conv2D `gguf:"0"`
LN1 *LayerNorm2D `gguf:"1"`
C2 *nn.Conv2D `gguf:"2"`
LN2 *LayerNorm2D `gguf:"3"`
}
func (m *samNeck) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
hiddenStates = m.C1.Forward(ctx, hiddenStates, 1, 1, 0, 0, 1, 1)
hiddenStates = m.LN1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.C2.Forward(ctx, hiddenStates, 1, 1, 1, 1, 1, 1)
hiddenStates = m.LN2.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates
}

View File

@@ -0,0 +1,140 @@
package deepseekocr
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
)
type textModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Blocks []textBlock `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output"`
Options textOptions
}
func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil
}
type textOptions struct {
hiddenSize,
numHeads,
numKVHeads,
numExperts,
numExpertsUsed int
ropeBase,
ropeScale,
eps float32
}
func (o textOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX())
}
type textBlock struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *textAttention
MLPNNorm *nn.RMSNorm `gguf:"ffn_norm"`
FeedForward textFeedForward
}
func (m *textBlock) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = m.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = m.MLPNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type textAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
query := m.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, -1)
key := m.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
value := m.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, -1, attention.Dim(2))
return m.Output.Forward(ctx, attention)
}
type textFeedForward interface {
Forward(ml.Context, ml.Tensor, textOptions) ml.Tensor
}
type textMoe struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
SharedExperts *textMLP `gguf:",suf:_shexp"`
}
func (m *textMoe) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts textOptions) ml.Tensor {
scores := m.Router.Forward(ctx, hiddenStates).Softmax(ctx)
indices := scores.TopK(ctx, opts.numExpertsUsed)
weights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, indices)
experts := hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
experts = m.Gate.Forward(ctx, experts, indices).SILU(ctx, m.Up.Forward(ctx, experts, indices))
experts = m.Down.Forward(ctx, experts, indices)
experts = experts.Mul(ctx, weights)
expert := func(i int) ml.Tensor {
return experts.View(
ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2),
)
}
routedStates := expert(0)
for i := 1; i < opts.numExpertsUsed; i++ {
routedStates = routedStates.Add(ctx, expert(i))
}
sharedStates := m.SharedExperts.Forward(ctx, hiddenStates, opts)
return routedStates.Add(ctx, sharedStates)
}
type textMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (m *textMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ textOptions) ml.Tensor {
hiddenStates = m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates))
return m.Down.Forward(ctx, hiddenStates)
}

View File

@@ -0,0 +1,117 @@
package deepseekocr
import (
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type visionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
ClassEmbedding ml.Tensor `gguf:"class_embd"`
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
PreLayerNorm *nn.LayerNorm `gguf:"pre_layrnorm"`
Blocks []visionBlock `gguf:"blk"`
Options visionOptions
}
func (m *visionModel) absolutePositionEmbedding(ctx ml.Context, embeds ml.Tensor) ml.Tensor {
numPatches := m.Options.imageSize / m.Options.patchSize * m.Options.imageSize / m.Options.patchSize
positions := ctx.Arange(0, float32(numPatches+1), 1, ml.DTypeI32)
positionEmbeds := m.PositionEmbedding.Forward(ctx, positions)
source := int(math.Sqrt(float64(positionEmbeds.Dim(1) - 1)))
target := int(math.Sqrt(float64(embeds.Dim(1) - 1)))
if source != target {
newPositionEmbeds := positionEmbeds.Slice(ctx, 1, 1, positionEmbeds.Dim(1), 1)
newPositionEmbeds = newPositionEmbeds.Reshape(ctx, -1, source, source)
newPositionEmbeds = newPositionEmbeds.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
newPositionEmbeds = newPositionEmbeds.Interpolate(ctx, [4]int{target, target, embeds.Dim(0), 1}, ml.SamplingModeBilinear)
newPositionEmbeds = newPositionEmbeds.Permute(ctx, 1, 2, 0, 3)
newPositionEmbeds = newPositionEmbeds.Contiguous(ctx, -1, target*target)
positionEmbeds = positionEmbeds.Slice(ctx, 1, 0, 1, 1).Concat(ctx, newPositionEmbeds, 1)
}
return positionEmbeds
}
func (m *visionModel) Forward(ctx ml.Context, pixelValues, patchEmbeds ml.Tensor) ml.Tensor {
if patchEmbeds == nil {
patchEmbeds = m.PatchEmbedding.Forward(ctx, pixelValues, m.Options.patchSize, m.Options.patchSize, 0, 0, 1, 1)
}
patchEmbeds = patchEmbeds.Reshape(ctx, -1, patchEmbeds.Dim(2), patchEmbeds.Dim(3))
patchEmbeds = patchEmbeds.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
classEmbeds := m.ClassEmbedding.Repeat(ctx, 2, patchEmbeds.Dim(2))
embeds := classEmbeds.Concat(ctx, patchEmbeds, 1)
embeds = embeds.Add(ctx, m.absolutePositionEmbedding(ctx, embeds))
hiddenStates := m.PreLayerNorm.Forward(ctx, embeds, m.Options.eps)
for _, block := range m.Blocks {
hiddenStates = block.Forward(ctx, hiddenStates, m.Options)
}
return hiddenStates
}
type visionOptions struct {
hiddenSize,
numHeads int
eps float32
imageSize, patchSize int
}
func (o visionOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
type visionBlock struct {
Norm1 *nn.LayerNorm `gguf:"layer_norm1"`
Attention *visionAttention `gguf:"self_attn"`
Norm2 *nn.LayerNorm `gguf:"layer_norm2"`
FeedForward *visionMLP `gguf:"mlp"`
}
func (m *visionBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts visionOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type visionAttention struct {
QKV *nn.Linear `gguf:"qkv_proj"`
Output *nn.Linear `gguf:"out_proj"`
}
func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOptions) ml.Tensor {
qkv := m.QKV.Forward(ctx, t)
qkv = qkv.Reshape(ctx, opts.headDim(), -1, qkv.Dim(1), qkv.Dim(2))
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3))
return m.Output.Forward(ctx, attention)
}
type visionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (m *visionMLP) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
return m.FC2.Forward(ctx, m.FC1.Forward(ctx, t).QuickGELU(ctx))
}

View File

@@ -3,6 +3,7 @@ package models
import (
_ "github.com/ollama/ollama/model/models/bert"
_ "github.com/ollama/ollama/model/models/deepseek2"
_ "github.com/ollama/ollama/model/models/deepseekocr"
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
@@ -11,6 +12,7 @@ import (
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
_ "github.com/ollama/ollama/model/models/nomicbert"
_ "github.com/ollama/ollama/model/models/qwen2"
_ "github.com/ollama/ollama/model/models/qwen25vl"
_ "github.com/ollama/ollama/model/models/qwen3"

View File

@@ -0,0 +1,170 @@
package nomicbert
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
Layers []EncoderLayer `gguf:"blk"`
Options
}
type Options struct {
hiddenSize int
numHeads int
headDim int
eps float32
poolingType pooling.Type
normalize bool
ropeFreqBase float32
}
// Single Encoder Layer
type EncoderLayer struct {
*Attention
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
*MLP
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
}
type Attention struct {
QKV *nn.Linear `gguf:"attn_qkv"`
Output *nn.Linear `gguf:"attn_output"`
}
type MLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
typeEmbed := m.TypeEmbedding.Weight.Slice(ctx, 1, 0, 1, 1)
hiddenStates = hiddenStates.Add(ctx, typeEmbed)
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, positions, &m.Options)
}
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
if m.normalize {
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
}
return hiddenStates, nil
}
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml.Tensor, opts *Options) ml.Tensor {
residual := hiddenStates
hiddenStates = e.Attention.Forward(ctx, hiddenStates, positions, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
residual = hiddenStates
hiddenStates = e.MLP.Forward(ctx, hiddenStates)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates
}
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml.Tensor, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
qkv := a.QKV.Forward(ctx, hiddenStates)
qkv = qkv.Reshape(ctx, opts.headDim, opts.numHeads*3, batchSize)
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
query = fast.RoPE(ctx, query, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return a.Output.Forward(ctx, attention)
}
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
hidden := m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates))
return m.Down.Forward(ctx, hidden)
}
func New(c fs.Config) (model.Model, error) {
hiddenSize := int(c.Uint("embedding_length"))
numHeads := int(c.Uint("attention.head_count"))
headDim := hiddenSize / numHeads
processor := model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
)
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: headDim,
eps: c.Float("attention.layer_norm_epsilon"),
poolingType: pooling.Type(c.Uint("pooling_type")),
normalize: c.Bool("normalize_embeddings", false),
ropeFreqBase: c.Float("rope.freq_base", 1000.0),
},
}, nil
}
func init() {
model.Register("nomic-bert", New)
model.Register("nomic-bert_embed", New)
}

319
model/parsers/cogito.go Normal file
View File

@@ -0,0 +1,319 @@
package parsers
import (
"encoding/json"
"errors"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
type CogitoParserState int
const (
CogitoCollectingThinking CogitoParserState = iota
CogitoCollectingContent
CogitoCollectingToolCalls
CogitoCollectingToolOutput
)
const (
cogitoThinkingCloseTag = "</think>"
cogitoToolCallsBeginTag = "<tool▁calls▁begin>"
cogitoToolCallsEndTag = "<tool▁calls▁end>"
cogitoToolCallBeginTag = "<tool▁call▁begin>"
cogitoToolCallEndTag = "<tool▁call▁end>"
cogitoToolSepTag = "<tool▁sep>"
cogitoToolOutputBeginTag = "<tool▁output▁begin>"
cogitoToolOutputEndTag = "<tool▁output▁end>"
cogitoToolOutputsBeginTag = "<tool▁outputs▁begin>"
cogitoToolOutputsEndTag = "<tool▁outputs▁end>"
)
type CogitoParser struct {
state CogitoParserState
buffer strings.Builder
}
func (p *CogitoParser) HasToolSupport() bool {
return true
}
func (p *CogitoParser) HasThinkingSupport() bool {
return true
}
func (p *CogitoParser) setInitialState(lastMessage *api.Message, tools []api.Tool, thinkValue *api.ThinkValue) {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
// Check both model capability AND request preference
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
// thinkingEnabled should be set to false for tools
if !thinkingEnabled {
p.state = CogitoCollectingContent
return
}
if prefill && lastMessage.Content != "" {
p.state = CogitoCollectingContent
return
}
// Note: for cogito, if there are tools, then we don't want to be thinking
if len(tools) > 0 {
p.state = CogitoCollectingContent
return
}
p.state = CogitoCollectingThinking
}
func (p *CogitoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.setInitialState(lastMessage, tools, thinkValue)
return tools
}
type cogitoEvent interface {
isCogitoEvent()
}
type cogitoEventThinkingContent struct {
content string
}
type cogitoEventContent struct {
content string
}
type cogitoEventToolCall struct {
toolCall api.ToolCall
}
func (cogitoEventThinkingContent) isCogitoEvent() {}
func (cogitoEventContent) isCogitoEvent() {}
func (cogitoEventToolCall) isCogitoEvent() {}
func (p *CogitoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case cogitoEventToolCall:
toolCalls = append(toolCalls, event.toolCall)
case cogitoEventThinkingContent:
thinkingSb.WriteString(event.content)
case cogitoEventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *CogitoParser) parseEvents() []cogitoEvent {
var all []cogitoEvent
keepLooping := true
for keepLooping {
var events []cogitoEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
return all
}
func (p *CogitoParser) eat() ([]cogitoEvent, bool) {
var events []cogitoEvent
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case CogitoCollectingThinking:
if strings.Contains(bufStr, cogitoThinkingCloseTag) { // thinking[</think>] -> content
split := strings.SplitN(bufStr, cogitoThinkingCloseTag, 2)
thinking := split[0]
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
remaining := split[1]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingContent
if len(thinking) > 0 {
events = append(events, cogitoEventThinkingContent{content: thinking})
}
return events, true
} else if overlapLen := overlap(bufStr, cogitoThinkingCloseTag); overlapLen > 0 { // partial </think>
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, cogitoEventThinkingContent{content: unambiguous})
}
return events, false
} else { // otherwise its thinking content
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, cogitoEventThinkingContent{content: unambiguous})
}
return events, false
}
case CogitoCollectingContent:
switch {
case strings.Contains(bufStr, cogitoToolCallsBeginTag): // content[<tool▁calls▁begin>] -> tool calls
split := strings.SplitN(bufStr, cogitoToolCallsBeginTag, 2)
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingToolCalls
if len(contentBefore) > 0 {
events = append(events, cogitoEventContent{content: contentBefore})
}
return events, true
case strings.Contains(bufStr, cogitoToolOutputsBeginTag): // content[<tool▁outputs▁begin>] -> tool outputs
split := strings.SplitN(bufStr, cogitoToolOutputsBeginTag, 2)
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingToolOutput
if len(contentBefore) > 0 {
events = append(events, cogitoEventContent{content: contentBefore})
}
return events, true
default: // otherwise its content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, cogitoEventContent{content: bufStr})
}
return events, false
}
case CogitoCollectingToolCalls:
if idx := strings.Index(bufStr, cogitoToolCallBeginTag); idx != -1 {
startIdx := idx + len(cogitoToolCallBeginTag)
if endIdx := strings.Index(bufStr[startIdx:], cogitoToolCallEndTag); endIdx != -1 {
toolCallContent := bufStr[startIdx : startIdx+endIdx]
if toolCall, err := p.parseToolCallContent(toolCallContent); err == nil {
remaining := bufStr[startIdx+endIdx+len(cogitoToolCallEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
events = append(events, cogitoEventToolCall{toolCall: toolCall})
return events, true
} else {
slog.Warn("cogito tool call parsing failed", "error", err)
}
}
}
if idx := strings.Index(bufStr, cogitoToolCallsEndTag); idx != -1 {
remaining := bufStr[idx+len(cogitoToolCallsEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingContent
return events, true
}
return events, false
case CogitoCollectingToolOutput:
if idx := strings.Index(bufStr, cogitoToolOutputBeginTag); idx != -1 {
startIdx := idx + len(cogitoToolOutputBeginTag)
if endIdx := strings.Index(bufStr[startIdx:], cogitoToolOutputEndTag); endIdx != -1 {
remaining := bufStr[startIdx+endIdx+len(cogitoToolOutputEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
return events, true
}
}
if idx := strings.Index(bufStr, cogitoToolOutputsEndTag); idx != -1 {
remaining := bufStr[idx+len(cogitoToolOutputsEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingContent
return events, true
}
return events, false
}
return events, false
}
func (p *CogitoParser) parseToolCallContent(content string) (api.ToolCall, error) {
// Expected format: function<tool▁sep>tool_name\n```json\n{args}\n```
parts := strings.SplitN(content, cogitoToolSepTag, 2)
if len(parts) < 2 {
return api.ToolCall{}, errors.New("invalid format")
}
nameAndArgs := parts[1]
jsonStart := strings.Index(nameAndArgs, "\n```json\n")
if jsonStart == -1 {
return api.ToolCall{}, errors.New("invalid format")
}
toolName := strings.TrimSpace(nameAndArgs[:jsonStart])
jsonContent := nameAndArgs[jsonStart+len("\n```json\n"):]
jsonEnd := strings.Index(jsonContent, "\n```")
if jsonEnd == -1 {
return api.ToolCall{}, errors.New("invalid format")
}
argsJSON := jsonContent[:jsonEnd]
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
return api.ToolCall{}, err
}
return api.ToolCall{
Function: api.ToolCallFunction{
Name: toolName,
Arguments: args,
},
}, nil
}

View File

@@ -0,0 +1,565 @@
package parsers
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestCogitoParser(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
expectedToolCalls []api.ToolCall
tools []api.Tool
lastMessage *api.Message
}{
{
name: "simple_content",
input: "This is a simple response.",
expectedContent: "This is a simple response.",
expectedThinking: "",
},
{
name: "thinking_only",
input: "This is thinking content.</think>This is response content.",
expectedContent: "This is response content.",
expectedThinking: "This is thinking content.",
},
{
name: "tool_call_simple",
input: `<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end><tool▁calls▁end>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
},
{
name: "thinking_with_tool_call",
input: `I need to check the weather.</think><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end><tool▁calls▁end>`,
expectedContent: "I need to check the weather.</think>",
expectedThinking: "", // No thinking when tools are present (Cogito-specific behavior)
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
},
{
name: "multiple_tool_calls",
input: `<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end>
<tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"London"}
` + "```" + `<tool▁call▁end><tool▁calls▁end>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "London",
},
},
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
},
{
name: "complex_tool_arguments",
input: `<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>process_data
` + "```json\n" + `{"items":["item1","item2"],"config":{"enabled":true,"threshold":0.95},"count":42}
` + "```" + `<tool▁call▁end><tool▁calls▁end>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: api.ToolCallFunctionArguments{
"items": []any{"item1", "item2"},
"config": map[string]any{"enabled": true, "threshold": 0.95},
"count": 42.0,
},
},
},
},
},
{
name: "tool_output_parsing",
input: `<tool▁outputs▁begin><tool▁output▁begin>{"temperature": 22, "condition": "sunny"}<tool▁output▁end><tool▁outputs▁end>`,
expectedContent: "",
expectedThinking: "",
},
{
name: "thinking_with_multiline_content",
input: `This is line 1
This is line 2
This is line 3</think>Final response here.`,
expectedContent: "Final response here.",
expectedThinking: "This is line 1\nThis is line 2\nThis is line 3",
},
{
name: "no_thinking_simple",
input: "This is content.",
expectedContent: "This is content.",
expectedThinking: "",
},
{
name: "prefill_content_only",
input: "Continuing from previous content.",
expectedContent: "Continuing from previous content.",
lastMessage: &api.Message{
Role: "assistant",
Content: "Previous content",
},
},
{
name: "prefill_with_thinking",
input: "Continuing thinking</think>Continuing content.",
expectedContent: "Continuing content.",
expectedThinking: "Continuing thinking",
lastMessage: &api.Message{
Role: "assistant",
},
},
// Edge cases
{
name: "nested_think_tags_in_thinking",
input: "I'm thinking <think>nested</think> more thinking</think>Final content.",
expectedContent: "more thinking</think>Final content.",
expectedThinking: "I'm thinking <think>nested",
},
{
name: "multiple_think_close_tags",
input: "First thinking</think>Content</think>More content.",
expectedContent: "Content</think>More content.",
expectedThinking: "First thinking",
},
{
name: "empty_thinking_content",
input: "</think>Just content here.",
expectedContent: "</think>Just content here.",
expectedThinking: "",
},
{
name: "thinking_disabled_with_think_tags",
input: "Content with </think> tags should be treated as content.",
expectedContent: "Content with </think> tags should be treated as content.",
expectedThinking: "",
lastMessage: &api.Message{
Role: "assistant",
Content: "existing", // Forces non-thinking mode
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Use thinking-enabled parser for tests that expect thinking
hasThinking := tt.expectedThinking != ""
parser := &CogitoParser{} // it has thinking support
parser.Init(tt.tools, tt.lastMessage, &api.ThinkValue{Value: hasThinking}) // but we should set it with the request that the user wants
content, thinking, toolCalls, err := parser.Add(tt.input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
t.Errorf("content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestCogitoParser_Streaming(t *testing.T) {
parser := &CogitoParser{}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
chunks := []string{
"This is ",
"thinking content",
".</think>This is ",
"content.<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>test_tool\n```json\n{\"arg\":\"value\"}\n```<tool▁call▁end><tool▁calls▁end>",
}
var finalContent, finalThinking strings.Builder
var finalToolCalls []api.ToolCall
for i, chunk := range chunks {
done := i == len(chunks)-1
content, thinking, toolCalls, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error on chunk %d: %v", i, err)
}
finalContent.WriteString(content)
finalThinking.WriteString(thinking)
finalToolCalls = append(finalToolCalls, toolCalls...)
}
expectedContent := "This is content."
expectedThinking := "This is thinking content."
expectedToolCalls := []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "test_tool",
Arguments: api.ToolCallFunctionArguments{
"arg": "value",
},
},
},
}
if finalContent.String() != expectedContent {
t.Errorf("expected content %q, got %q", expectedContent, finalContent.String())
}
if finalThinking.String() != expectedThinking {
t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String())
}
if diff := cmp.Diff(expectedToolCalls, finalToolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
}
func TestCogitoParser_StreamingEdgeCases(t *testing.T) {
tests := []struct {
name string
chunks []string
expectedContent string
expectedThinking string
expectedToolCalls []api.ToolCall
hasThinkingSupport bool
}{
{
name: "split_thinking_tag",
chunks: []string{
"This is thinking content</thi",
"nk>This is content.",
},
expectedContent: "This is content.",
expectedThinking: "This is thinking content",
hasThinkingSupport: true,
},
{
name: "split_tool_calls_begin_tag_conservative_parsing",
chunks: []string{
"Content before<tool▁calls▁beg",
"in><tool▁call▁begin>function<tool▁sep>test\n```json\n{}\n```<tool▁call▁end><tool▁calls▁end>",
},
// Parser is conservative - treats incomplete tags as content
expectedContent: "Content before<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>test\n```json\n{}\n```<tool▁call▁end><tool▁calls▁end>",
expectedToolCalls: nil,
hasThinkingSupport: false,
},
{
name: "thinking_disabled_with_split_tags",
chunks: []string{
"Content with </thi",
"nk> should be treated as content.",
},
expectedContent: "Content with </think> should be treated as content.",
expectedThinking: "",
hasThinkingSupport: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &CogitoParser{}
parser.Init(nil, nil, &api.ThinkValue{Value: tt.hasThinkingSupport})
var finalContent, finalThinking strings.Builder
var finalToolCalls []api.ToolCall
for i, chunk := range tt.chunks {
done := i == len(tt.chunks)-1
content, thinking, toolCalls, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error on chunk %d: %v", i, err)
}
finalContent.WriteString(content)
finalThinking.WriteString(thinking)
finalToolCalls = append(finalToolCalls, toolCalls...)
}
if finalContent.String() != tt.expectedContent {
t.Errorf("expected content %q, got %q", tt.expectedContent, finalContent.String())
}
if finalThinking.String() != tt.expectedThinking {
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
}
if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestCogitoParser_HasToolSupport(t *testing.T) {
parser := &CogitoParser{}
if !parser.HasToolSupport() {
t.Error("CogitoParser should support tools")
}
}
func TestCogitoParser_Init(t *testing.T) {
parser := &CogitoParser{}
tools := []api.Tool{
{Function: api.ToolFunction{Name: "test_tool"}},
}
lastMessage := &api.Message{Role: "assistant", Content: "previous"}
returnedTools := parser.Init(tools, lastMessage, nil)
if len(returnedTools) != len(tools) {
t.Errorf("expected %d tools returned, got %d", len(tools), len(returnedTools))
}
}
func TestCogitoParser_parseToolCallContent(t *testing.T) {
tests := []struct {
name string
content string
expected api.ToolCall
expectError bool
}{
{
name: "valid_tool_call_standard_format",
content: `function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
expectError: false,
},
{
name: "valid_tool_call_complex_args",
content: `function<tool▁sep>process_data
` + "```json\n" + `{"items":["item1","item2"],"config":{"enabled":true},"count":42}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: api.ToolCallFunctionArguments{
"items": []any{"item1", "item2"},
"config": map[string]any{"enabled": true},
"count": 42.0,
},
},
},
expectError: false,
},
{
name: "valid_tool_call_empty_args",
content: `function<tool▁sep>no_args_tool
` + "```json\n" + `{}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "no_args_tool",
Arguments: api.ToolCallFunctionArguments{},
},
},
expectError: false,
},
{
name: "missing_separator",
content: `functionget_weather` + "```json\n" + `{"location":"Paris"}` + "\n```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "invalid_function_type",
content: `not_function<tool▁sep>get_weather` + "```json\n" + `{"location":"Paris"}` + "\n```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "missing_json_block_start",
content: `function<tool▁sep>get_weather{"location":"Paris"}` + "```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "missing_json_block_end",
content: `function<tool▁sep>get_weather` + "```json\n" + `{"location":"Paris"}`,
expected: api.ToolCall{},
expectError: true,
},
{
name: "invalid_json",
content: `function<tool▁sep>get_weather` + "```json\n" + `{location:Paris}` + "\n```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "empty_function_type",
content: `<tool▁sep>get_weather` + "```json\n" + `{"location":"Paris"}` + "\n```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "tool_with_spaces_in_name",
content: `function<tool▁sep> get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
expectError: false,
},
{
name: "tool_with_multiline_json",
content: `function<tool▁sep>get_weather
` + "```json\n" + `{
"location": "Paris",
"units": "metric"
}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
"units": "metric",
},
},
},
expectError: false,
},
{
name: "tool_with_nested_objects",
content: `function<tool▁sep>complex_tool
` + "```json\n" + `{"nested":{"deep":{"value":123}}}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "complex_tool",
Arguments: api.ToolCallFunctionArguments{
"nested": map[string]any{
"deep": map[string]any{
"value": 123.0,
},
},
},
},
},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &CogitoParser{}
result, err := parser.parseToolCallContent(tt.content)
if tt.expectError {
if err == nil {
t.Errorf("expected error but got none")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("tool call mismatch (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -6,9 +6,9 @@ import (
)
type Parser interface {
// Init initializes the parser with tools and optional last message for chat prefill
// Init initializes the parser with tools, optional last message for chat prefill, and think value
// Returns processed tools if the parser needs to modify them (e.g., harmony renames them)
Init(tools []api.Tool, lastMessage *api.Message) []api.Tool
Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool
// Add processes streamed content and returns parsed content, thinking, and tool calls
// The done flag indicates if this is the last chunk (used for draining accumulators)
Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error)
@@ -52,6 +52,8 @@ func ParserForName(name string) Parser {
return &PassthroughParser{}
case "harmony":
return harmony.NewHarmonyMessageHandler()
case "cogito":
return &CogitoParser{}
default:
return nil
}
@@ -59,7 +61,7 @@ func ParserForName(name string) Parser {
type PassthroughParser struct{}
func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
return tools // passthrough doesn't modify tools
}

View File

@@ -10,7 +10,7 @@ type mockParser struct {
name string
}
func (m *mockParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (m *mockParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
return tools
}

View File

@@ -43,7 +43,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
return false
}
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
return tools // Qwen doesn't modify tools
}
@@ -432,7 +432,7 @@ func transformToXML(raw string) string {
groups := qwenTagRegex.FindStringSubmatch(match)
tag := groups[1]
var escapedValue strings.Builder
xml.EscapeText(&escapedValue, []byte(groups[2]))
_ = xml.EscapeText(&escapedValue, []byte(groups[2])) // error is always nil for strings.Builder
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
})

View File

@@ -54,7 +54,7 @@ func (p *Qwen3VLParser) setInitialState(lastMessage *api.Message) {
p.state = CollectingThinkingContent
}
func (p *Qwen3VLParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *Qwen3VLParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.setInitialState(lastMessage)
return tools

View File

@@ -198,7 +198,7 @@ func TestQwen3VLNonThinkingParserStreaming(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: false}
parser.Init([]api.Tool{}, nil)
parser.Init([]api.Tool{}, nil, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -515,7 +515,7 @@ func TestQwenOldParserStreaming(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: false}
parser.Init([]api.Tool{}, nil)
parser.Init([]api.Tool{}, nil, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -822,7 +822,7 @@ func TestQwen3VLNonThinkingToolCallWhitespaceHandling(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: false}
parser.Init([]api.Tool{}, nil)
parser.Init([]api.Tool{}, nil, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)

View File

@@ -205,7 +205,7 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, nil)
parser.Init([]api.Tool{}, nil, nil)
// parser.state = CollectingThinkingContent
for i, step := range tc.steps {
@@ -386,7 +386,7 @@ func TestQwen3VLParserState(t *testing.T) {
for _, tc := range cases {
parser := Qwen3VLParser{hasThinkingSupport: tc.hasThinking}
parser.Init(nil, tc.last)
parser.Init(nil, tc.last, nil)
if parser.state != tc.wantState {
t.Errorf("%s: got state %v, want %v", tc.desc, parser.state, tc.wantState)
}
@@ -437,7 +437,7 @@ func TestQwen3VLThinkingParserWithThinkingPrefill(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, last)
parser.Init([]api.Tool{}, last, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -500,7 +500,7 @@ func TestQwen3VLThinkingParserWithNonThinkingPrefill(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, last)
parser.Init([]api.Tool{}, last, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -523,7 +523,7 @@ func TestQwen3VLThinkingParserStreamingAssistantPrefillContent(t *testing.T) {
// last message is assistant with content ⇒ start in CollectingContent
last := &api.Message{Role: "assistant", Content: "has content"}
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, last)
parser.Init([]api.Tool{}, last, nil)
type step struct {
input string
@@ -750,7 +750,7 @@ func TestQwen3VLThinkingWhitespaceHandling(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, nil)
parser.Init([]api.Tool{}, nil, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -859,7 +859,7 @@ func TestQwen3VLToolCallWhitespaceHandling(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, tc.prefillMsg)
parser.Init([]api.Tool{}, tc.prefillMsg, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)

129
model/renderers/cogito.go Normal file
View File

@@ -0,0 +1,129 @@
package renderers
import (
"encoding/json"
"strings"
"github.com/ollama/ollama/api"
)
type CogitoRenderer struct {
isThinking bool
}
func (r *CogitoRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
defaultPrompt := "You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco."
// thinking is enabled: model must support it AND user must request it (true)
enableThinking := r.isThinking && (thinkValue != nil && thinkValue.Bool())
var systemPrompt string
var conversationMessages []api.Message
if len(messages) > 0 && messages[0].Role == "system" {
systemPrompt = messages[0].Content
conversationMessages = messages[1:]
} else {
conversationMessages = messages
}
var finalSystemPrompt string
if enableThinking {
finalSystemPrompt = "Enable deep thinking subroutine.\n\n" + defaultPrompt
if systemPrompt != "" {
finalSystemPrompt += "\n\n" + systemPrompt + "\n\n"
}
} else {
finalSystemPrompt = defaultPrompt
if systemPrompt != "" {
finalSystemPrompt += "\n\n" + systemPrompt
}
}
if len(tools) > 0 {
if finalSystemPrompt != "" {
finalSystemPrompt += "\nYou have the following functions available:\n"
} else {
finalSystemPrompt = "You have the following functions available:\n"
}
for _, tool := range tools {
toolJSON, _ := json.MarshalIndent(tool, "", " ") // TODO(gguo): double check json format
finalSystemPrompt += "```json\n" + string(toolJSON) + "\n```\n"
}
}
sb.WriteString("<begin▁of▁sentence>" + finalSystemPrompt)
outputsOpen := false
isLastUser := false
for i, message := range conversationMessages {
switch message.Role {
case "user":
isLastUser = true
sb.WriteString("<User>" + message.Content + "<Assistant>")
case "assistant":
isLastUser = false
if len(message.ToolCalls) > 0 {
if message.Content != "" {
sb.WriteString(message.Content)
}
sb.WriteString("<tool▁calls▁begin>")
for j, toolCall := range message.ToolCalls {
sb.WriteString("<tool▁call▁begin>function<tool▁sep>" + toolCall.Function.Name)
argsJSON, _ := json.Marshal(toolCall.Function.Arguments)
sb.WriteString("\n```json\n" + string(argsJSON) + "\n```")
sb.WriteString("<tool▁call▁end>")
if j < len(message.ToolCalls)-1 {
sb.WriteString("\n")
}
}
sb.WriteString("<tool▁calls▁end><end▁of▁sentence>")
} else {
sb.WriteString(message.Content + "<end▁of▁sentence>")
}
case "tool":
isLastUser = false
if !outputsOpen {
sb.WriteString("<tool▁outputs▁begin>")
outputsOpen = true
}
sb.WriteString("<tool▁output▁begin>" + message.Content + "<tool▁output▁end>")
hasNextTool := i+1 < len(conversationMessages) && conversationMessages[i+1].Role == "tool"
if hasNextTool {
sb.WriteString("\n")
} else {
sb.WriteString("<tool▁outputs▁end>")
outputsOpen = false
}
}
}
if outputsOpen {
sb.WriteString("<tool▁outputs▁end>")
}
if !isLastUser {
sb.WriteString("<Assistant>")
}
if enableThinking {
sb.WriteString("<think>\n")
}
return sb.String(), nil
}

View File

@@ -0,0 +1,491 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestCogitoRenderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
}{
{
name: "basic user message",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello, how are you?<Assistant>`,
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You are a helpful assistant.<User>Hello, how are you?<Assistant>`,
},
{
name: "conversation with assistant response",
messages: []api.Message{
{Role: "user", Content: "What is the capital of France?"},
{Role: "assistant", Content: "The capital of France is Paris."},
{Role: "user", Content: "Fantastic!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>What is the capital of France?<Assistant>The capital of France is Paris.<end▁of▁sentence><User>Fantastic!<Assistant>`,
},
{
name: "thinking enabled without system",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello, how are you?<Assistant><think>
`,
},
{
name: "thinking enabled with system",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You are a helpful assistant.
<User>Hello, how are you?<Assistant><think>
`,
},
{
name: "thinking disabled",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello, how are you?<Assistant>`,
},
{
name: "with tools",
messages: []api.Message{
{Role: "user", Content: "What's the weather like?"},
},
thinkValue: &api.ThinkValue{Value: false},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
},
Required: []string{"location"},
},
},
},
},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You have the following functions available:
` + "```json\n" + `{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather",
"parameters": {
"type": "object",
"required": [
"location"
],
"properties": {
"location": {
"type": "string",
"description": "City name"
}
}
}
}
}
` + "```\n" + `<User>What's the weather like?<Assistant>`,
},
{
name: "assistant with tool calls",
messages: []api.Message{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: "I'll check the weather in Paris for you.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>What's the weather in Paris?<Assistant>I'll check the weather in Paris for you.<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end><tool▁calls▁end><end▁of▁sentence><Assistant>`,
},
{
name: "tool response",
messages: []api.Message{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
},
{Role: "tool", Content: "Temperature: 22°C, Sunny"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>What's the weather in Paris?<Assistant><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end><tool▁calls▁end><end▁of▁sentence><tool▁outputs▁begin><tool▁output▁begin>Temperature: 22°C, Sunny<tool▁output▁end><tool▁outputs▁end><Assistant>`,
},
{
name: "multiple tool responses",
messages: []api.Message{
{Role: "user", Content: "Get weather for Paris and London"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "London",
},
},
},
},
},
{Role: "tool", Content: "Paris: 22°C, Sunny"},
{Role: "tool", Content: "London: 18°C, Cloudy"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Get weather for Paris and London<Assistant><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end>
<tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"London"}
` + "```" + `<tool▁call▁end><tool▁calls▁end><end▁of▁sentence><tool▁outputs▁begin><tool▁output▁begin>Paris: 22°C, Sunny<tool▁output▁end>
<tool▁output▁begin>London: 18°C, Cloudy<tool▁output▁end><tool▁outputs▁end><Assistant>`,
},
{
name: "thinking with tools",
messages: []api.Message{
{Role: "user", Content: "What's the weather like?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
},
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You have the following functions available:
` + "```json\n" + `{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather",
"parameters": {
"type": "object",
"required": [
"location"
],
"properties": {
"location": {
"type": "string",
"description": "City name"
}
}
}
}
}
` + "```\n" + `<User>What's the weather like?<Assistant><think>
`,
},
// test cases based on cogito
{
name: "single_turn_thinking_false",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello<Assistant>`,
},
{
name: "single_turn_thinking_true",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello<Assistant><think>
`,
},
{
name: "multi_turn_thinking_false",
messages: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "How are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello<Assistant>Hi there!<end▁of▁sentence><User>How are you?<Assistant>`,
},
{
name: "multi_turn_thinking_true",
messages: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "How are you?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello<Assistant>Hi there!<end▁of▁sentence><User>How are you?<Assistant><think>
`,
},
{
name: "multi_with_system_thinking_false",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant"},
{Role: "user", Content: "Start"},
{Role: "assistant", Content: "Okay"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You are a helpful assistant<User>Start<Assistant>Okay<end▁of▁sentence><Assistant>`,
},
{
name: "multi_with_system_thinking_true",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant"},
{Role: "user", Content: "Start"},
{Role: "assistant", Content: "Okay"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You are a helpful assistant
<User>Start<Assistant>Okay<end▁of▁sentence><Assistant><think>
`,
},
{
name: "multi_with_system2_thinking_false",
messages: []api.Message{
{Role: "system", Content: "You are a pirate chatbot who always responds in pirate speak!"},
{Role: "user", Content: "Give me a short introduction to LLMs."},
{Role: "assistant", Content: "Arrr! I'm a pirate"},
{Role: "user", Content: "Tell me more about LLMs."},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You are a pirate chatbot who always responds in pirate speak!<User>Give me a short introduction to LLMs.<Assistant>Arrr! I'm a pirate<end▁of▁sentence><User>Tell me more about LLMs.<Assistant>`,
},
{
name: "multi_with_system2_thinking_true",
messages: []api.Message{
{Role: "system", Content: "You are a pirate chatbot who always responds in pirate speak!"},
{Role: "user", Content: "Give me a short introduction to LLMs."},
{Role: "assistant", Content: "Arrr! I'm a pirate"},
{Role: "user", Content: "Tell me more about LLMs."},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You are a pirate chatbot who always responds in pirate speak!
<User>Give me a short introduction to LLMs.<Assistant>Arrr! I'm a pirate<end▁of▁sentence><User>Tell me more about LLMs.<Assistant><think>
`,
},
// tools
{
name: "tool_calls_only_no_content",
messages: []api.Message{
{Role: "user", Content: "Get weather for Paris"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Get weather for Paris<Assistant><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end><tool▁calls▁end><end▁of▁sentence><Assistant>`,
},
{
name: "complex_tool_arguments",
messages: []api.Message{
{Role: "user", Content: "Process complex data"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: api.ToolCallFunctionArguments{
"items": []any{"item1", "item2", "item3"},
"config": map[string]any{
"enabled": true,
"threshold": 0.95,
"tags": []string{"important", "urgent"},
},
},
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Process complex data<Assistant><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>process_data
` + "```json\n" + `{"config":{"enabled":true,"tags":["important","urgent"],"threshold":0.95},"items":["item1","item2","item3"]}
` + "```" + `<tool▁call▁end><tool▁calls▁end><end▁of▁sentence><Assistant>`,
},
{
name: "empty_messages",
messages: []api.Message{
{Role: "system", Content: ""},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: ""},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello<Assistant><end▁of▁sentence><Assistant>`,
},
{
name: "thinking_with_empty_assistant_content",
messages: []api.Message{
{Role: "user", Content: "Think about this"},
{Role: "assistant", Content: ""},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Think about this<Assistant><end▁of▁sentence><Assistant><think>
`,
},
{
name: "multiple_system_messages",
messages: []api.Message{
{Role: "system", Content: "First instruction"},
{Role: "system", Content: "Second instruction"},
{Role: "user", Content: "Hello"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
First instruction<User>Hello<Assistant>`,
},
{
name: "special_characters_in_content",
messages: []api.Message{
{Role: "user", Content: "What about <|special|> tokens and \"quotes\"?"},
{Role: "assistant", Content: "They're handled normally in content."},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>What about <|special|> tokens and "quotes"?<Assistant>They're handled normally in content.<end▁of▁sentence><Assistant>`,
},
{
name: "long_conversation_multiple_rounds",
messages: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello!"},
{Role: "user", Content: "How are you?"},
{Role: "assistant", Content: "Good, thanks!"},
{Role: "user", Content: "What's the weather?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hi<Assistant>Hello!<end▁of▁sentence><User>How are you?<Assistant>Good, thanks!<end▁of▁sentence><User>What's the weather?<Assistant>`,
},
}
renderer := &CogitoRenderer{isThinking: true}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatalf("Render() error = %v", err)
}
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -56,6 +56,9 @@ func rendererForName(name string) Renderer {
case "qwen3-vl-thinking":
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
return renderer
case "cogito":
renderer := &CogitoRenderer{isThinking: true}
return renderer
default:
return nil
}

View File

@@ -181,7 +181,7 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
}
}
if addSpecial && len(ids) > 0 {
if addSpecial {
ids = spm.vocab.addSpecials(ids)
}

View File

@@ -45,7 +45,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
func (v *Vocabulary) addSpecials(ids []int32) []int32 {
if v.AddBOS && len(v.BOS) > 0 {
if slices.Contains(v.BOS, ids[0]) {
if len(ids) > 0 && slices.Contains(v.BOS, ids[0]) {
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
}
@@ -54,7 +54,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
}
if v.AddEOS && len(v.EOS) > 0 {
if slices.Contains(v.BOS, ids[len(ids)-1]) {
if len(ids) > 0 && slices.Contains(v.BOS, ids[len(ids)-1]) {
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
}

View File

@@ -1,8 +1,12 @@
package model
import "testing"
import (
"testing"
func TestVocabulary_SpecialVocabulary(t *testing.T) {
"github.com/google/go-cmp/cmp"
)
func TestSpecialVocabulary(t *testing.T) {
vocab := &Vocabulary{
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
@@ -14,3 +18,90 @@ func TestVocabulary_SpecialVocabulary(t *testing.T) {
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
}
}
func TestAddSpecialVocabulary(t *testing.T) {
cases := []struct {
name string
vocab *Vocabulary
input []int32
want []int32
}{
{
name: "add bos",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: true,
AddEOS: false,
},
input: []int32{2, 3, 4},
want: []int32{0, 2, 3, 4},
},
{
// TODO(mxyng): this is to match previous behaviour
name: "add bos when already present",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: true,
AddEOS: false,
},
input: []int32{0, 2, 3, 4},
want: []int32{0, 0, 2, 3, 4},
},
{
name: "add eos",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: false,
AddEOS: true,
},
input: []int32{2, 3, 4},
want: []int32{2, 3, 4, 1},
},
{
// TODO(mxyng): this is to match previous behaviour
name: "add eos when already present",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: false,
AddEOS: true,
},
input: []int32{2, 3, 4, 1},
want: []int32{2, 3, 4, 1, 1},
},
{
name: "add both",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: true,
AddEOS: true,
},
input: []int32{2, 3, 4},
want: []int32{0, 2, 3, 4, 1},
},
{
name: "add bos to empty inputs",
vocab: &Vocabulary{
BOS: []int32{0},
EOS: []int32{1},
AddBOS: true,
AddEOS: false,
},
input: []int32{},
want: []int32{0},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got := tt.vocab.addSpecials(tt.input)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("no match (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -10,7 +10,8 @@ import (
)
type WordPiece struct {
vocab *Vocabulary
vocab *Vocabulary
lowercase bool
}
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
@@ -114,8 +115,10 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
subword = ggmlPrefix + subword
}
// TODO: some models might not want [ToLower]
piece = wpm.vocab.Encode(strings.ToLower(subword))
if wpm.lowercase {
subword = strings.ToLower(subword)
}
piece = wpm.vocab.Encode(subword)
if piece >= 0 {
break
}
@@ -140,7 +143,7 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
}
}
if addSpecial && len(ids) > 0 {
if addSpecial {
ids = wpm.vocab.addSpecials(ids)
}
@@ -160,8 +163,9 @@ func (wpm WordPiece) Vocabulary() *Vocabulary {
var _ TextProcessor = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary) WordPiece {
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
return WordPiece{
vocab: vocab,
vocab: vocab,
lowercase: lowercase,
}
}

View File

@@ -15,7 +15,9 @@ func TestWordPiece(t *testing.T) {
AddEOS: true,
BOS: []int32{1},
EOS: []int32{2},
})
},
true, // lowercase
)
ids, err := wpm.Encode("Hello world!", true)
if err != nil {

View File

@@ -549,7 +549,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil {
// Note the chunksum stream
// interuption, but do not cancel
// interruption, but do not cancel
// in-flight downloads. We can still
// make progress on them. Once they are
// done, ErrIncomplete will be returned

View File

@@ -4,9 +4,7 @@
// # Manifests
//
// A manifest is a JSON object that describes a model. The JSON object has a
// single field "layers" which is a list of layers that make up the model. Each
// layer has the following fields:
//
// single field "layers" which is a list of layers that make up the model.
// A layer is a single, logical unit of a model. Layers are stored in the cache
// as files with the name of the digest of the layer. Layers are pushed and
// pulled from the registry as blobs.

View File

@@ -175,7 +175,6 @@ func quantize(in, out *os.File, orig *fsggml.GGML, newFileType fsggml.FileType,
origTensors := orig.Tensors().Items()
outputTensors := make([]*fsggml.Tensor, len(origTensors))
for i, tensor := range origTensors {
tensor := tensor
newType := newType(tensor, kv, qs, newFileType)
newTensor := &fsggml.Tensor{
Name: tensor.Name,

View File

@@ -340,7 +340,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
builtinParser = parsers.ParserForName(m.Config.Parser)
if builtinParser != nil {
// no tools or last message for generate endpoint
builtinParser.Init(nil, nil)
builtinParser.Init(nil, nil, req.Think)
}
}
@@ -2051,7 +2051,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
lastMessage = &msgs[len(msgs)-1]
}
// Initialize parser and get processed tools
processedTools = builtinParser.Init(req.Tools, lastMessage)
processedTools = builtinParser.Init(req.Tools, lastMessage, req.Think)
}
}

View File

@@ -219,7 +219,6 @@ func TestParse(t *testing.T) {
}
for _, tt := range validCases {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()