Use runners for GPU discovery (#12090)

This revamps how we discover GPUs in the system by leveraging the Ollama
runner.  This should eliminate inconsistency between our GPU discovery and the
runners capabilities at runtime, particularly for cases where we try to filter
out unsupported GPUs.  Now the runner does that implicitly based on the actual
device list.  In some cases free VRAM reporting can be unreliable which can
leaad to scheduling mistakes, so this also includes a patch to leverage more
reliable VRAM reporting libraries if available.

Automatic workarounds have been removed as only one GPU leveraged this, which
is now documented. This GPU will soon fall off the support matrix with the next
ROCm bump.

Additional cleanup of the scheduler and discovery packages can be done in the
future once we have switched on the new memory management code, and removed
support for the llama runner.
This commit is contained in:
Daniel Hiltgen
2025-10-01 15:12:32 -07:00
committed by GitHub
parent 6b50f2b9cd
commit bc8909fb38
57 changed files with 3288 additions and 3819 deletions

View File

@@ -1557,8 +1557,8 @@ func Serve(ln net.Listener) error {
// At startup we retrieve GPU information so we can get log messages before loading a model
// This will log warnings to the log in case we have problems with detected GPUs
gpus := discover.GetGPUInfo()
gpus.LogDetails()
gpus := discover.GPUDevices(ctx, nil)
discover.LogDetails(gpus)
var totalVRAM uint64
for _, gpu := range gpus {

View File

@@ -36,8 +36,8 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
getGpuFn: getGpuFn,
getCpuFn: getCpuFn,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
// add small delay to simulate loading
@@ -229,8 +229,8 @@ func TestChatDebugRenderOnly(t *testing.T) {
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
getGpuFn: getGpuFn,
getCpuFn: getCpuFn,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
// add small delay to simulate loading

View File

@@ -74,8 +74,8 @@ func TestGenerateChat(t *testing.T) {
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
getGpuFn: getGpuFn,
getCpuFn: getCpuFn,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
// add small delay to simulate loading
@@ -618,8 +618,8 @@ func TestGenerate(t *testing.T) {
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
getGpuFn: getGpuFn,
getCpuFn: getCpuFn,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
// add small delay to simulate loading
@@ -994,8 +994,8 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
getGpuFn: getGpuFn,
getCpuFn: getCpuFn,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
time.Sleep(time.Millisecond)

View File

@@ -274,8 +274,8 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
getGpuFn: getGpuFn,
getCpuFn: getCpuFn,
reschedDelay: 100 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
req.successCh <- &runnerRef{
@@ -425,8 +425,8 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
getGpuFn: getGpuFn,
getCpuFn: getCpuFn,
reschedDelay: 100 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
req.successCh <- &runnerRef{
@@ -607,8 +607,8 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
getGpuFn: getGpuFn,
getCpuFn: getCpuFn,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
req.successCh <- &runnerRef{

View File

@@ -21,6 +21,7 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
)
@@ -52,8 +53,8 @@ type Scheduler struct {
loadFn func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool
newServerFn func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
getGpuFn func() discover.GpuInfoList
getCpuFn func() discover.GpuInfoList
getGpuFn func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList
getCpuFn func() discover.GpuInfo
reschedDelay time.Duration
}
@@ -148,7 +149,12 @@ func (s *Scheduler) processPending(ctx context.Context) {
s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath]
loadedCount := len(s.loaded)
runnersSnapshot := make([]discover.FilteredRunnerDiscovery, 0, len(s.loaded))
for _, r := range s.loaded {
runnersSnapshot = append(runnersSnapshot, r)
}
s.loadedMu.Unlock()
if runner != nil {
if runner.needsReload(ctx, pending) {
slog.Debug("reloading", "runner", runner)
@@ -166,9 +172,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
// Get a refreshed GPU list
var gpus discover.GpuInfoList
if pending.opts.NumGPU == 0 {
gpus = s.getCpuFn()
gpus = discover.GpuInfoList{s.getCpuFn()}
} else {
gpus = s.getGpuFn()
gpus = s.getGpuFn(ctx, runnersSnapshot)
}
if envconfig.MaxRunners() <= 0 {
@@ -343,7 +349,11 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
runner.refMu.Unlock()
} else {
slog.Debug("starting background wait for VRAM recovery", "runner", runner)
finished := runner.waitForVRAMRecovery()
runnersSnapshot := make([]discover.FilteredRunnerDiscovery, 0, len(s.loaded))
for _, r := range s.loaded {
runnersSnapshot = append(runnersSnapshot, r)
}
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
runner.unload()
delete(s.loaded, runner.modelPath)
s.loadedMu.Unlock()
@@ -429,7 +439,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
s.loadedMu.Unlock()
err := llama.Load(req.ctx, gpus, requireFull)
gpuIDs, err := llama.Load(req.ctx, gpus, requireFull)
if err != nil {
if errors.Is(err, llm.ErrLoadRequiredFull) {
return true
@@ -448,7 +458,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
llama: llama,
Options: &req.opts,
sessionDuration: sessionDuration,
gpus: gpus,
gpus: gpuIDs,
vramSize: llama.VRAMSize(),
totalSize: llama.TotalSize(),
loading: true,
@@ -497,11 +507,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
}
func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
type predKey struct {
Library string
ID string
}
predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
predMap := map[ml.DeviceID]uint64{} // Sum up the total predicted usage per GPU for all runners
s.loadedMu.Lock()
runners := make([]*runnerRef, 0, len(s.loaded))
for _, r := range s.loaded {
@@ -512,7 +518,7 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
r.refMu.Lock()
if r.llama != nil {
for _, gpu := range allGpus {
predMap[predKey{gpu.Library, gpu.ID}] += r.llama.VRAMByGPU(gpu.ID)
predMap[gpu.DeviceID] += r.llama.VRAMByGPU(gpu.DeviceID)
}
} else {
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
@@ -522,7 +528,7 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
for i := range allGpus {
if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok {
if p, ok := predMap[allGpus[i].DeviceID]; ok {
slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
if p > allGpus[i].TotalMemory {
// Shouldn't happen
@@ -546,8 +552,8 @@ type runnerRef struct {
llama llm.LlamaServer
pid int
loading bool // True only during initial load, then false forever
gpus discover.GpuInfoList // Recorded at time of provisioning
loading bool // True only during initial load, then false forever
gpus []ml.DeviceID // Recorded at time of provisioning
vramSize uint64
totalSize uint64
@@ -571,7 +577,6 @@ func (runner *runnerRef) unload() {
runner.llama.Close()
}
runner.model = nil
runner.llama = nil
runner.Options = nil
runner.gpus = nil
}
@@ -618,14 +623,14 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
// a before and after GPU memory allocation. The returned channel
// will be notified when we're done waiting, or have timed out and should
// proceed anyway
func (runner *runnerRef) waitForVRAMRecovery() chan any {
func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []discover.FilteredRunnerDiscovery) chan any {
finished := make(chan any, 1)
// CPU or Metal don't need checking, so no waiting required
// windows can page VRAM, only cuda currently can report accurate used vram usage
if len(runner.gpus) == 0 ||
(len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "metal")) ||
(runtime.GOOS == "windows" && runner.gpus[0].Library != "cuda") {
(len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "Metal")) ||
(runtime.GOOS == "windows" && runner.gpus[0].Library != "CUDA") {
finished <- struct{}{}
slog.Debug("no need to wait for VRAM recovery", "runner", runner)
return finished
@@ -633,7 +638,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
start := time.Now()
// Establish a baseline before we unload
gpusBefore := discover.GetGPUInfo()
gpusBefore := s.getGpuFn(context.Background(), runners)
var totalMemoryBefore, freeMemoryBefore uint64
for _, gpu := range gpusBefore {
totalMemoryBefore += gpu.TotalMemory
@@ -651,7 +656,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
}
// Query GPUs, look for free to go back up
gpusNow := discover.GetGPUInfo()
gpusNow := s.getGpuFn(context.Background(), runners)
var totalMemoryNow, freeMemoryNow uint64
for _, gpu := range gpusNow {
totalMemoryNow += gpu.TotalMemory
@@ -678,8 +683,7 @@ func (runner *runnerRef) LogValue() slog.Value {
}
if len(runner.gpus) > 0 {
attrs = append(attrs,
slog.String("inference", runner.gpus[0].Library),
slog.Int("devices", len(runner.gpus)),
slog.Any("inference", runner.gpus),
)
}
attrs = append(attrs,
@@ -695,6 +699,32 @@ func (runner *runnerRef) LogValue() slog.Value {
return slog.GroupValue(attrs...)
}
// Implements discover.RunnerDiscovery
func (runner *runnerRef) GetPort() int {
if runner.llama != nil {
return runner.llama.GetPort()
}
return -1
}
func (runner *runnerRef) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
if runner.llama != nil {
return runner.llama.GetDeviceInfos(ctx)
}
return nil
}
func (runner *runnerRef) GetActiveDeviceIDs() []ml.DeviceID {
return runner.gpus
}
func (runner *runnerRef) HasExited() bool {
if runner.llama != nil {
return runner.llama.HasExited()
}
return true
}
type ByDurationAndName []*runnerRef
func (a ByDurationAndName) Len() int { return len(a) }

View File

@@ -17,6 +17,7 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
)
func TestMain(m *testing.M) {
@@ -61,7 +62,7 @@ func TestLoad(t *testing.T) {
err := <-req.errCh
require.Contains(t, err.Error(), "this model may be incompatible")
server := &mockLlm{vramSize: 10, vramByGPU: map[string]uint64{}}
server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}}
s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
server.modelPath = model
return server, nil
@@ -109,7 +110,7 @@ func (scenario *reqBundle) newServer(gpus discover.GpuInfoList, model string, f
return scenario.srv, nil
}
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vramSize uint64, duration *api.Duration) *reqBundle {
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vramSize uint64, duration *api.Duration, vramByGPU map[ml.DeviceID]uint64) *reqBundle {
b := &reqBundle{}
b.ctx, b.ctxDone = context.WithCancel(ctx)
t.Helper()
@@ -146,22 +147,24 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vra
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
b.srv = &mockLlm{vramSize: vramSize, vramByGPU: map[string]uint64{"": vramSize}}
b.srv = &mockLlm{vramSize: vramSize, vramByGPU: vramByGPU}
return b
}
func getGpuFn() discover.GpuInfoList {
g := discover.GpuInfo{Library: "metal"}
func getGpuFn(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
slog.Info("test getGpuFn called", "runners", runners)
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []discover.GpuInfo{g}
}
func getCpuFn() discover.GpuInfoList {
g := discover.GpuInfo{Library: "cpu"}
func getCpuFn() discover.GpuInfo {
slog.Info("test getCpuFn called")
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "cpu"}}
g.TotalMemory = 32 * format.GigaByte
g.FreeMemory = 26 * format.GigaByte
return []discover.GpuInfo{g}
return g
}
func TestRequestsSameModelSameRequest(t *testing.T) {
@@ -170,8 +173,8 @@ func TestRequestsSameModelSameRequest(t *testing.T) {
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil)
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0}, nil)
b.req.model = a.req.model
b.f = a.f
@@ -208,13 +211,13 @@ func TestRequestsSameModelSameRequest(t *testing.T) {
}
func TestRequestsSimpleReloadSameModel(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
ctx, done := context.WithTimeout(t.Context(), 5000*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil)
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond}, nil)
tmpModel := *a.req.model
b.req.model = &tmpModel
b.f = a.f
@@ -243,6 +246,15 @@ func TestRequestsSimpleReloadSameModel(t *testing.T) {
// finish first two requests, so model can reload
time.Sleep(1 * time.Millisecond)
a.ctxDone()
// Report recovered VRAM usage
time.Sleep(1 * time.Millisecond)
s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
slog.Info("XXX altered getGpuFn called")
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 24 * format.GigaByte
return []discover.GpuInfo{g}
}
select {
case resp := <-b.req.successCh:
require.Equal(t, resp.llama, b.srv)
@@ -259,15 +271,18 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
s.getGpuFn = getGpuFn // 1 metal GPU
s.getCpuFn = getCpuFn // 1 CPU
// Multiple loaded models
a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
b := newScenarioRequest(t, ctx, "ollama-model-3b", 10*format.GigaByte, nil)
c := newScenarioRequest(t, ctx, "ollama-model-4a", 10*format.GigaByte, nil)
c.req.opts.NumGPU = 0 // CPU load, will be allowed
d := newScenarioRequest(t, ctx, "ollama-model-3c", 10*format.GigaByte, nil) // Needs prior unloaded
a := newScenarioRequest(t, ctx, "model-a-1g-gpu", 1*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 1 * format.GigaByte})
a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
b := newScenarioRequest(t, ctx, "model-b-10g-gpu", 10*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 10 * format.GigaByte})
b.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
c := newScenarioRequest(t, ctx, "model-c-10g-cpu", 10*format.GigaByte, nil, nil /* No GPU load */)
c.req.opts.NumGPU = 0 // CPU load, will be allowed
b.req.sessionDuration = &api.Duration{Duration: 10 * time.Millisecond} // longer than b to cause the scheduler to favor unloading b over c
d := newScenarioRequest(t, ctx, "model-d-10g-gpu", 13*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 13 * format.GigaByte}) // Needs prior unloaded
t.Setenv("OLLAMA_MAX_LOADED_MODELS", "1")
s.newServerFn = a.newServer
@@ -338,7 +353,16 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
s.loadedMu.Unlock()
// Mark b done so it can unload
b.ctxDone()
// Report recovered VRAM usage so scheduler will finish waiting and unload
time.Sleep(1 * time.Millisecond)
s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 24 * format.GigaByte
return []discover.GpuInfo{g}
}
select {
case resp := <-d.req.successCh:
require.Equal(t, resp.llama, d.srv)
@@ -347,6 +371,19 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
case <-ctx.Done():
t.Fatal("timeout")
}
// Wait for b to close
closeWait:
for {
select {
case <-ctx.Done():
t.Fatal("timeout")
default:
if b.srv.closeCalled {
break closeWait
}
time.Sleep(1 * time.Millisecond)
}
}
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
s.loadedMu.Unlock()
@@ -356,9 +393,9 @@ func TestGetRunner(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 3*time.Second)
defer done()
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond}, nil)
b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond}, nil)
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond}, nil)
t.Setenv("OLLAMA_MAX_QUEUE", "1")
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
@@ -420,7 +457,7 @@ func TestExpireRunner(t *testing.T) {
var f *ggml.GGML
gpus := discover.GpuInfoList{}
server := &mockLlm{vramSize: 10, vramByGPU: map[string]uint64{}}
server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}}
s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
server.modelPath = model
return server, nil
@@ -458,10 +495,10 @@ func TestPrematureExpired(t *testing.T) {
defer done()
// Same model, same request
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil)
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil, nil)
s := InitScheduler(ctx)
s.getGpuFn = func() discover.GpuInfoList {
g := discover.GpuInfo{Library: "metal"}
s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []discover.GpuInfo{g}
@@ -509,7 +546,7 @@ func TestUseLoadedRunner(t *testing.T) {
sessionDuration: &api.Duration{Duration: 2},
}
finished := make(chan *LlmRequest)
llm1 := &mockLlm{vramByGPU: map[string]uint64{}}
llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
r1 := &runnerRef{llama: llm1, sessionDuration: 1, numParallel: 1}
req.useLoadedRunner(r1, finished)
require.Equal(t, uint(1), r1.refCount)
@@ -532,22 +569,32 @@ func TestUpdateFreeSpace(t *testing.T) {
defer done()
gpus := discover.GpuInfoList{
{
Library: "a",
ID: "1",
DeviceID: ml.DeviceID{
ID: "1",
},
},
{
Library: "a",
ID: "2",
DeviceID: ml.DeviceID{
ID: "2",
},
},
}
gpus[0].TotalMemory = 1000
gpus[0].FreeMemory = 900
gpus[1].TotalMemory = 2000
gpus[1].FreeMemory = 1900
llm1 := &mockLlm{vramByGPU: map[string]uint64{"1": 50, "2": 50}}
llm2 := &mockLlm{vramByGPU: map[string]uint64{"1": 125, "2": 75}}
r1 := &runnerRef{llama: llm1, gpus: gpus, numParallel: 1}
r2 := &runnerRef{llama: llm2, gpus: gpus, numParallel: 1}
gpuIDs := []ml.DeviceID{
{
ID: "1",
},
{
ID: "2",
},
}
llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{{ID: "1"}: 50, {ID: "2"}: 50}}
llm2 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{{ID: "1"}: 125, {ID: "2"}: 75}}
r1 := &runnerRef{llama: llm1, gpus: gpuIDs, numParallel: 1}
r2 := &runnerRef{llama: llm2, gpus: gpuIDs, numParallel: 1}
s := InitScheduler(ctx)
s.loadedMu.Lock()
@@ -584,7 +631,7 @@ func TestNeedsReload(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
llm := &mockLlm{vramByGPU: map[string]uint64{}}
llm := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
do := api.DefaultOptions()
runner := &runnerRef{
model: &Model{
@@ -631,8 +678,8 @@ func TestUnloadAllRunners(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
llm1 := &mockLlm{vramByGPU: map[string]uint64{}}
llm2 := &mockLlm{vramByGPU: map[string]uint64{}}
llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
llm2 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
s := InitScheduler(ctx)
s.unloadAllRunners()
@@ -650,7 +697,7 @@ func TestUnloadAllRunners(t *testing.T) {
}
func TestUnload(t *testing.T) {
llm1 := &mockLlm{vramByGPU: map[string]uint64{}}
llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
r1 := &runnerRef{llama: llm1, numParallel: 1}
r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}, numParallel: 1}
r1.unload()
@@ -664,7 +711,7 @@ func TestAlreadyCanceled(t *testing.T) {
defer done()
dctx, done2 := context.WithCancel(ctx)
done2()
scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0})
scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0}, nil)
s := InitScheduler(ctx)
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req
@@ -691,24 +738,28 @@ type mockLlm struct {
closeCalled bool
vramSize uint64
totalSize uint64
vramByGPU map[string]uint64
vramByGPU map[ml.DeviceID]uint64
}
func (s *mockLlm) ModelPath() string {
return s.modelPath
}
func (s *mockLlm) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error {
func (s *mockLlm) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error) {
if requireFull {
for _, g := range gpus {
if g.FreeMemory >= s.vramSize {
return nil
return []ml.DeviceID{g.DeviceID}, nil
}
}
return llm.ErrLoadRequiredFull
return nil, llm.ErrLoadRequiredFull
}
return nil
gpuIDs := make([]ml.DeviceID, len(gpus))
for i := range gpus {
gpuIDs[i] = gpus[i].DeviceID
}
return gpuIDs, nil
}
func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp }
func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
@@ -732,7 +783,11 @@ func (s *mockLlm) Close() error {
s.closeCalled = true
return s.closeResp
}
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
func (s *mockLlm) VRAMByGPU(gpuid string) uint64 { return s.vramByGPU[gpuid] }
func (s *mockLlm) Pid() int { return -1 }
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
func (s *mockLlm) Pid() int { return -1 }
func (s *mockLlm) GetPort() int { return -1 }
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
func (s *mockLlm) HasExited() bool { return false }
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }