Compare commits

...

180 Commits

Author SHA1 Message Date
likelovewant
17bb5ea679 Merge branch 'ollama:main' into main 2025-03-23 12:10:05 +08:00
Blake Mizerany
ce929984a3 server/internal/client/ollama: fix file descriptor management in Pull (#9931)
Close chunked writers as soon as downloads complete, rather than
deferring closure until Pull exits. This prevents exhausting file
descriptors when pulling many layers.

Instead of unbounded defers, use a WaitGroup and background goroutine
to close each chunked writer as soon as its downloads finish.

Also rename 'total' to 'received' for clarity.
2025-03-21 16:16:38 -07:00
Michael Yang
4b34930a31 Merge pull request #9897 from ollama/mxyng/chunk-load
ml/backend/ggml: load tensors in 128KiB chunks
2025-03-21 14:47:13 -07:00
Michael Yang
74bd09652d ml/backend/ggml: load tensors in 32KiB chunks 2025-03-21 14:43:52 -07:00
Bruce MacDonald
fb6252d786 benchmark: performance of running ollama server (#8643) 2025-03-21 13:08:20 -07:00
Blake Mizerany
c794fef2f2 server/internal/client/ollama: persist through chunk download errors (#9923) 2025-03-21 13:03:43 -07:00
Parth Sareen
00ebda8cc4 Revert "parser: remove role validation from Modelfile parser" (#9917)
This reverts commit ffbfe833da.
2025-03-21 12:38:09 -07:00
Parth Sareen
d14ce75b95 docs: update final response for /api/chat stream (#9919) 2025-03-21 12:35:47 -07:00
Jesse Gross
2d6eac9084 kvcache: Optimize sliding window attention
Currently sliding window attention allocates and uses the full
context size and just masks out any tokens that are outside of the
window. However, we really only need (roughly) the sliding window
size.

At large context sizes this improves two things:
 - Memory allocated - since the fully context size is allocated up front,
   memory requirements drop substantially. On Gemma3:4b with a 32k
   context window, total memory usage (including weights and non-sliding
   layers) drops from ~20GB to ~8GB.
 - Computation - ranges that are completely outside of the sliding
   window are now removed from the tensors that are returned from the
   cache rather than simply being masked out. This results in more
   efficient processing, scaling with the size of the context that
   has actually been used.

Notable, this does not update the scheduler for any model to be aware of
the smaller memory requirements. This is difficult for Gemma3 because
the layers are heterogeneous between sliding and non-sliding attention.
As a result, while actual memory consumption will be reduced, the
scheduler will over-estimate the requirements of the model. This means
that splitting between GPUs or GPUs and CPUs will still be suboptimal.

Bug #9730
2025-03-21 11:20:19 -07:00
Jesse Gross
3ed7ad3ab3 kvcache: Pass granular cache size into implementations
Currently the runner computes the kv size needed and creates a
cache of that size. This is the context size times number of
parallel sequences.

Cache implementations can make better decisions about their memory
usage, so instead pass in the required capacity, number of sequences
and maximum batch size. For now, the causal cache just uses this to
compute the size in the same way as before.
2025-03-21 11:20:19 -07:00
Patrick Devine
6d1103048e fix: show correct bool value for kv in verbose show information (#9928) 2025-03-21 11:13:54 -07:00
Jesse Gross
0ff28758b3 ollamarunner: Provide mechanism for backends to report loading progress
This enables the runner to report progress back to the Ollama server,
both for showing status to the user and also to prevent the server
from killing the runner if it thinks things have stalled.

Most of the infrastructure was already there, this extends it to
be available to the backends.
2025-03-21 10:44:26 -07:00
Jesse Gross
d3e9ca3eda kvcache: Account for source tensors in defrag operation count
Defragging the KV cache can generate a lot of operations, so we
need to be careful that we don't overflow the number that the graph
can support. We currently account for all of the nodes that we add
to the graph for each move but we also need to include the original
cache tensors as well.

Fixes #9904
2025-03-21 10:42:19 -07:00
Jesse Gross
0fbfcf3c9c model: Pass input tensor instead of raw data to models
Rather than directly giving the input data to models, we can
pass a tensor instead. In the short term, this saves some duplicated
code.

Longer term, we will want to overlap setting up the next batch with
processing of the current one. In this case, we will only have the
shape of tensor but it will not be loaded with data at the time of
graph generation. By passing only a tensor to models now, we set up
this possibility and prevent them from relying on data that they won't
have in the future.

Although the same could be done for Positions and Outputs, in some
cases we either need the raw input data or don't use them at all.
Therefore, for now we leave them as they are and allow models to
convert them to tensors as needed.
2025-03-20 13:28:13 -07:00
Jesse Gross
0c220935bd input: Rename Options to Batch
Options is no longer very descriptive of this struct.
2025-03-20 13:28:13 -07:00
rylativity
ffbfe833da parser: remove role validation from Modelfile parser (#9874)
* updates parser/parser.go to allow arbitrary roles in Modelfile MESSAGE blocks
2025-03-20 13:11:17 -07:00
Parth Sareen
42a14f7f63 sample: add error handling for empty logits (#9740) 2025-03-20 11:11:18 -07:00
Patrick Devine
f8c3dbe5b5 templates: add autotemplate for gemma3 (#9880)
This change allows the gemma3 template to be autodetected during `ollama
create`.
2025-03-20 00:15:30 -07:00
Jesse Gross
b078dd157c gemma2: Remove second call to Rows
Looks like a merge conflict that broke the model.
2025-03-19 17:28:49 -07:00
Blake Mizerany
2ddacd7516 server/internal/client/ollama: confirm all chunksums were received (#9893)
If the chunksums response is missing a chunk, the client should fail
the download. This changes the client to check that all bytes are
accounted for in the chunksums response.

It is possible there are overlaps or gaps in the chunksums response and
so the size is not the only thing left to check, but this provides
enough coverage for now. We may want to check that chunks are contiguous
later.
2025-03-19 14:59:57 -07:00
Jeffrey Morgan
da0e345200 ml: use input context for extracting outputs (#9875) 2025-03-18 18:08:19 -07:00
Bruce MacDonald
df94175a0f ggml: return error on failure to read tensor data (#9872)
When converting a ggml model if there is a failure to read tensor data a nil error value was being returned. It should be assigned to the actual error from reading.
2025-03-18 16:51:33 -07:00
Bruce MacDonald
61a8825216 convert: return name of unsupported architecture (#9862)
When a model's architecture cannot be converted return the name of the unsupported arch in the error message.
2025-03-18 10:38:28 -07:00
likelovewant
a69a1e6e63 Merge remote-tracking branch 'upstream/main' 2025-03-18 18:09:35 +08:00
Michael Yang
021dcf089d Merge pull request #9824 from ollama/mxyng/sched
conditionally enable parallel pipelines
2025-03-17 15:41:37 -07:00
Jesse Gross
bf24498b1e ollamarunner: Check for minBatch of context space when shifting
Models can specify that a group of inputs need to be handled a single
batch. However, context shifting didn't respect this and could trigger
a break anyways. In this case, we should instead trigger a context
shift earlier so that it occurs before the grouped batch.

Note that there still some corner cases:
 - A long prompt that exceeds the context window can get truncated
   in the middle of an image. With the current models, this will
   result in the model not recognizing the image at all, which is
   pretty much the expected result with truncation.
 - The context window is set less than the minimum batch size. The
   only solution to this is to refuse to load the model with these
   settings. However, this can never occur with current models and
   default settings.

Since users are unlikely to run into these scenarios, fixing them is
left as a follow up.
2025-03-17 15:33:16 -07:00
Bruce MacDonald
95e271d98f runner: remove cache prompt flag from ollama runner (#9826)
We do not need to bypass the prompt caching in the ollama runner yet, as
only embedding models needed to bypass the prompt caching. When embedding
models are implemented they can skip initializing this cache completely.
2025-03-17 15:11:15 -07:00
Jeffrey Morgan
364629b8d6 ml/backend/ggml: allocate memory with malloc when loading model (#9822) 2025-03-17 13:32:40 -07:00
Parth Sareen
108fe02165 sample: make mutations in transforms explicit (#9743)
* updated minP to use early exit making use of sorted tokens
2025-03-17 11:24:18 -07:00
Michael Yang
4561fff36e conditionally enable parallel pipelines 2025-03-17 09:46:07 -07:00
Daniel Hiltgen
50b5962042 Add support for ROCm gfx1151 (#9773) 2025-03-17 09:33:57 -07:00
likelovewant
457576739f Merge branch 'ollama:main' into main 2025-03-17 14:58:37 +08:00
Louis Beaumont
e27e4a3c1b readme: add screenpipe to community integrations (#9786) 2025-03-16 21:56:42 -04:00
zeo
088514bbd4 readme: add Ellama to list of community integrations (#9800) 2025-03-16 21:54:43 -04:00
Patrick Devine
2c8b484643 fix: correctly save in interactive mode (#9788)
This fixes the case where a FROM line in previous modelfile points to a
file which may/may not be present in a different ollama instance. We
shouldn't be relying on the filename though and instead just check if
the FROM line was instead a valid model name and point to that instead.
2025-03-15 12:09:02 -07:00
Blake Mizerany
8294676150 server/internal/client/ollama: set User-Agent for registry client (#9775)
This sets the agent header in DefaultRegistry to include the version of
the client, OS, and architecture in the previous format, with a minor
twist.

Note: The version is obtained from the build info, instead of the
version in version.Version, which should not longer be necessary, but we
can remove in a future commit. Using the build info is more accurate and
also provides extra build information if the build is not tagged, and if
it is "dirty". Previously, the version was just "0.0.0" with no other
helpful information. The ollama.com registry and others handle this
swimmingly.
2025-03-14 18:33:07 -07:00
Patrick Devine
ef378ad673 gemma3 quantization (#9776) 2025-03-14 17:41:07 -07:00
Daniel Hiltgen
2d2247e59e Align versions for local builds (#9635)
Darwin was using a different pattern for the version string
than linux or windows.
2025-03-14 15:44:08 -07:00
Jesse Gross
7bf793a600 gemma3: Allow multiple image in a single input
Previously processing multiple images in a batch would trigger
segfaults so sending images together was disabled as a way to
mitigate this. The trigger was processing one image on the CPU
and one on the GPU.

This can no longer happen:
 - The vision encoder is now on the GPU so both images would be
   processed on the GPU.
 - We require images to be fully contained in a batch and each
   image including its special tokens is over half the batch size.
   As a result, we will never get two images in the same batch.

Fixes #9731
2025-03-14 15:38:54 -07:00
Jesse Gross
282bfaaa95 ollamarunner: Use a separate context per multimodal input
Currently there is a single context per sequence, shared all by
all multimodal inputs. Since we build a vision encoder graph per
image, with a large number of inputs we can eventually hit the
maximum number of graph nodes per context.

This changes to use a separate context for each image, ensuring
that available resource limits are consistent.
2025-03-14 15:38:54 -07:00
Jesse Gross
9679f40146 ml: Allow models to constrain inputs to a single batch
Models may require that a set of inputs all be processed as part
of the same batch. For example, if an image has multiple patches
with fully connected attention between them, we should not split
the batch in the middle of an image.

Fixes #9697
2025-03-14 15:38:54 -07:00
Bruce MacDonald
3892c3a703 llm: remove internal subprocess req and resp types (#9324)
This commit refactors the LLM subsystem by removing internal subprocess
request and response types. It consolidates duplicate type definitions
across the codebase, moving them to centralized locations. The change also
standardizes interfaces between components, simplifies the ServerStatusResp
struct, and moves the ParseDurationMs function to a common package. This
cleanup reduces code duplication between different runner implementations
(llamarunner and ollamarunner).
2025-03-14 15:21:53 -07:00
Blake Mizerany
4e320b8b90 server/internal/chunks: remove chunks package (#9755) 2025-03-14 08:57:59 -07:00
likelovewant
4cd0c73408 Merge branch 'ollama:main' into main 2025-03-14 13:44:39 +08:00
Blake Mizerany
eb2b22b042 server/internal/client: use chunksums for concurrent blob verification (#9746)
Replace large-chunk blob downloads with parallel small-chunk
verification to solve timeout and performance issues. Registry users
experienced progressively slowing download speeds as large-chunk
transfers aged, often timing out completely.

The previous approach downloaded blobs in a few large chunks but
required a separate, single-threaded pass to read the entire blob back
from disk for verification after download completion.

This change uses the new chunksums API to fetch many smaller
chunk+digest pairs, allowing concurrent downloads and immediate
verification as each chunk arrives. Chunks are written directly to their
final positions, eliminating the entire separate verification pass.

The result is more reliable downloads that maintain speed throughout the
transfer process and significantly faster overall completion, especially
over unstable connections or with large blobs.
2025-03-13 22:18:29 -07:00
Michael Yang
4ea4d2b189 Merge pull request #9703 from ollama/mxyng/gemma3-memory
count gemma3 vision tensors
2025-03-13 16:56:34 -07:00
Michael Yang
8d76fa23ef count non-repeating vision layers 2025-03-13 16:53:29 -07:00
Bradley Erickson
74b44fdf8f docs: Add OLLAMA_ORIGINS for browser extension support (#9643) 2025-03-13 16:35:20 -07:00
Michael Yang
65b88c544f fix divide by zero 2025-03-13 16:35:00 -07:00
Michael Yang
a422ba39c9 roughly count gemma3 graph
the largest operation is by far (q @ k) so just count that for
simplicity
2025-03-13 16:35:00 -07:00
Michael Yang
d2ec22371e count all vision tensors 2025-03-13 16:35:00 -07:00
Michael Yang
033cec232a count gemma3 vision tensors 2025-03-13 16:34:42 -07:00
Michael Yang
543240fb5f Merge pull request #9741 from ollama/mxyng/visionless
fix: error if image requested without vision model
2025-03-13 15:03:25 -07:00
Patrick Devine
4bed739259 add verbose mode to the show command (#9640)
Add metadata and tensor information to the show command to be able to
see more information about a model. This outputs the same data as
shown on the model details page on ollama.com
2025-03-13 14:24:27 -07:00
Patrick Devine
80c7ce381b fix: change default context size for gemma3 (#9744) 2025-03-13 13:59:19 -07:00
Michael Yang
ccfd41c4f0 Merge pull request #9742 from ollama/mxyng/engine-error-embeddings
fix: error on models that don't support embeddings
2025-03-13 13:12:33 -07:00
Michael Yang
3e102b7dad Update model/model.go
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2025-03-13 13:11:52 -07:00
Michael Yang
ec46f3286c engine: error on embeddings; not currently implemented 2025-03-13 11:40:55 -07:00
Michael Yang
5e2e0b46b1 fix: error if image requested without vision model 2025-03-13 10:52:09 -07:00
Michael Yang
45a13b1dec Merge pull request #9688 from Shane-XB-Qian/debug_mistype_lld
ollama-debug.c: correct mistype
2025-03-13 10:12:44 -07:00
Parth Sareen
5c0b663969 sample: separate softmax and temperature transforms (#9732) 2025-03-13 09:53:27 -07:00
shane.xb.qian
30d7a59ba8 ollama-debug.c: change 'ld' to 'PRIi64'
* macOS has different definition per info from @mxyng
2025-03-13 17:10:37 +08:00
ParthSareen
4aeb67ef4c sample: do all sorting in topK 2025-03-12 11:59:17 -07:00
ParthSareen
3ba91634c1 sample: simplify top_k=0 sorting 2025-03-12 11:59:17 -07:00
ParthSareen
1b7433b71e sample: use container/heap for top_k 2025-03-12 11:59:17 -07:00
Bruce MacDonald
a70820daa0 models/gemma3: remove final logit softcap (#9692)
Softcap isn't in the whitepaper/implementation for the language model so we should remove it. There is no discernible difference in output with it removed.
2025-03-12 10:17:57 -07:00
Shane-XB-Qian
6b45b1d6b4 cli: adding support ctrl-n/p like general cli (#9136)
Signed-off-by: shane.xb.qian <shane.qian@foxmail.com>
2025-03-12 08:51:56 -07:00
shane.xb.qian
85ab552028 ollama-debug.c: correct mistype
Signed-off-by: shane.xb.qian <shane.qian@foxmail.com>
2025-03-12 22:32:30 +08:00
likelovewant
c3945aaa1d Merge branch 'ollama:main' into main 2025-03-12 15:00:44 +08:00
frob
b3af953a55 cli: don't exit for invalid model during /load. (#9576)
Co-authored-by: Richard Lyons <frob@cloudstaff.com>
2025-03-11 23:42:53 -07:00
likelovewant
3a65093078 remove extra setting 2025-03-12 14:40:55 +08:00
Michael
ad4e0bf3be Adding Gemma 3 to readme (#9671) 2025-03-12 07:39:25 +01:00
likelovewant
88ab587807 Merge branch 'ollama:main' into main 2025-03-12 14:32:40 +08:00
Michael Yang
aee28501b5 Merge pull request #9661 from ollama/gemma
engine: add gemma support
2025-03-11 15:07:50 -07:00
jmorganca
83f0ec8269 all: address linter errors 2025-03-11 14:49:20 -07:00
jmorganca
c6b6938b3a kvcache: fix tests by adding AvgPool2D stub 2025-03-11 14:49:20 -07:00
jmorganca
fb4664fcec model: add more spm tokenizer tests 2025-03-11 14:49:20 -07:00
jmorganca
20e3593863 model: validate left and right pairs before merging them 2025-03-11 14:49:20 -07:00
Michael Yang
63a394068c use 2d pooling 2025-03-11 14:49:20 -07:00
Daniel Hiltgen
ab39e08eb9 llm: auto detect models that require Ollama Engine (#1) 2025-03-11 14:49:20 -07:00
jmorganca
11bfa62796 add trailing \n\n after <end_of_image> to match reference implementation 2025-03-11 14:49:20 -07:00
jmorganca
f63e62e546 reduce kernel size, add TODO for loading from config 2025-03-11 14:49:20 -07:00
jmorganca
65b0f329d1 Revert "Allow models to force a new batch"
This reverts commit c7eae586b899083acebcd9b3847b89ea78c2850c.
2025-03-11 14:49:20 -07:00
Jesse Gross
06007c0a18 Allow models to force a new batch
This is useful for a few things:
 - Work around bugs, such as having 2 images in one batch
 - Keep the image in a single batch for fully connected attention
 - Improve performance by not evaluating embeddings multiple times
2025-03-11 14:49:20 -07:00
Jesse Gross
a8e83a7654 Disable causal attention based on batch index
Currently we are using positions, which are relative to a
sequence and may not be unique.
2025-03-11 14:49:20 -07:00
Jesse Gross
475005504e Restrict Gemma to a single image per request 2025-03-11 14:49:20 -07:00
Jesse Gross
2c40c4d35e Fix follow up images and images split across batches 2025-03-11 14:49:19 -07:00
Michael Yang
e95278932b use non-causal mask only for image positions 2025-03-11 14:49:19 -07:00
Michael Yang
9d2a20a763 use non-causal mask for inputs with images 2025-03-11 14:49:19 -07:00
Patrick Devine
2e54d72fc3 fix gemma3 1b conversion 2025-03-11 14:49:19 -07:00
Michael Yang
6b32a2d549 compat with upstream gguf 2025-03-11 14:49:19 -07:00
Michael Yang
c5cbe4fc2a fallback to cpu 2025-03-11 14:49:19 -07:00
Michael Yang
f888912870 fix vision encoder 2025-03-11 14:49:19 -07:00
Michael Yang
9e4642e9b3 ollama debug tensor 2025-03-11 14:49:19 -07:00
Michael Yang
6b0486c216 duplicate token_embd to output 2025-03-11 14:49:19 -07:00
Michael Yang
d368c039f0 skip repacking vision tensors 2025-03-11 14:49:19 -07:00
Patrick Devine
9b54267e69 fix configs 2025-03-11 14:49:19 -07:00
Michael Yang
46bb0169c4 update model 2025-03-11 14:49:19 -07:00
Michael Yang
8934324b72 use fast attention 2025-03-11 14:49:18 -07:00
Jesse Gross
0e886595bf Fix tests and drift from main 2025-03-11 14:49:18 -07:00
Patrick Devine
c62861f4fa fix conversion 2025-03-11 14:49:18 -07:00
Michael Yang
0df1800436 set non-causal attention 2025-03-11 14:49:18 -07:00
Patrick Devine
631fecc6d9 temporary work around for converting spm 2025-03-11 14:49:18 -07:00
Jesse Gross
4346c2409d fix drift from main 2025-03-11 14:49:18 -07:00
Michael Yang
4b037a97dc add gemma vision encoder 2025-03-11 14:49:17 -07:00
Patrick Devine
5f74d1fd47 gemma2 impl 2025-03-11 14:35:08 -07:00
Daniel Hiltgen
4dcf80167a Build release for windows with local script (#9636) 2025-03-11 08:34:20 -07:00
Michael Yang
26a26998fb Merge pull request #9590 from ollama/mxyng/dump-pad
fix: pad tensor item if ge zero
2025-03-10 16:34:55 -07:00
Michael Yang
9926eae015 fix: pad tensor item if ge zero
this produces a nicer output since both positive and negative values
produces the same width
2025-03-10 16:18:12 -07:00
Vincent Koc
8585b7b151 docs: add opik to observability integrations (#9626) 2025-03-10 16:15:10 -07:00
Parth Sareen
7e34f4fbfa sample: add numerical stability to temperature/softmax transform (#9631) 2025-03-10 14:43:53 -07:00
Michael Yang
fe776293f7 Merge pull request #9569 from dwt/patch-1
Better WantedBy declaration
2025-03-10 14:09:37 -07:00
frob
d8a5d96b98 docs: Add OLLAMA_CONTEXT_LENGTH to FAQ. (#9545) 2025-03-10 11:02:54 -07:00
Xiaowei Zhu
757668c42f docs: add SwiftChat (#9540) 2025-03-10 11:01:09 -07:00
Sam
96ec8afd09 docs(tool): add mcp-llm (#9537) 2025-03-10 09:52:02 -07:00
Jeffrey Morgan
e093db92c4 sample: temporarily use grammars for constrained generation in new engine (#9586) 2025-03-10 16:17:39 +01:00
Jesse Gross
a1cda80bcb model: Update encoder cache to use multimodal input processing handler
The encoder cache needs to know the position of images in the input
stream so that it knows when to delete them. Previously images didn't
have a position, so we implied one by breaking batches before an
image and then assuming the image was in the first position. However,
multimodal objects are now given explicit positions in the input
stream, so we can use that instead.

Breaking batches was also a way to simulate a cross attention mask
for mllama. However, given that it only supports a single sequence
and a single image, this mask doesn't serve any real purpose.
Removing the batch break does not appear to affect the quality of
the output.

Most of this is simply moving the input data structures to a new
package to avoid import cycles.
2025-03-09 17:05:26 -07:00
likelovewant
642a2496fe Merge branch 'ollama:main' into main 2025-03-09 13:49:03 +08:00
Jesse Gross
4614fafae0 ollamarunner: Don't panic for unimplemented features at runtime.
It's ok to fail on startup but we shouldn't panic during runtime
based on user input. Downgrade the panic to a warning.
2025-03-08 18:58:18 -08:00
Jesse Gross
4100ed7bdd ml: Add support for quantized KV cache
Similar to the llama engine, quantizing the KV cache requires
flash attention to be enabled through the Ollama server.
2025-03-07 18:43:39 -08:00
Jesse Gross
f52b2615ef kvcache: Set context for shift offsets 2025-03-07 18:43:39 -08:00
Jesse Gross
25f9b152f9 ggml-backend: Ensure allocation meet backend requirements
Backends can impose additional alignment requirements on buffer sizes.
We should ensure that we meet these or allocations can fail.
2025-03-07 18:43:39 -08:00
Jesse Gross
6da8b6a879 kvcache: Support non-causal attention
Models can disable causality for all or part of their processing
while continuing to store data in the KV cache.
2025-03-07 18:39:27 -08:00
Jesse Gross
0daaaef8c9 ollamarunner: Quiet debug logging and panic on unimplemented features
Debug logging of every token has previously caused test timeouts
on slower machines.
2025-03-07 18:38:02 -08:00
Jesse Gross
98272fbd58 additional review comments 2025-03-07 14:08:21 -08:00
Michael Yang
b27e8f3f10 ml/backend/ggml: use backend buffer type
this ensures the tensor is created on the right buffer type for backends
such as cpu
2025-03-07 14:08:21 -08:00
Michael Yang
45df786f09 comments 2025-03-07 14:08:21 -08:00
Michael Yang
daaf42e4a4 ml/backend/ggml: clean up 2025-03-07 14:08:21 -08:00
Michael Yang
2dc60d4620 ml/backend/ggml: offload vision to cpu
temporary until tensor loading can accurately account for vision models
2025-03-07 14:08:21 -08:00
Michael Yang
b5312f30e8 ml/backend/ggml: handle tensor split 2025-03-07 14:08:21 -08:00
Michael Yang
26c2e0bd35 ml/backend/ggml: handle user specified cpu offloading 2025-03-07 14:08:21 -08:00
Michael Yang
bf920883d5 ml/backend/ggml: set cpu n_threads 2025-03-07 14:08:21 -08:00
Michael Yang
58b9ec1f6b kvcache: update tests 2025-03-07 14:08:21 -08:00
Michael Yang
7bae7fa5ce ml/backend/ggml: create tensor on specific backend
some tensors should be created on specific backends to reduce number of
copies and improve performance
2025-03-07 14:08:21 -08:00
Michael Yang
764e199d67 kvcache: create cache ctx per layer
each cache layer creates and maintains its own context instead of using
a large context for all layers
2025-03-07 14:08:21 -08:00
Michael Yang
bfce55db3d model: load non-repeated tensors into multiple backends
some tensors are expected to be used in repeating layers but are not
themselves repeated. this change copies these tensors into the same
backends as their repeating counterparts to minimize copying tensors
between backends
2025-03-07 14:08:21 -08:00
Michael Yang
bab6f34dc0 ml/backend/ggml: update model loading for hybrid/multi backends
use a similar strategy as llama.cpp for deciding where tensors should be
allocated. this will be improved later to be aware of usable memory
before assigning the tensor
2025-03-07 14:08:21 -08:00
Parth Sareen
0682dae027 sample: improve ollama engine sampler performance (#9374)
This change bring in various interface cleanups along with greatly improving the performance of the sampler.

Tested with llama3.2 on local machine.
Improves performance from ~ 70 tokens/s -> 135 tokens/s with topK(40) enabled.
Without topK performance is ~ 110 tokens/s
2025-03-07 12:37:48 -08:00
Breaker
1f6986e919 readme: add QwQ to the supported models list (#9565) 2025-03-07 09:30:07 -08:00
Jeffrey Morgan
4289c74359 llama: fix kv loading on snowflake-arctic-embed models (#9536) 2025-03-07 09:25:34 -08:00
‮rekcäH nitraM‮
25248f4bd5 Better WantedBy declaration
The problem with default.target is that it always points to the target that is currently started. So if you boot into single user mode or the rescue mode still Ollama tries to start.

I noticed this because either tried (and failed) to start all the time during a system update, where Ollama definitely is not wanted.
2025-03-07 10:26:31 +01:00
likelovewant
e82001c122 fix the min error 2025-03-07 12:22:51 +08:00
Jesse Gross
a7e63b82be ollamarunner: Improve multimodal input handling
Various vision models have different requirements for how they
receive their inputs. For example:
 - Mllama wants images together with text and the image embeddings
   don't themselves have positions or get stored in the main KV cache
 - Llava-style models feed in embeddings similar to tokens and
   images correspond to a varying number of tokens in the cache.

In addition, the strategy for providing inputs must support batching
and multiple sequences, which are managed by the runner. At the same
time, we want to keep data handling fully in the model so that new
architectures are not bottlenecked by runner code which does not
understand their particular requirements.

This provides a method for models to edit the input stream so that
it meets their needs while still being in a format that the runner
understands. This allows the runner to avoid special processing
for different models.

In addition, this fixes a regression where non-vision models may
try to incorrectly interpret images.
2025-03-06 16:54:16 -08:00
Jesse Gross
b70fc4d51e model: Don't unconditionally add special tokens
We sometimes tokenize partial strings. For example, with
multimodal inputs, we split the input string around the images
and then tokenize each piece. In these cases, we should only add
the special tokens on the first piece.
2025-03-06 16:54:16 -08:00
Blake Mizerany
e2252d0fc6 server/internal/registry: take over pulls from server package (#9485)
This commit replaces the old pull implementation in the server package
with the new, faster, more robust pull implementation in the registry
package.

The new endpoint, and now the remove endpoint too, are behind the
feature gate "client2" enabled only by setting the OLLAMA_EXPERIMENT
environment variable include "client2".

Currently, the progress indication is wired to perform the same as the
previous implementation to avoid making changes to the CLI, and because
the status reports happen at the start of the download, and the end of
the write to disk, the progress indication is not as smooth as it could
be. This is a known issue and will be addressed in a future change.

This implementation may be ~0.5-1.0% slower in rare cases, depending on
network and disk speed, but is generally MUCH faster and more robust
than the its predecessor in all other cases.
2025-03-05 14:48:18 -08:00
Daniel Hiltgen
cae5d4d4ea Win: doc new rocm zip file (#9367)
To stay under the 2G github artifact limit, we're splitting ROCm
out like we do on linux.
2025-03-05 14:11:21 -08:00
likelovewant
d80ea37d36 Merge branch 'ollama:main' into main 2025-03-05 13:40:11 +08:00
Michael Yang
05a01fdecb ml/backend/ggml: consolidate system info logging
- output backend system info when initializing the backend. this ensures
  this information is always present without needing to be called
  explicitly
- convert to structured logging
- enumerate devices rather than backends since devices are ordered
- track device indices grouped by device name
2025-03-04 15:14:31 -08:00
aritra saha
8fe6f69f28 docs: add granite-3.2 to the readme 2025-03-04 11:10:56 -08:00
Daniel Hiltgen
1fdb351c37 New engine: vision models and auto-fallback (#9113)
* Include unified vision layers in memory prediction

For newer vision models with a single gguf, include
the projection estimates.

* Adjust CLI to handle both styles of vision model metadata

* Wire up new tokenizers for new engine

If we're loading the new engine, utilize the new model
text processor instead of calling into cgo wrappers for
llama.cpp.  This also cleans up some tech debt from the
older tokenization flow for the C++ server which was
no longer used.

This also adjusts the grammar handling logic to pass
through to the new engine instead of utilizing the cgo
schema to grammar call.

* Lay foundation for auto selection of new engine
2025-03-04 09:03:46 -08:00
Blake Mizerany
7a01ad7614 server/internal/registry: reintroduce pruning on model deletion (#9489)
This reintroduces aggressive pruning on model deletion as a temporary
measure until a more controlled garbage collection (GC) mechanism is
implemented.

Issues with the current approach:

1. Users may accidentally delete a model (`ollama rm llama3.3` instead
   of `ollama rm llama3.2`), requiring a full re-download unless another
   model references the same blobs.

2. Users may assume a deleted model is still referenced elsewhere, but
   due to prior updates or deletions, the references no longer exist,
   leading to unnecessary re-downloads.

Soon, we should implement a structured GC mechanism to retain
unreferenced blobs for a configurable period before removal, which will
run on "ollama rm" and other commands we deem appropriate.

Users that want to immediately remove unreferenced blobs can use a new
prune command that will allow them to specify the age and class of blobs
to remove.

Example usage:

    # Run basic blob GC
    $ ollama prune

    # Remove unreferenced blobs older than 7 days
    $ ollama prune --age 7d

    # Remove all blobs, referenced or not, older than 7 days (and their manifests?)
    $ ollama prune --age 7d --all

    # Remove all unreferenced blobs immediately
    $ ollama prune --age 0 --all

    # Remove all blobs
    $ ollama prune --age 0 --all

This should provide a safer and more predictable cleanup process.
2025-03-03 19:11:16 -08:00
Blake Mizerany
55ab9f371a server/.../backoff,syncs: don't break builds without synctest (#9484)
Previously, developers without the synctest experiment enabled would see
build failures when running tests in some server/internal/internal
packages using the synctest package. This change makes the transition to
use of the package less painful but guards the use of the synctest
package with build tags.

synctest is enabled in CI. If a new change will break a synctest
package, it will break in CI, even if it does not break locally.

The developer docs have been updated to help with any confusion about
why package tests pass locally but fail in CI.
2025-03-03 16:45:40 -08:00
KindBrave
fefbf8f74b docs: add Ollama Android Chat community integration 2025-03-03 16:38:32 -08:00
Michael Yang
b428ddd796 docker: use go version from go.mod 2025-03-03 13:02:02 -08:00
Michael Yang
ba7d31240e fix: own lib/ollama directory
expand backend loading error handling to catch more problems and log
them instead of panicing
2025-03-03 13:01:18 -08:00
CYJiang
d25efe3954 cmd: add default err return for stop (#9458) 2025-03-03 12:13:41 -08:00
Mark
36dfb906bb docs: don't use self-closing tag for anchor element (#9456) 2025-03-03 11:56:34 -08:00
aritra saha
a6f0f908b9 docs: update phi3-mini to phi4-mini (#9424)
* Update README.md

removed phi 3 mini and added phi4-mini

* Update README.md

---------

Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2025-03-03 11:09:21 -08:00
İbrahim Çetin
3b1ddb2b3a docs: add reins to community integrations (#9411) 2025-03-03 11:06:30 -08:00
Jeffrey Morgan
1579c4f06d build: install binutils alongside gcc in Dockerfile (#9475) 2025-03-03 01:20:49 -08:00
Blake Mizerany
3519dd1c6e server/internal/client/ollama: hold DiskCache on Registry (#9463)
Previously, using a Registry required a DiskCache to be passed in for
use in various methods. This was a bit cumbersome, as the DiskCache is
required for most operations, and the DefaultCache is used in most of
those cases. This change makes the DiskCache an optional field on the
Registry struct.

This also changes DefaultCache to initialize on first use. This is to
not burden clients with the cost of creating a new cache per use, or
having to hold onto a cache for the lifetime of the Registry.

Also, slip in some minor docs updates for Trace.
2025-03-02 20:55:44 -08:00
Jeffrey Morgan
e41c4cbea7 build: install ccache manually in Dockerfile (#9464)
Reverts ccache installation to be done manually via curl instead of
using the dnf package manager as this has side effects of prepending
ccache's install directory to the front of the PATH
2025-03-02 16:48:31 -08:00
Blake Mizerany
ee048b76d4 server/internal/client/ollama: handle extended names in client/ollama (#9454)
The extended name format is a superset of the name format that only the
client needs to know about, not the server or other dependents of the
name package, so move the split logic into the client package.

Also, take advantage of knowing about the extended name format to allow
the client to use the extended name format when unlinking to verify they
are unlinking the manifest with the content they intend.
2025-03-02 13:30:41 -08:00
Soulter
af68d60a58 readme: add AstrBot to community integrations (#9442) 2025-03-01 21:58:34 -08:00
likelovewant
92731dfc6f Merge branch 'ollama:main' into main 2025-03-02 13:45:52 +08:00
Jesse Gross
21aa666a1e ml: Enable support for flash attention
The GGML flash attention kernel has specific requirements for
padding and permutation. This adds support to the KV cache
for conforming to these requirements so that flash attention
can be enabled.

Flash attention can be used in the same situations as the llama
engine and is enabled by the user in the same way.
2025-03-01 20:53:23 -08:00
Jesse Gross
ee141cc821 ml: Empty tensor constructor for tensors
In cases where we allocate a tensor and then fully overwrite it with
copied data, it is wasteful to first zero out the memory.
2025-03-01 20:53:23 -08:00
Jesse Gross
55e5776c44 ggml-backend: Store parent backend as part of tensor
It can be important for a tensor to know what backend it came from -
for example, to know if flash attention is enabled.
2025-03-01 20:53:23 -08:00
Jesse Gross
854a9195f3 attention: Remove unnecessary contiguous operations
Prior to performing attention, we need to permute query, key
and value. Currently we call Contiguous after each of these
permutations, which is correct but expensive. Avoiding the
3 calls to Contiguous increases performance by over 20%.

The permutations of query and key do not violate the continuity
rules for mulmat and the Contiguous call can be simply removed.

Value requires a different permutation and does require Contiguous.
However, we can use the copy into the cache as a way to perform this
without further overhead.

To support this and avoid unexpected tensor shapes that are seen by
models, we need tighter integration between attention, cache
and backend. Future optimization will also likely need this structure
 - for example, flash attention has special padding requirements in
the cache and other backends may have their own needs.

This further contains the operations that go into attention so that
these and other optimizations can be handled transparently. Models
that have special requirements for attention can still implement
their own version of it.
2025-03-01 20:53:23 -08:00
Jeffrey Morgan
96a97adf9b build: use correct GGML_HIP_NO_VMM compiler definition for ggml-hip (#9451) 2025-03-01 17:00:31 -08:00
Jeffrey Morgan
e75c6126e9 build: set GGML_CUDA_NO_VMM for ggml-hip target (#9449) 2025-03-01 14:02:19 -08:00
Blake Mizerany
cda6f5c66c server/internal/internal/names: validate names (#9400)
This commit is a step towards a goal to make names less ceremonial
outside of the registry client. Clients of the registry package can
treat names as opaque strings, and the registry package will handle
parsing, validating, and normalizing names.

Ideally we end up with the names package tucked away in an internal
package for good. We'll see how things go.

Also, this package name is not permanent. This another step in the
on-going process of refactoring the server code, and at some point it
will most likely be renamed/moved.
2025-03-01 13:15:14 -08:00
likelovewant
1f7de23036 Merge branch 'ollama:main' into main 2025-03-01 15:42:15 +08:00
Bruce MacDonald
bebb6823c0 server: validate local path on safetensor create (#9379)
More validation during the safetensor creation process.
Properly handle relative paths (like ./model.safetensors) while rejecting absolute paths
Add comprehensive test coverage for various paths
No functionality changes for valid inputs - existing workflows remain unaffected
Leverages Go 1.24's new os.Root functionality for secure containment
2025-02-28 16:10:43 -08:00
Michael Yang
31e472baa4 runner: defer context cancel
defer the cancel to guarantee it runs
2025-02-28 22:27:28 +00:00
Michael Yang
657685e85d fix: replace deprecated functions 2025-02-28 21:29:34 +00:00
Jeffrey Morgan
a14912858e build: add compute capability 12.0 to CUDA 12 preset (#9426)
Focuses initial Blackwell support on compute capability 12.0
which includes the 50x series of GeForce cards. In the future
additional compute capabilities may be added
2025-02-28 13:12:31 -08:00
Blake Mizerany
eed11ded30 server/.../safetensors: fix offsets and include all model parts (#9427)
Also, require the -as flag to be set when importing a model. This
prevents the confusing error message "invalid name".

Also, allow short names to be used when importing a model and
auto-complete the name with the default mask.
2025-02-28 13:08:10 -08:00
Michael Yang
b42aba40ed cuda: enable flash attention
ggml added an option to disable flash attention so explicitly enable it
2025-02-28 19:40:34 +00:00
王贺
25885e5335 docs: Add 1Panel to Community Integrations (#9312) 2025-02-28 09:53:03 -08:00
123 changed files with 7109 additions and 2704 deletions

View File

@@ -23,6 +23,7 @@ set(GGML_SCHED_MAX_COPIES 4)
set(GGML_LLAMAFILE ON)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
set(GGML_CUDA_GRAPHS ON)
set(GGML_CUDA_FA ON)
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
@@ -85,21 +86,6 @@ if(CMAKE_CUDA_COMPILER)
)
endif()
set(AMDGPU_TARGETS
gfx1010:xnack-
gfx1011
gfx1012:xnack-
gfx1030
gfx1031
gfx1032
gfx1034
gfx1035
gfx1100
gfx1101
gfx1102
gfx1103
gfx1150
CACHE STRING "List of AMDGPU targets to build for")
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX ""
CACHE STRING
@@ -112,7 +98,7 @@ if(CMAKE_HIP_COMPILER)
find_package(hip REQUIRED)
if(NOT AMDGPU_TARGETS)
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|900(:xnack-)|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011|1012(:xnack-)|103[0-6]|110[0-3]|1150)$")
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|900(:xnack-)|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011(:xnack-)|1012(:xnack-)|103[0-6]|110[0-3]|115[01]|1201)$")
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
endif()
@@ -121,9 +107,11 @@ if(CMAKE_HIP_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
if (WIN32)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
endif()
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
install(TARGETS ggml-hip
RUNTIME_DEPENDENCIES

View File

@@ -28,7 +28,7 @@
"name": "CUDA 12",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;100"
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120"
}
},
{
@@ -56,7 +56,7 @@
"name": "ROCm 6",
"inherits": [ "ROCm" ],
"cacheVariables": {
"AMDGPU_TARGETS": "gfx803;gfx902;gfx1011;gfx1030;gfx1031;gfx1032;gfx1034;gfx1035;gfx1036;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx900:xnack-;gfx906:xnack-;gfx90c:xnack-;gfx1010:xnack-;gfx1012:xnack-;"
"AMDGPU_TARGETS": "gfx803;gfx902;gfx1030;gfx1031;gfx1032;gfx1034;gfx1035;gfx1036;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1201;gfx900:xnack-;gfx906:xnack-;gfx90c:xnack-;gfx1010:xnack-;gfx1011:xnack-;gfx1012:xnack-;"
}
}
],

View File

@@ -12,7 +12,7 @@ FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base
RUN yum install -y yum-utils \
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
@@ -86,10 +86,11 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS build
ARG GOVERSION=1.23.4
RUN curl -fsSL https://golang.org/dl/go${GOVERSION}.linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
WORKDIR /go/src/github.com/ollama/ollama
COPY go.mod go.sum .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download
COPY . .
ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1

View File

@@ -1,5 +1,5 @@
<div align="center">
  <a href="https://ollama.com" />
  <a href="https://ollama.com">
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
</a>
</div>
@@ -76,6 +76,11 @@ Here are some example models that can be downloaded:
| Model | Parameters | Size | Download |
| ------------------ | ---------- | ----- | -------------------------------- |
| Gemma 3 | 1B | 815MB | `ollama run gemma3:1b` |
| Gemma 3 | 4B | 3.3GB | `ollama run gemma3` |
| Gemma 3 | 12B | 8.1GB | `ollama run gemma3:12b` |
| Gemma 3 | 27B | 17GB | `ollama run gemma3:27b` |
| QwQ | 32B | 20GB | `ollama run qwq` |
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
@@ -86,10 +91,7 @@ Here are some example models that can be downloaded:
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` |
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` |
| Mistral | 7B | 4.1GB | `ollama run mistral` |
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
@@ -97,7 +99,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Solar | 10.7B | 6.1GB | `ollama run solar` |
| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` |
> [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -297,6 +299,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop
- [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
- [Hollama](https://github.com/fmaclen/hollama)
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
@@ -408,6 +411,11 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
### Cloud
@@ -451,6 +459,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Apple Vision Pro
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Cross-platform AI chat app supporting Apple Vision Pro via "Designed for iPad")
- [Enchanted](https://github.com/AugustDev/enchanted)
### Database
@@ -528,10 +537,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Mobile
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
- [Enchanted](https://github.com/AugustDev/enchanted)
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
### Extensions & Plugins
@@ -577,12 +589,14 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
### Supported backends
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.

View File

@@ -349,6 +349,7 @@ type ShowResponse struct {
Messages []Message `json:"messages,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"`
}
@@ -361,9 +362,9 @@ type CopyRequest struct {
// PullRequest is the request passed to [Client.Pull].
type PullRequest struct {
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"`
Password string `json:"password"`
Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored
Username string `json:"username"` // Deprecated: ignored
Password string `json:"password"` // Deprecated: ignored
Stream *bool `json:"stream,omitempty"`
// Deprecated: set the model name with Model instead
@@ -467,6 +468,13 @@ type ModelDetails struct {
QuantizationLevel string `json:"quantization_level"`
}
// Tensor describes the metadata for a given tensor.
type Tensor struct {
Name string `json:"name"`
Type string `json:"type"`
Shape []uint64 `json:"shape"`
}
func (m *Metrics) Summary() {
if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)

View File

@@ -0,0 +1,178 @@
package benchmark
import (
"context"
"flag"
"fmt"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// Command line flags
var modelFlag string
func init() {
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
flag.Lookup("m").DefValue = "model"
}
// modelName returns the model name from flags, failing the test if not set
func modelName(b *testing.B) string {
if modelFlag == "" {
b.Fatal("Error: -m flag is required for benchmark tests")
}
return modelFlag
}
type TestCase struct {
name string
prompt string
maxTokens int
}
// runGenerateBenchmark contains the common generate and metrics logic
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
start := time.Now()
var ttft time.Duration
var metrics api.Metrics
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if ttft == 0 && resp.Response != "" {
ttft = time.Since(start)
}
if resp.Done {
metrics = resp.Metrics
}
return nil
})
// Report custom metrics as part of the benchmark results
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
// Token throughput metrics
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
b.ReportMetric(promptThroughput, "prompt_tok/s")
b.ReportMetric(genThroughput, "gen_tok/s")
// Token counts
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
if err != nil {
b.Fatal(err)
}
}
// BenchmarkColdStart runs benchmarks with model loading from cold state
func BenchmarkColdStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
b.StopTimer()
// Ensure model is unloaded before each iteration
unload(client, m, b)
b.StartTimer()
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// BenchmarkWarmStart runs benchmarks with pre-loaded model
func BenchmarkWarmStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// setup verifies server and model availability
func setup(b *testing.B) *api.Client {
client, err := api.ClientFromEnvironment()
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}
return client
}
// warmup ensures the model is loaded and warmed up
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
for range 3 {
err := client.Generate(
context.Background(),
&api.GenerateRequest{
Model: model,
Prompt: prompt,
Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
},
func(api.GenerateResponse) error { return nil },
)
if err != nil {
b.Logf("Error during model warm-up: %v", err)
}
}
}
// unload forces model unloading using KeepAlive: 0 parameter
func unload(client *api.Client, model string, b *testing.B) {
req := &api.GenerateRequest{
Model: model,
KeepAlive: &api.Duration{Duration: 0},
}
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
b.Logf("Unload error: %v", err)
}
time.Sleep(1 * time.Second)
}

View File

@@ -18,6 +18,7 @@ import (
"os/signal"
"path/filepath"
"runtime"
"sort"
"strconv"
"strings"
"sync/atomic"
@@ -34,7 +35,6 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/runner"
@@ -256,6 +256,7 @@ func StopHandler(cmd *cobra.Command, args []string) error {
if strings.Contains(err.Error(), "not found") {
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
}
return err
}
return nil
}
@@ -338,10 +339,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err
}
// TODO(jessegross): We should either find another way to know if this is
// a vision model or remove the logic. Also consider that other modalities will
// need different behavior anyways.
opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine()
if len(info.ProjectorInfo) != 0 {
opts.MultiModal = true
}
for k := range info.ModelInfo {
if strings.Contains(k, ".vision.") {
opts.MultiModal = true
break
}
}
opts.ParentModel = info.Details.ParentModel
if interactive {
@@ -562,8 +569,9 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
parameters, errParams := cmd.Flags().GetBool("parameters")
system, errSystem := cmd.Flags().GetBool("system")
template, errTemplate := cmd.Flags().GetBool("template")
verbose, errVerbose := cmd.Flags().GetBool("verbose")
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate, errVerbose} {
if boolErr != nil {
return errors.New("error retrieving flags")
}
@@ -601,7 +609,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
}
req := api.ShowRequest{Name: args[0]}
req := api.ShowRequest{Name: args[0], Verbose: verbose}
resp, err := client.Show(cmd.Context(), &req)
if err != nil {
return err
@@ -624,10 +632,10 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return nil
}
return showInfo(resp, os.Stdout)
return showInfo(resp, verbose, os.Stdout)
}
func showInfo(resp *api.ShowResponse, w io.Writer) error {
func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
tableRender := func(header string, rows func() [][]string) {
fmt.Fprintln(w, " ", header)
table := tablewriter.NewWriter(w)
@@ -684,6 +692,47 @@ func showInfo(resp *api.ShowResponse, w io.Writer) error {
})
}
if resp.ModelInfo != nil && verbose {
tableRender("Metadata", func() (rows [][]string) {
keys := make([]string, 0, len(resp.ModelInfo))
for k := range resp.ModelInfo {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
var v string
switch vData := resp.ModelInfo[k].(type) {
case bool:
v = fmt.Sprintf("%t", vData)
case string:
v = vData
case float64:
v = fmt.Sprintf("%g", vData)
case []any:
n := 3
if len(vData) < n {
n = len(vData)
}
v = fmt.Sprintf("%v", vData[:n])
default:
v = fmt.Sprintf("%T", vData)
}
rows = append(rows, []string{"", k, v})
}
return
})
}
if len(resp.Tensors) > 0 && verbose {
tableRender("Tensors", func() (rows [][]string) {
for _, t := range resp.Tensors {
rows = append(rows, []string{"", t.Name, t.Type, fmt.Sprint(t.Shape)})
}
return
})
}
head := func(s string, n int) (rows [][]string) {
scanner := bufio.NewScanner(strings.NewReader(s))
for scanner.Scan() && (len(rows) < n || n < 0) {
@@ -1190,6 +1239,7 @@ func NewCLI() *cobra.Command {
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
showCmd.Flags().Bool("template", false, "Show template of a model")
showCmd.Flags().Bool("system", false, "Show system message of a model")
showCmd.Flags().BoolP("verbose", "v", false, "Show detailed model information")
runCmd := &cobra.Command{
Use: "run MODEL [PROMPT]",
@@ -1274,7 +1324,6 @@ func NewCLI() *cobra.Command {
runnerCmd := &cobra.Command{
Use: "runner",
Short: llama.PrintSystemInfo(),
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
return runner.Execute(os.Args[1:])

View File

@@ -27,7 +27,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
}, &b); err != nil {
}, false, &b); err != nil {
t.Fatal(err)
}
@@ -57,7 +57,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
}, &b); err != nil {
}, false, &b); err != nil {
t.Fatal(err)
}
@@ -68,6 +68,60 @@ func TestShowInfo(t *testing.T) {
embedding length 0
quantization FP16
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("verbose model", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "8B",
QuantizationLevel: "FP16",
},
Parameters: `
stop up`,
ModelInfo: map[string]any{
"general.architecture": "test",
"general.parameter_count": float64(8_000_000_000),
"some.true_bool": true,
"some.false_bool": false,
"test.context_length": float64(1000),
"test.embedding_length": float64(11434),
},
Tensors: []api.Tensor{
{Name: "blk.0.attn_k.weight", Type: "BF16", Shape: []uint64{42, 3117}},
{Name: "blk.0.attn_q.weight", Type: "FP16", Shape: []uint64{3117, 42}},
},
}, true, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 8B
context length 1000
embedding length 11434
quantization FP16
Parameters
stop up
Metadata
general.architecture test
general.parameter_count 8e+09
some.false_bool false
some.true_bool true
test.context_length 1000
test.embedding_length 11434
Tensors
blk.0.attn_k.weight BF16 [42 3117]
blk.0.attn_q.weight FP16 [3117 42]
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
@@ -89,7 +143,7 @@ func TestShowInfo(t *testing.T) {
stop you
stop up
temperature 99`,
}, &b); err != nil {
}, false, &b); err != nil {
t.Fatal(err)
}
@@ -126,7 +180,7 @@ func TestShowInfo(t *testing.T) {
"clip.vision.embedding_length": float64(0),
"clip.vision.projection_dim": float64(0),
},
}, &b); err != nil {
}, false, &b); err != nil {
t.Fatal(err)
}
@@ -159,7 +213,7 @@ func TestShowInfo(t *testing.T) {
Ahoy, matey!
Weigh anchor!
`,
}, &b); err != nil {
}, false, &b); err != nil {
t.Fatal(err)
}
@@ -188,7 +242,7 @@ Weigh anchor!
QuantizationLevel: "FP16",
},
License: license,
}, &b); err != nil {
}, false, &b); err != nil {
t.Fatal(err)
}
@@ -707,3 +761,132 @@ func TestCreateHandler(t *testing.T) {
})
}
}
func TestNewCreateRequest(t *testing.T) {
tests := []struct {
name string
from string
opts runOptions
expected *api.CreateRequest
}{
{
"basic test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "",
Prompt: "You are a fun AI agent",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"parent model test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
},
},
{
"parent model as filepath test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "/some/file/like/etc/passwd",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"parent model as windows filepath test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "D:\\some\\file\\like\\etc\\passwd",
Messages: []api.Message{},
WordWrap: true,
},
&api.CreateRequest{
From: "mymodel",
Model: "newmodel",
},
},
{
"options test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
Options: map[string]any{
"temperature": 1.0,
},
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
Parameters: map[string]any{
"temperature": 1.0,
},
},
},
{
"messages test",
"newmodel",
runOptions{
Model: "mymodel",
ParentModel: "parentmodel",
System: "You are a fun AI agent",
Messages: []api.Message{
{
Role: "user",
Content: "hello there!",
},
{
Role: "assistant",
Content: "hello to you!",
},
},
WordWrap: true,
},
&api.CreateRequest{
From: "parentmodel",
Model: "newmodel",
System: "You are a fun AI agent",
Messages: []api.Message{
{
Role: "user",
Content: "hello there!",
},
{
Role: "assistant",
Content: "hello to you!",
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := NewCreateRequest(tt.from, tt.opts)
if !cmp.Equal(actual, tt.expected) {
t.Errorf("expected output %#v, got %#v", tt.expected, actual)
}
})
}
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
)
type MultilineState int
@@ -195,6 +196,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model)
if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("error: %v\n", err)
continue
}
return err
}
continue
@@ -343,7 +348,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
switch args[1] {
case "info":
_ = showInfo(resp, os.Stderr)
_ = showInfo(resp, false, os.Stderr)
case "license":
if resp.License == "" {
fmt.Println("No license was specified for this model.")
@@ -455,9 +460,16 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
parentModel := opts.ParentModel
modelName := model.ParseName(parentModel)
if !modelName.IsValid() {
parentModel = ""
}
req := &api.CreateRequest{
Name: name,
From: cmp.Or(opts.ParentModel, opts.Model),
Model: name,
From: cmp.Or(parentModel, opts.Model),
}
if opts.System != "" {

View File

@@ -13,8 +13,13 @@ import (
)
type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
TextModel TextParameters `json:"text_config"`
}
type TextParameters struct {
VocabSize uint32 `json:"vocab_size"`
}
type AdapterParameters struct {
@@ -185,6 +190,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
conv = &gemmaModel{}
case "Gemma2ForCausalLM":
conv = &gemma2Model{}
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
conv = &gemma3Model{Architecture: p.Architectures[0]}
case "Phi3ForCausalLM":
conv = &phi3Model{}
case "Qwen2ForCausalLM":
@@ -194,7 +201,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
case "CohereForCausalLM":
conv = &commandrModel{}
default:
return errors.New("unsupported architecture")
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
if err := json.Unmarshal(bts, conv); err != nil {
@@ -213,7 +220,14 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
}
vocabSize := int(p.VocabSize)
if vocabSize == 0 {
tVocabSize := int(p.TextModel.VocabSize)
vocabSize = tVocabSize
}
switch {
case vocabSize == 0:
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
case vocabSize > len(t.Vocabulary.Tokens):
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) {

View File

@@ -45,7 +45,7 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if strings.HasSuffix(t.Name(), "_norm.weight") {
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
t.SetRepacker(p.addOne)
}

142
convert/convert_gemma3.go Normal file
View File

@@ -0,0 +1,142 @@
package convert
import (
"cmp"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
gemmaModel
Architecture string
TextModel struct {
HeadDim uint32 `json:"head_dim"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
SlidingWindow uint32 `json:"sliding_window"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
NumHiddenLayers uint32 `json:"num_hidden_layers"` // block_count 32
HiddenSize uint32 `json:"hidden_size"` // embedding_length 1280
IntermediateSize uint32 `json:"intermediate_size"` // feed_forward_length 5120
ImageSize uint32 `json:"image_size"` // image_size 560
NumChannels uint32 `json:"num_channels"` // num_channels 3
PatchSize uint32 `json:"patch_size"` // patch_size 14
} `json:"vision_config"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
SlidingWindow uint32 `json:"sliding_window"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
}
const (
gemma4BLayerCount = 34
gemma12BLayerCount = 48
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"
numBlocks := cmp.Or(p.HiddenLayers, p.TextModel.HiddenLayers)
kv["gemma3.block_count"] = numBlocks
var (
numHeads uint32
numKVHeads uint32
)
switch numBlocks {
case gemma4BLayerCount:
numHeads = 8
numKVHeads = 4
case gemma12BLayerCount:
numHeads = 16
numKVHeads = 8
case gemma27BLayerCount:
numHeads = 32
numKVHeads = 16
default:
numHeads = p.NumAttentionHeads
numKVHeads = p.NumKeyValueHeads
}
kv["gemma3.attention.head_count"] = numHeads
kv["gemma3.attention.head_count_kv"] = numKVHeads
switch p.Architecture {
case "Gemma3ForCausalLM":
kv["gemma3.context_length"] = p.MaxPositionEmbeddings
kv["gemma3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["gemma3.attention.key_length"] = p.HeadDim
kv["gemma3.attention.value_length"] = p.HeadDim
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.feed_forward_length"] = p.IntermediateSize
default:
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 131072)
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
kv["gemma3.vision.num_channels"] = cmp.Or(p.VisionModel.NumChannels, 3)
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["gemma3.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, 1e-6)
kv["gemma3.attention.key_length"] = cmp.Or(p.TextModel.HeadDim, 256)
kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
}
if p.MultiModalTokensPerImage > 0 {
kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
}
return kv
}
func (p *gemma3Model) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"vision_tower.vision_model.embeddings", "v",
"vision_tower.vision_model", "v",
"vision_model.vision_model.embeddings", "v",
"vision_model.vision_model", "v",
"language_model.", "",
"model.layers", "blk",
"encoder.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"self_attn.out_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm",
"post_feedforward_layernorm", "post_ffw_norm",
"input_projection_weight", "input_projection.weight",
"multi_modal_projector", "mm",
}
}

View File

@@ -6,7 +6,9 @@ import (
"errors"
"fmt"
"io/fs"
"log/slog"
"os"
"reflect"
"slices"
"google.golang.org/protobuf/proto"
@@ -15,6 +17,8 @@ import (
)
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
slog.Debug("using spm vocabulary")
ast, err := parseAdditionalSpecialTokens(fsys)
if err != nil {
return nil, err
@@ -43,10 +47,19 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
v.Types = append(v.Types, int32(t))
default:
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
if slices.Contains(ast, piece.GetPiece()) {
// temporary fix to handle gemma3 broken configs
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
}
for _, t := range ast {
if t.Content == piece.GetPiece() {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
break
}
}
v.Types = append(v.Types, tt)
}
}
@@ -78,10 +91,16 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
return cmp.Compare(i.id, j.id)
})
n := len(v.Tokens)
for i, t := range ts {
if t.id != i+n {
return nil, fmt.Errorf("invalid token id: %d", t.id)
for _, t := range ts {
if t.id < len(v.Tokens) {
if v.Tokens[t.id] == t.content {
slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id)
continue
}
return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
}
if t.id != len(v.Tokens) {
return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
}
v.Tokens = append(v.Tokens, t.content)
@@ -92,7 +111,15 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
return &v, nil
}
func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
type specialToken struct {
Content string `json:"content"`
Lstrip bool `json:"lstrip"`
Normalized bool `json:"normalized"`
Rstrip bool `json:"rstrip"`
SingleWord bool `json:"single_word"`
}
func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
f, err := fsys.Open("special_tokens_map.json")
if errors.Is(err, os.ErrNotExist) {
return nil, nil
@@ -102,12 +129,43 @@ func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
defer f.Close()
var m struct {
AdditionalSpecialTokens []string `json:"additional_special_tokens"`
AdditionalSpecialTokens any `json:"additional_special_tokens"`
}
if err := json.NewDecoder(f).Decode(&m); err != nil {
return nil, err
}
return m.AdditionalSpecialTokens, nil
var ast []specialToken
switch st := m.AdditionalSpecialTokens.(type) {
case []string:
for _, s := range st {
ast = append(ast, specialToken{Content: s})
}
case []any:
for _, s := range st {
// marshal and unmarshal the object to get the special token
tMap := s.(map[string]any)
data, err := json.Marshal(tMap)
if err != nil {
return nil, err
}
var token specialToken
err = json.Unmarshal(data, &token)
if err != nil {
return nil, err
}
ast = append(ast, token)
}
default:
slog.Warn("special token", "unknown token", reflect.TypeOf(st))
}
slog.Debug("spm tokenizer", "additional tokens", ast)
return ast, nil
}

View File

@@ -300,10 +300,10 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
})
continue
}
minVer, err := strconv.Atoi(RocmComputeMajorMin)
if err != nil {
slog.Error("invalid RocmComputeMajorMin setting", "value", RocmComputeMajorMin, "error", err)
}
//minVer, err := strconv.Atoi(RocmComputeMajorMin)
//if err != nil {
// slog.Error("invalid RocmComputeMajorMin setting", "value", RocmComputeMajorMin, "error", err)
//}
// if int(major) < minVer {
// reason := fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch)
// slog.Warn(reason, "gpu", gpuID)

View File

@@ -558,6 +558,10 @@ Final response:
{
"model": "llama3.2",
"created_at": "2023-08-04T19:22:45.499127Z",
"message": {
"role": "assistant",
"content": ""
},
"done": true,
"total_duration": 4883583458,
"load_duration": 1334875,

59
docs/benchmark.md Normal file
View File

@@ -0,0 +1,59 @@
# Benchmark
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
## When to use
Run these benchmarks when:
- Making changes to the model inference engine
- Modifying model loading/unloading logic
- Changing prompt processing or token generation code
- Implementing a new model architecture
- Testing performance across different hardware setups
## Prerequisites
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
## Usage and Examples
>[!NOTE]
>All commands must be run from the root directory of the Ollama project.
Basic syntax:
```bash
go test -bench=. ./benchmark/... -m $MODEL_NAME
```
Required flags:
- `-bench=.`: Run all benchmarks
- `-m`: Model name to benchmark
Optional flags:
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
Common usage patterns:
Single benchmark run with a model specified:
```bash
go test -bench=. ./benchmark/... -m llama3.3
```
## Output metrics
The benchmark reports several key metrics:
- `gen_tok/s`: Generated tokens per second
- `prompt_tok/s`: Prompt processing tokens per second
- `ttft_ms`: Time to first token in milliseconds
- `load_ms`: Model load time in milliseconds
- `gen_tokens`: Total tokens generated
- `prompt_tokens`: Total prompt tokens processed
Each benchmark runs two scenarios:
- Cold start: Model is loaded from disk for each test
- Warm start: Model is pre-loaded in memory
Three prompt lengths are tested for each scenario:
- Short prompt (100 tokens)
- Medium prompt (500 tokens)
- Long prompt (1000 tokens)

View File

@@ -118,6 +118,35 @@ To run tests, use `go test`:
go test ./...
```
> NOTE: In rare cirumstances, you may nedd to change a package using the new
> "synctest" package in go1.24.
>
> If you do not have the "synctest" package enabled, you will not see build or
> test failures resulting from your change(s), if any, locally, but CI will
> break.
>
> If you see failures in CI, you can either keep pushing changes to see if the
> CI build passes, or you can enable the "synctest" package locally to see the
> failures before pushing.
>
> To enable the "synctest" package for testing, run the following command:
>
> ```shell
> GOEXPERIMENT=synctest go test ./...
> ```
>
> If you wish to enable synctest for all go commands, you can set the
> `GOEXPERIMENT` environment variable in your shell profile or by using:
>
> ```shell
> go env -w GOEXPERIMENT=synctest
> ```
>
> Which will enable the "synctest" package for all go commands without needing
> to set it for all shell sessions.
>
> The synctest package is not required for production builds.
## Library detection
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:

View File

@@ -20,7 +20,7 @@ Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size?
By default, Ollama uses a context window size of 2048 tokens.
By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`.
To change this when using `ollama run`, use `/set parameter`:
@@ -187,6 +187,13 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
```
# Allow all Chrome, Firefox, and Safari extensions
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
```
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
## Where are models stored?

View File

@@ -75,7 +75,7 @@ RestartSec=3
Environment="PATH=$PATH"
[Install]
WantedBy=default.target
WantedBy=multi-user.target
```
Then start the service:

View File

@@ -81,9 +81,11 @@ help you keep up to date.
If you'd like to install or integrate Ollama as a service, a standalone
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
and GPU library dependencies for Nvidia and AMD. This allows for embedding
Ollama in existing applications, or running it as a system service via `ollama
serve` with tools such as [NSSM](https://nssm.cc/).
and GPU library dependencies for Nvidia. If you have an AMD GPU, also download
and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the
same directory. This allows for embedding Ollama in existing applications, or
running it as a system service via `ollama serve` with tools such as
[NSSM](https://nssm.cc/).
> [!NOTE]
> If you are upgrading from a prior version, you should remove the old directories first.

View File

@@ -124,6 +124,19 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
return s
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
r := keyValue(kv, key, &array{})
s := make([]float32, r.size)
for i := range r.size {
s[i] = float32(r.values[i].(float32))
}
return s
}
func (kv KV) OllamaEngineRequired() bool {
return kv.Architecture() == "gemma3"
}
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
@@ -314,6 +327,10 @@ func (t Tensor) Size() uint64 {
return t.parameters() * t.typeSize() / t.blockSize()
}
func (t Tensor) Type() string {
return fileType(t.Kind).String()
}
type container interface {
Name() string
Decode(io.ReadSeeker) (model, error)
@@ -476,7 +493,7 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
// vocab graph
4*batch*(embedding+vocab)+embedding*vocab*105/128,
)
case "gemma", "gemma2":
case "gemma", "gemma2", "gemma3":
fullOffload = max(
4*batch*(embedding+vocab),
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
@@ -565,6 +582,56 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
return
}
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
if llm.KV().Uint("vision.block_count") == 0 {
return
}
for name, layer := range llm.Tensors().GroupLayers() {
if name == "v" || strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
imageSize := uint64(llm.KV().Uint("vision.image_size"))
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
if patchSize == 0 {
slog.Warn("unknown patch size for vision model")
return
}
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
switch llm.KV().Architecture() {
case "mllama":
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
graphSize = 4 * (8 +
imageSize*imageSize*numChannels*maxNumTiles +
embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3":
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
}
return weights, graphSize
}
// SupportsKVCacheType checks if the requested cache type is supported
func (f GGML) SupportsKVCacheType(cacheType string) bool {
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)

3
go.mod
View File

@@ -24,7 +24,7 @@ require (
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
golang.org/x/image v0.22.0
gonum.org/v1/gonum v0.15.0
golang.org/x/tools v0.30.0
)
require (
@@ -44,6 +44,7 @@ require (
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gonum.org/v1/gonum v0.15.0 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect
)

2
go.sum
View File

@@ -309,6 +309,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -66,6 +66,35 @@ func TestIntegrationMllama(t *testing.T) {
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
}
func TestIntegrationSplitBatch(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{
Model: "gemma3:4b",
// Fill up a chunk of the batch so the image will partially spill over into the next one
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
}
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// llava models on CPU can be quite slow to start,
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
}
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6

View File

@@ -4,6 +4,7 @@ import (
"errors"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
var (
@@ -29,10 +30,26 @@ type Cache interface {
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management **
// Init sets up runtime parameters
Init(backend ml.Backend, dtype ml.DType, capacity int32)
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Close closes the cache and frees resources associated with it
Close()
@@ -40,7 +57,7 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs.
StartForward(ctx ml.Context, positions []int32, seqs []int) error
StartForward(ctx ml.Context, batch input.Batch) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)

View File

@@ -8,6 +8,7 @@ import (
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
@@ -19,9 +20,13 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
// The mask is of shape history size, batch size
type Causal struct {
DType ml.DType
Capacity int32
windowSize int32
opts CausalOptions
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass **
// the active layer for Get and Put
@@ -39,6 +44,12 @@ type Causal struct {
// locations in the cache that are needed for this batch
curCellRange cellRange
// curSequences is the sequences corresponding to this pass's entries in the cache
curSequences []int
// curPositions is the positions corresponding to this pass's entries in the cache
curPositions []int32
// ** cache metadata **
// for each possible location in the cache, stores the position and set of sequences
@@ -52,8 +63,8 @@ type Causal struct {
shiftFn shiftFn
backend ml.Backend
cacheCtx ml.Context
keys, values []ml.Tensor
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
}
type cacheCell struct {
@@ -67,28 +78,81 @@ type cellRange struct {
}
func NewCausalCache(shift shiftFn) *Causal {
return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
return &Causal{
windowSize: math.MaxInt32,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
return &Causal{windowSize: windowSize, shiftFn: shift}
return &Causal{
windowSize: windowSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding == 0 {
c.config.CachePadding = 1
}
if c.config.MaskBatchPadding == 0 {
c.config.MaskBatchPadding = 1
}
if c.config.MaskDType == ml.DTypeOther {
c.config.MaskDType = ml.DTypeF32
}
var cacheSize int
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
cacheSize = maxSequences * capacity
} else {
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
}
cacheSize = roundUp(cacheSize, c.config.CachePadding)
c.cells = make([]cacheCell, cacheSize)
c.DType = dtype
c.Capacity = capacity
c.cells = make([]cacheCell, capacity)
c.cellRanges = make(map[int]cellRange)
c.backend = backend
c.cacheCtx = backend.NewContext()
}
func (c *Causal) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
}
func (c *Causal) Close() {
c.cacheCtx.Close()
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
c.curBatchSize = len(positions)
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
c.opts.Except = nil
c.updateSlidingWindow()
var err error
c.curLoc, err = c.findStartLoc()
@@ -101,8 +165,8 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err
}
c.curCellRange = newRange()
for i, pos := range positions {
seq := seqs[i]
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
@@ -127,7 +191,7 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err
c.cellRanges[seq] = seqRange
}
c.curMask, err = c.buildMask(ctx, positions, seqs)
c.curMask, err = c.buildMask(ctx)
return err
}
@@ -154,39 +218,138 @@ func (c *Causal) findStartLoc() (int, error) {
}
}
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
}
func (c *Causal) updateSlidingWindow() {
if c.windowSize == math.MaxInt32 {
return
}
// create a map of unique sequences to the lowest position in that sequence
lowestPos := make(map[int]int32)
for i := range c.curPositions {
seq := c.curSequences[i]
pos, ok := lowestPos[seq]
if !ok {
pos = c.curPositions[i]
} else if c.curPositions[i] < pos {
pos = c.curPositions[i]
}
lowestPos[seq] = pos
}
// delete any entries that are beyond the window of the oldest position in the sequence
for seq, pos := range lowestPos {
oldRange, ok := c.cellRanges[seq]
if !ok {
continue
}
newRange := newRange()
for i := oldRange.min; i <= oldRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos < pos-c.windowSize {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
newRange.min = min(newRange.min, i)
newRange.max = max(newRange.max, i)
}
}
}
c.cellRanges[seq] = newRange
}
}
func roundDown(length, pad int) int {
return (length / pad) * pad
}
func roundUp(length, pad int) int {
return ((length + pad - 1) / pad) * pad
}
// Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
// TODO(jessegross): This does not do padding, which is required for flash attention
len := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, c.curBatchSize*len)
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
// Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize {
enabled := !slices.Contains(c.opts.Except, i)
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
c.cells[j].pos < positions[i]-c.windowSize {
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
}
}
}
return ctx.FromFloatSlice(mask, len, c.curBatchSize)
// Mask out any padding tokens we added. For padding that we added to the cache history, this
// has already been masked out because the sequence doesn't match.
for i := c.curBatchSize * length; i < len(mask); i++ {
mask[i] = float32(math.Inf(-1))
}
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
ctx.Forward(maskTensor.Copy(ctx, out))
maskTensor = out
}
return maskTensor, nil
}
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
for _, obj := range objs {
if obj == nil {
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
for i, key := range c.keys {
if key == nil {
continue
}
srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
ctx.Forward(srcView.Copy(ctx, dstView))
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
}
}
@@ -210,7 +373,8 @@ func (c *Causal) defrag() {
ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v).
// copy for each of k and v). We also need to refer to the original
// k and v cache tensors - once per layer, not per move.
layers := 0
for _, key := range c.keys {
if key == nil {
@@ -219,7 +383,7 @@ func (c *Causal) defrag() {
layers++
}
maxMoves := ctx.MaxTensors() / (6 * layers)
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
@@ -238,8 +402,7 @@ func (c *Causal) defrag() {
pendingLen++
break
} else {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
}
@@ -263,8 +426,7 @@ func (c *Causal) defrag() {
}
if pendingLen > 0 {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
@@ -293,47 +455,107 @@ func (c *Causal) defrag() {
}
func (c *Causal) SetLayer(layer int) {
if layer >= len(c.keys) {
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
}
c.curLayer = layer
}
type CausalOptions struct {
// Enabled controls whether the causal mask is generated for a particular index in a batch
Except []int
}
// SetCausal disables causal mask generation for a particular range of indicies in
// the current batch for subsequent calls to Get. The state resets for the next forward pass.
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
if !slices.Equal(c.opts.Except, opts.Except) {
c.opts = opts
if ctx != nil {
var err error
c.curMask, err = c.buildMask(ctx)
if err != nil {
// This error should never occur because we have previously built a mask with the same shape
panic(fmt.Errorf("SetCausal: %w", err))
}
}
}
}
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
key.Dim(0), key.Stride(1),
key.Dim(1), key.Stride(2),
c.curMask.Dim(0),
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
cachedSize := c.curMask.Dim(0)
key = key.View(ctx, rowSize*c.curCellRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
cachedSize,
)
value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
value.Dim(0), value.Stride(1),
value.Dim(1), value.Stride(2),
c.curMask.Dim(0),
)
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
value = value.View(ctx, elemSize*c.curCellRange.min,
cachedSize, value.Stride(1),
vHeadDim, value.Stride(2),
numKVHeads,
)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
value = value.View(ctx, rowSize*c.curCellRange.min,
vHeadDim, value.Stride(1),
numKVHeads, value.Stride(2),
cachedSize,
)
}
return key, value, c.curMask
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
if c.curBatchSize != key.Dim(2) {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
kHeadDim := key.Dim(0)
vHeadDim := value.Dim(0)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
if c.curBatchSize != batchSize {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
}
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
if _, ok := c.ctxs[c.curLayer]; !ok {
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
}
ctx.Forward(
key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
)
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
}
if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
}
}
rowSize := c.keys[c.curLayer].Stride(2)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
if c.config.PermutedV {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
} else {
rowSize := c.values[c.curLayer].Stride(2)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
}
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
@@ -379,7 +601,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
}
}
kShift, err := ctx.FromIntSlice(offsets, len(offsets))
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
if err != nil {
return err
}
@@ -389,9 +611,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
continue
}
key = key.View(ctx, key.Stride(2)*seqRange.min,
key.Dim(0), key.Stride(1),
key.Dim(1), key.Stride(2),
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
key = key.View(ctx, rowSize*seqRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
size,
)

View File

@@ -6,6 +6,7 @@ import (
"testing"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
type testCase struct {
@@ -24,7 +25,7 @@ func TestStore(t *testing.T) {
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
@@ -57,11 +58,11 @@ func TestSWA(t *testing.T) {
cache := NewSWACache(1, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF32, 16)
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
name: "SlidingWindow",
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
@@ -70,6 +71,16 @@ func TestSWA(t *testing.T) {
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{5, 6, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
},
}
testCache(t, backend, cache, tests)
@@ -80,7 +91,7 @@ func TestSequences(t *testing.T) {
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
@@ -115,7 +126,7 @@ func TestRemove(t *testing.T) {
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
@@ -180,7 +191,7 @@ func TestDefrag(t *testing.T) {
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
@@ -228,7 +239,7 @@ func TestCopy(t *testing.T) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
tests := []testCase{
{
@@ -269,7 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, test.pos, test.seqs)
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
if err != nil {
panic(err)
}
@@ -303,13 +314,17 @@ func (b *testBackend) NewContext() ml.Context {
return &testContext{}
}
func (b *testBackend) NewContextSize(int) ml.Context {
return &testContext{}
}
func (b *testBackend) SystemInfo() string {
return "not implemented"
}
type testContext struct{}
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
total := 0
if len(shape) > 0 {
@@ -322,8 +337,12 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
}
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...)
}
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s)
@@ -342,11 +361,15 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return out, nil
}
func (c *testContext) Input() ml.Context { return c }
func (c *testContext) Output() ml.Context { return c }
func (c *testContext) Layer(int) ml.Context { return c }
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) MaxTensors() int {
func (c *testContext) MaxGraphNodes() int {
return 10
}
@@ -391,7 +414,7 @@ func (t *testTensor) Floats() []float32 {
}
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data {
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
@@ -428,11 +451,19 @@ func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
panic("not implemented")
}
@@ -468,7 +499,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
context := &testContext{}
view := context.Zeros(t.dtype, s...).(*testTensor)
view := context.Empty(t.dtype, s...).(*testTensor)
view.data = t.data[offset : offset+len(view.data)]
return view
@@ -482,6 +513,10 @@ func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}

View File

@@ -1,7 +1,10 @@
package kvcache
import (
"fmt"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
// Encoder cache stores K and V tensors that are position independent
@@ -11,6 +14,9 @@ import (
//
// Not currently safe for multiple sequences
type EncoderCache struct {
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass **
// the active layer for Get and Put
@@ -30,36 +36,63 @@ type EncoderCache struct {
encoderPos int32
// ** cache data storage **
cacheCtx ml.Context
keys, values []ml.Tensor
backend ml.Backend
ctxs map[int]ml.Context
keys, values map[int]ml.Tensor
}
func NewEncoderCache() *EncoderCache {
return &EncoderCache{}
return &EncoderCache{
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
c.cacheCtx = backend.NewContext()
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if maxSequences > 1 {
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
c.backend = backend
}
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
}
func (c *EncoderCache) Close() {
c.cacheCtx.Close()
for _, ctx := range c.ctxs {
ctx.Close()
}
}
func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
// The image is always in the first position
c.curPos = positions[0]
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
// We work with the most recent image
if len(batch.Multimodal) > 0 {
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
}
return nil
}
func (c *EncoderCache) SetLayer(layer int) {
if layer >= len(c.keys) {
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
}
c.curLayer = layer
}
@@ -75,9 +108,20 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.encoderPos = c.curPos
c.encoderCached = true
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
if c.config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3)
}
if _, ok := c.ctxs[c.curLayer]; !ok {
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
}
if _, ok := c.values[c.curLayer]; !ok {
c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
}
ctx.Forward(

View File

@@ -4,6 +4,7 @@ import (
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
// Wrapper cache is a container for multiple types of caches,
@@ -22,9 +23,15 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
}
}
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
for _, cache := range c.caches {
cache.Init(backend, dtype, capacity)
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
}
}
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
for _, cache := range c.caches {
cache.SetConfig(config)
}
}
@@ -34,14 +41,14 @@ func (c *WrapperCache) Close() {
}
}
func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
for i, cache := range c.caches {
err := cache.StartForward(ctx, positions, seqs)
err := cache.StartForward(ctx, batch)
if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- {
for k := range positions {
_ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
for k := range batch.Positions {
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
}
}
return err

View File

@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@@ -804,6 +805,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_GEMMA3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_STARCODER2,
{

View File

@@ -41,6 +41,7 @@ enum llm_arch {
LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,

View File

@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_GEMMA3:
{
} break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_GEMMA3:
{
} break;
case LLM_ARCH_STARCODER2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHIMOE:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
case LLM_ARCH_GEMMA3:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:

View File

@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// This used to be a regex, but <regex> has an extreme cost to compile times.
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
// don't quantize vision stuff
quantize &= name.find("v.blk.") == std::string::npos;
quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
quantize &= name.find("v.position_embedding.weight") == std::string::npos;
quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
// quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) >= 2);

View File

@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
if (precompiled_charsmap_keyidx != -1) {
size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx);
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
#ifdef IS_BIG_ENDIAN

View File

@@ -21,18 +21,6 @@ package llama
extern bool llamaProgressCallback(float progress, void *user_data);
extern void llamaLog(int level, char* text, void* user_data);
typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
COMPILER inline get_compiler() {
#if defined(__clang__)
return COMP_CLANG;
#elif defined(__GNUC__)
return COMP_GCC;
#else
return UNKNOWN_COMPILER;
#endif
}
*/
import "C"
@@ -72,19 +60,6 @@ func BackendInit() {
C.llama_backend_init()
}
func PrintSystemInfo() string {
var compiler string
switch C.get_compiler() {
case C.COMP_UNKNOWN:
compiler = "cgo(unknown_compiler)"
case C.COMP_GCC:
compiler = "cgo(gcc)"
case C.COMP_CLANG:
compiler = "cgo(clang)"
}
return C.GoString(C.llama_print_system_info()) + compiler
}
func GetModelArch(modelPath string) (string, error) {
mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp))
@@ -262,7 +237,7 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
cparams.progress_callback_user_data = unsafe.Pointer(&handle)
}
m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
m := Model{c: C.llama_model_load_from_file(C.CString(modelPath), cparams)}
if m.c == nil {
return nil, fmt.Errorf("unable to load model: %s", modelPath)
}
@@ -270,13 +245,27 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
return &m, nil
}
func LoadVocabFromFile(path string) (*Vocab, error) {
mp := C.CString(path)
defer C.free(unsafe.Pointer(mp))
v := Vocab{c: C.llama_load_vocab_from_file(mp)}
if v.c == nil {
return nil, fmt.Errorf("unable to load vocab: %s", path)
}
return &v, nil
}
func FreeVocab(vocab *Vocab) {
C.llama_free_vocab(vocab.c)
}
func FreeModel(model *Model) {
C.llama_free_model(model.c)
C.llama_model_free(model.c)
}
func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
c := Context{
c: C.llama_new_context_with_model(model.c, params.c),
c: C.llama_init_from_model(model.c, params.c),
numThreads: int(params.c.n_threads),
}
if c.c == nil {
@@ -287,15 +276,15 @@ func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
}
func (m *Model) NumVocab() int {
return int(C.llama_n_vocab(m.Vocab()))
return int(C.llama_vocab_n_tokens(m.Vocab()))
}
func (m *Model) TokenIsEog(token int) bool {
return bool(C.llama_token_is_eog(m.Vocab(), C.llama_token(token)))
return bool(C.llama_vocab_is_eog(m.Vocab(), C.llama_token(token)))
}
func (m *Model) AddBOSToken() bool {
return bool(C.llama_add_bos_token(m.Vocab()))
return bool(C.llama_vocab_get_add_bos(m.Vocab()))
}
func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
@@ -318,6 +307,10 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
return nil
}
type Vocab struct {
c *C.struct_llama_vocab
}
func (m *Model) Vocab() *C.struct_llama_vocab {
return C.llama_model_get_vocab(m.c)
}
@@ -478,7 +471,7 @@ func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int
}
func (m *Model) NEmbd() int {
return int(C.llama_n_embd(m.c))
return int(C.llama_model_n_embd(m.c))
}
func Quantize(infile, outfile string, ftype uint32) error {
@@ -694,3 +687,53 @@ func SchemaToGrammar(schema []byte) []byte {
}
return buf[:n]
}
type Sampler struct {
c *C.struct_llama_sampler
}
func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler {
cGrammar := C.CString(grammar)
cRoot := C.CString("root")
defer C.free(unsafe.Pointer(cGrammar))
defer C.free(unsafe.Pointer(cRoot))
sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)}
return sampler
}
func (s *Sampler) Accept(token int32) {
C.llama_sampler_accept(s.c, C.llama_token(token))
}
type TokenData struct {
Id int32
Logit float32
}
func (s *Sampler) Apply(tokens []TokenData) {
tds := make([]C.struct_llama_token_data, len(tokens))
for i, token := range tokens {
tds[i] = C.struct_llama_token_data{
id: C.int32_t(token.Id),
logit: C.float(token.Logit),
p: C.float(0.0),
}
}
tda := &C.llama_token_data_array{
data: (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])),
size: C.size_t(len(tokens)),
selected: C.int64_t(-1),
sorted: C.bool(false),
}
var pinner runtime.Pinner
pinner.Pin(&tds[0])
defer pinner.Unpin()
C.llama_sampler_apply(s.c, tda)
for i := range tokens {
tokens[i].Logit = float32(tds[i].logit)
}
}

View File

@@ -1,69 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Tue, 11 Feb 2025 14:06:36 -0800
Subject: [PATCH] try/catch backend load
---
ggml/src/ggml-backend-reg.cpp | 45 ++++++++++++++++++-----------------
1 file changed, 23 insertions(+), 22 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 98d5e14d..1c19129a 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
}
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
for (const auto & entry : dir_it) {
- if (entry.is_regular_file()) {
- std::wstring filename = entry.path().filename().wstring();
- std::wstring ext = entry.path().extension().wstring();
- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
- if (!handle && !silent) {
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
- }
- if (handle) {
+ try {
+ if (entry.is_regular_file()) {
+ std::wstring filename = entry.path().filename().wstring();
+ std::wstring ext = entry.path().extension().wstring();
+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
+ dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
+ if (!handle) {
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ continue;
+ }
+
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
- if (score_fn) {
- int s = score_fn();
-#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
-#endif
- if (s > best_score) {
- best_score = s;
- best_path = entry.path().wstring();
- }
- } else {
- if (!silent) {
- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
- }
+ if (!score_fn) {
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ continue;
+ }
+
+ int s = score_fn();
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
+ if (s > best_score) {
+ best_score = s;
+ best_path = entry.path().wstring();
}
}
}
+ } catch (const std::exception & e) {
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
}
}
}

View File

@@ -4,11 +4,11 @@ Date: Sun, 16 Feb 2025 20:00:22 -0500
Subject: [PATCH] use std::filesystem::path instead of wstring
---
ggml/src/ggml-backend-reg.cpp | 144 ++++++++++++++--------------------
1 file changed, 58 insertions(+), 86 deletions(-)
ggml/src/ggml-backend-reg.cpp | 199 +++++++++++++++-------------------
1 file changed, 88 insertions(+), 111 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 1c19129a..c854e6bb 100644
index 98d5e14d..799af5f3 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -66,26 +66,6 @@
@@ -264,47 +264,55 @@ index 1c19129a..c854e6bb 100644
for (const auto & search_path : search_paths) {
if (!fs::exists(search_path)) {
continue;
@@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
@@ -513,29 +485,26 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
for (const auto & entry : dir_it) {
try {
if (entry.is_regular_file()) {
- std::wstring filename = entry.path().filename().wstring();
- std::wstring ext = entry.path().extension().wstring();
+ std::string filename = entry.path().filename().string();
+ std::string ext = entry.path().extension().string();
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
+ dl_handle_ptr handle { dl_load_library(entry.path()) };
if (!handle) {
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (!score_fn) {
- GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
int s = score_fn();
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
if (s > best_score) {
best_score = s;
- best_path = entry.path().wstring();
+ best_path = entry.path();
}
if (entry.is_regular_file()) {
- std::wstring filename = entry.path().filename().wstring();
- std::wstring ext = entry.path().extension().wstring();
+ std::string filename = entry.path().filename().string();
+ std::string ext = entry.path().extension().string();
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
- if (!handle && !silent) {
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ dl_handle_ptr handle { dl_load_library(entry.path()) };
+ if (!handle) {
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
+ continue;
}
- if (handle) {
- auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
- if (score_fn) {
- int s = score_fn();
-#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
-#endif
- if (s > best_score) {
- best_score = s;
- best_path = entry.path().wstring();
- }
- } else {
- if (!silent) {
- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
- }
- }
+
+ auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
+ if (!score_fn) {
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
+ continue;
+ }
+
+ int s = score_fn();
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
+ if (s > best_score) {
+ best_score = s;
+ best_path = entry.path();
}
}
} catch (const std::exception & e) {
- GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
}
}
}
@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
@@ -545,7 +514,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
if (best_score == 0) {
// try to load the base backend
for (const auto & search_path : search_paths) {
@@ -313,3 +321,49 @@ index 1c19129a..c854e6bb 100644
if (fs::exists(path)) {
return get_reg().load_backend(path, silent);
}
@@ -560,6 +529,14 @@ void ggml_backend_load_all() {
ggml_backend_load_all_from_path(nullptr);
}
+static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
+ try {
+ ggml_backend_load_best(name, silent, user_search_path);
+ } catch (const std::exception & e) {
+ GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
+ }
+}
+
void ggml_backend_load_all_from_path(const char * dir_path) {
#ifdef NDEBUG
bool silent = true;
@@ -567,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
bool silent = false;
#endif
- ggml_backend_load_best("blas", silent, dir_path);
- ggml_backend_load_best("cann", silent, dir_path);
- ggml_backend_load_best("cuda", silent, dir_path);
- ggml_backend_load_best("hip", silent, dir_path);
- ggml_backend_load_best("kompute", silent, dir_path);
- ggml_backend_load_best("metal", silent, dir_path);
- ggml_backend_load_best("rpc", silent, dir_path);
- ggml_backend_load_best("sycl", silent, dir_path);
- ggml_backend_load_best("vulkan", silent, dir_path);
- ggml_backend_load_best("opencl", silent, dir_path);
- ggml_backend_load_best("musa", silent, dir_path);
- ggml_backend_load_best("cpu", silent, dir_path);
+ ggml_backend_try_load_best("blas", silent, dir_path);
+ ggml_backend_try_load_best("cann", silent, dir_path);
+ ggml_backend_try_load_best("cuda", silent, dir_path);
+ ggml_backend_try_load_best("hip", silent, dir_path);
+ ggml_backend_try_load_best("kompute", silent, dir_path);
+ ggml_backend_try_load_best("metal", silent, dir_path);
+ ggml_backend_try_load_best("rpc", silent, dir_path);
+ ggml_backend_try_load_best("sycl", silent, dir_path);
+ ggml_backend_try_load_best("vulkan", silent, dir_path);
+ ggml_backend_try_load_best("opencl", silent, dir_path);
+ ggml_backend_try_load_best("musa", silent, dir_path);
+ ggml_backend_try_load_best("cpu", silent, dir_path);
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
const char * backend_path = std::getenv("GGML_BACKEND_PATH");
if (backend_path) {

View File

@@ -0,0 +1,64 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
Date: Wed, 5 Mar 2025 17:41:07 -0800
Subject: [PATCH] fix string arr kv loading
---
ggml/include/gguf.h | 1 +
ggml/src/gguf.cpp | 7 +++++--
src/llama-vocab.cpp | 2 +-
3 files changed, 7 insertions(+), 3 deletions(-)
diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h
index 79ee2020..3efb22f0 100644
--- a/ggml/include/gguf.h
+++ b/ggml/include/gguf.h
@@ -114,6 +114,7 @@ extern "C" {
// get raw pointer to the first element of the array with the given key_id
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
+ GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id);
// get ith C string from array with given key_id
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
index ab13669c..f75b923f 100644
--- a/ggml/src/gguf.cpp
+++ b/ggml/src/gguf.cpp
@@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data();
}
+size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ return ctx->kv[key_id].data.size();
+}
+
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
@@ -874,7 +878,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data();
}
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index c7ff28be..7a185443 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
if (precompiled_charsmap_keyidx != -1) {
- size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
+ size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx);
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
#ifdef IS_BIG_ENDIAN

View File

@@ -0,0 +1,33 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Sun, 9 Mar 2025 14:44:16 -0700
Subject: [PATCH] ollama debug tensor
---
ggml/src/ggml-cpu/ggml-cpu.c | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index 2f606d82..ec60e8fc 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -11,6 +11,8 @@
#include "ggml-threading.h"
#include "ggml.h"
+#include "ollama-debug.h"
+
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@@ -14103,6 +14105,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
ggml_compute_forward(&params, node);
+#ifdef OLLAMA_DEBUG
+ ollama_debug(node, true);
+#endif
+
if (state->ith == 0 && cplan->abort_callback &&
cplan->abort_callback(cplan->abort_callback_data)) {
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);

View File

@@ -0,0 +1,113 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Patrick Devine <patrick@infrahq.com>
Date: Fri, 14 Mar 2025 16:33:23 -0700
Subject: [PATCH] gemma3 quantization
---
src/llama-arch.cpp | 19 +++++++++++++++++++
src/llama-arch.h | 1 +
src/llama-model.cpp | 7 +++++++
src/llama-quant.cpp | 9 +++++++++
4 files changed, 36 insertions(+)
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index b6f20286..b443fcd3 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -37,6 +37,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
+ { LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@@ -804,6 +805,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
+ {
+ LLM_ARCH_GEMMA3,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
+ },
+ },
{
LLM_ARCH_STARCODER2,
{
diff --git a/src/llama-arch.h b/src/llama-arch.h
index ec742224..aad92a5d 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -41,6 +41,7 @@ enum llm_arch {
LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
+ LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index ab1a07d1..70183041 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_GEMMA3:
+ {
+ } break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
+ case LLM_ARCH_GEMMA3:
+ {
+ } break;
case LLM_ARCH_STARCODER2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHIMOE:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
+ case LLM_ARCH_GEMMA3:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:
diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
index 6eb1da08..d2f3a510 100644
--- a/src/llama-quant.cpp
+++ b/src/llama-quant.cpp
@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// This used to be a regex, but <regex> has an extreme cost to compile times.
bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
+ // don't quantize vision stuff
+ quantize &= name.find("v.blk.") == std::string::npos;
+
+ quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos;
+ quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos;
+ quantize &= name.find("v.patch_embedding.weight") == std::string::npos;
+ quantize &= name.find("v.position_embedding.weight") == std::string::npos;
+ quantize &= name.find("v.post_layernorm.weight") == std::string::npos;
+
// quantize only 2D and 3D tensors (experts)
quantize &= (ggml_n_dims(tensor) >= 2);

View File

@@ -2,6 +2,9 @@
#include "sampling.h"
#include "sampling_ext.h"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "llama-model.h"
#include "llama-model-loader.h"
struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
try {
@@ -64,3 +67,22 @@ int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
return 0;
}
}
struct llama_vocab * llama_load_vocab_from_file(const char * fname) {
llama_vocab * vocab = new llama_vocab();
try {
const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
std::vector<std::string> splits = {};
llama_model_loader ml(std::string(fname), splits, false, false, nullptr);
vocab->load(ml, kv);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
return nullptr;
}
return vocab;
}
void llama_free_vocab(struct llama_vocab * vocab) {
delete vocab;
}

View File

@@ -35,6 +35,9 @@ extern "C"
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
struct llama_vocab * llama_load_vocab_from_file(const char * fname);
void llama_free_vocab(struct llama_vocab * vocab);
#ifdef __cplusplus
}
#endif

View File

@@ -115,6 +115,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
if projectorWeights == 0 && projectorGraph == 0 {
projectorWeights, projectorGraph = f.VisionGraphSize()
}
layers := f.Tensors().GroupLayers()
// add one layer worth of memory as a buffer
@@ -215,8 +218,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
layerSize = blk.Size()
layerSize += kv / f.KV().BlockCount()
memoryWeights += blk.Size()
}
memoryWeights += layerSize
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
// Stop allocating on GPU(s) once we hit the users target NumGPU
@@ -373,7 +376,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
// memory of the weights
"total", format.HumanBytes2(m.memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(m.memoryWeights-m.memoryLayerOutput),
"repeating", format.HumanBytes2(m.memoryWeights),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
),

View File

@@ -30,6 +30,7 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
)
type LlamaServer interface {
@@ -54,8 +55,15 @@ type llmServer struct {
options api.Options
numParallel int
modelPath string
modelLock sync.Mutex // Temporary until we switch fully to Go server
model *llama.Model // If non-nil, the runner is a new Go server
// llamaModel is an instance of the cgo llama.cpp model definition
// nil if this server is running the new engine
llamaModel *llama.Model
llamaModelLock sync.Mutex
// textProcessor handles text encoding/decoding for the model in the Ollama engine
// nil if this server is running the llama.cpp based engine
textProcessor model.TextProcessor
estimate MemoryEstimate
totalLayers uint64
@@ -89,7 +97,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs
// The gpu list must be a single family.
func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
systemInfo := discover.GetSystemInfo()
systemTotalMemory := systemInfo.System.TotalMemory
systemFreeMemory := systemInfo.System.FreeMemory
@@ -130,7 +138,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
slog.Info("offload", "", estimate)
params := []string{
"--model", model,
"--model", modelPath,
"--ctx-size", strconv.Itoa(opts.NumCtx),
"--batch-size", strconv.Itoa(opts.NumBatch),
}
@@ -153,11 +161,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
}
}
if len(projectors) > 0 {
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
params = append(params, "--mmproj", projectors[0])
}
defaultThreads := systemInfo.GetOptimalThreadCount()
if opts.NumThread > 0 {
params = append(params, "--threads", strconv.Itoa(opts.NumThread))
@@ -257,6 +260,34 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
}
}
slog.Debug("compatible gpu libraries", "compatible", compatible)
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
var llamaModel *llama.Model
var textProcessor model.TextProcessor
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
textProcessor, err = model.NewTextProcessor(modelPath)
if err != nil {
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
}
}
if textProcessor == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
}
}
if len(projectors) > 0 && llamaModel != nil {
params = append(params, "--mmproj", projectors[0])
}
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
@@ -275,7 +306,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
}
finalParams := []string{"runner"}
if envconfig.NewEngine() {
if textProcessor != nil {
// New engine
// TODO - if we have failure to load scenarios, add logic to retry with the old runner
finalParams = append(finalParams, "--ollama-engine")
}
finalParams = append(finalParams, params...)
@@ -315,28 +348,20 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
// finally, add the root library path
libraryPaths = append(libraryPaths, discover.LibOllamaPath)
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
// TODO - once fully switched to the Go runner, load the model here for tokenize/detokenize cgo access
s := &llmServer{
port: port,
cmd: exec.Command(exe, finalParams...),
status: NewStatusWriter(os.Stderr),
options: opts,
modelPath: model,
estimate: estimate,
numParallel: numParallel,
sem: semaphore.NewWeighted(int64(numParallel)),
totalLayers: f.KV().BlockCount() + 1,
gpus: gpus,
done: make(chan error, 1),
port: port,
cmd: exec.Command(exe, finalParams...),
status: NewStatusWriter(os.Stderr),
options: opts,
modelPath: modelPath,
llamaModel: llamaModel,
textProcessor: textProcessor,
estimate: estimate,
numParallel: numParallel,
sem: semaphore.NewWeighted(int64(numParallel)),
totalLayers: f.KV().BlockCount() + 1,
gpus: gpus,
done: make(chan error, 1),
}
s.cmd.Env = os.Environ()
@@ -377,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
}
slog.Info("starting llama server", "cmd", s.cmd.String())
slog.Info("starting llama server", "cmd", s.cmd)
if envconfig.Debug() {
filteredEnv := []string{}
for _, ev := range s.cmd.Env {
@@ -405,6 +430,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
}
err := fmt.Errorf("error starting runner: %v %s", err, msg)
if len(compatible) == 0 {
if llamaModel != nil {
llama.FreeModel(llamaModel)
}
return nil, err
}
@@ -442,7 +470,7 @@ const ( // iota is reset to 0
ServerStatusError
)
func (s ServerStatus) ToString() string {
func (s ServerStatus) String() string {
switch s {
case ServerStatusReady:
return "llm server ready"
@@ -457,12 +485,9 @@ func (s ServerStatus) ToString() string {
}
}
type ServerStatusResp struct {
Status string `json:"status"`
SlotsIdle int `json:"slots_idle"`
SlotsProcessing int `json:"slots_processing"`
Error string `json:"error"`
Progress float32 `json:"progress"`
type ServerStatusResponse struct {
Status ServerStatus `json:"status"`
Progress float32 `json:"progress"`
}
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
@@ -474,7 +499,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
}
if s.cmd.ProcessState.ExitCode() == -1 {
// Most likely a signal killed it, log some more details to try to help troubleshoot
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState)
}
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
}
@@ -499,21 +524,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
return ServerStatusError, fmt.Errorf("read health request: %w", err)
}
var status ServerStatusResp
if err := json.Unmarshal(body, &status); err != nil {
var ssr ServerStatusResponse
if err := json.Unmarshal(body, &ssr); err != nil {
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
}
switch status.Status {
case "ok":
return ServerStatusReady, nil
case "no slot available":
return ServerStatusNoSlotsAvailable, nil
case "loading model":
s.loadProgress = status.Progress
return ServerStatusLoadingModel, nil
switch ssr.Status {
case ServerStatusLoadingModel:
s.loadProgress = ssr.Progress
return ssr.Status, nil
case ServerStatusReady, ServerStatusNoSlotsAvailable:
return ssr.Status, nil
default:
return ServerStatusError, fmt.Errorf("server error: %+v", status)
return ssr.Status, fmt.Errorf("server error: %+v", ssr)
}
}
@@ -588,7 +611,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
status, _ := s.getServerStatus(ctx)
if lastStatus != status && status != ServerStatusReady {
// Only log on status changes
slog.Info("waiting for server to become available", "status", status.ToString())
slog.Info("waiting for server to become available", "status", status)
}
switch status {
case ServerStatusReady:
@@ -602,7 +625,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress))
stallTimer = time.Now().Add(stallDuration)
} else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 {
slog.Debug("model load completed, waiting for server to become available", "status", status.ToString())
slog.Debug("model load completed, waiting for server to become available", "status", status)
stallTimer = time.Now().Add(stallDuration)
fullyLoaded = true
}
@@ -643,63 +666,26 @@ type ImageData struct {
AspectRatioID int `json:"aspect_ratio_id"`
}
type completion struct {
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
}
type CompletionRequest struct {
Prompt string
Format json.RawMessage
Images []ImageData
Options *api.Options
Grammar string // set before sending the request to the subprocess
}
type CompletionResponse struct {
Content string
DoneReason string
Done bool
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
Content string `json:"content"`
DoneReason string `json:"done_reason"`
Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"`
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
request := map[string]any{
"prompt": req.Prompt,
"stream": true,
"n_predict": req.Options.NumPredict,
"n_keep": req.Options.NumKeep,
"main_gpu": req.Options.MainGPU,
"temperature": req.Options.Temperature,
"top_k": req.Options.TopK,
"top_p": req.Options.TopP,
"min_p": req.Options.MinP,
"typical_p": req.Options.TypicalP,
"repeat_last_n": req.Options.RepeatLastN,
"repeat_penalty": req.Options.RepeatPenalty,
"presence_penalty": req.Options.PresencePenalty,
"frequency_penalty": req.Options.FrequencyPenalty,
"mirostat": req.Options.Mirostat,
"mirostat_tau": req.Options.MirostatTau,
"mirostat_eta": req.Options.MirostatEta,
"seed": req.Options.Seed,
"stop": req.Options.Stop,
"image_data": req.Images,
"cache_prompt": true,
}
if len(req.Format) > 0 {
switch string(req.Format) {
case `null`, `""`:
@@ -707,7 +693,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
// these as "not set".
break
case `"json"`:
request["grammar"] = grammarJSON
req.Grammar = grammarJSON
default:
if req.Format[0] != '{' {
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
@@ -718,10 +704,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if g == nil {
return fmt.Errorf("invalid JSON schema in format")
}
request["grammar"] = string(g)
req.Grammar = string(g)
}
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
if err := s.sem.Acquire(ctx, 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
@@ -742,7 +733,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if err != nil {
return err
} else if status != ServerStatusReady {
return fmt.Errorf("unexpected server status: %s", status.ToString())
return fmt.Errorf("unexpected server status: %s", status)
}
// Handling JSON marshaling with special characters unescaped.
@@ -750,7 +741,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false)
if err := enc.Encode(request); err != nil {
if err := enc.Encode(req); err != nil {
return fmt.Errorf("failed to marshal data: %v", err)
}
@@ -801,7 +792,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
evt = line
}
var c completion
var c CompletionResponse
if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
}
@@ -825,20 +816,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
})
}
if c.Stop {
doneReason := "stop"
if c.StoppedLimit {
doneReason = "length"
}
fn(CompletionResponse{
Done: true,
DoneReason: doneReason,
PromptEvalCount: c.Timings.PromptN,
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
})
if c.Done {
fn(c)
return nil
}
}
@@ -886,7 +865,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
if err != nil {
return nil, err
} else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
return nil, fmt.Errorf("unexpected server status: %s", status)
}
data, err := json.Marshal(EmbeddingRequest{Content: input})
@@ -933,64 +912,25 @@ type TokenizeResponse struct {
}
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
s.modelLock.Lock()
defer s.modelLock.Unlock()
if s.model != nil {
return s.model.Tokenize(content, false, true)
}
s.llamaModelLock.Lock()
defer s.llamaModelLock.Unlock()
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return nil, err
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
if s.llamaModel != nil {
return s.llamaModel.Tokenize(content, false, true)
}
data, err := json.Marshal(TokenizeRequest{Content: content})
if err != nil {
return nil, fmt.Errorf("marshaling encode data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("encode request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("do encode request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
if s.model == nil {
slog.Debug("new runner detected, loading model for cgo tokenization")
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
}
s.model = m
if s.textProcessor != nil {
tokens, err := s.textProcessor.Encode(content, false)
if err != nil {
return nil, err
}
return s.model.Tokenize(content, false, true)
toks := make([]int, len(tokens))
for i, t := range tokens {
toks[i] = int(t)
}
return toks, nil
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read encode request: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm encode error: %s", body)
return nil, fmt.Errorf("%s", body)
}
var encoded TokenizeResponse
if err := json.Unmarshal(body, &encoded); err != nil {
return nil, fmt.Errorf("unmarshal encode response: %w", err)
}
return encoded.Tokens, nil
// not reached
return nil, fmt.Errorf("no tokenizer configured")
}
type DetokenizeRequest struct {
@@ -1002,80 +942,38 @@ type DetokenizeResponse struct {
}
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
s.modelLock.Lock()
defer s.modelLock.Unlock()
if s.model != nil {
s.llamaModelLock.Lock()
defer s.llamaModelLock.Unlock()
if s.llamaModel != nil {
var resp string
for _, token := range tokens {
resp += s.model.TokenToPiece(token)
resp += s.llamaModel.TokenToPiece(token)
}
return resp, nil
}
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return "", err
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
if err != nil {
return "", fmt.Errorf("marshaling decode data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data))
if err != nil {
return "", fmt.Errorf("decode request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("do decode request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
if s.model == nil {
slog.Debug("new runner detected, loading model for cgo tokenization")
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return "", err
}
s.model = m
if s.textProcessor != nil {
toks := make([]int32, len(tokens))
for i, t := range tokens {
toks[i] = int32(t)
}
var resp string
for _, token := range tokens {
resp += s.model.TokenToPiece(token)
content, err := s.textProcessor.Decode(toks)
if err != nil {
return "", err
}
return resp, nil
return content, nil
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read decode request: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm decode error: %s", body)
return "", fmt.Errorf("%s", body)
}
var decoded DetokenizeResponse
if err := json.Unmarshal(body, &decoded); err != nil {
return "", fmt.Errorf("unmarshal encode response: %w", err)
}
return decoded.Content, nil
// not reached
return "", fmt.Errorf("no tokenizer configured")
}
func (s *llmServer) Close() error {
s.modelLock.Lock()
if s.model != nil {
llama.FreeModel(s.model)
s.model = nil
s.llamaModelLock.Lock()
if s.llamaModel != nil {
llama.FreeModel(s.llamaModel)
s.llamaModel = nil
}
s.modelLock.Unlock()
s.llamaModelLock.Unlock()
if s.cmd != nil {
slog.Debug("stopping llama server")
@@ -1112,12 +1010,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
}
return 0
}
func parseDurationMs(ms float64) time.Duration {
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {
panic(err)
}
return dur
}

View File

@@ -2,9 +2,11 @@ package ml
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"os"
"slices"
"strconv"
"strings"
)
@@ -18,17 +20,51 @@ type Config interface {
Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32
}
type Backend interface {
Config() Config
Get(name string) Tensor
NewContext() Context
SystemInfo() string
NewContextSize(size int) Context
}
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
// CacheConfig controls optimizations (mostly backend-specific) that may transform
// the output the cache to work better with specific kernels.
type CacheConfig struct {
// CachePadding specifies the multiple for the number of tokens of cache history
// that will be returned from cache Get for k, v and mask. The capacity of the
// cache itself will also be increased to a multiple of this size if needed.
CachePadding int
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
// and return the permuted version via Get. This uses the cache copy operation
// to avoid a Contiguous call on the permuted tensor.
PermutedV bool
// MaskDType specifies the data type for generating the mask. If unset it will
// default to DTypeF32.
MaskDType DType
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
// Any position that does not correspond to an actual token will be filled with -Inf.
MaskBatchPadding int
}
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// Progress is a callback function that allows reporting percentage completion
// of model loading
Progress func(float32)
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
@@ -40,11 +76,14 @@ type BackendParams struct {
// TensorSplit is the fraction of the model to offload to each GPU
TensorSplit []float32
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool
}
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
@@ -52,23 +91,33 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
backends[name] = f
}
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
if backend, ok := backends["ggml"]; ok {
return backend(f, params)
return backend(ctx, f, params)
}
return nil, fmt.Errorf("unsupported backend")
}
type Context interface {
Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
FromIntSlice(s []int32, shape ...int) (Tensor, error)
Forward(...Tensor) Context
Compute(...Tensor)
MaxTensors() int
MaxGraphNodes() int
Close()
// Input returns a context appropriate for creating input tensors
Input() Context
// Output returns a context appropriate for creating output tensors
Output() Context
// Layer returns a context appropriate for creating intermediate tensors
Layer(int) Context
}
type Tensor interface {
@@ -91,8 +140,10 @@ type Tensor interface {
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor
@@ -102,6 +153,7 @@ type Tensor interface {
View(ctx Context, offset int, shape ...int) Tensor
Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context) Tensor
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
Pad(ctx Context, shape ...int) Tensor
Unpad(ctx Context, shape ...int) Tensor
@@ -116,6 +168,10 @@ type Tensor interface {
// operation equivalent to following code on a tensor named
// query:
//
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
@@ -169,8 +225,8 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
})
case DTypeF16:
f32 := ctx.Zeros(DTypeF32, t.Shape()...)
case DTypeF16, DTypeQ80, DTypeQ40:
f32 := ctx.Empty(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
@@ -195,16 +251,17 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
}
shape := t.Shape()
slices.Reverse(shape)
var sb strings.Builder
var f func([]int, int)
f = func(dims []int, stride int) {
prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
fmt.Fprint(&sb, "[")
defer func() { fmt.Fprint(&sb, "]") }()
sb.WriteString("[")
defer func() { sb.WriteString("]") }()
for i := 0; i < dims[0]; i++ {
if i >= items && i < dims[0]-items {
fmt.Fprint(&sb, "..., ")
sb.WriteString("..., ")
// skip to next printable element
skip := dims[0] - 2*items
if len(dims) > 1 {
@@ -219,9 +276,14 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
}
} else {
fmt.Fprint(&sb, fn(s[stride+i]))
text := fn(s[stride+i])
if len(text) > 0 && text[0] != '-' {
sb.WriteString(" ")
}
sb.WriteString(text)
if i < dims[0]-1 {
fmt.Fprint(&sb, ", ")
sb.WriteString(", ")
}
}
}
@@ -237,5 +299,7 @@ const (
DTypeOther DType = iota
DTypeF32
DTypeF16
DTypeQ80
DTypeQ40
DTypeI32
)

File diff suppressed because it is too large Load Diff

View File

@@ -114,6 +114,7 @@ extern "C" {
// get raw pointer to the first element of the array with the given key_id
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id);
// get ith C string from array with given key_id
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);

View File

@@ -0,0 +1,11 @@
#include "ggml.h"
#ifdef __cplusplus
extern "C" {
#endif
void ollama_debug(const struct ggml_tensor *tensor, bool verbose);
#ifdef __cplusplus
}
#endif

View File

@@ -484,33 +484,29 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
}
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
for (const auto & entry : dir_it) {
try {
if (entry.is_regular_file()) {
std::string filename = entry.path().filename().string();
std::string ext = entry.path().extension().string();
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
dl_handle_ptr handle { dl_load_library(entry.path()) };
if (!handle) {
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
if (entry.is_regular_file()) {
std::string filename = entry.path().filename().string();
std::string ext = entry.path().extension().string();
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
dl_handle_ptr handle { dl_load_library(entry.path()) };
if (!handle) {
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (!score_fn) {
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (!score_fn) {
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
continue;
}
int s = score_fn();
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
if (s > best_score) {
best_score = s;
best_path = entry.path();
}
int s = score_fn();
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
if (s > best_score) {
best_score = s;
best_path = entry.path();
}
}
} catch (const std::exception & e) {
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
}
}
}
@@ -533,6 +529,14 @@ void ggml_backend_load_all() {
ggml_backend_load_all_from_path(nullptr);
}
static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
try {
ggml_backend_load_best(name, silent, user_search_path);
} catch (const std::exception & e) {
GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
}
}
void ggml_backend_load_all_from_path(const char * dir_path) {
#ifdef NDEBUG
bool silent = true;
@@ -540,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
bool silent = false;
#endif
ggml_backend_load_best("blas", silent, dir_path);
ggml_backend_load_best("cann", silent, dir_path);
ggml_backend_load_best("cuda", silent, dir_path);
ggml_backend_load_best("hip", silent, dir_path);
ggml_backend_load_best("kompute", silent, dir_path);
ggml_backend_load_best("metal", silent, dir_path);
ggml_backend_load_best("rpc", silent, dir_path);
ggml_backend_load_best("sycl", silent, dir_path);
ggml_backend_load_best("vulkan", silent, dir_path);
ggml_backend_load_best("opencl", silent, dir_path);
ggml_backend_load_best("musa", silent, dir_path);
ggml_backend_load_best("cpu", silent, dir_path);
ggml_backend_try_load_best("blas", silent, dir_path);
ggml_backend_try_load_best("cann", silent, dir_path);
ggml_backend_try_load_best("cuda", silent, dir_path);
ggml_backend_try_load_best("hip", silent, dir_path);
ggml_backend_try_load_best("kompute", silent, dir_path);
ggml_backend_try_load_best("metal", silent, dir_path);
ggml_backend_try_load_best("rpc", silent, dir_path);
ggml_backend_try_load_best("sycl", silent, dir_path);
ggml_backend_try_load_best("vulkan", silent, dir_path);
ggml_backend_try_load_best("opencl", silent, dir_path);
ggml_backend_try_load_best("musa", silent, dir_path);
ggml_backend_try_load_best("cpu", silent, dir_path);
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
const char * backend_path = std::getenv("GGML_BACKEND_PATH");
if (backend_path) {

View File

@@ -0,0 +1,6 @@
//go:build debug
package cpu
// #cgo CPPFLAGS: -DOLLAMA_DEBUG
import "C"

View File

@@ -11,6 +11,8 @@
#include "ggml-threading.h"
#include "ggml.h"
#include "ollama-debug.h"
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@@ -14103,6 +14105,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
ggml_compute_forward(&params, node);
#ifdef OLLAMA_DEBUG
ollama_debug(node, true);
#endif
if (state->ith == 0 && cplan->abort_callback &&
cplan->abort_callback(cplan->abort_callback_data)) {
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);

View File

@@ -7,6 +7,20 @@ package ggml
// #include <stdlib.h>
// #include "ggml-backend.h"
// extern void sink(int level, char *text, void *user_data);
// static struct ggml_backend_feature * first_feature(ggml_backend_get_features_t fp, ggml_backend_reg_t reg) { return fp(reg); }
// static struct ggml_backend_feature * next_feature(struct ggml_backend_feature * feature) { return &feature[1]; }
/*
typedef enum { COMPILER_CLANG, COMPILER_GNUC, COMPILER_UNKNOWN } COMPILER;
static COMPILER compiler_name(void) {
#if defined(__clang__)
return COMPILER_CLANG;
#elif defined(__GNUC__)
return COMPILER_GNUC;
#else
return COMPILER_UNKNOWN;
#endif
}
*/
import "C"
import (
@@ -16,6 +30,7 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"unsafe"
@@ -90,4 +105,43 @@ var OnceLoad = sync.OnceFunc(func() {
visited[abspath] = struct{}{}
}
}
slog.Info("system", "", system{})
})
type system struct{}
func (system) LogValue() slog.Value {
var attrs []slog.Attr
names := make(map[string]int)
for i := range C.ggml_backend_dev_count() {
r := C.ggml_backend_dev_backend_reg(C.ggml_backend_dev_get(i))
func() {
fName := C.CString("ggml_backend_get_features")
defer C.free(unsafe.Pointer(fName))
if fn := C.ggml_backend_reg_get_proc_address(r, fName); fn != nil {
var features []any
for f := C.first_feature(C.ggml_backend_get_features_t(fn), r); f.name != nil; f = C.next_feature(f) {
features = append(features, C.GoString(f.name), C.GoString(f.value))
}
name := C.GoString(C.ggml_backend_reg_name(r))
attrs = append(attrs, slog.Group(name+"."+strconv.Itoa(names[name]), features...))
names[name] += 1
}
}()
}
switch C.compiler_name() {
case C.COMPILER_CLANG:
attrs = append(attrs, slog.String("compiler", "cgo(clang)"))
case C.COMPILER_GNUC:
attrs = append(attrs, slog.String("compiler", "cgo(gcc)"))
default:
attrs = append(attrs, slog.String("compiler", "cgo(unknown)"))
}
return slog.GroupValue(attrs...)
}

View File

@@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data();
}
size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
return ctx->kv[key_id].data.size();
}
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
@@ -874,7 +878,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data();
}

116
ml/backend/ggml/ggml/src/ollama-debug.c vendored Normal file
View File

@@ -0,0 +1,116 @@
#include <string.h>
#include <inttypes.h>
#include "ollama-debug.h"
static int mul(int64_t *dims, int ndims) {
int result = 1;
for (int i = 0; i < ndims; i++) {
result *= dims[i];
}
return result;
}
static void repeat(char c, int n) {
for (int i = 0; i < n; i++) {
fprintf(stderr, "%c", c);
}
}
static void print_tensor(const void *tensor, void (*cb)(const void *, int),
int shape,
int64_t *dims, int ndims, int stride,
int nitems, int pad) {
fprintf(stderr, "[");
for (int i = 0; i < dims[0]; i++) {
if (i >= nitems && i < dims[0] - nitems) {
fprintf(stderr, "... (%" PRIi64 " more), ", dims[0] - 2 * nitems);
int skip = dims[0] - 2 * nitems;
if (ndims > 1) {
stride += mul(dims + 1, ndims - 1) * skip;
repeat('\n', ndims - 1);
repeat(' ', shape - ndims + 1 + pad);
}
i += skip - 1;
} else if (ndims > 1) {
print_tensor(tensor, cb, shape, dims + 1, ndims - 1, stride,
nitems, pad);
stride += mul(dims + 1, ndims - 1);
if (i < dims[0] - 1) {
fprintf(stderr, ", ");
repeat('\n', ndims - 1);
repeat(' ', shape - ndims + 1 + pad);
}
} else {
cb(tensor, stride + i);
if (i < dims[0] - 1) {
fprintf(stderr, ", ");
}
}
}
fprintf(stderr, "]");
}
static void print_tensor_f16(const void *tensor, int i) {
float value = ggml_fp16_to_fp32(((const ggml_fp16_t *)tensor)[i]);
fprintf(stderr, "%s%f", value < 0 ? "" : " ", value);
}
static void print_tensor_f32(const void *tensor, int i) {
float value = ((const float *)tensor)[i];
fprintf(stderr, "%s%f", value < 0 ? "" : " ", value);
}
static void print_tensor_i32(const void *tensor, int i) {
int32_t value = ((const int32_t *)tensor)[i];
fprintf(stderr, "%s%d", value < 0 ? "" : " ", value);
}
static void ollama_debug_tensor(const struct ggml_tensor *tensor, bool verbose, const char *prefix, int indent) {
fprintf(stderr, "%s%s %s (%s): [%" PRIi64 " %" PRIi64 " %" PRIi64 " %" PRIi64 "]\n", prefix, tensor->name,
ggml_op_name(tensor->op), ggml_type_name(tensor->type), tensor->ne[0],
tensor->ne[1], tensor->ne[2], tensor->ne[3]);
if (!verbose) {
return;
}
for (int i = 0; i < indent; i++) {
fprintf(stderr, " ");
}
switch (tensor->type) {
case GGML_TYPE_F16:
print_tensor(ggml_get_data(tensor), print_tensor_f16, ggml_n_dims(tensor),
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
break;
case GGML_TYPE_F32:
print_tensor(ggml_get_data(tensor), print_tensor_f32, ggml_n_dims(tensor),
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
break;
case GGML_TYPE_I32:
print_tensor(ggml_get_data(tensor), print_tensor_i32, ggml_n_dims(tensor),
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
break;
default:
fprintf(stderr, "<unsupported type>\n");
return;
}
fprintf(stderr, "\n");
}
void ollama_debug(const struct ggml_tensor *tensor, bool verbose) {
ollama_debug_tensor(tensor, verbose, ">>> ", 4);
for (int i = 0; i < GGML_MAX_SRC && tensor->src[i] != NULL; ++i) {
char src[8];
const int n = snprintf(src, sizeof(src), " src%d ", i);
if (n >= sizeof(src)) {
src[sizeof(src) - 1] = '\0';
}
ollama_debug_tensor(tensor->src[i], verbose, src, 4);
}
}

View File

@@ -0,0 +1,7 @@
//go:build !debug
package ggml
func Threads(n int) int {
return n
}

View File

@@ -0,0 +1,7 @@
//go:build debug
package ggml
func Threads(_ int) int {
return 1
}

View File

@@ -3,6 +3,7 @@ package nn
import (
"fmt"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
@@ -11,40 +12,50 @@ import (
//
// Parameters:
// - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
// - mask: Optional attention mask that is added to the attention score. If
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
//
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if cache != nil {
cache.Put(ctx, key, value)
}
} else if cache == nil {
panic("key & value tensors must be provided if cache is nil")
}
if mask != nil && query.Dim(1) != mask.Dim(1) {
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
var mask ml.Tensor
if cache != nil {
key, value, mask = cache.Get(ctx)
}
if key.Dim(1) != value.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
}
if mask != nil && key.Dim(1) != mask.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
// Only use the fast SDPA implementation if we have a cache, since that's what
// will do any expected backend-specific transformations for us
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
} else {
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)

58
model/input/input.go Normal file
View File

@@ -0,0 +1,58 @@
package input
import "github.com/ollama/ollama/ml"
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
Token int32
// Multimodal is opaque data representing a non-text
// element such as an image (or part of one if the image
// can be processed in pieces). It may be either together
// with Token or on its own.
Multimodal any
// MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing
// equality.
MultimodalHash uint64
// SameBatch forces the following number of tokens to be processed
// in a single batch, breaking and extending batches as needed.
// Useful for things like images that must be processed in one
// shot.
SameBatch int
}
// MultimodalIndex is a multimodal element (such as an image)
// together with an index into the slice of Inputs with the
// corresponding token. Note that the index is not the same
// as the position - to find that use the index with the
// Positions slice.
type MultimodalIndex struct {
Index int
Multimodal any
}
// Batch contains the inputs for a model forward pass
type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal []MultimodalIndex
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
Positions []int32
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs []int32
}

View File

@@ -1,9 +1,9 @@
package model
import (
"context"
"errors"
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"log/slog"
@@ -16,23 +16,52 @@ import (
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/model/input"
)
// Options contains the inputs for a model forward pass
type Options struct {
Inputs []int32
Positions []int32
Sequences []int
Outputs []int32
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
Images []image.Image
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, input.Batch) (ml.Tensor, error)
Backend() ml.Backend
Config() config
}
type config struct {
Cache kvcache.Cache
// MultimodalProcessor must be implemented by multimodal models.
type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and
// generates an output (typically an embedding) that can be used by the model.
//
// The return value is most typically an ml.Tensor, however, different
// type are possible, such as an object containing a tensor plus
// additional metadata, a slice of tensors or even just the original input.
//
// The result may be cached by the runner.
EncodeMultimodal(ml.Context, []byte) (any, error)
// PostTokenize is called after tokenization to allow the model to edit the
// input stream to correctly arrange multimodal elements.
//
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
// in the order that the user provided them. Each element of the slice will be
// either a single token or single multimodal object.
//
// The model must ensure that inputs are stored according to how they will be
// processed and stored in the cache. For example, Llava-style models should insert
// placeholder tokens equal to the feature size of the corresponding image with
// the image itself attached to and split across these tokens. When Forward is called
// a partial subset of these tokens may be submitted according to the batch size.
//
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize([]input.Input) ([]input.Input, error)
}
// Base implements the common fields and methods for all models
@@ -41,6 +70,10 @@ type Base struct {
config
}
type config struct {
Cache kvcache.Cache
}
// Backend returns the underlying backend that will run the model
func (m *Base) Backend() ml.Backend {
return m.b
@@ -50,14 +83,6 @@ func (m *Base) Config() config {
return m.config
}
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, Options) (ml.Tensor, error)
Backend() ml.Backend
Config() config
}
var models = make(map[string]func(ml.Config) (Model, error))
// Register registers a model constructor for the given architecture
@@ -70,14 +95,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string, params ml.BackendParams) (Model, error) {
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
b, err := ml.NewBackend(r, params)
b, err := ml.NewBackend(ctx, r, params)
if err != nil {
return nil, err
}
@@ -100,6 +125,36 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil
}
func NewTextProcessor(s string) (TextProcessor, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
}
defer r.Close()
meta, _, err := fs.Decode(r, -1)
if err != nil {
return nil, err
}
return getTextProcessor(meta.KV())
}
func getTextProcessor(kv fs.KV) (TextProcessor, error) {
arch := kv.Architecture()
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
}
m, err := f(kv)
if err != nil {
return nil, err
}
tp, ok := m.(TextProcessor)
if !ok {
return nil, fmt.Errorf("%v is not a TextProcessor", m)
}
return tp, nil
}
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type()
@@ -226,24 +281,30 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
if len(opts.Positions) != len(opts.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
}
if len(opts.Positions) < 1 {
if len(batch.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
}
var err error
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil {
return nil, err
}
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
err := cache.StartForward(ctx, batch)
if err != nil {
return nil, err
}
}
t, err := m.Forward(ctx, opts)
t, err := m.Forward(ctx, batch)
if err != nil {
return nil, err
}

View File

@@ -3,12 +3,15 @@ package model
import (
"reflect"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
)
func TestParseTags(t *testing.T) {
@@ -134,3 +137,40 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
}
}
func TestGetTextProcessor(t *testing.T) {
tp, err := getTextProcessor(fs.KV{})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
t.Errorf("unexpected error: %v", err)
} else if tp != nil {
t.Error("expected nil tp")
}
models["dummy"] = func(ml.Config) (Model, error) {
return notTextProcessorModel{}, nil
}
tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
t.Errorf("unexpected error: %v", err)
} else if tp != nil {
t.Error("expected nil tp")
}
}
type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
panic("unimplemented")
}
func (notTextProcessorModel) Backend() ml.Backend {
panic("unimplemented")
}
func (notTextProcessorModel) Config() config {
panic("unimplemented")
}

View File

@@ -0,0 +1,214 @@
package gemma2
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeBase, ropeScale float32
attnLogitSoftcap float32
finalLogitSoftcap float32
largeModelScaling bool
}
type Model struct {
model.Base
model.SentencePieceModel
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"` // just set to token_embd?
*Options
}
const (
gemma27BLayerCount = 46
)
func New(c ml.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length")),
attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"),
},
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
m.Cache.SetConfig(ml.CacheConfig{})
return &m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
} else {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
}
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3)
k = k.Permute(ctx, 0, 2, 1, 3)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
// logit softcap
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
kq = kq.Tanh(ctx)
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
m.Options.largeModelScaling = true
}
for i, layer := range m.Layers {
cacheType := i % 2
m.Cache.SetLayer(i)
wc := m.Cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
}
func init() {
model.Register("gemma2", New)
}

View File

@@ -0,0 +1,158 @@
package gemma3
import (
"bytes"
"image"
"math"
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.SentencePieceModel
*VisionModel `gguf:"v,vision"`
*TextModel
*MultiModalProjector `gguf:"mm"`
ImageProcessor
}
var _ model.MultimodalProcessor = (*Model)(nil)
type MultiModalProjector struct {
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
InputProjection *nn.Linear `gguf:"mm_input_projection"`
tokensPerImage int
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
l := visionOutputs.Dim(0)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
patchesPerImage := imageSize / patchSize
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
visionOutputs = p.InputProjection.Weight.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mulmat(ctx, visionOutputs)
return visionOutputs
}
func New(c ml.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(1),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOT: int32(106),
AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
},
),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
MultiModalProjector: &MultiModalProjector{
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
},
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
)
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
return visionOutputs, nil
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal.(ml.Tensor)
result = append(result,
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
input.Input{Token: 255999}, // "<start_of_image>""
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
)
// add image token placeholders
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
result = append(result,
input.Input{Token: 256000}, // <end_of_image>
input.Input{Token: 108}, // "\n\n"
)
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
}
func init() {
model.Register("gemma3", New)
}

View File

@@ -0,0 +1,214 @@
package gemma3
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
largeModelScaling bool
}
type TextModel struct {
model.Base
model.SentencePieceModel
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []TextLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
const (
gemmaGlobalCacheCount = 6
gemma27BLayerCount = 62
)
const (
cacheTypeSWA = iota
cacheTypeCausal
)
func newTextModel(c ml.Config) *TextModel {
numBlocks := int(c.Uint("block_count"))
m := TextModel{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
Layers: make([]TextLayer, numBlocks),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
},
}
if numBlocks == gemma27BLayerCount {
m.largeModelScaling = true
}
return &m
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
ropeBase := opts.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
} else {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
}
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
scaleFactor := 1.0
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextOptions.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextOptions.ropeGlobalBase
}
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
}
type TextMLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type TextLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *TextSelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *TextMLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
// set image embeddings
var except []int
for _, image := range batch.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor)
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
for i := range visionOutputs.Dim(1) {
except = append(except, image.Index+i)
}
}
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers
cacheType := cacheTypeSWA
if (i+1)%gemmaGlobalCacheCount == 0 {
cacheType = cacheTypeCausal
}
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState)
}

View File

@@ -0,0 +1,127 @@
package gemma3
import (
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention)
return hiddenState
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
return hiddenState
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
MLP *VisionMLP `gguf:"mlp"`
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// self attention
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// feed forward
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize, numHeads int
imageSize, patchSize int
eps float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embedding"`
PositionEmbedding *nn.Embedding `gguf:"position_embedding"`
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
positions := make([]int32, numPatches)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
panic(err)
}
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
}
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}
func newVisionModel(c ml.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
eps: c.Float("vision.attention.layer_norm_epsilon"),
},
}
}

View File

@@ -0,0 +1,58 @@
package gemma3
import (
"image"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize, patchSize, numChannels int
}
func newImageProcessor(c ml.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
numChannels: int(c.Uint("vision.num_channels")),
}
}
func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
var pixelVals, rVals, gVals, bVals []float32
bounds := img.Bounds()
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
c := img.At(x, y)
r, g, b, _ := c.RGBA()
rVal := float32(r>>8) / 255.0
gVal := float32(g>>8) / 255.0
bVal := float32(b>>8) / 255.0
rVal = (rVal - mean[0]) / std[0]
gVal = (gVal - mean[1]) / std[1]
bVal = (bVal - mean[2]) / std[2]
rVals = append(rVals, rVal)
gVals = append(gVals, gVal)
bVals = append(bVals, bVal)
}
}
pixelVals = append(pixelVals, rVals...)
pixelVals = append(pixelVals, gVals...)
pixelVals = append(pixelVals, bVals...)
return pixelVals
}
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
outputSize := image.Point{p.imageSize, p.imageSize}
newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
return data, nil
}

View File

@@ -1,16 +1,18 @@
package llama
import (
"fmt"
"math"
"strings"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
@@ -29,6 +31,10 @@ type Model struct {
}
func New(c ml.Config) (model.Model, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
@@ -60,43 +66,38 @@ func New(c ml.Config) (model.Model, error) {
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type MLP struct {
@@ -138,23 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)

View File

@@ -1,10 +1,18 @@
package mllama
import (
"bytes"
"encoding/binary"
"fmt"
"hash/fnv"
"image"
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
@@ -25,6 +33,10 @@ const (
)
func New(c ml.Config) (model.Model, error) {
// Verify unified config
if c.Uint("vision.block_count") == 0 {
return nil, fmt.Errorf("non-unified vision model not supported")
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
@@ -43,65 +55,107 @@ func New(c ml.Config) (model.Model, error) {
TextModel: newTextModel(c),
}
m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
encoderCache := kvcache.NewEncoderCache()
encoderCache.SetConfig(ml.CacheConfig{})
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
return &m, nil
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
m.ImageProcessor.maxNumTiles,
)
if err != nil {
return nil, err
}
aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(aspectRatioID)}, 1)
if err != nil {
return nil, err
}
positions := make([]int32, 1601)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
return nil, err
}
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
return m.Projector.Forward(ctx, crossAttentionStates), nil
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var images []input.Input
fnvHash := fnv.New64a()
for i := range inputs {
if inputs[i].Multimodal == nil {
if len(images) > 0 {
inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ {
inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
inputs[i].MultimodalHash = fnvHash.Sum64()
}
images = nil
}
} else {
images = append(images, inputs[i])
inputs[i].Token = -1
}
}
inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
if opts.Images != nil {
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
if err != nil {
return nil, err
if len(batch.Multimodal) > 0 {
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
if len(images) > 0 {
crossAttentionStates = images[len(images)-1]
}
pixelValues, err := ctx.FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
m.ImageProcessor.maxNumTiles,
)
if err != nil {
return nil, err
}
aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
if err != nil {
return nil, err
}
positions := make([]int32, 1601)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.FromIntSlice(positions, len(positions))
if err != nil {
return nil, err
}
crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
}
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
// TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
}
func init() {

View File

@@ -10,36 +10,31 @@ import (
)
type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
cache.Put(ctx, key, value)
key, value, mask := cache.Get(ctx)
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
@@ -47,7 +42,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
}
return key, nil
}
type TextMLP struct {
@@ -107,7 +106,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
var key, value, mask ml.Tensor
var key, value ml.Tensor
if crossAttentionStates != nil {
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
@@ -119,16 +118,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
cache.Put(ctx, key, value)
} else {
key, value, mask = cache.Get(ctx)
}
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
key, value, _ = cache.Get(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scaleFactor)
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention)
@@ -191,8 +197,6 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
}
type TextModelOptions struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32

View File

@@ -144,8 +144,6 @@ func (p *ImageProcessor) splitToTiles(img image.Image, numTilesSize image.Point)
return images
}
// remove the "alpha" channel by drawing over a prefilled image
//
// remove the "alpha" channel by drawing over a prefilled image
//
//nolint:unused

View File

@@ -1,6 +1,8 @@
package models
import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mllama"
)

View File

@@ -4,6 +4,7 @@ import (
"cmp"
"iter"
"log/slog"
"slices"
"strings"
"sync"
@@ -18,8 +19,17 @@ const (
SpecialEOS
)
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
Encode(string) ([]int32, error)
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
}
@@ -27,11 +37,11 @@ type TextProcessor interface {
type Vocabulary struct {
Values []string
Types []uint32
Scores []uint32
Scores []float32
Merges []string
BOS, EOS int32
AddBOS, AddEOS bool
BOS, EOS, EOT int32
AddBOS, AddEOS, AddEOT bool
specialOnce sync.Once
special []string
@@ -48,7 +58,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
case SpecialBOS:
return id == v.BOS
case SpecialEOS:
return id == v.EOS
return id == v.EOS || id == v.EOT
default:
return false
}
@@ -76,7 +86,9 @@ func (v *Vocabulary) Decode(id int32) string {
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == 3 {
if slices.Contains([]int{105, 106}, i) {
v.special = append(v.special, v.Values[i])
} else if v.Types[i] == TOKEN_TYPE_CONTROL {
v.special = append(v.special, v.Values[i])
}
}
@@ -144,7 +156,7 @@ type merge struct {
runes []rune
}
func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
@@ -177,7 +189,6 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
slog.Debug("encoded", "text", frag.value, "ids", frag.ids, "special", true)
continue
}
@@ -201,7 +212,6 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
// short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
slog.Debug("encoded", "text", sb.String(), "ids", []int32{id})
continue
}
@@ -275,14 +285,13 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
slog.Debug("encoded", "text", string(merge.runes), "ids", []int32{id})
}
}
}
}
}
if len(ids) > 0 {
if addSpecial && len(ids) > 0 {
if bpe.vocab.AddBOS {
if ids[0] == bpe.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
@@ -329,6 +338,5 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
}
}
slog.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil
}

246
model/process_text_spm.go Normal file
View File

@@ -0,0 +1,246 @@
package model
import (
"iter"
"log/slog"
"strings"
"github.com/dlclark/regexp2"
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
)
const spmWhitespaceSep = "▁"
func replaceWhitespaceBySeperator(s string) string {
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
}
type SentencePieceModel struct {
maxTokenLen int
pre *regexp2.Regexp
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePieceModel)(nil)
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
var maxTokenLen int
for cnt := range vocab.Types {
switch vocab.Types[cnt] {
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
fallthrough
default:
counter[int(vocab.Types[cnt])] += 1
}
}
slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePieceModel{
maxTokenLen: maxTokenLen,
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
vocab: vocab,
}
}
func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
return func(yield func(string) bool) {
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
if !yield(m.String()) {
break
}
}
}
}
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := spm.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
slog.Debug("fragments", "frags", fragments)
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range spm.split(frag.value) {
split = replaceWhitespaceBySeperator(split)
var sb strings.Builder
sb.Write([]byte(split))
if id := spm.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
pq := queue.NewWith(func(a, b any) int {
priA := a.(*candidate)
priB := b.(*candidate)
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
return -1
}
return 1
})
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
slog.Debug("tokenizer", "merges", merges)
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pq.Enqueue(pair)
}
}
pqv := pq.Values()
for _, v := range pqv {
e := v.(*candidate)
slog.Debug("candidate", "candidate", e)
}
for !pq.Empty() {
v, _ := pq.Dequeue()
pair := v.(*candidate)
left, right := merges[pair.a], merges[pair.b]
slog.Debug("pair", "left", left, "right", right)
if len(left.runes) == 0 || len(right.runes) == 0 {
continue
}
if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pq.Enqueue(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pq.Enqueue(pair)
}
}
slog.Debug("merges", "merges", merges)
for _, merge := range merges {
if len(merge.runes) > 0 {
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
} else {
slog.Debug("missing token", "token", string(merge.runes))
}
}
}
}
}
if addSpecial && len(ids) > 0 {
if spm.vocab.AddBOS {
if ids[0] == spm.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
}
slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
ids = append([]int32{spm.vocab.BOS}, ids...)
}
if spm.vocab.AddEOS {
if ids[len(ids)-1] == spm.vocab.EOS {
slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
}
slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
ids = append(ids, spm.vocab.EOS)
}
}
return ids, nil
}
type candidate struct {
a, b int
score float32
}
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
if _, err := sb.WriteString(data); err != nil {
return "", err
}
}
slog.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil
}

