mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-22 14:53:56 +00:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1c28a2d3e6 | ||
|
|
39e29ae5dd | ||
|
|
30a9f063c9 | ||
|
|
ce7455a8e1 | ||
|
|
e3936d4fb3 | ||
|
|
940e62772e | ||
|
|
71e6a0d0d1 | ||
|
|
2cd11ae365 | ||
|
|
52bbad12f9 | ||
|
|
30e88d7f31 | ||
|
|
2b7ed61ca2 | ||
|
|
647513a7d4 | ||
|
|
a210ec74d2 |
@@ -364,6 +364,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||||
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
|
||||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||||
|
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
|
||||||
|
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
|
||||||
|
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
||||||
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
||||||
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
||||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
|
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
|
||||||
@@ -522,7 +525,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
|
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
|
||||||
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
|
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
|
||||||
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
|
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
|
||||||
- [vnc-lm](https://github.com/jk011ru/vnc-lm) (A containerized Discord bot with support for attachments and web links)
|
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
|
||||||
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
|
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
|
||||||
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
|
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
|
||||||
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
|
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
|
||||||
@@ -536,3 +539,4 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
### Observability
|
### Observability
|
||||||
|
|
||||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||||
|
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
||||||
|
|||||||
12
cmd/cmd.go
12
cmd/cmd.go
@@ -39,6 +39,7 @@ import (
|
|||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/server"
|
"github.com/ollama/ollama/server"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -558,6 +559,8 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||||
|
|
||||||
|
n := model.ParseName(args[0])
|
||||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||||
if spinner != nil {
|
if spinner != nil {
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
@@ -568,7 +571,16 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.Stop()
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
|
|
||||||
|
destination := n.String()
|
||||||
|
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
|
||||||
|
destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
|
||||||
|
}
|
||||||
|
fmt.Printf("\nYou can find your model at:\n\n")
|
||||||
|
fmt.Printf("\t%s\n", destination)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
125
cmd/cmd_test.go
125
cmd/cmd_test.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
@@ -369,3 +370,127 @@ func TestGetModelfileName(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPushHandler(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
modelName string
|
||||||
|
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
|
||||||
|
expectedError string
|
||||||
|
expectedOutput string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "successful push",
|
||||||
|
modelName: "test-model",
|
||||||
|
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||||
|
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("expected POST request, got %s", r.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.PushRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name != "test-model" {
|
||||||
|
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate progress updates
|
||||||
|
responses := []api.ProgressResponse{
|
||||||
|
{Status: "preparing manifest"},
|
||||||
|
{Digest: "sha256:abc123456789", Total: 100, Completed: 50},
|
||||||
|
{Digest: "sha256:abc123456789", Total: 100, Completed: 100},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, resp := range responses {
|
||||||
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.(http.Flusher).Flush()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized push",
|
||||||
|
modelName: "unauthorized-model",
|
||||||
|
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||||
|
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
err := json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": "access denied",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
|
||||||
|
handler(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer mockServer.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
|
cmd.SetContext(context.TODO())
|
||||||
|
|
||||||
|
// Redirect stderr to capture progress output
|
||||||
|
oldStderr := os.Stderr
|
||||||
|
r, w, _ := os.Pipe()
|
||||||
|
os.Stderr = w
|
||||||
|
|
||||||
|
// Capture stdout for the "Model pushed" message
|
||||||
|
oldStdout := os.Stdout
|
||||||
|
outR, outW, _ := os.Pipe()
|
||||||
|
os.Stdout = outW
|
||||||
|
|
||||||
|
err := PushHandler(cmd, []string{tt.modelName})
|
||||||
|
|
||||||
|
// Restore stderr
|
||||||
|
w.Close()
|
||||||
|
os.Stderr = oldStderr
|
||||||
|
// drain the pipe
|
||||||
|
if _, err := io.ReadAll(r); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore stdout and get output
|
||||||
|
outW.Close()
|
||||||
|
os.Stdout = oldStdout
|
||||||
|
stdout, _ := io.ReadAll(outR)
|
||||||
|
|
||||||
|
if tt.expectedError == "" {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if tt.expectedOutput != "" {
|
||||||
|
if got := string(stdout); got != tt.expectedOutput {
|
||||||
|
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
||||||
|
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -514,7 +514,7 @@ func extractFileNames(input string) []string {
|
|||||||
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
||||||
// and followed by more characters and a file extension
|
// and followed by more characters and a file extension
|
||||||
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
||||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b`
|
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
|
||||||
re := regexp.MustCompile(regexPattern)
|
re := regexp.MustCompile(regexPattern)
|
||||||
|
|
||||||
return re.FindAllString(input, -1)
|
return re.FindAllString(input, -1)
|
||||||
|
|||||||
@@ -12,44 +12,45 @@ import (
|
|||||||
func TestExtractFilenames(t *testing.T) {
|
func TestExtractFilenames(t *testing.T) {
|
||||||
// Unix style paths
|
// Unix style paths
|
||||||
input := ` some preamble
|
input := ` some preamble
|
||||||
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2
|
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
|
||||||
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.svg`
|
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
|
||||||
res := extractFileNames(input)
|
res := extractFileNames(input)
|
||||||
assert.Len(t, res, 5)
|
assert.Len(t, res, 5)
|
||||||
assert.Contains(t, res[0], "one.png")
|
assert.Contains(t, res[0], "one.png")
|
||||||
assert.Contains(t, res[1], "two.jpg")
|
assert.Contains(t, res[1], "two.jpg")
|
||||||
assert.Contains(t, res[2], "three.jpeg")
|
assert.Contains(t, res[2], "three.jpeg")
|
||||||
assert.Contains(t, res[3], "four.png")
|
assert.Contains(t, res[3], "four.png")
|
||||||
assert.Contains(t, res[4], "five.svg")
|
assert.Contains(t, res[4], "five.JPG")
|
||||||
assert.NotContains(t, res[4], '"')
|
assert.NotContains(t, res[4], '"')
|
||||||
assert.NotContains(t, res, "inbtween")
|
assert.NotContains(t, res, "inbetween1")
|
||||||
|
assert.NotContains(t, res, "./1.svg")
|
||||||
|
|
||||||
// Windows style paths
|
// Windows style paths
|
||||||
input = ` some preamble
|
input = ` some preamble
|
||||||
c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2
|
c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2
|
||||||
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
|
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
|
||||||
./relative\ path/five.svg inbetween5 "./relative with/spaces/six.png inbetween6
|
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
|
||||||
d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
|
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
|
||||||
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.svg some ending
|
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
|
||||||
`
|
`
|
||||||
res = extractFileNames(input)
|
res = extractFileNames(input)
|
||||||
assert.Len(t, res, 10)
|
assert.Len(t, res, 10)
|
||||||
assert.NotContains(t, res, "inbtween")
|
assert.NotContains(t, res, "inbetween2")
|
||||||
assert.Contains(t, res[0], "one.png")
|
assert.Contains(t, res[0], "one.png")
|
||||||
assert.Contains(t, res[0], "c:")
|
assert.Contains(t, res[0], "c:")
|
||||||
assert.Contains(t, res[1], "two.jpg")
|
assert.Contains(t, res[1], "two.jpg")
|
||||||
assert.Contains(t, res[1], "c:")
|
assert.Contains(t, res[1], "c:")
|
||||||
assert.Contains(t, res[2], "three.jpeg")
|
assert.Contains(t, res[2], "three.jpeg")
|
||||||
assert.Contains(t, res[3], "four.png")
|
assert.Contains(t, res[3], "four.png")
|
||||||
assert.Contains(t, res[4], "five.svg")
|
assert.Contains(t, res[4], "five.JPG")
|
||||||
assert.Contains(t, res[5], "six.png")
|
assert.Contains(t, res[5], "six.png")
|
||||||
assert.Contains(t, res[6], "seven.svg")
|
assert.Contains(t, res[6], "seven.JPEG")
|
||||||
assert.Contains(t, res[6], "d:")
|
assert.Contains(t, res[6], "d:")
|
||||||
assert.Contains(t, res[7], "eight.png")
|
assert.Contains(t, res[7], "eight.png")
|
||||||
assert.Contains(t, res[7], "c:")
|
assert.Contains(t, res[7], "c:")
|
||||||
assert.Contains(t, res[8], "nine.png")
|
assert.Contains(t, res[8], "nine.png")
|
||||||
assert.Contains(t, res[8], "d:")
|
assert.Contains(t, res[8], "d:")
|
||||||
assert.Contains(t, res[9], "ten.svg")
|
assert.Contains(t, res[9], "ten.PNG")
|
||||||
assert.Contains(t, res[9], "E:")
|
assert.Contains(t, res[9], "E:")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ make apply-patches
|
|||||||
|
|
||||||
**Pin to new base commit**
|
**Pin to new base commit**
|
||||||
|
|
||||||
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring.env`
|
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring`
|
||||||
|
|
||||||
#### Applying patches
|
#### Applying patches
|
||||||
|
|
||||||
|
|||||||
@@ -199,6 +199,20 @@ func countCommonPrefix(a []input, b []input) int {
|
|||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
||||||
|
targetFree := (c.numCtx - numKeep) / 2
|
||||||
|
targetFree = max(targetFree, 1)
|
||||||
|
|
||||||
|
currentFree := c.numCtx - inputLen
|
||||||
|
discard := targetFree - currentFree
|
||||||
|
|
||||||
|
if discard < 0 {
|
||||||
|
discard = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return discard
|
||||||
|
}
|
||||||
|
|
||||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||||
//
|
//
|
||||||
@@ -208,11 +222,7 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
|||||||
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
targetFree := (c.numCtx - numKeep) / 2
|
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
|
||||||
targetFree = max(targetFree, 1)
|
|
||||||
|
|
||||||
currentFree := c.numCtx - len(slot.Inputs)
|
|
||||||
discard := targetFree - currentFree
|
|
||||||
|
|
||||||
if discard <= 0 {
|
if discard <= 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -227,3 +227,66 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestShiftDiscard(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
numCtx int
|
||||||
|
numKeep int
|
||||||
|
inputLen int
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Shift",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 5,
|
||||||
|
inputLen: 2048,
|
||||||
|
expected: 1021,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Max Keep",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 2047,
|
||||||
|
inputLen: 2048,
|
||||||
|
expected: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No Keep",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 0,
|
||||||
|
inputLen: 2048,
|
||||||
|
expected: 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Truncate",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 5,
|
||||||
|
inputLen: 5000,
|
||||||
|
expected: 3973,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Truncate Keep",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 2047,
|
||||||
|
inputLen: 5000,
|
||||||
|
expected: 2953,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No Op",
|
||||||
|
numCtx: 2048,
|
||||||
|
numKeep: 5,
|
||||||
|
inputLen: 512,
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := InputCache{numCtx: tt.numCtx}
|
||||||
|
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -122,9 +122,11 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|||||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||||
|
|
||||||
if len(inputs) > s.cache.numCtx {
|
if len(inputs) > s.cache.numCtx {
|
||||||
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep)
|
discard := len(inputs) - s.cache.numCtx
|
||||||
newInputs := inputs[:params.numKeep]
|
newInputs := inputs[:params.numKeep]
|
||||||
newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...)
|
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||||
|
|
||||||
|
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
|
||||||
inputs = newInputs
|
inputs = newInputs
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,10 +164,16 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|||||||
// generating image embeddings for each image
|
// generating image embeddings for each image
|
||||||
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||||
var inputs []input
|
var inputs []input
|
||||||
|
var parts []string
|
||||||
|
var matches [][]string
|
||||||
|
|
||||||
|
if s.image != nil {
|
||||||
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
||||||
parts := re.Split(prompt, -1)
|
parts = re.Split(prompt, -1)
|
||||||
matches := re.FindAllStringSubmatch(prompt, -1)
|
matches = re.FindAllStringSubmatch(prompt, -1)
|
||||||
|
} else {
|
||||||
|
parts = []string{prompt}
|
||||||
|
}
|
||||||
|
|
||||||
for i, part := range parts {
|
for i, part := range parts {
|
||||||
// text - tokenize
|
// text - tokenize
|
||||||
@@ -825,10 +833,21 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type multiLPath []string
|
||||||
|
|
||||||
|
func (m *multiLPath) Set(value string) error {
|
||||||
|
*m = append(*m, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *multiLPath) String() string {
|
||||||
|
return strings.Join(*m, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) loadModel(
|
func (s *Server) loadModel(
|
||||||
params llama.ModelParams,
|
params llama.ModelParams,
|
||||||
mpath string,
|
mpath string,
|
||||||
lpath string,
|
lpath multiLPath,
|
||||||
ppath string,
|
ppath string,
|
||||||
kvSize int,
|
kvSize int,
|
||||||
flashAttention bool,
|
flashAttention bool,
|
||||||
@@ -849,12 +868,14 @@ func (s *Server) loadModel(
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if lpath != "" {
|
if lpath.String() != "" {
|
||||||
err := s.model.ApplyLoraFromFile(s.lc, lpath, 1.0, threads)
|
for _, path := range lpath {
|
||||||
|
err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if ppath != "" {
|
if ppath != "" {
|
||||||
var err error
|
var err error
|
||||||
@@ -882,7 +903,6 @@ func main() {
|
|||||||
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
||||||
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
|
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
|
||||||
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
|
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||||
lpath := flag.String("lora", "", "Path to lora layer file")
|
|
||||||
port := flag.Int("port", 8080, "Port to expose the server on")
|
port := flag.Int("port", 8080, "Port to expose the server on")
|
||||||
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||||
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
|
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
|
||||||
@@ -892,6 +912,9 @@ func main() {
|
|||||||
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||||
requirements := flag.Bool("requirements", false, "print json requirement information")
|
requirements := flag.Bool("requirements", false, "print json requirement information")
|
||||||
|
|
||||||
|
var lpaths multiLPath
|
||||||
|
flag.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
if *requirements {
|
if *requirements {
|
||||||
printRequirements(os.Stdout)
|
printRequirements(os.Stdout)
|
||||||
@@ -938,7 +961,7 @@ func main() {
|
|||||||
params := llama.ModelParams{
|
params := llama.ModelParams{
|
||||||
NumGpuLayers: *nGpuLayers,
|
NumGpuLayers: *nGpuLayers,
|
||||||
MainGpu: *mainGpu,
|
MainGpu: *mainGpu,
|
||||||
UseMmap: !*noMmap && *lpath == "",
|
UseMmap: !*noMmap && lpaths.String() == "",
|
||||||
UseMlock: *mlock,
|
UseMlock: *mlock,
|
||||||
TensorSplit: tensorSplitFloats,
|
TensorSplit: tensorSplitFloats,
|
||||||
Progress: func(progress float32) {
|
Progress: func(progress float32) {
|
||||||
@@ -947,7 +970,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
server.ready.Add(1)
|
server.ready.Add(1)
|
||||||
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
|
go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
|
||||||
|
|
||||||
server.cond = sync.NewCond(&server.mu)
|
server.cond = sync.NewCond(&server.mu)
|
||||||
|
|
||||||
|
|||||||
@@ -144,10 +144,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
|||||||
// Loop through potential servers
|
// Loop through potential servers
|
||||||
finalErr := errors.New("no suitable llama servers found")
|
finalErr := errors.New("no suitable llama servers found")
|
||||||
|
|
||||||
if len(adapters) > 1 {
|
|
||||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
|
||||||
}
|
|
||||||
|
|
||||||
rDir, err := runners.Refresh(build.EmbedFS)
|
rDir, err := runners.Refresh(build.EmbedFS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -201,8 +197,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(adapters) > 0 {
|
if len(adapters) > 0 {
|
||||||
// TODO: applying multiple adapters is not supported by the llama.cpp server yet
|
for _, adapter := range adapters {
|
||||||
params = append(params, "--lora", adapters[0])
|
params = append(params, "--lora", adapter)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(projectors) > 0 {
|
if len(projectors) > 0 {
|
||||||
|
|||||||
@@ -200,9 +200,9 @@ func toolCallId() string {
|
|||||||
return "call_" + strings.ToLower(string(b))
|
return "call_" + strings.ToLower(string(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
func toToolCalls(tc []api.ToolCall) []ToolCall {
|
||||||
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
toolCalls := make([]ToolCall, len(tc))
|
||||||
for i, tc := range r.Message.ToolCalls {
|
for i, tc := range tc {
|
||||||
toolCalls[i].ID = toolCallId()
|
toolCalls[i].ID = toolCallId()
|
||||||
toolCalls[i].Type = "function"
|
toolCalls[i].Type = "function"
|
||||||
toolCalls[i].Function.Name = tc.Function.Name
|
toolCalls[i].Function.Name = tc.Function.Name
|
||||||
@@ -215,7 +215,11 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||||||
|
|
||||||
toolCalls[i].Function.Arguments = string(args)
|
toolCalls[i].Function.Arguments = string(args)
|
||||||
}
|
}
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
|
||||||
|
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||||
return ChatCompletion{
|
return ChatCompletion{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
@@ -244,6 +248,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||||
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||||
return ChatCompletionChunk{
|
return ChatCompletionChunk{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
@@ -252,7 +257,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
|||||||
SystemFingerprint: "fp_ollama",
|
SystemFingerprint: "fp_ollama",
|
||||||
Choices: []ChunkChoice{{
|
Choices: []ChunkChoice{{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Delta: Message{Role: "assistant", Content: r.Message.Content},
|
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
|
||||||
FinishReason: func(reason string) *string {
|
FinishReason: func(reason string) *string {
|
||||||
if len(reason) > 0 {
|
if len(reason) > 0 {
|
||||||
return &reason
|
return &reason
|
||||||
@@ -571,7 +576,7 @@ type EmbedWriter struct {
|
|||||||
model string
|
model string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||||
var serr api.StatusError
|
var serr api.StatusError
|
||||||
err := json.Unmarshal(data, &serr)
|
err := json.Unmarshal(data, &serr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -630,7 +635,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|||||||
func (w *ChatWriter) Write(data []byte) (int, error) {
|
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||||
code := w.ResponseWriter.Status()
|
code := w.ResponseWriter.Status()
|
||||||
if code != http.StatusOK {
|
if code != http.StatusOK {
|
||||||
return w.writeError(code, data)
|
return w.writeError(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.writeResponse(data)
|
return w.writeResponse(data)
|
||||||
@@ -679,7 +684,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
|||||||
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||||
code := w.ResponseWriter.Status()
|
code := w.ResponseWriter.Status()
|
||||||
if code != http.StatusOK {
|
if code != http.StatusOK {
|
||||||
return w.writeError(code, data)
|
return w.writeError(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.writeResponse(data)
|
return w.writeResponse(data)
|
||||||
@@ -704,7 +709,7 @@ func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
|||||||
func (w *ListWriter) Write(data []byte) (int, error) {
|
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||||
code := w.ResponseWriter.Status()
|
code := w.ResponseWriter.Status()
|
||||||
if code != http.StatusOK {
|
if code != http.StatusOK {
|
||||||
return w.writeError(code, data)
|
return w.writeError(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.writeResponse(data)
|
return w.writeResponse(data)
|
||||||
@@ -730,7 +735,7 @@ func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
|||||||
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||||
code := w.ResponseWriter.Status()
|
code := w.ResponseWriter.Status()
|
||||||
if code != http.StatusOK {
|
if code != http.StatusOK {
|
||||||
return w.writeError(code, data)
|
return w.writeError(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.writeResponse(data)
|
return w.writeResponse(data)
|
||||||
@@ -755,7 +760,7 @@ func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
|||||||
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||||
code := w.ResponseWriter.Status()
|
code := w.ResponseWriter.Status()
|
||||||
if code != http.StatusOK {
|
if code != http.StatusOK {
|
||||||
return w.writeError(code, data)
|
return w.writeError(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.writeResponse(data)
|
return w.writeResponse(data)
|
||||||
|
|||||||
@@ -1076,17 +1076,15 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
|
|||||||
req.ContentLength = contentLength
|
req.ContentLength = contentLength
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := (&http.Client{
|
c := &http.Client{
|
||||||
Transport: &http.Transport{
|
|
||||||
DialContext: testMakeRequestDialContext,
|
|
||||||
},
|
|
||||||
CheckRedirect: regOpts.CheckRedirect,
|
CheckRedirect: regOpts.CheckRedirect,
|
||||||
}).Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
if testMakeRequestDialContext != nil {
|
||||||
return resp, nil
|
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
tr.DialContext = testMakeRequestDialContext
|
||||||
|
c.Transport = tr
|
||||||
|
}
|
||||||
|
return c.Do(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getValue(header, key string) string {
|
func getValue(header, key string) string {
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ func TestExecuteWithTools(t *testing.T) {
|
|||||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||||
|
|
||||||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
||||||
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
|
||||||
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||||
|
|
||||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
|
|||||||
@@ -1458,6 +1458,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
|
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
slog.Error("chat prompt error", "error", err)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1467,6 +1468,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
var sb strings.Builder
|
||||||
|
var hasToolCalls bool
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
@@ -1492,7 +1495,34 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
|
||||||
|
// however this was a simple change for now without reworking streaming logic of this (and other)
|
||||||
|
// handlers
|
||||||
|
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
|
||||||
ch <- res
|
ch <- res
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Streaming tool calls:
|
||||||
|
// If tools are recognized, use a flag to track the sending of a tool downstream
|
||||||
|
// This ensures that content is cleared from the message on the last chunk sent
|
||||||
|
sb.WriteString(r.Content)
|
||||||
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
|
res.Message.ToolCalls = toolCalls
|
||||||
|
res.Message.Content = ""
|
||||||
|
sb.Reset()
|
||||||
|
hasToolCalls = true
|
||||||
|
ch <- res
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Done {
|
||||||
|
// Send any remaining content if no tool calls were detected
|
||||||
|
if !hasToolCalls {
|
||||||
|
res.Message.Content = sb.String()
|
||||||
|
}
|
||||||
|
ch <- res
|
||||||
|
}
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -25,10 +26,14 @@ type mockRunner struct {
|
|||||||
// CompletionRequest is only valid until the next call to Completion
|
// CompletionRequest is only valid until the next call to Completion
|
||||||
llm.CompletionRequest
|
llm.CompletionRequest
|
||||||
llm.CompletionResponse
|
llm.CompletionResponse
|
||||||
|
CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||||
m.CompletionRequest = r
|
m.CompletionRequest = r
|
||||||
|
if m.CompletionFn != nil {
|
||||||
|
return m.CompletionFn(ctx, r, fn)
|
||||||
|
}
|
||||||
fn(m.CompletionResponse)
|
fn(m.CompletionResponse)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -88,9 +93,14 @@ func TestGenerateChat(t *testing.T) {
|
|||||||
Model: "test",
|
Model: "test",
|
||||||
Modelfile: fmt.Sprintf(`FROM %s
|
Modelfile: fmt.Sprintf(`FROM %s
|
||||||
TEMPLATE """
|
TEMPLATE """
|
||||||
{{- if .System }}System: {{ .System }} {{ end }}
|
{{- if .Tools }}
|
||||||
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
{{ .Tools }}
|
||||||
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
{{ end }}
|
||||||
|
{{- range .Messages }}
|
||||||
|
{{- .Role }}: {{ .Content }}
|
||||||
|
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
|
{{- end }}
|
||||||
|
{{ end }}"""
|
||||||
`, createBinFile(t, llm.KV{
|
`, createBinFile(t, llm.KV{
|
||||||
"general.architecture": "llama",
|
"general.architecture": "llama",
|
||||||
"llama.block_count": uint32(1),
|
"llama.block_count": uint32(1),
|
||||||
@@ -263,7 +273,7 @@ func TestGenerateChat(t *testing.T) {
|
|||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -292,7 +302,7 @@ func TestGenerateChat(t *testing.T) {
|
|||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,7 +324,7 @@ func TestGenerateChat(t *testing.T) {
|
|||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -337,12 +347,242 @@ func TestGenerateChat(t *testing.T) {
|
|||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("messages with tools (non-streaming)", func(t *testing.T) {
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("failed to create test-system model: %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
}{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"location"},
|
||||||
|
Properties: map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}{
|
||||||
|
"location": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "The city and state",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
Type: "string",
|
||||||
|
Enum: []string{"celsius", "fahrenheit"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock.CompletionResponse = llm.CompletionResponse{
|
||||||
|
Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "done",
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
streamRequest := true
|
||||||
|
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "What's the weather in Seattle?"},
|
||||||
|
},
|
||||||
|
Tools: tools,
|
||||||
|
Stream: &streamRequest,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
var errResp struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
|
||||||
|
t.Logf("Failed to decode error response: %v", err)
|
||||||
|
} else {
|
||||||
|
t.Logf("Error response: %s", errResp.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.ChatResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Message.ToolCalls == nil {
|
||||||
|
t.Error("expected tool calls, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedToolCall := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"location": "Seattle, WA",
|
||||||
|
"unit": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
|
||||||
|
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("messages with tools (streaming)", func(t *testing.T) {
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
}{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"location"},
|
||||||
|
Properties: map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}{
|
||||||
|
"location": {
|
||||||
|
Type: "string",
|
||||||
|
Description: "The city and state",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
Type: "string",
|
||||||
|
Enum: []string{"celsius", "fahrenheit"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate streaming response with multiple chunks
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Send chunks with small delays to simulate streaming
|
||||||
|
responses := []llm.CompletionResponse{
|
||||||
|
{
|
||||||
|
Content: `{"name":"get_`,
|
||||||
|
Done: false,
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Content: `weather","arguments":{"location":"Seattle`,
|
||||||
|
Done: false,
|
||||||
|
PromptEvalCount: 2,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Content: `, WA","unit":"celsius"}}`,
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "tool_call",
|
||||||
|
PromptEvalCount: 3,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, resp := range responses {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
fn(resp)
|
||||||
|
time.Sleep(10 * time.Millisecond) // Small delay between chunks
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "What's the weather in Seattle?"},
|
||||||
|
},
|
||||||
|
Tools: tools,
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and validate the streamed responses
|
||||||
|
decoder := json.NewDecoder(w.Body)
|
||||||
|
var finalToolCall api.ToolCall
|
||||||
|
|
||||||
|
for {
|
||||||
|
var resp api.ChatResponse
|
||||||
|
if err := decoder.Decode(&resp); err == io.EOF {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Done {
|
||||||
|
if len(resp.Message.ToolCalls) != 1 {
|
||||||
|
t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
|
||||||
|
}
|
||||||
|
finalToolCall = resp.Message.ToolCalls[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedToolCall := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"location": "Seattle, WA",
|
||||||
|
"unit": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
|
||||||
|
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGenerate(t *testing.T) {
|
func TestGenerate(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user