Merge branch 'ollama:main' into main

This commit is contained in:
likelovewant
2024-11-29 12:50:59 +08:00
committed by GitHub
15 changed files with 584 additions and 75 deletions

View File

@@ -364,6 +364,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page) - [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.) - [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama) - [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol) - [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app) - [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings) - [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
@@ -522,7 +525,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.) - [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama) - [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.) - [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
- [vnc-lm](https://github.com/jk011ru/vnc-lm) (A containerized Discord bot with support for attachments and web links) - [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality) - [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator) - [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator) - [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
@@ -536,3 +539,4 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Observability ### Observability
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics. - [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.

View File

@@ -39,6 +39,7 @@ import (
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
@@ -558,6 +559,8 @@ func PushHandler(cmd *cobra.Command, args []string) error {
} }
request := api.PushRequest{Name: args[0], Insecure: insecure} request := api.PushRequest{Name: args[0], Insecure: insecure}
n := model.ParseName(args[0])
if err := client.Push(cmd.Context(), &request, fn); err != nil { if err := client.Push(cmd.Context(), &request, fn); err != nil {
if spinner != nil { if spinner != nil {
spinner.Stop() spinner.Stop()
@@ -568,7 +571,16 @@ func PushHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
p.Stop()
spinner.Stop() spinner.Stop()
destination := n.String()
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
}
fmt.Printf("\nYou can find your model at:\n\n")
fmt.Printf("\t%s\n", destination)
return nil return nil
} }

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@@ -369,3 +370,127 @@ func TestGetModelfileName(t *testing.T) {
}) })
} }
} }
func TestPushHandler(t *testing.T) {
tests := []struct {
name string
modelName string
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
expectedError string
expectedOutput string
}{
{
name: "successful push",
modelName: "test-model",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST request, got %s", r.Method)
}
var req api.PushRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Name != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name)
}
// Simulate progress updates
responses := []api.ProgressResponse{
{Status: "preparing manifest"},
{Digest: "sha256:abc123456789", Total: 100, Completed: 50},
{Digest: "sha256:abc123456789", Total: 100, Completed: 100},
}
for _, resp := range responses {
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.(http.Flusher).Flush()
}
},
},
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
},
{
name: "unauthorized push",
modelName: "unauthorized-model",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(map[string]string{
"error": "access denied",
})
if err != nil {
t.Fatal(err)
}
},
},
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
handler(w, r)
return
}
http.Error(w, "not found", http.StatusNotFound)
}))
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{}
cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(context.TODO())
// Redirect stderr to capture progress output
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
// Capture stdout for the "Model pushed" message
oldStdout := os.Stdout
outR, outW, _ := os.Pipe()
os.Stdout = outW
err := PushHandler(cmd, []string{tt.modelName})
// Restore stderr
w.Close()
os.Stderr = oldStderr
// drain the pipe
if _, err := io.ReadAll(r); err != nil {
t.Fatal(err)
}
// Restore stdout and get output
outW.Close()
os.Stdout = oldStdout
stdout, _ := io.ReadAll(outR)
if tt.expectedError == "" {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if tt.expectedOutput != "" {
if got := string(stdout); got != tt.expectedOutput {
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
}
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
}
})
}
}

View File

@@ -514,7 +514,7 @@ func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20) // Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension // and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches // This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b` regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
re := regexp.MustCompile(regexPattern) re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1) return re.FindAllString(input, -1)

View File

