From f4408219e92a7b22107a68d0b3f5eb545c06aed9 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 5 Jul 2024 15:30:06 -0700 Subject: [PATCH 01/11] Refine scheduler unit tests for reliability This breaks up some of the test scenarios to create a more reliable set of tests, as well as adding a little more coverage. --- server/sched_test.go | 327 ++++++++++++++++++++++++++----------------- 1 file changed, 196 insertions(+), 131 deletions(-) diff --git a/server/sched_test.go b/server/sched_test.go index 3fbd188a..c16b407d 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "os" + "runtime" "testing" "time" @@ -94,7 +95,7 @@ func TestLoad(t *testing.T) { require.Len(t, s.expiredCh, 1) } -type bundle struct { +type reqBundle struct { ctx context.Context //nolint:containedctx ctxDone func() srv *mockLlm @@ -102,13 +103,13 @@ type bundle struct { ggml *llm.GGML } -func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { +func (scenario *reqBundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { return scenario.srv, nil } -func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle { - scenario := &bundle{} - scenario.ctx, scenario.ctxDone = context.WithCancel(ctx) +func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle { + b := &reqBundle{} + b.ctx, b.ctxDone = context.WithCancel(ctx) t.Helper() f, err := os.CreateTemp(t.TempDir(), modelName) @@ -135,124 +136,154 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV fname := f.Name() model := &Model{Name: modelName, ModelPath: fname} - scenario.ggml, err = llm.LoadModel(model.ModelPath, 0) + b.ggml, err = llm.LoadModel(model.ModelPath, 0) require.NoError(t, err) - scenario.req = &LlmRequest{ - ctx: scenario.ctx, + if duration == nil { + duration = &api.Duration{Duration: 5 * time.Millisecond} + } + b.req = &LlmRequest{ + ctx: b.ctx, model: model, opts: api.DefaultOptions(), - sessionDuration: &api.Duration{Duration: 5 * time.Millisecond}, + sessionDuration: duration, successCh: make(chan *runnerRef, 1), errCh: make(chan error, 1), } - scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}} - return scenario + b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}} + return b } -func TestRequests(t *testing.T) { - ctx, done := context.WithTimeout(context.Background(), 10*time.Second) +func getGpuFn() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "metal"} + g.TotalMemory = 24 * format.GigaByte + g.FreeMemory = 12 * format.GigaByte + return []gpu.GpuInfo{g} +} + +func getCpuFn() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "cpu"} + g.TotalMemory = 32 * format.GigaByte + g.FreeMemory = 26 * format.GigaByte + return []gpu.GpuInfo{g} +} + +func TestRequestsSameModelSameRequest(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) defer done() - - // Same model, same request - scenario1a := newScenario(t, ctx, "ollama-model-1", 10) - scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond} - scenario1b := newScenario(t, ctx, "ollama-model-1", 11) - scenario1b.req.model = scenario1a.req.model - scenario1b.ggml = scenario1a.ggml - scenario1b.req.sessionDuration = &api.Duration{Duration: 0} - - // simple reload of same model - scenario2a := newScenario(t, ctx, "ollama-model-1", 20) - tmpModel := *scenario1a.req.model - scenario2a.req.model = &tmpModel - scenario2a.ggml = scenario1a.ggml - scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond} - - // Multiple loaded models - scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte) - scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte) - scenario3c := newScenario(t, ctx, "ollama-model-4a", 30) - scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed - scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded - s := InitScheduler(ctx) - s.getGpuFn = func() gpu.GpuInfoList { - g := gpu.GpuInfo{Library: "metal"} - g.TotalMemory = 24 * format.GigaByte - g.FreeMemory = 12 * format.GigaByte - return []gpu.GpuInfo{g} - } - s.getCpuFn = func() gpu.GpuInfoList { - g := gpu.GpuInfo{Library: "cpu"} - g.TotalMemory = 32 * format.GigaByte - g.FreeMemory = 26 * format.GigaByte - return []gpu.GpuInfo{g} - } - s.newServerFn = scenario1a.newServer - slog.Info("scenario1a") - s.pendingReqCh <- scenario1a.req + s.getGpuFn = getGpuFn + s.getCpuFn = getCpuFn + a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}) + b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0}) + b.req.model = a.req.model + b.ggml = a.ggml + + s.newServerFn = a.newServer + slog.Info("a") + s.pendingReqCh <- a.req require.Len(t, s.pendingReqCh, 1) s.Run(ctx) select { - case resp := <-scenario1a.req.successCh: - require.Equal(t, resp.llama, scenario1a.srv) + case resp := <-a.req.successCh: + require.Equal(t, resp.llama, a.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario1a.req.errCh) - case err := <-scenario1a.req.errCh: + require.Empty(t, a.req.errCh) + case err := <-a.req.errCh: t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") } // Same runner as first request due to not needing a reload - s.newServerFn = scenario1b.newServer - slog.Info("scenario1b") - s.pendingReqCh <- scenario1b.req + s.newServerFn = b.newServer + slog.Info("b") + s.pendingReqCh <- b.req select { - case resp := <-scenario1b.req.successCh: - require.Equal(t, resp.llama, scenario1a.srv) + case resp := <-b.req.successCh: + require.Equal(t, resp.llama, a.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario1b.req.errCh) - case err := <-scenario1b.req.errCh: + require.Empty(t, b.req.errCh) + case err := <-b.req.errCh: + t.Fatal(err.Error()) + case <-ctx.Done(): + t.Fatal("timeout") + } +} + +func TestRequestsSimpleReloadSameModel(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer done() + s := InitScheduler(ctx) + s.getGpuFn = getGpuFn + s.getCpuFn = getCpuFn + a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}) + b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond}) + tmpModel := *a.req.model + b.req.model = &tmpModel + b.ggml = a.ggml + + s.newServerFn = a.newServer + slog.Info("a") + s.pendingReqCh <- a.req + require.Len(t, s.pendingReqCh, 1) + s.Run(ctx) + select { + case resp := <-a.req.successCh: + require.Equal(t, resp.llama, a.srv) + require.Empty(t, s.pendingReqCh) + require.Empty(t, a.req.errCh) + case err := <-a.req.errCh: t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") } // Trigger a reload - s.newServerFn = scenario2a.newServer - scenario2a.req.model.AdapterPaths = []string{"new"} - slog.Info("scenario2a") - s.pendingReqCh <- scenario2a.req + s.newServerFn = b.newServer + b.req.model.AdapterPaths = []string{"new"} + slog.Info("b") + s.pendingReqCh <- b.req // finish first two requests, so model can reload time.Sleep(1 * time.Millisecond) - scenario1a.ctxDone() - scenario1b.ctxDone() + a.ctxDone() select { - case resp := <-scenario2a.req.successCh: - require.Equal(t, resp.llama, scenario2a.srv) + case resp := <-b.req.successCh: + require.Equal(t, resp.llama, b.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario2a.req.errCh) - case err := <-scenario2a.req.errCh: + require.Empty(t, b.req.errCh) + case err := <-b.req.errCh: t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") } +} + +func TestRequestsMultipleLoadedModels(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer done() + s := InitScheduler(ctx) + s.getGpuFn = getGpuFn + s.getCpuFn = getCpuFn + + // Multiple loaded models + a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil) + b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil) + c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil) + c.req.opts.NumGPU = 0 // CPU load, will be allowed + d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded envconfig.MaxRunners = 1 - s.newServerFn = scenario3a.newServer - slog.Info("scenario3a") - s.pendingReqCh <- scenario3a.req - // finish prior request, so new model can load - time.Sleep(1 * time.Millisecond) - scenario2a.ctxDone() + s.newServerFn = a.newServer + slog.Info("a") + s.pendingReqCh <- a.req + s.Run(ctx) select { - case resp := <-scenario3a.req.successCh: - require.Equal(t, resp.llama, scenario3a.srv) + case resp := <-a.req.successCh: + require.Equal(t, resp.llama, a.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario3a.req.errCh) - case err := <-scenario3a.req.errCh: + require.Empty(t, a.req.errCh) + case err := <-a.req.errCh: t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") @@ -262,15 +293,15 @@ func TestRequests(t *testing.T) { s.loadedMu.Unlock() envconfig.MaxRunners = 0 - s.newServerFn = scenario3b.newServer - slog.Info("scenario3b") - s.pendingReqCh <- scenario3b.req + s.newServerFn = b.newServer + slog.Info("b") + s.pendingReqCh <- b.req select { - case resp := <-scenario3b.req.successCh: - require.Equal(t, resp.llama, scenario3b.srv) + case resp := <-b.req.successCh: + require.Equal(t, resp.llama, b.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario3b.req.errCh) - case err := <-scenario3b.req.errCh: + require.Empty(t, b.req.errCh) + case err := <-b.req.errCh: t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") @@ -280,15 +311,15 @@ func TestRequests(t *testing.T) { s.loadedMu.Unlock() // This is a CPU load with NumGPU = 0 so it should load - s.newServerFn = scenario3c.newServer - slog.Info("scenario3c") - s.pendingReqCh <- scenario3c.req + s.newServerFn = c.newServer + slog.Info("c") + s.pendingReqCh <- c.req select { - case resp := <-scenario3c.req.successCh: - require.Equal(t, resp.llama, scenario3c.srv) + case resp := <-c.req.successCh: + require.Equal(t, resp.llama, c.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario3c.req.errCh) - case err := <-scenario3c.req.errCh: + require.Empty(t, c.req.errCh) + case err := <-c.req.errCh: t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") @@ -298,25 +329,25 @@ func TestRequests(t *testing.T) { s.loadedMu.Unlock() // Try to load a model that wont fit - s.newServerFn = scenario3d.newServer - slog.Info("scenario3d") + s.newServerFn = d.newServer + slog.Info("d") s.loadedMu.Lock() require.Len(t, s.loaded, 3) s.loadedMu.Unlock() - scenario3a.ctxDone() // Won't help since this one isn't big enough to make room + a.ctxDone() // Won't help since this one isn't big enough to make room time.Sleep(2 * time.Millisecond) - s.pendingReqCh <- scenario3d.req + s.pendingReqCh <- d.req // finish prior request, so new model can load time.Sleep(6 * time.Millisecond) s.loadedMu.Lock() require.Len(t, s.loaded, 2) s.loadedMu.Unlock() - scenario3b.ctxDone() + b.ctxDone() select { - case resp := <-scenario3d.req.successCh: - require.Equal(t, resp.llama, scenario3d.srv) + case resp := <-d.req.successCh: + require.Equal(t, resp.llama, d.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario3d.req.errCh) + require.Empty(t, d.req.errCh) case <-ctx.Done(): t.Fatal("timeout") } @@ -325,30 +356,59 @@ func TestRequests(t *testing.T) { s.loadedMu.Unlock() } +func TestRequestsModelTooBigForSystem(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer done() + s := InitScheduler(ctx) + s.getGpuFn = func() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "metal"} + g.TotalMemory = 4 * format.MebiByte + g.FreeMemory = 3 * format.MebiByte + return []gpu.GpuInfo{g} + } + + s.getCpuFn = func() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "cpu"} + g.TotalMemory = 4 * format.MebiByte + g.FreeMemory = 2 * format.MebiByte + return []gpu.GpuInfo{g} + } + a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}) + + s.newServerFn = a.newServer + slog.Info("a") + s.pendingReqCh <- a.req + require.Len(t, s.pendingReqCh, 1) + s.Run(ctx) + select { + case <-a.req.successCh: + if runtime.GOOS == "linux" { + t.Fatal("request should have been rejected with out of space") + } + // else - Darwin and Windows don't reject right now + case err := <-a.req.errCh: + require.Contains(t, err.Error(), "too large") + case <-ctx.Done(): + t.Fatal("timeout") + } +} func TestGetRunner(t *testing.T) { ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) defer done() - scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) - scenario1a.req.sessionDuration = &api.Duration{Duration: 0} - scenario1b := newScenario(t, ctx, "ollama-model-1b", 10) - scenario1b.req.sessionDuration = &api.Duration{Duration: 0} - scenario1c := newScenario(t, ctx, "ollama-model-1c", 10) - scenario1c.req.sessionDuration = &api.Duration{Duration: 0} + a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond}) + b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond}) + c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond}) envconfig.MaxQueuedRequests = 1 s := InitScheduler(ctx) - s.getGpuFn = func() gpu.GpuInfoList { - g := gpu.GpuInfo{Library: "metal"} - g.TotalMemory = 24 * format.GigaByte - g.FreeMemory = 12 * format.GigaByte - return []gpu.GpuInfo{g} - } - s.newServerFn = scenario1a.newServer - slog.Info("scenario1a") - successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration) + s.getGpuFn = getGpuFn + s.getCpuFn = getCpuFn + s.newServerFn = a.newServer + slog.Info("a") + successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration) require.Len(t, s.pendingReqCh, 1) - slog.Info("scenario1b") - successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration) + slog.Info("b") + successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration) require.Len(t, s.pendingReqCh, 1) require.Empty(t, successCh1b) require.Len(t, errCh1b, 1) @@ -357,22 +417,24 @@ func TestGetRunner(t *testing.T) { s.Run(ctx) select { case resp := <-successCh1a: - require.Equal(t, resp.llama, scenario1a.srv) + require.Equal(t, resp.llama, a.srv) require.Empty(t, s.pendingReqCh) require.Empty(t, errCh1a) + case err := <-errCh1a: + t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") } - scenario1a.ctxDone() + a.ctxDone() // Set "a" model to idle so it can unload s.loadedMu.Lock() require.Len(t, s.loaded, 1) s.loadedMu.Unlock() - scenario1c.req.model.ModelPath = "bad path" - slog.Info("scenario1c") - successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration) + c.req.model.ModelPath = "bad path" + slog.Info("c") + successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration) // Starts in pending channel, then should be quickly processsed to return an error - time.Sleep(5 * time.Millisecond) + time.Sleep(20 * time.Millisecond) // Long enough for the "a" model to expire and unload require.Empty(t, successCh1c) s.loadedMu.Lock() require.Empty(t, s.loaded) @@ -380,7 +442,7 @@ func TestGetRunner(t *testing.T) { require.Len(t, errCh1c, 1) err = <-errCh1c require.Contains(t, err.Error(), "bad path") - scenario1b.ctxDone() + b.ctxDone() } // TODO - add one scenario that triggers the bogus finished event with positive ref count @@ -389,7 +451,7 @@ func TestPrematureExpired(t *testing.T) { defer done() // Same model, same request - scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) + scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil) s := InitScheduler(ctx) s.getGpuFn = func() gpu.GpuInfoList { g := gpu.GpuInfo{Library: "metal"} @@ -411,6 +473,8 @@ func TestPrematureExpired(t *testing.T) { s.loadedMu.Unlock() slog.Info("sending premature expired event now") s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe + case err := <-errCh1a: + t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") } @@ -446,6 +510,8 @@ func TestUseLoadedRunner(t *testing.T) { select { case success := <-req.successCh: require.Equal(t, r1, success) + case err := <-req.errCh: + t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") } @@ -625,8 +691,7 @@ func TestAlreadyCanceled(t *testing.T) { defer done() dctx, done2 := context.WithCancel(ctx) done2() - scenario1a := newScenario(t, dctx, "ollama-model-1", 10) - scenario1a.req.sessionDuration = &api.Duration{Duration: 0} + scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0}) s := InitScheduler(ctx) slog.Info("scenario1a") s.pendingReqCh <- scenario1a.req From 73e2c8f68fe075ea159a20bbf778c0cf801316ad Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Tue, 9 Jul 2024 15:28:25 -0700 Subject: [PATCH 02/11] Fix context exhaustion integration test for small gpus On the smaller GPUs, the initial model load of llama2 took over 30s (the default timeout for the DoGenerate helper) --- integration/context_test.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/integration/context_test.go b/integration/context_test.go index 46fac5ea..f1342e16 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -12,7 +12,7 @@ import ( func TestContextExhaustion(t *testing.T) { // Longer needed for small footprint GPUs - ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up the test data req := api.GenerateRequest{ @@ -25,5 +25,10 @@ func TestContextExhaustion(t *testing.T) { "num_ctx": 128, }, } - GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"}) + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + if err := PullIfMissing(ctx, client, req.Model); err != nil { + t.Fatalf("PullIfMissing failed: %v", err) + } + DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second) } From 51b2fd299cd568093ce796aef3e7e37ae656b02a Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Fri, 19 Jul 2024 11:19:20 -0700 Subject: [PATCH 03/11] adjust openai chat msg processing (#5729) --- openai/openai.go | 7 +++---- openai/openai_test.go | 8 ++++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/openai/openai.go b/openai/openai.go index 01864e48..93b63296 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -351,7 +351,6 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { case string: messages = append(messages, api.Message{Role: msg.Role, Content: content}) case []any: - message := api.Message{Role: msg.Role} for _, c := range content { data, ok := c.(map[string]any) if !ok { @@ -363,7 +362,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { if !ok { return nil, fmt.Errorf("invalid message format") } - message.Content = text + messages = append(messages, api.Message{Role: msg.Role, Content: text}) case "image_url": var url string if urlMap, ok := data["image_url"].(map[string]any); ok { @@ -395,12 +394,12 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { if err != nil { return nil, fmt.Errorf("invalid message format") } - message.Images = append(message.Images, img) + + messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}}) default: return nil, fmt.Errorf("invalid message format") } } - messages = append(messages, message) default: if msg.ToolCalls == nil { return nil, fmt.Errorf("invalid message content type: %T", content) diff --git a/openai/openai_test.go b/openai/openai_test.go index 046ee69c..ad056e6d 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -161,8 +161,12 @@ func TestMiddlewareRequests(t *testing.T) { img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):]) - if !bytes.Equal(chatReq.Messages[0].Images[0], img) { - t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0]) + if chatReq.Messages[1].Role != "user" { + t.Fatalf("expected 'user', got %s", chatReq.Messages[1].Role) + } + + if !bytes.Equal(chatReq.Messages[1].Images[0], img) { + t.Fatalf("expected image encoding, got %s", chatReq.Messages[1].Images[0]) } }, }, From c57317cbf0c865dd1fbe4852e1cce3cf4703b7ee Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Fri, 19 Jul 2024 11:37:12 -0700 Subject: [PATCH 04/11] OpenAI: Function Based Testing (#5752) * distinguish error forwarding * more coverage * rm comment --- openai/openai.go | 1 + openai/openai_test.go | 461 +++++++++++++++++++++++++----------------- 2 files changed, 279 insertions(+), 183 deletions(-) diff --git a/openai/openai.go b/openai/openai.go index 93b63296..de6f4bd5 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -877,6 +877,7 @@ func ChatMiddleware() gin.HandlerFunc { chatReq, err := fromChatRequest(req) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + return } if err := json.NewEncoder(&b).Encode(chatReq); err != nil { diff --git a/openai/openai_test.go b/openai/openai_test.go index ad056e6d..f978d46c 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -20,113 +20,59 @@ const prefix = `data:image/jpeg;base64,` const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` const imageURL = prefix + image -func TestMiddlewareRequests(t *testing.T) { +func prepareRequest(req *http.Request, body any) { + bodyBytes, _ := json.Marshal(body) + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") +} + +func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { + return func(c *gin.Context) { + bodyBytes, _ := io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + err := json.Unmarshal(bodyBytes, capturedRequest) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request") + } + c.Next() + } +} + +func TestChatMiddleware(t *testing.T) { type testCase struct { Name string - Method string - Path string - Handler func() gin.HandlerFunc Setup func(t *testing.T, req *http.Request) - Expected func(t *testing.T, req *http.Request) + Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) } - var capturedRequest *http.Request - - captureRequestMiddleware := func() gin.HandlerFunc { - return func(c *gin.Context) { - bodyBytes, _ := io.ReadAll(c.Request.Body) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - capturedRequest = c.Request - c.Next() - } - } + var capturedRequest *api.ChatRequest testCases := []testCase{ { - Name: "chat handler", - Method: http.MethodPost, - Path: "/api/chat", - Handler: ChatMiddleware, + Name: "chat handler", Setup: func(t *testing.T, req *http.Request) { body := ChatCompletionRequest{ Model: "test-model", Messages: []Message{{Role: "user", Content: "Hello"}}, } - - bodyBytes, _ := json.Marshal(body) - - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") + prepareRequest(req, body) }, - Expected: func(t *testing.T, req *http.Request) { - var chatReq api.ChatRequest - if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil { - t.Fatal(err) + Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.Code) } - if chatReq.Messages[0].Role != "user" { - t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role) + if req.Messages[0].Role != "user" { + t.Fatalf("expected 'user', got %s", req.Messages[0].Role) } - if chatReq.Messages[0].Content != "Hello" { - t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content) + if req.Messages[0].Content != "Hello" { + t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content) } }, }, { - Name: "completions handler", - Method: http.MethodPost, - Path: "/api/generate", - Handler: CompletionsMiddleware, - Setup: func(t *testing.T, req *http.Request) { - temp := float32(0.8) - body := CompletionRequest{ - Model: "test-model", - Prompt: "Hello", - Temperature: &temp, - Stop: []string{"\n", "stop"}, - Suffix: "suffix", - } - - bodyBytes, _ := json.Marshal(body) - - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") - }, - Expected: func(t *testing.T, req *http.Request) { - var genReq api.GenerateRequest - if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil { - t.Fatal(err) - } - - if genReq.Prompt != "Hello" { - t.Fatalf("expected 'Hello', got %s", genReq.Prompt) - } - - if genReq.Options["temperature"] != 1.6 { - t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"]) - } - - stopTokens, ok := genReq.Options["stop"].([]any) - - if !ok { - t.Fatalf("expected stop tokens to be a list") - } - - if stopTokens[0] != "\n" || stopTokens[1] != "stop" { - t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens) - } - - if genReq.Suffix != "suffix" { - t.Fatalf("expected 'suffix', got %s", genReq.Suffix) - } - }, - }, - { - Name: "chat handler with image content", - Method: http.MethodPost, - Path: "/api/chat", - Handler: ChatMiddleware, + Name: "chat handler with image content", Setup: func(t *testing.T, req *http.Request) { body := ChatCompletionRequest{ Model: "test-model", @@ -139,91 +85,254 @@ func TestMiddlewareRequests(t *testing.T) { }, }, } - - bodyBytes, _ := json.Marshal(body) - - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") + prepareRequest(req, body) }, - Expected: func(t *testing.T, req *http.Request) { - var chatReq api.ChatRequest - if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil { - t.Fatal(err) + Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.Code) } - if chatReq.Messages[0].Role != "user" { - t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role) + if req.Messages[0].Role != "user" { + t.Fatalf("expected 'user', got %s", req.Messages[0].Role) } - if chatReq.Messages[0].Content != "Hello" { - t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content) + if req.Messages[0].Content != "Hello" { + t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content) } img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):]) - if chatReq.Messages[1].Role != "user" { - t.Fatalf("expected 'user', got %s", chatReq.Messages[1].Role) + if req.Messages[1].Role != "user" { + t.Fatalf("expected 'user', got %s", req.Messages[1].Role) } - if !bytes.Equal(chatReq.Messages[1].Images[0], img) { - t.Fatalf("expected image encoding, got %s", chatReq.Messages[1].Images[0]) + if !bytes.Equal(req.Messages[1].Images[0], img) { + t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0]) } }, }, { - Name: "embed handler single input", - Method: http.MethodPost, - Path: "/api/embed", - Handler: EmbeddingsMiddleware, + Name: "chat handler with tools", + Setup: func(t *testing.T, req *http.Request) { + body := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{ + {Role: "user", Content: "What's the weather like in Paris Today?"}, + {Role: "assistant", ToolCalls: []ToolCall{{ + ID: "id", + Type: "function", + Function: struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + }{ + Name: "get_current_weather", + Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}", + }, + }}}, + }, + } + prepareRequest(req, body) + }, + Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { + if resp.Code != 200 { + t.Fatalf("expected 200, got %d", resp.Code) + } + + if req.Messages[0].Content != "What's the weather like in Paris Today?" { + t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content) + } + + if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" { + t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"]) + } + + if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" { + t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"]) + } + }, + }, + { + Name: "chat handler error forwarding", + Setup: func(t *testing.T, req *http.Request) { + body := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{{Role: "user", Content: 2}}, + } + prepareRequest(req, body) + }, + Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.Code) + } + + if !strings.Contains(resp.Body.String(), "invalid message content type") { + t.Fatalf("error was not forwarded") + } + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/chat", endpoint) + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil) + + tc.Setup(t, req) + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + tc.Expected(t, capturedRequest, resp) + + capturedRequest = nil + }) + } +} + +func TestCompletionsMiddleware(t *testing.T) { + type testCase struct { + Name string + Setup func(t *testing.T, req *http.Request) + Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) + } + + var capturedRequest *api.GenerateRequest + + testCases := []testCase{ + { + Name: "completions handler", + Setup: func(t *testing.T, req *http.Request) { + temp := float32(0.8) + body := CompletionRequest{ + Model: "test-model", + Prompt: "Hello", + Temperature: &temp, + Stop: []string{"\n", "stop"}, + Suffix: "suffix", + } + prepareRequest(req, body) + }, + Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { + if req.Prompt != "Hello" { + t.Fatalf("expected 'Hello', got %s", req.Prompt) + } + + if req.Options["temperature"] != 1.6 { + t.Fatalf("expected 1.6, got %f", req.Options["temperature"]) + } + + stopTokens, ok := req.Options["stop"].([]any) + + if !ok { + t.Fatalf("expected stop tokens to be a list") + } + + if stopTokens[0] != "\n" || stopTokens[1] != "stop" { + t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens) + } + + if req.Suffix != "suffix" { + t.Fatalf("expected 'suffix', got %s", req.Suffix) + } + }, + }, + { + Name: "completions handler error forwarding", + Setup: func(t *testing.T, req *http.Request) { + body := CompletionRequest{ + Model: "test-model", + Prompt: "Hello", + Temperature: nil, + Stop: []int{1, 2}, + Suffix: "suffix", + } + prepareRequest(req, body) + }, + Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.Code) + } + + if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") { + t.Fatalf("error was not forwarded") + } + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/generate", endpoint) + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil) + + tc.Setup(t, req) + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + tc.Expected(t, capturedRequest, resp) + + capturedRequest = nil + }) + } +} + +func TestEmbeddingsMiddleware(t *testing.T) { + type testCase struct { + Name string + Setup func(t *testing.T, req *http.Request) + Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) + } + + var capturedRequest *api.EmbedRequest + + testCases := []testCase{ + { + Name: "embed handler single input", Setup: func(t *testing.T, req *http.Request) { body := EmbedRequest{ Input: "Hello", Model: "test-model", } - - bodyBytes, _ := json.Marshal(body) - - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") + prepareRequest(req, body) }, - Expected: func(t *testing.T, req *http.Request) { - var embedReq api.EmbedRequest - if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil { - t.Fatal(err) + Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) { + if req.Input != "Hello" { + t.Fatalf("expected 'Hello', got %s", req.Input) } - if embedReq.Input != "Hello" { - t.Fatalf("expected 'Hello', got %s", embedReq.Input) - } - - if embedReq.Model != "test-model" { - t.Fatalf("expected 'test-model', got %s", embedReq.Model) + if req.Model != "test-model" { + t.Fatalf("expected 'test-model', got %s", req.Model) } }, }, { - Name: "embed handler batch input", - Method: http.MethodPost, - Path: "/api/embed", - Handler: EmbeddingsMiddleware, + Name: "embed handler batch input", Setup: func(t *testing.T, req *http.Request) { body := EmbedRequest{ Input: []string{"Hello", "World"}, Model: "test-model", } - - bodyBytes, _ := json.Marshal(body) - - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") + prepareRequest(req, body) }, - Expected: func(t *testing.T, req *http.Request) { - var embedReq api.EmbedRequest - if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil { - t.Fatal(err) - } - - input, ok := embedReq.Input.([]any) + Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) { + input, ok := req.Input.([]any) if !ok { t.Fatalf("expected input to be a list") @@ -237,36 +346,52 @@ func TestMiddlewareRequests(t *testing.T) { t.Fatalf("expected 'World', got %s", input[1]) } - if embedReq.Model != "test-model" { - t.Fatalf("expected 'test-model', got %s", embedReq.Model) + if req.Model != "test-model" { + t.Fatalf("expected 'test-model', got %s", req.Model) + } + }, + }, + { + Name: "embed handler error forwarding", + Setup: func(t *testing.T, req *http.Request) { + body := EmbedRequest{ + Model: "test-model", + } + prepareRequest(req, body) + }, + Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.Code) + } + + if !strings.Contains(resp.Body.String(), "invalid input") { + t.Fatalf("error was not forwarded") } }, }, } - gin.SetMode(gin.TestMode) - router := gin.New() - endpoint := func(c *gin.Context) { c.Status(http.StatusOK) } + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/embed", endpoint) + for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { - router = gin.New() - router.Use(captureRequestMiddleware()) - router.Use(tc.Handler()) - router.Handle(tc.Method, tc.Path, endpoint) - req, _ := http.NewRequest(tc.Method, tc.Path, nil) + req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil) - if tc.Setup != nil { - tc.Setup(t, req) - } + tc.Setup(t, req) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) - tc.Expected(t, capturedRequest) + tc.Expected(t, capturedRequest, resp) + + capturedRequest = nil }) } } @@ -284,36 +409,6 @@ func TestMiddlewareResponses(t *testing.T) { } testCases := []testCase{ - { - Name: "completions handler error forwarding", - Method: http.MethodPost, - Path: "/api/generate", - TestPath: "/api/generate", - Handler: CompletionsMiddleware, - Endpoint: func(c *gin.Context) { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) - }, - Setup: func(t *testing.T, req *http.Request) { - body := CompletionRequest{ - Model: "test-model", - Prompt: "Hello", - } - - bodyBytes, _ := json.Marshal(body) - - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") - }, - Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { - if resp.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.Code) - } - - if !strings.Contains(resp.Body.String(), `"invalid request"`) { - t.Fatalf("error was not forwarded") - } - }, - }, { Name: "list handler", Method: http.MethodGet, @@ -330,8 +425,6 @@ func TestMiddlewareResponses(t *testing.T) { }) }, Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { - assert.Equal(t, http.StatusOK, resp.Code) - var listResp ListCompletion if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { t.Fatal(err) @@ -395,6 +488,8 @@ func TestMiddlewareResponses(t *testing.T) { resp := httptest.NewRecorder() router.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + tc.Expected(t, resp) }) } From e8b954c646544d40d84be50aae9cd909fcbd8f41 Mon Sep 17 00:00:00 2001 From: Josh <76125168+joshyan1@users.noreply.github.com> Date: Fri, 19 Jul 2024 15:24:29 -0700 Subject: [PATCH 05/11] server: validate template (#5734) add template validation to modelfile --- server/images.go | 6 ++++++ server/routes.go | 14 +++++++++++--- server/routes_create_test.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 3 deletions(-) diff --git a/server/images.go b/server/images.go index 5e4e8858..574dec19 100644 --- a/server/images.go +++ b/server/images.go @@ -492,6 +492,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio layers = append(layers, baseLayer.Layer) } case "license", "template", "system": + if c.Name == "template" { + if _, err := template.Parse(c.Args); err != nil { + return fmt.Errorf("%w: %s", errBadTemplate, err) + } + } + if c.Name != "license" { // replace layers = slices.DeleteFunc(layers, func(layer *Layer) bool { diff --git a/server/routes.go b/server/routes.go index c33b7195..85db7924 100644 --- a/server/routes.go +++ b/server/routes.go @@ -56,6 +56,7 @@ func init() { } var errRequired = errors.New("is required") +var errBadTemplate = errors.New("template error") func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { opts := api.DefaultOptions() @@ -609,8 +610,11 @@ func (s *Server) CreateModelHandler(c *gin.Context) { quantization := cmp.Or(r.Quantize, r.Quantization) if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil { + if errors.Is(err, errBadTemplate) { + ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest} + } ch <- gin.H{"error": err.Error()} - } + } }() if r.Stream != nil && !*r.Stream { @@ -1196,11 +1200,15 @@ func waitForStream(c *gin.Context, ch chan interface{}) { return } case gin.H: + status, ok := r["status"].(int) + if !ok { + status = http.StatusInternalServerError + } if errorMsg, ok := r["error"].(string); ok { - c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg}) + c.JSON(status, gin.H{"error": errorMsg}) return } else { - c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"}) + c.JSON(status, gin.H{"error": "unexpected error format in progress response"}) return } default: diff --git a/server/routes_create_test.go b/server/routes_create_test.go index cb548ebd..3234ea5e 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -491,6 +491,42 @@ func TestCreateTemplateSystem(t *testing.T) { if string(system) != "Say bye!" { t.Errorf("expected \"Say bye!\", actual %s", system) } + + t.Run("incomplete template", func(t *testing.T) { + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)), + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status code 400, actual %d", w.Code) + } + }) + + t.Run("template with unclosed if", func(t *testing.T) { + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)), + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status code 400, actual %d", w.Code) + } + }) + + t.Run("template with undefined function", func(t *testing.T) { + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)), + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status code 400, actual %d", w.Code) + } + }) } func TestCreateLicenses(t *testing.T) { From 69a2d4ccffa7532680bed245fad77bb166ec0bb9 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Fri, 19 Jul 2024 19:11:25 -0700 Subject: [PATCH 06/11] Fix generate test flakyness (#5804) --- server/routes_generate_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index c914b300..5c0caff1 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -73,8 +73,8 @@ func TestGenerateChat(t *testing.T) { getCpuFn: gpu.GetCPUInfo, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { - // add 10ms delay to simulate loading - time.Sleep(10 * time.Millisecond) + // add small delay to simulate loading + time.Sleep(time.Millisecond) req.successCh <- &runnerRef{ llama: &mock, } @@ -371,6 +371,8 @@ func TestGenerate(t *testing.T) { getCpuFn: gpu.GetCPUInfo, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { + // add small delay to simulate loading + time.Sleep(time.Millisecond) req.successCh <- &runnerRef{ llama: &mock, } From 20090f3172c4848584060cbd51e7c9b14c3630cb Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Fri, 19 Jul 2024 20:19:26 -0700 Subject: [PATCH 07/11] preserve last assistant message (#5802) --- template/template.go | 3 ++- template/template_test.go | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/template/template.go b/template/template.go index b5bfb16c..f7453791 100644 --- a/template/template.go +++ b/template/template.go @@ -264,6 +264,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool { if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") { cut = true + return false } return cut @@ -273,7 +274,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{ "System": system, "Prompt": prompt, - "Response": "", + "Response": response, }); err != nil { return err } diff --git a/template/template_test.go b/template/template_test.go index ae0db80b..b46e1df5 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -260,6 +260,26 @@ func TestExecuteWithMessages(t *testing.T) { Hello friend![/INST] Hello human![INST] What is your name?[/INST] `, }, + { + "mistral assistant", + []template{ + {"no response", `[INST] {{ .Prompt }}[/INST] `}, + {"response", `[INST] {{ .Prompt }}[/INST] {{ .Response }}`}, + {"messages", ` +{{- range $i, $m := .Messages }} +{{- if eq .Role "user" }}[INST] {{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}{{ end }} +{{- end }}`}, + }, + Values{ + Messages: []api.Message{ + {Role: "user", Content: "Hello friend!"}, + {Role: "assistant", Content: "Hello human!"}, + {Role: "user", Content: "What is your name?"}, + {Role: "assistant", Content: "My name is Ollama and I"}, + }, + }, + `[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I`, + }, { "chatml", []template{ From 1475eab95f5a4ddc5b1bb169df0d89e71732dfa0 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sat, 20 Jul 2024 13:41:21 -0400 Subject: [PATCH 08/11] add patch for tekken (#5807) --- llm/patches/10-tekken.diff | 43 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 llm/patches/10-tekken.diff diff --git a/llm/patches/10-tekken.diff b/llm/patches/10-tekken.diff new file mode 100644 index 00000000..56a583e0 --- /dev/null +++ b/llm/patches/10-tekken.diff @@ -0,0 +1,43 @@ +diff --git a/include/llama.h b/include/llama.h +index bb4b05ba..a92174e0 100644 +--- a/include/llama.h ++++ b/include/llama.h +@@ -92,6 +92,7 @@ extern "C" { + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, ++ LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + }; + + // note: these values should be synchronized with ggml_rope +diff --git a/src/llama.cpp b/src/llama.cpp +index 18364976..435b6fe5 100644 +--- a/src/llama.cpp ++++ b/src/llama.cpp +@@ -5429,6 +5429,12 @@ static void llm_load_vocab( + } else if ( + tokenizer_pre == "jais") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS; ++ } else if ( ++ tokenizer_pre == "tekken") { ++ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN; ++ vocab.tokenizer_clean_spaces = false; ++ vocab.tokenizer_ignore_merges = true; ++ vocab.tokenizer_add_bos = true; + } else { + LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__); + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; +@@ -15448,6 +15454,13 @@ struct llm_tokenizer_bpe { + " ?[^(\\s|.,!?…。,、।۔،)]+", + }; + break; ++ case LLAMA_VOCAB_PRE_TYPE_TEKKEN: ++ // original regex from tokenizer.json ++ // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ++ regex_exprs = { ++ "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", ++ }; ++ break; + default: + // default regex for BPE tokenization pre-processing + regex_exprs = { From 283948c83b5cbf74f6cf86dce4434238e64d6e1c Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 19 Jul 2024 15:07:26 -0700 Subject: [PATCH 09/11] Adjust windows ROCm discovery The v5 hip library returns unsupported GPUs which wont enumerate at inference time in the runner so this makes sure we align discovery. The gfx906 cards are no longer supported so we shouldn't compile with that GPU type as it wont enumerate at runtime. --- docs/gpu.md | 15 +++++++++++++-- gpu/amd_hip_windows.go | 5 +++-- gpu/amd_windows.go | 3 ++- llm/generate/gen_windows.ps1 | 2 +- llm/server.go | 2 ++ 5 files changed, 21 insertions(+), 6 deletions(-) diff --git a/docs/gpu.md b/docs/gpu.md index 80f276c3..e669ea32 100644 --- a/docs/gpu.md +++ b/docs/gpu.md @@ -46,13 +46,24 @@ sudo modprobe nvidia_uvm` ## AMD Radeon Ollama supports the following AMD GPUs: + +### Linux Support | Family | Cards and accelerators | | -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | | AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` | | AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` | | AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` | -### Overrides +### Windows Support +With ROCm v6.1, the following GPUs are supported on Windows. + +| Family | Cards and accelerators | +| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | +| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` | +| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` | + + +### Overrides on Linux Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In some cases you can force the system to try to use a similar LLVM target that is close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4) @@ -63,7 +74,7 @@ would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the server. If you have an unsupported AMD GPU you can experiment using the list of supported types below. -At this time, the known supported GPU types are the following LLVM Targets. +At this time, the known supported GPU types on linux are the following LLVM Targets. This table shows some example GPUs that map to these LLVM targets: | **LLVM Target** | **An Example GPU** | |-----------------|---------------------| diff --git a/gpu/amd_hip_windows.go b/gpu/amd_hip_windows.go index 2586278c..98806234 100644 --- a/gpu/amd_hip_windows.go +++ b/gpu/amd_hip_windows.go @@ -33,9 +33,10 @@ type HipLib struct { } func NewHipLib() (*HipLib, error) { - h, err := windows.LoadLibrary("amdhip64.dll") + // At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs + h, err := windows.LoadLibrary("amdhip64_6.dll") if err != nil { - return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err) + return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err) } hl := &HipLib{} hl.dll = h diff --git a/gpu/amd_windows.go b/gpu/amd_windows.go index 425259d7..20aed447 100644 --- a/gpu/amd_windows.go +++ b/gpu/amd_windows.go @@ -92,7 +92,8 @@ func AMDGetGPUInfo() []RocmGPUInfo { continue } if gfxOverride == "" { - if !slices.Contains[[]string, string](supported, gfx) { + // Strip off Target Features when comparing + if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) { slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported) // TODO - consider discrete markdown just for ROCM troubleshooting? slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage") diff --git a/llm/generate/gen_windows.ps1 b/llm/generate/gen_windows.ps1 index beb964f9..d8bce92d 100644 --- a/llm/generate/gen_windows.ps1 +++ b/llm/generate/gen_windows.ps1 @@ -7,8 +7,8 @@ function amdGPUs { return $env:AMDGPU_TARGETS } # Current supported rocblas list from ROCm v6.1.2 on windows + # https://rocm.docs.amd.com/projects/install-on-windows/en/latest/reference/system-requirements.html#windows-supported-gpus $GPU_LIST = @( - "gfx906:xnack-" "gfx1030" "gfx1100" "gfx1101" diff --git a/llm/server.go b/llm/server.go index 36c0e0b5..ba7eab03 100644 --- a/llm/server.go +++ b/llm/server.go @@ -385,8 +385,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr filteredEnv := []string{} for _, ev := range s.cmd.Env { if strings.HasPrefix(ev, "CUDA_") || + strings.HasPrefix(ev, "ROCR_") || strings.HasPrefix(ev, "ROCM_") || strings.HasPrefix(ev, "HIP_") || + strings.HasPrefix(ev, "GPU_") || strings.HasPrefix(ev, "HSA_") || strings.HasPrefix(ev, "GGML_") || strings.HasPrefix(ev, "PATH=") || From 5534f2cc6a3f29022998950472741d16e7a66b40 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sat, 20 Jul 2024 21:48:12 -0400 Subject: [PATCH 10/11] llm: consider `head_dim` in llama arch (#5817) --- llm/patches/11-embd_kv.diff | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 llm/patches/11-embd_kv.diff diff --git a/llm/patches/11-embd_kv.diff b/llm/patches/11-embd_kv.diff new file mode 100644 index 00000000..ad17a700 --- /dev/null +++ b/llm/patches/11-embd_kv.diff @@ -0,0 +1,19 @@ +diff --git a/src/llama.cpp b/src/llama.cpp +index 2b9ace28..e60d3d8d 100644 +--- a/src/llama.cpp ++++ b/src/llama.cpp +@@ -6052,10 +6052,10 @@ static bool llm_load_tensors( + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + +- layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); +- layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); +- layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); +- layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); ++ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); ++ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); ++ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); ++ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + + // optional bias tensors + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); From 80ee9b5e47fc0ea99d1f3f33224923266627c15c Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sun, 21 Jul 2024 00:22:11 -0400 Subject: [PATCH 11/11] Remove out of space test temporarily (#5825) --- server/sched_test.go | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/server/sched_test.go b/server/sched_test.go index 7991e7c5..9ddd1fab 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -7,7 +7,6 @@ import ( "fmt" "log/slog" "os" - "runtime" "testing" "time" @@ -356,42 +355,6 @@ func TestRequestsMultipleLoadedModels(t *testing.T) { s.loadedMu.Unlock() } -func TestRequestsModelTooBigForSystem(t *testing.T) { - ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) - defer done() - s := InitScheduler(ctx) - s.getGpuFn = func() gpu.GpuInfoList { - g := gpu.GpuInfo{Library: "metal"} - g.TotalMemory = 4 * format.MebiByte - g.FreeMemory = 3 * format.MebiByte - return []gpu.GpuInfo{g} - } - - s.getCpuFn = func() gpu.GpuInfoList { - g := gpu.GpuInfo{Library: "cpu"} - g.TotalMemory = 4 * format.MebiByte - g.FreeMemory = 2 * format.MebiByte - return []gpu.GpuInfo{g} - } - a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}) - - s.newServerFn = a.newServer - slog.Info("a") - s.pendingReqCh <- a.req - require.Len(t, s.pendingReqCh, 1) - s.Run(ctx) - select { - case <-a.req.successCh: - if runtime.GOOS == "linux" { - t.Fatal("request should have been rejected with out of space") - } - // else - Darwin and Windows don't reject right now - case err := <-a.req.errCh: - require.Contains(t, err.Error(), "too large") - case <-ctx.Done(): - t.Fatal("timeout") - } -} func TestGetRunner(t *testing.T) { ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) defer done()