mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
Merge branch 'ollama:main' into main
This commit is contained in:
14
.github/workflows/release.yaml
vendored
14
.github/workflows/release.yaml
vendored
@@ -225,7 +225,7 @@ jobs:
|
|||||||
CGO_CFLAGS=${{ env.CGO_CFLAGS }}
|
CGO_CFLAGS=${{ env.CGO_CFLAGS }}
|
||||||
CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }}
|
CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }}
|
||||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||||
cache-from: type=registry,ref=ollama/ollama:latest
|
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||||
cache-to: type=inline
|
cache-to: type=inline
|
||||||
- run: |
|
- run: |
|
||||||
for COMPONENT in bin/* lib/ollama/*; do
|
for COMPONENT in bin/* lib/ollama/*; do
|
||||||
@@ -298,8 +298,8 @@ jobs:
|
|||||||
context: .
|
context: .
|
||||||
platforms: ${{ matrix.os }}/${{ matrix.arch }}
|
platforms: ${{ matrix.os }}/${{ matrix.arch }}
|
||||||
build-args: ${{ matrix.build-args }}
|
build-args: ${{ matrix.build-args }}
|
||||||
outputs: type=image,name=ollama/ollama,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
|
||||||
cache-from: type=registry,ref=ollama/ollama:latest
|
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||||
cache-to: type=inline
|
cache-to: type=inline
|
||||||
- run: |
|
- run: |
|
||||||
mkdir -p ${{ matrix.os }}-${{ matrix.arch }}
|
mkdir -p ${{ matrix.os }}-${{ matrix.arch }}
|
||||||
@@ -331,7 +331,7 @@ jobs:
|
|||||||
latest=false
|
latest=false
|
||||||
suffix=${{ matrix.suffix }}
|
suffix=${{ matrix.suffix }}
|
||||||
images: |
|
images: |
|
||||||
ollama/ollama
|
${{ vars.DOCKER_REPO }}
|
||||||
tags: |
|
tags: |
|
||||||
type=ref,enable=true,priority=600,prefix=pr-,event=pr
|
type=ref,enable=true,priority=600,prefix=pr-,event=pr
|
||||||
type=semver,pattern={{version}}
|
type=semver,pattern={{version}}
|
||||||
@@ -341,8 +341,8 @@ jobs:
|
|||||||
path: ${{ runner.temp }}
|
path: ${{ runner.temp }}
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
- run: |
|
- run: |
|
||||||
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf 'ollama/ollama@%s ')
|
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf '${{ vars.DOCKER_REPO }}@%s ')
|
||||||
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
|
docker buildx imagetools inspect ${{ vars.DOCKER_REPO }}:${{ steps.metadata.outputs.version }}
|
||||||
working-directory: ${{ runner.temp }}
|
working-directory: ${{ runner.temp }}
|
||||||
|
|
||||||
# Trigger downstream release process
|
# Trigger downstream release process
|
||||||
@@ -380,4 +380,4 @@ jobs:
|
|||||||
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
|
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
|
||||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||||
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
|
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
|
||||||
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\", \"publish\": \"1\"}}"
|
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\", \"origin\": \"${GITHUB_REPOSITORY}\", \"publish\": \"1\"}}"
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ FROM ${FLAVOR} AS archive
|
|||||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||||
COPY --from=build /bin/ollama /bin/ollama
|
COPY --from=build /bin/ollama /bin/ollama
|
||||||
|
|
||||||
FROM ubuntu:20.04
|
FROM ubuntu:24.04
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y ca-certificates \
|
&& apt-get install -y ca-certificates \
|
||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://ollama.com">
|
<a href="https://ollama.com">
|
||||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
<img alt="ollama" width="240" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -10,7 +10,7 @@ Get up and running with large language models.
|
|||||||
|
|
||||||
### macOS
|
### macOS
|
||||||
|
|
||||||
[Download](https://ollama.com/download/Ollama-darwin.zip)
|
[Download](https://ollama.com/download/Ollama.dmg)
|
||||||
|
|
||||||
### Windows
|
### Windows
|
||||||
|
|
||||||
@@ -616,6 +616,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
||||||
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
|
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
|
||||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
||||||
|
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
||||||
|
|
||||||
### Supported backends
|
### Supported backends
|
||||||
|
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ type Message struct {
|
|||||||
Thinking string `json:"thinking,omitempty"`
|
Thinking string `json:"thinking,omitempty"`
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
ToolName string `json:"tool_name,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||||
|
|||||||
242
docs/api.md
242
docs/api.md
@@ -500,6 +500,7 @@ The `message` object has the following fields:
|
|||||||
- `thinking`: (for thinking models) the model's thinking process
|
- `thinking`: (for thinking models) the model's thinking process
|
||||||
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
||||||
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
|
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
|
||||||
|
- `tool_name` (optional): add the name of the tool that was executed to inform the model of the result
|
||||||
|
|
||||||
Advanced parameters (optional):
|
Advanced parameters (optional):
|
||||||
|
|
||||||
@@ -508,13 +509,21 @@ Advanced parameters (optional):
|
|||||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||||
|
|
||||||
|
### Tool calling
|
||||||
|
|
||||||
|
Tool calling is supported by providing a list of tools in the `tools` parameter. The model will generate a response that includes a list of tool calls. See the [Chat request (Streaming with tools)](#chat-request-streaming-with-tools) example below.
|
||||||
|
|
||||||
|
Models can also explain the result of the tool call in the response. See the [Chat request (With history, with tools)](#chat-request-with-history-with-tools) example below.
|
||||||
|
|
||||||
|
[See models with tool calling capabilities](https://ollama.com/search?c=tool).
|
||||||
|
|
||||||
### Structured outputs
|
### Structured outputs
|
||||||
|
|
||||||
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below.
|
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below.
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
#### Chat Request (Streaming)
|
#### Chat request (Streaming)
|
||||||
|
|
||||||
##### Request
|
##### Request
|
||||||
|
|
||||||
@@ -569,6 +578,88 @@ Final response:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Chat request (Streaming with tools)
|
||||||
|
|
||||||
|
##### Request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/chat -d '{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is the weather in tokyo?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the weather in a given city",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city to get the weather for"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": true
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Response
|
||||||
|
|
||||||
|
A stream of JSON objects is returned:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"created_at": "2025-07-07T20:22:19.184789Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": {
|
||||||
|
"city": "Tokyo"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"done": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Final response:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model":"llama3.2",
|
||||||
|
"created_at":"2025-07-07T20:22:19.19314Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
|
"done_reason": "stop",
|
||||||
|
"done": true,
|
||||||
|
"total_duration": 182242375,
|
||||||
|
"load_duration": 41295167,
|
||||||
|
"prompt_eval_count": 169,
|
||||||
|
"prompt_eval_duration": 24573166,
|
||||||
|
"eval_count": 15,
|
||||||
|
"eval_duration": 115959084
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
#### Chat request (No streaming)
|
#### Chat request (No streaming)
|
||||||
|
|
||||||
##### Request
|
##### Request
|
||||||
@@ -606,6 +697,74 @@ curl http://localhost:11434/api/chat -d '{
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Chat request (No streaming, with tools)
|
||||||
|
|
||||||
|
##### Request
|
||||||
|
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/chat -d '{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is the weather in tokyo?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the weather in a given city",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city to get the weather for"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": false
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"created_at": "2025-07-07T20:32:53.844124Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": {
|
||||||
|
"city": "Tokyo"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"done_reason": "stop",
|
||||||
|
"done": true,
|
||||||
|
"total_duration": 3244883583,
|
||||||
|
"load_duration": 2969184542,
|
||||||
|
"prompt_eval_count": 169,
|
||||||
|
"prompt_eval_duration": 141656333,
|
||||||
|
"eval_count": 18,
|
||||||
|
"eval_duration": 133293625
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
#### Chat request (Structured outputs)
|
#### Chat request (Structured outputs)
|
||||||
|
|
||||||
##### Request
|
##### Request
|
||||||
@@ -712,6 +871,87 @@ Final response:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### Chat request (With history, with tools)
|
||||||
|
|
||||||
|
##### Request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/chat -d '{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is the weather in Toronto?"
|
||||||
|
},
|
||||||
|
// the message from the model appended to history
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "get_temperature",
|
||||||
|
"arguments": {
|
||||||
|
"city": "Toronto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
// the tool call result appended to history
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": "11 degrees celsius",
|
||||||
|
"tool_name": "get_temperature",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": false,
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the weather in a given city",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city to get the weather for"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"created_at": "2025-07-07T20:43:37.688511Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The current temperature in Toronto is 11°C."
|
||||||
|
},
|
||||||
|
"done_reason": "stop",
|
||||||
|
"done": true,
|
||||||
|
"total_duration": 890771750,
|
||||||
|
"load_duration": 707634750,
|
||||||
|
"prompt_eval_count": 94,
|
||||||
|
"prompt_eval_duration": 91703208,
|
||||||
|
"eval_count": 11,
|
||||||
|
"eval_duration": 90282125
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
#### Chat request (with images)
|
#### Chat request (with images)
|
||||||
|
|
||||||
##### Request
|
##### Request
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ Check your compute compatibility to see if your card is supported:
|
|||||||
|
|
||||||
| Compute Capability | Family | Cards |
|
| Compute Capability | Family | Cards |
|
||||||
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
||||||
|
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
|
||||||
|
| | NVIDIA Professioal | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
|
||||||
| 9.0 | NVIDIA | `H200` `H100` |
|
| 9.0 | NVIDIA | `H200` `H100` |
|
||||||
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
||||||
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
||||||
|
|||||||
@@ -19,37 +19,6 @@ import (
|
|||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
started = time.Now()
|
|
||||||
chatModels = []string{
|
|
||||||
"granite3-moe:latest",
|
|
||||||
"granite-code:latest",
|
|
||||||
"nemotron-mini:latest",
|
|
||||||
"command-r:latest",
|
|
||||||
"gemma2:latest",
|
|
||||||
"gemma:latest",
|
|
||||||
"internlm2:latest",
|
|
||||||
"phi3.5:latest",
|
|
||||||
"phi3:latest",
|
|
||||||
// "phi:latest", // flaky, sometimes generates no response on first query
|
|
||||||
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
|
|
||||||
"falcon:latest",
|
|
||||||
"falcon2:latest",
|
|
||||||
"minicpm-v:latest",
|
|
||||||
"mistral:latest",
|
|
||||||
"orca-mini:latest",
|
|
||||||
"llama2:latest",
|
|
||||||
"llama3.1:latest",
|
|
||||||
"llama3.2:latest",
|
|
||||||
"llama3.2-vision:latest",
|
|
||||||
"qwen2.5-coder:latest",
|
|
||||||
"qwen:latest",
|
|
||||||
"solar-pro:latest",
|
|
||||||
"codellama:latest",
|
|
||||||
"nous-hermes:latest",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestModelsGenerate(t *testing.T) {
|
func TestModelsGenerate(t *testing.T) {
|
||||||
softTimeout, hardTimeout := getTimeouts(t)
|
softTimeout, hardTimeout := getTimeouts(t)
|
||||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||||
@@ -70,6 +39,13 @@ func TestModelsGenerate(t *testing.T) {
|
|||||||
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var chatModels []string
|
||||||
|
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
|
||||||
|
chatModels = ollamaEngineChatModels
|
||||||
|
} else {
|
||||||
|
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
|
||||||
|
}
|
||||||
|
|
||||||
for _, model := range chatModels {
|
for _, model := range chatModels {
|
||||||
t.Run(model, func(t *testing.T) {
|
t.Run(model, func(t *testing.T) {
|
||||||
if time.Now().Sub(started) > softTimeout {
|
if time.Now().Sub(started) > softTimeout {
|
||||||
|
|||||||
266
integration/model_perf_test.go
Normal file
266
integration/model_perf_test.go
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
//go:build integration && perf
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log/slog"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Models that don't work reliably with the large context prompt in this test case
|
||||||
|
longContextFlakes = []string{
|
||||||
|
"granite-code:latest",
|
||||||
|
"nemotron-mini:latest",
|
||||||
|
"falcon:latest", // 2k model
|
||||||
|
"falcon2:latest", // 2k model
|
||||||
|
"minicpm-v:latest",
|
||||||
|
"qwen:latest",
|
||||||
|
"solar-pro:latest",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Note: this test case can take a long time to run, particularly on models with
|
||||||
|
// large contexts. Run with -timeout set to a large value to get reasonable coverage
|
||||||
|
// Example usage:
|
||||||
|
//
|
||||||
|
// go test --tags=integration,perf -count 1 ./integration -v -timeout 90m -run TestModelsPerf 2>&1 | tee int.log
|
||||||
|
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
|
||||||
|
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
|
||||||
|
func TestModelsPerf(t *testing.T) {
|
||||||
|
softTimeout, hardTimeout := getTimeouts(t)
|
||||||
|
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// TODO use info API eventually
|
||||||
|
var maxVram uint64
|
||||||
|
var err error
|
||||||
|
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||||
|
maxVram, err = strconv.ParseUint(s, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := ioutil.ReadFile(filepath.Join("testdata", "shakespeare.txt"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open test data file: %s", err)
|
||||||
|
}
|
||||||
|
longPrompt := "summarize the following: " + string(data)
|
||||||
|
|
||||||
|
var chatModels []string
|
||||||
|
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
|
||||||
|
chatModels = ollamaEngineChatModels
|
||||||
|
} else {
|
||||||
|
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range chatModels {
|
||||||
|
t.Run(model, func(t *testing.T) {
|
||||||
|
if time.Now().Sub(started) > softTimeout {
|
||||||
|
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||||
|
}
|
||||||
|
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
var maxContext int
|
||||||
|
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("show failed: %s", err)
|
||||||
|
}
|
||||||
|
arch := resp.ModelInfo["general.architecture"].(string)
|
||||||
|
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
|
||||||
|
|
||||||
|
if maxVram > 0 {
|
||||||
|
resp, err := client.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list models failed %v", err)
|
||||||
|
}
|
||||||
|
for _, m := range resp.Models {
|
||||||
|
// For these tests we want to exercise a some amount of overflow on the CPU
|
||||||
|
if m.Name == model && float32(m.Size)*0.75 > float32(maxVram) {
|
||||||
|
t.Skipf("model %s is too large %s for available VRAM %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slog.Info("scneario", "model", model, "max_context", maxContext)
|
||||||
|
loaded := false
|
||||||
|
defer func() {
|
||||||
|
// best effort unload once we're done with the model
|
||||||
|
if loaded {
|
||||||
|
client.Generate(ctx, &api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Some models don't handle the long context data well so skip them to avoid flaky test results
|
||||||
|
longContextFlake := false
|
||||||
|
for _, flake := range longContextFlakes {
|
||||||
|
if model == flake {
|
||||||
|
longContextFlake = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// iterate through a few context sizes for coverage without excessive runtime
|
||||||
|
var contexts []int
|
||||||
|
keepGoing := true
|
||||||
|
if maxContext > 16384 {
|
||||||
|
contexts = []int{4096, 8192, 16384, maxContext}
|
||||||
|
} else if maxContext > 8192 {
|
||||||
|
contexts = []int{4096, 8192, maxContext}
|
||||||
|
} else if maxContext > 4096 {
|
||||||
|
contexts = []int{4096, maxContext}
|
||||||
|
} else if maxContext > 0 {
|
||||||
|
contexts = []int{maxContext}
|
||||||
|
} else {
|
||||||
|
t.Fatal("unknown max context size")
|
||||||
|
}
|
||||||
|
for _, numCtx := range contexts {
|
||||||
|
if !keepGoing && numCtx > 8192 { // Always try up to 8k before bailing out
|
||||||
|
break
|
||||||
|
}
|
||||||
|
skipLongPrompt := false
|
||||||
|
|
||||||
|
// Workaround bug 11172 temporarily...
|
||||||
|
maxPrompt := longPrompt
|
||||||
|
// If we fill the context too full with the prompt, many models
|
||||||
|
// quickly hit context shifting and go bad.
|
||||||
|
if len(maxPrompt) > numCtx*2 { // typically yields ~1/2 full context
|
||||||
|
maxPrompt = maxPrompt[:numCtx*2]
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
prompt string
|
||||||
|
anyResp []string
|
||||||
|
}{
|
||||||
|
{"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}},
|
||||||
|
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}},
|
||||||
|
}
|
||||||
|
var gpuPercent int
|
||||||
|
for _, tc := range testCases {
|
||||||
|
if len(tc.prompt) > 100 && (longContextFlake || skipLongPrompt) {
|
||||||
|
slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: tc.prompt,
|
||||||
|
KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
"num_ctx": numCtx,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
atLeastOne := false
|
||||||
|
var resp api.GenerateResponse
|
||||||
|
|
||||||
|
stream := false
|
||||||
|
req.Stream = &stream
|
||||||
|
|
||||||
|
// Avoid potentially getting stuck indefinitely
|
||||||
|
limit := 5 * time.Minute
|
||||||
|
genCtx, cancel := context.WithDeadlineCause(
|
||||||
|
ctx,
|
||||||
|
time.Now().Add(limit),
|
||||||
|
fmt.Errorf("generate on model %s with ctx %d took longer than %v", model, numCtx, limit),
|
||||||
|
)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error {
|
||||||
|
resp = rsp
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
// Avoid excessive test runs, but don't consider a failure with massive context
|
||||||
|
if numCtx > 16384 && strings.Contains(err.Error(), "took longer") {
|
||||||
|
slog.Warn("max context was taking too long, skipping", "error", err)
|
||||||
|
keepGoing = false
|
||||||
|
skipLongPrompt = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.Fatalf("generate error: ctx:%d err:%s", numCtx, err)
|
||||||
|
}
|
||||||
|
loaded = true
|
||||||
|
for _, expResp := range tc.anyResp {
|
||||||
|
if strings.Contains(strings.ToLower(resp.Response), expResp) {
|
||||||
|
atLeastOne = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !atLeastOne {
|
||||||
|
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Response)
|
||||||
|
}
|
||||||
|
models, err := client.ListRunning(ctx)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to list running models", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(models.Models) > 1 {
|
||||||
|
slog.Warn("multiple models loaded, may impact performance results", "loaded", models.Models)
|
||||||
|
}
|
||||||
|
for _, m := range models.Models {
|
||||||
|
if m.Name == model {
|
||||||
|
if m.SizeVRAM == 0 {
|
||||||
|
slog.Info("Model fully loaded into CPU")
|
||||||
|
gpuPercent = 0
|
||||||
|
keepGoing = false
|
||||||
|
skipLongPrompt = true
|
||||||
|
} else if m.SizeVRAM == m.Size {
|
||||||
|
slog.Info("Model fully loaded into GPU")
|
||||||
|
gpuPercent = 100
|
||||||
|
} else {
|
||||||
|
sizeCPU := m.Size - m.SizeVRAM
|
||||||
|
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
|
||||||
|
gpuPercent = int(100 - cpuPercent)
|
||||||
|
slog.Info("Model split between CPU/GPU", "CPU", cpuPercent, "GPU", gpuPercent)
|
||||||
|
keepGoing = false
|
||||||
|
|
||||||
|
// Heuristic to avoid excessive test run time
|
||||||
|
if gpuPercent < 90 {
|
||||||
|
skipLongPrompt = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
|
||||||
|
"MODEL",
|
||||||
|
"CONTEXT",
|
||||||
|
"GPU PERCENT",
|
||||||
|
"PROMPT COUNT",
|
||||||
|
"LOAD TIME",
|
||||||
|
"PROMPT EVAL TPS",
|
||||||
|
"EVAL TPS",
|
||||||
|
)
|
||||||
|
fmt.Fprintf(os.Stderr, "MODEL_PERF_DATA:%s,%d,%d,%d,%0.2f,%0.2f,%0.2f\n",
|
||||||
|
model,
|
||||||
|
numCtx,
|
||||||
|
gpuPercent,
|
||||||
|
resp.PromptEvalCount,
|
||||||
|
float64(resp.LoadDuration)/1000000000.0,
|
||||||
|
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
|
||||||
|
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
124456
integration/testdata/shakespeare.txt
vendored
Normal file
124456
integration/testdata/shakespeare.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@@ -32,6 +32,48 @@ const (
|
|||||||
smol = "llama3.2:1b"
|
smol = "llama3.2:1b"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
started = time.Now()
|
||||||
|
|
||||||
|
// Note: add newer models at the top of the list to test them first
|
||||||
|
ollamaEngineChatModels = []string{
|
||||||
|
"gemma3n:e2b",
|
||||||
|
"mistral-small3.2:latest",
|
||||||
|
"deepseek-r1:1.5b",
|
||||||
|
"llama3.2-vision:latest",
|
||||||
|
"qwen2.5-coder:latest",
|
||||||
|
"qwen2.5vl:3b",
|
||||||
|
"qwen3:0.6b", // dense
|
||||||
|
"qwen3:30b", // MOE
|
||||||
|
"gemma3:1b",
|
||||||
|
"llama3.1:latest",
|
||||||
|
"llama3.2:latest",
|
||||||
|
"gemma2:latest",
|
||||||
|
"minicpm-v:latest", // arch=qwen2
|
||||||
|
"granite-code:latest", // arch=llama
|
||||||
|
}
|
||||||
|
llamaRunnerChatModels = []string{
|
||||||
|
"mistral:latest",
|
||||||
|
"falcon3:latest",
|
||||||
|
"granite3-moe:latest",
|
||||||
|
"command-r:latest",
|
||||||
|
"nemotron-mini:latest",
|
||||||
|
"phi3.5:latest",
|
||||||
|
"solar-pro:latest",
|
||||||
|
"internlm2:latest",
|
||||||
|
"codellama:latest", // arch=llama
|
||||||
|
"phi3:latest",
|
||||||
|
"falcon2:latest",
|
||||||
|
"gemma:latest",
|
||||||
|
"llama2:latest",
|
||||||
|
"nous-hermes:latest",
|
||||||
|
"orca-mini:latest",
|
||||||
|
"qwen:latest",
|
||||||
|
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
|
||||||
|
"falcon:latest",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
func Init() {
|
func Init() {
|
||||||
lifecycle.InitLogging()
|
lifecycle.InitLogging()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -138,10 +138,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
|
requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
|
||||||
var props C.struct_ggml_backend_dev_props
|
var props C.struct_ggml_backend_dev_props
|
||||||
C.ggml_backend_dev_get_props(cpuDeviceBufferType.d, &props)
|
C.ggml_backend_dev_get_props(cpuDeviceBufferType.d, &props)
|
||||||
|
requiredMemory.CPU.UUID = C.GoString(props.uuid)
|
||||||
// Bug #11211: Reporting of UUIDs is temporarily disabled due to causing segfaults
|
|
||||||
// This only affects debug information until the new memory management code is in place
|
|
||||||
// requiredMemory.CPU.UUID = C.GoString(props.uuid)
|
|
||||||
requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1)
|
requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1)
|
||||||
requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1)
|
requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1)
|
||||||
|
|
||||||
@@ -158,7 +155,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
|
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
|
||||||
var props C.struct_ggml_backend_dev_props
|
var props C.struct_ggml_backend_dev_props
|
||||||
C.ggml_backend_dev_get_props(d, &props)
|
C.ggml_backend_dev_get_props(d, &props)
|
||||||
// requiredMemory.GPUs[i].UUID = C.GoString(props.uuid)
|
requiredMemory.GPUs[i].UUID = C.GoString(props.uuid)
|
||||||
requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1)
|
requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1)
|
||||||
requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
|
requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
|
||||||
}
|
}
|
||||||
@@ -358,6 +355,24 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
bbs[c] = b
|
bbs[c] = b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mimic llama runner logs summarizing layers and memory
|
||||||
|
slog.Info(fmt.Sprintf("offloading %d repeating layers to GPU", max(0, params.NumGPULayers-1)))
|
||||||
|
gpuLayers := 0
|
||||||
|
switch C.ggml_backend_dev_type(output.d) {
|
||||||
|
case 0: // CPU
|
||||||
|
slog.Info("offloading output layer to CPU")
|
||||||
|
case 1: // GPU
|
||||||
|
slog.Info("offloading output layer to GPU")
|
||||||
|
gpuLayers++
|
||||||
|
case 2: // ACCEL
|
||||||
|
slog.Info("offloading output layer to ACCEL")
|
||||||
|
}
|
||||||
|
for _, layer := range layers {
|
||||||
|
if C.ggml_backend_dev_type(layer.d) == 1 {
|
||||||
|
gpuLayers++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slog.Info(fmt.Sprintf("offloaded %d/%d layers to GPU", gpuLayers, len(layers)+1))
|
||||||
for bs := range maps.Values(bbs) {
|
for bs := range maps.Values(bbs) {
|
||||||
slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
|
slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -310,21 +310,23 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// collate messages based on role. consecutive messages of the same role are merged
|
// collate messages based on role. consecutive messages of the same role are merged
|
||||||
// into a single message. collate also collects and returns all system messages.
|
// into a single message (except for tool messages which preserve individual metadata).
|
||||||
|
// collate also collects and returns all system messages.
|
||||||
// collate mutates message content adding image tags ([img-%d]) as needed
|
// collate mutates message content adding image tags ([img-%d]) as needed
|
||||||
|
// todo(parthsareen): revisit for contextual image support
|
||||||
func collate(msgs []api.Message) (string, []*api.Message) {
|
func collate(msgs []api.Message) (string, []*api.Message) {
|
||||||
var system []string
|
var system []string
|
||||||
var collated []*api.Message
|
var collated []*api.Message
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
msg := msgs[i]
|
if msgs[i].Role == "system" {
|
||||||
if msg.Role == "system" {
|
system = append(system, msgs[i].Content)
|
||||||
system = append(system, msg.Content)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
|
// merges consecutive messages of the same role into a single message (except for tool messages)
|
||||||
collated[len(collated)-1].Content += "\n\n" + msg.Content
|
if len(collated) > 0 && collated[len(collated)-1].Role == msgs[i].Role && msgs[i].Role != "tool" {
|
||||||
|
collated[len(collated)-1].Content += "\n\n" + msgs[i].Content
|
||||||
} else {
|
} else {
|
||||||
collated = append(collated, &msg)
|
collated = append(collated, &msgs[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -163,10 +163,12 @@ func TestParse(t *testing.T) {
|
|||||||
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
|
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
|
||||||
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
|
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
|
||||||
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
|
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
|
||||||
|
{"{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role", "toolname"}},
|
||||||
{`{{- range .Messages }}
|
{`{{- range .Messages }}
|
||||||
{{- if eq .Role "system" }}SYSTEM:
|
{{- if eq .Role "system" }}SYSTEM:
|
||||||
{{- else if eq .Role "user" }}USER:
|
{{- else if eq .Role "user" }}USER:
|
||||||
{{- else if eq .Role "assistant" }}ASSISTANT:
|
{{- else if eq .Role "assistant" }}ASSISTANT:
|
||||||
|
{{- else if eq .Role "tool" }}TOOL:
|
||||||
{{- end }} {{ .Content }}
|
{{- end }} {{ .Content }}
|
||||||
{{- end }}`, []string{"content", "messages", "role"}},
|
{{- end }}`, []string{"content", "messages", "role"}},
|
||||||
{`{{- if .Messages }}
|
{`{{- if .Messages }}
|
||||||
@@ -376,3 +378,99 @@ func TestExecuteWithSuffix(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCollate(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
msgs []api.Message
|
||||||
|
expected []*api.Message
|
||||||
|
system string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "consecutive user messages are merged",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
{Role: "user", Content: "How are you?"},
|
||||||
|
},
|
||||||
|
expected: []*api.Message{
|
||||||
|
{Role: "user", Content: "Hello\n\nHow are you?"},
|
||||||
|
},
|
||||||
|
system: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "consecutive tool messages are NOT merged",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
|
||||||
|
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
|
||||||
|
},
|
||||||
|
expected: []*api.Message{
|
||||||
|
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
|
||||||
|
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
|
||||||
|
},
|
||||||
|
system: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool messages preserve all fields",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "What's the weather?"},
|
||||||
|
{Role: "tool", Content: "sunny", ToolName: "get_conditions"},
|
||||||
|
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
|
||||||
|
},
|
||||||
|
expected: []*api.Message{
|
||||||
|
{Role: "user", Content: "What's the weather?"},
|
||||||
|
{Role: "tool", Content: "sunny", ToolName: "get_conditions"},
|
||||||
|
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
|
||||||
|
},
|
||||||
|
system: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed messages with system",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "system", Content: "You are helpful"},
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
{Role: "assistant", Content: "Hi there!"},
|
||||||
|
{Role: "user", Content: "What's the weather?"},
|
||||||
|
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
|
||||||
|
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
|
||||||
|
{Role: "user", Content: "Thanks"},
|
||||||
|
},
|
||||||
|
expected: []*api.Message{
|
||||||
|
{Role: "system", Content: "You are helpful"},
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
{Role: "assistant", Content: "Hi there!"},
|
||||||
|
{Role: "user", Content: "What's the weather?"},
|
||||||
|
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
|
||||||
|
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
|
||||||
|
{Role: "user", Content: "Thanks"},
|
||||||
|
},
|
||||||
|
system: "You are helpful",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
system, collated := collate(tt.msgs)
|
||||||
|
if diff := cmp.Diff(system, tt.system); diff != "" {
|
||||||
|
t.Errorf("system mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare the messages
|
||||||
|
if len(collated) != len(tt.expected) {
|
||||||
|
t.Errorf("expected %d messages, got %d", len(tt.expected), len(collated))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range collated {
|
||||||
|
if collated[i].Role != tt.expected[i].Role {
|
||||||
|
t.Errorf("message %d role mismatch: got %q, want %q", i, collated[i].Role, tt.expected[i].Role)
|
||||||
|
}
|
||||||
|
if collated[i].Content != tt.expected[i].Content {
|
||||||
|
t.Errorf("message %d content mismatch: got %q, want %q", i, collated[i].Content, tt.expected[i].Content)
|
||||||
|
}
|
||||||
|
if collated[i].ToolName != tt.expected[i].ToolName {
|
||||||
|
t.Errorf("message %d tool name mismatch: got %q, want %q", i, collated[i].ToolName, tt.expected[i].ToolName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -134,16 +134,16 @@ func (p *Parser) parseToolCall() *api.ToolCall {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// only look for arguments if the tool has parameters
|
// only look for arguments after the tool name if the tool has parameters
|
||||||
|
// TODO (jmorganca): while probably uncommon, this doesn't support
|
||||||
|
// parsing arguments before the tool name, which may be needed in the future
|
||||||
args := map[string]any{}
|
args := map[string]any{}
|
||||||
if len(tool.Function.Parameters.Properties) > 0 {
|
if len(tool.Function.Parameters.Properties) > 0 {
|
||||||
if args, i = p.findArguments(*tool); args == nil {
|
if args, i = findArguments(*tool, p.buffer[end:]); args == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if i > end {
|
end += i
|
||||||
end = i
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tc := &api.ToolCall{
|
tc := &api.ToolCall{
|
||||||
@@ -160,14 +160,14 @@ func (p *Parser) parseToolCall() *api.ToolCall {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// findArguments returns the first object that appears to be
|
// findArguments returns the first object that appears to be
|
||||||
// arguments for the provided tool, returning nil
|
// arguments for the provided tool in the provided buffer,
|
||||||
func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) {
|
// returning nil if no arguments are found.
|
||||||
if len(p.buffer) == 0 {
|
// TODO (jmorganca): this does not support parsing omitted arguments
|
||||||
return nil, 0
|
// objects for functions that have all-optional parameters
|
||||||
}
|
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
|
||||||
|
// `{"name": "get_conditions"}` will not currently work
|
||||||
// no arguments to parse
|
func findArguments(tool api.Tool, buffer []byte) (map[string]any, int) {
|
||||||
if len(tool.Function.Parameters.Properties) == 0 {
|
if len(buffer) == 0 {
|
||||||
return nil, 0
|
return nil, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,7 +177,7 @@ func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) {
|
|||||||
var object []byte
|
var object []byte
|
||||||
|
|
||||||
// find any outer json object
|
// find any outer json object
|
||||||
for i, c := range p.buffer {
|
for i, c := range buffer {
|
||||||
if c == '{' {
|
if c == '{' {
|
||||||
braces++
|
braces++
|
||||||
if start == -1 {
|
if start == -1 {
|
||||||
@@ -190,7 +190,7 @@ func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) {
|
|||||||
braces--
|
braces--
|
||||||
if braces == 0 {
|
if braces == 0 {
|
||||||
end = i + 1
|
end = i + 1
|
||||||
object = p.buffer[start:end]
|
object = buffer[start:end]
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -202,8 +202,6 @@ func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var data map[string]any
|
var data map[string]any
|
||||||
|
|
||||||
// not valid json
|
|
||||||
if err := json.Unmarshal(object, &data); err != nil {
|
if err := json.Unmarshal(object, &data); err != nil {
|
||||||
return nil, 0
|
return nil, 0
|
||||||
}
|
}
|
||||||
@@ -212,15 +210,27 @@ func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) {
|
|||||||
find = func(obj any) map[string]any {
|
find = func(obj any) map[string]any {
|
||||||
switch obj := obj.(type) {
|
switch obj := obj.(type) {
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
found := true
|
valid := true
|
||||||
|
// check if all keys in the object exist in the tool's parameters
|
||||||
for key := range obj {
|
for key := range obj {
|
||||||
if _, exists := tool.Function.Parameters.Properties[key]; !exists {
|
if _, exists := tool.Function.Parameters.Properties[key]; !exists {
|
||||||
found = false
|
valid = false
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if found {
|
// check for required parameters
|
||||||
|
// TODO (jmorganca): this should error instead of silently failing
|
||||||
|
if valid {
|
||||||
|
for _, required := range tool.Function.Parameters.Required {
|
||||||
|
if _, exists := obj[required]; !exists {
|
||||||
|
valid = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if valid {
|
||||||
return obj
|
return obj
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,8 @@ func TestParser(t *testing.T) {
|
|||||||
Enum []any `json:"enum,omitempty"`
|
Enum []any `json:"enum,omitempty"`
|
||||||
} `json:"properties"`
|
} `json:"properties"`
|
||||||
}{
|
}{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
|
Required: []string{"city"},
|
||||||
Properties: map[string]struct {
|
Properties: map[string]struct {
|
||||||
Type api.PropertyType `json:"type"`
|
Type api.PropertyType `json:"type"`
|
||||||
Items any `json:"items,omitempty"`
|
Items any `json:"items,omitempty"`
|
||||||
@@ -159,8 +160,23 @@ func TestParser(t *testing.T) {
|
|||||||
calls: nil,
|
calls: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "missing args",
|
name: "empty args",
|
||||||
inputs: []string{`<tool_call>{"name": "get_conditions"}</tool_call>`},
|
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {}}</tool_call>`},
|
||||||
|
content: "",
|
||||||
|
tmpl: qwen,
|
||||||
|
calls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Index: 0,
|
||||||
|
Name: "get_conditions",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing required args",
|
||||||
|
inputs: []string{`<tool_call>{"name": "get_temperature", "arguments": {}}</tool_call>`},
|
||||||
content: "",
|
content: "",
|
||||||
tmpl: qwen,
|
tmpl: qwen,
|
||||||
calls: nil,
|
calls: nil,
|
||||||
@@ -259,9 +275,9 @@ func TestParser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen two tool calls one with no args",
|
name: "empty args followed by args",
|
||||||
inputs: []string{`Let me check the weather. <tool_call>{"name": "say_hello"}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}`},
|
inputs: []string{`Let me say hello and check the weather. <tool_call>{"name": "say_hello", "arguments": {}}</tool_call><tool_call>{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}</tool_call>`},
|
||||||
content: "Let me check the weather. ",
|
content: "Let me say hello and check the weather. ",
|
||||||
tmpl: qwen,
|
tmpl: qwen,
|
||||||
calls: []api.ToolCall{
|
calls: []api.ToolCall{
|
||||||
{
|
{
|
||||||
@@ -271,6 +287,31 @@ func TestParser(t *testing.T) {
|
|||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.ToolCallFunctionArguments{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Index: 1,
|
||||||
|
Name: "get_temperature",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"city": "London",
|
||||||
|
"format": "fahrenheit",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen empty followed by args",
|
||||||
|
inputs: []string{`Let me check the weather. <tool_call>{"name": "get_conditions", "arguments": {}}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}`},
|
||||||
|
content: "Let me check the weather. ",
|
||||||
|
tmpl: qwen,
|
||||||
|
calls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Index: 0,
|
||||||
|
Name: "get_conditions",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{},
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 1,
|
Index: 1,
|
||||||
@@ -1035,16 +1076,19 @@ func TestFindArguments(t *testing.T) {
|
|||||||
},
|
},
|
||||||
tool: tool,
|
tool: tool,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "deepseek",
|
||||||
|
buffer: []byte(`", "arguments": {"location": "Tokyo"}}</tool_call>`),
|
||||||
|
want: map[string]any{
|
||||||
|
"location": "Tokyo",
|
||||||
|
},
|
||||||
|
tool: tool,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
parser := &Parser{
|
|
||||||
buffer: tt.buffer,
|
|
||||||
tools: []api.Tool{tool, tool2},
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, _ := parser.findArguments(tool)
|
got, _ := findArguments(tt.tool, tt.buffer)
|
||||||
|
|
||||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||||
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
|
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
|
||||||
|
|||||||
Reference in New Issue
Block a user