View File

@@ -0,0 +1,118 @@
package model
import (
"log/slog"
"os"
"path/filepath"
"slices"
"testing"
"google.golang.org/protobuf/proto"
"github.com/ollama/ollama/convert/sentencepiece"
)
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
if err != nil {
t.Fatal(err)
}
var spm sentencepiece.ModelProto
if err := proto.Unmarshal(bts, &spm); err != nil {
t.Fatal(err)
}
preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
var v Vocabulary
for _, piece := range spm.GetPieces() {
v.Values = append(v.Values, piece.GetPiece())
v.Scores = append(v.Scores, piece.GetScore())
switch t := piece.GetType(); t {
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
sentencepiece.ModelProto_SentencePiece_CONTROL,
sentencepiece.ModelProto_SentencePiece_UNUSED,
sentencepiece.ModelProto_SentencePiece_BYTE:
v.Types = append(v.Types, uint32(t))
default:
tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// todo parse the special tokens file
// - this will roundtrip correctly but the <start_of_turn> and
// <end_of_turn> tokens aren't processed
v.Types = append(v.Types, tt)
}
}
return NewSentencePieceModel(preTokenizer, &v)
}
func TestSentencePieceEncode(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
tokenizer := loadSentencePieceVocab(t)
t.Run("basic roundtrip", func(t *testing.T) {
t.Parallel()
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
"你好",
"Hello 你好 world!",
"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
"Multilingual: 你好 こんにちは Привет Hola مرحبا",
"Numbers and symbols: 123456789 +- */",
"Special tokens: <bos> text <eos>",
"Code snippets: func main() { fmt.Println(\"Hello World\") }",
"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
}
for _, want := range cases {
ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Fatal(err)
}
if got, err := tokenizer.Decode(ids); err != nil {
t.Fatal(err)
} else if got != want {
t.Errorf("got %q, want %q [%#v]", got, want, ids)
}
}
})
t.Run("special tokens", func(t *testing.T) {
type candidate struct {
token string
ids []int32
}
cases := []candidate{
{"<bos>", []int32{2}},
{"<eos>", []int32{1}},
}
for _, want := range cases {
ids, err := tokenizer.Encode(want.token, true)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(ids, want.ids) {
t.Errorf("got %#v, want %#v", ids, want.ids)
}
}
})
}

View File

@@ -74,7 +74,7 @@ func TestLlama(t *testing.T) {
t.Run("simple", func(t *testing.T) {
t.Parallel()
ids, err := tokenizer.Encode("hello world")
ids, err := tokenizer.Encode("hello world", true)
if err != nil {
t.Error(err)
}
@@ -92,7 +92,7 @@ func TestLlama(t *testing.T) {
t.Errorf("got %q, want hello world", s)
}
ids, err = tokenizer.Encode("hello <|end_of_text|>")
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
if err != nil {
t.Error(err)
}
@@ -126,7 +126,7 @@ func TestLlama(t *testing.T) {
}
for s, want := range cases {
ids, err := tokenizer.Encode(s)
ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Error(err)
}
@@ -152,7 +152,7 @@ func TestLlama(t *testing.T) {
}
for _, want := range cases {
ids, err := tokenizer.Encode(want)
ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Error(err)
}
@@ -176,7 +176,7 @@ func TestLlama(t *testing.T) {
}
for s, want := range cases {
ids, err := tokenizer.Encode(s)
ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Fatal(err)
}
@@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for range b.N {
_, err := tokenizer.Encode(string(bts))
_, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}
@@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
})
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
ids, err := tokenizer.Encode(string(bts))
ids, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}

BIN
model/testdata/gemma2/tokenizer.model vendored Normal file

