Compare commits

...

82 Commits

Author SHA1 Message Date
likelovewant
04431b50fa fix 2025-09-28 12:37:28 +08:00
羊撅撅
c47154c08d fix: correct condition for AMDGPU_TARGETS filtering logic (#12412) 2025-09-26 11:38:47 -07:00
Patrick Devine
b04e46da3e bugfix: restore the current runOptions if loading fails in the CLI (#12402)
There are two bugs when using `/load <model>` for a model that doesn't exist, namely:
  1. it will not restore the current model settings if the current model is a thinking model; and
  2. it will crash is the current model is a non-thinking model

This bug fix saves the current runOptions and then restores them if the model load
doesn't happen. It also fixes the crash happening for non-thinking models.
2025-09-25 18:30:45 -07:00
Devon Rifkin
34efbbd3f0 Merge pull request #12417 from ollama/drifkin/qwen3-coder-unicode
parsers: fix unicode handling for qwen3-coder
2025-09-25 15:56:34 -07:00
Devon Rifkin
05ba4ca1f4 parsers: fix unicode handling for qwen3-coder
When trimming whitespace at the end of every chunk, we were iterating
backwards over the string byte-by-byte instead of rune-by-rune.

As an example of how this can cause corruption, suppose we have the
multi-byte character  (`"\u2705"`), which is represented in utf-8 as
the three bytes `0xE2 0x9C 0x85`. It happens that `0x85` is NEL, which
passes `unicode.IsSpace()`. Because we were iterating byte-by-byte, this
caused us to mistakenly slice in the middle of the rune, removing `0x85`
and leaving `0xE2 0x9C`, which beyond being the incorrect place to
slice, is not even a valid utf-8 character.

`trailingWhitespaceLen()` was modified to count from the end in a
rune-aware way. Tests with various multibyte unicode characters were
also added.


Fixes: #12414
2025-09-25 15:47:46 -07:00
Patrick Devine
5a56ff3cf0 cli: add device signin flow when doing ollama push (#12405) 2025-09-25 15:04:43 -07:00
Gabe Goodhart
2fba04b5fb tools: handle the case where a tool call sends "arguments" or "parameters" as a serialized json string (#12413) 2025-09-25 14:37:39 -07:00
Grace
fbd82ba5bb Grace/deepseek v3 migration (#12385)
* init deepseek model file

* temp removal of flash attention implementation

* shapes and proper, can make a pass

* query, key, value have good cosine similarity, but the max diff is a bit high

* Attention block is working! ** with eager for now, have not added the mask line

* Attention block is working! ** with eager for now, have not added the mask line

* working MoE at around 0.95 cosine sim

* added cosine similarity function

* Starting end to end structure

* Trying (and failing) to get rope to work, going to test full thing on tater

* running on tater36... just not the right outputs

* we have the right values for rope... but its still not working?

* chnage Extrapolation Factor to 1

* removed adding residuals twice, removed normalization from shared expert, refactored Norms (Attention, MLP) to be outside the (Attention, MLP) blocks and in the Transformer block instead, add cache setLayer

* Temporary modelfiles for cpu

* change kpass intermediate step to kv, two layer outputs [0,1] look fine

* this calls for 16 chicken nuggets

* whoops

* cleaning up code

* delete stuff we dont need

* getting rid of debug statements for llama cpp

* working with long contexts

* fix long context view error

* reverting some changes I made for files that are not apart of pr

* Added proper tokenizer for deeepseek3

* clean up model and go test

* remove Modelfile

* not passing the tests

* whoops

* how to pass the ci tests

* resolving some of the comments

* rename

* linted and renamed deepseek3 -> deepseek2

* remove name go

* addressed changes - main change was adopting qwen3 naming scheme

* I cannot with linters

* clean up logs

* clean up logs

---------

Co-authored-by: Grace Guo <graceguo@Graces-MBP.localdomain>
Co-authored-by: Grace Guo <graceguo@Graces-MacBook-Pro.local>
Co-authored-by: graceguo <graceguo@tater36.localdomain>
2025-09-24 15:19:47 -07:00
Michael Yang
2e742544bf prefer ollama engine for qwen3moe (#12374) 2025-09-24 11:21:32 -07:00
Devon Rifkin
bbb195a6ff Merge pull request #12393 from ollama/drifkin/fix-built-ins
harmony: don't sanitize built-ins
2025-09-23 23:45:31 -07:00
Devon Rifkin
fd88cd7cb0 harmony: don't sanitize built-ins
In #11910 we started sanitizing function names, but we accidentally were
modifying built-ins like `browser.open` to `browser_open`. This was
removing the special prompt rendering for built-ins, but this wasn't
immediately apparent since the models seem to be reasonably good at
remembering the built-ins even when presented with these slightly
renamed version. This fix prevents built-ins from ever being renamed.
2025-09-23 23:34:55 -07:00
Michael Yang
e1979c571a fix: leaf alt name (#12390)
a leaf node with an alternative name gets all its alternatives names
added into the same branch rather than creating branches themselves
2025-09-23 17:50:53 -07:00
Michael Yang
bf78ed6ee9 add pre:, suf: to tags (#12274) 2025-09-23 16:08:57 -07:00
Michael Yang
a40d427bce multi-regexp pretokenizer (#12325) 2025-09-23 13:21:47 -07:00
Patrick Devine
64883e3c4c auth: fix problems with the ollama keypairs (#12373)
* auth: fix problems with the ollama keypairs

This change adds several fixes including:
  - reading in the pubkey files correctly
  - fixing the push unit test to create a keypair file in a temp directory
  - not return 500 errors for normal status error
2025-09-22 23:20:20 -07:00
Devon Rifkin
41efdd4048 Merge pull request #12339 from ollama/drifkin/harmony-refactor-to-builtin
harmony: remove special casing in routes.go
2025-09-22 13:13:40 -07:00
Daniel Hiltgen
c23e6f4cae tests: add single threaded history test (#12295)
* tests: add single threaded history test

Also tidies up some existing tests to handle more model output variation

* test: add support for testing specific architectures
2025-09-22 11:23:14 -07:00
jmorganca
af060eb250 docs: update cloud.md for cloud models 2025-09-22 13:09:17 -03:00
jmorganca
ae5c33008e docs: move turbo.md to cloud.md 2025-09-22 13:09:17 -03:00
likelovewant
000a3ec8b9 Merge branch 'ollama:main' into main 2025-09-21 10:33:39 +08:00
Devon Rifkin
3677842ff1 Merge pull request #12358 from ollama/drifkin/qwen3-coder-ampersands
parsers: fix `&`s in qwen3coder parameter values
2025-09-20 12:40:33 -07:00
Devon Rifkin
242df70a75 parsers: fix &s in qwen3coder parameter values
In <https://github.com/ollama/ollama/issues/12357> we that the model
will output tool calls such as

```
<function=shell>
<parameter=command>
pwd && ls -la
</parameter>
</function>
```

We parse this using the approach of transforming into valid xml and then
using an xml parser. While we do transform the function and parameter
names, we weren't escaping the parameter values (which in this example
are invalid since `pwd && ls -la` contains unescaped ampersands).

This has been fixed by first transforming the tags in the same way, and
then walking the transformed string and escaping the text in between the
tags. This also fixes a case where `<` in the middle of a parameter
value would cause an xml parse failure.

Fixes: #12357
2025-09-20 12:11:38 -07:00
Patrick Devine
dba39b2eee gemma: fix rope scaling for qat models (#12348)
* gemma: fix rope scaling for qat models

* gofumpt yourself
2025-09-19 15:04:40 -07:00
Michael Yang
9f3a37fd36 fix: model load for unsupported embedding models (#12311)
with #12181, there's now support for embeddings in ollama engine.
this is done by mutating the architecture and adding _embed when it
detects an embedding model. however this introduced a bug where if
an embedding model was run based on an existing ollama engine model
without an embedding implementation, e.g. llama4, it will pass the
initial arch support check but fail when actually loaded.

there's currently two entrypoints to creating a model. previously this
second entrypoint was necessary because calling model.New would also
load the model. since #11818, this is no longer th case so merge them
to reduce complexity
2025-09-18 16:11:08 -07:00
Michael Yang
7460259eb3 feat: qwen3 embed (#12301)
* cleanup

* use pooling.TypeNone

* pooling test

* qwen3 embed
2025-09-18 15:50:32 -07:00
Jeffrey Morgan
22ccdd74c2 server: add unauthorized error to remote chat handler (#12338) 2025-09-18 15:40:31 -07:00
Daniel Hiltgen
0c3d0e7533 build: avoid unbounded parallel builds (#12319)
With the addition of cuda v13, on a clean setup, the level of parallelism
was causing docker desktop to become overwhelmed and compilers
were crashing.  This limits to 8 parallel per build stage, with the ability
to override if you have many more cores available.
2025-09-18 14:57:01 -07:00
Devon Rifkin
e7f56ef3d8 harmony: remove special casing in routes.go
Now that we have a built-in parser abstraction, which was introduced in
<https://github.com/ollama/ollama/pull/12248>, we can modify our harmony
parser to match this and then get rid of nearly all of the
harmony-specific logic in routes.go. We do have a small amount of
code that turns the parser on by default if the architecture matches and
no other built-in parser was provided.

The built-in parser interface was modified in order to handle harmony's
prefill and tool name translation requirements.
2025-09-18 14:55:59 -07:00
Patrick Devine
eb0a5d4459 auth: check the permissions on the private key to see if it's readable (#12336) 2025-09-18 14:34:34 -07:00
Michael Yang
ceac416ec2 fix(integration): check truncated length (#12337) 2025-09-18 14:00:21 -07:00
Patrick Devine
2717dce6fe convert: convert bf16 vision weights to fp16 (#12324)
This change moves back to converting bf16 vision weights to fp16,
specifically if they start with the name "v." (such as v.blk.0.attn_k.weight).

This fixes a bug where converted images are failing because they are trying
to call `im2col` which doesn't have a bf16 kernel in ggml.
2025-09-17 17:43:17 -07:00
frob
9b8187b487 server: skip parsing initial <think> if provided in the prompt for /api/generate (#12289) 2025-09-17 16:39:04 -07:00
Patrick Devine
8b894933a7 engine: add remote proxy (#12307) 2025-09-17 14:40:53 -07:00
Daniel Hiltgen
9c5bf342bc fix: multi-cuda version skew (#12318)
Ensure that in a version skewed multi-cuda setup we use the lowest version for all GPUs
2025-09-17 13:05:09 -07:00
Michael Yang
564b558c92 fix(llama): other llama flavours (#12308)
* fix(llama): rope scale

* spm llama

* skip moe models

* cleanup
2025-09-17 12:12:21 -07:00
Michael Yang
a417ac97ee prefer ollama engine for qwen3 (#12310) 2025-09-17 09:48:21 -07:00
russcoss
05d53457af refactor: use the built-in max/min to simplify the code (#12280)
Signed-off-by: russcoss <russcoss@outlook.com>
2025-09-16 17:14:21 -07:00
Michael Yang
b225508c9b logutil: fix source field (#12279) 2025-09-16 16:18:07 -07:00
Devon Rifkin
fa1c987a29 Merge pull request #12248 from ollama/drifkin/qwen3-coder-parsing
add qwen3-coder tool support
2025-09-16 10:21:43 -07:00
Michael Yang
ad95d5b30b use split activations when possible (#12293)
* use ggml_*_split activations when possible

* forward qkv
2025-09-16 09:51:19 -07:00
Michael Yang
c253433d68 embed: cleanup (#12299)
* cleanup

* use pooling.TypeNone

* pooling test
2025-09-16 09:48:42 -07:00
Beshoy Girgis
a1cff89b30 fix: fix CUDA detection for older GPUs (#12300)
Prioritize GPU compute capability over driver version to ensure
Pascal GPUs (CC 6.1) use compatible CUDA v12 libraries instead of v13.
2025-09-16 07:47:06 -07:00
Daniel Hiltgen
93c64ea1b1 doc: show how to clear the cgo cache (#12298) 2025-09-15 15:45:35 -07:00
Michael Yang
3f6642f6fc model: implement bert in ollama engine (#9080)
* fix truncate

* s/SentencePieceModel/SentencePiece/

* bert

* wordpiece

* refactor pooling

* more tokenizers

* normalize embeddings
2025-09-15 15:35:59 -07:00
Michael Yang
6f7117145f batch: use tensors for outputs (#12185)
this cleans up the model interface slightly without too much impact in
other areas
2025-09-15 14:33:06 -07:00
Devon Rifkin
472feec2ff address comments 2025-09-15 11:46:25 -07:00
Devon Rifkin
47991940d4 add qwen3-coder tool support
The format qwen3-coder uses is relatively unique, both in rendering and
in parsing. To implement parsing, I wrote a custom parser in similar
style to harmony. For the rendering, I found that the logic would be
much more difficult to follow in a template, so I introduced the concept
of a built-in renderer that uses go code, rather than a template to
generate prompts.

I set us up for future built-in parsers and renderers by making it so
they can be specified in a Modelfile like so:

```
RENDERER "qwen3-coder"
PARSER "qwen3-coder"
```

These need to be provided explicitly because the architecture alone is
not enough to understand what format the model expects to receive, and
what format we expect it to output (e.g., qwen3-coder is `qwen3moe`,
which includes other qwen3-family models as well)

I haven't converted harmony to be one of these "built-ins" yet, since
some of it is in flux with the changes @ParthSareen has been making to
move harmony to the runner. It is likely that many other built-ins will
need to move to the runner as well, but I'm able to slightly defer that
decision since qwen3-coder doesn't have thinking (and therefore doesn't
need to be in the runner to make structured outputs work). I expect to
unify harmony with this approach very soon.

Whether a particular model supports tools or thinking was previously
inferred from templates, but without a template we now also use the
parser itself to declare what it supports. If we have future models that
re-use the same parsing format, but have different capabilities, we'll
want to parameterize them and give them different names to be specified
as a `PARSER`.

Misc changes:

- I worked on the renderer by diffing outputs from the reference
  implementation and ours. To make it easier to do this, I extended
  <https://github.com/ollama/ollama/pull/11875> to also support
  returning the prompt via the openai compat layer
2025-09-15 11:33:47 -07:00
likelovewant
9f3f80891d Merge branch 'ollama:main' into main 2025-09-13 10:45:51 +08:00
jmorganca
92b96d54ef Revert "runner: move harmony to runner (#12052)"
This reverts commit 1a558f98e2.
2025-09-12 20:40:14 -03:00
jmorganca
9d56e63dbf Revert "runner: simplify parser entrypoints in runner (#12233)"
This reverts commit 8d6fffaead.
2025-09-12 20:40:14 -03:00
tc-mb
053092185e Fix image cannot be seen with slice image on llama engine
Ollama's recent engine update, llama.cpp, caused all models requiring a slice schema to not display images. As a result, the value of numTokens isn't always the length of the sliced ​​image embed, but rather the end length of the schema. This causes the image embed to not be correctly included during all slice processing.
2025-09-12 16:25:12 -07:00
Daniel Hiltgen
44a6792873 tests: tighten up a few flaky tests (#12271)
Sometimes the context test results are pure emoji's
Thanksgiving has too much variability, so swap for a more straight forward prompt.
2025-09-12 13:59:34 -07:00
Daniel Hiltgen
e4ce68311a cuda: remove compression for better compatibility (#12259)
This retains compatibility with driver 531 and up at the trade-off of space.
2025-09-12 07:59:14 -07:00
Jesse Gross
26214125e8 ollamarunner: Suppress stack trace during memory allocation
Allocation failures can be a normal part of new memory estimates, so
we shouldn't print a stack trace in this case.
2025-09-11 14:30:31 -07:00
Daniel Hiltgen
61fb912ca4 CI: fix windows cuda build (#12246)
* ci: adjust cuda component list

v13 has a different breakdown of the components required to build ollama

* review comments
2025-09-11 12:25:26 -07:00
Jesse Gross
aba1575315 llm: Don't try to load split vision models in the Ollama engine
If a model with a split vision projector is loaded in the Ollama
engine, the projector will be ignored and the model will hallucinate
a response. Instead, fallback and try to load the model in the llama
engine.
2025-09-11 11:41:55 -07:00
Jesse Gross
eb10390de9 llm: Enable new memory estimates by default
New memory estimates (see #11090 for more information) are now
enabled automatically for all models running on the Ollama engine,
improving both stability and performance through more accurate sizing
and allocation. Models running on the llama engine will continue to
use the original style of memory estimation.
2025-09-11 11:21:53 -07:00
Michael Yang
feb18cd710 feat: add dimensions field to embed requests (#12242)
* feat: add field to truncate embeddings

* add openai embeddings for dimensions
2025-09-11 10:36:10 -07:00
fengyuchuanshen
8a7e2055d2 cmd: use slices.Contains to simplify code (#12249) 2025-09-11 09:57:31 -07:00
Jesse Gross
29ddfc2cab ggml: Disable flash attention for gemma2
Our new engine implementation of gemma2 doesn't support flash
attention, which means that it also doesn't support KV cache
quantization. Currently, it is possible to turn these two on,
which will result in a crash.
2025-09-10 16:40:45 -07:00
Jesse Gross
71cb86af3e llm: Remove unneeded warning with flash attention enabled
If flash attention is enabled without KV cache quanitization, we will
currently always get this warning:
level=WARN source=server.go:226 msg="kv cache type not supported by model" type=""
2025-09-10 16:40:45 -07:00
CarbonatedWater.org
5198956372 docs: add ollama-co2 to community integrations (#12230) 2025-09-10 16:37:10 -07:00
Daniel Hiltgen
17a023f34b Add v12 + v13 cuda support (#12000)
* Add support for upcoming NVIDIA Jetsons

The latest Jetsons with JetPack 7 are moving to an SBSA compatible model and
will not require building a JetPack specific variant.

* cuda: bring back dual versions

This adds back dual CUDA versions for our releases,
with v11 and v13 to cover a broad set of GPUs and
driver versions.

* win: break up native builds in build_windows.ps1

* v11 build working on windows and linux

* switch to cuda v12.8 not JIT

* Set CUDA compression to size

* enhance manual install linux docs
2025-09-10 12:05:18 -07:00
Parth Sareen
8d6fffaead runner: simplify parser entrypoints in runner (#12233) 2025-09-10 11:24:42 -07:00
Parth Sareen
20b53eaa72 tests: add tool calling integration test (#12232) 2025-09-09 14:01:11 -07:00
Daniel Hiltgen
6745182885 tests: reduce stress on CPU to 2 models (#12161)
* tests: reduce stress on CPU to 2 models

This should avoid flakes due to systems getting overloaded with 3 (or more) models running concurrently

* tests: allow slow systems to pass on timeout

If a slow system is still streaming a response, and the response
will pass validation, don't fail just because the system is slow.

* test: unload embedding models more quickly
2025-09-09 09:32:15 -07:00
Kashyap Tanuku
f810ec741c readme: add Clueless to community integrations (#12188) 2025-09-08 21:31:29 -07:00
Jesse Gross
e119783e66 llm: Clamp batch size to context size
The context must always be able to store the current batch, so
if the user requests a small context then we should also shrink
the batch to match. This also fixes the TestLongInputContext
test on the new engine. (The old engine already has this behavior.)
2025-09-08 20:40:11 -07:00
Parth Sareen
1a558f98e2 runner: move harmony to runner (#12052) 2025-09-08 15:07:59 -07:00
Gabe Goodhart
7b91c9ce51 Hybrid and recurrent memory estimates (#12186)
This PR updates the memory size estimate logic to better handle recurrent and hybrid-recurrent models which are currently being badly overestimated because the default logic assumes full attention for all layers.

The logic for the sizing of the recurrent layers comes from the llama.cpp implementation

        ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
        ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2025-09-08 14:53:22 -07:00
Daniel Hiltgen
950d33aa30 docs: show how to debug nvidia init failures (#12216)
This debug setting can help troubleshoot obscure initialization failures.
2025-09-08 11:39:00 -07:00
Michael Yang
9714e38dd0 fix: nil pointer dereference if cache is nil (#12215) 2025-09-08 09:53:59 -07:00
frob
4378ae4ffa parser: don't check the file type of safetensors to prevent false negatives. (#12176)
* Don't check the file type of safetensor to prevent false negatives.

---------

Co-authored-by: Patrick Devine <patrick@infrahq.com>
2025-09-05 16:27:40 -07:00
likelovewant
501cb38b8c Merge branch 'ollama:main' into main 2025-09-05 17:58:44 +08:00
Michael Yang
5994e8e8fd embedding gemma model (#12181)
* ollama: add embeddings
2025-09-04 09:09:07 -07:00
likelovewant
59e3a35203 Merge branch 'ollama:main' into main 2025-09-04 19:34:11 +08:00
Michael Yang
b3e6120736 more logutil.Trace (#12177) 2025-09-03 17:24:39 -07:00
Michael Yang
fb92b61754 logutil: add Trace and TraceContext helpers (#12110) 2025-09-02 13:09:12 -07:00
Jesse Gross
8149a3c86e llm: Avoid underflow in free memory logging
If a GPU's free memory is less than the reserved amount, we might get
an underflow. Since it is an unsigned uint64, we print this as a large
number rather than the more correct 0. This only affects logging, the
actual layout code already handles this correctly.

Bug #12138
2025-09-02 12:30:26 -07:00
Daniel Hiltgen
0cc90a8186 harden uncaught exception registration (#12120) 2025-09-02 09:43:55 -07:00
pxwanglu
e42300f25b ml: fix struct field name in comment (#12123) 2025-08-31 16:26:11 -07:00
alpha-nerd-nomyo
66e73809a1 readme: add NOMYO Router to community integrations (#12129) 2025-08-31 13:49:10 -07:00
112 changed files with 6513 additions and 976 deletions

View File

@@ -65,14 +65,36 @@ jobs:
arch: amd64
preset: 'CUDA 12'
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
cuda-version: '12.8'
flags: ''
runner_dir: 'cuda_v12'
- os: windows
arch: amd64
preset: 'CUDA 13'
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
flags: ''
runner_dir: 'cuda_v13'
- os: windows
arch: amd64
preset: 'ROCm 6'
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2'
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
runner_dir: ''
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@@ -96,7 +118,7 @@ jobs:
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
$subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
$subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
}
@@ -138,7 +160,7 @@ jobs:
run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}"
cmake --build --parallel --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
env:
@@ -232,7 +254,7 @@ jobs:
case "$COMPONENT" in
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;

View File

@@ -46,7 +46,7 @@ jobs:
include:
- preset: CPU
- preset: CUDA
container: nvidia/cuda:12.8.1-devel-ubuntu22.04
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm
container: rocm/dev-ubuntu-22.04:6.1.2
@@ -78,8 +78,17 @@ jobs:
include:
- preset: CPU
- preset: CUDA
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
- preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
@@ -102,7 +111,8 @@ jobs:
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait
$subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
}
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path

1
.gitignore vendored
View File

@@ -8,6 +8,7 @@
dist
build
.cache
.gocache
*.exe
.idea
test_data

View File

@@ -38,7 +38,7 @@ if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama)
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama/${OLLAMA_RUNNER_DIR})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
@@ -81,7 +81,7 @@ if(CMAKE_CUDA_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
@@ -99,10 +99,12 @@ check_language(HIP)
if(CMAKE_HIP_COMPILER)
set(HIP_PLATFORM "amd")
find_package(hip REQUIRED)
if(NOT AMDGPU_TARGETS)
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011(:xnack-)|1012(:xnack-)|103[0-6]|110[0-3]|115[01]|120[01])$")
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
find_package(hip REQUIRED)
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|90[012]|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011(:xnack-)|1012(:xnack-)|103[0-6]|110[0-3]|115[0123]|120[01])$")
endif()
if(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
endif()

View File

@@ -18,6 +18,14 @@
"name": "CUDA",
"inherits": [ "Default" ]
},
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50-virtual;60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
}
},
{
"name": "CUDA 12",
"inherits": [ "CUDA" ],
@@ -26,6 +34,14 @@
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
}
},
{
"name": "CUDA 13",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 2"
}
},
{
"name": "JetPack 5",
"inherits": [ "CUDA" ],
@@ -72,11 +88,21 @@
"configurePreset": "CUDA",
"targets": [ "ggml-cuda" ]
},
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 11"
},
{
"name": "CUDA 12",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 12"
},
{
"name": "CUDA 13",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 13"
},
{
"name": "JetPack 5",
"inherits": [ "CUDA" ],

View File

@@ -1,6 +1,7 @@
# vim: filetype=dockerfile
ARG FLAVOR=${TARGETARCH}
ARG PARALLEL=8
ARG ROCMVERSION=6.3.3
ARG JETPACK5VERSION=r35.4.1
@@ -34,26 +35,51 @@ ENV LDFLAGS=-s
FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \
&& cmake --build --parallel --preset 'CPU' \
&& cmake --install build --component CPU --strip --parallel 8
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
FROM base AS cuda-11
ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS cuda-12
ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \
&& cmake --build --parallel --preset 'CUDA 12' \
&& cmake --install build --component CUDA --strip --parallel 8
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS cuda-13
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \
&& cmake --build --parallel --preset 'ROCm 6' \
&& cmake --install build --component HIP --strip --parallel 8
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
ARG CMAKEVERSION
@@ -61,10 +87,11 @@ RUN apt-get update && apt-get install -y curl ccache \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 5' \
&& cmake --build --parallel --preset 'JetPack 5' \
&& cmake --install build --component CUDA --strip --parallel 8
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
ARG CMAKEVERSION
@@ -72,10 +99,11 @@ RUN apt-get update && apt-get install -y curl ccache \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 6' \
&& cmake --build --parallel --preset 'JetPack 6' \
&& cmake --install build --component CUDA --strip --parallel 8
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
@@ -92,10 +120,14 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
FROM --platform=linux/arm64 scratch AS arm64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6

View File

@@ -435,6 +435,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
### Cloud
@@ -624,6 +626,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
### Supported backends

View File

@@ -45,6 +45,12 @@ func checkError(resp *http.Response, body []byte) error {
return nil
}
if resp.StatusCode == http.StatusUnauthorized {
authError := AuthorizationError{StatusCode: resp.StatusCode}
json.Unmarshal(body, &authError)
return authError
}
apiError := StatusError{StatusCode: resp.StatusCode}
err := json.Unmarshal(body, &apiError)
@@ -214,7 +220,8 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
scanner.Buffer(scanBuf, maxBufferSize)
for scanner.Scan() {
var errorResponse struct {
Error string `json:"error,omitempty"`
Error string `json:"error,omitempty"`
SigninURL string `json:"signin_url,omitempty"`
}
bts := scanner.Bytes()
@@ -222,7 +229,13 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
return fmt.Errorf("unmarshal: %w", err)
}
if response.StatusCode >= http.StatusBadRequest {
if response.StatusCode == http.StatusUnauthorized {
return AuthorizationError{
StatusCode: response.StatusCode,
Status: response.Status,
SigninURL: errorResponse.SigninURL,
}
} else if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
@@ -428,3 +441,21 @@ func (c *Client) Version(ctx context.Context) (string, error) {
return version.Version, nil
}
// Signout will signout a client for a local ollama server.
func (c *Client) Signout(ctx context.Context) error {
return c.do(ctx, http.MethodPost, "/api/signout", nil, nil)
}
// Disconnect will disconnect an ollama instance from ollama.com.
func (c *Client) Disconnect(ctx context.Context, encodedKey string) error {
return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil)
}
func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
var resp UserResponse
if err := c.do(ctx, http.MethodPost, "/api/me", nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

View File

@@ -11,6 +11,8 @@ import (
"strings"
"time"
"github.com/google/uuid"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
@@ -36,6 +38,19 @@ func (e StatusError) Error() string {
}
}
type AuthorizationError struct {
StatusCode int
Status string
SigninURL string `json:"signin_url"`
}
func (e AuthorizationError) Error() string {
if e.Status != "" {
return e.Status
}
return "something went wrong, please see the ollama server logs for details"
}
// ImageData represents the raw binary data of an image file.
type ImageData []byte
@@ -313,13 +328,29 @@ func (t *ToolFunction) String() string {
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"`
DoneReason string `json:"done_reason,omitempty"`
// Model is the model name that generated the response.
Model string `json:"model"`
// RemoteModel is the name of the upstream model that generated the response.
RemoteModel string `json:"remote_model,omitempty"`
// RemoteHost is the URL of the upstream Ollama host that generated the response.
RemoteHost string `json:"remote_host,omitempty"`
// CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"`
// Message contains the message or part of a message from the model.
Message Message `json:"message"`
// Done specifies if the response is complete.
Done bool `json:"done"`
// DoneReason is the reason the model stopped generating text.
DoneReason string `json:"done_reason,omitempty"`
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
Metrics
}
@@ -329,13 +360,6 @@ type DebugInfo struct {
ImageCount int `json:"image_count,omitempty"`
}
// DebugTemplateResponse is returned when _debug_render_only is set to true
type DebugTemplateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
DebugInfo DebugInfo `json:"_debug_info"`
}
type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
@@ -388,8 +412,12 @@ type EmbedRequest struct {
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Truncate truncates the input to fit the model's max sequence length.
Truncate *bool `json:"truncate,omitempty"`
// Dimensions truncates the output embedding to the specified dimension.
Dimensions int `json:"dimensions,omitempty"`
// Options lists model-specific options.
Options map[string]any `json:"options"`
}
@@ -427,18 +455,47 @@ type EmbeddingResponse struct {
// CreateRequest is the request passed to [Client.Create].
type CreateRequest struct {
Model string `json:"model"`
Stream *bool `json:"stream,omitempty"`
// Model is the model name to create.
Model string `json:"model"`
// Stream specifies whether the response is streaming; it is true by default.
Stream *bool `json:"stream,omitempty"`
// Quantize is the quantization format for the model; leave blank to not change the quantization level.
Quantize string `json:"quantize,omitempty"`
From string `json:"from,omitempty"`
Files map[string]string `json:"files,omitempty"`
Adapters map[string]string `json:"adapters,omitempty"`
Template string `json:"template,omitempty"`
License any `json:"license,omitempty"`
System string `json:"system,omitempty"`
Parameters map[string]any `json:"parameters,omitempty"`
Messages []Message `json:"messages,omitempty"`
// From is the name of the model or file to use as the source.
From string `json:"from,omitempty"`
// RemoteHost is the URL of the upstream ollama API for the model (if any).
RemoteHost string `json:"remote_host,omitempty"`
// Files is a map of files include when creating the model.
Files map[string]string `json:"files,omitempty"`
// Adapters is a map of LoRA adapters to include when creating the model.
Adapters map[string]string `json:"adapters,omitempty"`
// Template is the template used when constructing a request to the model.
Template string `json:"template,omitempty"`
// License is a string or list of strings for licenses.
License any `json:"license,omitempty"`
// System is the system prompt for the model.
System string `json:"system,omitempty"`
// Parameters is a map of hyper-parameters which are applied to the model.
Parameters map[string]any `json:"parameters,omitempty"`
// Messages is a list of messages added to the model before chat and generation requests.
Messages []Message `json:"messages,omitempty"`
Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"`
// Info is a map of additional information for the model
Info map[string]any `json:"info,omitempty"`
// Deprecated: set the model name with Model instead
Name string `json:"name"`
@@ -476,8 +533,12 @@ type ShowResponse struct {
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"`
Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"`
RemoteModel string `json:"remote_model,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
@@ -536,12 +597,14 @@ type ProcessResponse struct {
// ListModelResponse is a single model description in [ListResponse].
type ListModelResponse struct {
Name string `json:"name"`
Model string `json:"model"`
ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"`
Name string `json:"name"`
Model string `json:"model"`
RemoteModel string `json:"remote_model,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"`
}
// ProcessModelResponse is a single model description in [ProcessResponse].
@@ -565,6 +628,12 @@ type GenerateResponse struct {
// Model is the model name that generated the response.
Model string `json:"model"`
// RemoteModel is the name of the upstream model that generated the response.
RemoteModel string `json:"remote_model,omitempty"`
// RemoteHost is the URL of the upstream Ollama host that generated the response.
RemoteHost string `json:"remote_host,omitempty"`
// CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"`
@@ -588,6 +657,8 @@ type GenerateResponse struct {
Metrics
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
}
// ModelDetails provides details about a model.
@@ -600,6 +671,18 @@ type ModelDetails struct {
QuantizationLevel string `json:"quantization_level"`
}
// UserResponse provides information about a user.
type UserResponse struct {
ID uuid.UUID `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Bio string `json:"bio,omitempty"`
AvatarURL string `json:"avatarurl,omitempty"`
FirstName string `json:"firstname,omitempty"`
LastName string `json:"lastname,omitempty"`
Plan string `json:"plan,omitempty"`
}
// Tensor describes the metadata for a given tensor.
type Tensor struct {
Name string `json:"name"`

View File

@@ -18,21 +18,13 @@ import (
const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) {
func GetPublicKey() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
}
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
@@ -59,11 +51,12 @@ func NewNonce(r io.Reader, length int) (string, error) {
}
func Sign(ctx context.Context, bts []byte) (string, error) {
keyPath, err := keyPath()
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))

View File

@@ -47,6 +47,8 @@ import (
"github.com/ollama/ollama/version"
)
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
if name == "" {
@@ -56,10 +58,8 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string)
if err != nil {
return
}
for _, cap := range resp.Capabilities {
if cap == model.CapabilityThinking {
return
}
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
return
}
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
}
@@ -288,7 +288,17 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
Think: opts.Think,
}
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
return client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
if r.RemoteModel != "" && opts.ShowConnect {
p.StopAndClear()
if strings.HasPrefix(r.RemoteHost, "https://ollama.com") {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", r.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", r.RemoteModel, r.RemoteHost)
}
}
return nil
})
}
func StopHandler(cmd *cobra.Command, args []string) error {
@@ -309,9 +319,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true
opts := runOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
ShowConnect: true,
}
format, err := cmd.Flags().GetString("format")
@@ -369,6 +380,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
prompts = append([]string{string(in)}, prompts...)
opts.ShowConnect = false
opts.WordWrap = false
interactive = false
}
@@ -435,6 +447,15 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
var sErr api.AuthorizationError
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
if sErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, sErr.SigninURL)
}
return nil
}
return err
}
@@ -455,6 +476,59 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generate(cmd, opts)
}
func SigninHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
user, err := client.Whoami(cmd.Context())
if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
fmt.Println("You need to be signed in to Ollama to run Cloud models.")
fmt.Println()
if aErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, aErr.SigninURL)
}
return nil
}
return err
}
if user != nil && user.Name != "" {
fmt.Printf("You are already signed in as user '%s'\n", user.Name)
fmt.Println()
return nil
}
return nil
}
func SignoutHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
err = client.Signout(cmd.Context())
if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
fmt.Println("You are not signed in to ollama.com")
fmt.Println()
return nil
} else {
return err
}
}
fmt.Println("You have signed out of ollama.com")
fmt.Println()
return nil
}
func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -466,6 +540,25 @@ func PushHandler(cmd *cobra.Command, args []string) error {
return err
}
n := model.ParseName(args[0])
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
_, err := client.Whoami(cmd.Context())
if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
fmt.Println("You need to be signed in to push models to ollama.com.")
fmt.Println()
if aErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, aErr.SigninURL)
}
return nil
}
return err
}
}
p := progress.NewProgress(os.Stderr)
defer p.Stop()
@@ -502,12 +595,12 @@ func PushHandler(cmd *cobra.Command, args []string) error {
request := api.PushRequest{Name: args[0], Insecure: insecure}
n := model.ParseName(args[0])
if err := client.Push(cmd.Context(), &request, fn); err != nil {
if spinner != nil {
spinner.Stop()
}
if strings.Contains(err.Error(), "access denied") {
errStr := strings.ToLower(err.Error())
if strings.Contains(errStr, "access denied") || strings.Contains(errStr, "unauthorized") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
}
return err
@@ -541,7 +634,14 @@ func ListHandler(cmd *cobra.Command, args []string) error {
for _, m := range models.Models {
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
var size string
if m.RemoteModel != "" {
size = "-"
} else {
size = format.HumanBytes(m.Size)
}
data = append(data, []string{m.Name, m.Digest[:12], size, format.HumanTime(m.ModifiedAt, "Never")})
}
}
@@ -626,8 +726,8 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
KeepAlive: &api.Duration{Duration: 0},
}
if err := loadOrUnloadModel(cmd, opts); err != nil {
if !strings.Contains(err.Error(), "not found") {
return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err)
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
}
}
@@ -738,12 +838,36 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
}
tableRender("Model", func() (rows [][]string) {
if resp.RemoteHost != "" {
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
}
if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch})
rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))})
rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)})
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)})
var paramStr string
if resp.Details.ParameterSize != "" {
paramStr = resp.Details.ParameterSize
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
if f, ok := v.(float64); ok {
paramStr = format.HumanNumber(uint64(f))
}
}
rows = append(rows, []string{"", "parameters", paramStr})
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
} else {
rows = append(rows, []string{"", "architecture", resp.Details.Family})
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
@@ -991,6 +1115,52 @@ type runOptions struct {
KeepAlive *api.Duration
Think *api.ThinkValue
HideThinking bool
ShowConnect bool
}
func (r runOptions) Copy() runOptions {
var messages []api.Message
if r.Messages != nil {
messages = make([]api.Message, len(r.Messages))
copy(messages, r.Messages)
}
var images []api.ImageData
if r.Images != nil {
images = make([]api.ImageData, len(r.Images))
copy(images, r.Images)
}
var opts map[string]any
if r.Options != nil {
opts = make(map[string]any, len(r.Options))
for k, v := range r.Options {
opts[k] = v
}
}
var think *api.ThinkValue
if r.Think != nil {
cThink := *r.Think
think = &cThink
}
return runOptions{
Model: r.Model,
ParentModel: r.ParentModel,
Prompt: r.Prompt,
Messages: messages,
WordWrap: r.WordWrap,
Format: r.Format,
System: r.System,
Images: images,
Options: opts,
MultiModal: r.MultiModal,
KeepAlive: r.KeepAlive,
Think: think,
HideThinking: r.HideThinking,
ShowConnect: r.ShowConnect,
}
}
type displayResponseState struct {
@@ -1546,6 +1716,22 @@ func NewCLI() *cobra.Command {
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
signinCmd := &cobra.Command{
Use: "signin",
Short: "Sign in to ollama.com",
Args: cobra.ExactArgs(0),
PreRunE: checkServerHeartbeat,
RunE: SigninHandler,
}
signoutCmd := &cobra.Command{
Use: "signout",
Short: "Sign out from ollama.com",
Args: cobra.ExactArgs(0),
PreRunE: checkServerHeartbeat,
RunE: SignoutHandler,
}
listCmd := &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
@@ -1640,6 +1826,8 @@ func NewCLI() *cobra.Command {
stopCmd,
pullCmd,
pushCmd,
signinCmd,
signoutCmd,
listCmd,
psCmd,
copyCmd,

View File

@@ -3,10 +3,12 @@ package cmd
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"reflect"
"strings"
"testing"
"time"
@@ -304,6 +306,8 @@ func TestDeleteHandler(t *testing.T) {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusNotFound)
errPayload := `{"error":"model '%s' not found"}`
w.Write([]byte(fmt.Sprintf(errPayload, req.Name)))
}
return
}
@@ -346,7 +350,7 @@ func TestDeleteHandler(t *testing.T) {
}
err := DeleteHandler(cmd, []string{"test-model-not-found"})
if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") {
if err == nil || !strings.Contains(err.Error(), "model 'test-model-not-found' not found") {
t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
}
}
@@ -488,9 +492,35 @@ func TestPushHandler(t *testing.T) {
w.(http.Flusher).Flush()
}
},
"/api/me": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST request, got %s", r.Method)
}
},
},
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
},
{
name: "not signed in push",
modelName: "notsignedin-model",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/me": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST request, got %s", r.Method)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(map[string]string{
"error": "unauthorized",
"signin_url": "https://somethingsomething",
})
if err != nil {
t.Fatal(err)
}
},
},
expectedOutput: "You need to be signed in to push",
},
{
name: "unauthorized push",
modelName: "unauthorized-model",
@@ -499,12 +529,17 @@ func TestPushHandler(t *testing.T) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(map[string]string{
"error": "access denied",
"error": "403: {\"errors\":[{\"code\":\"ACCESS DENIED\", \"message\":\"access denied\"}]}",
})
if err != nil {
t.Fatal(err)
}
},
"/api/me": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST request, got %s", r.Method)
}
},
},
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
},
@@ -522,6 +557,10 @@ func TestPushHandler(t *testing.T) {
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)
t.Setenv("USERPROFILE", tmpDir)
initializeKeypair()
cmd := &cobra.Command{}
cmd.Flags().Bool("insecure", false, "")
@@ -557,7 +596,7 @@ func TestPushHandler(t *testing.T) {
t.Errorf("expected no error, got %v", err)
}
if tt.expectedOutput != "" {
if got := string(stdout); got != tt.expectedOutput {
if got := string(stdout); !strings.Contains(got, tt.expectedOutput) {
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
}
}
@@ -915,3 +954,286 @@ func TestNewCreateRequest(t *testing.T) {
})
}
}
func TestRunOptions_Copy(t *testing.T) {
// Setup test data
originalKeepAlive := &api.Duration{Duration: 5 * time.Minute}
originalThink := &api.ThinkValue{Value: "test reasoning"}
original := runOptions{
Model: "test-model",
ParentModel: "parent-model",
Prompt: "test prompt",
Messages: []api.Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi there"},
},
WordWrap: true,
Format: "json",
System: "system prompt",
Images: []api.ImageData{
[]byte("image1"),
[]byte("image2"),
},
Options: map[string]any{
"temperature": 0.7,
"max_tokens": 1000,
"top_p": 0.9,
},
MultiModal: true,
KeepAlive: originalKeepAlive,
Think: originalThink,
HideThinking: false,
ShowConnect: true,
}
// Test the copy
copied := original.Copy()
// Test 1: Verify the copy is not the same instance
if &copied == &original {
t.Error("Copy should return a different instance")
}
// Test 2: Verify all fields are copied correctly
tests := []struct {
name string
got interface{}
want interface{}
}{
{"Model", copied.Model, original.Model},
{"ParentModel", copied.ParentModel, original.ParentModel},
{"Prompt", copied.Prompt, original.Prompt},
{"WordWrap", copied.WordWrap, original.WordWrap},
{"Format", copied.Format, original.Format},
{"System", copied.System, original.System},
{"MultiModal", copied.MultiModal, original.MultiModal},
{"HideThinking", copied.HideThinking, original.HideThinking},
{"ShowConnect", copied.ShowConnect, original.ShowConnect},
}
for _, tt := range tests {
if !reflect.DeepEqual(tt.got, tt.want) {
t.Errorf("%s mismatch: got %v, want %v", tt.name, tt.got, tt.want)
}
}
// Test 3: Verify Messages slice is deeply copied
if len(copied.Messages) != len(original.Messages) {
t.Errorf("Messages length mismatch: got %d, want %d", len(copied.Messages), len(original.Messages))
}
if len(copied.Messages) > 0 && &copied.Messages[0] == &original.Messages[0] {
t.Error("Messages should be different instances")
}
// Modify original to verify independence
if len(original.Messages) > 0 {
originalContent := original.Messages[0].Content
original.Messages[0].Content = "modified"
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
t.Error("Messages should be independent after copy")
}
// Restore for other tests
original.Messages[0].Content = originalContent
}
// Test 4: Verify Images slice is deeply copied
if len(copied.Images) != len(original.Images) {
t.Errorf("Images length mismatch: got %d, want %d", len(copied.Images), len(original.Images))
}
if len(copied.Images) > 0 && &copied.Images[0] == &original.Images[0] {
t.Error("Images should be different instances")
}
// Modify original to verify independence
if len(original.Images) > 0 {
originalImage := original.Images[0]
original.Images[0] = []byte("modified")
if len(copied.Images) > 0 && string(copied.Images[0]) == "modified" {
t.Error("Images should be independent after copy")
}
// Restore for other tests
original.Images[0] = originalImage
}
// Test 5: Verify Options map is deeply copied
if len(copied.Options) != len(original.Options) {
t.Errorf("Options length mismatch: got %d, want %d", len(copied.Options), len(original.Options))
}
if len(copied.Options) > 0 && &copied.Options == &original.Options {
t.Error("Options map should be different instances")
}
// Modify original to verify independence
if len(original.Options) > 0 {
originalTemp := original.Options["temperature"]
original.Options["temperature"] = 0.9
if copied.Options["temperature"] == 0.9 {
t.Error("Options should be independent after copy")
}
// Restore for other tests
original.Options["temperature"] = originalTemp
}
// Test 6: Verify KeepAlive pointer is copied (shallow copy)
if copied.KeepAlive != original.KeepAlive {
t.Error("KeepAlive pointer should be the same (shallow copy)")
}
// Test 7: Verify Think pointer creates a new instance
if original.Think != nil && copied.Think == original.Think {
t.Error("Think should be a different instance")
}
if original.Think != nil && copied.Think != nil {
if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) {
t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value)
}
}
// Test 8: Test with zero values
zeroOriginal := runOptions{}
zeroCopy := zeroOriginal.Copy()
if !reflect.DeepEqual(zeroCopy, zeroOriginal) {
fmt.Printf("orig: %#v\ncopy: %#v\n", zeroOriginal, zeroCopy)
t.Error("Copy of zero value should equal original zero value")
}
}
func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) {
// Test with empty slices and maps
original := runOptions{
Messages: []api.Message{},
Images: []api.ImageData{},
Options: map[string]any{},
}
copied := original.Copy()
if copied.Messages == nil {
t.Error("Empty Messages slice should remain empty, not nil")
}
if copied.Images == nil {
t.Error("Empty Images slice should remain empty, not nil")
}
if copied.Options == nil {
t.Error("Empty Options map should remain empty, not nil")
}
if len(copied.Messages) != 0 {
t.Error("Empty Messages slice should remain empty")
}
if len(copied.Images) != 0 {
t.Error("Empty Images slice should remain empty")
}
if len(copied.Options) != 0 {
t.Error("Empty Options map should remain empty")
}
}
func TestRunOptions_Copy_NilPointers(t *testing.T) {
// Test with nil pointers
original := runOptions{
KeepAlive: nil,
Think: nil,
}
copied := original.Copy()
if copied.KeepAlive != nil {
t.Error("Nil KeepAlive should remain nil")
}
if copied.Think != nil {
t.Error("Nil Think should remain nil")
}
}
func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
tests := []struct {
name string
think *api.ThinkValue
}{
{"nil Think", nil},
{"bool true", &api.ThinkValue{Value: true}},
{"bool false", &api.ThinkValue{Value: false}},
{"string value", &api.ThinkValue{Value: "reasoning text"}},
{"int value", &api.ThinkValue{Value: 42}},
{"nil value", &api.ThinkValue{Value: nil}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
original := runOptions{Think: tt.think}
copied := original.Copy()
if tt.think == nil {
if copied.Think != nil {
t.Error("Nil Think should remain nil")
}
return
}
if copied.Think == nil {
t.Error("Non-nil Think should not become nil")
return
}
if copied.Think == original.Think {
t.Error("Think should be a different instance")
}
if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) {
t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value)
}
})
}
}
func TestRunOptions_Copy_Independence(t *testing.T) {
// Test that modifications to original don't affect copy
originalThink := &api.ThinkValue{Value: "original"}
original := runOptions{
Model: "original-model",
Messages: []api.Message{{Role: "user", Content: "original"}},
Options: map[string]any{"key": "value"},
Think: originalThink,
}
copied := original.Copy()
// Modify original
original.Model = "modified-model"
if len(original.Messages) > 0 {
original.Messages[0].Content = "modified"
}
original.Options["key"] = "modified"
if original.Think != nil {
original.Think.Value = "modified"
}
// Verify copy is unchanged
if copied.Model == "modified-model" {
t.Error("Copy Model should not be affected by original modification")
}
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
t.Error("Copy Messages should not be affected by original modification")
}
if copied.Options["key"] == "modified" {
t.Error("Copy Options should not be affected by original modification")
}
if copied.Think != nil && copied.Think.Value == "modified" {
t.Error("Copy Think should not be affected by original modification")
}
}

View File

@@ -195,16 +195,24 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("Usage:\n /load <modelname>")
continue
}
origOpts := opts.Copy()
opts.Model = args[1]
opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model)
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
if err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
opts = origOpts.Copy()
continue
}
return err
}
if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("error: %v\n", err)
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
opts = origOpts.Copy()
continue
}
if strings.Contains(err.Error(), "does not support thinking") {

View File

@@ -28,6 +28,7 @@ type bertModel struct {
LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"`
normalizeEmbeddings bool
PoolingType uint32
}
@@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
var pooling string
for _, m := range modules {
if m.Type == "sentence_transformers.models.Pooling" {
switch m.Type {
case "sentence_transformers.models.Pooling":
pooling = m.Path
break
case "sentence_transformers.models.Normalize":
p.normalizeEmbeddings = true
}
}
@@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false
kv["bert.pooling_type"] = p.PoolingType
kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)

View File

@@ -96,7 +96,7 @@ type safetensor struct {
func (st safetensor) Kind() uint32 {
kind := st.tensorBase.Kind()
if st.dtype == "BF16" && kind != tensorKindFP32 {
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
kind = tensorKindBF16
}

View File

@@ -230,3 +230,65 @@ func TestSafetensors(t *testing.T) {
})
}
}
func TestSafetensorKind(t *testing.T) {
tests := []struct {
name string
st safetensor
expected uint32
}{
{
name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16",
st: safetensor{
tensorBase: &tensorBase{
name: "weight.matrix",
shape: []uint64{10, 10}, // will default to FP16
},
dtype: "BF16",
},
expected: tensorKindBF16,
},
{
name: "BF16 dtype with v. prefix should return base kind",
st: safetensor{
tensorBase: &tensorBase{
name: "v.weight.matrix",
shape: []uint64{10, 10}, // will default to FP16
},
dtype: "BF16",
},
expected: tensorKindFP16,
},
{
name: "BF16 dtype with FP32 base kind should return FP32",
st: safetensor{
tensorBase: &tensorBase{
name: "weight.matrix",
shape: []uint64{10}, // will default to FP32
},
dtype: "BF16",
},
expected: tensorKindFP32,
},
{
name: "Non-BF16 dtype should return base kind",
st: safetensor{
tensorBase: &tensorBase{
name: "weight.matrix",
shape: []uint64{10, 10}, // will default to FP16
},
dtype: "FP16",
},
expected: tensorKindFP16,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.st.Kind()
if result != tt.expected {
t.Errorf("Kind() = %d, expected %d", result, tt.expected)
}
})
}
}

View File

@@ -16,7 +16,7 @@ import (
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
var CudaTegra string = os.Getenv("JETSON_JETPACK")
func cudaVariant(gpuInfo CudaGPUInfo) string {
func cudaVariant(gpuInfos []CudaGPUInfo) string {
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
if CudaTegra != "" {
ver := strings.Split(CudaTegra, ".")
@@ -43,14 +43,22 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
}
}
}
return "sbsa"
}
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
// The detected driver is older than Feb 2023
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
return "v11"
// Check GPU compute capability FIRST, lowest common denominator if multi-gpu
for _, gpuInfo := range gpuInfos {
if gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) {
// GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1)
return "v12"
}
}
return "v12"
// GPU is Turing or newer (CC >= 7.5) - can use newer CUDA
if len(gpuInfos) > 0 && gpuInfos[0].DriverMajor < 13 {
// The detected driver is older than 580 (Aug 2025)
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfos[0].DriverMajor, gpuInfos[0].DriverMinor))
return "v12"
}
return "v13"
}

View File

@@ -284,18 +284,8 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.MinimumMemory = cudaMinimumMemory
gpuInfo.DriverMajor = driverMajor
gpuInfo.DriverMinor = driverMinor
variant := cudaVariant(gpuInfo)
// Start with our bundled libraries
if variant != "" {
variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant)
if _, err := os.Stat(variantPath); err == nil {
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...)
}
}
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
gpuInfo.Variant = variant
if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) {
unsupportedGPUs = append(unsupportedGPUs,
@@ -333,6 +323,24 @@ func GetGPUInfo() GpuInfoList {
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
cudaGPUs = append(cudaGPUs, gpuInfo)
}
// Second pass on NVIDIA GPUs to set lowest common denominator variant and DependencyPaths
variant := cudaVariant(cudaGPUs)
var variantPath string
// Start with our bundled libraries
if variant != "" {
variantPath = filepath.Join(LibOllamaPath, "cuda_"+variant)
if _, err := os.Stat(variantPath); err != nil {
variantPath = ""
}
}
for i := range cudaGPUs {
cudaGPUs[i].Variant = variant
if variantPath != "" {
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
cudaGPUs[i].DependencyPath = append([]string{variantPath}, cudaGPUs[i].DependencyPath...)
}
}
}
// Intel

View File

@@ -1708,6 +1708,7 @@ Advanced parameters:
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `dimensions`: number of dimensions for the embedding
### Examples

40
docs/cloud.md Normal file
View File

@@ -0,0 +1,40 @@
# Cloud
| Ollama's cloud is currently in preview. For full documentation, see [Ollama's documentation](https://docs.ollama.com/cloud).
## Cloud Models
[Cloud models](https://ollama.com/cloud) are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldnt fit on a personal computer.
Ollama currently supports the following cloud models, with more coming soon:
- `gpt-oss:20b-cloud`
- `gpt-oss:120b-cloud`
- `deepseek-v3.1:671b-cloud`
- `qwen3-coder:480b-cloud`
### Get started
To run a cloud model, open the terminal and run:
```
ollama run gpt-oss:120b-cloud
```
To run cloud models with integrations that work with Ollama, first download the cloud model:
```
ollama pull qwen3-coder:480b-cloud
```
Then sign in to Ollama:
```
ollama signin
```
Finally, access the model using the model name `qwen3-coder:480b-cloud` via Ollama's local API or tooling.
## Cloud API access
Cloud models can also be accessed directly on ollama.com's API. For more information, see the [docs](https://docs.ollama.com/cloud).

View File

@@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository:
go run . serve
```
> [!NOTE]
> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
## macOS (Apple Silicon)
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.

View File

@@ -11,12 +11,13 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install
> [!NOTE]
> If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
> If you are upgrading from a prior version, you **MUST** remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
Download and extract the package:
```shell
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
sudo rm -rf /usr/lib/ollama
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
```

View File

@@ -92,6 +92,9 @@ If none of those resolve the problem, gather additional information and file an
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
You may get more details for initialization failures by enabling debug prints in the uvm driver. You should only use this temporarily while troubleshooting
- `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm uvm_debug_prints=1`
## AMD GPU Discovery

View File

@@ -1,107 +0,0 @@
# Turbo
>  Turbo is preview
Ollamas [Turbo](https://ollama.com/turbo) is a new way to run open-source models with acceleration from datacenter-grade hardware.
Currently, the following models are available in Turbo:
- `gpt-oss:20b`
- `gpt-oss:120b`
## Get started
### Ollama for macOS & Windows
Download Ollama
- Select a model such as `gpt-oss:20b` or `gpt-oss:120b`
- Click on **Turbo**. Youll be prompted to create an account or sign in
### Ollamas CLI
- [Sign up](https://ollama.com/signup) for an Ollama account
- Add your Ollama key [to ollama.com](https://ollama.com/settings/keys).
On macOS and Linux:
```shell
cat ~/.ollama/id_ed25519.pub
```
On Windows:
```
type "%USERPROFILE%\.ollama\id_ed25519.pub"
```
- Then run a model setting `OLLAMA_HOST` to `ollama.com`:
```shell
OLLAMA_HOST=ollama.com ollama run gpt-oss:120b
```
### Ollamas Python library
- Download Ollama's [Python library](https://github.com/ollama/ollama-python)
- [Sign up](https://ollama.com/signup) for an Ollama account
- Create an API key by visiting https://ollama.com/settings/keys
```python
from ollama import Client
client = Client(
host="https://ollama.com",
headers={'Authorization': '<api key>'}
)
messages = [
{
'role': 'user',
'content': 'Why is the sky blue?',
},
]
for part in client.chat('gpt-oss:120b', messages=messages, stream=True):
print(part['message']['content'], end='', flush=True)
```
### Ollamas JavaScript library
- Download Ollama's [JavaScript library](https://github.com/ollama/ollama-js)
- [Sign up](https://ollama.com/signup) for an Ollama account
- Create an API key by visiting https://ollama.com/settings/keys
```typescript
import { Ollama } from 'ollama';
const ollama = new Ollama({
host: 'https://ollama.com',
headers: {
Authorization: "Bearer <api key>"
}
});
const response = await ollama.chat({
model: 'gpt-oss:120b',
messages: [{ role: 'user', content: 'Explain quantum computing' }],
stream: true
});
for await (const part of response) {
process.stdout.write(part.message.content)
}
```
### Community integrations
Turbo mode is also compatible with several community integrations.
#### Open WebUI
- Go to **settings** → **Admin settings** → **Connections**
- Under **Ollama API,** click **+**
- For the **URL** put `https://ollama.com`
- For the **API key,** create an API key on https://ollama.com/settings/keys and add it.
- Click **Save**
Now, if you navigate to the model selector, Turbo models should be available under **External**.

View File

@@ -134,6 +134,17 @@ func LoadTimeout() (loadTimeout time.Duration) {
return loadTimeout
}
func Remotes() []string {
var r []string
raw := strings.TrimSpace(Var("OLLAMA_REMOTES"))
if raw == "" {
r = []string{"ollama.com"}
} else {
r = strings.Split(raw, ",")
}
return r
}
func Bool(k string) func() bool {
return func() bool {
if s := Var(k); s != "" {
@@ -185,8 +196,6 @@ var (
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
// Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH")
// Enable the new memory estimation logic
NewMemoryEstimates = Bool("OLLAMA_NEW_ESTIMATES")
)
func String(s string) func() string {
@@ -272,7 +281,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_NEW_ESTIMATES": {"OLLAMA_NEW_ESTIMATES", NewMemoryEstimates(), "Enable the new memory estimation logic"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
// Informational
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},

View File

@@ -57,10 +57,28 @@ func (kv KV) EmbeddingLength() uint64 {
return uint64(kv.Uint("embedding_length"))
}
func (kv KV) HeadCount() []uint64 {
headCountDefault := uint32(1)
headCount := kv.UintOrArrayValueAsArray("attention.head_count", headCountDefault)
if len(headCount) == 1 {
headCountDefault = headCount[0]
}
nLayers := int(kv.BlockCount())
if len(headCount) > nLayers {
slog.Warn("got more elements of attention.head_count than layers", "len(headCount)", len(headCount), "layers", nLayers)
}
out := make([]uint64, nLayers)
for i := range nLayers {
if i >= len(headCount) {
out[i] = uint64(headCountDefault)
} else {
out[i] = uint64(headCount[i])
}
}
return out
}
func (kv KV) HeadCountMax() uint64 {
// TODO(drifkin): using the max value can cause an overestimation. In the
// future if array values become more popular, we can adapt the more invasive
// <https://github.com/ollama/ollama/pull/10225>
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
}
@@ -68,6 +86,27 @@ func (kv KV) HeadCountMin() uint64 {
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
}
func (kv KV) HeadCountKV() []uint64 {
headCountKVDefault := uint32(1)
headCountKV := kv.UintOrArrayValueAsArray("attention.head_count_kv", headCountKVDefault)
if len(headCountKV) == 1 {
headCountKVDefault = headCountKV[0]
}
nLayers := int(kv.BlockCount())
if len(headCountKV) > nLayers {
slog.Warn("got more elements of attention.head_count than layers", "len(headCountKV)", len(headCountKV), "layers", nLayers)
}
out := make([]uint64, nLayers)
for i := range nLayers {
if i >= len(headCountKV) {
out[i] = uint64(headCountKVDefault)
} else {
out[i] = uint64(headCountKV[i])
}
}
return out
}
func (kv KV) HeadCountKVMax() uint64 {
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
}
@@ -100,6 +139,26 @@ func (kv KV) ChatTemplate() string {
return kv.String("tokenizer.chat_template")
}
// ssm architecture parameters
func (kv KV) SSMConvKernel() uint64 {
return uint64(kv.Uint("ssm.conv_kernel"))
}
func (kv KV) SSMInnerSize() uint64 {
return uint64(kv.Uint("ssm.inner_size"))
}
func (kv KV) SSMStateSize() uint64 {
return uint64(kv.Uint("ssm.state_size"))
}
func (kv KV) SSMGroupCount() uint64 {
return uint64(kv.Uint("ssm.group_count"))
}
// general types
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
@@ -131,22 +190,27 @@ func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
}
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
arrVal := kv.UintOrArrayValueAsArray(key, defaultValue)
return slices.Min(arrVal), slices.Max(arrVal)
}
func (kv KV) UintOrArrayValueAsArray(key string, defaultValue uint32) []uint32 {
if u32, ok := keyValue(kv, key, uint32(0)); ok {
return u32, u32
return []uint32{u32}
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
min := slices.Min(u32s.values)
max := slices.Max(u32s.values)
return min, max
return u32s.values
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
min := slices.Min(i32s.values)
max := slices.Max(i32s.values)
if min < 0 || max < 0 {
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
dst := make([]uint32, len(i32s.values))
for i, v := range i32s.values {
if v < 0 {
slog.Warn("array values are unexpectedly negative", "key", key, "i", i, "v", v)
}
dst[i] = uint32(v)
}
return uint32(min), uint32(max)
return dst
}
return defaultValue, defaultValue
return []uint32{defaultValue}
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
@@ -179,6 +243,8 @@ func (kv KV) OllamaEngineRequired() bool {
"gemma3",
"gemma3n",
"mistral3",
"qwen3",
"qwen3moe",
"llama4",
"mllama",
"qwen25vl",
@@ -486,7 +552,9 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCountMax()
headsArr := f.KV().HeadCount()
headsKV := f.KV().HeadCountKVMax()
headsKVArr := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
embeddingHeads := f.KV().EmbeddingHeadCountMax()
@@ -496,12 +564,51 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
layers := f.Tensors().GroupLayers()
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
// Default for models unless special-cased below. These defaults mirror the
// cache usage in llama.cpp under the assumption that models without special
// cases below will use the llamarunner and caching will be handled by the
// llama.cpp layer.
//
// This also assumes that a layer without heads or headsKV set is recurrent
// which is usually the case. Some models (eg nemotronh) use "blocks" in
// place of layers where some are MLP blocks that don't have any cache.
// Models like this will need a special case below to be accurately
// estimated.
var kvTotal uint64
kv = make([]uint64, f.KV().BlockCount())
kvSizeAttn := uint64(0)
kvSizeRecurrent := uint64(0)
for i := range kv {
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
headsL := headsArr[i]
headsKVL := headsKVArr[i]
if headsL > 0 && headsKVL > 0 {
// full attention layer
// NOTE: Assumes uniform values for all attn layers
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKVL) * bytesPerElement)
kvSizeAttn += kv[i]
} else {
// recurrent layer
ssmDConv := f.KV().SSMConvKernel()
ssmDState := f.KV().SSMStateSize()
ssmDInner := f.KV().SSMInnerSize()
ssmNGroups := f.KV().SSMGroupCount()
nEmbdR := uint64(0)
if ssmDConv > 0 {
nEmbdR = (ssmDConv - 1) * (ssmDInner + 2*ssmNGroups*ssmDState)
}
nEmbdS := ssmDState * ssmDInner
// recurrent always uses F32 in llama.cpp backend
// https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L18644
bytesPerElementRecurrent := kvCacheBytesPerElement("f32")
kv[i] = (nEmbdR + nEmbdS) * uint64(bytesPerElementRecurrent)
kvSizeRecurrent += kv[i]
}
kvTotal += kv[i]
}
slog.Debug("default cache size estimate", "attention MiB", float32(kvSizeAttn)/(1024.*1024.), "attention bytes", kvSizeAttn, "recurrent MiB", float32(kvSizeRecurrent)/(1024.*1024.), "recurrent bytes", kvSizeRecurrent)
switch f.KV().Architecture() {
case "llama", "llama4":
@@ -759,12 +866,16 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
// SupportsKVCacheType checks if the requested cache type is supported
func (f GGML) SupportsKVCacheType(cacheType string) bool {
if cacheType == "" || cacheType == "f16" {
return true
}
if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) {
// gpt-oss uses attention with sinks which does not support quantized cache types
slog.Warn("model only supports non-quantized cache types ", "mode", arch)
return cacheType == "f16"
slog.Warn("model only supports non-quantized cache types", "model", arch)
return false
}
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
}
// SupportsFlashAttention checks if the model supports flash attention
@@ -774,6 +885,10 @@ func (f GGML) SupportsFlashAttention() bool {
return false
}
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
return false
}
// Check head counts match and are non-zero
headCountK := f.KV().EmbeddingHeadCountK()
headCountV := f.KV().EmbeddingHeadCountV()
@@ -794,6 +909,8 @@ func kvCacheBytesPerElement(cacheType string) float64 {
return 1 // 1/2 of fp16
case "q4_0":
return 0.5 // 1/4 of fp16
case "f32":
return 4 // f32 (default for recurrent)
default:
return 2 // f16 (default)
}

View File

@@ -1,7 +1,7 @@
package harmony
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
@@ -266,6 +266,8 @@ type HarmonyMessageHandler struct {
state harmonyMessageState
HarmonyParser *HarmonyParser
FunctionNameMap *FunctionNameMap
toolAccumulator *HarmonyToolCallAccumulator
convertedTools map[string]struct{}
}
// NewHarmonyMessageHandler creates a new message handler
@@ -278,6 +280,7 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
HeaderEndTag: "<|message|>",
},
FunctionNameMap: NewFunctionNameMap(),
convertedTools: make(map[string]struct{}),
}
}
@@ -292,7 +295,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
for _, event := range events {
switch event := event.(type) {
case HarmonyEventHeaderComplete:
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event header complete", "header", event.Header)
logutil.Trace("harmony event header complete", "header", event.Header)
switch event.Header.Channel {
case "analysis":
if event.Header.Recipient != "" {
@@ -315,7 +318,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
h.state = harmonyMessageState_Normal
}
case HarmonyEventContentEmitted:
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event content", "content", event.Content, "state", h.state)
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
if h.state == harmonyMessageState_Normal {
contentSb.WriteString(event.Content)
} else if h.state == harmonyMessageState_Thinking {
@@ -385,8 +388,85 @@ func NewFunctionNameMap() *FunctionNameMap {
}
}
// Init initializes the handler with tools and optional last message
// Implements the Parser interface
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
// Initialize the harmony parser
if h.HarmonyParser == nil {
h.HarmonyParser = &HarmonyParser{
MessageStartTag: "<|start|>",
MessageEndTag: "<|end|>",
HeaderEndTag: "<|message|>",
}
}
// Handle prefill for chat mode
if lastMessage != nil {
h.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
} else {
h.HarmonyParser.AddImplicitStart()
}
// Initialize tool accumulator
h.toolAccumulator = h.CreateToolParser()
// Process tools and return renamed versions
if len(tools) == 0 {
return tools
}
processedTools := make([]api.Tool, len(tools))
copy(processedTools, tools)
for i, tool := range processedTools {
if tool.Function.Name != "" {
processedTools[i].Function.Name = h.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
h.convertedTools[tool.Function.Name] = struct{}{}
}
}
return processedTools
}
// Add implements the Parser interface - processes streamed content and extracts content, thinking, and tool calls
func (h *HarmonyMessageHandler) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
content, thinking, toolContent := h.AddContent(s, h.toolAccumulator)
if toolContent != "" {
h.toolAccumulator.Add(toolContent)
}
// tool calls always happen one at a time, and always at the end of a message,
// so for simplicity we defer parsing them until we know we're done
if done {
toolName, raw := h.toolAccumulator.Drain()
if toolName != nil {
name := strings.TrimPrefix(*toolName, "functions.")
name = h.FunctionNameMap.OriginalFromConverted(name)
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(raw), &args); err != nil {
return "", "", nil, fmt.Errorf("error parsing tool call: raw='%s', err=%w", raw, err)
}
calls = append(calls, api.ToolCall{Function: api.ToolCallFunction{Name: name, Arguments: args}})
}
}
return content, thinking, calls, nil
}
// HasToolSupport implements the Parser interface
func (h *HarmonyMessageHandler) HasToolSupport() bool {
return true
}
// HasThinkingSupport implements the Parser interface
func (h *HarmonyMessageHandler) HasThinkingSupport() bool {
return true
}
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
harmonyFunctionName := m.deriveName(userFunctionName)
// built-in functions should not be renamed
if userFunctionName == "browser.open" || userFunctionName == "browser.search" || userFunctionName == "browser.find" || userFunctionName == "python" {
harmonyFunctionName = userFunctionName
}
m.userToHarmony[userFunctionName] = harmonyFunctionName
m.harmonyToUser[harmonyFunctionName] = userFunctionName
return harmonyFunctionName

View File

@@ -513,6 +513,7 @@ func TestFunctionConvertAndAdd(t *testing.T) {
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
{name: "built-in functions should not be renamed", in: []string{"browser.open", "python", "not.a.built-in.function", "browser.not_a_real_built_in"}, want: []string{"browser.open", "python", "not_a_built_in_function", "browser_not_a_real_built_in"}},
}
for i, tt := range tests {

View File

@@ -12,3 +12,6 @@ The integration tests have 2 modes of operating.
> [!IMPORTANT]
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL`

View File

@@ -22,13 +22,12 @@ func TestAPIGenerate(t *testing.T) {
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Prompt: "why is the sky blue? be brief",
Prompt: blueSkyPrompt,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scattering"}
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
@@ -120,14 +119,14 @@ func TestAPIGenerate(t *testing.T) {
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
for _, resp := range blueSkyExpected {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Errorf("none of %v found in %s", anyResp, response)
t.Errorf("none of %v found in %s", blueSkyExpected, response)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
@@ -181,7 +180,7 @@ func TestAPIChat(t *testing.T) {
Messages: []api.Message{
{
Role: "user",
Content: "why is the sky blue? be brief",
Content: blueSkyPrompt,
},
},
Options: map[string]interface{}{
@@ -189,7 +188,6 @@ func TestAPIChat(t *testing.T) {
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scattering"}
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
@@ -279,14 +277,14 @@ func TestAPIChat(t *testing.T) {
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
for _, resp := range blueSkyExpected {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Errorf("none of %v found in %s", anyResp, response)
t.Errorf("none of %v found in %s", blueSkyExpected, response)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for chat")
@@ -410,3 +408,99 @@ func TestAPIEmbeddings(t *testing.T) {
t.Errorf("zero length embedding response")
}
}
func TestAPIToolCalling(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 30 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
modelName := "qwen3:0.6b"
if err := PullIfMissing(ctx, client, modelName); err != nil {
t.Fatalf("pull failed %s", err)
}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
},
},
},
},
}
req := api.ChatRequest{
Model: modelName,
Messages: []api.Message{
{
Role: "user",
Content: "Call get_weather with location set to San Francisco.",
},
},
Tools: tools,
Options: map[string]any{
"temperature": 0,
},
}
stallTimer := time.NewTimer(initialTimeout)
var gotToolCall bool
var lastToolCall api.ToolCall
fn := func(response api.ChatResponse) error {
if len(response.Message.ToolCalls) > 0 {
gotToolCall = true
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
}
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
req.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Chat(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
case <-done:
if genErr != nil {
t.Fatalf("chat failed: %v", genErr)
}
if !gotToolCall {
t.Fatalf("expected at least one tool call, got none")
}
if lastToolCall.Function.Name != "get_weather" {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
}
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
}
case <-ctx.Done():
t.Error("outer test context done while waiting for tool-calling chat")
}
}

View File

@@ -19,14 +19,14 @@ func TestBlueSky(t *testing.T) {
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Prompt: "why is the sky blue?",
Prompt: blueSkyPrompt,
Stream: &stream,
Options: map[string]any{
"temperature": 0,
"seed": 123,
},
}
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
GenerateTestHelper(ctx, t, req, blueSkyExpected)
}
func TestUnicode(t *testing.T) {
@@ -110,12 +110,12 @@ func TestUnicodeModelDir(t *testing.T) {
req := api.GenerateRequest{
Model: smol,
Prompt: "why is the sky blue?",
Prompt: blueSkyPrompt,
Stream: &stream,
Options: map[string]any{
"temperature": 0,
"seed": 123,
},
}
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
GenerateTestHelper(ctx, t, req, blueSkyExpected)
}

View File

@@ -121,6 +121,7 @@ func TestMultiModelStress(t *testing.T) {
// The intent is to go 1 over what can fit so we force the scheduler to thrash
targetLoadCount := 0
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
chooseModels:
for i, model := range chosenModels {
req := &api.GenerateRequest{Model: model}
slog.Info("loading", "model", model)
@@ -142,6 +143,13 @@ func TestMultiModelStress(t *testing.T) {
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
break
}
// Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts
for _, m := range models.Models {
if m.SizeVRAM == 0 {
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
break chooseModels
}
}
}
}
if targetLoadCount == len(chosenModels) {

View File

@@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
}
func TestContextExhaustion(t *testing.T) {
@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Prompt: "Write me a story with a ton of emojis?",
Prompt: "Write me a story in english with a lot of emojis",
Stream: &stream,
Options: map[string]any{
"temperature": 0,
@@ -63,11 +63,11 @@ func TestContextExhaustion(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second)
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
}
// Send multiple generate requests with prior context and ensure the response is coherant and expected
func TestGenerateWithHistory(t *testing.T) {
func TestParallelGenerateWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := GenerateRequests()
numParallel := 2
@@ -113,8 +113,48 @@ func TestGenerateWithHistory(t *testing.T) {
wg.Wait()
}
// Send generate requests with prior context and ensure the response is coherant and expected
func TestGenerateWithHistory(t *testing.T) {
req := api.GenerateRequest{
Model: smol,
Prompt: rainbowPrompt,
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
"num_ctx": 16384,
},
}
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial request
slog.Info("loading", "model", req.Model)
err := client.Generate(ctx,
&api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", req.Model, err)
}
req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
for i := 0; i < len(rainbowFollowups); i++ {
req.Prompt = rainbowFollowups[i]
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
}
}
// Send multiple chat requests with prior context and ensure the response is coherant and expected
func TestChatWithHistory(t *testing.T) {
func TestParallelChatWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := ChatRequests()
numParallel := 2
@@ -164,3 +204,55 @@ func TestChatWithHistory(t *testing.T) {
}
wg.Wait()
}
// Send generate requests with prior context and ensure the response is coherant and expected
func TestChatWithHistory(t *testing.T) {
req := api.ChatRequest{
Model: smol,
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
"num_ctx": 16384,
},
Messages: []api.Message{
{
Role: "user",
Content: rainbowPrompt,
},
},
}
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial request
slog.Info("loading", "model", req.Model)
err := client.Generate(ctx,
&api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", req.Model, err)
}
assistant := DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
for i := 0; i < len(rainbowFollowups); i++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
req.Messages = append(req.Messages,
*assistant,
api.Message{Role: "user", Content: rainbowFollowups[i]},
)
assistant = DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
if assistant == nil {
t.Fatalf("didn't get an assistant response for context")
}
}
}

View File

@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
@@ -38,14 +39,14 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
defer cleanup()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompt: "why is the sky blue?",
Model: "all-minilm",
Prompt: "why is the sky blue?",
KeepAlive: &api.Duration{Duration: 10 * time.Second},
}
res, err := embeddingTestHelper(ctx, client, t, req)
if err != nil {
t.Fatalf("error: %v", err)
t.Fatal(err)
}
if len(res.Embedding) != 384 {
@@ -73,9 +74,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatalf("error: %v", err)
t.Fatal(err)
}
if len(res.Embeddings) != 1 {
@@ -111,9 +111,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatalf("error: %v", err)
t.Fatal(err)
}
if len(res.Embeddings) != 2 {
@@ -155,93 +154,135 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
truncTrue, truncFalse := true, false
type testReq struct {
Name string
Request api.EmbedRequest
want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why",
})
if err != nil {
t.Fatal(err)
}
reqs := []testReq{
cases := []struct {
name string
request api.EmbedRequest
check func(*api.EmbedResponse, error)
}{
{
Name: "Target Truncation",
Request: api.EmbedRequest{
name: "target truncation",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why",
},
},
{
Name: "Default Truncate",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 1},
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
},
},
{
Name: "Explicit Truncate",
Request: api.EmbedRequest{
name: "default truncate",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 3},
},
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
},
},
{
name: "explicit truncate",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 3},
},
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
},
},
{
name: "truncate error",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 3},
},
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
},
{
name: "input after truncate error",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1},
},
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
},
{
name: "input after truncate error",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 0},
},
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
},
}
res := make(map[string]*api.EmbedResponse)
for _, req := range reqs {
response, err := embedTestHelper(ctx, client, t, req.Request)
if err != nil {
t.Fatalf("error: %v", err)
}
res[req.Name] = response
}
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
t.Fatal("expected default request to truncate correctly")
}
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
t.Fatal("expected default request and truncate true request to be the same")
}
// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 1},
})
if err == nil {
t.Fatal("expected error, got nil")
for _, req := range cases {
t.Run(req.name, func(t *testing.T) {
req.check(embedTestHelper(ctx, client, t, req.request))
})
}
}
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
t.Helper()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
t.Fatal(err)
}
response, err := client.Embeddings(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
return client.Embeddings(ctx, &req)
}
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
t.Helper()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
t.Fatal(err)
}
response, err := client.Embed(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
return client.Embed(ctx, &req)
}

View File

@@ -4,7 +4,9 @@ package integration
import (
"context"
"fmt"
"log/slog"
"os"
"testing"
"time"
@@ -20,6 +22,7 @@ func TestLibraryModelsGenerate(t *testing.T) {
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
chatModels := libraryChatModels
for _, model := range chatModels {
@@ -30,16 +33,26 @@ func TestLibraryModelsGenerate(t *testing.T) {
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err)
}
if targetArch != "" {
resp, err := client.Show(ctx, &api.ShowRequest{Name: model})
if err != nil {
t.Fatalf("unable to show model: %s", err)
}
arch := resp.ModelInfo["general.architecture"].(string)
if arch != targetArch {
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
}
}
req := api.GenerateRequest{
Model: model,
Prompt: "why is the sky blue?",
Prompt: blueSkyPrompt,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
"temperature": 0.1,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength"}
anyResp := blueSkyExpected
// Special cases
if model == "duckdb-nsql" {
anyResp = []string{"select", "from"}

View File

@@ -68,14 +68,13 @@ func TestModelsGenerate(t *testing.T) {
// TODO - fiddle with context size
req := api.GenerateRequest{
Model: model,
Prompt: "why is the sky blue?",
Prompt: blueSkyPrompt,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
DoGenerate(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
})
}
}

View File

@@ -40,6 +40,18 @@ var (
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
func TestModelsPerf(t *testing.T) {
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
doModelPerfTest(t, ollamaEngineChatModels)
} else {
doModelPerfTest(t, append(ollamaEngineChatModels, llamaRunnerChatModels...))
}
}
func TestLibraryModelsPerf(t *testing.T) {
doModelPerfTest(t, libraryChatModels)
}
func doModelPerfTest(t *testing.T, chatModels []string) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
@@ -65,14 +77,12 @@ func TestModelsPerf(t *testing.T) {
}
longPrompt := "summarize the following: " + string(data)
var chatModels []string
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
chatModels = ollamaEngineChatModels
} else {
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
}
targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
for _, model := range chatModels {
if !strings.Contains(model, ":") {
model = model + ":latest"
}
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
@@ -88,6 +98,9 @@ func TestModelsPerf(t *testing.T) {
}
arch := resp.ModelInfo["general.architecture"].(string)
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
if targetArch != "" && arch != targetArch {
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
}
if maxVram > 0 {
resp, err := client.List(ctx)
@@ -151,8 +164,8 @@ func TestModelsPerf(t *testing.T) {
prompt string
anyResp []string
}{
{"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}},
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}},
{blueSkyPrompt, blueSkyExpected},
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy", "love", "sorrow", "beauty"}},
}
var gpuPercent int
for _, tc := range testCases {
@@ -241,11 +254,12 @@ func TestModelsPerf(t *testing.T) {
}
}
}
// Round the logged prompt count for comparisons across versions/configurations which can vary slightly
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
"MODEL",
"CONTEXT",
"GPU PERCENT",
"PROMPT COUNT",
"APPROX PROMPT COUNT",
"LOAD TIME",
"PROMPT EVAL TPS",
"EVAL TPS",
@@ -254,7 +268,7 @@ func TestModelsPerf(t *testing.T) {
model,
numCtx,
gpuPercent,
resp.PromptEvalCount,
(resp.PromptEvalCount/10)*10,
float64(resp.LoadDuration)/1000000000.0,
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),

View File

@@ -76,7 +76,7 @@ func TestQuantization(t *testing.T) {
stream := true
genReq := api.GenerateRequest{
Model: newName,
Prompt: "why is the sky blue?",
Prompt: blueSkyPrompt,
KeepAlive: &api.Duration{Duration: 3 * time.Second},
Options: map[string]any{
"seed": 42,
@@ -88,14 +88,13 @@ func TestQuantization(t *testing.T) {
// Some smaller quantizations can cause models to have poor quality
// or get stuck in repetition loops, so we stop as soon as we have any matches
anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"}
reqCtx, reqCancel := context.WithCancel(ctx)
atLeastOne := false
var buf bytes.Buffer
genfn := func(response api.GenerateResponse) error {
buf.Write([]byte(response.Response))
fullResp := strings.ToLower(buf.String())
for _, resp := range anyResp {
for _, resp := range blueSkyExpected {
if strings.Contains(fullResp, resp) {
atLeastOne = true
t.Log(fullResp)

View File

@@ -256,13 +256,29 @@ var (
"snowflake-arctic-embed",
"snowflake-arctic-embed2",
}
blueSkyPrompt = "why is the sky blue? Be brief but factual in your reply"
blueSkyExpected = []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength", "interact"}
rainbowPrompt = "how do rainbows form? Be brief but factual in your reply"
rainbowFollowups = []string{
"Explain the physics involved in them. Be breif in your reply",
"Explain the chemistry involved in them. Be breif in your reply",
"Explain the quantum mechanics involved in them. Be breif in your reply",
"What are common myths related to them? Be brief in your reply",
"What are common fairytales related to them? Be brief in your reply",
"Can they form if there is no rain? Be breif in your reply",
"Can they form if there are no clouds? Be breif in your reply",
"Do they happen on other planets? Be brief in your reply",
}
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refracted", "reflect", "color", "spectrum", "frequency", "end", "gold", "fortune", "blessing", "prosperity"}
)
func init() {
lifecycle.InitLogging()
custom := os.Getenv("OLLAMA_TEST_SMOL_MODEL")
custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL")
if custom != "" {
slog.Info("setting smol test model to " + custom)
slog.Info("setting default test model to " + custom)
smol = custom
}
}
@@ -502,6 +518,22 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
done <- 0
}()
var response string
verify := func() {
// Verify the response contains the expected data
response = buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
}
}
select {
case <-stallTimer.C:
if buf.Len() == 0 {
@@ -517,21 +549,14 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
if genErr != nil {
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
}
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
}
verify()
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
// if they are still generating valid responses
slog.Warn("outer test context done while waiting for generate")
verify()
}
return context
}
@@ -552,7 +577,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
KeepAlive: &api.Duration{Duration: 10 * time.Second},
}, {
Model: smol,
Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply",
Prompt: "how do rainbows form? Be brief but factual in your reply",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
}, {
@@ -568,11 +593,11 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
},
},
[][]string{
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
{"sunlight", "scatter", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorb", "wavelength", "water", "molecule"},
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigment", "particle", "iron oxide", "rust", "air", "water", "wet", "mixture", "mixing", "mineral", "element", "decomposed", "matter", "wavelength"},
{"water", "droplet", "refract", "reflect", "color", "spectrum", "raindrop"},
{"fourth", "july", "declaration", "independence"},
{"nitrogen", "oxygen", "carbon", "dioxide"},
{"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor", "fluid", "particles", "gas"},
}
}
@@ -599,6 +624,22 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
done <- 0
}()
var response string
verify := func() {
// Verify the response contains the expected data
response = buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
}
}
select {
case <-stallTimer.C:
if buf.Len() == 0 {
@@ -614,23 +655,14 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
if genErr != nil {
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
}
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
}
verify()
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
// if they are still generating valid responses
slog.Warn("outer test context done while waiting for chat")
verify()
}
return &api.Message{Role: role, Content: buf.String()}
}

View File

@@ -515,33 +515,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32,
}
nChunks := C.mtmd_input_chunks_size(ic)
numEmbed := llamaContext.Model().NEmbd()
lastChunkSize := 0
embed := make([][]float32, 0)
for i := range int(nChunks) {
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
lastChunkSize = numTokens
slog.Debug("chunk tokens", "index", i, "numTokens", numTokens)
// Encode the chunk
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
return nil, errors.New("unable to encode mtmd image chunk")
}
}
// Get the embeddings
embed := make([][]float32, lastChunkSize)
embd := C.mtmd_get_output_embd(c.c)
if nil == embd {
return nil, errors.New("failed to get image embedding")
}
// Get the embeddings for this chunk
chunkEmbed := make([][]float32, numTokens)
chunkEmbd := C.mtmd_get_output_embd(c.c)
if nil == chunkEmbd {
continue
}
// Extend the embedding array for each token
s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize)
rows := make([]float32, len(s))
copy(rows, s)
for i := range lastChunkSize {
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
// Extend the embedding array for each token
s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed)
rows := make([]float32, len(s))
copy(rows, s)
for i := range numTokens {
chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed]
}
embed = append(embed, chunkEmbed...)
}
slog.Debug("image embeddings", "totalEmbeddings", len(embed))
return embed, nil
}

View File

@@ -0,0 +1,28 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Fri, 29 Aug 2025 16:53:08 -0700
Subject: [PATCH] harden uncaught exception registration
---
ggml/src/ggml.cpp | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp
index 0d388d45..f5bcb446 100644
--- a/ggml/src/ggml.cpp
+++ b/ggml/src/ggml.cpp
@@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{
return false;
}
const auto prev{std::get_terminate()};
- GGML_ASSERT(prev != ggml_uncaught_exception);
- previous_terminate_handler = prev;
+ // GGML_ASSERT(prev != ggml_uncaught_exception);
+ if (prev != ggml_uncaught_exception) {
+ previous_terminate_handler = prev;
+ } else {
+ GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__);
+ }
std::set_terminate(ggml_uncaught_exception);
return true;
}();

View File

@@ -202,7 +202,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
var kvct string
if useFlashAttention {
requested := strings.ToLower(envconfig.KvCacheType())
if requested != "" && f.SupportsKVCacheType(requested) {
if f.SupportsKVCacheType(requested) {
kvct = requested
}
}

View File

@@ -148,7 +148,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
var textProcessor model.TextProcessor
var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
textProcessor, err = model.NewTextProcessor(modelPath)
if len(projectors) == 0 {
textProcessor, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
if err != nil {
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
@@ -161,11 +165,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
}
}
newEstimates := textProcessor != nil && envconfig.NewMemoryEstimates()
if newEstimates {
slog.Info("enabling new memory estimates")
}
// Verify the requested context size is <= the model training size
trainCtx := f.KV().ContextLength()
if opts.NumCtx > int(trainCtx) && trainCtx > 0 {
@@ -173,6 +172,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
opts.NumCtx = int(trainCtx)
}
opts.NumBatch = min(opts.NumBatch, opts.NumCtx)
loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()}
defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount()
@@ -218,7 +219,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if kvct != "" && f.SupportsKVCacheType(kvct) {
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
@@ -431,7 +432,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
}
}()
if newEstimates {
if textProcessor != nil {
return &ollamaServer{llmServer: s}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
@@ -678,8 +679,12 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ
if !(len(gpus) == 1 && gpus[0].Library == "cpu") {
for _, gpu := range gpus {
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory {
available = 0
}
slog.Info("gpu memory", "id", gpu.ID,
"available", format.HumanBytes2(gpu.FreeMemory-envconfig.GpuOverhead()-gpu.MinimumMemory),
"available", format.HumanBytes2(available),
"free", format.HumanBytes2(gpu.FreeMemory),
"minimum", format.HumanBytes2(gpu.MinimumMemory),
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
@@ -861,7 +866,7 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d
}
layers[i] += memory.CPU.Weights[i].Size
layers[i] += memory.CPU.Cache[i].Size
slog.Log(context.TODO(), logutil.LevelTrace, "layer to assign", "layer", i, "size", format.HumanBytes2(layers[i]))
logutil.Trace("layer to assign", "layer", i, "size", format.HumanBytes2(layers[i]))
}
gpuLayers := ml.GPULayersList{}

View File

@@ -1,9 +1,12 @@
package logutil
import (
"context"
"io"
"log/slog"
"path/filepath"
"runtime"
"time"
)
const LevelTrace slog.Level = -8
@@ -27,3 +30,19 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger {
},
}))
}
type key string
func Trace(msg string, args ...any) {
TraceContext(context.WithValue(context.TODO(), key("skip"), 1), msg, args...)
}
func TraceContext(ctx context.Context, msg string, args ...any) {
if logger := slog.Default(); logger.Enabled(ctx, LevelTrace) {
skip, _ := ctx.Value(key("skip")).(int)
pc, _, _, _ := runtime.Caller(1 + skip)
record := slog.NewRecord(time.Now(), LevelTrace, msg, pc)
record.Add(args...)
logger.Handler().Handle(ctx, record)
}
}

View File

@@ -266,7 +266,7 @@ func (m DeviceMemory) LogValue() slog.Value {
// allocation is guaranteed to be provided so that if it failed, the caller can
// accommodate that to make forward progress.
type BackendMemory struct {
// InputsWeights are always located on the CPU and cannot be moved
// InputWeights are always located on the CPU and cannot be moved
InputWeights Memory
// CPU model components are located in system memory. This does not
@@ -416,6 +416,7 @@ type Tensor interface {
AddID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor
L2Norm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
@@ -429,12 +430,13 @@ type Tensor interface {
Sin(ctx Context) Tensor
Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor
QuickGELU(ctx Context) Tensor
SILU(ctx Context) Tensor
RELU(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
SILU(ctx Context, up ...Tensor) Tensor
RELU(ctx Context, up ...Tensor) Tensor
Sigmoid(ctx Context) Tensor
SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
Reshape(ctx Context, shape ...int) Tensor
View(ctx Context, offset int, shape ...int) Tensor

View File

@@ -271,7 +271,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
tt := C.ggml_new_tensor(ctxs[bt], kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
C.ggml_set_name(tt, cname)
slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
logutil.Trace("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
if layer == -1 {
@@ -378,7 +378,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
}
for bs := range maps.Values(bbs) {
slog.Log(context.TODO(), logutil.LevelTrace, "model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
logutil.Trace("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
"size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
}
@@ -811,7 +811,7 @@ func (c *Context) Reserve() {
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])),
logutil.Trace("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])),
"buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferStatus.size)))
}
@@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
if w != nil {
@@ -1424,35 +1431,46 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
}
}
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
if len(t2) > 0 {
return &Tensor{
b: t.b,
t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
}
}
return &Tensor{
b: t.b,
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t),
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
if len(t2) > 0 {
return &Tensor{
b: t.b,
t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
}
}
}
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
if len(t2) > 0 {
return &Tensor{
b: t.b,
t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
}
}
return &Tensor{
b: t.b,
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),

View File

@@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{
return false;
}
const auto prev{std::get_terminate()};
GGML_ASSERT(prev != ggml_uncaught_exception);
previous_terminate_handler = prev;
// GGML_ASSERT(prev != ggml_uncaught_exception);
if (prev != ggml_uncaught_exception) {
previous_terminate_handler = prev;
} else {
GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__);
}
std::set_terminate(ggml_uncaught_exception);
return true;
}();

View File

@@ -26,6 +26,7 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
}
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
ctx.Forward(query)
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
@@ -39,6 +40,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
ctx.Forward(key, value)
if cache != nil {
cache.Put(ctx, key, value)
}

42
ml/nn/pooling/pooling.go Normal file
View File

@@ -0,0 +1,42 @@
package pooling
import (
"github.com/ollama/ollama/ml"
)
type Type uint32
const (
TypeNone Type = iota
TypeMean
TypeCLS
TypeLast
)
func (t Type) String() string {
switch t {
case TypeMean:
return "Mean"
case TypeCLS:
return "CLS"
case TypeLast:
return "Last"
default:
return "Unknown"
}
}
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
switch t {
case TypeMean:
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
case TypeCLS:
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
case TypeLast:
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
return hiddenStates
default:
panic("unknown pooling type")
}
}

View File

@@ -0,0 +1,79 @@
package pooling_test
import (
"bytes"
"os"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/discover"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn/pooling"
)
func setup(tb testing.TB, n int) ml.Backend {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "*.bin")
if err != nil {
tb.Fatal(err)
}
defer f.Close()
if err := fsggml.WriteGGUF(f, fsggml.KV{
"general.architecture": "test",
"test.block_count": uint32(1),
}, []*fsggml.Tensor{
{Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))},
}); err != nil {
tb.Fatal(err)
}
var gpuLayers ml.GPULayersList
if gpus := discover.GetGPUInfo(); len(gpus) > 0 {
gpuLayers = append(gpuLayers, ml.GPULayers{
ID: gpus[0].ID,
Layers: slices.Collect(func(yield func(int) bool) {
for i := range n {
if !yield(i) {
return
}
}
}),
})
}
b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers})
if err != nil {
tb.Fatal(err)
}
return b
}
func TestForward(t *testing.T) {
cases := map[pooling.Type][]float32{
pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11},
pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7},
pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15},
}
for typ, want := range cases {
t.Run(typ.String(), func(t *testing.T) {
b := setup(t, 99)
defer b.Close()
ctx := b.NewContext()
defer ctx.Close()
tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2)
tt = typ.Forward(ctx, tt)
ctx.Forward(tt).Compute(tt)
if diff := cmp.Diff(want, tt.Floats()); diff != "" {
t.Error(diff)
}
})
}
}

View File

@@ -2,10 +2,10 @@ package model
import (
"cmp"
"context"
"fmt"
"iter"
"log/slog"
"slices"
"strings"
"github.com/dlclark/regexp2"
@@ -14,16 +14,28 @@ import (
)
type BytePairEncoding struct {
pre *regexp2.Regexp
vocab *Vocabulary
vocab *Vocabulary
regexps []*regexp2.Regexp
}
var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
if len(pretokenizers) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
pre: regexp2.MustCompile(pre, regexp2.None),
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
}
}),
}
}
@@ -36,13 +48,36 @@ func (bpe BytePairEncoding) Is(id int32, special Special) bool {
}
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
return func(yield func(string) bool) {
for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) {
if !yield(m.String()) {
break
parts := []string{s}
for _, re := range bpe.regexps {
parts = slices.Collect(func(yield func(string) bool) {
for _, part := range parts {
r := []rune(part)
var offset int
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
if offset-m.Index != 0 {
if !yield(string(r[:m.Index])) {
return
}
}
if !yield(m.String()) {
return
}
offset = m.Index + m.Length
}
if offset < len(r) {
if !yield(string(r[offset:])) {
return
}
}
}
}
})
}
return slices.Values(parts)
}
// fragment is a string fragment and their corresponding token IDs
@@ -202,12 +237,11 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
if addSpecial && len(ids) > 0 {
ids = bpe.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
@@ -243,6 +277,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
return sb.String(), nil
}

View File

@@ -59,12 +59,12 @@ func llama(t testing.TB) BytePairEncoding {
}
return NewBytePairEncoding(
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
&Vocabulary{
Values: tokens,
Types: types,
Merges: merges,
},
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
}
@@ -282,3 +282,41 @@ func BenchmarkBytePairEncoding(b *testing.B) {
})
}
}
func TestSplit(t *testing.T) {
cases := []struct {
name string
patterns,
want []string
}{
{
name: "default",
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
},
{
name: "unicode",
patterns: []string{
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
},
{
name: "individual digits",
patterns: []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
})
}
}

View File

@@ -54,10 +54,9 @@ type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal []MultimodalIndex
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs ml.Tensor
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
@@ -66,7 +65,8 @@ type Batch struct {
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs []int32
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal []MultimodalIndex
}

View File

@@ -1,7 +1,6 @@
package model
import (
"context"
"errors"
"fmt"
_ "image/jpeg"
@@ -22,10 +21,15 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
)
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
var (
ErrNoVisionModel = errors.New("this model is missing data required for image input")
ErrUnsupportedModel = errors.New("model not supported")
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
)
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
@@ -104,19 +108,12 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return nil, err
}
arch := b.Config().Architecture()
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
}
m, err := f(b.Config())
m, err := modelForArch(b.Config())
if err != nil {
return nil, err
}
base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem()))
return m, nil
@@ -128,30 +125,38 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
defer r.Close()
meta, err := fsggml.Decode(r, -1)
if err != nil {
return nil, err
}
return getTextProcessor(meta.KV())
}
func getTextProcessor(kv fsggml.KV) (TextProcessor, error) {
arch := kv.Architecture()
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
}
m, err := f(kv)
m, err := modelForArch(meta.KV())
if err != nil {
return nil, err
}
tp, ok := m.(TextProcessor)
if !ok {
return nil, fmt.Errorf("%v is not a TextProcessor", m)
return nil, ErrUnsupportedTokenizer
}
return tp, nil
}
func modelForArch(c fs.Config) (Model, error) {
arch := c.Architecture()
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed"
}
f, ok := models[arch]
if !ok {
return nil, ErrUnsupportedModel
}
return f(c)
}
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type()
@@ -167,38 +172,47 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
// make a copy
tagsCopy := tags
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
tagsCopy = append(tagsCopy, ParseTags(tag))
tagsCopy = append(tagsCopy, parseTag(tag))
}
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
vv.Set(reflect.ValueOf(base))
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
var fn func([]Tag) [][]string
fn = func(tags []Tag) (names [][]string) {
var fn func([]Tag, string, string) [][]string
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
if len(tags) > 0 {
localNames := []string{tags[0].Name}
localNames = append(localNames, tags[0].Alternate...)
for _, localName := range localNames {
fullName := []string{localName}
nested := fn(tags[1:])
if len(nested) > 0 {
for _, rest := range nested {
names = append(names, append(fullName, rest...))
var names []string
if tags[0].name != "" {
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
names = append(names, prefix+n+suffix)
}
}
childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix)
if len(names) == 0 {
// current tag has no name, use child names only
fullNames = append(fullNames, childNames...)
} else if len(childNames) == 0 {
// current tag has names but no children, create branches for each name
for _, name := range names {
fullNames = append(fullNames, []string{name})
}
} else {
// merge each name with each child
for _, name := range names {
for _, childName := range childNames {
fullNames = append(fullNames, append([]string{name}, childName...))
}
} else {
names = append(names, fullName)
}
}
}
return names
return fullNames
}
names := fn(tagsCopy)
names := fn(tagsCopy, "", "")
for _, name := range names {
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "found tensor", "", tensor)
logutil.Trace("found tensor", "", tensor)
vv.Set(reflect.ValueOf(tensor))
break
}
@@ -209,9 +223,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
for i := range vv.Len() {
vvv := vv.Index(i)
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
} else {
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
}
}
}
@@ -239,7 +253,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
vv = vv.Elem()
}
vv = vv.Elem()
vv = reflect.Indirect(vv)
if v.IsNil() {
vv = reflect.New(v.Type().Elem()).Elem()
}
@@ -250,18 +264,31 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
}
type Tag struct {
Name string
Alternate []string
name,
// prefix and suffix are applied to child tags
prefix,
suffix string
alternatives []string
}
func ParseTags(s string) (tag Tag) {
func parseTag(s string) (tag Tag) {
parts := strings.Split(s, ",")
if len(parts) > 0 {
tag.Name = parts[0]
tag.name = parts[0]
for _, part := range parts[1:] {
if value, ok := strings.CutPrefix(part, "alt:"); ok {
tag.Alternate = append(tag.Alternate, value)
if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
// elevate alternative to primary if no primary given
tag.name = value
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
} else if ok {
tag.alternatives = append(tag.alternatives, value)
}
if value, ok := strings.CutPrefix(part, "pre:"); ok {
tag.prefix = value
}
if value, ok := strings.CutPrefix(part, "suf:"); ok {
tag.suffix = value
}
}
}

View File

@@ -1,9 +1,9 @@
package model
import (
"errors"
"reflect"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
@@ -12,7 +12,6 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
)
func TestParseTags(t *testing.T) {
@@ -23,14 +22,14 @@ func TestParseTags(t *testing.T) {
{
value: "output",
want: Tag{
Name: "output",
name: "output",
},
},
{
value: "output,alt:token_embd",
want: Tag{
Name: "output",
Alternate: []string{
name: "output",
alternatives: []string{
"token_embd",
},
},
@@ -39,8 +38,8 @@ func TestParseTags(t *testing.T) {
for _, tt := range cases {
t.Run(tt.value, func(t *testing.T) {
got := ParseTags(tt.value)
if diff := cmp.Diff(tt.want, got); diff != "" {
got := parseTag(tt.value)
if diff := cmp.Diff(tt.want, got, cmp.AllowUnexported((Tag{}))); diff != "" {
t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
}
})
@@ -126,6 +125,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
Input *nn.Embedding `gguf:"input"`
Output *nn.Linear `gguf:"output,alt:input"`
Nested *nested `gguf:"nested"`
Tensor ml.Tensor `gguf:"leaf,alt:tensor"`
}
var m fakeModel
@@ -134,6 +134,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
names: []string{
"input.weight",
"nested.b.weight",
"leaf",
},
}}, v.Elem()))
@@ -143,44 +144,115 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
Nested: &nested{
Weight: &nn.Linear{Weight: &fakeTensor{Name: "nested.b.weight"}},
},
Tensor: &fakeTensor{Name: "leaf"},
}, m); diff != "" {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
}
}
func TestGetTextProcessor(t *testing.T) {
tp, err := getTextProcessor(fsggml.KV{})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
t.Errorf("unexpected error: %v", err)
} else if tp != nil {
t.Error("expected nil tp")
func TestPopulateFieldsPrefixSuffixName(t *testing.T) {
type fakeBlock struct {
A *nn.Linear `gguf:"a"`
B *nn.Linear `gguf:",pre:b_"`
C *nn.Linear `gguf:",suf:_c"`
XY *nn.Linear `gguf:",pre:x_,suf:_y"`
}
models["dummy"] = func(fs.Config) (Model, error) {
return notTextProcessorModel{}, nil
type fakeModel struct {
Blocks []fakeBlock `gguf:"blk"`
}
tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
t.Errorf("unexpected error: %v", err)
} else if tp != nil {
t.Error("expected nil tp")
m := fakeModel{
Blocks: make([]fakeBlock, 2),
}
v := reflect.ValueOf(&m)
v.Elem().Set(populateFields(Base{b: &fakeBackend{
names: []string{
"blk.0.a.weight",
"blk.0.b_weight",
"blk.0.b_bias",
"blk.0.weight_c",
"blk.0.x_weight_y",
"blk.1.a.weight",
"blk.1.b_weight",
"blk.1.b_bias",
"blk.1.weight_c",
"blk.1.x_weight_y",
},
}}, v.Elem()))
if diff := cmp.Diff(fakeModel{
Blocks: []fakeBlock{
{
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.a.weight"}},
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.b_weight"}, Bias: &fakeTensor{Name: "blk.0.b_bias"}},
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.weight_c"}},
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.x_weight_y"}},
},
{
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.a.weight"}},
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.b_weight"}, Bias: &fakeTensor{Name: "blk.1.b_bias"}},
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.weight_c"}},
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.x_weight_y"}},
},
},
}, m); diff != "" {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
}
}
type notTextProcessorModel struct{}
func TestModelForArch(t *testing.T) {
type fakeModel struct {
Model
}
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
panic("unimplemented")
}
type fakeEmbeddingModel struct {
Model
}
func (notTextProcessorModel) Backend() ml.Backend {
panic("unimplemented")
}
models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil }
models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil }
func (notTextProcessorModel) Config() config {
panic("unimplemented")
cases := []struct {
name string
config fs.Config
want any
err error
}{
{
name: "model",
config: fsggml.KV{
"general.architecture": "model",
},
want: fakeModel{},
},
{
name: "embedding",
config: fsggml.KV{
"general.architecture": "model",
"model.pooling_type": uint32(1),
},
want: fakeEmbeddingModel{},
},
{
name: "unsupported",
config: fsggml.KV{
"general.architecture": "unsupported",
},
err: ErrUnsupportedModel,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := modelForArch(tt.config)
if !errors.Is(err, tt.err) {
t.Fatal(err)
}
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff)
}
})
}
}

181
model/models/bert/embed.go Normal file
View File

@@ -0,0 +1,181 @@
package bert
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
Layers []EncoderLayer `gguf:"blk"`
Options
}
// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
}
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
if m.normalize {
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
}
return hiddenStates, nil
}
type EncoderLayer struct {
*Attention
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
*MLP
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
}
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
// Attention
residual := hiddenStates
hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
// MLP
residual = hiddenStates
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
query := a.Query.Forward(ctx, hiddenStates)
if a.QueryNorm != nil {
query = a.QueryNorm.Forward(ctx, query, opts.eps)
}
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
key := a.Key.Forward(ctx, hiddenStates)
if a.KeyNorm != nil {
key = a.KeyNorm.Forward(ctx, key, opts.eps)
}
key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
value := a.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return a.Output.Forward(ctx, attention)
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
}
type Options struct {
hiddenSize,
numHeads,
numKVHeads,
keyLength,
valueLength int
poolingType pooling.Type
eps float32
normalize bool
}
func (o Options) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func New(c fs.Config) (model.Model, error) {
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
//nolint:misspell
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
// support it for compatibility.
c.Uint("tokenizer.ggml.seperator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_epsilon"),
poolingType: pooling.Type(c.Uint("pooling_type")),
normalize: c.Bool("normalize_embeddings", true),
},
}, nil
}
func init() {
model.Register("bert", New)
model.Register("bert_embed", New)
}

View File

@@ -0,0 +1,324 @@
package deepseek2
// uses deepseek 2 architecture but written based on deepseek 3 model
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
numExpertsUsed int
numExperts int
normTopKProb bool
routedScalingFactor float32
kvLoraRank,
qkNopeHeadDim,
qkRopeHeadDim,
kqNopeHeadDim,
qkHeadDim int
qLoraRank int
vHeadDim int
hiddenSize,
numHeads,
numKVHeads,
keyLength,
valueLength,
originalContextLength int
eps,
ropeBase,
ropeScale float32
kqScale float64
}
func (o Options) RoPEOptions() []func(*rope.Options) {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
return []func(*rope.Options){
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(attnFactor),
}
}
type Attention struct {
Q *nn.Linear `gguf:"attn_q"`
QA *nn.Linear `gguf:"attn_q_a"`
QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"`
QB *nn.Linear `gguf:"attn_q_b"`
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
KVB *nn.Linear `gguf:"attn_kv_b"`
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
seqLength := hiddenStates.Dim(1)
var query ml.Tensor
if opts.qLoraRank == 0 { // nil {
query = attn.Q.Forward(ctx, hiddenStates)
} else {
query = attn.QA.Forward(ctx, hiddenStates)
query = attn.QANorm.Forward(ctx, query, opts.eps)
query = attn.QB.Forward(ctx, query)
}
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
qPass := query.View(ctx, 0,
opts.qkNopeHeadDim, query.Stride(1),
query.Dim(1), query.Stride(2),
query.Dim(2))
qRot := query.View(ctx, opts.qkNopeHeadDim*query.Stride(0),
opts.qkRopeHeadDim, query.Stride(1),
query.Dim(1), query.Stride(2),
query.Dim(2))
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
kPass := compressedKV.View(ctx, 0, opts.kvLoraRank, compressedKV.Stride(1), compressedKV.Dim(1))
kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0),
opts.qkRopeHeadDim, compressedKV.Stride(1),
1, compressedKV.Stride(1),
compressedKV.Dim(1))
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
kPass = attn.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kPass = kv.View(ctx, 0, opts.kqNopeHeadDim, kv.Stride(1), kv.Dim(1), kv.Stride(2), kv.Dim(2))
value := kv.View(ctx, opts.kqNopeHeadDim*kv.Stride(0),
opts.vHeadDim, kv.Stride(1),
kv.Dim(1), kv.Stride(2),
kv.Dim(2)).Contiguous(ctx)
qRot = fast.RoPE(ctx, qRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = kRot.Repeat(ctx, 1, qPass.Dim(1))
query = qRot.Concat(ctx, qPass, 0)
key := kRot.Concat(ctx, kPass, 0)
attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
}
type MLP interface {
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
}
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"`
SharedExpert *dense `gguf:",suf:_shexp"`
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
}
func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor {
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices)
hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices)
hiddenStates = hiddenStates.SILU(ctx, upStates)
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
experts = experts.Mul(ctx, topKWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextStates
}
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
scores = scores.Add(ctx, moe.ExpProbsBias)
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
return topKIndices
}
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
residuals := hiddenStates
routerLogits := moe.Router.Forward(ctx, hiddenStates)
scores := routerLogits.Sigmoid(ctx)
topKIndices := moe.topKIndices(ctx, scores, opts)
topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices)
if opts.normTopKProb {
topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx))
topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
}
topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor))
hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts)
sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts)
hiddenStates = hiddenStates.Add(ctx, sharedExpertResult)
return hiddenStates
}
type dense struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *Attention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP MLP
}
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenStates
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
}
func New(c fs.Config) (model.Model, error) {
layers := make([]Layer, c.Uint("block_count"))
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
for i := range layers {
if i < firstDenseLayerIndex {
layers[i].MLP = &dense{}
} else {
layers[i].MLP = &sparse{}
}
}
mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor"))))
kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length")))
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
// Split regex into multiple parts (according to DeepSeek3's regex)
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
),
Layers: layers,
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("expert_weights_norm", true),
qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal,
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
qkHeadDim: int(c.Uint("attention.key_length")),
vHeadDim: int(c.Uint("attention.value_length")),
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
routedScalingFactor: c.Float("expert_weights_scale"),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
kqScale: kqScale,
},
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.qkRopeHeadDim, m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("deepseek2", New)
}

View File

@@ -24,7 +24,7 @@ type Options struct {
type Model struct {
model.Base
model.SentencePieceModel
model.SentencePiece
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -40,7 +40,7 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) {
attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
ropeScale: c.Float("rope.scaling.factor", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"),
},
@@ -88,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -98,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
}
type MLP struct {
@@ -138,7 +138,7 @@ type MLP struct {
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
@@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
lastLayerOutputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)

View File

@@ -0,0 +1,62 @@
package gemma3
import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type embedModel struct {
model.Base
model.SentencePiece
*TextModel
poolingType pooling.Type
Dense [2]*nn.Linear `gguf:"dense"`
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates)
}
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil
}
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
TextModel: newTextModel(c),
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return m, nil
}

View File

@@ -16,7 +16,7 @@ import (
type Model struct {
model.Base
model.SentencePieceModel
model.SentencePiece
*VisionModel `gguf:"v"`
*TextModel
@@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -141,12 +141,11 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("gemma3", New)
model.Register("gemma3_embed", newEmbedModel)
}

View File

@@ -53,7 +53,10 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
ropeScale: 1,
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
// (8 instead of 1)
// ropeScale: c.Float("rope.scaling.factor", 1.0),
},
}
@@ -84,7 +87,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -95,7 +98,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -113,7 +116,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
ropeBase = m.TextConfig.ropeGlobalBase
}
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
}
type TextMLP struct {
@@ -123,7 +126,7 @@ type TextMLP struct {
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -159,8 +162,10 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
// set image embeddings
@@ -191,12 +196,12 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
lastLayerOutputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState)
return hiddenState
}

View File

@@ -10,7 +10,7 @@ import (
type Model struct {
model.Base
model.SentencePieceModel
model.SentencePiece
*TextModel
}
@@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
SentencePieceModel: model.NewSentencePieceModel(
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),

View File

@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
@@ -95,7 +95,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
ropeBase = m.ropeBaseLocal
}
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
}
type TextScaledWordEmbedding struct {
@@ -170,8 +170,7 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
}
active = d.PerLayerInputGate.Forward(ctx, active)
active = active.GELU(ctx)
active = active.Mul(ctx, perLayerInput)
active = active.GELU(ctx, perLayerInput)
active = d.PerLayerProjection.Forward(ctx, active)
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
@@ -257,14 +256,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
query := attn.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
var key, value ml.Tensor
if !sharedKV {
key = attn.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
value = attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
@@ -292,7 +291,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
}
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
hiddenStates = hiddenStates.GELU(ctx, upStates)
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
return hiddenStates
}
@@ -350,7 +349,7 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeBase: c.Float("rope.freq_base", 1_000_000),
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
ropeScale: c.Float("rope.freq_scale", 1.0),
ropeScale: c.Float("rope.scaling.factor", 1.0),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
activationSparsityScale: c.Floats("activation_sparsity_scale"),

View File

@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
}
var outputs ml.Tensor
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if i == len(m.TransformerBlocks)-1 {
outputs = batch.Outputs
}
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
@@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
}
hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7)
hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7)
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
@@ -227,17 +227,6 @@ func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer",
strings.Join([]string{
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`\p{N}{1,3}`,
` ?[^\s\p{L}\p{N}]+[\r\n/]*`,
`\s*[\r\n]+`,
`\s+(?!\S)`,
`\s+`,
}, "|"),
),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -250,6 +239,15 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
strings.Join([]string{
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`\p{N}{1,3}`,
` ?[^\s\p{L}\p{N}]+[\r\n/]*`,
`\s*[\r\n]+`,
`\s+(?!\S)`,
`\s+`,
}, "|"),
),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),

View File

@@ -2,7 +2,6 @@ package llama
import (
"cmp"
"fmt"
"math"
"github.com/ollama/ollama/fs"
@@ -23,51 +22,80 @@ type Options struct {
type Model struct {
model.Base
model.BytePairEncoding
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
Options
}
func New(c fs.Config) (model.Model, error) {
// This model currently only supports the gpt2 tokenizer
if c.String("tokenizer.ggml.model") == "llama" {
return nil, fmt.Errorf("unsupported tokenizer: llama")
if c.Uint("expert_count") > 0 {
// TODO: support mixtures of experts
return nil, model.ErrUnsupportedModel
}
// Best effort detection of library/deepseek-coder model(s) which are incompatible
if c.String("general.name") == "deepseek-ai" {
return nil, fmt.Errorf("unsupported model: %s", c.String("general.name"))
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
var processor model.TextProcessor
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
}
switch c.String("tokenizer.ggml.model") {
case "gpt2":
var pretokenizers []string
switch c.String("tokenizer.ggml.pre") {
case "default":
// no-op use the default bpe pretokenizer
case "qwen2":
pretokenizers = []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
case "refact":
pretokenizers = []string{
`\p{N}`,
`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`,
}
case "tekken":
pretokenizers = []string{
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
default:
// use a llama-style pretokenizer
pretokenizers = []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
processor = model.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeBase: c.Float("rope.freq_base", 1e5),
ropeScale: c.Float("rope.scaling.factor", 1),
},
}
@@ -98,8 +126,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
@@ -108,7 +136,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
}
type MLP struct {
@@ -118,7 +146,7 @@ type MLP struct {
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -160,10 +188,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)

View File

@@ -34,8 +34,6 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer",
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -48,6 +46,7 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
@@ -176,9 +175,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
}
func init() {

View File

@@ -33,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if useRope {
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
}
if opts.useQKNorm {
@@ -58,14 +58,14 @@ type TextMLP struct {
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextExperts struct {
Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
}
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
hiddenStates = hiddenStates.Mul(ctx, scores)
upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
upStates := e.Up.Forward(ctx, hiddenStates, experts)
gateStates := e.Gate.Forward(ctx, hiddenStates, experts)
downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
@@ -88,22 +88,10 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
return nextStates
}
// TextSharedExpert is TextMLP with different tensor names
type TextSharedExpert struct {
Gate *nn.Linear `gguf:"ffn_gate_shexp"`
Up *nn.Linear `gguf:"ffn_up_shexp"`
Down *nn.Linear `gguf:"ffn_down_shexp"`
}
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextMOE struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Experts *TextExperts
SharedExpert *TextSharedExpert
SharedExpert *TextMLP `gguf:",suf:_shexp"`
}
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
@@ -196,7 +184,7 @@ func newTextModel(c fs.Config) *TextModel {
numExpertsUsed: int(c.Uint("expert_used_count")),
ropeDim: int(c.Uint("rope.dimension_count")),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeScale: c.Float("rope.scaling.factor", 1),
eps: c.Float("attention.layer_norm_rms_epsilon"),
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
noRopeInterval: int(c.Uint("no_rope_interval", 4)),
@@ -248,5 +236,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
}

View File

@@ -33,7 +33,6 @@ var _ model.TextProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
}
func init() {

View File

@@ -40,11 +40,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil
}
type MLP struct {
@@ -65,7 +65,7 @@ type MLP struct {
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -132,7 +132,7 @@ func newTextModel(c fs.Config) *TextModel {
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeScale: c.Float("rope.scaling.factor", 1),
},
}
}

View File

@@ -51,7 +51,7 @@ type VisionMLP struct {
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}

View File

@@ -33,7 +33,6 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
@@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
// TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
}
func init() {

View File

@@ -26,11 +26,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -45,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
}
return key, nil
@@ -58,7 +58,7 @@ type TextMLP struct {
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel {
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeScale: c.Float("rope.scaling.factor", 1),
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
},
}

View File

@@ -1,6 +1,8 @@
package models
import (
_ "github.com/ollama/ollama/model/models/bert"
_ "github.com/ollama/ollama/model/models/deepseek2"
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"

View File

@@ -43,8 +43,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
value := attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
@@ -59,7 +59,7 @@ type MLP struct {
}
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
@@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
@@ -124,7 +124,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
}
func New(c fs.Config) (model.Model, error) {
@@ -139,7 +139,6 @@ func New(c fs.Config) (model.Model, error) {
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -152,6 +151,7 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
@@ -160,7 +160,7 @@ func New(c fs.Config) (model.Model, error) {
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeScale: c.Float("rope.scaling.factor", 1),
eps: c.Float("attention.layer_norm_rms_epsilon"),
},
}

View File

@@ -29,7 +29,6 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -42,6 +41,7 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
TextModel: NewTextModel(c),
VisionModel: newVisionModel(c),
@@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
}
func init() {

View File

@@ -38,7 +38,7 @@ func NewTextModel(c fs.Config) *TextModel {
originalContextLength: int(c.Uint("context_length", 128000)),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeScale: c.Float("rope.scaling.factor", 1),
},
}
@@ -60,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -78,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
// Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
}
// MLP implements the feed-forward network component with SwiGLU activation
@@ -90,7 +90,7 @@ type MLP struct {
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
// Apply SwiGLU activation gating
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
// Project back to hidden dimension
return mlp.Down.Forward(ctx, hiddenState)
}

View File

@@ -100,8 +100,7 @@ type VisionMLP struct {
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using activation as specified in config (likely GELU or SiLU/Swish)
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
upOutput := mlp.Up.Forward(ctx, hiddenStates)
hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput)
hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}

View File

@@ -0,0 +1,73 @@
package qwen3
import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type embedModel struct {
model.Base
model.BytePairEncoding
*Model
poolingType pooling.Type
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates, err := m.forward(ctx, batch)
if err != nil {
return nil, err
}
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil
}
func newEmbed(c fs.Config) (model.Model, error) {
layers := make([]Layer, c.Uint("block_count"))
for i := range layers {
layers[i].MLP = &dense{}
}
m := embedModel{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
Model: &Model{
Layers: layers,
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true),
},
},
poolingType: pooling.Type(c.Uint("pooling_type")),
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}

View File

@@ -30,10 +30,10 @@ func (o Options) headDim() int {
}
type Attention struct {
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Query *nn.Linear `gguf:"attn_q"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
@@ -52,8 +52,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
@@ -65,10 +65,10 @@ type MLP interface {
}
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.Linear `gguf:"ffn_down_exps"`
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
}
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
@@ -87,13 +87,9 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts))
hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
hiddenStates = hiddenStates.SILU(ctx)
hiddenStates = hiddenStates.Mul(ctx, upStates)
experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
@@ -111,7 +107,8 @@ type dense struct {
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).
SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
@@ -154,29 +151,39 @@ type Model struct {
*Options
}
// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates, err := m.forward(ctx, batch)
if err != nil {
return nil, err
}
return m.Output.Forward(ctx, hiddenStates), nil
}
// Forward implements model.Model.
func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
if m.Cache != nil {
m.Cache.SetLayer(i)
}
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
}
var _ model.Model = (*Model)(nil)
@@ -193,7 +200,6 @@ func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -206,6 +212,7 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
Layers: layers,
Options: &Options{
@@ -216,7 +223,7 @@ func New(c fs.Config) (model.Model, error) {
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeScale: c.Float("rope.scaling.factor", 1),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true),
@@ -230,4 +237,5 @@ func New(c fs.Config) (model.Model, error) {
func init() {
model.Register("qwen3", New)
model.Register("qwen3moe", New)
model.Register("qwen3_embed", newEmbed)
}

49
model/parsers/parsers.go Normal file
View File

@@ -0,0 +1,49 @@
package parsers
import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
)
type Parser interface {
// Init initializes the parser with tools and optional last message for chat prefill
// Returns processed tools if the parser needs to modify them (e.g., harmony renames them)
Init(tools []api.Tool, lastMessage *api.Message) []api.Tool
// Add processes streamed content and returns parsed content, thinking, and tool calls
// The done flag indicates if this is the last chunk (used for draining accumulators)
Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error)
HasToolSupport() bool
HasThinkingSupport() bool
}
func ParserForName(name string) Parser {
switch name {
case "qwen3-coder":
parser := &Qwen3CoderParser{}
return parser
case "passthrough":
return &PassthroughParser{}
case "harmony":
return harmony.NewHarmonyMessageHandler()
default:
return nil
}
}
type PassthroughParser struct{}
func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
return tools // passthrough doesn't modify tools
}
func (p *PassthroughParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
return s, "", nil, nil
}
func (p *PassthroughParser) HasToolSupport() bool {
return false
}
func (p *PassthroughParser) HasThinkingSupport() bool {
return false
}

463
model/parsers/qwen3coder.go Normal file
View File

@@ -0,0 +1,463 @@
package parsers
import (
"context"
"encoding/json"
"encoding/xml"
"fmt"
"log/slog"
"math"
"regexp"
"strconv"
"strings"
"unicode"
"unicode/utf8"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type qwenParserState int
const (
toolOpenTag = "<tool_call>"
toolCloseTag = "</tool_call>"
)
const (
qwenParserState_LookingForToolStart qwenParserState = iota
qwenParserState_CollectingToolContent
)
type Qwen3CoderParser struct {
state qwenParserState
acc strings.Builder
tools []api.Tool
}
func (p *Qwen3CoderParser) HasToolSupport() bool {
return true
}
func (p *Qwen3CoderParser) HasThinkingSupport() bool {
return false
}
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
p.tools = tools
return tools // Qwen doesn't modify tools
}
func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.acc.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var sb strings.Builder
for _, event := range events {
switch event := event.(type) {
case qwenEventRawToolCall:
toolCall, err := parseToolCall(event, p.tools)
if err != nil {
slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case qwenEventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content
// events, we naively append them together here. See the note below about
// `qwenEvent`s for more details
sb.WriteString(event.content)
}
}
return sb.String(), "", toolCalls, nil
}
func (p *Qwen3CoderParser) parseEvents() []qwenEvent {
var all []qwenEvent
keepLooping := true
for keepLooping {
var events []qwenEvent
events, keepLooping = eat(p)
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "acc", p.acc.String())
}
return all
}
// we use some internal event types in order to communicate between `Add` and
// `eat`. We do this to support interleaving content and parallel tool calls in
// the parser, even though qwen3-coder isn't supposed to do this. Our API
// doesn't currently support models outputting multiple messages in a turn, so
// we wouldn't be able to represent it yet, but there's no reason to prevent the
// parser from supporting it, especially for future models if they end up using
// a similar format.
type qwenEvent interface {
isQwenEvent()
}
type qwenEventRawToolCall struct {
raw string
}
type qwenEventContent struct {
content string
}
func (qwenEventContent) isQwenEvent() {}
func (qwenEventRawToolCall) isQwenEvent() {}
// eat consumes the parser's buffer, and returns a list of any unambiguous
// events from the current parser state. If the parser transitions to another
// state, it may have additional events to emit on the next call, which is what
// the second return value indicates
func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
var events []qwenEvent
switch p.state {
case qwenParserState_LookingForToolStart:
if strings.Contains(p.acc.String(), toolOpenTag) {
// we found a full tool open tag, so we can emit the content before the
// tag, being sure to trim any trailing whitespace
split := strings.SplitN(p.acc.String(), toolOpenTag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
if len(before) > 0 {
events = append(events, qwenEventContent{content: before})
}
after := split[1]
p.acc.Reset()
p.acc.WriteString(after)
p.state = qwenParserState_CollectingToolContent
return events, true
} else if overlap := overlap(p.acc.String(), toolOpenTag); overlap > 0 {
// we found a partial tool open tag, so we can emit the unambiguous part,
// which is the (trailing-whitespace trimmed) content before the partial
// tool open tag
beforePartialTag := p.acc.String()[:len(p.acc.String())-overlap]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := p.acc.String()[:ambiguousStart]
ambiguous := p.acc.String()[ambiguousStart:]
p.acc.Reset()
p.acc.WriteString(ambiguous)
events = append(events, qwenEventContent{content: unambiguous})
return events, false
} else {
// we found content that is entirely not a tool call. We should withhold
// any trailing whitespace in case this is the end of the content
whitespaceLen := trailingWhitespaceLen(p.acc.String())
ambiguousStart := len(p.acc.String()) - whitespaceLen
unambiguous := p.acc.String()[:ambiguousStart]
ambiguous := p.acc.String()[ambiguousStart:]
p.acc.Reset()
p.acc.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwenEventContent{content: unambiguous})
}
return events, false
}
case qwenParserState_CollectingToolContent:
if strings.Contains(p.acc.String(), toolCloseTag) {
split := strings.SplitN(p.acc.String(), toolCloseTag, 2)
before := split[0]
if len(before) == 0 {
slog.Warn("qwen tool call closing tag found but no content before it")
}
// remove any whitespace between the tool call and any content after it
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
p.acc.Reset()
p.acc.WriteString(after)
events = append(events, qwenEventRawToolCall{raw: before})
p.state = qwenParserState_LookingForToolStart
return events, true
} else {
// note that we don't need to check the overlap here because we only plan
// on parsing the tool call once we see the full closing tag. We don't
// stream back the unparsed tool content, so there's no need to be eager
// here
return events, false
}
default:
panic("unreachable")
}
}
// TODO(drifkin): move this to a shared location
// longest overlap between suffix of s and prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
func trailingWhitespaceLen(s string) int {
remaining := s
total := 0
for len(remaining) > 0 {
r, size := utf8.DecodeLastRuneInString(remaining)
// if it's an invalid utf8 rune, assume it isn't whitespace
if r == utf8.RuneError && size == 1 {
break
}
if !unicode.IsSpace(r) {
break
}
total += size
remaining = remaining[:len(remaining)-size]
}
return total
}
type XMLFunctionCall struct {
XMLName xml.Name `xml:"function"`
Name string `xml:"name,attr"`
Parameters []XMLParameter `xml:"parameter"`
}
type XMLParameter struct {
Name string `xml:"name,attr"`
Value string `xml:",chardata"`
}
// parseToolCall parses a raw tool call string into an api.ToolCall.
// The raw string follows an xml-like format, here's an example:
//
// <function=get_current_temperature>
// <parameter=location>
// San Francisco
// </parameter>
// <parameter=unit>
// celsius
// </parameter>
// </function>
func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
toolCall := api.ToolCall{}
xmlString := transformToXML(raw.raw)
var functionCall XMLFunctionCall
err := xml.Unmarshal([]byte(xmlString), &functionCall)
if err != nil {
return api.ToolCall{}, err
}
toolCall.Function = api.ToolCallFunction{
Name: functionCall.Name,
}
// Find the matching tool to get parameter types
var matchedTool *api.Tool
for i := range tools {
if tools[i].Function.Name == functionCall.Name {
matchedTool = &tools[i]
break
}
}
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
for _, parameter := range functionCall.Parameters {
// Look up the parameter type if we found the tool
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
paramType = prop.Type
}
}
toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
}
return toolCall, nil
}
// parseValue converts a raw string value to the appropriate type based on the parameter type specification.
//
// For union types (multiple types in PropertyType, which we support but doesn't
// seem as though the reference parser does type coercion with those types in
// mind) we use a type precedence approach:
// 1. null - checked first regardless of declared types (matches reference implementation)
// 2. boolean - only "true"/"false" are valid booleans
// 3. integer - must parse as a whole number
// 4. number - must parse as numeric (returns int if no decimal part)
// 5. array - must parse as valid JSON array
// 6. object - must parse as valid JSON object
// 7. string - always succeeds (least specific type)
//
// This precedence ensures we return the most specific type that successfully parses,
// following the principle of least surprise. For example, with PropertyType{"string", "number"},
// "123" becomes 123 (number), while "hello" becomes "hello" (string).
func parseValue(raw string, paramType api.PropertyType) any {
// first remove a single leading newlines, and a single trailing newline (if
// they exist). This follows the reference implementation
raw = strings.TrimPrefix(raw, "\n")
raw = strings.TrimSuffix(raw, "\n")
// Check for null first (case-insensitive) - this takes precedence over any type
if strings.ToLower(raw) == "null" {
return nil
}
// If no type is specified, default to string
if len(paramType) == 0 {
return raw
}
// Check if any of the specified types match, using type precedence
// Order: boolean -> integer -> number -> array -> object -> string
typeSet := make(map[string]bool)
for _, t := range paramType {
typeSet[t] = true
}
// Try boolean first (most restrictive)
if typeSet["boolean"] {
lower := strings.ToLower(raw)
switch lower {
case "true":
return true
case "false":
return false
}
// If not a valid boolean but boolean is the only type, return false (matching reference)
if len(paramType) == 1 {
return false
}
// Otherwise try other types
}
// Try integer
if typeSet["integer"] {
if i, err := strconv.ParseInt(raw, 10, 64); err == nil {
// Return as int if it fits in int32, otherwise int64
if i >= math.MinInt32 && i <= math.MaxInt32 {
return int(i)
}
return i
}
// If integer is the only type and parsing failed, fall back to string
if len(paramType) == 1 {
return raw
}
}
// Try number (float)
if typeSet["number"] {
if f, err := strconv.ParseFloat(raw, 64); err == nil {
// If the number has no decimal part, return as int (matching reference)
if f == math.Trunc(f) {
i := int64(f)
if i >= math.MinInt32 && i <= math.MaxInt32 {
return int(i)
}
return i
}
return f
}
// If number is the only type and parsing failed, fall back to string
if len(paramType) == 1 {
return raw
}
}
// Try array
if typeSet["array"] {
var arr []any
if err := json.Unmarshal([]byte(raw), &arr); err == nil {
return arr
}
// If array is the only type and parsing failed, fall back to string
if len(paramType) == 1 {
return raw
}
}
// Try object
if typeSet["object"] {
var obj map[string]any
if err := json.Unmarshal([]byte(raw), &obj); err == nil {
return obj
}
// If object is the only type and parsing failed, fall back to string
if len(paramType) == 1 {
return raw
}
}
// String always succeeds (or if "string" is in the type set)
if typeSet["string"] {
return raw
}
// If we get here, none of the types matched and string wasn't an option
// We return string as a fallback. The reference implementation will attempt
// to parse the value as a python literal, but we purposefully don't support
// that
return raw
}
var (
qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
qwenXMLTagRegex = regexp.MustCompile(`</?(?:function|parameter)(?:\s+name="[^"]*")?>`)
)
// transformToXML transforms a raw qwen tool call with xml-like tags into valid
// xml so that it can be parsed by any xml parser
func transformToXML(raw string) string {
// take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking
// care to properly escape the string that becomes the attribute value
transformed := qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
groups := qwenTagRegex.FindStringSubmatch(match)
tag := groups[1]
var escapedValue strings.Builder
xml.EscapeText(&escapedValue, []byte(groups[2]))
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
})
// Walk the resulting string, escaping any character data that sits between the
// xml tags we just emitted
var out strings.Builder
lastIdx := 0
for _, loc := range qwenXMLTagRegex.FindAllStringIndex(transformed, -1) {
if loc[0] > lastIdx {
escapeTextNode(&out, transformed[lastIdx:loc[0]])
}
out.WriteString(transformed[loc[0]:loc[1]])
lastIdx = loc[1]
}
if lastIdx < len(transformed) {
escapeTextNode(&out, transformed[lastIdx:])
}
return out.String()
}
// escapeTextNode escapes XML character data without altering other characters
// like newlines or tabs (which is why we don't use xml.EscapeText for this)
func escapeTextNode(sb *strings.Builder, s string) {
for _, r := range s {
switch r {
case '&':
sb.WriteString("&amp;")
case '<':
sb.WriteString("&lt;")
case '>':
sb.WriteString("&gt;")
default:
sb.WriteRune(r)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,217 @@
package renderers
import (
"encoding/json"
"fmt"
"reflect"
"strings"
"github.com/ollama/ollama/api"
)
var (
imStartTag = "<|im_start|>"
imEndTag = "<|im_end|>"
)
// renderAdditionalKeys renders all JSON fields except the ones in handledKeys
// This follows the same approach from the reference implementation, which gives
// a particular key ordering
func renderAdditionalKeys(obj any, handledKeys map[string]bool) string {
data, err := json.Marshal(obj)
if err != nil {
return ""
}
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return ""
}
var sb strings.Builder
for key, value := range m {
if handledKeys[key] {
continue
}
// Check if value is a map or array (needs JSON serialization)
switch v := value.(type) {
case map[string]any, []any:
jsonBytes, _ := json.Marshal(v)
// TODO(drifkin): it would be nice to format the JSON here similarly to
// python's default json.dumps behavior (spaces after commas and colons).
// This would let us be byte-for-byte compatible with the reference
// implementation for most common inputs
jsonStr := string(jsonBytes)
sb.WriteString("\n<" + key + ">" + jsonStr + "</" + key + ">")
case nil:
continue
default:
// Simple types, convert to string
sb.WriteString("\n<" + key + ">" + fmt.Sprintf("%v", value) + "</" + key + ">")
}
}
return sb.String()
}
func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
var sb strings.Builder
// filter out system messages and choose the first (if any) to win
var systemMessage string
var filteredMessages []api.Message
for _, message := range messages {
if message.Role != "system" {
filteredMessages = append(filteredMessages, message)
continue
}
if systemMessage == "" {
systemMessage = message.Content
}
}
if systemMessage != "" || len(tools) > 0 {
sb.WriteString(imStartTag + "system\n")
// if we have tools but no system message, match the reference implementation by providing a default system message
if systemMessage == "" {
systemMessage = "You are Qwen, a helpful AI assistant that can interact with a computer to solve tasks."
}
sb.WriteString(systemMessage)
if len(tools) > 0 {
sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n")
sb.WriteString("<tools>")
for _, tool := range tools {
sb.WriteString("\n")
sb.WriteString("<function>\n")
sb.WriteString("<name>" + tool.Function.Name + "</name>")
if tool.Function.Description != "" {
sb.WriteString("\n<description>" + tool.Function.Description + "</description>")
}
sb.WriteString("\n<parameters>")
for name, prop := range tool.Function.Parameters.Properties {
sb.WriteString("\n<parameter>")
sb.WriteString("\n<name>" + name + "</name>")
if len(prop.Type) > 0 {
// TODO(!!!)(drifkin): we should match the reference implementation for
// more complex types here instead of using this format
sb.WriteString("\n<type>" + prop.ToTypeScriptType() + "</type>")
}
if prop.Description != "" {
sb.WriteString("\n<description>" + prop.Description + "</description>")
}
// Render any additional keys not already handled
handledKeys := map[string]bool{
"type": true,
"description": true,
}
sb.WriteString(renderAdditionalKeys(prop, handledKeys))
sb.WriteString("\n</parameter>")
}
// Render extra keys for parameters (everything except 'type' and 'properties')
paramHandledKeys := map[string]bool{
"type": true,
"properties": true,
}
sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys))
sb.WriteString("\n</parameters>")
sb.WriteString("\n</function>")
}
sb.WriteString("\n</tools>")
sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>")
}
sb.WriteString(imEndTag + "\n")
}
for i, message := range filteredMessages {
lastMessage := i == len(filteredMessages)-1
prefill := lastMessage && message.Role == "assistant"
switch message.Role {
case "assistant":
if len(message.ToolCalls) > 0 {
sb.WriteString(imStartTag + "assistant\n")
if message.Content != "" {
sb.WriteString(message.Content + "\n")
}
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
for name, value := range toolCall.Function.Arguments {
valueStr := formatToolCallArgument(value)
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
}
sb.WriteString("\n</function>\n</tool_call>")
}
sb.WriteString("<|im_end|>\n")
} else {
sb.WriteString(imStartTag + "assistant\n")
sb.WriteString(message.Content)
if !prefill {
sb.WriteString(imEndTag + "\n")
}
}
case "tool":
// consecutive tool responses should share a single `<im_start>user`, but
// have their own <tool_response> tags
// only start a new user block if this is the first tool response
if i == 0 || filteredMessages[i-1].Role != "tool" {
sb.WriteString(imStartTag + "user\n")
}
sb.WriteString("<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>\n")
// close the user block only if this is the last tool response
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
sb.WriteString(imEndTag + "\n")
}
default:
sb.WriteString(imStartTag + message.Role + "\n")
sb.WriteString(message.Content)
sb.WriteString(imEndTag + "\n")
}
if lastMessage && !prefill {
sb.WriteString(imStartTag + "assistant\n")
}
}
return sb.String(), nil
}
func formatToolCallArgument(value any) string {
if value == nil {
return "null"
}
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
}
if reflect.TypeOf(value) != nil {
kind := reflect.TypeOf(value).Kind()
if kind == reflect.Map || kind == reflect.Slice || kind == reflect.Array {
if marshalled, err := json.Marshal(value); err == nil {
return string(marshalled)
}
}
}
return fmt.Sprintf("%v", value)
}