@@ -12,44 +12,45 @@ import (
func TestExtractFilenames(t *testing.T) { func TestExtractFilenames(t *testing.T) {
// Unix style paths // Unix style paths
input := ` some preamble input := ` some preamble
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.svg` /unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
res := extractFileNames(input) res := extractFileNames(input)
assert.Len(t, res, 5) assert.Len(t, res, 5)
assert.Contains(t, res[0], "one.png") assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[1], "two.jpg") assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[2], "three.jpeg") assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png") assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.svg") assert.Contains(t, res[4], "five.JPG")
assert.NotContains(t, res[4], '"') assert.NotContains(t, res[4], '"')
assert.NotContains(t, res, "inbtween") assert.NotContains(t, res, "inbetween1")
assert.NotContains(t, res, "./1.svg")
// Windows style paths // Windows style paths
input = ` some preamble input = ` some preamble
c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2 c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4 /absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
./relative\ path/five.svg inbetween5 "./relative with/spaces/six.png inbetween6 ./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8 d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.svg some ending d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
` `
res = extractFileNames(input) res = extractFileNames(input)
assert.Len(t, res, 10) assert.Len(t, res, 10)
assert.NotContains(t, res, "inbtween") assert.NotContains(t, res, "inbetween2")
assert.Contains(t, res[0], "one.png") assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[0], "c:") assert.Contains(t, res[0], "c:")
assert.Contains(t, res[1], "two.jpg") assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[1], "c:") assert.Contains(t, res[1], "c:")
assert.Contains(t, res[2], "three.jpeg") assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png") assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.svg") assert.Contains(t, res[4], "five.JPG")
assert.Contains(t, res[5], "six.png") assert.Contains(t, res[5], "six.png")
assert.Contains(t, res[6], "seven.svg") assert.Contains(t, res[6], "seven.JPEG")
assert.Contains(t, res[6], "d:") assert.Contains(t, res[6], "d:")
assert.Contains(t, res[7], "eight.png") assert.Contains(t, res[7], "eight.png")
assert.Contains(t, res[7], "c:") assert.Contains(t, res[7], "c:")
assert.Contains(t, res[8], "nine.png") assert.Contains(t, res[8], "nine.png")
assert.Contains(t, res[8], "d:") assert.Contains(t, res[8], "d:")
assert.Contains(t, res[9], "ten.svg") assert.Contains(t, res[9], "ten.PNG")
assert.Contains(t, res[9], "E:") assert.Contains(t, res[9], "E:")
} }

View File

@@ -105,7 +105,7 @@ make apply-patches
**Pin to new base commit** **Pin to new base commit**
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring.env` To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring`
#### Applying patches #### Applying patches

View File

@@ -199,6 +199,20 @@ func countCommonPrefix(a []input, b []input) int {
return count return count
} }
func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
targetFree := (c.numCtx - numKeep) / 2
targetFree = max(targetFree, 1)
currentFree := c.numCtx - inputLen
discard := targetFree - currentFree
if discard < 0 {
discard = 0
}
return discard
}
// Frees up space in the KV cache by deleting the oldest half of history and shifting // Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning). // the newest half into that space (saving numKeep inputs at the beginning).
// //
@@ -208,11 +222,7 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx) return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
} }
targetFree := (c.numCtx - numKeep) / 2 discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
targetFree = max(targetFree, 1)
currentFree := c.numCtx - len(slot.Inputs)
discard := targetFree - currentFree
if discard <= 0 { if discard <= 0 {
return nil return nil

View File

@@ -227,3 +227,66 @@ func TestFindCacheSlot(t *testing.T) {
}) })
} }
} }
func TestShiftDiscard(t *testing.T) {
tests := []struct {
name string
numCtx int
numKeep int
inputLen int
expected int
}{
{
name: "Shift",
numCtx: 2048,
numKeep: 5,
inputLen: 2048,
expected: 1021,
},
{
name: "Max Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 2048,
expected: 1,
},
{
name: "No Keep",
numCtx: 2048,
numKeep: 0,
inputLen: 2048,
expected: 1024,
},
{
name: "Truncate",
numCtx: 2048,
numKeep: 5,
inputLen: 5000,
expected: 3973,
},
{
name: "Truncate Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 5000,
expected: 2953,
},
{
name: "No Op",
numCtx: 2048,
numKeep: 5,
inputLen: 512,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := InputCache{numCtx: tt.numCtx}
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
if result != tt.expected {
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
}
})
}
}

View File

@@ -122,9 +122,11 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
params.numKeep = min(params.numKeep, s.cache.numCtx-1) params.numKeep = min(params.numKeep, s.cache.numCtx-1)
if len(inputs) > s.cache.numCtx { if len(inputs) > s.cache.numCtx {
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep) discard := len(inputs) - s.cache.numCtx
newInputs := inputs[:params.numKeep] newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...) newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
inputs = newInputs inputs = newInputs
} }
@@ -162,10 +164,16 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// generating image embeddings for each image // generating image embeddings for each image
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
var inputs []input var inputs []input
var parts []string
var matches [][]string
if s.image != nil {
re := regexp.MustCompile(`\[img-(\d+)\]`) re := regexp.MustCompile(`\[img-(\d+)\]`)
parts := re.Split(prompt, -1) parts = re.Split(prompt, -1)
matches := re.FindAllStringSubmatch(prompt, -1) matches = re.FindAllStringSubmatch(prompt, -1)
} else {
parts = []string{prompt}
}
for i, part := range parts { for i, part := range parts {
// text - tokenize // text - tokenize
@@ -825,10 +833,21 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
} }
} }
type multiLPath []string
func (m *multiLPath) Set(value string) error {
*m = append(*m, value)
return nil
}
func (m *multiLPath) String() string {
return strings.Join(*m, ", ")
}
func (s *Server) loadModel( func (s *Server) loadModel(
params llama.ModelParams, params llama.ModelParams,
mpath string, mpath string,
lpath string, lpath multiLPath,
ppath string, ppath string,
kvSize int, kvSize int,
flashAttention bool, flashAttention bool,
@@ -849,12 +868,14 @@ func (s *Server) loadModel(
panic(err) panic(err)
} }
if lpath != "" { if lpath.String() != "" {
err := s.model.ApplyLoraFromFile(s.lc, lpath, 1.0, threads) for _, path := range lpath {
err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
}
if ppath != "" { if ppath != "" {
var err error var err error
@@ -882,7 +903,6 @@ func main() {
mainGpu := flag.Int("main-gpu", 0, "Main GPU") mainGpu := flag.Int("main-gpu", 0, "Main GPU")
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention") flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size") kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
lpath := flag.String("lora", "", "Path to lora layer file")
port := flag.Int("port", 8080, "Port to expose the server on") port := flag.Int("port", 8080, "Port to expose the server on")
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)") verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
@@ -892,6 +912,9 @@ func main() {
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
requirements := flag.Bool("requirements", false, "print json requirement information") requirements := flag.Bool("requirements", false, "print json requirement information")
var lpaths multiLPath
flag.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
flag.Parse() flag.Parse()
if *requirements { if *requirements {
printRequirements(os.Stdout) printRequirements(os.Stdout)
@@ -938,7 +961,7 @@ func main() {
params := llama.ModelParams{ params := llama.ModelParams{
NumGpuLayers: *nGpuLayers, NumGpuLayers: *nGpuLayers,
MainGpu: *mainGpu, MainGpu: *mainGpu,
UseMmap: !*noMmap && *lpath == "", UseMmap: !*noMmap && lpaths.String() == "",
UseMlock: *mlock, UseMlock: *mlock,
TensorSplit: tensorSplitFloats, TensorSplit: tensorSplitFloats,
Progress: func(progress float32) { Progress: func(progress float32) {
@@ -947,7 +970,7 @@ func main() {
} }
server.ready.Add(1) server.ready.Add(1)
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache) go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
server.cond = sync.NewCond(&server.mu) server.cond = sync.NewCond(&server.mu)

View File

@@ -144,10 +144,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
// Loop through potential servers // Loop through potential servers
finalErr := errors.New("no suitable llama servers found") finalErr := errors.New("no suitable llama servers found")
if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
}
rDir, err := runners.Refresh(build.EmbedFS) rDir, err := runners.Refresh(build.EmbedFS)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -201,8 +197,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
} }
if len(adapters) > 0 { if len(adapters) > 0 {
// TODO: applying multiple adapters is not supported by the llama.cpp server yet for _, adapter := range adapters {
params = append(params, "--lora", adapters[0]) params = append(params, "--lora", adapter)
}
} }
if len(projectors) > 0 { if len(projectors) > 0 {

View File

@@ -200,9 +200,9 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b)) return "call_" + strings.ToLower(string(b))
} }
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { func toToolCalls(tc []api.ToolCall) []ToolCall {
toolCalls := make([]ToolCall, len(r.Message.ToolCalls)) toolCalls := make([]ToolCall, len(tc))
for i, tc := range r.Message.ToolCalls { for i, tc := range tc {
toolCalls[i].ID = toolCallId() toolCalls[i].ID = toolCallId()
toolCalls[i].Type = "function" toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name toolCalls[i].Function.Name = tc.Function.Name
@@ -215,7 +215,11 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls[i].Function.Arguments = string(args) toolCalls[i].Function.Arguments = string(args)
} }
return toolCalls
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletion{ return ChatCompletion{
Id: id, Id: id,
Object: "chat.completion", Object: "chat.completion",
@@ -244,6 +248,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
} }
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletionChunk{ return ChatCompletionChunk{
Id: id, Id: id,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
@@ -252,7 +257,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
SystemFingerprint: "fp_ollama", SystemFingerprint: "fp_ollama",
Choices: []ChunkChoice{{ Choices: []ChunkChoice{{
Index: 0, Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content}, Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
FinishReason: func(reason string) *string { FinishReason: func(reason string) *string {
if len(reason) > 0 { if len(reason) > 0 {
return &reason return &reason
@@ -571,7 +576,7 @@ type EmbedWriter struct {
model string model string
} }
func (w *BaseWriter) writeError(code int, data []byte) (int, error) { func (w *BaseWriter) writeError(data []byte) (int, error) {
var serr api.StatusError var serr api.StatusError
err := json.Unmarshal(data, &serr) err := json.Unmarshal(data, &serr)
if err != nil { if err != nil {
@@ -630,7 +635,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
func (w *ChatWriter) Write(data []byte) (int, error) { func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status() code := w.ResponseWriter.Status()
if code != http.StatusOK { if code != http.StatusOK {
return w.writeError(code, data) return w.writeError(data)
} }
return w.writeResponse(data) return w.writeResponse(data)
@@ -679,7 +684,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
func (w *CompleteWriter) Write(data []byte) (int, error) { func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status() code := w.ResponseWriter.Status()
if code != http.StatusOK { if code != http.StatusOK {
return w.writeError(code, data) return w.writeError(data)
} }
return w.writeResponse(data) return w.writeResponse(data)
@@ -704,7 +709,7 @@ func (w *ListWriter) writeResponse(data []byte) (int, error) {
func (w *ListWriter) Write(data []byte) (int, error) { func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status() code := w.ResponseWriter.Status()
if code != http.StatusOK { if code != http.StatusOK {
return w.writeError(code, data) return w.writeError(data)
} }
return w.writeResponse(data) return w.writeResponse(data)
@@ -730,7 +735,7 @@ func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
func (w *RetrieveWriter) Write(data []byte) (int, error) { func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status() code := w.ResponseWriter.Status()
if code != http.StatusOK { if code != http.StatusOK {
return w.writeError(code, data) return w.writeError(data)
} }
return w.writeResponse(data) return w.writeResponse(data)
@@ -755,7 +760,7 @@ func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
func (w *EmbedWriter) Write(data []byte) (int, error) { func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status() code := w.ResponseWriter.Status()
if code != http.StatusOK { if code != http.StatusOK {
return w.writeError(code, data) return w.writeError(data)
} }
return w.writeResponse(data) return w.writeResponse(data)

View File

@@ -1076,17 +1076,15 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
req.ContentLength = contentLength req.ContentLength = contentLength
} }
resp, err := (&http.Client{ c := &http.Client{
Transport: &http.Transport{
DialContext: testMakeRequestDialContext,
},
CheckRedirect: regOpts.CheckRedirect, CheckRedirect: regOpts.CheckRedirect,
}).Do(req)
if err != nil {
return nil, err
} }
if testMakeRequestDialContext != nil {
return resp, nil tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DialContext = testMakeRequestDialContext
c.Transport = tr
}
return c.Do(req)
} }
func getValue(header, key string) string { func getValue(header, key string) string {

View File

@@ -39,6 +39,7 @@ func TestExecuteWithTools(t *testing.T) {
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},

View File

@@ -1458,6 +1458,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil { if err != nil {
slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@@ -1467,6 +1468,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
var sb strings.Builder
var hasToolCalls bool
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
@@ -1492,7 +1495,34 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
// however this was a simple change for now without reworking streaming logic of this (and other)
// handlers
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
ch <- res ch <- res
return
}
// Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
sb.Reset()
hasToolCalls = true
ch <- res
return
}
if r.Done {
// Send any remaining content if no tool calls were detected
if !hasToolCalls {
res.Message.Content = sb.String()
}
ch <- res
}
}); err != nil { }); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }

View File

@@ -8,6 +8,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@@ -25,10 +26,14 @@ type mockRunner struct {
// CompletionRequest is only valid until the next call to Completion // CompletionRequest is only valid until the next call to Completion
llm.CompletionRequest llm.CompletionRequest
llm.CompletionResponse llm.CompletionResponse
CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
} }
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
m.CompletionRequest = r m.CompletionRequest = r
if m.CompletionFn != nil {
return m.CompletionFn(ctx, r, fn)
}
fn(m.CompletionResponse) fn(m.CompletionResponse)
return nil return nil
} }
@@ -88,9 +93,14 @@ func TestGenerateChat(t *testing.T) {
Model: "test", Model: "test",
Modelfile: fmt.Sprintf(`FROM %s Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """ TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }} {{- if .Tools }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }} {{ .Tools }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}""" {{ end }}
{{- range .Messages }}
{{- .Role }}: {{ .Content }}
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}
{{ end }}"""
`, createBinFile(t, llm.KV{ `, createBinFile(t, llm.KV{
"general.architecture": "llama", "general.architecture": "llama",
"llama.block_count": uint32(1), "llama.block_count": uint32(1),
@@ -263,7 +273,7 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" { if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@@ -292,7 +302,7 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" { if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@@ -314,7 +324,7 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" { if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@@ -337,12 +347,242 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" { if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!") checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
}) })
t.Run("messages with tools (non-streaming)", func(t *testing.T) {
if w.Code != http.StatusOK {
t.Fatalf("failed to create test-system model: %d", w.Code)
}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}{
"location": {
Type: "string",
Description: "The city and state",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
}
mock.CompletionResponse = llm.CompletionResponse{
Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
Done: true,
DoneReason: "done",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
}
streamRequest := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "What's the weather in Seattle?"},
},
Tools: tools,
Stream: &streamRequest,
})
if w.Code != http.StatusOK {
var errResp struct {
Error string `json:"error"`
}
if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
t.Logf("Failed to decode error response: %v", err)
} else {
t.Logf("Error response: %s", errResp.Error)
}
}
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Message.ToolCalls == nil {
t.Error("expected tool calls, got nil")
}
expectedToolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Seattle, WA",
"unit": "celsius",
},
},
}
if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
}
})
t.Run("messages with tools (streaming)", func(t *testing.T) {
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}{
"location": {
Type: "string",
Description: "The city and state",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
}
// Simulate streaming response with multiple chunks
var wg sync.WaitGroup
wg.Add(1)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
// Send chunks with small delays to simulate streaming
responses := []llm.CompletionResponse{
{
Content: `{"name":"get_`,
Done: false,
PromptEvalCount: 1,
PromptEvalDuration: 1,
},
{
Content: `weather","arguments":{"location":"Seattle`,
Done: false,
PromptEvalCount: 2,
PromptEvalDuration: 1,
},
{
Content: `, WA","unit":"celsius"}}`,
Done: true,
DoneReason: "tool_call",
PromptEvalCount: 3,
PromptEvalDuration: 1,
},
}
for _, resp := range responses {
select {
case <-ctx.Done():
return ctx.Err()
default:
fn(resp)
time.Sleep(10 * time.Millisecond) // Small delay between chunks
}
}
return nil
}
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "What's the weather in Seattle?"},
},
Tools: tools,
Stream: &stream,
})
wg.Wait()
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// Read and validate the streamed responses
decoder := json.NewDecoder(w.Body)
var finalToolCall api.ToolCall
for {
var resp api.ChatResponse
if err := decoder.Decode(&resp); err == io.EOF {
break
} else if err != nil {
t.Fatal(err)
}
if resp.Done {
if len(resp.Message.ToolCalls) != 1 {
t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
}
finalToolCall = resp.Message.ToolCalls[0]
}
}
expectedToolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Seattle, WA",
"unit": "celsius",
},
},
}
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
}
})
} }
func TestGenerate(t *testing.T) { func TestGenerate(t *testing.T) {