Binary file not shown.

View File

@@ -116,19 +116,9 @@ func (i *Instance) Readline() (string, error) {
switch r {
case KeyUp:
if i.History.Pos > 0 {
if i.History.Pos == i.History.Size() {
currentLineBuf = []rune(buf.String())
}
buf.Replace([]rune(i.History.Prev()))
}
i.historyPrev(buf, &currentLineBuf)
case KeyDown:
if i.History.Pos < i.History.Size() {
buf.Replace([]rune(i.History.Next()))
if i.History.Pos == i.History.Size() {
buf.Replace(currentLineBuf)
}
}
i.historyNext(buf, &currentLineBuf)
case KeyLeft:
buf.MoveLeft()
case KeyRight:
@@ -185,6 +175,10 @@ func (i *Instance) Readline() (string, error) {
esc = true
case CharInterrupt:
return "", ErrInterrupt
case CharPrev:
i.historyPrev(buf, &currentLineBuf)
case CharNext:
i.historyNext(buf, &currentLineBuf)
case CharLineStart:
buf.MoveToStart()
case CharLineEnd:
@@ -246,6 +240,24 @@ func (i *Instance) HistoryDisable() {
i.History.Enabled = false
}
func (i *Instance) historyPrev(buf *Buffer, currentLineBuf *[]rune) {
if i.History.Pos > 0 {
if i.History.Pos == i.History.Size() {
*currentLineBuf = []rune(buf.String())
}
buf.Replace([]rune(i.History.Prev()))
}
}
func (i *Instance) historyNext(buf *Buffer, currentLineBuf *[]rune) {
if i.History.Pos < i.History.Size() {
buf.Replace([]rune(i.History.Next()))
if i.History.Pos == i.History.Size() {
buf.Replace(*currentLineBuf)
}
}
}
func NewTerminal() (*Terminal, error) {
fd := os.Stdin.Fd()
termios, err := SetRawMode(fd)

View File

@@ -24,6 +24,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/runner/common"
)
@@ -99,7 +100,7 @@ type NewSequenceParams struct {
embedding bool
}
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
startTime := time.Now()
@@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// generating image embeddings for each image
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) {
var inputs []input
var parts []string
var matches [][]string
@@ -229,7 +230,7 @@ type Server struct {
image *ImageContext
// status for external health reporting - loading, ready to serve, etc.
status ServerStatus
status llm.ServerStatus
// current progress on loading the model
progress float32
@@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return nil
}
// TODO (jmorganca): use structs from the api package to avoid duplication
// this way the api acts as a proxy instead of using a different api for the
// runner
type Options struct {
api.Runner
NumKeep int `json:"n_keep"`
Seed int `json:"seed"`
NumPredict int `json:"n_predict"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
Temperature float32 `json:"temperature"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
Stop []string `json:"stop"`
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
AspectRatioID int `json:"aspect_ratio_id"`
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Options
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req CompletionRequest
req.Options = Options(api.DefaultOptions())
var req llm.CompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
@@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
var samplingParams llama.SamplingParams
samplingParams.TopK = req.TopK
samplingParams.TopP = req.TopP
samplingParams.MinP = req.MinP
samplingParams.TypicalP = req.TypicalP
samplingParams.Temp = req.Temperature
samplingParams.RepeatLastN = req.RepeatLastN
samplingParams.PenaltyRepeat = req.RepeatPenalty
samplingParams.PenaltyFreq = req.FrequencyPenalty
samplingParams.PenaltyPresent = req.PresencePenalty
samplingParams.Mirostat = req.Mirostat
samplingParams.MirostatTau = req.MirostatTau
samplingParams.MirostatEta = req.MirostatEta
samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar
// Extract options from the CompletionRequest
samplingParams := llama.SamplingParams{
TopK: req.Options.TopK,
TopP: req.Options.TopP,
MinP: req.Options.MinP,
TypicalP: req.Options.TypicalP,
Temp: req.Options.Temperature,
RepeatLastN: req.Options.RepeatLastN,
PenaltyRepeat: req.Options.RepeatPenalty,
PenaltyFreq: req.Options.FrequencyPenalty,
PenaltyPresent: req.Options.PresencePenalty,
Mirostat: req.Options.Mirostat,
MirostatTau: req.Options.MirostatTau,
MirostatEta: req.Options.MirostatEta,
Seed: uint32(req.Options.Seed),
Grammar: req.Grammar,
}
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: req.NumKeep,
numPredict: req.Options.NumPredict,
stop: req.Options.Stop,
numKeep: req.Options.NumKeep,
samplingParams: &samplingParams,
embedding: false,
})
@@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
// Send the final response
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Stop: true,
StoppedLimit: seq.doneReason == "limit",
Timings: Timings{
PromptN: seq.numPromptInputs,
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
PredictedN: seq.numDecoded,
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
},
doneReason := "stop"
if seq.doneReason == "limit" {
doneReason = "length"
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true,
DoneReason: doneReason,
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numDecoded,
EvalDuration: time.Since(seq.startGenerationTime),
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
}
@@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
type EmbeddingRequest struct {
Content string `json:"content"`
CachePrompt bool `json:"cache_prompt"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
var req EmbeddingRequest
var req llm.EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
@@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
type ServerStatus int
const (
ServerStatusReady ServerStatus = iota
ServerStatusLoadingModel
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "ok"
case ServerStatusLoadingModel:
return "loading model"
default:
return "server error"
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&HealthResponse{
Status: s.status.ToString(),
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
Status: s.status,
Progress: s.progress,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -879,7 +794,7 @@ func (s *Server) loadModel(
panic(err)
}
s.status = ServerStatusReady
s.status = llm.ServerStatusReady
s.ready.Done()
}
@@ -931,14 +846,13 @@ func Execute(args []string) error {
slog.Info("starting go runner")
llama.BackendInit()
slog.Info("system", "info", llama.PrintSystemInfo(), "threads", *threads)
server := &Server{
batchSize: *batchSize,
parallel: *parallel,
seqs: make([]*Sequence, *parallel),
seqsSem: semaphore.NewWeighted(int64(*parallel)),
status: ServerStatusLoadingModel,
status: llm.ServerStatusLoadingModel,
}
var tensorSplitFloats []float32
@@ -968,13 +882,14 @@ func Execute(args []string) error {
server.cond = sync.NewCond(&server.mu)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port)
listener, err := net.Listen("tcp", addr)
if err != nil {
fmt.Println("Listen error:", err)
cancel()
return err
}
defer listener.Close()
@@ -994,6 +909,5 @@ func Execute(args []string) error {
return err
}
cancel()
return nil
}

View File

@@ -5,12 +5,12 @@ import (
"fmt"
"log/slog"
"math"
"reflect"
"time"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type InputCache struct {
@@ -31,27 +31,26 @@ type InputCache struct {
cache kvcache.Cache
}
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
if kvSize/int32(numSlots) < 1 {
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
numCtx := kvSize / int32(numSlots)
if numCtx < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
}
slots := make([]InputCacheSlot, numSlots)
for i := range slots {
slots[i] = InputCacheSlot{
Id: i,
Inputs: make([]input, 0),
}
slots[i] = InputCacheSlot{Id: i}
}
cache := model.Config().Cache
if cache != nil {
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
}
return &InputCache{
numCtx: kvSize / int32(numSlots),
numCtx: numCtx,
enabled: cache != nil,
slots: slots,
multiUserCache: multiUserCache,
@@ -62,9 +61,9 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
func kvCacheTypeFromStr(s string) ml.DType {
switch s {
case "q8_0":
panic("kv cache quantization not yet implemented")
return ml.DTypeQ80
case "q4_0":
panic("kv cache quantization not yet implemented")
return ml.DTypeQ40
default:
return ml.DTypeF16
}
@@ -83,7 +82,7 @@ type InputCacheSlot struct {
Id int
// Inputs that are stored in the KV cache
Inputs []input
Inputs []input.Input
// is this cache actively being processed as part of a sequence?
InUse bool
@@ -92,7 +91,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -110,10 +109,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
return nil, nil, err
}
if !cachePrompt {
numPast = 0
}
slot.InUse = true
slot.lastUsed = time.Now()
@@ -143,7 +138,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
return slot, prompt, nil
}
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
longest := int32(-1)
var longestSlot *InputCacheSlot
@@ -166,7 +161,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int3
return longestSlot, longest, nil
}
func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
oldest := time.Now()
var oldestSlot *InputCacheSlot
@@ -202,7 +197,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs))
oldestSlot.Inputs = make([]input, longest)
oldestSlot.Inputs = make([]input.Input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@@ -212,7 +207,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
return oldestSlot, longest, nil
}
func countCommonPrefix(a []input, b []input) int32 {
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
var count int32
for i := range a {
@@ -220,7 +215,7 @@ func countCommonPrefix(a []input, b []input) int32 {
break
}
if !reflect.DeepEqual(a[i], b[i]) {
if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash {
break
}

View File

@@ -4,6 +4,8 @@ import (
"image"
"testing"
"time"
"github.com/ollama/ollama/model/input"
)
func TestCountCommon(t *testing.T) {
@@ -13,44 +15,50 @@ func TestCountCommon(t *testing.T) {
tests := []struct {
name string
t1 []input
t2 []input
t1 []input.Input
t2 []input.Input
expected int32
}{
{
name: "Equal",
t1: []input{{token: 1}, {token: 2}, {token: 3}},
t2: []input{{token: 1}, {token: 2}, {token: 3}},
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 3,
},
{
name: "Prefix",
t1: []input{{token: 1}},
t2: []input{{token: 1}, {token: 2}, {token: 3}},
t1: []input.Input{{Token: 1}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 1,
},
{
name: "Image Prefix",
t1: []input{{image: imgA}},
t2: []input{{image: imgA}, {image: imgB}, {image: imgC}},
t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}},
t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
expected: 1,
},
{
name: "Mixed",
t1: []input{{token: 1}, {image: imgA}},
t2: []input{{token: 1}, {image: imgA}, {token: 5}},
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
expected: 2,
},
{
name: "Mixed, Same Length",
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
expected: 1,
},
{
name: "Empty",
t1: []input{},
t2: []input{{token: 1}, {token: 2}, {token: 3}},
t1: []input.Input{},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 0,
},
{
name: "Both Empty",
t1: []input{},
t2: []input{},
t1: []input.Input{},
t2: []input.Input{},
expected: 0,
},
}
@@ -74,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input
prompt []input.Input
longest expected
best expected
}{
@@ -83,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
{
Id: 1,
Inputs: []input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []input{{token: 1}},
prompt: []input.Input{{Token: 1}},
longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0},
},
@@ -103,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}},
Inputs: []input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{{token: 1}, {token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input{{token: 1}, {token: 2}},
prompt: []input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2},
},
@@ -123,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}, {token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []input{{token: 2}},
prompt: []input.Input{{Token: 2}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@@ -144,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}, {token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
},
},
prompt: []input{{token: 1}},
prompt: []input.Input{{Token: 1}},
longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1},
},
@@ -165,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}},
Inputs: []input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{{token: 1}, {token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input{{token: 2}, {token: 3}},
prompt: []input.Input{{Token: 2}, {Token: 3}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@@ -185,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}, {token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{{token: 1}},
Inputs: []input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input{{token: 1}, {token: 2}},
prompt: []input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2},
},
@@ -289,3 +297,131 @@ func TestShiftDiscard(t *testing.T) {
})
}
}
func TestLoadCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input.Input
wantErr bool
expectedSlotId int
expectedPrompt int // expected length of remaining prompt
}{
{
name: "Basic cache hit - single user",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Basic cache hit - multi user",
cache: InputCache{
multiUserCache: true,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Exact match - leave one input",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Should leave 1 token for sampling
},
{
name: "No available slots",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: true,
expectedSlotId: -1,
expectedPrompt: -1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
// Check error state
if (err != nil) != tt.wantErr {
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return // Skip further checks if we expected an error
}
// Verify slot ID
if slot.Id != tt.expectedSlotId {
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
}
// Verify slot is now marked in use
if !slot.InUse {
t.Errorf("LoadCacheSlot() slot not marked InUse")
}
// Verify remaining prompt length
if len(remainingPrompt) != tt.expectedPrompt {
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
len(remainingPrompt), tt.expectedPrompt)
}
})
}
}

View File

@@ -1,13 +1,12 @@
package ollamarunner
import (
"bytes"
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"image"
"hash/maphash"
"log"
"log/slog"
"net"
@@ -25,30 +24,33 @@ import (
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
_ "github.com/ollama/ollama/model/models"
)
// input is an element of the prompt to process, either a token or an image
type input struct {
token int32
image image.Image
type contextList struct {
list []ml.Context
}
type Sequence struct {
// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
// multimodal embeddings
ctxs *contextList
// batch index
iBatch int
// prompt inputs left to evaluate
inputs []input
inputs []input.Input
// inputs that have been added to a batch but not yet submitted to Forward
pendingInputs []input
pendingInputs []input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
@@ -97,12 +99,12 @@ type NewSequenceParams struct {
embedding bool
}
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
startTime := time.Now()
inputs, err := s.inputs(prompt, images)
inputs, ctxs, err := s.inputs(prompt, images)
if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 {
@@ -113,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
params.numKeep = int32(len(inputs))
}
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
// otherwise we might truncate or split the batch against the model's wishes
// Ensure that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
@@ -128,6 +133,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// TODO(jessegross): Ingest cached history for grammar
return &Sequence{
ctxs: ctxs,
inputs: inputs,
numPromptInputs: len(inputs),
startProcessingTime: startTime,
@@ -146,28 +152,38 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
var inputs []input
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
var inputs []input.Input
var parts []string
var matches [][]string
// TODO(jessegross): This can sometimes trigger for matching text in the
// user's prompt. We previously tried to avoid it by only looking for images
// on image models. We don't have a clear indication now but it would be better
// to properly escape it in any case.
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1)
multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)
if visionModel {
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1)
} else {
parts = []string{prompt}
}
var contexts contextList
runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
for _, ctx := range ctxs {
ctx.Close()
}
}, contexts.list)
postTokenize := false
for i, part := range parts {
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part)
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
if err != nil {
return nil, err
return nil, nil, err
}
for _, t := range tokens {
inputs = append(inputs, input{token: t})
inputs = append(inputs, input.Input{Token: t})
}
// image - decode and store
@@ -183,19 +199,34 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
}
if imageIndex < 0 {
return nil, fmt.Errorf("invalid image index: %d", n)
return nil, nil, fmt.Errorf("invalid image index: %d", n)
}
image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data))
ctx := s.model.Backend().NewContext()
contexts.list = append(contexts.list, ctx)
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
if err != nil {
return nil, err
return nil, nil, err
}
inputs = append(inputs, input{image: image})
s.multimodalHash.Reset()
_, _ = s.multimodalHash.Write(images[imageIndex].Data)
imageHash := s.multimodalHash.Sum64()
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true
}
}
return inputs, nil
if visionModel && postTokenize {
var err error
inputs, err = multimodalProcessor.PostTokenize(inputs)
if err != nil {
return nil, nil, err
}
}
return inputs, &contexts, nil
}
type Server struct {
@@ -207,7 +238,7 @@ type Server struct {
model model.Model
// status for external health reporting - loading, ready to serve, etc.
status ServerStatus
status llm.ServerStatus
// current progress on loading the model
progress float32
@@ -236,8 +267,15 @@ type Server struct {
// KV cache
cache *InputCache
// next sequence for prompt processing to avoid starvation
nextSeq int
// multimodalHash generates hashes for comparing equality
// of non-text data
multimodalHash maphash.Hash
// vocab is a llama.cpp vocab required for gammar-based
// constrained generation (json mode, structured outputs)
// TODO: this is temporary until Ollama sampling supports
// constrained generation
vocab *sample.Vocab
}
func (s *Server) allNil() bool {
@@ -310,85 +348,80 @@ func (s *Server) processBatch() error {
}
defer s.mu.Unlock()
var options model.Options
imgSeq := -1
seqIdx := s.nextSeq - 1
for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx]
var batchInputs []int32
var batch input.Batch
for i, seq := range s.seqs {
if seq == nil {
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, "limit")
s.removeSequence(i, "limit")
continue
}
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []input{}
seq.cache.Inputs = []input.Input{}
}
for i, input := range seq.inputs {
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
} else {
break
}
batchSize := s.batchSize
for j, inp := range seq.inputs {
// If we are required to put following inputs into a single batch then extend the
// batch size. Since we are only extending the size the minimum amount possible, this
// will cause a break if we have pending inputs.
minBatch := 1 + inp.SameBatch
if minBatch > batchSize {
batchSize = minBatch
}
if i >= s.batchSize {
if len(seq.pendingInputs)+minBatch > batchSize {
break
}
// TODO(jessegross): Image inputs need to be rethought - it's
// it doesn't work well for different types of models or multiple sequences
if input.image != nil {
if len(seq.pendingInputs) != len(options.Images) {
// If the sum of our working set (already processed tokens, tokens we added to this
// batch, required following tokens) exceeds the context size, then trigger a shift
// now so we don't have to do one later when we can't break the batch.
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
if len(seq.pendingInputs) != 0 {
break
}
if imgSeq != seqIdx && imgSeq != -1 {
s.nextSeq = seqIdx
break
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
imgSeq = seqIdx
options.Images = append(options.Images, input.image)
seq.pendingInputs = append(seq.pendingInputs, input)
continue
}
options.Inputs = append(options.Inputs, input.token)
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
options.Sequences = append(options.Sequences, seq.cache.Id)
seq.iBatch = len(options.Outputs)
if i+1 == len(seq.inputs) {
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
batchInputs = append(batchInputs, inp.Token)
if inp.Multimodal != nil {
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
}
seq.pendingInputs = append(seq.pendingInputs, input)
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id)
seq.iBatch = len(batch.Outputs)
if j+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
}
seq.pendingInputs = append(seq.pendingInputs, inp)
}
seq.inputs = seq.inputs[len(seq.pendingInputs):]
}
if len(options.Inputs) == 0 {
if len(batchInputs) == 0 {
return nil
}
ctx := s.model.Backend().NewContext()
defer ctx.Close()
modelOutput, err := model.Forward(ctx, s.model, options)
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
if err != nil {
return fmt.Errorf("failed to decode batch: %w", err)
}
@@ -403,7 +436,7 @@ func (s *Server) processBatch() error {
// After calling Forward, pending inputs are now in the cache
if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []input{}
seq.pendingInputs = []input.Input{}
}
// don't sample prompt processing
@@ -422,12 +455,13 @@ func (s *Server) processBatch() error {
// if done processing the prompt, generate an embedding and return
if seq.embeddingOnly {
// TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported")
s.removeSequence(i, "")
continue
}
// sample a token
vocabSize := len(logits) / len(options.Outputs)
vocabSize := len(logits) / len(batch.Outputs)
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
if err != nil {
@@ -449,7 +483,7 @@ func (s *Server) processBatch() error {
return err
}
seq.inputs = []input{{token: token}}
seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
@@ -496,75 +530,18 @@ func (s *Server) processBatch() error {
return nil
}
// TODO (jmorganca): use structs from the api package to avoid duplication
// this way the api acts as a proxy instead of using a different api for the
// runner
type Options struct {
api.Runner
NumKeep int `json:"n_keep"`
Seed int `json:"seed"`
NumPredict int `json:"n_predict"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
Temperature float32 `json:"temperature"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
Stop []string `json:"stop"`
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
AspectRatioID int `json:"aspect_ratio_id"`
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Options
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req CompletionRequest
req.Options = Options(api.DefaultOptions())
var req llm.CompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
@@ -575,11 +552,30 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
var grammar *sample.Grammar
var err error
if req.Grammar != "" {
grammar, err = sample.NewGrammar(s.vocab, req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return
}
}
sampler := sample.NewSampler(
req.Options.Temperature,
req.Options.TopK,
req.Options.TopP,
req.Options.MinP,
req.Options.Seed,
grammar,
)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: int32(req.NumKeep),
sampler: sample.Greedy(), // TODO: add support for different samplers when performance is optimized
numPredict: req.Options.NumPredict,
stop: req.Options.Stop,
numKeep: int32(req.Options.NumKeep),
sampler: sampler,
embedding: false,
})
if err != nil {
@@ -601,7 +597,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -628,7 +624,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -639,15 +635,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
// Send the final response
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Stop: true,
StoppedLimit: seq.doneReason == "limit",
Timings: Timings{
PromptN: seq.numPromptInputs,
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
PredictedN: seq.numPredicted,
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
},
doneReason := "stop"
if seq.doneReason == "limit" {
doneReason = "length"
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true,
DoneReason: doneReason,
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numPredicted,
EvalDuration: time.Since(seq.startGenerationTime),
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
}
@@ -658,102 +656,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
type EmbeddingRequest struct {
Content string `json:"content"`
CachePrompt bool `json:"cache_prompt"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
var req EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
slog.Debug("embedding request", "content", req.Content)
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
// Ensure there is a place to put the sequence, released when removed from s.seqs
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting embeddings request due to client closing the connection")
} else {
slog.Error("Failed to acquire semaphore", "error", err)
}
return
}
s.mu.Lock()
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
s.seqs[i] = seq
s.cond.Signal()
found = true
break
}
}
s.mu.Unlock()
if !found {
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
Embedding: embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
type ServerStatus int
const (
ServerStatusReady ServerStatus = iota
ServerStatusLoadingModel
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "ok"
case ServerStatusLoadingModel:
return "loading model"
default:
return "server error"
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&HealthResponse{
Status: s.status.ToString(),
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
Status: s.status,
Progress: s.progress,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
@@ -772,6 +678,7 @@ func (m *multiLPath) String() string {
}
func (s *Server) loadModel(
ctx context.Context,
mpath string,
params ml.BackendParams,
lpath multiLPath,
@@ -781,19 +688,19 @@ func (s *Server) loadModel(
multiUserCache bool,
) {
var err error
s.model, err = model.New(mpath, params)
s.model, err = model.New(ctx, mpath, params)
if err != nil {
panic(err)
}
slog.Info("system", "info", s.model.Backend().SystemInfo(), "threads", params.NumThreads)
s.vocab = sample.NewVocab(mpath)
// TODO(jessegross): LoRA loading
if lpath.String() != "" {
panic("loras are not yet implemented")
}
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache)
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
if err != nil {
panic(err)
}
@@ -807,7 +714,7 @@ func (s *Server) loadModel(
s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
s.status = ServerStatusReady
s.status = llm.ServerStatusReady
s.ready.Done()
}
@@ -818,7 +725,7 @@ func Execute(args []string) error {
batchSize := fs.Int("batch-size", 512, "Batch size")
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
_ = fs.Bool("flash-attn", false, "Enable flash attention")
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
port := fs.Int("port", 8080, "Port to expose the server on")
@@ -859,11 +766,10 @@ func Execute(args []string) error {
server := &Server{
batchSize: *batchSize,
status: ServerStatusLoadingModel,
status: llm.ServerStatusLoadingModel,
}
// TODO(jessegross): Parameters that need to be implemented:
// flash-attn
// no-mmap
// mlock
@@ -878,33 +784,42 @@ func Execute(args []string) error {
}
params := ml.BackendParams{
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,
TensorSplit: tensorSplitFloats,
Progress: func(progress float32) {
server.progress = progress
},
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,
TensorSplit: tensorSplitFloats,
FlashAttention: *flashAttention,
}
server.ready.Add(1)
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
ctx, cancel := context.WithCancel(context.Background())
go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port)
listener, err := net.Listen("tcp", addr)
if err != nil {
fmt.Println("Listen error:", err)
cancel()
return err
}
defer listener.Close()
mux := http.NewServeMux()
mux.HandleFunc("/embedding", server.embeddings)
mux.HandleFunc("/completion", server.completion)
mux.HandleFunc("/health", server.health)
// TODO: support embeddings
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
})
mux.HandleFunc("POST /completion", server.completion)
mux.HandleFunc("GET /health", server.health)
httpServer := http.Server{
Handler: mux,
@@ -916,6 +831,5 @@ func Execute(args []string) error {
return err
}
cancel()
return nil
}

View File

@@ -3,118 +3,224 @@ package sample
import (
"errors"
"math"
"math/rand/v2"
"slices"
"sync"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/stat/sampleuv"
"github.com/ollama/ollama/llama"
)
type Sampler interface {
Sample([]float32) (int32, error)
// token represents information about a single token during sampling
type token struct {
id int32 // The token's unique identifier
value float32 // The raw logit or probability from the model
}
type weighted struct {
src rand.Source
transforms []Transform
type Sampler struct {
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
grammar *Grammar
}
// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279
func Weighted(seed *uint64, transforms ...Transform) Sampler {
var src rand.Source
if seed != nil {
src = rand.NewSource(*seed)
}
return weighted{src: src, transforms: transforms}
}
func (s weighted) Sample(logits []float32) (int32, error) {
logits64 := make([]float64, len(logits))
for i, v := range logits {
logits64[i] = float64(v)
}
for _, t := range s.transforms {
logits64 = t.Apply(logits64)
}
logitsCopy := make([]float64, 0, len(logits))
indices := make([]int, 0, len(logits))
for i, logit := range logits64 {
if !math.IsInf(logit, -1) {
logitsCopy = append(logitsCopy, logit)
indices = append(indices, i)
}
}
if len(logitsCopy) == 0 {
return -1, errors.New("no valid logits found for weighed sampling")
}
probs := softmax(logitsCopy)
w := sampleuv.NewWeighted(probs, s.src)
if idx, ok := w.Take(); ok {
return int32(indices[idx]), nil
}
return -1, errors.New("weighted sampler failed, no valid token found")
}
type greedy struct{}
func Greedy() Sampler {
return greedy{}
}
// Sample returns the index of the maximum value in logits.
func (s greedy) Sample(logits []float32) (int32, error) {
func (s *Sampler) Sample(logits []float32) (int32, error) {
if len(logits) == 0 {
return -1, errors.New("no logits provided for greedy sampling")
return -1, errors.New("sample: no logits provided to sample")
}
maxIdx := 0
tokens := make([]token, len(logits))
for i := range logits {
if logits[i] > logits[maxIdx] {
maxIdx = i
tokens[i].id = int32(i)
tokens[i].value = logits[i]
}
t, err := s.sample(tokens)
if err != nil {
return -1, err
}
if s.grammar != nil {
// optimization: first check if the max logit is accepted by the grammar
// if the max logit is rejected, apply the grammar to all logits (slower)
top := []token{t}
s.grammar.Apply(top)
if !math.IsInf(float64(top[0].value), -1) {
s.grammar.Accept(top[0].id)
return top[0].id, nil
}
// since .sample has side effects of modifying the tokens
// we need to reset them before applying the grammar and
// sampling again
for i := range logits {
tokens[i].id = int32(i)
tokens[i].value = logits[i]
}
s.grammar.Apply(tokens)
t, err = s.sample(tokens)
if err != nil {
return -1, err
}
s.grammar.Accept(t.id)
}
return t.id, nil
}
// greedy returns the highest probability token from the tokens
func greedy(tokens []token) token {
max := tokens[0]
for i := 1; i < len(tokens); i++ {
if tokens[i].value > max.value {
max = tokens[i]
}
}
return int32(maxIdx), nil
return max
}
// sample returns the highest probability token from the tokens
// given sampler parameters. It also has side effects of modifying the tokens
func (s *Sampler) sample(tokens []token) (token, error) {
if s.temperature == 0 {
return greedy(tokens), nil
}
// topK also sorts the tokens in descending order of logits
tokens = topK(tokens, s.topK)
// scale and normalize the tokens in place
temperature(tokens, s.temperature)
softmax(tokens)
tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP)
var r float32
if s.rng != nil {
r = s.rng.Float32()
} else {
r = rand.Float32()
}
// Calculate cumulative sum of probabilities
var sum float32
for i := range tokens {
sum += tokens[i].value
tokens[i].value = sum
}
r *= tokens[len(tokens)-1].value
idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
if token.value < target {
return -1
}
return 1
})
if math.IsNaN(float64(sum)) {
return token{}, errors.New("sample: logits sum to NaN, check model output")
}
return tokens[idx], nil
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
if temperature == 0 {
return Greedy(), nil
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
var rng *rand.Rand
if seed != -1 {
// PCG requires two parameters: sequence and stream
// Use original seed for sequence
sequence := uint64(seed)
// Use golden ratio hash to generate statistically independent seeds
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
}
if temperature < 0.0 {
temperature = 0.0
}
if temperature < 0 || temperature > 2 {
return nil, errors.New("temperature must be between 0 and 2")
if topP < 0.0 {
topP = 0.0
}
if topP >= 1.0 {
topP = 1.0
}
transforms := []Transform{Temperature(temperature)}
if topK != 0 {
if topK <= 0 {
return nil, errors.New("topK must be greater than 0")
}
transforms = append(transforms, TopK(topK))
if minP < 0.0 {
minP = 0.0
}
if minP >= 1.0 {
minP = 1.0
}
if topP != 0 {
if topP < 0 || topP >= 1 {
return nil, errors.New("topP must be between 0 and 1")
}
transforms = append(transforms, TopP(topP))
return Sampler{
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
grammar: grammar,
}
if minP != 0 {
if minP < 0 || minP >= 1 {
return nil, errors.New("minP must be between 0 and 1")
}
transforms = append(transforms, MinP(minP))
}
if seed >= 0 {
seed64 := uint64(seed)
return Weighted(&seed64, transforms...), nil
}
return Weighted(nil, transforms...), nil
}
type Grammar struct {
vocab *Vocab
grammar string
sampler *llama.Sampler
}
func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
v, err := vocab.Load()
if err != nil {
return nil, err
}
return &Grammar{
vocab: vocab,
grammar: grammar,
sampler: llama.NewGrammarSampler(v, grammar),
}, nil
}
func (g *Grammar) Apply(tokens []token) {
tds := make([]llama.TokenData, len(tokens))
for i, token := range tokens {
tds[i].Id = token.id
tds[i].Logit = token.value
}
g.sampler.Apply(tds)
for i := range tokens {
tokens[i].value = tds[i].Logit
}
}
func (g *Grammar) Accept(token int32) {
g.sampler.Accept(token)
}
type Vocab struct {
once sync.Once
vocab *llama.Vocab
err error
path string
}
func NewVocab(path string) *Vocab {
return &Vocab{path: path}
}
// Load returns the lazily-loaded vocabulary
func (v *Vocab) Load() (*llama.Vocab, error) {
v.once.Do(func() {
vocab, err := llama.LoadVocabFromFile(v.path)
if err != nil {
v.err = err
return
}
v.vocab = vocab
})
return v.vocab, v.err
}

View File

@@ -0,0 +1,92 @@
package sample
import (
"fmt"
"math/rand"
"testing"
)
func BenchmarkWeightedSampler(b *testing.B) {
sizes := []int{10, 100, 1000, 10000}
for _, size := range sizes {
b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
logits := make([]float32, size)
for i := range logits {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
}
})
}
configs := []struct {
name string
temperature float32
topK int
topP float32
minP float32
seed int
}{
{"Greedy", 0, -1, 0, 0, -1},
{"Temperature", 0.8, -1, 0, 0, -1},
{"TopK", 0.8, 50, 0, 0, -1},
{"TopP", 0.8, -1, 0.9, 0, -1},
{"MinP", 0.8, -1, 0, 0.05, -1},
{"WithSeed", 0.8, 50, 0, 0, 42},
}
// Fixed size for common vocab size
size := 128000
logits := make([]float32, size)
for i := range logits {
logits[i] = float32(rand.Float64()*10 - 5)
}
for _, tc := range configs {
b.Run("Config"+tc.name, func(b *testing.B) {
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
sampler.Sample(logits)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
}
})
}
// Test with combined transforms separately - topK influences performance greatly
b.Run("TransformCombined", func(b *testing.B) {
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
}
})
}
func BenchmarkGreedySampler(b *testing.B) {
sizes := []int{10, 100, 1000, 10000, 100000}
for _, size := range sizes {
b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
logits := make([]float32, size)
for i := range logits {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0, -1, 0, 0, -1, nil)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
}
})
}
}

View File

@@ -4,12 +4,12 @@ import (
"math"
"math/rand/v2"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWeighted(t *testing.T) {
got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
logits := []float32{-10, 3, -10, -10}
sampler := NewSampler(0, 0, 0, 0, 0, nil)
got, err := sampler.Sample(logits)
if err != nil {
t.Error(err)
return
@@ -19,194 +19,49 @@ func TestWeighted(t *testing.T) {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
if err == nil {
t.Error("expected error for no valid tokens, got index", got)
}
seed := uint64(42)
got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
logits = []float32{-100, -10, 0, 10}
sampler = NewSampler(0, 0, 0, 0, 0, nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
return
}
// With seed 42, we expect a consistent sample
want = int32(3) // This will be deterministic due to the seed
want = int32(3) // Should pick highest probability with this r value
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
}
type testTransform struct {
id int
callOrder *[]int
}
func (ts *testTransform) Apply(logits []float64) []float64 {
if ts.callOrder != nil {
*ts.callOrder = append(*ts.callOrder, ts.id)
}
return logits
}
func TestSample(t *testing.T) {
input := []float32{1, 2, 3, 4}
var callOrder []int
mock1 := &testTransform{
id: 1,
callOrder: &callOrder,
}
mock2 := &testTransform{
id: 2,
callOrder: &callOrder,
}
mock3 := &testTransform{
id: 3,
callOrder: &callOrder,
}
_, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
// Test very high p
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
return
}
wantOrder := []int{1, 2, 3}
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
t.Errorf("call order mismatch (-want +got):\n%s", diff)
}
}
func TestNewSampler(t *testing.T) {
tests := []struct {
name string
temperature float32
topK int
topP float32
minP float32
seed int
wantErr bool
}{
{
name: "no transforms",
// temperature is 0, so greedy should be used
wantErr: false,
},
{
name: "temperature",
temperature: 0.5,
wantErr: false,
},
{
name: "invalid temperature negative",
temperature: -1,
wantErr: true,
},
{
name: "invalid temperature too high",
temperature: 2.1,
wantErr: true,
},
{
name: "top k",
topK: 10,
temperature: 0.8,
wantErr: false,
},
{
name: "invalid top k negative",
topK: -1,
temperature: 0.8,
wantErr: true,
},
{
name: "top p",
topP: 0.9,
temperature: 0.8,
wantErr: false,
},
{
name: "invalid top p negative",
topP: -0.1,
temperature: 0.8,
wantErr: true,
},
{
name: "invalid top p one",
topP: 1.0,
temperature: 0.8,
wantErr: true,
},
{
name: "min p",
minP: 0.2,
temperature: 0.8,
wantErr: false,
},
{
name: "invalid min p negative",
minP: -0.1,
temperature: 0.8,
wantErr: true,
},
{
name: "invalid min p one",
minP: 1.0,
temperature: 0.8,
wantErr: true,
},
{
name: "default values",
temperature: 0.8,
topK: 40,
topP: 0.9,
minP: 0.0,
seed: 0,
wantErr: false,
},
{
name: "all zeroes",
temperature: 0.0,
topK: 0,
topP: 0.0,
minP: 0.0,
seed: 0,
wantErr: false, // all zeroes means no transforms
},
{
name: "all transforms",
temperature: 0.8,
topK: 50,
topP: 0.95,
minP: 0.1,
seed: 42,
wantErr: false,
},
// Should get the token with the highest logit
want = int32(0)
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
if (err != nil) != tt.wantErr {
t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
}
})
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
got, err = sampler.Sample(logits)
if err == nil {
t.Errorf("expected error, got %d", got)
return
}
}
func BenchmarkSample(b *testing.B) {
transforms := []Transform{
Temperature(0.5),
TopK(10),
TopP(0.9),
MinP(0.2),
}
samplers := map[string]Sampler{
"Greedy": Greedy(),
"Weighted": Weighted(nil, transforms...),
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
}
// Generate random logits for benchmarking
logits := make([]float32, 1<<16)
for i := range logits {
logits[i] = rand.Float32()
@@ -215,9 +70,9 @@ func BenchmarkSample(b *testing.B) {
for name, s := range samplers {
b.Run(name, func(b *testing.B) {
b.ResetTimer()
for range b.N {
for b.Loop() {
if _, err := s.Sample(logits); err != nil {
b.Error(err)
b.Fatalf("error sampling: %v", err)
}
}
})

View File

@@ -1,120 +1,130 @@
package sample
import (
"cmp"
"container/heap"
"math"
"slices"
pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
)
type Transform interface {
Apply([]float64) []float64
// tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
type tokenHeap []token
func (h tokenHeap) Len() int { return len(h) }
func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value }
func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *tokenHeap) Push(x any) {
*h = append(*h, x.(token))
}
// TODO(parthsareen): potentially cache softmax values
func softmax(logits []float64) []float64 {
var sum float64
probs := make([]float64, len(logits))
for i, v := range logits {
probs[i] = math.Exp(v)
sum += probs[i]
}
for i := range probs {
probs[i] /= sum
}
return probs
func (h *tokenHeap) Pop() any {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
type Temperature float64
func (t Temperature) Apply(logits []float64) []float64 {
temp := math.Max(float64(t), 1e-7)
// subtracting max logit to avoid under/overflow
maxLogit := slices.Max(logits)
for i := range logits {
logits[i] = (logits[i] - maxLogit) / temp
// temperature applies scaling to the logits
func temperature(ts []token, temp float32) {
// Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7)
for i := range ts {
ts[i].value = ts[i].value / temp
}
return logits
}
type logitMap struct {
index int
logit float64
}
type TopK int
// TODO(parthsareen): avoid having to check all logits after this transform
func (k TopK) Apply(logits []float64) []float64 {
if int(k) >= len(logits) {
return logits
}
q := pq.NewWith(func(a, b logitMap) int {
return -cmp.Compare(a.logit, b.logit)
})
for i, logit := range logits {
q.Enqueue(logitMap{index: i, logit: logit})
}
validLogits := make(map[int]float64)
for range k {
logitMap, _ := q.Dequeue()
validLogits[logitMap.index] = logitMap.logit
}
for i := range logits {
if _, ok := validLogits[i]; !ok {
logits[i] = math.Inf(-1)
// softmax applies normalization to the logits
func softmax(ts []token) {
// Find max logit for numerical stability
maxLogit := float32(math.Inf(-1))
for _, t := range ts {
if t.value > maxLogit {
maxLogit = t.value
}
}
return logits
}
type TopP float64
func (p TopP) Apply(logits []float64) []float64 {
probs := softmax(logits)
indices := make([]int, len(probs))
for i := range indices {
indices[i] = i
// Compute exp(x - max)
var sum float32
for i, v := range ts {
ts[i].value = float32(math.Exp(float64(v.value - maxLogit)))
sum += ts[i].value
}
// sort in descending order
slices.SortFunc(indices, func(i, j int) int {
return cmp.Compare(probs[j], probs[i])
})
// exp(x - max) / sum(exp(x - max))
for i := range ts {
ts[i].value /= sum
}
}
var sum float64
for i, idx := range indices {
sum += probs[idx]
if sum > float64(p) {
for _, idx := range indices[i+1:] {
logits[idx] = math.Inf(-1)
// topK limits the number of tokens considered to the k highest logits
func topK(ts []token, k int) []token {
if k >= len(ts) || k <= 0 {
slices.SortFunc(ts, func(a, b token) int {
switch {
case a.value < b.value:
return 1
case a.value > b.value:
return -1
default:
return 0
}
break
}
})
return ts
}
return logits
}
type MinP float64
// Initialize min-heap with first k elements
h := make(tokenHeap, k)
copy(h, ts[:k])
heap.Init(&h)
func (p MinP) Apply(logits []float64) []float64 {
probs := softmax(logits)
threshold := slices.Max(probs) * float64(p)
for i, prob := range probs {
if prob < threshold {
logits[i] = math.Inf(-1)
// Process remaining elements
for i := k; i < len(ts); i++ {
if ts[i].value > h[0].value {
heap.Pop(&h)
heap.Push(&h, ts[i])
}
}
return logits
// Convert heap to sorted slice in descending order
result := make([]token, len(h))
for i := k - 1; i >= 0; i-- {
result[i] = heap.Pop(&h).(token)
}
return result
}
// topP limits tokens to those with cumulative probability p
// requires ts to be sorted in descending order of probabilities
func topP(ts []token, p float32) []token {
if p == 1.0 {
return ts
}
// Find cutoff index where cumulative sum exceeds p
var sum float32
for i, t := range ts {
sum += t.value
if sum > float32(p) {
return ts[:i+1]
}
}
return ts
}
// minP filters tokens with probabilities >= p * max_prob
// requires ts to be sorted in descending order of probabilities
func minP(ts []token, p float32) []token {
maxProb := ts[0].value
threshold := maxProb * p
for i, t := range ts {
if t.value < threshold {
return ts[:i]
}
}
return ts
}

View File

@@ -4,77 +4,354 @@ import (
"math"
"math/rand/v2"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestTemperature(t *testing.T) {
got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
want := []float64{-4, -10, 0, -14, -6, -12, -8}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
// Helper to convert float32 slice to logit slice
func toTokens(values []float32) []token {
tokens := make([]token, len(values))
for i, v := range values {
tokens[i] = token{
id: int32(i),
value: v,
}
}
return tokens
}
// Helper to compare logit slices
func compareLogits(t *testing.T, name string, want []float32, got []token) {
t.Helper()
if len(want) != len(got) {
t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
return
}
for i := range want {
if math.Abs(float64(got[i].value-want[i])) > 1e-6 {
t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
}
}
}
func TestTemperature(t *testing.T) {
input := []float32{1.0, 4.0, -2.0, 0.0}
tokens := toTokens(input)
temperature(tokens, 0.5)
want := []float32{2.0, 8.0, -4.0, 0.0}
compareLogits(t, "temperature(0.5)", want, tokens)
input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 1.0)
want = []float32{1.0, 4.0, -2.0, 0.0}
compareLogits(t, "temperature(1)", want, tokens)
input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 0.0)
want = []float32{1e7, 4e7, -2e7, 0.0}
compareLogits(t, "temperature(0)", want, tokens)
}
func TestSoftmax(t *testing.T) {
got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("probs mismatch (-want +got):\n%s", diff)
}
}
func TestTopK(t *testing.T) {
got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
tests := []struct {
name string
input []float32
expected []float32
}{
{
name: "correctness softmax",
input: []float32{1, -2, 3, 0},
expected: []float32{0.113550, 0.005653, 0.839024, 0.041773},
},
{
name: "normal distribution",
input: []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367},
},
{
name: "single value",
input: []float32{1.0},
},
{
name: "identical values",
input: []float32{0.9, 0.9, 0.9},
},
{
name: "large values",
input: []float32{1000.0, 2000.0, 3000.0},
},
{
name: "small values",
input: []float32{1e-6, 2e-6, 3e-6},
},
{
name: "negative values",
input: []float32{-1.0, -2.0, -3.0},
},
{
name: "mixed values",
input: []float32{-100.0, 0.0, 100.0},
},
}
got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokens := toTokens(tt.input)
softmax(tokens)
want = []float64{-3, -2, -1, 0, 1, 2, 4}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
}
if tt.expected != nil {
compareLogits(t, tt.name, tt.expected, tokens)
return
}
func TestTopP(t *testing.T) {
got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
}
func TestMinP(t *testing.T) {
got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
}
func BenchmarkTransform(b *testing.B) {
transforms := map[string]Transform{
"Temperature": Temperature(0.5),
"TopK": TopK(10),
"TopP": TopP(0.9),
"MinP": MinP(0.2),
}
logits := make([]float64, 1<<16)
for i := range logits {
logits[i] = rand.Float64()
}
for name, transform := range transforms {
b.Run(name, func(b *testing.B) {
b.ResetTimer()
for range b.N {
transform.Apply(logits)
// Check probabilities sum to 1
var sum float32
for _, token := range tokens {
sum += token.value
if token.value < 0 || token.value > 1 {
t.Errorf("probability out of range [0,1]: got %f", token.value)
}
}
if math.Abs(float64(sum-1.0)) > 1e-6 {
t.Errorf("probabilities don't sum to 1: got %f", sum)
}
})
}
}
func TestTopK(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
tokens := toTokens(input)
tokens = topK(tokens, 5)
if len(tokens) != 5 {
t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
}
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
compareLogits(t, "topK(3)", want, tokens)
tokens = toTokens(input)
tokens = topK(tokens, 20)
if len(tokens) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
}
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
tokens = toTokens(input)
tokens = topK(tokens, -1)
if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
}
compareLogits(t, "topK(-1)", want, tokens)
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
tokens = toTokens(input)
tokens = topK(tokens, 0)
if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
}
compareLogits(t, "topK(-1)", want, tokens)
input = []float32{-1e7, -2e7, -3e7, -4e7}
tokens = toTokens(input)
tokens = topK(tokens, 1)
if len(tokens) < 1 {
t.Error("topK should keep at least one token")
}
}
func TestTopP(t *testing.T) {
input := []float32{-3, -2, -1, 0, 1, 2, 4}
tokens := toTokens(input)
// First apply temperature and softmax to get probabilities
softmax(tokens)
tokens = topK(tokens, 20)
// Test with very high p value
got := topP(tokens, 1.0)
// Should keep all tokens since p is 1
if len(got) != len(input) {
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
}
// Test with normal p value
got = topP(tokens, 0.95)
if len(got) > 3 {
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", got)
}
// Test edge case - ensure at least one token remains
input = []float32{-1e6, -1e6, -1e7}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
got = topP(tokens, 0.0)
if len(got) < 1 {
t.Error("topP should keep at least one token")
}
// Test with zero p value
got = topP(tokens, 0.0)
// Should keep only the highest probability token
if len(got) != 1 {
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
t.Logf("got: %v", got)
}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
got = topP(tokens, 1e-10)
if len(got) == 0 {
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
t.Logf("got: %v", got)
}
}
func TestMinP(t *testing.T) {
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
tokens := toTokens(input)
// First apply temperature and softmax
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 1.0)
if len(tokens) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
}
// Test with normal p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob
if len(tokens) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", tokens)
}
// Test with zero p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.0)
// Should keep only the highest probability token
if len(tokens) != len(input) {
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
// Test with single token
tokens = toTokens(input[:1])
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.1)
// Should keep only the highest probability token
if len(tokens) != 1 {
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
input = []float32{1e-10, 1e-10, 1e-10}
tokens = toTokens(input)
softmax(tokens)
tokens = minP(tokens, 1.0)
if len(tokens) < 1 {
t.Error("minP should keep at least one token even with extreme probabilities")
got := minP(tokens, 1.0)
if len(got) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
}
// Test with normal p value
got = minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob
if len(got) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
t.Logf("got: %v", got)
}
// Test with zero p value
got = minP(tokens, 0.0)
// Should keep only the highest probability token
if len(got) != len(tokens) {
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
t.Logf("got: %v", got)
}
}
}
func BenchmarkTransforms(b *testing.B) {
// Generate random logits
tokens := make([]token, 1<<16)
for i := range tokens {
tokens[i] = token{
id: int32(i),
value: rand.Float32(),
}
}
tokensCopy := make([]token, len(tokens))
b.Run("Temperature", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
temperature(tokensCopy, 0.5)
}
})
b.Run("Softmax", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
softmax(tokensCopy)
}
})
b.Run("TopK", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
tokens = topK(tokensCopy, 10)
}
})
b.Run("TopP", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
tokens = topP(tokensCopy, 0.9)
}
})
b.Run("MinP", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
tokens = minP(tokensCopy, 0.2)
}
})
b.Run("SortTokens", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
tokens = topK(tokensCopy, 200000)
}
})
}

