megrge upstream update and reslove the conflicts

This commit is contained in:
likelovewant
2024-07-22 17:00:43 +08:00
16 changed files with 611 additions and 329 deletions

View File

@@ -46,13 +46,24 @@ sudo modprobe nvidia_uvm`
## AMD Radeon ## AMD Radeon
Ollama supports the following AMD GPUs: Ollama supports the following AMD GPUs:
### Linux Support
| Family | Cards and accelerators | | Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | | -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` | | AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` | | AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` | | AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
### Overrides ### Windows Support
With ROCm v6.1, the following GPUs are supported on Windows.
| Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
### Overrides on Linux
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
some cases you can force the system to try to use a similar LLVM target that is some cases you can force the system to try to use a similar LLVM target that is
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4) close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
@@ -63,7 +74,7 @@ would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the
server. If you have an unsupported AMD GPU you can experiment using the list of server. If you have an unsupported AMD GPU you can experiment using the list of
supported types below. supported types below.
At this time, the known supported GPU types are the following LLVM Targets. At this time, the known supported GPU types on linux are the following LLVM Targets.
This table shows some example GPUs that map to these LLVM targets: This table shows some example GPUs that map to these LLVM targets:
| **LLVM Target** | **An Example GPU** | | **LLVM Target** | **An Example GPU** |
|-----------------|---------------------| |-----------------|---------------------|

View File

@@ -33,9 +33,10 @@ type HipLib struct {
} }
func NewHipLib() (*HipLib, error) { func NewHipLib() (*HipLib, error) {
// At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs/ this repo will consist with v5.7
h, err := windows.LoadLibrary("amdhip64.dll") h, err := windows.LoadLibrary("amdhip64.dll")
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err) return nil, fmt.Errorf("unable to load amdhip64.dll, please make sure to upgrade to the latest amd driver: %w", err)
} }
hl := &HipLib{} hl := &HipLib{}
hl.dll = h hl.dll = h

View File