View File

@@ -0,0 +1,338 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestQwen3CoderRenderer(t *testing.T) {
tests := []struct {
name string
msgs []api.Message
tools []api.Tool
expected string
}{
{
name: "basic",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
},
expected: `<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
`,
},
{
name: "with tools and response",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in San Francisco?"},
{
Role: "assistant",
Content: "I'll check the weather in San Francisco for you.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{
"unit": "fahrenheit",
},
},
},
},
},
{Role: "tool", Content: "{\"location\": \"San Francisco, CA\", \"temperature\": 68, \"condition\": \"partly cloudy\", \"humidity\": 65, \"wind_speed\": 12}", ToolName: "get_weather"},
{Role: "user", Content: "That sounds nice! What about New York?"},
},
tools: []api.Tool{
{Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Required: []string{"unit"},
Properties: map[string]api.ToolProperty{
"unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
// TODO(drifkin): add multiple params back once we have predictable
// order via some sort of ordered map type (see
// <https://github.com/ollama/ollama/issues/12244>)
/*
"location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
*/
},
},
}},
},
expected: `<|im_start|>system
You are a helpful assistant with access to tools.
# Tools
You have access to the following functions:
<tools>
<function>
<name>get_weather</name>
<description>Get the current weather in a given location</description>
<parameters>
<parameter>
<name>unit</name>
<type>string</type>
<description>The unit of temperature</description>
<enum>["celsius","fahrenheit"]</enum>
</parameter>
<required>["unit"]</required>
</parameters>
</function>
</tools>
If you choose to call a function ONLY reply in the following format with NO suffix:
<tool_call>
<function=example_function_name>
<parameter=example_parameter_1>
value_1
</parameter>
<parameter=example_parameter_2>
This is the value for the second parameter
that can span
multiple lines
</parameter>
</function>
</tool_call>
<IMPORTANT>
Reminder:
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
- Required parameters MUST be specified
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
</IMPORTANT><|im_end|>
<|im_start|>user
What is the weather like in San Francisco?<|im_end|>
<|im_start|>assistant
I'll check the weather in San Francisco for you.
<tool_call>
<function=get_weather>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
</tool_response>
<|im_end|>
<|im_start|>user
That sounds nice! What about New York?<|im_end|>
<|im_start|>assistant
`,
},
{
name: "parallel tool calls",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "call double(1) and triple(2)"},
{Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
{Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
}},
{Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
{Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
},
tools: []api.Tool{
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
"number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
}}}},
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
"number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
}}}},
},
expected: `<|im_start|>system
You are a helpful assistant with access to tools.
# Tools
You have access to the following functions:
<tools>
<function>
<name>double</name>
<description>Double a number</description>
<parameters>
<parameter>
<name>number</name>
<type>string</type>
<description>The number to double</description>
</parameter>
</parameters>
</function>
<function>
<name>triple</name>
<description>Triple a number</description>
<parameters>
<parameter>
<name>number</name>
<type>string</type>
<description>The number to triple</description>
</parameter>
</parameters>
</function>
</tools>
If you choose to call a function ONLY reply in the following format with NO suffix:
<tool_call>
<function=example_function_name>
<parameter=example_parameter_1>
value_1
</parameter>
<parameter=example_parameter_2>
This is the value for the second parameter
that can span
multiple lines
</parameter>
</function>
</tool_call>
<IMPORTANT>
Reminder:
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
- Required parameters MUST be specified
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
</IMPORTANT><|im_end|>
<|im_start|>user
call double(1) and triple(2)<|im_end|>
<|im_start|>assistant
I'll call double(1) and triple(2) for you.
<tool_call>
<function=double>
<parameter=number>
1
</parameter>
</function>
</tool_call>
<tool_call>
<function=triple>
<parameter=number>
2
</parameter>
</function>
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
{"number": 2}
</tool_response>
<tool_response>
{"number": 6}
</tool_response>
<|im_end|>
<|im_start|>assistant
`,
},
{
name: "prefill",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Tell me something interesting."},
{Role: "assistant", Content: "I'll tell you something interesting about cats"},
},
expected: `<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Tell me something interesting.<|im_end|>
<|im_start|>assistant
I'll tell you something interesting about cats`,
},
{
name: "complex tool call arguments should remain json encoded",
msgs: []api.Message{
{Role: "user", Content: "call tool"},
{Role: "assistant", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{
Name: "echo",
Arguments: map[string]any{
"payload": map[string]any{"foo": "bar"},
},
}},
}},
{Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
},
expected: `<|im_start|>user
call tool<|im_end|>
<|im_start|>assistant
<tool_call>
<function=echo>
<parameter=payload>
{"foo":"bar"}
</parameter>
</function>
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
{"payload": {"foo": "bar"}}
</tool_response>
<|im_end|>
<|im_start|>assistant
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestFormatToolCallArgument(t *testing.T) {
tests := []struct {
name string
arg any
expected string
}{
{
name: "string",
arg: "foo",
// notice no quotes around the string
expected: "foo",
},
{
name: "map",
arg: map[string]any{"foo": "bar"},
expected: "{\"foo\":\"bar\"}",
},
{
name: "number",
arg: 1,
expected: "1",
},
{
name: "boolean",
arg: true,
expected: "true",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := formatToolCallArgument(tt.arg)
if got != tt.expected {
t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,26 @@
package renderers
import (
"fmt"
"github.com/ollama/ollama/api"
)
type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error)
func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
renderer := rendererForName(name)
if renderer == nil {
return "", fmt.Errorf("unknown renderer %q", name)
}
return renderer(msgs, tools, think)
}
func rendererForName(name string) rendererFunc {
switch name {
case "qwen3-coder":
return Qwen3CoderRenderer
default:
return nil
}
}

View File

@@ -2,7 +2,6 @@ package model
import (
"container/heap"
"context"
"fmt"
"log/slog"
"strconv"
@@ -13,19 +12,19 @@ import (
const spmWhitespaceSep = "▁"
type SentencePieceModel struct {
type SentencePiece struct {
maxTokenLen int
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePieceModel)(nil)
var _ TextProcessor = (*SentencePiece)(nil)
func (spm SentencePieceModel) Vocabulary() *Vocabulary {
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
}
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
slog.Log(context.TODO(), logutil.LevelTrace, "Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
var maxTokenLen int
@@ -39,21 +38,21 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePieceModel{
return SentencePiece{
maxTokenLen: maxTokenLen,
vocab: vocab,
}
}
func (spm SentencePieceModel) Is(id int32, special Special) bool {
func (spm SentencePiece) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special)
@@ -182,12 +181,11 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
if addSpecial && len(ids) > 0 {
ids = spm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
@@ -220,7 +218,7 @@ func (q *queue) Pop() interface{} {
return item
}
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
@@ -246,6 +244,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String())
logutil.Trace("decoded", "ids", ids, "string", sb.String())
return sb.String(), nil
}

View File

@@ -12,7 +12,7 @@ import (
"github.com/ollama/ollama/convert/sentencepiece"
)
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
@@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
}
}
return NewSentencePieceModel(&v)
return NewSentencePiece(&v)
}
func TestSentencePieceEncode(t *testing.T) {
@@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) {
})
}
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
func TestSentencePieceDecodeByteTokens(t *testing.T) {
vocab := &Vocabulary{
Values: []string{
"normal",
@@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
Scores: []float32{0, 0, 0, 0, 0},
}
spm := NewSentencePieceModel(vocab)
spm := NewSentencePiece(vocab)
tests := []struct {
name string

View File

@@ -49,7 +49,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
}
slog.Debug("adding bos token to prompt", "id", v.BOS)
slog.Debug("adding bos token to prompt", "id", v.BOS[0])
ids = append([]int32{v.BOS[0]}, ids...)
}
@@ -58,7 +58,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
}
slog.Debug("adding eos token to prompt", "id", v.EOS)
slog.Debug("adding eos token to prompt", "id", v.EOS[0])
ids = append(ids, v.EOS[0])
}

167
model/wordpiece.go Normal file
View File

@@ -0,0 +1,167 @@
package model
import (
"fmt"
"iter"
"strings"
"unicode"
"github.com/ollama/ollama/logutil"
)
type WordPiece struct {
vocab *Vocabulary
}
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
// this differs from original word piece which uses "##" to indicate subwords.
const ggmlPrefix = "▁"
var wordPieceReplacer = strings.NewReplacer(
" .", ".",
" ?", "?",
" !", "!",
" ,", ",",
" ' ", "'",
" n't", "n't",
" 'm", "'m",
" do not", " don't",
" 's", "'s",
" 've", "'ve",
" 're", "'re",
)
// Decode implements TextProcessor.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
if id < 0 || int(id) >= len(wpm.vocab.Values) {
return "", fmt.Errorf("invalid token id: %d", id)
}
var separator string
piece := wpm.vocab.Values[id]
if i > 0 &&
(strings.HasPrefix(piece, ggmlPrefix) ||
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
separator = " "
}
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
}
return sb.String(), nil
}
// words splits a string into words, treating CJK characters as separate words.
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
func (wpm WordPiece) words(s string) iter.Seq[string] {
return func(yield func(string) bool) {
runes := make([]rune, 0, len(s)*3)
for _, r := range s {
switch {
case r >= 0x4E00 && r <= 0x9FFF,
r >= 0x3400 && r <= 0x4DBF,
r >= 0x20000 && r <= 0x2A6DF,
r >= 0x2A700 && r <= 0x2B73F,
r >= 0x2B740 && r <= 0x2B81F,
r >= 0x2B820 && r <= 0x2CEAF,
r >= 0xF900 && r <= 0xFAFF,
r >= 0x2F800 && r <= 0x2FA1F:
runes = append(runes, ' ', r, ' ')
default:
runes = append(runes, r)
}
}
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
// split on but keep punctuation
var start int
for start < len(w) {
end := strings.IndexFunc(w[start:], unicode.IsPunct)
if end < 0 {
end = len(w) - start
} else if end == 0 {
end = 1
}
if !yield(w[start : start+end]) {
return
}
start += end
}
}
}
}
// Encode implements TextProcessor.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
// TODO: use [UNK] from config
unk := wpm.vocab.Encode("[UNK]")
for word := range wpm.words(s) {
var start int
var pieces []int32
for start < len(word) {
end := len(word)
var piece int32
for start < end {
subword := word[start:end]
if start == 0 {
subword = ggmlPrefix + subword
}
// TODO: some models might not want [ToLower]
piece = wpm.vocab.Encode(strings.ToLower(subword))
if piece >= 0 {
break
}
end--
}
if piece < 0 {
// Unknown token
pieces = pieces[:0]
break
}
pieces = append(pieces, piece)
start = end
}
if len(pieces) > 0 {
ids = append(ids, pieces...)
} else {
ids = append(ids, unk)
}
}
if addSpecial && len(ids) > 0 {
ids = wpm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
// Is implements TextProcessor.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements TextProcessor.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ TextProcessor = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary) WordPiece {
return WordPiece{
vocab: vocab,
}
}

51
model/wordpiece_test.go Normal file
View File

@@ -0,0 +1,51 @@
package model
import (
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWordPiece(t *testing.T) {
wpm := NewWordPiece(
&Vocabulary{
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
AddBOS: true,
AddEOS: true,
BOS: []int32{1},
EOS: []int32{2},
})
ids, err := wpm.Encode("Hello world!", true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
t.Errorf("unexpected ids (-want +got):\n%s", diff)
}
words, err := wpm.Decode(ids)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}
func TestWordPieceWords(t *testing.T) {
var wpm WordPiece
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}

View File

@@ -76,8 +76,9 @@ type JsonSchema struct {
}
type EmbedRequest struct {
Input any `json:"input"`
Model string `json:"model"`
Input any `json:"input"`
Model string `json:"model"`
Dimensions int `json:"dimensions,omitempty"`
}
type StreamOptions struct {
@@ -104,16 +105,18 @@ type ChatCompletionRequest struct {
Tools []api.Tool `json:"tools"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
DebugRenderOnly bool `json:"_debug_render_only"`
}
type ChatCompletion struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage,omitempty"`
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage,omitempty"`
DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
}
type ChatCompletionChunk struct {
@@ -140,6 +143,7 @@ type CompletionRequest struct {
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
DebugRenderOnly bool `json:"_debug_render_only"`
}
type Completion struct {
@@ -272,8 +276,8 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}
return nil
}(r.DoneReason),
}},
Usage: toUsage(r),
}}, Usage: toUsage(r),
DebugInfo: r.DebugInfo,
}
}
@@ -567,13 +571,14 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
return &api.ChatRequest{
Model: r.Model,
Messages: messages,
Format: format,
Options: options,
Stream: &r.Stream,
Tools: r.Tools,
Think: think,
Model: r.Model,
Messages: messages,
Format: format,
Options: options,
Stream: &r.Stream,
Tools: r.Tools,
Think: think,
DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
@@ -647,11 +652,12 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
}
return api.GenerateRequest{
Model: r.Model,
Prompt: r.Prompt,
Options: options,
Stream: &r.Stream,
Suffix: r.Suffix,
Model: r.Model,
Prompt: r.Prompt,
Options: options,
Stream: &r.Stream,
Suffix: r.Suffix,
DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
@@ -1005,7 +1011,7 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}