View File

@@ -8,7 +8,7 @@ usage() {
exit 1
}
export VERSION=${VERSION:-$(git describe --tags --dirty)}
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
export CGO_CPPFLAGS='-mmacosx-version-min=11.3'

View File

@@ -80,13 +80,14 @@ function checkEnv() {
function buildOllama() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
New-Item "${script:SRC_DIR}\dist\windows-${script:ARCH}\lib\ollama\" -ItemType Directory -ea 0
& cmake --fresh --preset CPU --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset CPU --parallel $script:JOBS
& cmake --build --preset CPU --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component CPU --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -101,7 +102,7 @@ function buildOllama() {
# to avoid 2022 (or newer) from being used as the default
& cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 11" --parallel $script:JOBS
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -112,7 +113,7 @@ function buildOllama() {
write-host "Building CUDA v12 backend libraries"
& cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 12" --parallel $script:JOBS
& cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -131,7 +132,7 @@ function buildOllama() {
$env:HIPCXX=""
$env:HIP_PLATFORM=""
$env:CMAKE_PREFIX_PATH=""
& cmake --build --preset "ROCm" --parallel $script:JOBS
& cmake --build --preset "ROCm" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "HIP" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}

View File

@@ -77,11 +77,12 @@ if [ -d "$OLLAMA_INSTALL_DIR/lib/ollama" ] ; then
fi
status "Installing ollama to $OLLAMA_INSTALL_DIR"
$SUDO install -o0 -g0 -m755 -d $BINDIR
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR"
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
status "Downloading Linux ${ARCH} bundle"
curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
status "Making ollama accessible in the PATH in $BINDIR"
$SUDO ln -sf "$OLLAMA_INSTALL_DIR/ollama" "$BINDIR/ollama"

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"net/http"
"os"
@@ -34,6 +35,7 @@ var (
errOnlyGGUFSupported = errors.New("supplied file was not in GGUF format")
errUnknownType = errors.New("unknown type")
errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified")
errFilePath = errors.New("file path must be relative")
)
func (s *Server) CreateHandler(c *gin.Context) {
@@ -46,6 +48,13 @@ func (s *Server) CreateHandler(c *gin.Context) {
return
}
for v := range r.Files {
if !fs.ValidPath(v) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
return
}
}
name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
@@ -104,7 +113,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
if r.Adapters != nil {
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
if err != nil {
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType} {
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
if errors.Is(err, badReq) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
return
@@ -221,8 +230,22 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return nil, err
}
defer os.RemoveAll(tmpDir)
// Set up a root to validate paths
root, err := os.OpenRoot(tmpDir)
if err != nil {
return nil, err
}
defer root.Close()
for fp, digest := range files {
if !fs.ValidPath(fp) {
return nil, fmt.Errorf("%w: %s", errFilePath, fp)
}
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
// Path is likely outside the root
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
}
blobPath, err := GetBlobsPath(digest)
if err != nil {
return nil, err
@@ -270,6 +293,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
if err != nil {
return nil, err
}
defer bin.Close()
f, _, err := ggml.Decode(bin, 0)
if err != nil {

106
server/create_test.go Normal file
View File

@@ -0,0 +1,106 @@
package server
import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/api"
)
func TestConvertFromSafetensors(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
// Helper function to create a new layer and return its digest
makeTemp := func(content string) string {
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
if err != nil {
t.Fatalf("Failed to create layer: %v", err)
}
return l.Digest
}
// Create a safetensors compatible file with empty JSON content
var buf bytes.Buffer
headerSize := int64(len("{}"))
binary.Write(&buf, binary.LittleEndian, headerSize)
buf.WriteString("{}")
model := makeTemp(buf.String())
config := makeTemp(`{
"architectures": ["LlamaForCausalLM"],
"vocab_size": 32000
}`)
tokenizer := makeTemp(`{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{
"id": 0,
"content": "<|endoftext|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
]
}`)
tests := []struct {
name string
filePath string
wantErr error
}{
// Invalid
{
name: "InvalidRelativePathShallow",
filePath: filepath.Join("..", "file.safetensors"),
wantErr: errFilePath,
},
{
name: "InvalidRelativePathDeep",
filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
wantErr: errFilePath,
},
{
name: "InvalidNestedPath",
filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
wantErr: errFilePath,
},
{
name: "AbsolutePathOutsideRoot",
filePath: filepath.Join(os.TempDir(), "model.safetensors"),
wantErr: errFilePath, // Should fail since it's outside tmpDir
},
{
name: "ValidRelativePath",
filePath: "model.safetensors",
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create the minimum required file map for convertFromSafetensors
files := map[string]string{
tt.filePath: model,
"config.json": config,
"tokenizer.json": tokenizer,
}
_, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
if (tt.wantErr == nil && err != nil) ||
(tt.wantErr != nil && err == nil) ||
(tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -146,7 +146,7 @@ func debugger(err *error) func(step string) {
// be in either of the following forms:
//
// @<digest>
// <name>
// <name>@<digest>
// <name>
//
// If a digest is provided, it is returned as is and nothing else happens.
@@ -160,8 +160,6 @@ func debugger(err *error) func(step string) {
// hashed is passed to a PutBytes call to ensure that the manifest is in the
// blob store. This is done to ensure that future calls to [Get] succeed in
// these cases.
//
// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
func (c *DiskCache) Resolve(name string) (Digest, error) {
name, digest := splitNameDigest(name)
if digest != "" {
@@ -279,18 +277,6 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
// It returns an error if either the name or digest is invalid, or if link
// creation encounters any issues.
func (c *DiskCache) Link(name string, d Digest) error {
// TODO(bmizerany): Move link handling from cache to registry.
//
// We originally placed links in the cache due to its storage
// knowledge. However, the registry likely offers better context for
// naming concerns, and our API design shouldn't be tightly coupled to
// our on-disk format.
//
// Links work effectively when independent from physical location -
// they can reference content with matching SHA regardless of storage
// location. In an upcoming change, we plan to shift this
// responsibility to the registry where it better aligns with the
// system's conceptual model.
manifest, err := c.manifestPath(name)
if err != nil {
return err
@@ -341,7 +327,9 @@ func (c *DiskCache) GetFile(d Digest) string {
return absJoin(c.dir, "blobs", filename)
}
// Links returns a sequence of links in the cache in lexical order.
// Links returns a sequence of link names. The sequence is in lexical order.
// Names are converted from their relative path form to their name form but are
// not guaranteed to be valid. Callers should validate the names before using.
func (c *DiskCache) Links() iter.Seq2[string, error] {
return func(yield func(string, error) bool) {
for path, err := range c.links() {
@@ -414,12 +402,14 @@ func (c *DiskCache) links() iter.Seq2[string, error] {
}
type checkWriter struct {
d Digest
size int64
n int64
h hash.Hash
d Digest
f *os.File
err error
h hash.Hash
w io.Writer // underlying writer; set by creator
n int64
err error
testHookBeforeFinalWrite func(*os.File)
}
@@ -435,6 +425,10 @@ func (w *checkWriter) seterr(err error) error {
// underlying writer is guaranteed to be the last byte of p as verified by the
// hash.
func (w *checkWriter) Write(p []byte) (int, error) {
if w.err != nil {
return 0, w.err
}
_, err := w.h.Write(p)
if err != nil {
return 0, w.seterr(err)
@@ -453,7 +447,7 @@ func (w *checkWriter) Write(p []byte) (int, error) {
if nextSize > w.size {
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
}
n, err := w.f.Write(p)
n, err := w.w.Write(p)
w.n += int64(n)
return n, w.seterr(err)
}
@@ -493,10 +487,12 @@ func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size
// Copy file to f, but also into h to double-check hash.
cw := &checkWriter{
d: out,
size: size,
h: sha256.New(),
f: f,
d: out,
size: size,
h: sha256.New(),
f: f,
w: f,
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
}
n, err := io.Copy(cw, file)
@@ -532,11 +528,6 @@ func splitNameDigest(s string) (name, digest string) {
var errInvalidName = errors.New("invalid name")
func nameToPath(name string) (_ string, err error) {
if strings.Contains(name, "@") {
// TODO(bmizerany): HACK: Fix names.Parse to validate.
// TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
return "", errInvalidName
}
n := names.Parse(name)
if !n.IsFullyQualified() {
return "", errInvalidName
@@ -547,8 +538,7 @@ func nameToPath(name string) (_ string, err error) {
func absJoin(pp ...string) string {
abs, err := filepath.Abs(filepath.Join(pp...))
if err != nil {
// Likely a bug bug or a bad OS problem. Just panic.
panic(err)
panic(err) // this should never happen
}
return abs
}

73
server/internal/cache/blob/chunked.go vendored Normal file
View File

@@ -0,0 +1,73 @@
package blob
import (
"crypto/sha256"
"errors"
"io"
"os"
)
// Chunk represents a range of bytes in a blob.
type Chunk struct {
Start int64
End int64
}
// Size returns end minus start plus one.
func (c Chunk) Size() int64 {
return c.End - c.Start + 1
}
// Chunker writes to a blob in chunks.
// Its zero value is invalid. Use [DiskCache.Chunked] to create a new Chunker.
type Chunker struct {
digest Digest
size int64
f *os.File // nil means pre-validated
}
// Chunked returns a new Chunker, ready for use storing a blob of the given
// size in chunks.
//
// Use [Chunker.Put] to write data to the blob at specific offsets.
func (c *DiskCache) Chunked(d Digest, size int64) (*Chunker, error) {
name := c.GetFile(d)
info, err := os.Stat(name)
if err == nil && info.Size() == size {
return &Chunker{}, nil
}
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0o666)
if err != nil {
return nil, err
}
return &Chunker{digest: d, size: size, f: f}, nil
}
// Put copies chunk.Size() bytes from r to the blob at the given offset,
// merging the data with the existing blob. It returns an error if any. As a
// special case, if r has less than chunk.Size() bytes, Put returns
// io.ErrUnexpectedEOF.
func (c *Chunker) Put(chunk Chunk, d Digest, r io.Reader) error {
if c.f == nil {
return nil
}
cw := &checkWriter{
d: d,
size: chunk.Size(),
h: sha256.New(),
f: c.f,
w: io.NewOffsetWriter(c.f, chunk.Start),
}
_, err := io.CopyN(cw, r, chunk.Size())
if err != nil && errors.Is(err, io.EOF) {
return io.ErrUnexpectedEOF
}
return err
}
// Close closes the underlying file.
func (c *Chunker) Close() error {
return c.f.Close()
}

View File

@@ -63,6 +63,10 @@ func (d Digest) Short() string {
return fmt.Sprintf("%x", d.sum[:4])
}
func (d Digest) Sum() [32]byte {
return d.sum
}
func (d Digest) Compare(other Digest) int {
return slices.Compare(d.sum[:], other.sum[:])
}

View File

@@ -1,78 +0,0 @@
package chunks
import (
"fmt"
"iter"
"strconv"
"strings"
)
type Chunk struct {
Start, End int64
}
func New(start, end int64) Chunk {
return Chunk{start, end}
}
// ParseRange parses a string in the form "unit=range" where unit is a string
// and range is a string in the form "start-end". It returns the unit and the
// range as a Chunk.
func ParseRange(s string) (unit string, _ Chunk, _ error) {
unit, r, _ := strings.Cut(s, "=")
if r == "" {
return unit, Chunk{}, nil
}
c, err := Parse(r)
if err != nil {
return "", Chunk{}, err
}
return unit, c, err
}
// Parse parses a string in the form "start-end" and returns the Chunk.
func Parse(s string) (Chunk, error) {
startStr, endStr, _ := strings.Cut(s, "-")
start, err := strconv.ParseInt(startStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid start: %v", err)
}
end, err := strconv.ParseInt(endStr, 10, 64)
if err != nil {
return Chunk{}, fmt.Errorf("invalid end: %v", err)
}
if start > end {
return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
}
return Chunk{start, end}, nil
}
// Of returns a sequence of contiguous Chunks of size chunkSize that cover
// the range [0, size), in order.
func Of(size, chunkSize int64) iter.Seq[Chunk] {
return func(yield func(Chunk) bool) {
for start := int64(0); start < size; start += chunkSize {
end := min(start+chunkSize-1, size-1)
if !yield(Chunk{start, end}) {
break
}
}
}
}
// Count returns the number of Chunks of size chunkSize needed to cover the
// range [0, size).
func Count(size, chunkSize int64) int64 {
return (size + chunkSize - 1) / chunkSize
}
// Size returns end minus start plus one.
func (c Chunk) Size() int64 {
return c.End - c.Start + 1
}
// String returns the string representation of the Chunk in the form
// "{start}-{end}".
func (c Chunk) String() string {
return fmt.Sprintf("%d-%d", c.Start, c.End)
}

View File

@@ -1,65 +0,0 @@
package chunks
import (
"slices"
"testing"
)
func TestOf(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want []Chunk
}{
{0, 1, nil},
{1, 1, []Chunk{{0, 0}}},
{1, 2, []Chunk{{0, 0}}},
{2, 1, []Chunk{{0, 0}, {1, 1}}},
{10, 9, []Chunk{{0, 8}, {9, 9}}},
}
for _, tt := range cases {
got := slices.Collect(Of(tt.total, tt.chunkSize))
if !slices.Equal(got, tt.want) {
t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want)
}
}
}
func TestSize(t *testing.T) {
cases := []struct {
c Chunk
want int64
}{
{Chunk{0, 0}, 1},
{Chunk{0, 1}, 2},
{Chunk{3, 4}, 2},
}
for _, tt := range cases {
got := tt.c.Size()
if got != tt.want {
t.Errorf("%v: got %d; want %d", tt.c, got, tt.want)
}
}
}
func TestCount(t *testing.T) {
cases := []struct {
total int64
chunkSize int64
want int64
}{
{0, 1, 0},
{1, 1, 1},
{1, 2, 1},
{2, 1, 2},
{10, 9, 2},
}
for _, tt := range cases {
got := Count(tt.total, tt.chunkSize)
if got != tt.want {
t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want)
}
}
}

View File

@@ -19,13 +19,17 @@ import (
"fmt"
"io"
"io/fs"
"iter"
"log/slog"
"net/http"
"os"
"path/filepath"
"runtime"
"runtime/debug"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@@ -33,19 +37,16 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/internal/backoff"
"github.com/ollama/ollama/server/internal/internal/names"
"github.com/ollama/ollama/server/internal/internal/syncs"
_ "embed"
)
// Errors
var (
// ErrManifestNotFound is returned when a manifest is not found in the
// ErrModelNotFound is returned when a manifest is not found in the
// cache or registry.
ErrManifestNotFound = errors.New("manifest not found")
ErrModelNotFound = errors.New("model not found")
// ErrManifestInvalid is returned when a manifest found in a local or
// remote cache is invalid.
@@ -53,38 +54,41 @@ var (
// ErrMissingModel is returned when the model part of a name is missing
// or invalid.
ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}")
ErrNameInvalid = errors.New("invalid or missing name")
// ErrCached is passed to [Trace.PushUpdate] when a layer already
// exists. It is a non-fatal error and is never returned by [Registry.Push].
ErrCached = errors.New("cached")
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
// incomplete due to one or more layer download failures. Users that
// want specific errors should use [WithTrace].
ErrIncomplete = errors.New("incomplete")
)
// Defaults
const (
// DefaultChunkingThreshold is the threshold at which a layer should be
// split up into chunks when downloading.
DefaultChunkingThreshold = 128 << 20
// DefaultMaxChunkSize is the default maximum size of a chunk to
// download. It is configured based on benchmarks and aims to strike a
// balance between download speed and memory usage.
DefaultMaxChunkSize = 8 << 20
DefaultChunkingThreshold = 64 << 20
)
// DefaultCache returns a new disk cache for storing models. If the
// OLLAMA_MODELS environment variable is set, it uses that directory;
// otherwise, it uses $HOME/.ollama/models.
func DefaultCache() (*blob.DiskCache, error) {
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
dir := os.Getenv("OLLAMA_MODELS")
if dir == "" {
home, err := os.UserHomeDir()
if err != nil {
return nil, err
}
home, _ := os.UserHomeDir()
home = cmp.Or(home, ".")
dir = filepath.Join(home, ".ollama", "models")
}
return blob.Open(dir)
})
// DefaultCache returns the default cache used by the registry. It is
// configured from the OLLAMA_MODELS environment variable, or defaults to
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
// it uses the current working directory.
func DefaultCache() (*blob.DiskCache, error) {
return defaultCache()
}
// Error is the standard error returned by Ollama APIs. It can represent a
@@ -109,7 +113,18 @@ type Error struct {
}
func (e *Error) Error() string {
return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
var b strings.Builder
b.WriteString("registry responded with status ")
b.WriteString(strconv.Itoa(e.Status))
if e.Code != "" {
b.WriteString(": code ")
b.WriteString(e.Code)
}
if e.Message != "" {
b.WriteString(": ")
b.WriteString(e.Message)
}
return b.String()
}
func (e *Error) LogValue() slog.Value {
@@ -147,17 +162,30 @@ func (e *Error) UnmarshalJSON(b []byte) error {
return nil
}
var defaultName = func() names.Name {
n := names.Parse("registry.ollama.ai/library/_:latest")
const DefaultMask = "registry.ollama.ai/library/_:latest"
var defaultMask = func() names.Name {
n := names.Parse(DefaultMask)
if !n.IsFullyQualified() {
panic("default name is not fully qualified")
panic("default mask is not fully qualified")
}
return n
}()
// CompleteName returns a fully qualified name by merging the given name with
// the default mask. If the name is already fully qualified, it is returned
// unchanged.
func CompleteName(name string) string {
return names.Merge(names.Parse(name), defaultMask).String()
}
// Registry is a client for performing push and pull operations against an
// Ollama registry.
type Registry struct {
// Cache is the cache used to store models. If nil, [DefaultCache] is
// used.
Cache *blob.DiskCache
// UserAgent is the User-Agent header to send with requests to the
// registry. If empty, the User-Agent is determined by HTTPClient.
UserAgent string
@@ -182,24 +210,35 @@ type Registry struct {
// pushing or pulling models. If zero, the number of streams is
// determined by [runtime.GOMAXPROCS].
//
// Clients that want "unlimited" streams should set this to a large
// number.
// A negative value means no limit.
MaxStreams int
// ChunkingThreshold is the maximum size of a layer to download in a single
// request. If zero, [DefaultChunkingThreshold] is used.
ChunkingThreshold int64
// MaxChunkSize is the maximum size of a chunk to download. If zero,
// the default is [DefaultMaxChunkSize].
//
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
// Mask, if set, is the name used to convert non-fully qualified names
// to fully qualified names. If empty, [DefaultMask] is used.
Mask string
}
// NameMask, if set, is the name used to convert non-fully qualified
// names to fully qualified names. If empty, the default mask
// ("registry.ollama.ai/library/_:latest") is used.
NameMask string
func (r *Registry) cache() (*blob.DiskCache, error) {
if r.Cache != nil {
return r.Cache, nil
}
return defaultCache()
}
func (r *Registry) parseName(name string) (names.Name, error) {
mask := defaultMask
if r.Mask != "" {
mask = names.Parse(r.Mask)
}
n := names.Merge(names.Parse(name), mask)
if !n.IsFullyQualified() {
return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
}
return n, nil
}
// DefaultRegistry returns a new Registry configured from the environment. The
@@ -219,6 +258,7 @@ func DefaultRegistry() (*Registry, error) {
}
var rc Registry
rc.UserAgent = UserAgent()
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
if err != nil {
return nil, err
@@ -234,78 +274,53 @@ func DefaultRegistry() (*Registry, error) {
return &rc, nil
}
type PushParams struct {
// From is an optional destination name for the model. If empty, the
// destination name is the same as the source name.
From string
}
func UserAgent() string {
buildinfo, _ := debug.ReadBuildInfo()
// parseName parses name using [names.ParseExtended] and then merges the name with the
// default name, and checks that the name is fully qualified. If a digest is
// present, it parse and returns it with the other fields as their zero values.
//
// It returns an error if the name is not fully qualified, or if the digest, if
// any, is invalid.
//
// The scheme is returned as provided by [names.ParseExtended].
func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
maskName := defaultName
if mask != "" {
maskName = names.Parse(mask)
if !maskName.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("invalid name mask: %s", mask)
}
}
scheme, n, ds := names.ParseExtended(s)
if !n.IsValid() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
}
n = names.Merge(n, maskName)
if ds != "" {
// Digest is present. Validate it.
d, err = blob.ParseDigest(ds)
if err != nil {
return "", names.Name{}, blob.Digest{}, err
}
version := buildinfo.Main.Version
if version == "(devel)" {
// When using `go run .` the version is "(devel)". This is seen
// as an invalid version by ollama.com and so it defaults to
// "needs upgrade" for some requests, such as pulls. These
// checks can be skipped by using the special version "v0.0.0",
// so we set it to that here.
version = "v0.0.0"
}
// The name check is deferred until after the digest check because we
// say that digests take precedence over names, and so should there
// errors when being parsed.
if !n.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
}
scheme = cmp.Or(scheme, "https")
return scheme, n, d, nil
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
version,
runtime.GOARCH,
runtime.GOOS,
runtime.Version(),
)
}
func (r *Registry) maxStreams() int {
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
// Large downloads require a writter stream, so ensure we have at least
// two streams to avoid a deadlock.
return max(n, 2)
return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
}
func (r *Registry) maxChunkingThreshold() int64 {
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
}
// chunkSizeFor returns the chunk size for a layer of the given size. If the
// size is less than or equal to the max chunking threshold, the size is
// returned; otherwise, the max chunk size is returned.
func (r *Registry) maxChunkSize() int64 {
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
type PushParams struct {
// From is an optional destination name for the model. If empty, the
// destination name is the same as the source name.
From string
}
// Push pushes the model with the name in the cache to the remote registry.
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
if p == nil {
p = &PushParams{}
}
m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
c, err := r.cache()
if err != nil {
return err
}
m, err := r.ResolveLocal(cmp.Or(p.From, name))
if err != nil {
return err
}
@@ -328,7 +343,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
t := traceFromContext(ctx)
scheme, n, _, err := parseName(name, r.NameMask)
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
// This should never happen since ResolveLocal should have
// already validated the name.
@@ -354,7 +369,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
n.Model(),
l.Digest,
)
res, err := r.doOK(ctx, "POST", startURL, nil)
res, err := r.send(ctx, "POST", startURL, nil)
if err != nil {
return err
}
@@ -378,7 +393,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
}
req.ContentLength = l.Size
res, err = doOK(r.client(), req)
res, err = sendRequest(r.client(), req)
if err == nil {
res.Body.Close()
}
@@ -398,7 +413,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
n.Model(),
n.Tag(),
)
res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data))
res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data))
if err == nil {
res.Body.Close()
}
@@ -414,6 +429,22 @@ func canRetry(err error) bool {
return re.Status >= 500
}
// trackingReader is an io.Reader that tracks the number of bytes read and
// calls the update function with the layer, the number of bytes read.
//
// It always calls update with a nil error.
type trackingReader struct {
l *Layer
r io.Reader
update func(l *Layer, n int64, err error)
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
r.update(r.l, int64(n), nil)
return
}
// Pull pulls the model with the given name from the remote registry into the
// cache.
//
@@ -421,159 +452,135 @@ func canRetry(err error) bool {
// chunks of the specified size, and then reassembled and verified. This is
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
scheme, n, _, err := parseName(name, r.NameMask)
if err != nil {
return err
}
func (r *Registry) Pull(ctx context.Context, name string) error {
m, err := r.Resolve(ctx, name)
if err != nil {
return err
}
// TODO(bmizerany): decide if this should be considered valid. Maybe
// server-side we special case '{}' to have some special meaning? Maybe
// "archiving" a tag (which is how we reason about it in the registry
// already, just with a different twist).
if len(m.Layers) == 0 {
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
}
exists := func(l *Layer) bool {
info, err := c.Get(l.Digest)
return err == nil && info.Size == l.Size
c, err := r.cache()
if err != nil {
return err
}
t := traceFromContext(ctx)
// TODO(bmizerany): work to remove the need to do this
layers := m.Layers
if m.Config != nil && m.Config.Digest.IsValid() {
layers = append(layers, m.Config)
}
// Send initial layer trace events to allow clients to have an
// understanding of work to be done before work starts.
var expected int64
t := traceFromContext(ctx)
for _, l := range layers {
t.update(l, 0, nil)
expected += l.Size
}
var received atomic.Int64
var g errgroup.Group
g.SetLimit(r.maxStreams())
for _, l := range m.Layers {
if exists(l) {
for _, l := range layers {
info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size {
received.Add(l.Size)
t.update(l, l.Size, ErrCached)
continue
}
blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest)
req, err := r.newRequest(ctx, "GET", blobURL, nil)
var wg sync.WaitGroup
chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil {
t.update(l, 0, err)
continue
}
t.update(l, 0, nil)
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil {
// Chunksum stream interrupted. Note in trace
// log and let in-flight downloads complete.
// This will naturally trigger ErrIncomplete
// since received < expected bytes.
t.update(l, 0, err)
break
}
if l.Size <= r.maxChunkingThreshold() {
g.Go(func() error {
res, err := doOK(r.client(), req)
wg.Add(1)
g.Go(func() (err error) {
defer func() {
if err == nil {
received.Add(cs.Chunk.Size())
} else {
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
}
wg.Done()
}()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
err = c.Put(l.Digest, res.Body, l.Size)
if err == nil {
t.update(l, l.Size, nil)
}
return err
body := &trackingReader{l: l, r: res.Body, update: t.update}
return chunked.Put(cs.Chunk, cs.Digest, body)
})
} else {
q := syncs.NewRelayReader()
g.Go(func() (err error) {
defer func() { q.CloseWithError(err) }()
return c.Put(l.Digest, q, l.Size)
})
var progress atomic.Int64
// We want to avoid extra round trips per chunk due to
// redirects from the registry to the blob store, so
// fire an initial request to get the final URL and
// then use that URL for the chunk requests.
req.Header.Set("Range", "bytes=0-0")
res, err := doOK(r.client(), req)
if err != nil {
return err
}
res.Body.Close()
req = res.Request.WithContext(req.Context())
streamNo := 0
tws := make([]*bufio.Writer, r.maxStreams()-1)
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
ticket := q.Take()
bufIdx := streamNo % len(tws)
streamNo++
g.Go(func() (err error) {
defer func() {
if err != nil {
q.CloseWithError(err)
}
ticket.Close()
t.update(l, progress.Load(), err)
}()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := func() error {
req := req.Clone(req.Context())
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := doOK(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
tw := tws[bufIdx]
if tw == nil {
tw = bufio.NewWriterSize(nil, int(r.maxChunkSize()))
tws[bufIdx] = tw
}
tw.Reset(ticket)
defer tw.Reset(nil) // release ticket
_, err = io.CopyN(tw, res.Body, chunk.Size())
if err != nil {
return maybeUnexpectedEOF(err)
}
if err := tw.Flush(); err != nil {
return err
}
total := progress.Add(chunk.Size())
if total >= l.Size {
q.Close()
}
return nil
}()
if !canRetry(err) {
return err
}
}
return nil
})
}
}
}
// Close writer immediately after downloads finish, not at Pull
// exit. Using defer would keep file descriptors open until all
// layers complete, potentially exhausting system limits with
// many layers.
//
// The WaitGroup tracks when all chunks finish downloading,
// allowing precise writer closure in a background goroutine.
// Each layer briefly uses one extra goroutine while at most
// maxStreams()-1 chunks download in parallel.
//
// This caps file descriptors at maxStreams() instead of
// growing with layer count.
g.Go(func() error {
wg.Wait()
chunked.Close()
return nil
})
}
if err := g.Wait(); err != nil {
return err
}
if received.Load() != expected {
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
}
// store the manifest blob
md := blob.DigestFromBytes(m.Data)
if err := blob.PutBytes(c, md, m.Data); err != nil {
return err
}
// commit the manifest with a link
return c.Link(m.Name, md)
}
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
// before attempting to unlink the model.
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
_, n, _, err := parseName(name, r.NameMask)
func (r *Registry) Unlink(name string) (ok bool, _ error) {
n, err := r.parseName(name)
if err != nil {
return false, err
}
c, err := r.cache()
if err != nil {
return false, err
}
@@ -585,9 +592,10 @@ type Manifest struct {
Name string `json:"-"` // the canonical name of the model
Data []byte `json:"-"` // the raw data of the manifest
Layers []*Layer `json:"layers"`
}
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
// For legacy reasons, we still have to download the config layer.
Config *Layer `json:"config"`
}
// Layer returns the layer with the given
// digest, or nil if not found.
@@ -615,10 +623,9 @@ func (m Manifest) MarshalJSON() ([]byte, error) {
// last phase of the commit which expects it, but does nothing
// with it. This will be fixed in a future release of
// ollama.com.
Config *Layer `json:"config"`
Config Layer `json:"config"`
}{
M: M(m),
Config: &Layer{Digest: emptyDigest},
M: M(m),
}
return json.Marshal(v)
}
@@ -648,14 +655,18 @@ type Layer struct {
Size int64 `json:"size"`
}
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
// parsed using [names.ParseExtended] but the scheme is ignored.
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
_, n, d, err := parseName(name, r.NameMask)
// ResolveLocal resolves a name to a Manifest in the local cache.
func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
_, n, d, err := r.parseNameExtended(name)
if err != nil {
return nil, err
}
c, err := r.cache()
if err != nil {
return nil, err
}
if !d.IsValid() {
// No digest, so resolve the manifest by name.
d, err = c.Resolve(n.String())
if err != nil {
return nil, err
@@ -664,7 +675,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro
data, err := os.ReadFile(c.GetFile(d))
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name)
return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name)
}
return nil, err
}
@@ -677,7 +688,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro
// Resolve resolves a name to a Manifest in the remote registry.
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
scheme, n, d, err := parseName(name, r.NameMask)
scheme, n, d, err := r.parseNameExtended(name)
if err != nil {
return nil, err
}
@@ -687,7 +698,7 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
}
res, err := r.doOK(ctx, "GET", manifestURL, nil)
res, err := r.send(ctx, "GET", manifestURL, nil)
if err != nil {
return nil, err
}
@@ -704,6 +715,123 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
return m, nil
}
type chunksum struct {
URL string
Chunk blob.Chunk
Digest blob.Digest
}
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
return func(yield func(chunksum, error) bool) {
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
yield(chunksum{}, err)
return
}
if l.Size < r.maxChunkingThreshold() {
// any layer under the threshold should be downloaded
// in one go.
cs := chunksum{
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
),
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
Digest: l.Digest,
}
yield(cs, nil)
return
}
// A chunksums response is a sequence of chunksums in a
// simple, easy to parse line-oriented format.
//
// Example:
//
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
//
// << HTTP/1.1 200 OK
// << Content-Location: <blobURL>
// <<
// << <digest> <start>-<end>
// << ...
//
// The blobURL is the URL to download the chunks from.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
)
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
if err != nil {
yield(chunksum{}, err)
return
}
res, err := sendRequest(r.client(), req)
if err != nil {
yield(chunksum{}, err)
return
}
defer res.Body.Close()
if res.StatusCode != 200 {
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
yield(chunksum{}, err)
return
}
blobURL := res.Header.Get("Content-Location")
s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords)
for {
if !s.Scan() {
if s.Err() != nil {
yield(chunksum{}, s.Err())
}
return
}
d, err := blob.ParseDigest(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
return
}
if !s.Scan() {
err := s.Err()
if err == nil {
err = fmt.Errorf("missing chunk range for digest %s", d)
}
yield(chunksum{}, err)
return
}
chunk, err := parseChunk(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
return
}
cs := chunksum{
URL: blobURL,
Chunk: chunk,
Digest: d,
}
if !yield(cs, nil) {
return
}
}
}
}
func (r *Registry) client() *http.Client {
if r.HTTPClient != nil {
return r.HTTPClient
@@ -712,7 +840,7 @@ func (r *Registry) client() *http.Client {
}
// newRequest constructs a new request, ready to use, with the given method,
// url, and body, presigned with client Key and UserAgent.
// url, and body, pre-signed with client [Key] and [UserAgent].
func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
@@ -731,11 +859,17 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R
return req, nil
}
// doOK makes a request with the given client and request, and returns the
// sendRequest makes a request with the given client and request, and returns the
// response if the status code is 200. If the status code is not 200, an Error
// is parsed from the response body and returned. If any other error occurs, it
// is returned.
func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("request error %s: %w", r.URL, err)
}
}()
if r.URL.Scheme == "https+insecure" {
// TODO(bmizerany): clone client.Transport, set
// InsecureSkipVerify, etc.
@@ -778,20 +912,26 @@ func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
// Use the raw body if we can't parse it as an error object.
re.Message = string(out)
}
// coerce MANIFEST_UNKNOWN to ErrManifestNotFound
if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") {
return nil, ErrModelNotFound
}
re.Status = res.StatusCode
return nil, &re
}
return res, nil
}
// doOK is a convenience method for making a request with newRequest and
// passing it to doOK with r.client().
func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
// send is a convenience method for making a request with newRequest and
// passing it to send with r.client().
func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
req, err := r.newRequest(ctx, method, path, body)
if err != nil {
return nil, err
}
return doOK(r.client(), req)
return sendRequest(r.client(), req)
}
// makeAuthToken creates an Ollama auth token for the given private key.
@@ -854,9 +994,108 @@ func checkData(url string) string {
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
}
func maybeUnexpectedEOF(err error) error {
if errors.Is(err, io.EOF) {
return io.ErrUnexpectedEOF
}
return err
type publicError struct {
wrapped error
message string
}
func withPublicMessagef(err error, message string, args ...any) error {
return publicError{wrapped: err, message: fmt.Sprintf(message, args...)}
}
func (e publicError) Error() string { return e.message }
func (e publicError) Unwrap() error { return e.wrapped }
var supportedSchemes = []string{
"http",
"https",
"https+insecure",
}
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
// parseNameExtended parses and validates an extended name, returning the scheme, name,
// and digest.
//
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
// given, [ErrNameInvalid] wrapped with a display friendly message is returned.
//
// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly
// message is returned.
//
// If the name is not, once merged with the mask, fully qualified,
// [ErrNameInvalid] wrapped with a display friendly message is returned.
func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
scheme, name, digest := splitExtended(s)
scheme = cmp.Or(scheme, "https")
if !slices.Contains(supportedSchemes, scheme) {
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
return "", names.Name{}, blob.Digest{}, err
}
var d blob.Digest
if digest != "" {
var err error
d, err = blob.ParseDigest(digest)
if err != nil {
err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest)
return "", names.Name{}, blob.Digest{}, err
}
if name == "" {
// We have can resolve a manifest from a digest only,
// so skip name validation and return the scheme and
// digest.
return scheme, names.Name{}, d, nil
}
}
n, err := r.parseName(name)
if err != nil {
return "", names.Name{}, blob.Digest{}, err
}
return scheme, n, d, nil
}
// splitExtended splits an extended name string into its scheme, name, and digest
// parts.
//
// Examples:
//
// http://ollama.com/bmizerany/smol:latest@digest
// https://ollama.com/bmizerany/smol:latest
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
// model@digest
// @digest
func splitExtended(s string) (scheme, name, digest string) {
i := strings.Index(s, "://")
if i >= 0 {
scheme = s[:i]
s = s[i+3:]
}
i = strings.LastIndex(s, "@")
if i >= 0 {
digest = s[i+1:]
s = s[:i]
}
return scheme, s, digest
}
// parseChunk parses a string in the form "start-end" and returns the Chunk.
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
startPart, endPart, found := strings.Cut(string(s), "-")
if !found {
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
}
start, err := strconv.ParseInt(startPart, 10, 64)
if err != nil {
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
}
end, err := strconv.ParseInt(endPart, 10, 64)
if err != nil {
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
}
if start > end {
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
}
return blob.Chunk{Start: start, End: end}, nil
}

View File

@@ -2,6 +2,7 @@ package ollama
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
@@ -16,14 +17,36 @@ import (
"reflect"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/testutil"
)
func ExampleRegistry_cancelOnFirstError() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = WithTrace(ctx, &Trace{
Update: func(l *Layer, n int64, err error) {
if err != nil {
// Discontinue pulling layers if there is an
// error instead of continuing to pull more
// data.
cancel()
}
},
})
var r Registry
if err := r.Pull(ctx, "model"); err != nil {
// panic for demo purposes
panic(err)
}
}
func TestManifestMarshalJSON(t *testing.T) {
// All manifests should contain an "empty" config object.
var m Manifest
@@ -56,22 +79,23 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
// newClient constructs a cache with predefined manifests for testing. The manifests are:
//
// empty: no data
// zero: no layers
// single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data
// empty: no data
// zero: no layers
// single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data
//
// Tests that want to ensure the client does not communicate with the upstream
// registry should pass a nil handler, which will cause a panic if
// communication is attempted.
//
// To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper()
c, err := blob.Open(t.TempDir())
if err != nil {
t.Fatal(err)
@@ -84,14 +108,15 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
}
}
rc := &Registry{
r := &Registry{
Cache: c,
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
Transport: recordRoundTripper(upstreamRegistry),
},
}
link := func(name string, manifest string) {
_, n, _, err := parseName(name, rc.NameMask)
n, err := r.parseName(name)
if err != nil {
panic(err)
}
@@ -122,7 +147,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
link("invalid", "!!!!!")
return rc, c
return r, c
}
func okHandler(w http.ResponseWriter, r *http.Request) {
@@ -145,84 +170,61 @@ func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
return d
}
func TestRegistryPushInvalidNames(t *testing.T) {
rc, c := newClient(t, nil)
cases := []struct {
name string
err error
}{
{"", ErrNameInvalid},
{"@", ErrNameInvalid},
{"@x", blob.ErrInvalidDigest},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
// Create a new registry and push a new image.
err := rc.Push(t.Context(), c, tt.name, nil)
if !errors.Is(err, tt.err) {
t.Errorf("err = %v; want %v", err, tt.err)
}
})
}
}
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
return WithTrace(ctx, t), t
}
func TestPushZero(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "empty", nil)
rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), "empty", nil)
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
}
}
func TestPushSingle(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "single", nil)
rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), "single", nil)
testutil.Check(t, err)
}
func TestPushMultiple(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "multiple", nil)
rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), "multiple", nil)
testutil.Check(t, err)
}
func TestPushNotFound(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Errorf("unexpected request: %v", r)
})
err := rc.Push(t.Context(), c, "notfound", nil)
err := rc.Push(t.Context(), "notfound", nil)
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
}
}
func TestPushNullLayer(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Push(t.Context(), c, "null", nil)
rc, _ := newClient(t, nil)
err := rc.Push(t.Context(), "null", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
}
func TestPushSizeMismatch(t *testing.T) {
rc, c := newClient(t, nil)
rc, _ := newClient(t, nil)
ctx, _ := withTraceUnexpected(t.Context())
got := rc.Push(ctx, c, "sizemismatch", nil)
got := rc.Push(ctx, "sizemismatch", nil)
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
t.Errorf("err = %v; want size mismatch", got)
}
}
func TestPushInvalid(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Push(t.Context(), c, "invalid", nil)
rc, _ := newClient(t, nil)
err := rc.Push(t.Context(), "invalid", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
@@ -230,7 +232,7 @@ func TestPushInvalid(t *testing.T) {
func TestPushExistsAtRemote(t *testing.T) {
var pushed bool
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/uploads/") {
if !pushed {
// First push. Return an uploadURL.
@@ -258,35 +260,35 @@ func TestPushExistsAtRemote(t *testing.T) {
check := testutil.Checker(t)
err := rc.Push(ctx, c, "single", nil)
err := rc.Push(ctx, "single", nil)
check(err)
if !errors.Is(errors.Join(errs...), nil) {
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
}
err = rc.Push(ctx, c, "single", nil)
err = rc.Push(ctx, "single", nil)
check(err)
}
func TestPushRemoteError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(500)
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
return
}
})
got := rc.Push(t.Context(), c, "single", nil)
got := rc.Push(t.Context(), "single", nil)
checkErrCode(t, got, 500, "blob_error")
}
func TestPushLocationError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", ":///x")
w.WriteHeader(http.StatusAccepted)
})
got := rc.Push(t.Context(), c, "single", nil)
got := rc.Push(t.Context(), "single", nil)
wantContains := "invalid upload URL"
if got == nil || !strings.Contains(got.Error(), wantContains) {
t.Errorf("err = %v; want to contain %v", got, wantContains)
@@ -294,14 +296,14 @@ func TestPushLocationError(t *testing.T) {
}
func TestPushUploadRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Host == "blob.store" {
w.WriteHeader(499) // force RoundTrip error on upload
return
}
w.Header().Set("Location", "http://blob.store/blobs/123")
})
got := rc.Push(t.Context(), c, "single", nil)
got := rc.Push(t.Context(), "single", nil)
if !errors.Is(got, errRoundTrip) {
t.Errorf("got = %v; want %v", got, errRoundTrip)
}
@@ -317,20 +319,20 @@ func TestPushUploadFileOpenError(t *testing.T) {
os.Remove(c.GetFile(l.Digest))
},
})
got := rc.Push(ctx, c, "single", nil)
got := rc.Push(ctx, "single", nil)
if !errors.Is(got, fs.ErrNotExist) {
t.Errorf("got = %v; want fs.ErrNotExist", got)
}
}
func TestPushCommitRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
panic("unexpected")
}
w.WriteHeader(499) // force RoundTrip error
})
err := rc.Push(t.Context(), c, "zero", nil)
err := rc.Push(t.Context(), "zero", nil)
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
@@ -344,8 +346,8 @@ func checkNotExist(t *testing.T, err error) {
}
func TestRegistryPullInvalidName(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Pull(t.Context(), c, "://")
rc, _ := newClient(t, nil)
err := rc.Pull(t.Context(), "://")
if !errors.Is(err, ErrNameInvalid) {
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
}
@@ -360,10 +362,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
}
for _, resp := range cases {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, resp)
})
err := rc.Pull(t.Context(), c, "x")
err := rc.Pull(t.Context(), "x")
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want invalid manifest", err)
}
@@ -386,18 +388,18 @@ func TestRegistryPullNotCached(t *testing.T) {
})
// Confirm that the layer does not exist locally
_, err := rc.ResolveLocal(c, "model")
_, err := rc.ResolveLocal("model")
checkNotExist(t, err)
_, err = c.Get(d)
checkNotExist(t, err)
err = rc.Pull(t.Context(), c, "model")
err = rc.Pull(t.Context(), "model")
check(err)
mw, err := rc.Resolve(t.Context(), "model")
check(err)
mg, err := rc.ResolveLocal(c, "model")
mg, err := rc.ResolveLocal("model")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
@@ -422,7 +424,7 @@ func TestRegistryPullNotCached(t *testing.T) {
func TestRegistryPullCached(t *testing.T) {
cached := blob.DigestFromBytes("exists")
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(499) // should not be called
return
@@ -445,10 +447,10 @@ func TestRegistryPullCached(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
err := rc.Pull(ctx, c, "single")
err := rc.Pull(ctx, "single")
testutil.Check(t, err)
want := []int64{6}
want := []int64{0, 6}
if !errors.Is(errors.Join(errs...), ErrCached) {
t.Errorf("errs = %v; want %v", errs, ErrCached)
}
@@ -458,30 +460,30 @@ func TestRegistryPullCached(t *testing.T) {
}
func TestRegistryPullManifestNotFound(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
err := rc.Pull(t.Context(), c, "notfound")
err := rc.Pull(t.Context(), "notfound")
checkErrCode(t, err, 404, "")
}
func TestRegistryPullResolveRemoteError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
})
err := rc.Pull(t.Context(), c, "single")
err := rc.Pull(t.Context(), "single")
checkErrCode(t, err, 500, "an_error")
}
func TestRegistryPullResolveRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/manifests/") {
w.WriteHeader(499) // force RoundTrip error
return
}
})
err := rc.Pull(t.Context(), c, "single")
err := rc.Pull(t.Context(), "single")
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
@@ -534,7 +536,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
// Check that we pull all layers that we can.
err := rc.Pull(ctx, c, "mixed")
err := rc.Pull(ctx, "mixed")
if err != nil {
t.Fatal(err)
}
@@ -551,54 +553,6 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
}
}
func TestRegistryPullChunking(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
if r.URL.Host != "blob.store" {
// The production registry redirects to the blob store.
http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound)
return
}
if strings.Contains(r.URL.Path, "/blobs/") {
rng := r.Header.Get("Range")
if rng == "" {
http.Error(w, "missing range", http.StatusBadRequest)
return
}
_, c, err := chunks.ParseRange(r.Header.Get("Range"))
if err != nil {
panic(err)
}
io.WriteString(w, "remote"[c.Start:c.End+1])
return
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote"))
})
// Force chunking by setting the threshold to less than the size of the
// layer.
rc.ChunkingThreshold = 3
rc.MaxChunkSize = 3
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(d *Layer, n int64, err error) {
if err != nil {
t.Errorf("update %v %d %v", d, n, err)
}
reads = append(reads, n)
},
})
err := rc.Pull(ctx, c, "remote")
testutil.Check(t, err)
want := []int64{0, 3, 6}
if !slices.Equal(reads, want) {
t.Errorf("reads = %v; want %v", reads, want)
}
}
func TestRegistryResolveByDigest(t *testing.T) {
check := testutil.Checker(t)
@@ -622,13 +576,13 @@ func TestInsecureSkipVerify(t *testing.T) {
}))
defer s.Close()
const name = "ollama.com/library/insecure"
const name = "library/insecure"
var rc Registry
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
_, err := rc.Resolve(t.Context(), url)
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
t.Errorf("err = %v; want cert verifiction failure", err)
t.Errorf("err = %v; want cert verification failure", err)
}
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
@@ -724,3 +678,191 @@ func TestErrorUnmarshal(t *testing.T) {
})
}
}
// TestParseNameErrors tests that parseName returns errors messages with enough
// detail for users to debug naming issues they may encounter. Previous to this
// test, the error messages were not very helpful and each problem was reported
// as the same message.
//
// It is only for testing error messages, not that all invalids and valids are
// covered. Those are in other tests for names.Name and blob.Digest.
func TestParseNameExtendedErrors(t *testing.T) {
cases := []struct {
name string
err error
want string
}{}
var r Registry
for _, tt := range cases {
_, _, _, err := r.parseNameExtended(tt.name)
if !errors.Is(err, tt.err) {
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
}
if err != nil && !strings.Contains(err.Error(), tt.want) {
t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
}
}
}
func TestParseNameExtended(t *testing.T) {
cases := []struct {
in string
scheme string
name string
digest string
err string
}{
{in: "http://m", scheme: "http", name: "m"},
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
{in: "http+insecure://m", err: "unsupported scheme"},
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
{in: "", err: "invalid or missing name"},
{in: "m", scheme: "https", name: "m"},
{in: "://", err: "invalid or missing name"},
{in: "@sha256:deadbeef", err: "invalid digest"},
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
var r Registry
scheme, n, digest, err := r.parseNameExtended(tt.in)
if err != nil {
if tt.err == "" {
t.Errorf("err = %v; want nil", err)
} else if !strings.Contains(err.Error(), tt.err) {
t.Errorf("err = %v; want %q", err, tt.err)
}
} else if tt.err != "" {
t.Errorf("err = nil; want %q", tt.err)
}
if err == nil && !n.IsFullyQualified() {
t.Errorf("name = %q; want fully qualified", n)
}
if scheme != tt.scheme {
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
}
// smoke-test name is superset of tt.name
if !strings.Contains(n.String(), tt.name) {
t.Errorf("name = %q; want %q", n, tt.name)
}
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
if digest.String() != tt.digest {
t.Errorf("digest = %q; want %q", digest, tt.digest)
}
})
}
}
func TestUnlink(t *testing.T) {
t.Run("found by name", func(t *testing.T) {
rc, _ := newClient(t, nil)
// confirm linked
_, err := rc.ResolveLocal("single")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
// unlink
_, err = rc.Unlink("single")
testutil.Check(t, err)
// confirm unlinked
_, err = rc.ResolveLocal("single")
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want fs.ErrNotExist", err)
}
})
t.Run("not found by name", func(t *testing.T) {
rc, _ := newClient(t, nil)
ok, err := rc.Unlink("manifestNotFound")
if err != nil {
t.Fatal(err)
}
if ok {
t.Error("expected not found")
}
})
}
func TestPullChunksums(t *testing.T) {
check := testutil.Checker(t)
content := "hello"
var chunksums string
contentDigest := func() blob.Digest {
return blob.DigestFromBytes(content)
}
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/manifests/latest"):
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
w.Header().Set("Content-Location", loc)
io.WriteString(w, chunksums)
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
default:
t.Errorf("unexpected request: %v", r)
http.NotFound(w, r)
}
})
rc.MaxStreams = 1 // prevent concurrent chunk downloads
rc.ChunkingThreshold = 1 // for all blobs to be chunked
var mu sync.Mutex
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("Update: %v %d %v", l, n, err)
mu.Lock()
reads = append(reads, n)
mu.Unlock()
},
})
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
blob.DigestFromBytes("hel"),
blob.DigestFromBytes("lo"),
)
err := rc.Pull(ctx, "test")
check(err)
wantReads := []int64{
0, // initial signaling of layer pull starting
3, // first chunk read
2, // second chunk read
}
if !slices.Equal(reads, wantReads) {
t.Errorf("reads = %v; want %v", reads, wantReads)
}
mw, err := rc.Resolve(t.Context(), "test")
check(err)
mg, err := rc.ResolveLocal("test")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
for i := range mg.Layers {
_, err = c.Get(mg.Layers[i].Digest)
if err != nil {
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
}
}
// missing chunks
content = "llama"
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
err = rc.Pull(ctx, "missingchunks")
if err == nil {
t.Error("expected error because of missing chunks")
}
}

Some files were not shown because too many files have changed in this diff Show More