@@ -92,8 +92,9 @@ func AMDGetGPUInfo() []RocmGPUInfo {
continue continue
} }
if gfxOverride == "" { if gfxOverride == "" {
if !slices.Contains[[]string, string](supported, gfx) { // Strip off Target Features when comparing
//slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported) if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) {
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting? // TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage") slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
continue continue

View File

@@ -12,7 +12,7 @@ import (
func TestContextExhaustion(t *testing.T) { func TestContextExhaustion(t *testing.T) {
// Longer needed for small footprint GPUs // Longer needed for small footprint GPUs
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel() defer cancel()
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
@@ -25,5 +25,10 @@ func TestContextExhaustion(t *testing.T) {
"num_ctx": 128, "num_ctx": 128,
}, },
} }
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"}) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
} }

View File

@@ -0,0 +1,43 @@
diff --git a/include/llama.h b/include/llama.h
index bb4b05ba..a92174e0 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -92,6 +92,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
+ LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
};
// note: these values should be synchronized with ggml_rope
diff --git a/src/llama.cpp b/src/llama.cpp
index 18364976..435b6fe5 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -5429,6 +5429,12 @@ static void llm_load_vocab(
} else if (
tokenizer_pre == "jais") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
+ } else if (
+ tokenizer_pre == "tekken") {
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN;
+ vocab.tokenizer_clean_spaces = false;
+ vocab.tokenizer_ignore_merges = true;
+ vocab.tokenizer_add_bos = true;
} else {
LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
@@ -15448,6 +15454,13 @@ struct llm_tokenizer_bpe {
" ?[^(\\s|.,!?…。,、।۔،)]+",
};
break;
+ case LLAMA_VOCAB_PRE_TYPE_TEKKEN:
+ // original regex from tokenizer.json
+ // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+ regex_exprs = {
+ "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+ };
+ break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {

View File

@@ -0,0 +1,19 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index 2b9ace28..e60d3d8d 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -6052,10 +6052,10 @@ static bool llm_load_tensors(
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
- layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
- layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
- layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
- layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
// optional bias tensors
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);

View File

@@ -385,8 +385,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
filteredEnv := []string{} filteredEnv := []string{}
for _, ev := range s.cmd.Env { for _, ev := range s.cmd.Env {
if strings.HasPrefix(ev, "CUDA_") || if strings.HasPrefix(ev, "CUDA_") ||
strings.HasPrefix(ev, "ROCR_") ||
strings.HasPrefix(ev, "ROCM_") || strings.HasPrefix(ev, "ROCM_") ||
strings.HasPrefix(ev, "HIP_") || strings.HasPrefix(ev, "HIP_") ||
strings.HasPrefix(ev, "GPU_") ||
strings.HasPrefix(ev, "HSA_") || strings.HasPrefix(ev, "HSA_") ||
strings.HasPrefix(ev, "GGML_") || strings.HasPrefix(ev, "GGML_") ||
strings.HasPrefix(ev, "PATH=") || strings.HasPrefix(ev, "PATH=") ||

View File

@@ -351,7 +351,6 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
case string: case string:
messages = append(messages, api.Message{Role: msg.Role, Content: content}) messages = append(messages, api.Message{Role: msg.Role, Content: content})
case []any: case []any:
message := api.Message{Role: msg.Role}
for _, c := range content { for _, c := range content {
data, ok := c.(map[string]any) data, ok := c.(map[string]any)
if !ok { if !ok {
@@ -363,7 +362,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
if !ok { if !ok {
return nil, fmt.Errorf("invalid message format") return nil, fmt.Errorf("invalid message format")
} }
message.Content = text messages = append(messages, api.Message{Role: msg.Role, Content: text})
case "image_url": case "image_url":
var url string var url string
if urlMap, ok := data["image_url"].(map[string]any); ok { if urlMap, ok := data["image_url"].(map[string]any); ok {
@@ -395,12 +394,12 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid message format") return nil, fmt.Errorf("invalid message format")
} }
message.Images = append(message.Images, img)
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
default: default:
return nil, fmt.Errorf("invalid message format") return nil, fmt.Errorf("invalid message format")
} }
} }
messages = append(messages, message)
default: default:
if msg.ToolCalls == nil { if msg.ToolCalls == nil {
return nil, fmt.Errorf("invalid message content type: %T", content) return nil, fmt.Errorf("invalid message content type: %T", content)
@@ -878,6 +877,7 @@ func ChatMiddleware() gin.HandlerFunc {
chatReq, err := fromChatRequest(req) chatReq, err := fromChatRequest(req)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
} }
if err := json.NewEncoder(&b).Encode(chatReq); err != nil { if err := json.NewEncoder(&b).Encode(chatReq); err != nil {

View File

@@ -20,113 +20,59 @@ const prefix = `data:image/jpeg;base64,`
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
const imageURL = prefix + image const imageURL = prefix + image
func TestMiddlewareRequests(t *testing.T) { func prepareRequest(req *http.Request, body any) {
type testCase struct { bodyBytes, _ := json.Marshal(body)
Name string req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
Method string req.Header.Set("Content-Type", "application/json")
Path string }
Handler func() gin.HandlerFunc
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *http.Request)
}
var capturedRequest *http.Request func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
captureRequestMiddleware := func() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body) bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
capturedRequest = c.Request err := json.Unmarshal(bodyBytes, capturedRequest)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
}
c.Next() c.Next()
} }
}
func TestChatMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
} }
var capturedRequest *api.ChatRequest
testCases := []testCase{ testCases := []testCase{
{ {
Name: "chat handler", Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{ body := ChatCompletionRequest{
Model: "test-model", Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}}, Messages: []Message{{Role: "user", Content: "Hello"}},
} }
prepareRequest(req, body)
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}, },
Expected: func(t *testing.T, req *http.Request) { Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
var chatReq api.ChatRequest if resp.Code != http.StatusOK {
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil { t.Fatalf("expected 200, got %d", resp.Code)
t.Fatal(err)
} }
if chatReq.Messages[0].Role != "user" { if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role) t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
} }
if chatReq.Messages[0].Content != "Hello" { if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content) t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
}
},
},
{
Name: "completions handler",
Method: http.MethodPost,
Path: "/api/generate",
Handler: CompletionsMiddleware,
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
Suffix: "suffix",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, req *http.Request) {
var genReq api.GenerateRequest
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
t.Fatal(err)
}
if genReq.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
}
if genReq.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
}
stopTokens, ok := genReq.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
if genReq.Suffix != "suffix" {
t.Fatalf("expected 'suffix', got %s", genReq.Suffix)
} }
}, },
}, },
{ {
Name: "chat handler with image content", Name: "chat handler with image content",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{ body := ChatCompletionRequest{
Model: "test-model", Model: "test-model",
@@ -139,87 +85,254 @@ func TestMiddlewareRequests(t *testing.T) {
}, },
}, },
} }
prepareRequest(req, body)
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}, },
Expected: func(t *testing.T, req *http.Request) { Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
var chatReq api.ChatRequest if resp.Code != http.StatusOK {
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil { t.Fatalf("expected 200, got %d", resp.Code)
t.Fatal(err)
} }
if chatReq.Messages[0].Role != "user" { if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role) t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
} }
if chatReq.Messages[0].Content != "Hello" { if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content) t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
} }
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):]) img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if !bytes.Equal(chatReq.Messages[0].Images[0], img) { if req.Messages[1].Role != "user" {
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0]) t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
}
if !bytes.Equal(req.Messages[1].Images[0], img) {
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
} }
}, },
}, },
{
Name: "chat handler with tools",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{Role: "user", Content: "What's the weather like in Paris Today?"},
{Role: "assistant", ToolCalls: []ToolCall{{
ID: "id",
Type: "function",
Function: struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}{
Name: "get_current_weather",
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
},
}}},
},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != 200 {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
}
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
}
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
}
},
},
{
Name: "chat handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: 2}},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid message content type") {
t.Fatalf("error was not forwarded")
}
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/chat", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
tc.Setup(t, req)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
})
}
}
func TestCompletionsMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
}
var capturedRequest *api.GenerateRequest
testCases := []testCase{
{
Name: "completions handler",
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if req.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Prompt)
}
if req.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
}
stopTokens, ok := req.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
if req.Suffix != "suffix" {
t.Fatalf("expected 'suffix', got %s", req.Suffix)
}
},
},
{
Name: "completions handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: nil,
Stop: []int{1, 2},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
t.Fatalf("error was not forwarded")
}
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
tc.Setup(t, req)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
})
}
}
func TestEmbeddingsMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
}
var capturedRequest *api.EmbedRequest
testCases := []testCase{
{ {
Name: "embed handler single input", Name: "embed handler single input",
Method: http.MethodPost,
Path: "/api/embed",
Handler: EmbeddingsMiddleware,
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{ body := EmbedRequest{
Input: "Hello", Input: "Hello",
Model: "test-model", Model: "test-model",
} }
prepareRequest(req, body)
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}, },
Expected: func(t *testing.T, req *http.Request) { Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
var embedReq api.EmbedRequest if req.Input != "Hello" {
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil { t.Fatalf("expected 'Hello', got %s", req.Input)
t.Fatal(err)
} }
if embedReq.Input != "Hello" { if req.Model != "test-model" {
t.Fatalf("expected 'Hello', got %s", embedReq.Input) t.Fatalf("expected 'test-model', got %s", req.Model)
}
if embedReq.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
} }
}, },
}, },
{ {
Name: "embed handler batch input", Name: "embed handler batch input",
Method: http.MethodPost,
Path: "/api/embed",
Handler: EmbeddingsMiddleware,
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{ body := EmbedRequest{
Input: []string{"Hello", "World"}, Input: []string{"Hello", "World"},
Model: "test-model", Model: "test-model",
} }
prepareRequest(req, body)
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}, },
Expected: func(t *testing.T, req *http.Request) { Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
var embedReq api.EmbedRequest input, ok := req.Input.([]any)
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
t.Fatal(err)
}
input, ok := embedReq.Input.([]any)
if !ok { if !ok {
t.Fatalf("expected input to be a list") t.Fatalf("expected input to be a list")
@@ -233,36 +346,52 @@ func TestMiddlewareRequests(t *testing.T) {
t.Fatalf("expected 'World', got %s", input[1]) t.Fatalf("expected 'World', got %s", input[1])
} }
if embedReq.Model != "test-model" { if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", embedReq.Model) t.Fatalf("expected 'test-model', got %s", req.Model)
} }
}, },
}, },
{
Name: "embed handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
} }
gin.SetMode(gin.TestMode) if !strings.Contains(resp.Body.String(), "invalid input") {
router := gin.New() t.Fatalf("error was not forwarded")
}
},
},
}
endpoint := func(c *gin.Context) { endpoint := func(c *gin.Context) {
c.Status(http.StatusOK) c.Status(http.StatusOK)
} }
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/embed", endpoint)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
router = gin.New() req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
router.Use(captureRequestMiddleware())
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, endpoint)
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
if tc.Setup != nil {
tc.Setup(t, req) tc.Setup(t, req)
}
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest) tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
}) })
} }
} }
@@ -280,36 +409,6 @@ func TestMiddlewareResponses(t *testing.T) {
} }
testCases := []testCase{ testCases := []testCase{
{
Name: "completions handler error forwarding",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
t.Fatalf("error was not forwarded")
}
},
},
{ {
Name: "list handler", Name: "list handler",
Method: http.MethodGet, Method: http.MethodGet,
@@ -326,8 +425,6 @@ func TestMiddlewareResponses(t *testing.T) {
}) })
}, },
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var listResp ListCompletion var listResp ListCompletion
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
t.Fatal(err) t.Fatal(err)
@@ -391,6 +488,8 @@ func TestMiddlewareResponses(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
tc.Expected(t, resp) tc.Expected(t, resp)
}) })
} }

