mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 14:26:30 +00:00
Use runners for GPU discovery (#12090)
This revamps how we discover GPUs in the system by leveraging the Ollama runner. This should eliminate inconsistency between our GPU discovery and the runners capabilities at runtime, particularly for cases where we try to filter out unsupported GPUs. Now the runner does that implicitly based on the actual device list. In some cases free VRAM reporting can be unreliable which can leaad to scheduling mistakes, so this also includes a patch to leverage more reliable VRAM reporting libraries if available. Automatic workarounds have been removed as only one GPU leveraged this, which is now documented. This GPU will soon fall off the support matrix with the next ROCm bump. Additional cleanup of the scheduler and discovery packages can be done in the future once we have switched on the new memory management code, and removed support for the llama runner.
This commit is contained in:
@@ -1557,8 +1557,8 @@ func Serve(ln net.Listener) error {
|
||||
|
||||
// At startup we retrieve GPU information so we can get log messages before loading a model
|
||||
// This will log warnings to the log in case we have problems with detected GPUs
|
||||
gpus := discover.GetGPUInfo()
|
||||
gpus.LogDetails()
|
||||
gpus := discover.GPUDevices(ctx, nil)
|
||||
discover.LogDetails(gpus)
|
||||
|
||||
var totalVRAM uint64
|
||||
for _, gpu := range gpus {
|
||||
|
||||
@@ -36,8 +36,8 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
@@ -229,8 +229,8 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
|
||||
@@ -74,8 +74,8 @@ func TestGenerateChat(t *testing.T) {
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
@@ -618,8 +618,8 @@ func TestGenerate(t *testing.T) {
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
@@ -994,8 +994,8 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
@@ -274,8 +274,8 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 100 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
@@ -425,8 +425,8 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 100 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
@@ -607,8 +607,8 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -52,8 +53,8 @@ type Scheduler struct {
|
||||
|
||||
loadFn func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool
|
||||
newServerFn func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
|
||||
getGpuFn func() discover.GpuInfoList
|
||||
getCpuFn func() discover.GpuInfoList
|
||||
getGpuFn func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList
|
||||
getCpuFn func() discover.GpuInfo
|
||||
reschedDelay time.Duration
|
||||
}
|
||||
|
||||
@@ -148,7 +149,12 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[pending.model.ModelPath]
|
||||
loadedCount := len(s.loaded)
|
||||
runnersSnapshot := make([]discover.FilteredRunnerDiscovery, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
runnersSnapshot = append(runnersSnapshot, r)
|
||||
}
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
if runner != nil {
|
||||
if runner.needsReload(ctx, pending) {
|
||||
slog.Debug("reloading", "runner", runner)
|
||||
@@ -166,9 +172,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
// Get a refreshed GPU list
|
||||
var gpus discover.GpuInfoList
|
||||
if pending.opts.NumGPU == 0 {
|
||||
gpus = s.getCpuFn()
|
||||
gpus = discover.GpuInfoList{s.getCpuFn()}
|
||||
} else {
|
||||
gpus = s.getGpuFn()
|
||||
gpus = s.getGpuFn(ctx, runnersSnapshot)
|
||||
}
|
||||
|
||||
if envconfig.MaxRunners() <= 0 {
|
||||
@@ -343,7 +349,11 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
runner.refMu.Unlock()
|
||||
} else {
|
||||
slog.Debug("starting background wait for VRAM recovery", "runner", runner)
|
||||
finished := runner.waitForVRAMRecovery()
|
||||
runnersSnapshot := make([]discover.FilteredRunnerDiscovery, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
runnersSnapshot = append(runnersSnapshot, r)
|
||||
}
|
||||
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
|
||||
runner.unload()
|
||||
delete(s.loaded, runner.modelPath)
|
||||
s.loadedMu.Unlock()
|
||||
@@ -429,7 +439,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
|
||||
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
err := llama.Load(req.ctx, gpus, requireFull)
|
||||
gpuIDs, err := llama.Load(req.ctx, gpus, requireFull)
|
||||
if err != nil {
|
||||
if errors.Is(err, llm.ErrLoadRequiredFull) {
|
||||
return true
|
||||
@@ -448,7 +458,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
|
||||
llama: llama,
|
||||
Options: &req.opts,
|
||||
sessionDuration: sessionDuration,
|
||||
gpus: gpus,
|
||||
gpus: gpuIDs,
|
||||
vramSize: llama.VRAMSize(),
|
||||
totalSize: llama.TotalSize(),
|
||||
loading: true,
|
||||
@@ -497,11 +507,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
|
||||
}
|
||||
|
||||
func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
|
||||
type predKey struct {
|
||||
Library string
|
||||
ID string
|
||||
}
|
||||
predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
|
||||
predMap := map[ml.DeviceID]uint64{} // Sum up the total predicted usage per GPU for all runners
|
||||
s.loadedMu.Lock()
|
||||
runners := make([]*runnerRef, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
@@ -512,7 +518,7 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
|
||||
r.refMu.Lock()
|
||||
if r.llama != nil {
|
||||
for _, gpu := range allGpus {
|
||||
predMap[predKey{gpu.Library, gpu.ID}] += r.llama.VRAMByGPU(gpu.ID)
|
||||
predMap[gpu.DeviceID] += r.llama.VRAMByGPU(gpu.DeviceID)
|
||||
}
|
||||
} else {
|
||||
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
|
||||
@@ -522,7 +528,7 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
|
||||
|
||||
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
|
||||
for i := range allGpus {
|
||||
if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok {
|
||||
if p, ok := predMap[allGpus[i].DeviceID]; ok {
|
||||
slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
|
||||
if p > allGpus[i].TotalMemory {
|
||||
// Shouldn't happen
|
||||
@@ -546,8 +552,8 @@ type runnerRef struct {
|
||||
|
||||
llama llm.LlamaServer
|
||||
pid int
|
||||
loading bool // True only during initial load, then false forever
|
||||
gpus discover.GpuInfoList // Recorded at time of provisioning
|
||||
loading bool // True only during initial load, then false forever
|
||||
gpus []ml.DeviceID // Recorded at time of provisioning
|
||||
vramSize uint64
|
||||
totalSize uint64
|
||||
|
||||
@@ -571,7 +577,6 @@ func (runner *runnerRef) unload() {
|
||||
runner.llama.Close()
|
||||
}
|
||||
runner.model = nil
|
||||
runner.llama = nil
|
||||
runner.Options = nil
|
||||
runner.gpus = nil
|
||||
}
|
||||
@@ -618,14 +623,14 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
// a before and after GPU memory allocation. The returned channel
|
||||
// will be notified when we're done waiting, or have timed out and should
|
||||
// proceed anyway
|
||||
func (runner *runnerRef) waitForVRAMRecovery() chan any {
|
||||
func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []discover.FilteredRunnerDiscovery) chan any {
|
||||
finished := make(chan any, 1)
|
||||
|
||||
// CPU or Metal don't need checking, so no waiting required
|
||||
// windows can page VRAM, only cuda currently can report accurate used vram usage
|
||||
if len(runner.gpus) == 0 ||
|
||||
(len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "metal")) ||
|
||||
(runtime.GOOS == "windows" && runner.gpus[0].Library != "cuda") {
|
||||
(len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "Metal")) ||
|
||||
(runtime.GOOS == "windows" && runner.gpus[0].Library != "CUDA") {
|
||||
finished <- struct{}{}
|
||||
slog.Debug("no need to wait for VRAM recovery", "runner", runner)
|
||||
return finished
|
||||
@@ -633,7 +638,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
|
||||
start := time.Now()
|
||||
|
||||
// Establish a baseline before we unload
|
||||
gpusBefore := discover.GetGPUInfo()
|
||||
gpusBefore := s.getGpuFn(context.Background(), runners)
|
||||
var totalMemoryBefore, freeMemoryBefore uint64
|
||||
for _, gpu := range gpusBefore {
|
||||
totalMemoryBefore += gpu.TotalMemory
|
||||
@@ -651,7 +656,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
|
||||
}
|
||||
|
||||
// Query GPUs, look for free to go back up
|
||||
gpusNow := discover.GetGPUInfo()
|
||||
gpusNow := s.getGpuFn(context.Background(), runners)
|
||||
var totalMemoryNow, freeMemoryNow uint64
|
||||
for _, gpu := range gpusNow {
|
||||
totalMemoryNow += gpu.TotalMemory
|
||||
@@ -678,8 +683,7 @@ func (runner *runnerRef) LogValue() slog.Value {
|
||||
}
|
||||
if len(runner.gpus) > 0 {
|
||||
attrs = append(attrs,
|
||||
slog.String("inference", runner.gpus[0].Library),
|
||||
slog.Int("devices", len(runner.gpus)),
|
||||
slog.Any("inference", runner.gpus),
|
||||
)
|
||||
}
|
||||
attrs = append(attrs,
|
||||
@@ -695,6 +699,32 @@ func (runner *runnerRef) LogValue() slog.Value {
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// Implements discover.RunnerDiscovery
|
||||
func (runner *runnerRef) GetPort() int {
|
||||
if runner.llama != nil {
|
||||
return runner.llama.GetPort()
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (runner *runnerRef) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
if runner.llama != nil {
|
||||
return runner.llama.GetDeviceInfos(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (runner *runnerRef) GetActiveDeviceIDs() []ml.DeviceID {
|
||||
return runner.gpus
|
||||
}
|
||||
|
||||
func (runner *runnerRef) HasExited() bool {
|
||||
if runner.llama != nil {
|
||||
return runner.llama.HasExited()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type ByDurationAndName []*runnerRef
|
||||
|
||||
func (a ByDurationAndName) Len() int { return len(a) }
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -61,7 +62,7 @@ func TestLoad(t *testing.T) {
|
||||
err := <-req.errCh
|
||||
require.Contains(t, err.Error(), "this model may be incompatible")
|
||||
|
||||
server := &mockLlm{vramSize: 10, vramByGPU: map[string]uint64{}}
|
||||
server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}}
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
server.modelPath = model
|
||||
return server, nil
|
||||
@@ -109,7 +110,7 @@ func (scenario *reqBundle) newServer(gpus discover.GpuInfoList, model string, f
|
||||
return scenario.srv, nil
|
||||
}
|
||||
|
||||
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vramSize uint64, duration *api.Duration) *reqBundle {
|
||||
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vramSize uint64, duration *api.Duration, vramByGPU map[ml.DeviceID]uint64) *reqBundle {
|
||||
b := &reqBundle{}
|
||||
b.ctx, b.ctxDone = context.WithCancel(ctx)
|
||||
t.Helper()
|
||||
@@ -146,22 +147,24 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vra
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
}
|
||||
b.srv = &mockLlm{vramSize: vramSize, vramByGPU: map[string]uint64{"": vramSize}}
|
||||
b.srv = &mockLlm{vramSize: vramSize, vramByGPU: vramByGPU}
|
||||
return b
|
||||
}
|
||||
|
||||
func getGpuFn() discover.GpuInfoList {
|
||||
g := discover.GpuInfo{Library: "metal"}
|
||||
func getGpuFn(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
|
||||
slog.Info("test getGpuFn called", "runners", runners)
|
||||
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
|
||||
g.TotalMemory = 24 * format.GigaByte
|
||||
g.FreeMemory = 12 * format.GigaByte
|
||||
return []discover.GpuInfo{g}
|
||||
}
|
||||
|
||||
func getCpuFn() discover.GpuInfoList {
|
||||
g := discover.GpuInfo{Library: "cpu"}
|
||||
func getCpuFn() discover.GpuInfo {
|
||||
slog.Info("test getCpuFn called")
|
||||
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "cpu"}}
|
||||
g.TotalMemory = 32 * format.GigaByte
|
||||
g.FreeMemory = 26 * format.GigaByte
|
||||
return []discover.GpuInfo{g}
|
||||
return g
|
||||
}
|
||||
|
||||
func TestRequestsSameModelSameRequest(t *testing.T) {
|
||||
@@ -170,8 +173,8 @@ func TestRequestsSameModelSameRequest(t *testing.T) {
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getCpuFn = getCpuFn
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil)
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0}, nil)
|
||||
b.req.model = a.req.model
|
||||
b.f = a.f
|
||||
|
||||
@@ -208,13 +211,13 @@ func TestRequestsSameModelSameRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRequestsSimpleReloadSameModel(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(t.Context(), 5000*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getCpuFn = getCpuFn
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil)
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond}, nil)
|
||||
tmpModel := *a.req.model
|
||||
b.req.model = &tmpModel
|
||||
b.f = a.f
|
||||
@@ -243,6 +246,15 @@ func TestRequestsSimpleReloadSameModel(t *testing.T) {
|
||||
// finish first two requests, so model can reload
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
a.ctxDone()
|
||||
// Report recovered VRAM usage
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
|
||||
slog.Info("XXX altered getGpuFn called")
|
||||
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
|
||||
g.TotalMemory = 24 * format.GigaByte
|
||||
g.FreeMemory = 24 * format.GigaByte
|
||||
return []discover.GpuInfo{g}
|
||||
}
|
||||
select {
|
||||
case resp := <-b.req.successCh:
|
||||
require.Equal(t, resp.llama, b.srv)
|
||||
@@ -259,15 +271,18 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getCpuFn = getCpuFn
|
||||
s.getGpuFn = getGpuFn // 1 metal GPU
|
||||
s.getCpuFn = getCpuFn // 1 CPU
|
||||
|
||||
// Multiple loaded models
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-3b", 10*format.GigaByte, nil)
|
||||
c := newScenarioRequest(t, ctx, "ollama-model-4a", 10*format.GigaByte, nil)
|
||||
c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||
d := newScenarioRequest(t, ctx, "ollama-model-3c", 10*format.GigaByte, nil) // Needs prior unloaded
|
||||
a := newScenarioRequest(t, ctx, "model-a-1g-gpu", 1*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 1 * format.GigaByte})
|
||||
a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||
b := newScenarioRequest(t, ctx, "model-b-10g-gpu", 10*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 10 * format.GigaByte})
|
||||
b.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||
c := newScenarioRequest(t, ctx, "model-c-10g-cpu", 10*format.GigaByte, nil, nil /* No GPU load */)
|
||||
c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||
b.req.sessionDuration = &api.Duration{Duration: 10 * time.Millisecond} // longer than b to cause the scheduler to favor unloading b over c
|
||||
d := newScenarioRequest(t, ctx, "model-d-10g-gpu", 13*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 13 * format.GigaByte}) // Needs prior unloaded
|
||||
|
||||
t.Setenv("OLLAMA_MAX_LOADED_MODELS", "1")
|
||||
s.newServerFn = a.newServer
|
||||
@@ -338,7 +353,16 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
s.loadedMu.Unlock()
|
||||
// Mark b done so it can unload
|
||||
b.ctxDone()
|
||||
// Report recovered VRAM usage so scheduler will finish waiting and unload
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
|
||||
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
|
||||
g.TotalMemory = 24 * format.GigaByte
|
||||
g.FreeMemory = 24 * format.GigaByte
|
||||
return []discover.GpuInfo{g}
|
||||
}
|
||||
select {
|
||||
case resp := <-d.req.successCh:
|
||||
require.Equal(t, resp.llama, d.srv)
|
||||
@@ -347,6 +371,19 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
// Wait for b to close
|
||||
closeWait:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
default:
|
||||
if b.srv.closeCalled {
|
||||
break closeWait
|
||||
}
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
s.loadedMu.Unlock()
|
||||
@@ -356,9 +393,9 @@ func TestGetRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer done()
|
||||
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond}, nil)
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond}, nil)
|
||||
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond}, nil)
|
||||
t.Setenv("OLLAMA_MAX_QUEUE", "1")
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
@@ -420,7 +457,7 @@ func TestExpireRunner(t *testing.T) {
|
||||
|
||||
var f *ggml.GGML
|
||||
gpus := discover.GpuInfoList{}
|
||||
server := &mockLlm{vramSize: 10, vramByGPU: map[string]uint64{}}
|
||||
server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}}
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
server.modelPath = model
|
||||
return server, nil
|
||||
@@ -458,10 +495,10 @@ func TestPrematureExpired(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
// Same model, same request
|
||||
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil)
|
||||
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil, nil)
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = func() discover.GpuInfoList {
|
||||
g := discover.GpuInfo{Library: "metal"}
|
||||
s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
|
||||
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
|
||||
g.TotalMemory = 24 * format.GigaByte
|
||||
g.FreeMemory = 12 * format.GigaByte
|
||||
return []discover.GpuInfo{g}
|
||||
@@ -509,7 +546,7 @@ func TestUseLoadedRunner(t *testing.T) {
|
||||
sessionDuration: &api.Duration{Duration: 2},
|
||||
}
|
||||
finished := make(chan *LlmRequest)
|
||||
llm1 := &mockLlm{vramByGPU: map[string]uint64{}}
|
||||
llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
|
||||
r1 := &runnerRef{llama: llm1, sessionDuration: 1, numParallel: 1}
|
||||
req.useLoadedRunner(r1, finished)
|
||||
require.Equal(t, uint(1), r1.refCount)
|
||||
@@ -532,22 +569,32 @@ func TestUpdateFreeSpace(t *testing.T) {
|
||||
defer done()
|
||||
gpus := discover.GpuInfoList{
|
||||
{
|
||||
Library: "a",
|
||||
ID: "1",
|
||||
DeviceID: ml.DeviceID{
|
||||
ID: "1",
|
||||
},
|
||||
},
|
||||
{
|
||||
Library: "a",
|
||||
ID: "2",
|
||||
DeviceID: ml.DeviceID{
|
||||
ID: "2",
|
||||
},
|
||||
},
|
||||
}
|
||||
gpus[0].TotalMemory = 1000
|
||||
gpus[0].FreeMemory = 900
|
||||
gpus[1].TotalMemory = 2000
|
||||
gpus[1].FreeMemory = 1900
|
||||
llm1 := &mockLlm{vramByGPU: map[string]uint64{"1": 50, "2": 50}}
|
||||
llm2 := &mockLlm{vramByGPU: map[string]uint64{"1": 125, "2": 75}}
|
||||
r1 := &runnerRef{llama: llm1, gpus: gpus, numParallel: 1}
|
||||
r2 := &runnerRef{llama: llm2, gpus: gpus, numParallel: 1}
|
||||
gpuIDs := []ml.DeviceID{
|
||||
{
|
||||
ID: "1",
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
},
|
||||
}
|
||||
llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{{ID: "1"}: 50, {ID: "2"}: 50}}
|
||||
llm2 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{{ID: "1"}: 125, {ID: "2"}: 75}}
|
||||
r1 := &runnerRef{llama: llm1, gpus: gpuIDs, numParallel: 1}
|
||||
r2 := &runnerRef{llama: llm2, gpus: gpuIDs, numParallel: 1}
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.loadedMu.Lock()
|
||||
@@ -584,7 +631,7 @@ func TestNeedsReload(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
llm := &mockLlm{vramByGPU: map[string]uint64{}}
|
||||
llm := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
|
||||
do := api.DefaultOptions()
|
||||
runner := &runnerRef{
|
||||
model: &Model{
|
||||
@@ -631,8 +678,8 @@ func TestUnloadAllRunners(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
llm1 := &mockLlm{vramByGPU: map[string]uint64{}}
|
||||
llm2 := &mockLlm{vramByGPU: map[string]uint64{}}
|
||||
llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
|
||||
llm2 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
|
||||
s := InitScheduler(ctx)
|
||||
s.unloadAllRunners()
|
||||
|
||||
@@ -650,7 +697,7 @@ func TestUnloadAllRunners(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnload(t *testing.T) {
|
||||
llm1 := &mockLlm{vramByGPU: map[string]uint64{}}
|
||||
llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
|
||||
r1 := &runnerRef{llama: llm1, numParallel: 1}
|
||||
r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}, numParallel: 1}
|
||||
r1.unload()
|
||||
@@ -664,7 +711,7 @@ func TestAlreadyCanceled(t *testing.T) {
|
||||
defer done()
|
||||
dctx, done2 := context.WithCancel(ctx)
|
||||
done2()
|
||||
scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0})
|
||||
scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0}, nil)
|
||||
s := InitScheduler(ctx)
|
||||
slog.Info("scenario1a")
|
||||
s.pendingReqCh <- scenario1a.req
|
||||
@@ -691,24 +738,28 @@ type mockLlm struct {
|
||||
closeCalled bool
|
||||
vramSize uint64
|
||||
totalSize uint64
|
||||
vramByGPU map[string]uint64
|
||||
vramByGPU map[ml.DeviceID]uint64
|
||||
}
|
||||
|
||||
func (s *mockLlm) ModelPath() string {
|
||||
return s.modelPath
|
||||
}
|
||||
|
||||
func (s *mockLlm) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error {
|
||||
func (s *mockLlm) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error) {
|
||||
if requireFull {
|
||||
for _, g := range gpus {
|
||||
if g.FreeMemory >= s.vramSize {
|
||||
return nil
|
||||
return []ml.DeviceID{g.DeviceID}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return llm.ErrLoadRequiredFull
|
||||
return nil, llm.ErrLoadRequiredFull
|
||||
}
|
||||
return nil
|
||||
gpuIDs := make([]ml.DeviceID, len(gpus))
|
||||
for i := range gpus {
|
||||
gpuIDs[i] = gpus[i].DeviceID
|
||||
}
|
||||
return gpuIDs, nil
|
||||
}
|
||||
func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp }
|
||||
func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
|
||||
@@ -732,7 +783,11 @@ func (s *mockLlm) Close() error {
|
||||
s.closeCalled = true
|
||||
return s.closeResp
|
||||
}
|
||||
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
|
||||
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
|
||||
func (s *mockLlm) VRAMByGPU(gpuid string) uint64 { return s.vramByGPU[gpuid] }
|
||||
func (s *mockLlm) Pid() int { return -1 }
|
||||
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
|
||||
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
|
||||
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
||||
func (s *mockLlm) Pid() int { return -1 }
|
||||
func (s *mockLlm) GetPort() int { return -1 }
|
||||
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
|
||||
func (s *mockLlm) HasExited() bool { return false }
|
||||
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
|
||||
|
||||
Reference in New Issue
Block a user