mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
megrge upstream update and reslove the conflicts
This commit is contained in:
15
docs/gpu.md
15
docs/gpu.md
@@ -46,13 +46,24 @@ sudo modprobe nvidia_uvm`
|
|||||||
|
|
||||||
## AMD Radeon
|
## AMD Radeon
|
||||||
Ollama supports the following AMD GPUs:
|
Ollama supports the following AMD GPUs:
|
||||||
|
|
||||||
|
### Linux Support
|
||||||
| Family | Cards and accelerators |
|
| Family | Cards and accelerators |
|
||||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
|
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
|
||||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
|
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
|
||||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
|
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
|
||||||
|
|
||||||
### Overrides
|
### Windows Support
|
||||||
|
With ROCm v6.1, the following GPUs are supported on Windows.
|
||||||
|
|
||||||
|
| Family | Cards and accelerators |
|
||||||
|
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
|
||||||
|
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
|
||||||
|
|
||||||
|
|
||||||
|
### Overrides on Linux
|
||||||
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
|
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
|
||||||
some cases you can force the system to try to use a similar LLVM target that is
|
some cases you can force the system to try to use a similar LLVM target that is
|
||||||
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
|
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
|
||||||
@@ -63,7 +74,7 @@ would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the
|
|||||||
server. If you have an unsupported AMD GPU you can experiment using the list of
|
server. If you have an unsupported AMD GPU you can experiment using the list of
|
||||||
supported types below.
|
supported types below.
|
||||||
|
|
||||||
At this time, the known supported GPU types are the following LLVM Targets.
|
At this time, the known supported GPU types on linux are the following LLVM Targets.
|
||||||
This table shows some example GPUs that map to these LLVM targets:
|
This table shows some example GPUs that map to these LLVM targets:
|
||||||
| **LLVM Target** | **An Example GPU** |
|
| **LLVM Target** | **An Example GPU** |
|
||||||
|-----------------|---------------------|
|
|-----------------|---------------------|
|
||||||
|
|||||||
@@ -33,9 +33,10 @@ type HipLib struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewHipLib() (*HipLib, error) {
|
func NewHipLib() (*HipLib, error) {
|
||||||
|
// At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs/ this repo will consist with v5.7
|
||||||
h, err := windows.LoadLibrary("amdhip64.dll")
|
h, err := windows.LoadLibrary("amdhip64.dll")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err)
|
return nil, fmt.Errorf("unable to load amdhip64.dll, please make sure to upgrade to the latest amd driver: %w", err)
|
||||||
}
|
}
|
||||||
hl := &HipLib{}
|
hl := &HipLib{}
|
||||||
hl.dll = h
|
hl.dll = h
|
||||||
|
|||||||
@@ -92,8 +92,9 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if gfxOverride == "" {
|
if gfxOverride == "" {
|
||||||
if !slices.Contains[[]string, string](supported, gfx) {
|
// Strip off Target Features when comparing
|
||||||
//slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) {
|
||||||
|
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
||||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
func TestContextExhaustion(t *testing.T) {
|
func TestContextExhaustion(t *testing.T) {
|
||||||
// Longer needed for small footprint GPUs
|
// Longer needed for small footprint GPUs
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
@@ -25,5 +25,10 @@ func TestContextExhaustion(t *testing.T) {
|
|||||||
"num_ctx": 128,
|
"num_ctx": 128,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatalf("PullIfMissing failed: %v", err)
|
||||||
|
}
|
||||||
|
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
||||||
}
|
}
|
||||||
|
|||||||
43
llm/patches/10-tekken.diff
Normal file
43
llm/patches/10-tekken.diff
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
diff --git a/include/llama.h b/include/llama.h
|
||||||
|
index bb4b05ba..a92174e0 100644
|
||||||
|
--- a/include/llama.h
|
||||||
|
+++ b/include/llama.h
|
||||||
|
@@ -92,6 +92,7 @@ extern "C" {
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
|
||||||
|
+ LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
|
||||||
|
};
|
||||||
|
|
||||||
|
// note: these values should be synchronized with ggml_rope
|
||||||
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
|
index 18364976..435b6fe5 100644
|
||||||
|
--- a/src/llama.cpp
|
||||||
|
+++ b/src/llama.cpp
|
||||||
|
@@ -5429,6 +5429,12 @@ static void llm_load_vocab(
|
||||||
|
} else if (
|
||||||
|
tokenizer_pre == "jais") {
|
||||||
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
|
||||||
|
+ } else if (
|
||||||
|
+ tokenizer_pre == "tekken") {
|
||||||
|
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN;
|
||||||
|
+ vocab.tokenizer_clean_spaces = false;
|
||||||
|
+ vocab.tokenizer_ignore_merges = true;
|
||||||
|
+ vocab.tokenizer_add_bos = true;
|
||||||
|
} else {
|
||||||
|
LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
||||||
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
|
@@ -15448,6 +15454,13 @@ struct llm_tokenizer_bpe {
|
||||||
|
" ?[^(\\s|.,!?…。,、।۔،)]+",
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
+ case LLAMA_VOCAB_PRE_TYPE_TEKKEN:
|
||||||
|
+ // original regex from tokenizer.json
|
||||||
|
+ // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||||
|
+ regex_exprs = {
|
||||||
|
+ "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
+ };
|
||||||
|
+ break;
|
||||||
|
default:
|
||||||
|
// default regex for BPE tokenization pre-processing
|
||||||
|
regex_exprs = {
|
||||||
19
llm/patches/11-embd_kv.diff
Normal file
19
llm/patches/11-embd_kv.diff
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
|
index 2b9ace28..e60d3d8d 100644
|
||||||
|
--- a/src/llama.cpp
|
||||||
|
+++ b/src/llama.cpp
|
||||||
|
@@ -6052,10 +6052,10 @@ static bool llm_load_tensors(
|
||||||
|
|
||||||
|
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||||
|
|
||||||
|
- layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||||
|
- layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
|
||||||
|
- layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||||
|
- layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||||
|
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
|
||||||
|
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
|
||||||
|
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
|
||||||
|
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
|
||||||
|
|
||||||
|
// optional bias tensors
|
||||||
|
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
@@ -385,8 +385,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
filteredEnv := []string{}
|
filteredEnv := []string{}
|
||||||
for _, ev := range s.cmd.Env {
|
for _, ev := range s.cmd.Env {
|
||||||
if strings.HasPrefix(ev, "CUDA_") ||
|
if strings.HasPrefix(ev, "CUDA_") ||
|
||||||
|
strings.HasPrefix(ev, "ROCR_") ||
|
||||||
strings.HasPrefix(ev, "ROCM_") ||
|
strings.HasPrefix(ev, "ROCM_") ||
|
||||||
strings.HasPrefix(ev, "HIP_") ||
|
strings.HasPrefix(ev, "HIP_") ||
|
||||||
|
strings.HasPrefix(ev, "GPU_") ||
|
||||||
strings.HasPrefix(ev, "HSA_") ||
|
strings.HasPrefix(ev, "HSA_") ||
|
||||||
strings.HasPrefix(ev, "GGML_") ||
|
strings.HasPrefix(ev, "GGML_") ||
|
||||||
strings.HasPrefix(ev, "PATH=") ||
|
strings.HasPrefix(ev, "PATH=") ||
|
||||||
|
|||||||
@@ -351,7 +351,6 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
case string:
|
case string:
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
||||||
case []any:
|
case []any:
|
||||||
message := api.Message{Role: msg.Role}
|
|
||||||
for _, c := range content {
|
for _, c := range content {
|
||||||
data, ok := c.(map[string]any)
|
data, ok := c.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -363,7 +362,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
message.Content = text
|
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
||||||
case "image_url":
|
case "image_url":
|
||||||
var url string
|
var url string
|
||||||
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
||||||
@@ -395,12 +394,12 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
message.Images = append(message.Images, img)
|
|
||||||
|
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
messages = append(messages, message)
|
|
||||||
default:
|
default:
|
||||||
if msg.ToolCalls == nil {
|
if msg.ToolCalls == nil {
|
||||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||||
@@ -878,6 +877,7 @@ func ChatMiddleware() gin.HandlerFunc {
|
|||||||
chatReq, err := fromChatRequest(req)
|
chatReq, err := fromChatRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||||
|
|||||||
@@ -20,113 +20,59 @@ const prefix = `data:image/jpeg;base64,`
|
|||||||
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
const imageURL = prefix + image
|
const imageURL = prefix + image
|
||||||
|
|
||||||
func TestMiddlewareRequests(t *testing.T) {
|
func prepareRequest(req *http.Request, body any) {
|
||||||
type testCase struct {
|
bodyBytes, _ := json.Marshal(body)
|
||||||
Name string
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
Method string
|
req.Header.Set("Content-Type", "application/json")
|
||||||
Path string
|
}
|
||||||
Handler func() gin.HandlerFunc
|
|
||||||
Setup func(t *testing.T, req *http.Request)
|
|
||||||
Expected func(t *testing.T, req *http.Request)
|
|
||||||
}
|
|
||||||
|
|
||||||
var capturedRequest *http.Request
|
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||||
|
|
||||||
captureRequestMiddleware := func() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
capturedRequest = c.Request
|
err := json.Unmarshal(bodyBytes, capturedRequest)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
|
||||||
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.ChatRequest
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "chat handler",
|
Name: "chat handler",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/chat",
|
|
||||||
Handler: ChatMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := ChatCompletionRequest{
|
body := ChatCompletionRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
var chatReq api.ChatRequest
|
if resp.Code != http.StatusOK {
|
||||||
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
t.Fatalf("expected 200, got %d", resp.Code)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Role != "user" {
|
if req.Messages[0].Role != "user" {
|
||||||
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Content != "Hello" {
|
if req.Messages[0].Content != "Hello" {
|
||||||
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "completions handler",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/generate",
|
|
||||||
Handler: CompletionsMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
temp := float32(0.8)
|
|
||||||
body := CompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Prompt: "Hello",
|
|
||||||
Temperature: &temp,
|
|
||||||
Stop: []string{"\n", "stop"},
|
|
||||||
Suffix: "suffix",
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
|
||||||
var genReq api.GenerateRequest
|
|
||||||
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if genReq.Prompt != "Hello" {
|
|
||||||
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
if genReq.Options["temperature"] != 1.6 {
|
|
||||||
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
|
|
||||||
}
|
|
||||||
|
|
||||||
stopTokens, ok := genReq.Options["stop"].([]any)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected stop tokens to be a list")
|
|
||||||
}
|
|
||||||
|
|
||||||
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
|
||||||
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
if genReq.Suffix != "suffix" {
|
|
||||||
t.Fatalf("expected 'suffix', got %s", genReq.Suffix)
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "chat handler with image content",
|
Name: "chat handler with image content",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/chat",
|
|
||||||
Handler: ChatMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := ChatCompletionRequest{
|
body := ChatCompletionRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
@@ -139,87 +85,254 @@ func TestMiddlewareRequests(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
var chatReq api.ChatRequest
|
if resp.Code != http.StatusOK {
|
||||||
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
t.Fatalf("expected 200, got %d", resp.Code)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Role != "user" {
|
if req.Messages[0].Role != "user" {
|
||||||
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Content != "Hello" {
|
if req.Messages[0].Content != "Hello" {
|
||||||
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
||||||
|
|
||||||
if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
|
if req.Messages[1].Role != "user" {
|
||||||
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
|
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(req.Messages[1].Images[0], img) {
|
||||||
|
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "chat handler with tools",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{
|
||||||
|
{Role: "user", Content: "What's the weather like in Paris Today?"},
|
||||||
|
{Role: "assistant", ToolCalls: []ToolCall{{
|
||||||
|
ID: "id",
|
||||||
|
Type: "function",
|
||||||
|
Function: struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
|
||||||
|
},
|
||||||
|
}}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != 200 {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
|
||||||
|
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
|
||||||
|
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
|
||||||
|
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "chat handler error forwarding",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{{Role: "user", Content: 2}},
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(resp.Body.String(), "invalid message content type") {
|
||||||
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
|
||||||
|
|
||||||
|
tc.Setup(t, req)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
tc.Expected(t, capturedRequest, resp)
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompletionsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.GenerateRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
Name: "completions handler",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
temp := float32(0.8)
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Temperature: &temp,
|
||||||
|
Stop: []string{"\n", "stop"},
|
||||||
|
Suffix: "suffix",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if req.Prompt != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", req.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Options["temperature"] != 1.6 {
|
||||||
|
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
|
||||||
|
}
|
||||||
|
|
||||||
|
stopTokens, ok := req.Options["stop"].([]any)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected stop tokens to be a list")
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
||||||
|
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Suffix != "suffix" {
|
||||||
|
t.Fatalf("expected 'suffix', got %s", req.Suffix)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "completions handler error forwarding",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Temperature: nil,
|
||||||
|
Stop: []int{1, 2},
|
||||||
|
Suffix: "suffix",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
|
||||||
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
|
||||||
|
|
||||||
|
tc.Setup(t, req)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
tc.Expected(t, capturedRequest, resp)
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.EmbedRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "embed handler single input",
|
Name: "embed handler single input",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/embed",
|
|
||||||
Handler: EmbeddingsMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := EmbedRequest{
|
body := EmbedRequest{
|
||||||
Input: "Hello",
|
Input: "Hello",
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||||
var embedReq api.EmbedRequest
|
if req.Input != "Hello" {
|
||||||
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
t.Fatalf("expected 'Hello', got %s", req.Input)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if embedReq.Input != "Hello" {
|
if req.Model != "test-model" {
|
||||||
t.Fatalf("expected 'Hello', got %s", embedReq.Input)
|
t.Fatalf("expected 'test-model', got %s", req.Model)
|
||||||
}
|
|
||||||
|
|
||||||
if embedReq.Model != "test-model" {
|
|
||||||
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "embed handler batch input",
|
Name: "embed handler batch input",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/embed",
|
|
||||||
Handler: EmbeddingsMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := EmbedRequest{
|
body := EmbedRequest{
|
||||||
Input: []string{"Hello", "World"},
|
Input: []string{"Hello", "World"},
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||||
var embedReq api.EmbedRequest
|
input, ok := req.Input.([]any)
|
||||||
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
input, ok := embedReq.Input.([]any)
|
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("expected input to be a list")
|
t.Fatalf("expected input to be a list")
|
||||||
@@ -233,36 +346,52 @@ func TestMiddlewareRequests(t *testing.T) {
|
|||||||
t.Fatalf("expected 'World', got %s", input[1])
|
t.Fatalf("expected 'World', got %s", input[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
if embedReq.Model != "test-model" {
|
if req.Model != "test-model" {
|
||||||
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
t.Fatalf("expected 'test-model', got %s", req.Model)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "embed handler error forwarding",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := EmbedRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
if !strings.Contains(resp.Body.String(), "invalid input") {
|
||||||
router := gin.New()
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
endpoint := func(c *gin.Context) {
|
endpoint := func(c *gin.Context) {
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
router = gin.New()
|
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
|
||||||
router.Use(captureRequestMiddleware())
|
|
||||||
router.Use(tc.Handler())
|
|
||||||
router.Handle(tc.Method, tc.Path, endpoint)
|
|
||||||
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
|
|
||||||
|
|
||||||
if tc.Setup != nil {
|
|
||||||
tc.Setup(t, req)
|
tc.Setup(t, req)
|
||||||
}
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
tc.Expected(t, capturedRequest)
|
tc.Expected(t, capturedRequest, resp)
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -280,36 +409,6 @@ func TestMiddlewareResponses(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
|
||||||
Name: "completions handler error forwarding",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/generate",
|
|
||||||
TestPath: "/api/generate",
|
|
||||||
Handler: CompletionsMiddleware,
|
|
||||||
Endpoint: func(c *gin.Context) {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
|
||||||
},
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
body := CompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Prompt: "Hello",
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
||||||
if resp.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
|
|
||||||
t.Fatalf("error was not forwarded")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
Name: "list handler",
|
Name: "list handler",
|
||||||
Method: http.MethodGet,
|
Method: http.MethodGet,
|
||||||
@@ -326,8 +425,6 @@ func TestMiddlewareResponses(t *testing.T) {
|
|||||||
})
|
})
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
assert.Equal(t, http.StatusOK, resp.Code)
|
|
||||||
|
|
||||||
var listResp ListCompletion
|
var listResp ListCompletion
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -391,6 +488,8 @@ func TestMiddlewareResponses(t *testing.T) {
|
|||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
|
||||||
tc.Expected(t, resp)
|
tc.Expected(t, resp)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -492,6 +492,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
|||||||
layers = append(layers, baseLayer.Layer)
|
layers = append(layers, baseLayer.Layer)
|
||||||
}
|
}
|
||||||
case "license", "template", "system":
|
case "license", "template", "system":
|
||||||
|
if c.Name == "template" {
|
||||||
|
if _, err := template.Parse(c.Args); err != nil {
|
||||||
|
return fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if c.Name != "license" {
|
if c.Name != "license" {
|
||||||
// replace
|
// replace
|
||||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var errRequired = errors.New("is required")
|
var errRequired = errors.New("is required")
|
||||||
|
var errBadTemplate = errors.New("template error")
|
||||||
|
|
||||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
@@ -609,6 +610,9 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
|||||||
|
|
||||||
quantization := cmp.Or(r.Quantize, r.Quantization)
|
quantization := cmp.Or(r.Quantize, r.Quantization)
|
||||||
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
|
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
|
||||||
|
if errors.Is(err, errBadTemplate) {
|
||||||
|
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||||
|
}
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -1196,11 +1200,15 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case gin.H:
|
case gin.H:
|
||||||
|
status, ok := r["status"].(int)
|
||||||
|
if !ok {
|
||||||
|
status = http.StatusInternalServerError
|
||||||
|
}
|
||||||
if errorMsg, ok := r["error"].(string); ok {
|
if errorMsg, ok := r["error"].(string); ok {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
c.JSON(status, gin.H{"error": errorMsg})
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
|
c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -491,6 +491,42 @@ func TestCreateTemplateSystem(t *testing.T) {
|
|||||||
if string(system) != "Say bye!" {
|
if string(system) != "Say bye!" {
|
||||||
t.Errorf("expected \"Say bye!\", actual %s", system)
|
t.Errorf("expected \"Say bye!\", actual %s", system)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("incomplete template", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("template with unclosed if", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("template with undefined function", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateLicenses(t *testing.T) {
|
func TestCreateLicenses(t *testing.T) {
|
||||||
|
|||||||
@@ -73,8 +73,8 @@ func TestGenerateChat(t *testing.T) {
|
|||||||
getCpuFn: gpu.GetCPUInfo,
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
reschedDelay: 250 * time.Millisecond,
|
reschedDelay: 250 * time.Millisecond,
|
||||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
// add 10ms delay to simulate loading
|
// add small delay to simulate loading
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(time.Millisecond)
|
||||||
req.successCh <- &runnerRef{
|
req.successCh <- &runnerRef{
|
||||||
llama: &mock,
|
llama: &mock,
|
||||||
}
|
}
|
||||||
@@ -371,6 +371,8 @@ func TestGenerate(t *testing.T) {
|
|||||||
getCpuFn: gpu.GetCPUInfo,
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
reschedDelay: 250 * time.Millisecond,
|
reschedDelay: 250 * time.Millisecond,
|
||||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
|
// add small delay to simulate loading
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
req.successCh <- &runnerRef{
|
req.successCh <- &runnerRef{
|
||||||
llama: &mock,
|
llama: &mock,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func TestLoad(t *testing.T) {
|
|||||||
require.Len(t, s.expiredCh, 1)
|
require.Len(t, s.expiredCh, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
type bundle struct {
|
type reqBundle struct {
|
||||||
ctx context.Context //nolint:containedctx
|
ctx context.Context //nolint:containedctx
|
||||||
ctxDone func()
|
ctxDone func()
|
||||||
srv *mockLlm
|
srv *mockLlm
|
||||||
@@ -102,13 +102,13 @@ type bundle struct {
|
|||||||
ggml *llm.GGML
|
ggml *llm.GGML
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
func (scenario *reqBundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
return scenario.srv, nil
|
return scenario.srv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
|
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle {
|
||||||
scenario := &bundle{}
|
b := &reqBundle{}
|
||||||
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
|
b.ctx, b.ctxDone = context.WithCancel(ctx)
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||||
@@ -135,124 +135,154 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
|||||||
|
|
||||||
fname := f.Name()
|
fname := f.Name()
|
||||||
model := &Model{Name: modelName, ModelPath: fname}
|
model := &Model{Name: modelName, ModelPath: fname}
|
||||||
scenario.ggml, err = llm.LoadModel(model.ModelPath, 0)
|
b.ggml, err = llm.LoadModel(model.ModelPath, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
scenario.req = &LlmRequest{
|
if duration == nil {
|
||||||
ctx: scenario.ctx,
|
duration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||||
|
}
|
||||||
|
b.req = &LlmRequest{
|
||||||
|
ctx: b.ctx,
|
||||||
model: model,
|
model: model,
|
||||||
opts: api.DefaultOptions(),
|
opts: api.DefaultOptions(),
|
||||||
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
|
sessionDuration: duration,
|
||||||
successCh: make(chan *runnerRef, 1),
|
successCh: make(chan *runnerRef, 1),
|
||||||
errCh: make(chan error, 1),
|
errCh: make(chan error, 1),
|
||||||
}
|
}
|
||||||
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
||||||
return scenario
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequests(t *testing.T) {
|
func getGpuFn() gpu.GpuInfoList {
|
||||||
ctx, done := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer done()
|
|
||||||
|
|
||||||
// Same model, same request
|
|
||||||
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
|
|
||||||
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
|
||||||
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
|
|
||||||
scenario1b.req.model = scenario1a.req.model
|
|
||||||
scenario1b.ggml = scenario1a.ggml
|
|
||||||
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
|
||||||
|
|
||||||
// simple reload of same model
|
|
||||||
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
|
|
||||||
tmpModel := *scenario1a.req.model
|
|
||||||
scenario2a.req.model = &tmpModel
|
|
||||||
scenario2a.ggml = scenario1a.ggml
|
|
||||||
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
|
||||||
|
|
||||||
// Multiple loaded models
|
|
||||||
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
|
|
||||||
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
|
|
||||||
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
|
|
||||||
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
|
||||||
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
|
|
||||||
|
|
||||||
s := InitScheduler(ctx)
|
|
||||||
s.getGpuFn = func() gpu.GpuInfoList {
|
|
||||||
g := gpu.GpuInfo{Library: "metal"}
|
g := gpu.GpuInfo{Library: "metal"}
|
||||||
g.TotalMemory = 24 * format.GigaByte
|
g.TotalMemory = 24 * format.GigaByte
|
||||||
g.FreeMemory = 12 * format.GigaByte
|
g.FreeMemory = 12 * format.GigaByte
|
||||||
return []gpu.GpuInfo{g}
|
return []gpu.GpuInfo{g}
|
||||||
}
|
}
|
||||||
s.getCpuFn = func() gpu.GpuInfoList {
|
|
||||||
|
func getCpuFn() gpu.GpuInfoList {
|
||||||
g := gpu.GpuInfo{Library: "cpu"}
|
g := gpu.GpuInfo{Library: "cpu"}
|
||||||
g.TotalMemory = 32 * format.GigaByte
|
g.TotalMemory = 32 * format.GigaByte
|
||||||
g.FreeMemory = 26 * format.GigaByte
|
g.FreeMemory = 26 * format.GigaByte
|
||||||
return []gpu.GpuInfo{g}
|
return []gpu.GpuInfo{g}
|
||||||
}
|
}
|
||||||
s.newServerFn = scenario1a.newServer
|
|
||||||
slog.Info("scenario1a")
|
func TestRequestsSameModelSameRequest(t *testing.T) {
|
||||||
s.pendingReqCh <- scenario1a.req
|
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.getGpuFn = getGpuFn
|
||||||
|
s.getCpuFn = getCpuFn
|
||||||
|
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
||||||
|
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
|
||||||
|
b.req.model = a.req.model
|
||||||
|
b.ggml = a.ggml
|
||||||
|
|
||||||
|
s.newServerFn = a.newServer
|
||||||
|
slog.Info("a")
|
||||||
|
s.pendingReqCh <- a.req
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
s.Run(ctx)
|
s.Run(ctx)
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario1a.req.successCh:
|
case resp := <-a.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario1a.req.errCh)
|
require.Empty(t, a.req.errCh)
|
||||||
case err := <-scenario1a.req.errCh:
|
case err := <-a.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same runner as first request due to not needing a reload
|
// Same runner as first request due to not needing a reload
|
||||||
s.newServerFn = scenario1b.newServer
|
s.newServerFn = b.newServer
|
||||||
slog.Info("scenario1b")
|
slog.Info("b")
|
||||||
s.pendingReqCh <- scenario1b.req
|
s.pendingReqCh <- b.req
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario1b.req.successCh:
|
case resp := <-b.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario1b.req.errCh)
|
require.Empty(t, b.req.errCh)
|
||||||
case err := <-scenario1b.req.errCh:
|
case err := <-b.req.errCh:
|
||||||
|
t.Fatal(err.Error())
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Fatal("timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestsSimpleReloadSameModel(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.getGpuFn = getGpuFn
|
||||||
|
s.getCpuFn = getCpuFn
|
||||||
|
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
||||||
|
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
|
||||||
|
tmpModel := *a.req.model
|
||||||
|
b.req.model = &tmpModel
|
||||||
|
b.ggml = a.ggml
|
||||||
|
|
||||||
|
s.newServerFn = a.newServer
|
||||||
|
slog.Info("a")
|
||||||
|
s.pendingReqCh <- a.req
|
||||||
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
|
s.Run(ctx)
|
||||||
|
select {
|
||||||
|
case resp := <-a.req.successCh:
|
||||||
|
require.Equal(t, resp.llama, a.srv)
|
||||||
|
require.Empty(t, s.pendingReqCh)
|
||||||
|
require.Empty(t, a.req.errCh)
|
||||||
|
case err := <-a.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trigger a reload
|
// Trigger a reload
|
||||||
s.newServerFn = scenario2a.newServer
|
s.newServerFn = b.newServer
|
||||||
scenario2a.req.model.AdapterPaths = []string{"new"}
|
b.req.model.AdapterPaths = []string{"new"}
|
||||||
slog.Info("scenario2a")
|
slog.Info("b")
|
||||||
s.pendingReqCh <- scenario2a.req
|
s.pendingReqCh <- b.req
|
||||||
// finish first two requests, so model can reload
|
// finish first two requests, so model can reload
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
scenario1a.ctxDone()
|
a.ctxDone()
|
||||||
scenario1b.ctxDone()
|
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario2a.req.successCh:
|
case resp := <-b.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario2a.srv)
|
require.Equal(t, resp.llama, b.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario2a.req.errCh)
|
require.Empty(t, b.req.errCh)
|
||||||
case err := <-scenario2a.req.errCh:
|
case err := <-b.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.getGpuFn = getGpuFn
|
||||||
|
s.getCpuFn = getCpuFn
|
||||||
|
|
||||||
|
// Multiple loaded models
|
||||||
|
a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
|
||||||
|
b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil)
|
||||||
|
c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil)
|
||||||
|
c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||||
|
d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded
|
||||||
|
|
||||||
envconfig.MaxRunners = 1
|
envconfig.MaxRunners = 1
|
||||||
s.newServerFn = scenario3a.newServer
|
s.newServerFn = a.newServer
|
||||||
slog.Info("scenario3a")
|
slog.Info("a")
|
||||||
s.pendingReqCh <- scenario3a.req
|
s.pendingReqCh <- a.req
|
||||||
// finish prior request, so new model can load
|
s.Run(ctx)
|
||||||
time.Sleep(1 * time.Millisecond)
|
|
||||||
scenario2a.ctxDone()
|
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3a.req.successCh:
|
case resp := <-a.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3a.srv)
|
require.Equal(t, resp.llama, a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3a.req.errCh)
|
require.Empty(t, a.req.errCh)
|
||||||
case err := <-scenario3a.req.errCh:
|
case err := <-a.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
@@ -262,15 +292,15 @@ func TestRequests(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
envconfig.MaxRunners = 0
|
envconfig.MaxRunners = 0
|
||||||
s.newServerFn = scenario3b.newServer
|
s.newServerFn = b.newServer
|
||||||
slog.Info("scenario3b")
|
slog.Info("b")
|
||||||
s.pendingReqCh <- scenario3b.req
|
s.pendingReqCh <- b.req
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3b.req.successCh:
|
case resp := <-b.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3b.srv)
|
require.Equal(t, resp.llama, b.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3b.req.errCh)
|
require.Empty(t, b.req.errCh)
|
||||||
case err := <-scenario3b.req.errCh:
|
case err := <-b.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
@@ -280,15 +310,15 @@ func TestRequests(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
// This is a CPU load with NumGPU = 0 so it should load
|
// This is a CPU load with NumGPU = 0 so it should load
|
||||||
s.newServerFn = scenario3c.newServer
|
s.newServerFn = c.newServer
|
||||||
slog.Info("scenario3c")
|
slog.Info("c")
|
||||||
s.pendingReqCh <- scenario3c.req
|
s.pendingReqCh <- c.req
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3c.req.successCh:
|
case resp := <-c.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3c.srv)
|
require.Equal(t, resp.llama, c.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3c.req.errCh)
|
require.Empty(t, c.req.errCh)
|
||||||
case err := <-scenario3c.req.errCh:
|
case err := <-c.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
@@ -298,25 +328,25 @@ func TestRequests(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
// Try to load a model that wont fit
|
// Try to load a model that wont fit
|
||||||
s.newServerFn = scenario3d.newServer
|
s.newServerFn = d.newServer
|
||||||
slog.Info("scenario3d")
|
slog.Info("d")
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 3)
|
require.Len(t, s.loaded, 3)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
|
a.ctxDone() // Won't help since this one isn't big enough to make room
|
||||||
time.Sleep(2 * time.Millisecond)
|
time.Sleep(2 * time.Millisecond)
|
||||||
s.pendingReqCh <- scenario3d.req
|
s.pendingReqCh <- d.req
|
||||||
// finish prior request, so new model can load
|
// finish prior request, so new model can load
|
||||||
time.Sleep(6 * time.Millisecond)
|
time.Sleep(6 * time.Millisecond)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 2)
|
require.Len(t, s.loaded, 2)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
scenario3b.ctxDone()
|
b.ctxDone()
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3d.req.successCh:
|
case resp := <-d.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3d.srv)
|
require.Equal(t, resp.llama, d.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3d.req.errCh)
|
require.Empty(t, d.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -329,26 +359,19 @@ func TestGetRunner(t *testing.T) {
|
|||||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||||
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||||
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
|
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||||
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
|
||||||
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
|
|
||||||
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
|
|
||||||
envconfig.MaxQueuedRequests = 1
|
envconfig.MaxQueuedRequests = 1
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.getGpuFn = func() gpu.GpuInfoList {
|
s.getGpuFn = getGpuFn
|
||||||
g := gpu.GpuInfo{Library: "metal"}
|
s.getCpuFn = getCpuFn
|
||||||
g.TotalMemory = 24 * format.GigaByte
|
s.newServerFn = a.newServer
|
||||||
g.FreeMemory = 12 * format.GigaByte
|
slog.Info("a")
|
||||||
return []gpu.GpuInfo{g}
|
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
|
||||||
}
|
|
||||||
s.newServerFn = scenario1a.newServer
|
|
||||||
slog.Info("scenario1a")
|
|
||||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
slog.Info("scenario1b")
|
slog.Info("b")
|
||||||
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
|
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
require.Empty(t, successCh1b)
|
require.Empty(t, successCh1b)
|
||||||
require.Len(t, errCh1b, 1)
|
require.Len(t, errCh1b, 1)
|
||||||
@@ -357,22 +380,24 @@ func TestGetRunner(t *testing.T) {
|
|||||||
s.Run(ctx)
|
s.Run(ctx)
|
||||||
select {
|
select {
|
||||||
case resp := <-successCh1a:
|
case resp := <-successCh1a:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, errCh1a)
|
require.Empty(t, errCh1a)
|
||||||
|
case err := <-errCh1a:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
scenario1a.ctxDone()
|
a.ctxDone() // Set "a" model to idle so it can unload
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 1)
|
require.Len(t, s.loaded, 1)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
scenario1c.req.model.ModelPath = "bad path"
|
c.req.model.ModelPath = "bad path"
|
||||||
slog.Info("scenario1c")
|
slog.Info("c")
|
||||||
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
|
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
|
||||||
// Starts in pending channel, then should be quickly processsed to return an error
|
// Starts in pending channel, then should be quickly processsed to return an error
|
||||||
time.Sleep(5 * time.Millisecond)
|
time.Sleep(20 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
||||||
require.Empty(t, successCh1c)
|
require.Empty(t, successCh1c)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Empty(t, s.loaded)
|
require.Empty(t, s.loaded)
|
||||||
@@ -380,7 +405,7 @@ func TestGetRunner(t *testing.T) {
|
|||||||
require.Len(t, errCh1c, 1)
|
require.Len(t, errCh1c, 1)
|
||||||
err = <-errCh1c
|
err = <-errCh1c
|
||||||
require.Contains(t, err.Error(), "bad path")
|
require.Contains(t, err.Error(), "bad path")
|
||||||
scenario1b.ctxDone()
|
b.ctxDone()
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
||||||
@@ -389,7 +414,7 @@ func TestPrematureExpired(t *testing.T) {
|
|||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
// Same model, same request
|
// Same model, same request
|
||||||
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil)
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.getGpuFn = func() gpu.GpuInfoList {
|
s.getGpuFn = func() gpu.GpuInfoList {
|
||||||
g := gpu.GpuInfo{Library: "metal"}
|
g := gpu.GpuInfo{Library: "metal"}
|
||||||
@@ -411,6 +436,8 @@ func TestPrematureExpired(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
slog.Info("sending premature expired event now")
|
slog.Info("sending premature expired event now")
|
||||||
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
||||||
|
case err := <-errCh1a:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -446,6 +473,8 @@ func TestUseLoadedRunner(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case success := <-req.successCh:
|
case success := <-req.successCh:
|
||||||
require.Equal(t, r1, success)
|
require.Equal(t, r1, success)
|
||||||
|
case err := <-req.errCh:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -625,8 +654,7 @@ func TestAlreadyCanceled(t *testing.T) {
|
|||||||
defer done()
|
defer done()
|
||||||
dctx, done2 := context.WithCancel(ctx)
|
dctx, done2 := context.WithCancel(ctx)
|
||||||
done2()
|
done2()
|
||||||
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
|
scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0})
|
||||||
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
slog.Info("scenario1a")
|
slog.Info("scenario1a")
|
||||||
s.pendingReqCh <- scenario1a.req
|
s.pendingReqCh <- scenario1a.req
|
||||||
|
|||||||
@@ -264,6 +264,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
||||||
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
||||||
cut = true
|
cut = true
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return cut
|
return cut
|
||||||
@@ -273,7 +274,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Prompt": prompt,
|
"Prompt": prompt,
|
||||||
"Response": "",
|
"Response": response,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -260,6 +260,26 @@ func TestExecuteWithMessages(t *testing.T) {
|
|||||||
|
|
||||||
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"mistral assistant",
|
||||||
|
[]template{
|
||||||
|
{"no response", `[INST] {{ .Prompt }}[/INST] `},
|
||||||
|
{"response", `[INST] {{ .Prompt }}[/INST] {{ .Response }}`},
|
||||||
|
{"messages", `
|
||||||
|
{{- range $i, $m := .Messages }}
|
||||||
|
{{- if eq .Role "user" }}[INST] {{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}{{ end }}
|
||||||
|
{{- end }}`},
|
||||||
|
},
|
||||||
|
Values{
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello friend!"},
|
||||||
|
{Role: "assistant", Content: "Hello human!"},
|
||||||
|
{Role: "user", Content: "What is your name?"},
|
||||||
|
{Role: "assistant", Content: "My name is Ollama and I"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I`,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"chatml",
|
"chatml",
|
||||||
[]template{
|
[]template{
|
||||||
|
|||||||
Reference in New Issue
Block a user