mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
Merge branch 'ollama:main' into main
This commit is contained in:
@@ -65,7 +65,7 @@ continuation of the sentence:
|
|||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
llm/backend/mlx: support the llama architecture
|
llm/backend/mlx: support the llama architecture
|
||||||
CONTRIBUTING: provide clairity on good commit messages, and bad
|
CONTRIBUTING: provide clarity on good commit messages, and bad
|
||||||
|
|
||||||
Bad Examples:
|
Bad Examples:
|
||||||
|
|
||||||
|
|||||||
@@ -432,6 +432,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||||
|
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||||
|
|
||||||
### Cloud
|
### Cloud
|
||||||
|
|
||||||
|
|||||||
@@ -1137,6 +1137,14 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this error should ideally be wrapped properly by the client
|
||||||
|
if strings.Contains(err.Error(), "upstream error") {
|
||||||
|
p.StopAndClear()
|
||||||
|
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||||
|
fmt.Println()
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -385,9 +385,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
case "modelfile":
|
case "modelfile":
|
||||||
fmt.Println(resp.Modelfile)
|
fmt.Println(resp.Modelfile)
|
||||||
case "parameters":
|
case "parameters":
|
||||||
|
fmt.Println("Model defined parameters:")
|
||||||
if resp.Parameters == "" {
|
if resp.Parameters == "" {
|
||||||
fmt.Println("No parameters were specified for this model.")
|
fmt.Println(" No additional parameters were specified for this model.")
|
||||||
} else {
|
} else {
|
||||||
|
for _, l := range strings.Split(resp.Parameters, "\n") {
|
||||||
|
fmt.Printf(" %s\n", l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
if len(opts.Options) > 0 {
|
if len(opts.Options) > 0 {
|
||||||
fmt.Println("User defined parameters:")
|
fmt.Println("User defined parameters:")
|
||||||
for k, v := range opts.Options {
|
for k, v := range opts.Options {
|
||||||
@@ -395,9 +401,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
}
|
}
|
||||||
fmt.Println("Model defined parameters:")
|
|
||||||
fmt.Println(resp.Parameters)
|
|
||||||
}
|
|
||||||
case "system":
|
case "system":
|
||||||
switch {
|
switch {
|
||||||
case opts.System != "":
|
case opts.System != "":
|
||||||
|
|||||||
@@ -11,14 +11,13 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -137,9 +136,7 @@ func TestConvertModel(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := maps.Keys(expect)
|
for _, k := range slices.Sorted(maps.Keys(expect)) {
|
||||||
slices.Sort(keys)
|
|
||||||
for _, k := range keys {
|
|
||||||
if v, ok := actual[k]; !ok {
|
if v, ok := actual[k]; !ok {
|
||||||
t.Errorf("missing %s", k)
|
t.Errorf("missing %s", k)
|
||||||
} else if v != expect[k] {
|
} else if v != expect[k] {
|
||||||
@@ -343,9 +340,7 @@ func TestConvertAdapter(t *testing.T) {
|
|||||||
|
|
||||||
actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
|
actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
|
||||||
|
|
||||||
keys := maps.Keys(c.Expected)
|
for _, k := range slices.Sorted(maps.Keys(c.Expected)) {
|
||||||
slices.Sort(keys)
|
|
||||||
for _, k := range keys {
|
|
||||||
if v, ok := actual[k]; !ok {
|
if v, ok := actual[k]; !ok {
|
||||||
t.Errorf("missing %s", k)
|
t.Errorf("missing %s", k)
|
||||||
} else if v != c.Expected[k] {
|
} else if v != c.Expected[k] {
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"maps"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/d4l3k/go-bfloat16"
|
"github.com/d4l3k/go-bfloat16"
|
||||||
"github.com/x448/float16"
|
"github.com/x448/float16"
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type safetensorMetadata struct {
|
type safetensorMetadata struct {
|
||||||
@@ -46,8 +46,7 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := maps.Keys(headers)
|
keys := slices.Sorted(maps.Keys(headers))
|
||||||
slices.Sort(keys)
|
|
||||||
|
|
||||||
names := make(map[string]struct{}, len(keys))
|
names := make(map[string]struct{}, len(keys))
|
||||||
|
|
||||||
|
|||||||
@@ -8,11 +8,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -260,11 +259,8 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
tokens[token.ID] = token
|
tokens[token.ID] = token
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := maps.Keys(tokens)
|
|
||||||
slices.Sort(keys)
|
|
||||||
|
|
||||||
v := Vocabulary{Model: "gpt2"}
|
v := Vocabulary{Model: "gpt2"}
|
||||||
for _, k := range keys {
|
for _, k := range slices.Sorted(maps.Keys(tokens)) {
|
||||||
token := tokens[k]
|
token := tokens[k]
|
||||||
v.Tokens = append(v.Tokens, token.Content)
|
v.Tokens = append(v.Tokens, token.Content)
|
||||||
v.Scores = append(v.Scores, float32(token.ID))
|
v.Scores = append(v.Scores, float32(token.ID))
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
driverMajor, driverMinor, err := AMDDriverVersion()
|
driverMajor, driverMinor, err := AMDDriverVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
||||||
slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
|
slog.Warn("ollama recommends running the https://www.amd.com/en/support/download/linux-drivers.html", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
|
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ To run tests, use `go test`:
|
|||||||
go test ./...
|
go test ./...
|
||||||
```
|
```
|
||||||
|
|
||||||
> NOTE: In rare cirumstances, you may need to change a package using the new
|
> NOTE: In rare circumstances, you may need to change a package using the new
|
||||||
> "synctest" package in go1.24.
|
> "synctest" package in go1.24.
|
||||||
>
|
>
|
||||||
> If you do not have the "synctest" package enabled, you will not see build or
|
> If you do not have the "synctest" package enabled, you will not see build or
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ curl -fsSL https://ollama.com/install.sh | sh
|
|||||||
Download and extract the package:
|
Download and extract the package:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -L https://ollama.com/download/ollama-linux-amd64.tgz -o ollama-linux-amd64.tgz
|
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
|
||||||
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
|
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
|
|||||||
|
|
||||||
## LLM libraries
|
## LLM libraries
|
||||||
|
|
||||||
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` an the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
|
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` and the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
|
||||||
|
|
||||||
In the server log, you will see a message that looks something like this (varies from release to release):
|
In the server log, you will see a message that looks something like this (varies from release to release):
|
||||||
|
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -71,7 +71,7 @@ require (
|
|||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.36.0
|
golang.org/x/crypto v0.36.0
|
||||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||||
golang.org/x/net v0.38.0 // indirect
|
golang.org/x/net v0.38.0 // indirect
|
||||||
golang.org/x/sys v0.31.0
|
golang.org/x/sys v0.31.0
|
||||||
golang.org/x/term v0.30.0
|
golang.org/x/term v0.30.0
|
||||||
|
|||||||
@@ -20,11 +20,21 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
|||||||
// The mask is of shape history size, batch size
|
// The mask is of shape history size, batch size
|
||||||
type Causal struct {
|
type Causal struct {
|
||||||
DType ml.DType
|
DType ml.DType
|
||||||
windowSize int32
|
|
||||||
|
// swaWindowSize is the number of tokens that will be included in the mask
|
||||||
|
// during attention operations. swaMemorySize is the number of tokens that
|
||||||
|
// will be retained in memory for partial prefix caching. Set to math.MaxInt32
|
||||||
|
// for unlimited or if sliding window attention is not being used.
|
||||||
|
swaWindowSize int32
|
||||||
|
swaMemorySize int32
|
||||||
|
|
||||||
chunkSize int32
|
chunkSize int32
|
||||||
|
|
||||||
opts CausalOptions
|
opts CausalOptions
|
||||||
|
|
||||||
|
// maxBatch is the largest batch that we might receive
|
||||||
|
maxBatch int
|
||||||
|
|
||||||
// config controls mostly backend-specific optimizations
|
// config controls mostly backend-specific optimizations
|
||||||
config *ml.CacheConfig
|
config *ml.CacheConfig
|
||||||
|
|
||||||
@@ -85,7 +95,6 @@ type cellRange struct {
|
|||||||
|
|
||||||
func NewCausalCache(shift shiftFn) *Causal {
|
func NewCausalCache(shift shiftFn) *Causal {
|
||||||
return &Causal{
|
return &Causal{
|
||||||
windowSize: math.MaxInt32,
|
|
||||||
shiftFn: shift,
|
shiftFn: shift,
|
||||||
ctxs: make(map[int]ml.Context),
|
ctxs: make(map[int]ml.Context),
|
||||||
keys: make(map[int]ml.Tensor),
|
keys: make(map[int]ml.Tensor),
|
||||||
@@ -95,7 +104,18 @@ func NewCausalCache(shift shiftFn) *Causal {
|
|||||||
|
|
||||||
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||||
return &Causal{
|
return &Causal{
|
||||||
windowSize: windowSize,
|
swaWindowSize: windowSize,
|
||||||
|
shiftFn: shift,
|
||||||
|
ctxs: make(map[int]ml.Context),
|
||||||
|
keys: make(map[int]ml.Tensor),
|
||||||
|
values: make(map[int]ml.Tensor),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
|
||||||
|
return &Causal{
|
||||||
|
swaWindowSize: windowSize,
|
||||||
|
swaMemorySize: memorySize,
|
||||||
shiftFn: shift,
|
shiftFn: shift,
|
||||||
ctxs: make(map[int]ml.Context),
|
ctxs: make(map[int]ml.Context),
|
||||||
keys: make(map[int]ml.Tensor),
|
keys: make(map[int]ml.Tensor),
|
||||||
@@ -105,7 +125,6 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|||||||
|
|
||||||
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
|
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
|
||||||
return &Causal{
|
return &Causal{
|
||||||
windowSize: math.MaxInt32,
|
|
||||||
chunkSize: chunkSize,
|
chunkSize: chunkSize,
|
||||||
shiftFn: shift,
|
shiftFn: shift,
|
||||||
ctxs: make(map[int]ml.Context),
|
ctxs: make(map[int]ml.Context),
|
||||||
@@ -135,11 +154,25 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
|||||||
c.config.MaskDType = ml.DTypeF32
|
c.config.MaskDType = ml.DTypeF32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.swaWindowSize == 0 {
|
||||||
|
c.swaWindowSize = math.MaxInt32
|
||||||
|
}
|
||||||
|
if c.swaMemorySize == 0 {
|
||||||
|
c.swaMemorySize = c.swaWindowSize
|
||||||
|
}
|
||||||
|
if int(c.swaMemorySize) > capacity {
|
||||||
|
c.swaMemorySize = math.MaxInt32
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.swaMemorySize < c.swaWindowSize {
|
||||||
|
panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
|
||||||
|
}
|
||||||
|
|
||||||
var cacheSize int
|
var cacheSize int
|
||||||
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
|
if c.swaMemorySize == math.MaxInt32 {
|
||||||
cacheSize = maxSequences * capacity
|
cacheSize = maxSequences * capacity
|
||||||
} else {
|
} else {
|
||||||
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
|
cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
|
||||||
}
|
}
|
||||||
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||||
c.cells = make([]cacheCell, cacheSize)
|
c.cells = make([]cacheCell, cacheSize)
|
||||||
@@ -147,6 +180,7 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
|||||||
c.DType = dtype
|
c.DType = dtype
|
||||||
c.cellRanges = make(map[int]cellRange)
|
c.cellRanges = make(map[int]cellRange)
|
||||||
c.backend = backend
|
c.backend = backend
|
||||||
|
c.maxBatch = maxBatch
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||||
@@ -183,7 +217,6 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.curCellRange = newRange()
|
|
||||||
for i, pos := range batch.Positions {
|
for i, pos := range batch.Positions {
|
||||||
seq := batch.Sequences[i]
|
seq := batch.Sequences[i]
|
||||||
|
|
||||||
@@ -194,19 +227,12 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
|||||||
seqRange = newRange()
|
seqRange = newRange()
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.curLoc+i > seqRange.max {
|
seqRange.min = min(seqRange.min, c.curLoc+i)
|
||||||
seqRange.max = c.curLoc + i
|
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i)
|
||||||
}
|
|
||||||
if seqRange.max > c.curCellRange.max {
|
seqRange.max = max(seqRange.max, c.curLoc+i)
|
||||||
c.curCellRange.max = seqRange.max
|
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i)
|
||||||
}
|
|
||||||
|
|
||||||
if c.curLoc+i < seqRange.min {
|
|
||||||
seqRange.min = c.curLoc + i
|
|
||||||
}
|
|
||||||
if seqRange.min < c.curCellRange.min {
|
|
||||||
c.curCellRange.min = seqRange.min
|
|
||||||
}
|
|
||||||
c.cellRanges[seq] = seqRange
|
c.cellRanges[seq] = seqRange
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -248,7 +274,16 @@ func (c *Causal) findStartLoc() (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) updateSlidingWindow() {
|
func (c *Causal) updateSlidingWindow() {
|
||||||
if c.windowSize == math.MaxInt32 {
|
c.curCellRange = newRange()
|
||||||
|
|
||||||
|
if c.swaMemorySize == math.MaxInt32 {
|
||||||
|
for _, seq := range c.curSequences {
|
||||||
|
if seqRange, ok := c.cellRanges[seq]; ok {
|
||||||
|
c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
|
||||||
|
c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,12 +313,16 @@ func (c *Causal) updateSlidingWindow() {
|
|||||||
|
|
||||||
for i := oldRange.min; i <= oldRange.max; i++ {
|
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||||
if slices.Contains(c.cells[i].sequences, seq) {
|
if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
if c.cells[i].pos < pos-c.windowSize {
|
if c.cells[i].pos < pos-c.swaMemorySize {
|
||||||
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||||
} else {
|
} else {
|
||||||
newRange.min = min(newRange.min, i)
|
newRange.min = min(newRange.min, i)
|
||||||
newRange.max = max(newRange.max, i)
|
newRange.max = max(newRange.max, i)
|
||||||
}
|
}
|
||||||
|
if c.cells[i].pos >= pos-c.swaWindowSize {
|
||||||
|
c.curCellRange.min = min(c.curCellRange.min, i)
|
||||||
|
c.curCellRange.max = max(c.curCellRange.max, i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,7 +362,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
|||||||
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||||
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
||||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -481,6 +520,8 @@ func (c *Causal) defrag() {
|
|||||||
|
|
||||||
c.cellRanges[seq] = seqRange
|
c.cellRanges[seq] = seqRange
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.updateSlidingWindow()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) SetLayer(layer int) {
|
func (c *Causal) SetLayer(layer int) {
|
||||||
@@ -606,7 +647,7 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) CanResume(seq int, pos int32) bool {
|
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||||
if c.windowSize == math.MaxInt32 {
|
if c.swaMemorySize == math.MaxInt32 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -628,8 +669,8 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
lastWindowStart := max(0, last-c.windowSize)
|
lastWindowStart := max(0, last-c.swaMemorySize)
|
||||||
posWindowStart := max(0, pos-c.windowSize)
|
posWindowStart := max(0, pos-c.swaWindowSize)
|
||||||
|
|
||||||
return posWindowStart >= lastWindowStart
|
return posWindowStart >= lastWindowStart
|
||||||
}
|
}
|
||||||
@@ -639,21 +680,34 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
|||||||
return ErrNotSupported
|
return ErrNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := c.backend.NewContext()
|
|
||||||
defer ctx.Close()
|
|
||||||
|
|
||||||
seqRange := c.cellRanges[seq]
|
seqRange := c.cellRanges[seq]
|
||||||
size := seqRange.max - seqRange.min + 1
|
|
||||||
|
|
||||||
|
for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
|
||||||
|
size := min(seqRange.max-start+1, c.maxBatch)
|
||||||
offsets := make([]int32, size)
|
offsets := make([]int32, size)
|
||||||
|
|
||||||
|
var batchFirst, batchLast int
|
||||||
|
|
||||||
|
batchFirst = -1
|
||||||
for i := range offsets {
|
for i := range offsets {
|
||||||
cell := c.cells[seqRange.min+i]
|
cell := c.cells[start+i]
|
||||||
|
|
||||||
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||||
offsets[i] = offset
|
offsets[i] = offset
|
||||||
|
if batchFirst < 0 {
|
||||||
|
batchFirst = i
|
||||||
|
}
|
||||||
|
batchLast = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if batchFirst < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
offsets = offsets[batchFirst : batchLast+1]
|
||||||
|
|
||||||
|
ctx := c.backend.NewContext()
|
||||||
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
|
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||||
|
|
||||||
for i, key := range c.keys {
|
for i, key := range c.keys {
|
||||||
@@ -665,14 +719,15 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
|||||||
numKVHeads := key.Dim(1)
|
numKVHeads := key.Dim(1)
|
||||||
rowSize := key.Stride(2)
|
rowSize := key.Stride(2)
|
||||||
|
|
||||||
key = key.View(ctx, rowSize*seqRange.min,
|
key = key.View(ctx, rowSize*(start+batchFirst),
|
||||||
kHeadDim, key.Stride(1),
|
kHeadDim, key.Stride(1),
|
||||||
numKVHeads, key.Stride(2),
|
numKVHeads, key.Stride(2),
|
||||||
size,
|
len(offsets),
|
||||||
)
|
)
|
||||||
|
|
||||||
roped, err := c.shiftFn(ctx, i, key, kShift)
|
roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
ctx.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -680,6 +735,8 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx.Compute()
|
ctx.Compute()
|
||||||
|
ctx.Close()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,6 +60,8 @@ func TestSWA(t *testing.T) {
|
|||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
x := float32(math.Inf(-1))
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
name: "FirstBatch",
|
name: "FirstBatch",
|
||||||
@@ -69,7 +71,12 @@ func TestSWA(t *testing.T) {
|
|||||||
pos: []int32{0, 1, 2, 3},
|
pos: []int32{0, 1, 2, 3},
|
||||||
expected: []float32{1, 2, 3, 4},
|
expected: []float32{1, 2, 3, 4},
|
||||||
expectedShape: []int{1, 1, 4},
|
expectedShape: []int{1, 1, 4},
|
||||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
expectedMask: []float32{
|
||||||
|
0, x, x, x,
|
||||||
|
0, 0, x, x,
|
||||||
|
x, 0, 0, x,
|
||||||
|
x, x, 0, 0,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "SecondBatch",
|
name: "SecondBatch",
|
||||||
@@ -79,7 +86,53 @@ func TestSWA(t *testing.T) {
|
|||||||
pos: []int32{4, 5},
|
pos: []int32{4, 5},
|
||||||
expected: []float32{5, 6, 3, 4},
|
expected: []float32{5, 6, 3, 4},
|
||||||
expectedShape: []int{1, 1, 4},
|
expectedShape: []int{1, 1, 4},
|
||||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
|
expectedMask: []float32{
|
||||||
|
0, x, x, 0,
|
||||||
|
0, 0, x, x,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCache(t, backend, cache, tests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSWAMem(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
cache := NewSWAMemCache(1, 3, nil)
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
x := float32(math.Inf(-1))
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
name: "FirstBatch",
|
||||||
|
in: []float32{1, 2, 3, 4},
|
||||||
|
inShape: []int{1, 1, 4},
|
||||||
|
seqs: []int{0, 0, 0, 0},
|
||||||
|
pos: []int32{0, 1, 2, 3},
|
||||||
|
expected: []float32{1, 2, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{
|
||||||
|
0, x, x, x,
|
||||||
|
0, 0, x, x,
|
||||||
|
x, 0, 0, x,
|
||||||
|
x, x, 0, 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SecondBatch",
|
||||||
|
in: []float32{5, 6},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 0},
|
||||||
|
pos: []int32{4, 5},
|
||||||
|
expected: []float32{4, 5, 6},
|
||||||
|
expectedShape: []int{1, 1, 3},
|
||||||
|
expectedMask: []float32{
|
||||||
|
0, 0, x,
|
||||||
|
x, 0, 0,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -437,6 +490,70 @@ func TestCanResume(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCanResumeSWAMem(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
windowSize := int32(4)
|
||||||
|
memSize := int32(5)
|
||||||
|
cache := NewSWAMemCache(windowSize, memSize, nil)
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
context := backend.NewContext()
|
||||||
|
defer context.Close()
|
||||||
|
|
||||||
|
err := cache.StartForward(context, input.Batch{
|
||||||
|
Positions: []int32{0, 1, 2, 3, 4, 5},
|
||||||
|
Sequences: []int{0, 0, 0, 0, 0, 0},
|
||||||
|
}, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartForward failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.SetLayer(0)
|
||||||
|
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6)
|
||||||
|
cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
|
// shift window by adding position 6
|
||||||
|
err = cache.StartForward(context, input.Batch{
|
||||||
|
Positions: []int32{6, 7},
|
||||||
|
Sequences: []int{0, 0},
|
||||||
|
}, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartForward failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.SetLayer(0)
|
||||||
|
tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2)
|
||||||
|
cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
|
// only the latest position has overlapping windows
|
||||||
|
if cache.CanResume(0, 0) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if cache.CanResume(0, 1) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if cache.CanResume(0, 2) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if cache.CanResume(0, 3) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if cache.CanResume(0, 4) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if cache.CanResume(0, 5) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if !cache.CanResume(0, 6) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
|
||||||
|
}
|
||||||
|
if !cache.CanResume(0, 7) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type testBackend struct {
|
type testBackend struct {
|
||||||
ml.Backend
|
ml.Backend
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,10 +16,10 @@ ggml-ci
|
|||||||
2 files changed, 67 insertions(+), 14 deletions(-)
|
2 files changed, 67 insertions(+), 14 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||||
index ee4f2dcb..f20f5615 100644
|
index a9eeebc6..110c9ece 100644
|
||||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||||
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
|
@@ -489,6 +489,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
||||||
GGML_METAL_KERNEL_TYPE_COS,
|
GGML_METAL_KERNEL_TYPE_COS,
|
||||||
GGML_METAL_KERNEL_TYPE_NEG,
|
GGML_METAL_KERNEL_TYPE_NEG,
|
||||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||||
@@ -27,7 +27,7 @@ index ee4f2dcb..f20f5615 100644
|
|||||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||||
@@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
@@ -1436,6 +1437,7 @@ @implementation GGMLMetalClass
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ index 64fb4ff4..5b9a0fe3 100644
|
|||||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
index 4c829153..9e64e5ae 100644
|
index d6960174..2b9fabf4 100644
|
||||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
@@ -35,6 +35,7 @@
|
@@ -35,6 +35,7 @@
|
||||||
|
|||||||
50
llama/patches/0021-Enable-CUDA-Graphs-for-gemma3n.patch
Normal file
50
llama/patches/0021-Enable-CUDA-Graphs-for-gemma3n.patch
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Oliver Simons <osimons@nvidia.com>
|
||||||
|
Date: Tue, 22 Jul 2025 11:02:28 +0200
|
||||||
|
Subject: [PATCH] Enable CUDA Graphs for gemma3n.
|
||||||
|
|
||||||
|
Similar to
|
||||||
|
https://github.com/ggml-org/llama.cpp/pull/14741,
|
||||||
|
though ollama has a slightly different model graph
|
||||||
|
than llama.cpp which requires different workaround
|
||||||
|
checks.
|
||||||
|
---
|
||||||
|
ggml/src/ggml-cuda/ggml-cuda.cu | 16 ++++++++++++----
|
||||||
|
1 file changed, 12 insertions(+), 4 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
|
index 2b9fabf4..28ccf4be 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
|
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
|
@@ -2474,6 +2474,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||||
|
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||||
|
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
||||||
|
|
||||||
|
+ const std::string gemma3n_per_layer_proj_src1_name = " (reshaped)";
|
||||||
|
+ const std::string gemma3n_node_name = "node_";
|
||||||
|
+
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
|
||||||
|
@@ -2495,12 +2498,17 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
|
||||||
|
- // disable CUDA graphs for batch size > 1 for now.
|
||||||
|
- // Changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||||
|
+ // workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n
|
||||||
|
+ // number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here
|
||||||
|
+ if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256
|
||||||
|
+ && node->ne[2] == 1
|
||||||
|
+ && node->ne[3] == 1
|
||||||
|
+ && node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false
|
||||||
|
+ && node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) {
|
||||||
|
+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||||
|
use_cuda_graph = false;
|
||||||
|
#ifndef NDEBUG
|
||||||
|
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||||
|
+ GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
27
llama/patches/0022-BF16-macos-version-guard.patch
Normal file
27
llama/patches/0022-BF16-macos-version-guard.patch
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Daniel Hiltgen <daniel@ollama.com>
|
||||||
|
Date: Wed, 30 Jul 2025 08:43:46 -0700
|
||||||
|
Subject: [PATCH] BF16 macos version guard
|
||||||
|
|
||||||
|
Only enable BF16 on supported MacOS versions (v14+)
|
||||||
|
---
|
||||||
|
ggml/src/ggml-metal/ggml-metal.m | 6 +++++-
|
||||||
|
1 file changed, 5 insertions(+), 1 deletion(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||||
|
index 110c9ece..ab46f6e3 100644
|
||||||
|
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||||
|
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||||
|
@@ -89,7 +89,11 @@
|
||||||
|
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||||
|
|
||||||
|
#if defined(GGML_METAL_USE_BF16)
|
||||||
|
- ctx->use_bfloat = ctx->has_bfloat;
|
||||||
|
+ if (@available(macOS 14.0, *)) {
|
||||||
|
+ ctx->use_bfloat = ctx->has_bfloat;
|
||||||
|
+ } else {
|
||||||
|
+ ctx->use_bfloat = false;
|
||||||
|
+ }
|
||||||
|
#else
|
||||||
|
ctx->use_bfloat = false;
|
||||||
|
#endif
|
||||||
16
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
16
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
@@ -2474,6 +2474,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
|||||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
||||||
|
|
||||||
|
const std::string gemma3n_per_layer_proj_src1_name = " (reshaped)";
|
||||||
|
const std::string gemma3n_node_name = "node_";
|
||||||
|
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
ggml_tensor * node = cgraph->nodes[i];
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
|
||||||
@@ -2495,12 +2498,17 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
|
// workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n
|
||||||
// disable CUDA graphs for batch size > 1 for now.
|
// number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here
|
||||||
// Changes in batch size or context size can cause changes to the grid size of some kernels.
|
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256
|
||||||
|
&& node->ne[2] == 1
|
||||||
|
&& node->ne[3] == 1
|
||||||
|
&& node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false
|
||||||
|
&& node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) {
|
||||||
|
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||||
use_cuda_graph = false;
|
use_cuda_graph = false;
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -89,7 +89,11 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|||||||
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||||
|
|
||||||
#if defined(GGML_METAL_USE_BF16)
|
#if defined(GGML_METAL_USE_BF16)
|
||||||
|
if (@available(macOS 14.0, *)) {
|
||||||
ctx->use_bfloat = ctx->has_bfloat;
|
ctx->use_bfloat = ctx->has_bfloat;
|
||||||
|
} else {
|
||||||
|
ctx->use_bfloat = false;
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
ctx->use_bfloat = false;
|
ctx->use_bfloat = false;
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -203,10 +203,9 @@ func (a AltUp) Predict(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions
|
|||||||
coefficients := a.PredictionCoefficient.Forward(ctx, modalities)
|
coefficients := a.PredictionCoefficient.Forward(ctx, modalities)
|
||||||
coefficients = coefficients.Reshape(ctx, opts.altupInputs, opts.altupInputs, coefficients.Dim(1), coefficients.Dim(2))
|
coefficients = coefficients.Reshape(ctx, opts.altupInputs, opts.altupInputs, coefficients.Dim(1), coefficients.Dim(2))
|
||||||
|
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
predictions := coefficients.Mulmat(ctx, hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx))
|
||||||
predictions := coefficients.Mulmat(ctx, hiddenStates)
|
predictions = predictions.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||||
predictions = predictions.Add(ctx, hiddenStates)
|
return predictions.Add(ctx, hiddenStates)
|
||||||
return predictions.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a AltUp) Correct(ctx ml.Context, predictions, activated, one ml.Tensor, opts *TextOptions) ml.Tensor {
|
func (a AltUp) Correct(ctx ml.Context, predictions, activated, one ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||||
|
|||||||
@@ -842,8 +842,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
}
|
}
|
||||||
resp.Parameters = strings.Join(params, "\n")
|
resp.Parameters = strings.Join(params, "\n")
|
||||||
|
|
||||||
|
if len(req.Options) > 0 {
|
||||||
|
if m.Options == nil {
|
||||||
|
m.Options = make(map[string]any)
|
||||||
|
}
|
||||||
for k, v := range req.Options {
|
for k, v := range req.Options {
|
||||||
if _, ok := req.Options[k]; ok {
|
|
||||||
m.Options[k] = v
|
m.Options[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -82,19 +83,6 @@ func createTestFile(t *testing.T, name string) (string, string) {
|
|||||||
return f.Name(), digest
|
return f.Name(), digest
|
||||||
}
|
}
|
||||||
|
|
||||||
// equalStringSlices checks if two slices of strings are equal.
|
|
||||||
func equalStringSlices(a, b []string) bool {
|
|
||||||
if len(a) != len(b) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := range a {
|
|
||||||
if a[i] != b[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
type panicTransport struct{}
|
type panicTransport struct{}
|
||||||
|
|
||||||
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
@@ -447,7 +435,7 @@ func TestRoutes(t *testing.T) {
|
|||||||
"stop \"foo\"",
|
"stop \"foo\"",
|
||||||
"top_p 0.9",
|
"top_p 0.9",
|
||||||
}
|
}
|
||||||
if !equalStringSlices(params, expectedParams) {
|
if !slices.Equal(params, expectedParams) {
|
||||||
t.Errorf("expected parameters %v, got %v", expectedParams, params)
|
t.Errorf("expected parameters %v, got %v", expectedParams, params)
|
||||||
}
|
}
|
||||||
paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64)
|
paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"maps"
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -14,7 +15,6 @@ import (
|
|||||||
"text/template/parse"
|
"text/template/parse"
|
||||||
|
|
||||||
"github.com/agnivade/levenshtein"
|
"github.com/agnivade/levenshtein"
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
@@ -157,9 +157,7 @@ func (t *Template) Vars() []string {
|
|||||||
set[strings.ToLower(n)] = struct{}{}
|
set[strings.ToLower(n)] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
vars = maps.Keys(set)
|
return slices.Sorted(maps.Keys(set))
|
||||||
slices.Sort(vars)
|
|
||||||
return vars
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Values struct {
|
type Values struct {
|
||||||
|
|||||||
113
tools/tools.go
113
tools/tools.go
@@ -120,16 +120,14 @@ func (p *Parser) parseToolCall() *api.ToolCall {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// only look for arguments after the tool name if the tool has parameters
|
var args map[string]any
|
||||||
// TODO (jmorganca): while probably uncommon, this doesn't support
|
if found, i := findArguments(p.buffer); found == nil {
|
||||||
// parsing arguments before the tool name, which may be needed in the future
|
|
||||||
args := map[string]any{}
|
|
||||||
if len(tool.Function.Parameters.Properties) > 0 {
|
|
||||||
var i int
|
|
||||||
if args, i = findArguments(*tool, p.buffer[end:]); args == nil {
|
|
||||||
return nil
|
return nil
|
||||||
|
} else {
|
||||||
|
args = found
|
||||||
|
if i > end {
|
||||||
|
end = i
|
||||||
}
|
}
|
||||||
end += i
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tc := &api.ToolCall{
|
tc := &api.ToolCall{
|
||||||
@@ -217,93 +215,70 @@ func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) {
|
|||||||
// objects for functions that have all-optional parameters
|
// objects for functions that have all-optional parameters
|
||||||
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
|
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
|
||||||
// `{"name": "get_conditions"}` will not currently work
|
// `{"name": "get_conditions"}` will not currently work
|
||||||
func findArguments(tool api.Tool, buffer []byte) (map[string]any, int) {
|
func findArguments(buffer []byte) (map[string]any, int) {
|
||||||
if len(buffer) == 0 {
|
if len(buffer) == 0 {
|
||||||
return nil, 0
|
return nil, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
var braces int
|
var braces int
|
||||||
var start int = -1
|
var start int = -1
|
||||||
var end int
|
|
||||||
var object []byte
|
|
||||||
|
|
||||||
// find any outer json object
|
|
||||||
for i, c := range buffer {
|
for i, c := range buffer {
|
||||||
if c == '{' {
|
if c == '{' {
|
||||||
braces++
|
if braces == 0 {
|
||||||
if start == -1 {
|
|
||||||
start = i
|
start = i
|
||||||
}
|
}
|
||||||
}
|
braces++
|
||||||
|
} else if c == '}' && braces > 0 {
|
||||||
if c == '}' {
|
|
||||||
if start != -1 {
|
|
||||||
braces--
|
braces--
|
||||||
if braces == 0 {
|
if braces == 0 && start != -1 {
|
||||||
end = i + 1
|
object := buffer[start : i+1]
|
||||||
object = buffer[start:end]
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if braces > 0 {
|
|
||||||
return nil, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
var data map[string]any
|
var data map[string]any
|
||||||
if err := json.Unmarshal(object, &data); err != nil {
|
if err := json.Unmarshal(object, &data); err != nil {
|
||||||
return nil, 0
|
start = -1
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var find func(obj any) map[string]any
|
var findObject func(obj map[string]any) (map[string]any, bool)
|
||||||
find = func(obj any) map[string]any {
|
findObject = func(obj map[string]any) (map[string]any, bool) {
|
||||||
switch obj := obj.(type) {
|
if _, hasName := obj["name"]; hasName {
|
||||||
|
if args, ok := obj["arguments"].(map[string]any); ok {
|
||||||
|
return args, true
|
||||||
|
}
|
||||||
|
if args, ok := obj["parameters"].(map[string]any); ok {
|
||||||
|
return args, true
|
||||||
|
}
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range obj {
|
||||||
|
switch child := v.(type) {
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
valid := true
|
if result, found := findObject(child); found {
|
||||||
// check if all keys in the object exist in the tool's parameters
|
return result, true
|
||||||
for key := range obj {
|
|
||||||
if _, exists := tool.Function.Parameters.Properties[key]; !exists {
|
|
||||||
valid = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// check for required parameters
|
|
||||||
// TODO (jmorganca): this should error instead of silently failing
|
|
||||||
if valid {
|
|
||||||
for _, required := range tool.Function.Parameters.Required {
|
|
||||||
if _, exists := obj[required]; !exists {
|
|
||||||
valid = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if valid {
|
|
||||||
return obj
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, value := range obj {
|
|
||||||
if result := find(value); result != nil {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case []any:
|
case []any:
|
||||||
for _, item := range obj {
|
for _, item := range child {
|
||||||
if result := find(item); result != nil {
|
if childObj, ok := item.(map[string]any); ok {
|
||||||
return result
|
if result, found := findObject(childObj); found {
|
||||||
|
return result, true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
result := find(data)
|
if args, found := findObject(data); found {
|
||||||
if result != nil {
|
return args, i
|
||||||
return result, end
|
}
|
||||||
|
|
||||||
|
return data, i
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, 0
|
return nil, 0
|
||||||
|
|||||||
@@ -227,13 +227,6 @@ func TestParser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "invalid arguments",
|
|
||||||
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {"city": "San Francisco"}}</tool_call>`},
|
|
||||||
content: "",
|
|
||||||
tmpl: qwen,
|
|
||||||
calls: nil,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "empty args",
|
name: "empty args",
|
||||||
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {}}</tool_call>`},
|
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {}}</tool_call>`},
|
||||||
@@ -249,13 +242,6 @@ func TestParser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "missing required args",
|
|
||||||
inputs: []string{`<tool_call>{"name": "get_temperature", "arguments": {}}</tool_call>`},
|
|
||||||
content: "",
|
|
||||||
tmpl: qwen,
|
|
||||||
calls: nil,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "text before tool call",
|
name: "text before tool call",
|
||||||
inputs: []string{`Let me check the weather. <tool_call>{"name": "get_temperature", "arguments": {"city": "New York"}}</tool_call>`},
|
inputs: []string{`Let me check the weather. <tool_call>{"name": "get_temperature", "arguments": {"city": "New York"}}</tool_call>`},
|
||||||
@@ -273,21 +259,6 @@ func TestParser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "qwen no args tool call",
|
|
||||||
inputs: []string{`Let me say hello to the user. I'll use the say_hello tool <tool_call>{"name": "say_hello"}</tool_call>`},
|
|
||||||
content: "Let me say hello to the user. I'll use the say_hello tool ",
|
|
||||||
tmpl: qwen,
|
|
||||||
calls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Index: 0,
|
|
||||||
Name: "say_hello",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "qwen no args with text",
|
name: "qwen no args with text",
|
||||||
inputs: []string{"Let me say hello to the user. I'll use the say_hello tool. "},
|
inputs: []string{"Let me say hello to the user. I'll use the say_hello tool. "},
|
||||||
@@ -521,52 +492,6 @@ func TestParser(t *testing.T) {
|
|||||||
content: "for { fmt.Println(\"hello\") }",
|
content: "for { fmt.Println(\"hello\") }",
|
||||||
tmpl: json,
|
tmpl: json,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "json no args tool call",
|
|
||||||
inputs: []string{
|
|
||||||
"{\"name\": \"say_hello\"}",
|
|
||||||
},
|
|
||||||
content: "",
|
|
||||||
tmpl: json,
|
|
||||||
calls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Index: 0,
|
|
||||||
Name: "say_hello",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "json no args no tool call",
|
|
||||||
inputs: []string{
|
|
||||||
"I'll use the say_hello tool to say hello to the user.",
|
|
||||||
},
|
|
||||||
content: "I'll use the say_hello tool to say hello to the user.",
|
|
||||||
tmpl: json,
|
|
||||||
calls: nil,
|
|
||||||
},
|
|
||||||
|
|
||||||
// TODO (jmorganca): this is a false positive, we should
|
|
||||||
// not be parsing this as a tool call
|
|
||||||
{
|
|
||||||
name: "json no args false positive",
|
|
||||||
inputs: []string{
|
|
||||||
`{say_hello!!!}`,
|
|
||||||
},
|
|
||||||
content: "",
|
|
||||||
tmpl: json,
|
|
||||||
calls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Index: 0,
|
|
||||||
Name: "say_hello",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "list multiple",
|
name: "list multiple",
|
||||||
inputs: []string{
|
inputs: []string{
|
||||||
@@ -684,26 +609,6 @@ func TestParser(t *testing.T) {
|
|||||||
tmpl: list,
|
tmpl: list,
|
||||||
calls: nil,
|
calls: nil,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "list with no arguments",
|
|
||||||
inputs: []string{
|
|
||||||
"[",
|
|
||||||
"{",
|
|
||||||
"\"name\": \"say_hello\"",
|
|
||||||
"}",
|
|
||||||
},
|
|
||||||
content: "",
|
|
||||||
tmpl: list,
|
|
||||||
calls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Index: 0,
|
|
||||||
Name: "say_hello",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "tool name with collision",
|
name: "tool name with collision",
|
||||||
inputs: []string{
|
inputs: []string{
|
||||||
@@ -711,7 +616,7 @@ func TestParser(t *testing.T) {
|
|||||||
"{",
|
"{",
|
||||||
"\"name\": \"say_hello",
|
"\"name\": \"say_hello",
|
||||||
"_world\",",
|
"_world\",",
|
||||||
"}",
|
"\"arguments\": {}}",
|
||||||
"}",
|
"}",
|
||||||
},
|
},
|
||||||
content: "",
|
content: "",
|
||||||
@@ -733,13 +638,13 @@ func TestParser(t *testing.T) {
|
|||||||
"{",
|
"{",
|
||||||
"\"name\": \"say_hello",
|
"\"name\": \"say_hello",
|
||||||
"_world\",",
|
"_world\",",
|
||||||
"}",
|
"\"arguments\": {}}",
|
||||||
"</tool_call>",
|
"</tool_call>",
|
||||||
"<tool_call>",
|
"<tool_call>",
|
||||||
"{",
|
"{",
|
||||||
"\"name\": \"say_hello",
|
"\"name\": \"say_hello",
|
||||||
"\",",
|
"\",",
|
||||||
"}",
|
"\"arguments\": {}}",
|
||||||
"</tool_call>",
|
"</tool_call>",
|
||||||
},
|
},
|
||||||
content: "",
|
content: "",
|
||||||
@@ -773,7 +678,7 @@ func TestParser(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "tool name with collision non streaming multiple",
|
name: "tool name with collision non streaming multiple",
|
||||||
inputs: []string{
|
inputs: []string{
|
||||||
`<tool_call>{"name": "say_hello"}</tool_call><tool_call>{"name": "say_hello_world"}`,
|
`<tool_call>{"name": "say_hello", "arguments": {}}</tool_call><tool_call>{"name": "say_hello_world", "arguments": {}}`,
|
||||||
},
|
},
|
||||||
content: "",
|
content: "",
|
||||||
tmpl: qwen,
|
tmpl: qwen,
|
||||||
@@ -797,7 +702,7 @@ func TestParser(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "tool name with collision non streaming shorter",
|
name: "tool name with collision non streaming shorter",
|
||||||
inputs: []string{
|
inputs: []string{
|
||||||
`<tool_call>{"name": "say_hello"}</tool_call>`,
|
`<tool_call>{"name": "say_hello", "arguments": {}}</tool_call>`,
|
||||||
},
|
},
|
||||||
content: "",
|
content: "",
|
||||||
tmpl: qwen,
|
tmpl: qwen,
|
||||||
@@ -814,7 +719,7 @@ func TestParser(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "tool name with collision non streaming longer",
|
name: "tool name with collision non streaming longer",
|
||||||
inputs: []string{
|
inputs: []string{
|
||||||
`<tool_call>{"name": "say_hello_world"}</tool_call>`,
|
`<tool_call>{"name": "say_hello_world", "arguments": {}}</tool_call>`,
|
||||||
},
|
},
|
||||||
content: "",
|
content: "",
|
||||||
tmpl: qwen,
|
tmpl: qwen,
|
||||||
@@ -871,6 +776,26 @@ func TestParser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "args before name",
|
||||||
|
inputs: []string{
|
||||||
|
`<tool_call>{"arguments": {"a": "5", "b": "10"}, "name": "add"}</tool_call>`,
|
||||||
|
},
|
||||||
|
content: "",
|
||||||
|
tmpl: qwen,
|
||||||
|
calls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Index: 0,
|
||||||
|
Name: "add",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"a": "5",
|
||||||
|
"b": "10",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -1167,75 +1092,25 @@ func TestFindTag(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFindArguments(t *testing.T) {
|
func TestFindArguments(t *testing.T) {
|
||||||
tool := api.Tool{
|
|
||||||
Type: "function",
|
|
||||||
Function: api.ToolFunction{
|
|
||||||
Name: "get_temperature",
|
|
||||||
Description: "Retrieve the temperature for a given location",
|
|
||||||
Parameters: struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Defs any `json:"$defs,omitempty"`
|
|
||||||
Items any `json:"items,omitempty"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
Properties map[string]struct {
|
|
||||||
Type api.PropertyType `json:"type"`
|
|
||||||
Items any `json:"items,omitempty"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []any `json:"enum,omitempty"`
|
|
||||||
} `json:"properties"`
|
|
||||||
}{
|
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]struct {
|
|
||||||
Type api.PropertyType `json:"type"`
|
|
||||||
Items any `json:"items,omitempty"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []any `json:"enum,omitempty"`
|
|
||||||
}{
|
|
||||||
"format": {
|
|
||||||
Type: api.PropertyType{"string"},
|
|
||||||
Description: "The format to return the temperature in",
|
|
||||||
Enum: []any{"fahrenheit", "celsius"},
|
|
||||||
},
|
|
||||||
"location": {
|
|
||||||
Type: api.PropertyType{"string"},
|
|
||||||
Description: "The location to get the temperature for",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
tool2 := api.Tool{
|
|
||||||
Type: "function",
|
|
||||||
Function: api.ToolFunction{
|
|
||||||
Name: "say_hello",
|
|
||||||
Description: "Say hello to the user",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
buffer []byte
|
buffer []byte
|
||||||
want map[string]any
|
want map[string]any
|
||||||
tool api.Tool
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "empty string",
|
name: "empty string",
|
||||||
buffer: []byte{},
|
buffer: []byte{},
|
||||||
want: nil,
|
want: nil,
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "whitespace only",
|
name: "whitespace only",
|
||||||
buffer: []byte(" \n\t "),
|
buffer: []byte(" \n\t "),
|
||||||
want: nil,
|
want: nil,
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "unbalanced braces - missing closing",
|
name: "unbalanced braces - missing closing",
|
||||||
buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`),
|
buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`),
|
||||||
want: nil,
|
want: nil,
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "unbalanced braces - extra closing",
|
name: "unbalanced braces - extra closing",
|
||||||
@@ -1243,13 +1118,11 @@ func TestFindArguments(t *testing.T) {
|
|||||||
want: map[string]any{
|
want: map[string]any{
|
||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid JSON",
|
name: "invalid JSON",
|
||||||
buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`),
|
buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`),
|
||||||
want: nil,
|
want: nil,
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid json",
|
name: "valid json",
|
||||||
@@ -1258,7 +1131,6 @@ func TestFindArguments(t *testing.T) {
|
|||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid arguments with special tokens",
|
name: "valid arguments with special tokens",
|
||||||
@@ -1267,16 +1139,14 @@ func TestFindArguments(t *testing.T) {
|
|||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid arguments in array",
|
name: "valid arguments in array",
|
||||||
buffer: []byte(`[{"arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`),
|
buffer: []byte(`[{"name": "get_temperature", "arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`),
|
||||||
want: map[string]any{
|
want: map[string]any{
|
||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "nested deep",
|
name: "nested deep",
|
||||||
@@ -1285,7 +1155,6 @@ func TestFindArguments(t *testing.T) {
|
|||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "one arg",
|
name: "one arg",
|
||||||
@@ -1293,7 +1162,6 @@ func TestFindArguments(t *testing.T) {
|
|||||||
want: map[string]any{
|
want: map[string]any{
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "two args",
|
name: "two args",
|
||||||
@@ -1302,13 +1170,6 @@ func TestFindArguments(t *testing.T) {
|
|||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no args",
|
|
||||||
buffer: []byte(`{"name": "say_hello"}`),
|
|
||||||
want: nil,
|
|
||||||
tool: tool2,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "deepseek",
|
name: "deepseek",
|
||||||
@@ -1316,7 +1177,6 @@ func TestFindArguments(t *testing.T) {
|
|||||||
want: map[string]any{
|
want: map[string]any{
|
||||||
"location": "Tokyo",
|
"location": "Tokyo",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "deepseek",
|
name: "deepseek",
|
||||||
@@ -1324,13 +1184,12 @@ func TestFindArguments(t *testing.T) {
|
|||||||
want: map[string]any{
|
want: map[string]any{
|
||||||
"location": "Tokyo",
|
"location": "Tokyo",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, _ := findArguments(tt.tool, tt.buffer)
|
got, _ := findArguments(tt.buffer)
|
||||||
|
|
||||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||||
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
|
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
|
||||||
|
|||||||
Reference in New Issue
Block a user