View File

@@ -492,6 +492,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
layers = append(layers, baseLayer.Layer) layers = append(layers, baseLayer.Layer)
} }
case "license", "template", "system": case "license", "template", "system":
if c.Name == "template" {
if _, err := template.Parse(c.Args); err != nil {
return fmt.Errorf("%w: %s", errBadTemplate, err)
}
}
if c.Name != "license" { if c.Name != "license" {
// replace // replace
layers = slices.DeleteFunc(layers, func(layer *Layer) bool { layers = slices.DeleteFunc(layers, func(layer *Layer) bool {

View File

@@ -56,6 +56,7 @@ func init() {
} }
var errRequired = errors.New("is required") var errRequired = errors.New("is required")
var errBadTemplate = errors.New("template error")
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions() opts := api.DefaultOptions()
@@ -609,6 +610,9 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
quantization := cmp.Or(r.Quantize, r.Quantization) quantization := cmp.Or(r.Quantize, r.Quantization)
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil { if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
if errors.Is(err, errBadTemplate) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
}
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@@ -1196,11 +1200,15 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
return return
} }
case gin.H: case gin.H:
status, ok := r["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
if errorMsg, ok := r["error"].(string); ok { if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg}) c.JSON(status, gin.H{"error": errorMsg})
return return
} else { } else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"}) c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
return return
} }
default: default:

View File

@@ -491,6 +491,42 @@ func TestCreateTemplateSystem(t *testing.T) {
if string(system) != "Say bye!" { if string(system) != "Say bye!" {
t.Errorf("expected \"Say bye!\", actual %s", system) t.Errorf("expected \"Say bye!\", actual %s", system)
} }
t.Run("incomplete template", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("template with unclosed if", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("template with undefined function", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
} }
func TestCreateLicenses(t *testing.T) { func TestCreateLicenses(t *testing.T) {

View File

@@ -73,8 +73,8 @@ func TestGenerateChat(t *testing.T) {
getCpuFn: gpu.GetCPUInfo, getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond, reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add 10ms delay to simulate loading // add small delay to simulate loading
time.Sleep(10 * time.Millisecond) time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
llama: &mock, llama: &mock,
} }
@@ -371,6 +371,8 @@ func TestGenerate(t *testing.T) {
getCpuFn: gpu.GetCPUInfo, getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond, reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{ req.successCh <- &runnerRef{
llama: &mock, llama: &mock,
} }

View File

@@ -94,7 +94,7 @@ func TestLoad(t *testing.T) {
require.Len(t, s.expiredCh, 1) require.Len(t, s.expiredCh, 1)
} }
type bundle struct { type reqBundle struct {
ctx context.Context //nolint:containedctx ctx context.Context //nolint:containedctx
ctxDone func() ctxDone func()
srv *mockLlm srv *mockLlm
@@ -102,13 +102,13 @@ type bundle struct {
ggml *llm.GGML ggml *llm.GGML
} }
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { func (scenario *reqBundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return scenario.srv, nil return scenario.srv, nil
} }
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle { func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle {
scenario := &bundle{} b := &reqBundle{}
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx) b.ctx, b.ctxDone = context.WithCancel(ctx)
t.Helper() t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName) f, err := os.CreateTemp(t.TempDir(), modelName)
@@ -135,124 +135,154 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
fname := f.Name() fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname} model := &Model{Name: modelName, ModelPath: fname}
scenario.ggml, err = llm.LoadModel(model.ModelPath, 0) b.ggml, err = llm.LoadModel(model.ModelPath, 0)
require.NoError(t, err) require.NoError(t, err)
scenario.req = &LlmRequest{ if duration == nil {
ctx: scenario.ctx, duration = &api.Duration{Duration: 5 * time.Millisecond}
}
b.req = &LlmRequest{
ctx: b.ctx,
model: model, model: model,
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond}, sessionDuration: duration,
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1), errCh: make(chan error, 1),
} }
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}} b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
return scenario return b
} }
func TestRequests(t *testing.T) { func getGpuFn() gpu.GpuInfoList {
ctx, done := context.WithTimeout(context.Background(), 10*time.Second)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
tmpModel := *scenario1a.req.model
scenario2a.req.model = &tmpModel
scenario2a.ggml = scenario1a.ggml
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"} g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g} return []gpu.GpuInfo{g}
} }
s.getCpuFn = func() gpu.GpuInfoList {
func getCpuFn() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "cpu"} g := gpu.GpuInfo{Library: "cpu"}
g.TotalMemory = 32 * format.GigaByte g.TotalMemory = 32 * format.GigaByte
g.FreeMemory = 26 * format.GigaByte g.FreeMemory = 26 * format.GigaByte
return []gpu.GpuInfo{g} return []gpu.GpuInfo{g}
} }
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a") func TestRequestsSameModelSameRequest(t *testing.T) {
s.pendingReqCh <- scenario1a.req ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
b.req.model = a.req.model
b.ggml = a.ggml
s.newServerFn = a.newServer
slog.Info("a")
s.pendingReqCh <- a.req
require.Len(t, s.pendingReqCh, 1) require.Len(t, s.pendingReqCh, 1)
s.Run(ctx) s.Run(ctx)
select { select {
case resp := <-scenario1a.req.successCh: case resp := <-a.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1a.req.errCh) require.Empty(t, a.req.errCh)
case err := <-scenario1a.req.errCh: case err := <-a.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
// Same runner as first request due to not needing a reload // Same runner as first request due to not needing a reload
s.newServerFn = scenario1b.newServer s.newServerFn = b.newServer
slog.Info("scenario1b") slog.Info("b")
s.pendingReqCh <- scenario1b.req s.pendingReqCh <- b.req
select { select {
case resp := <-scenario1b.req.successCh: case resp := <-b.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1b.req.errCh) require.Empty(t, b.req.errCh)
case err := <-scenario1b.req.errCh: case err := <-b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
}
func TestRequestsSimpleReloadSameModel(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
tmpModel := *a.req.model
b.req.model = &tmpModel
b.ggml = a.ggml
s.newServerFn = a.newServer
slog.Info("a")
s.pendingReqCh <- a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-a.req.successCh:
require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, a.req.errCh)
case err := <-a.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
// Trigger a reload // Trigger a reload
s.newServerFn = scenario2a.newServer s.newServerFn = b.newServer
scenario2a.req.model.AdapterPaths = []string{"new"} b.req.model.AdapterPaths = []string{"new"}
slog.Info("scenario2a") slog.Info("b")
s.pendingReqCh <- scenario2a.req s.pendingReqCh <- b.req
// finish first two requests, so model can reload // finish first two requests, so model can reload
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
scenario1a.ctxDone() a.ctxDone()
scenario1b.ctxDone()
select { select {
case resp := <-scenario2a.req.successCh: case resp := <-b.req.successCh:
require.Equal(t, resp.llama, scenario2a.srv) require.Equal(t, resp.llama, b.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario2a.req.errCh) require.Empty(t, b.req.errCh)
case err := <-scenario2a.req.errCh: case err := <-b.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
}
func TestRequestsMultipleLoadedModels(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
// Multiple loaded models
a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil)
c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil)
c.req.opts.NumGPU = 0 // CPU load, will be allowed
d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded
envconfig.MaxRunners = 1 envconfig.MaxRunners = 1
s.newServerFn = scenario3a.newServer s.newServerFn = a.newServer
slog.Info("scenario3a") slog.Info("a")
s.pendingReqCh <- scenario3a.req s.pendingReqCh <- a.req
// finish prior request, so new model can load s.Run(ctx)
time.Sleep(1 * time.Millisecond)
scenario2a.ctxDone()
select { select {
case resp := <-scenario3a.req.successCh: case resp := <-a.req.successCh:
require.Equal(t, resp.llama, scenario3a.srv) require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3a.req.errCh) require.Empty(t, a.req.errCh)
case err := <-scenario3a.req.errCh: case err := <-a.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
@@ -262,15 +292,15 @@ func TestRequests(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
envconfig.MaxRunners = 0 envconfig.MaxRunners = 0
s.newServerFn = scenario3b.newServer s.newServerFn = b.newServer
slog.Info("scenario3b") slog.Info("b")
s.pendingReqCh <- scenario3b.req s.pendingReqCh <- b.req
select { select {
case resp := <-scenario3b.req.successCh: case resp := <-b.req.successCh:
require.Equal(t, resp.llama, scenario3b.srv) require.Equal(t, resp.llama, b.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3b.req.errCh) require.Empty(t, b.req.errCh)
case err := <-scenario3b.req.errCh: case err := <-b.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
@@ -280,15 +310,15 @@ func TestRequests(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
// This is a CPU load with NumGPU = 0 so it should load // This is a CPU load with NumGPU = 0 so it should load
s.newServerFn = scenario3c.newServer s.newServerFn = c.newServer
slog.Info("scenario3c") slog.Info("c")
s.pendingReqCh <- scenario3c.req s.pendingReqCh <- c.req
select { select {
case resp := <-scenario3c.req.successCh: case resp := <-c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv) require.Equal(t, resp.llama, c.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3c.req.errCh) require.Empty(t, c.req.errCh)
case err := <-scenario3c.req.errCh: case err := <-c.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
@@ -298,25 +328,25 @@ func TestRequests(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
// Try to load a model that wont fit // Try to load a model that wont fit
s.newServerFn = scenario3d.newServer s.newServerFn = d.newServer
slog.Info("scenario3d") slog.Info("d")
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 3) require.Len(t, s.loaded, 3)
s.loadedMu.Unlock() s.loadedMu.Unlock()
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond) time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- scenario3d.req s.pendingReqCh <- d.req
// finish prior request, so new model can load // finish prior request, so new model can load
time.Sleep(6 * time.Millisecond) time.Sleep(6 * time.Millisecond)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 2) require.Len(t, s.loaded, 2)
s.loadedMu.Unlock() s.loadedMu.Unlock()
scenario3b.ctxDone() b.ctxDone()
select { select {
case resp := <-scenario3d.req.successCh: case resp := <-d.req.successCh:
require.Equal(t, resp.llama, scenario3d.srv) require.Equal(t, resp.llama, d.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3d.req.errCh) require.Empty(t, d.req.errCh)
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@@ -329,26 +359,19 @@ func TestGetRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done() defer done()
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
scenario1a.req.sessionDuration = &api.Duration{Duration: 0} b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10) c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
envconfig.MaxQueuedRequests = 1 envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList { s.getGpuFn = getGpuFn
g := gpu.GpuInfo{Library: "metal"} s.getCpuFn = getCpuFn
g.TotalMemory = 24 * format.GigaByte s.newServerFn = a.newServer
g.FreeMemory = 12 * format.GigaByte slog.Info("a")
return []gpu.GpuInfo{g} successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1) require.Len(t, s.pendingReqCh, 1)
slog.Info("scenario1b") slog.Info("b")
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration) successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1) require.Len(t, s.pendingReqCh, 1)
require.Empty(t, successCh1b) require.Empty(t, successCh1b)
require.Len(t, errCh1b, 1) require.Len(t, errCh1b, 1)
@@ -357,22 +380,24 @@ func TestGetRunner(t *testing.T) {
s.Run(ctx) s.Run(ctx)
select { select {
case resp := <-successCh1a: case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, errCh1a) require.Empty(t, errCh1a)
case err := <-errCh1a:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
scenario1a.ctxDone() a.ctxDone() // Set "a" model to idle so it can unload
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 1) require.Len(t, s.loaded, 1)
s.loadedMu.Unlock() s.loadedMu.Unlock()
scenario1c.req.model.ModelPath = "bad path" c.req.model.ModelPath = "bad path"
slog.Info("scenario1c") slog.Info("c")
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration) successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
// Starts in pending channel, then should be quickly processsed to return an error // Starts in pending channel, then should be quickly processsed to return an error
time.Sleep(5 * time.Millisecond) time.Sleep(20 * time.Millisecond) // Long enough for the "a" model to expire and unload
require.Empty(t, successCh1c) require.Empty(t, successCh1c)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Empty(t, s.loaded) require.Empty(t, s.loaded)
@@ -380,7 +405,7 @@ func TestGetRunner(t *testing.T) {
require.Len(t, errCh1c, 1) require.Len(t, errCh1c, 1)
err = <-errCh1c err = <-errCh1c
require.Contains(t, err.Error(), "bad path") require.Contains(t, err.Error(), "bad path")
scenario1b.ctxDone() b.ctxDone()
} }
// TODO - add one scenario that triggers the bogus finished event with positive ref count // TODO - add one scenario that triggers the bogus finished event with positive ref count
@@ -389,7 +414,7 @@ func TestPrematureExpired(t *testing.T) {
defer done() defer done()
// Same model, same request // Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil)
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList { s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"} g := gpu.GpuInfo{Library: "metal"}
@@ -411,6 +436,8 @@ func TestPrematureExpired(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
slog.Info("sending premature expired event now") slog.Info("sending premature expired event now")
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
case err := <-errCh1a:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@@ -446,6 +473,8 @@ func TestUseLoadedRunner(t *testing.T) {
select { select {
case success := <-req.successCh: case success := <-req.successCh:
require.Equal(t, r1, success) require.Equal(t, r1, success)
case err := <-req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@@ -625,8 +654,7 @@ func TestAlreadyCanceled(t *testing.T) {
defer done() defer done()
dctx, done2 := context.WithCancel(ctx) dctx, done2 := context.WithCancel(ctx)
done2() done2()
scenario1a := newScenario(t, dctx, "ollama-model-1", 10) scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0})
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
s := InitScheduler(ctx) s := InitScheduler(ctx)
slog.Info("scenario1a") slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req s.pendingReqCh <- scenario1a.req

View File

@@ -264,6 +264,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool { nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") { if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
cut = true cut = true
return false
} }
return cut return cut
@@ -273,7 +274,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{ if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": system, "System": system,
"Prompt": prompt, "Prompt": prompt,
"Response": "", "Response": response,
}); err != nil { }); err != nil {
return err return err
} }

View File

@@ -260,6 +260,26 @@ func TestExecuteWithMessages(t *testing.T) {
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `, Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
}, },
{
"mistral assistant",
[]template{
{"no response", `[INST] {{ .Prompt }}[/INST] `},
{"response", `[INST] {{ .Prompt }}[/INST] {{ .Response }}`},
{"messages", `
{{- range $i, $m := .Messages }}
{{- if eq .Role "user" }}[INST] {{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}{{ end }}
{{- end }}`},
},
Values{
Messages: []api.Message{
{Role: "user", Content: "Hello friend!"},
{Role: "assistant", Content: "Hello human!"},
{Role: "user", Content: "What is your name?"},
{Role: "assistant", Content: "My name is Ollama and I"},
},
},
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I`,
},
{ {
"chatml", "chatml",
[]template{ []template{