View File

@@ -100,6 +100,10 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
req.System = c.Args
case "license":
licenses = append(licenses, c.Args)
case "renderer":
req.Renderer = c.Args
case "parser":
req.Parser = c.Args
case "message":
role, msg, _ := strings.Cut(c.Args, ": ")
messages = append(messages, api.Message{Role: role, Content: msg})
@@ -246,7 +250,7 @@ func filesForModel(path string) ([]string, error) {
for _, match := range matches {
if ct, err := detectContentType(match); err != nil {
return nil, err
} else if ct != contentType {
} else if len(contentType) > 0 && ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
}
}
@@ -255,7 +259,8 @@ func filesForModel(path string) ([]string, error) {
}
var files []string
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
@@ -319,7 +324,7 @@ func (c Command) String() string {
switch c.Name {
case "model":
fmt.Fprintf(&sb, "FROM %s", c.Args)
case "license", "template", "system", "adapter":
case "license", "template", "system", "adapter", "renderer", "parser":
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
case "message":
role, message, _ := strings.Cut(c.Args, ": ")
@@ -345,7 +350,7 @@ const (
var (
errMissingFrom = errors.New("no FROM line")
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"")
)
type ParserError struct {
@@ -605,7 +610,7 @@ func isValidMessageRole(role string) bool {
func isValidCommand(cmd string) bool {
switch strings.ToLower(cmd) {
case "from", "license", "template", "system", "adapter", "parameter", "message":
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message":
return true
default:
return false

View File

@@ -198,6 +198,34 @@ BADCOMMAND param1 value1
}
}
func TestParseFileRenderer(t *testing.T) {
input := `
FROM foo
RENDERER renderer1
`
reader := strings.NewReader(input)
modelfile, err := ParseFile(reader)
require.NoError(t, err)
assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "renderer", Args: "renderer1"}}, modelfile.Commands)
}
func TestParseFileParser(t *testing.T) {
input := `
FROM foo
PARSER parser1
`
reader := strings.NewReader(input)
modelfile, err := ParseFile(reader)
require.NoError(t, err)
assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "parser", Args: "parser1"}}, modelfile.Commands)
}
func TestParseFileMessages(t *testing.T) {
cases := []struct {
input string

View File

@@ -204,13 +204,8 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
targetFree = max(targetFree, 1)
currentFree := c.numCtx - inputLen
discard := targetFree - currentFree
if discard < 0 {
discard = 0
}
return discard
return max(targetFree-currentFree, 0)
}
type ErrReprocessInputs struct {

View File

@@ -34,8 +34,8 @@ type InputCache struct {
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
numCtx := kvSize / int32(numSlots)
if numCtx < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
if int(numCtx) < batchSize {
return nil, fmt.Errorf("kv size must be at least as large as batch size * parallel (kv: %v batch: %v parallel: %v)", kvSize, batchSize, numSlots)
}
slots := make([]InputCacheSlot, numSlots)
@@ -70,11 +70,9 @@ func kvCacheTypeFromStr(s string) ml.DType {
}
func (c *InputCache) Close() {
if c == nil {
return
if c != nil && c.cache != nil {
c.cache.Close()
}
c.cache.Close()
}
// Locking: Operations on InputCacheSlot (including finding one
@@ -95,7 +93,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*InputCacheSlot, []*input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -113,6 +111,10 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*i
return nil, nil, err
}
if !cachePrompt {
numPast = 0
}
slot.InUse = true
slot.lastUsed = time.Now()
@@ -240,13 +242,8 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
targetFree = max(targetFree, 1)
currentFree := c.numCtx - inputLen
discard := targetFree - currentFree
if discard < 0 {
discard = 0
}
return discard
return max(targetFree-currentFree, 0)
}
type ErrReprocessInputs struct {

View File

@@ -393,7 +393,7 @@ func TestLoadCacheSlot(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt, true)
// Check error state
if (err != nil) != tt.wantErr {

View File

@@ -17,7 +17,6 @@ import (
"reflect"
"regexp"
"runtime"
"runtime/debug"
"strconv"
"strings"
"sync"
@@ -32,6 +31,7 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
@@ -405,6 +405,8 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) run(ctx context.Context) {
s.ready.Wait()
supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
var activeBatch batchState
for {
select {
@@ -418,7 +420,12 @@ func (s *Server) run(ctx context.Context) {
if err != nil {
panic(err)
}
go s.computeBatch(activeBatch)
if supportsAsync {
go s.computeBatch(activeBatch)
} else {
s.computeBatch(activeBatch)
}
}
}
}
@@ -429,12 +436,12 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// before setting up the next batch so the seqs inputs are ready to receive their
// token values and we get the correct input pointers for the batchInputs
if pendingBatch.ctx != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
logutil.Trace("forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
<-pendingBatch.computeStartedCh
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
logutil.Trace("forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
logutil.Trace("forwardBatch no pending batch detected", "batchID", s.batchID)
// No pendingBatch, so the inputs will be ready in the seqs immediately
nextBatch.inputsReadyCh = make(chan struct{}, 1)
nextBatch.inputsReadyCh <- struct{}{}
@@ -460,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
var batchInputs []*input.Input
var batchOutputs []int32
var batch input.Batch
resumeSeq := -1
@@ -542,11 +550,11 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id)
seq.iBatch = len(batch.Outputs)
if i+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
seq.iBatch = len(batchOutputs)
if i+1 == len(seq.inputs) || seq.embeddingOnly {
batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
}
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
seq.pendingInputs = append(seq.pendingInputs, inp)
}
@@ -560,7 +568,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
}
if len(batchInputs) == 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
logutil.Trace("forwardBatch no batchInputs, going idle", "batchID", s.batchID)
nextBatch.ctx.Close()
nextBatch.ctx = nil
return
@@ -569,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
if err != nil {
err = fmt.Errorf("failed to build graph: %w", err)
@@ -589,14 +598,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
defer activeBatch.ctx.Close()
// Wait until inputs are ready
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
logutil.Trace("computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
<-activeBatch.inputsReadyCh
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", activeBatch.id)
logutil.Trace("computeBatch: inputs are ready", "batchID", activeBatch.id)
// Once we complete, signal the next batch of inputs are ready
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
defer func() {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", activeBatch.id)
logutil.Trace("computeBatch: outputs are ready", "batchID", activeBatch.id)
activeBatch.outputsReadyCh <- struct{}{}
}()
@@ -626,7 +635,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
// Detect if the sequence we're processing has already been completed and replaced
// with a new sequence
if seq != activeBatch.seqs[i] {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
@@ -666,18 +675,19 @@ func (s *Server) computeBatch(activeBatch batchState) {
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
activeBatch.ctx.ComputeWithNotify(
func() {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
activeBatch.computeStartedCh <- struct{}{}
},
activeBatch.modelOutput)
logits := activeBatch.modelOutput.Floats()
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", activeBatch.id)
outputs := activeBatch.modelOutput.Floats()
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
s.mu.Lock()
defer s.mu.Unlock()
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", activeBatch.id)
logutil.Trace("computeBatch: decoding", "batchID", activeBatch.id)
for i, seq := range s.seqs {
if seq == nil || nextBatchTokens[i] == nil {
continue
@@ -689,16 +699,15 @@ func (s *Server) computeBatch(activeBatch batchState) {
// if done processing the prompt, generate an embedding and return
if seq.embeddingOnly {
// TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported", "id", activeBatch.id, "seqIdx", i)
seq.embedding <- outputs
s.removeSequence(i, llm.DoneReasonStop)
continue
}
// sample a token
vocabSize := len(logits) / len(activeBatch.batch.Outputs)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(logits), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
return
@@ -711,7 +720,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
logutil.Trace("computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
s.removeSequence(i, llm.DoneReasonStop)
continue
}
@@ -834,7 +843,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
@@ -890,6 +899,67 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
return
}
var req llm.EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
if err != nil {
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting embedding request due to client closing the connection")
} else {
http.Error(w, fmt.Sprintf("failed to acquire semaphore: %v", err), http.StatusInternalServerError)
}
return
}
s.mu.Lock()
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("failed to load cache: %v", err), http.StatusInternalServerError)
return
}
s.seqs[i] = seq
s.cond.Signal()
found = true
break
}
}
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: <-seq.embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
@@ -978,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
batch.Positions[i] = int32(i)
}
batch.Outputs = make([]int32, s.parallel)
for i := range batch.Outputs {
batch.Outputs[i] = int32(i)
}
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
cache := s.model.Config().Cache
if cache != nil {
@@ -1017,9 +1083,13 @@ func (s *Server) allocModel(
// Convert memory allocation panics to errors
defer func() {
if r := recover(); r != nil {
debug.PrintStack()
if err, ok := r.(error); ok {
panicErr = err
var noMem ml.ErrNoMem
if errors.As(err, &noMem) {
panicErr = noMem
} else {
panic(r)
}
} else {
panic(r)
}
@@ -1206,10 +1276,7 @@ func Execute(args []string) error {
mux := http.NewServeMux()
// TODO: support embeddings
mux.HandleFunc("POST /load", server.load)
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
})
mux.HandleFunc("POST /embedding", server.embeddings)
mux.HandleFunc("POST /completion", server.completion)
mux.HandleFunc("GET /health", server.health)

View File

@@ -82,7 +82,6 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
merges := make([]string, 0, 1)
// Only need vocab for Grammar Test
return model.NewBytePairEncoding(
``,
&model.Vocabulary{
Values: tokens,
Types: make([]int32, len(vocab)),

View File

@@ -78,7 +78,7 @@ function checkEnv() {
}
function buildOllama() {
function buildCPU() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
@@ -90,20 +90,72 @@ function buildOllama() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component CPU --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
function buildCUDA11() {
# CUDA v11 claims to be compatible with MSVC 2022, but the latest updates are no longer compatible
# 19.40 is the last compiler version that works, but recent udpates are 19.43
# So this pins to MSVC 2019 for best compatibility
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v12")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }}
$env:CUDAToolkit_ROOT=$hashEnv[$v12]
write-host "Building CUDA v12 backend libraries"
& cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR
if ("$script:CUDA_DIRS".Contains("v11")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
write-host "Building CUDA v11 backend libraries $cuda"
$env:CUDAToolkit_ROOT=$cuda
& cmake --fresh --preset "CUDA 11" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v11"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildCUDA12() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v12.8")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12_8")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
write-host "Building CUDA v12 backend libraries $cuda"
$env:CUDAToolkit_ROOT=$cuda
& cmake --fresh --preset "CUDA 12" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v12"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildCUDA13() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v13")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V13")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
$env:CUDAToolkit_ROOT=$cuda
write-host "Building CUDA v13 backend libraries $cuda"
& cmake --fresh --preset "CUDA 13" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v13"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 13" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildROCm() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
if ($env:HIP_PATH) {
write-host "Building ROCm backend libraries"
if (-Not (get-command -ErrorAction silent ninja)) {
@@ -129,6 +181,10 @@ function buildOllama() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildOllama() {
mkdir -Force -path "${script:DIST_DIR}\"
write-host "Building ollama CLI"
& go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -236,6 +292,10 @@ function distZip() {
checkEnv
try {
if ($($args.count) -eq 0) {
buildCPU
buildCUDA12
buildCUDA13
buildROCm
buildOllama
buildApp
gatherDependencies

View File

@@ -16,6 +16,7 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
--build-arg=OLLAMA_FAST_BUILD \
--build-arg=CUSTOM_CPU_FLAGS \
--build-arg=GPU_RUNNER_CPU_FLAGS \
--build-arg=PARALLEL \
--build-arg=AMDGPU_TARGETS"
echo "Building Ollama"

Some files were not shown because too many files have changed in this diff Show More