llm: New memory management

This changes the memory allocation strategy from upfront estimation to
tracking actual allocations done by the engine and reacting to that. The
goal is avoid issues caused by both under-estimation (crashing) and
over-estimation (low performance due to under-utilized GPUs).

It is currently opt-in and can be enabled for models running on the
Ollama engine by setting OLLAMA_NEW_ESTIMATES=1. Behavior in other
cases is unchanged and will continue to use the existing estimates.
This commit is contained in:
Jesse Gross
2025-05-29 12:21:48 -07:00
committed by Jesse Gross
parent ef7d26ba2c
commit d5a0d8d904
26 changed files with 1860 additions and 900 deletions

View File

@@ -28,7 +28,6 @@ type LlmRequest struct {
ctx context.Context //nolint:containedctx
model *Model
opts api.Options
origNumCtx int // Track the initial ctx request
sessionDuration *api.Duration
successCh chan *runnerRef
errCh chan error
@@ -41,10 +40,17 @@ type Scheduler struct {
expiredCh chan *runnerRef
unloadedCh chan any
loaded map[string]*runnerRef
// loadedMu protects loaded and activeLoading
loadedMu sync.Mutex
loadFn func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel int)
// activeLoading is the model that we are currently working on loading,
// including by evicting one or more other models. We can only load
// one model at a time but new requests to models that already loaded can
// happen in parallel
activeLoading llm.LlamaServer
loaded map[string]*runnerRef
loadFn func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool
newServerFn func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
getGpuFn func() discover.GpuInfoList
getCpuFn func() discover.GpuInfoList
@@ -56,9 +62,6 @@ type Scheduler struct {
// on a large GPU can cause stalling
var defaultModelsPerGPU = 3
// Default automatic value for parallel setting
var defaultParallel = 1
var ErrMaxQueue = errors.New("server busy, please try again. maximum pending requests exceeded")
func InitScheduler(ctx context.Context) *Scheduler {
@@ -79,24 +82,36 @@ func InitScheduler(ctx context.Context) *Scheduler {
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
if m.CheckCapabilities(model.CapabilityVision) == nil {
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
req := &LlmRequest{
ctx: c,
model: model,
model: m,
opts: opts,
sessionDuration: sessionDuration,
successCh: make(chan *runnerRef),
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
select {
case s.pendingReqCh <- req:
default:
req.errCh <- ErrMaxQueue
s.loadedMu.Lock()
runner := s.loaded[req.model.ModelPath]
s.loadedMu.Unlock()
if runner != nil && !runner.needsReload(c, req) {
req.useLoadedRunner(runner, s.finishedReqCh)
} else {
select {
case s.pendingReqCh <- req:
default:
req.errCh <- ErrMaxQueue
}
}
return req.successCh, req.errCh
}
@@ -122,21 +137,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
case pending := <-s.pendingReqCh:
// Block other requests until we get this pending request running
pending.schedAttempts++
if pending.origNumCtx == 0 {
pending.origNumCtx = pending.opts.NumCtx
}
if pending.ctx.Err() != nil {
slog.Debug("pending request cancelled or timed out, skipping scheduling")
continue
}
numParallel := int(envconfig.NumParallel())
// `mllama` is a snowflake and uses an encoder cache which cannot be used with num_parallel > 1
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains(pending.model.Config.ModelFamilies, "mllama") && numParallel != 1 {
numParallel = 1
slog.Warn("mllama does not currently support parallel requests")
}
for {
var runnerToExpire *runnerRef
@@ -195,84 +200,26 @@ func (s *Scheduler) processPending(ctx context.Context) {
break
}
// Embedding models should always be loaded with parallel=1
if pending.model.CheckCapabilities(model.CapabilityCompletion) != nil {
numParallel = 1
}
// Update free memory from currently loaded models
s.updateFreeSpace(gpus)
// Evaluate if the model will fit in the available system memory, or if we should unload a model first
if len(gpus) == 1 && gpus[0].Library == "cpu" {
// simplifying assumption of defaultParallel when in CPU mode
if numParallel <= 0 {
numParallel = defaultParallel
}
pending.opts.NumCtx = pending.origNumCtx * numParallel
if loadedCount == 0 {
slog.Debug("cpu mode with first model, loading")
s.loadFn(pending, ggml, gpus, numParallel)
break
}
runnerToExpire = s.maybeFindCPURunnerToUnload(pending, ggml, gpus)
if runnerToExpire == nil {
slog.Debug("cpu mode with available system memory or first model, loading")
s.loadFn(pending, ggml, gpus, numParallel)
break
}
// else we need to expire a runner
} else if loadedCount == 0 {
if loadedCount == 0 {
// No models loaded. Load the model but prefer the best fit.
slog.Debug("loading first model", "model", pending.model.ModelPath)
g := pickBestFullFitByLibrary(pending, ggml, gpus, &numParallel)
if g != nil {
gpus = g
} else {
// Only allow partial loads when this is the first model
gpus = pickBestPartialFitByLibrary(pending, ggml, gpus, &numParallel)
}
s.loadFn(pending, ggml, gpus, numParallel)
s.loadFn(pending, ggml, gpus, false)
break
}
if runnerToExpire == nil {
// More than one loaded model, so we have to see if the
// new one fits
//
// We want to avoid loading on any GPUs that have other
// models still loading on them to avoid potential races
// with VRAM consumption ramping up during load
availGpus := s.filterGPUsWithoutLoadingModels(gpus)
// More than one loaded model, so we have to see if the
// new one fits
// Update free memory from currently loaded models
s.updateFreeSpace(availGpus)
fitGpus := pickBestFullFitByLibrary(pending, ggml, availGpus, &numParallel)
if fitGpus != nil {
slog.Debug("new model fits with existing models, loading")
s.loadFn(pending, ggml, fitGpus, numParallel)
break
}
// We couldn't find a set of GPUs to fully load the new
// model. If no other models are loading (both GPU lists
// are the same) then we need to unload another model to
// make room
if len(availGpus) < len(gpus) {
// There are other requests pending, and this one
// needs more time, so put it on the back of the
// queue so that we might satisfy other pending
// requests that aren't blocked
go func() {
// Process in a go routine to avoid deadlocking
// the scheduler if our queue is full
slog.Debug("delaying scheduling while other models finish loading", "attempts", pending.schedAttempts, "model", pending.model.ModelPath)
time.Sleep(s.reschedDelay)
s.pendingReqCh <- pending
}()
break
}
runnerToExpire = s.findRunnerToUnload()
needEvict := s.loadFn(pending, ggml, gpus, true)
if !needEvict {
slog.Debug("new model fits with existing models, loading")
break
}
runnerToExpire = s.findRunnerToUnload()
}
if runnerToExpire == nil {
@@ -293,8 +240,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
}
runnerToExpire.refMu.Unlock()
// Wait for the unload to happen
// Note: at this point we're queueing up all incoming requests, even if they were for
// a different model that's loaded and not scheduled to be removed.
slog.Debug("waiting for pending requests to complete and unload to occur", "runner", runnerToExpire)
select {
case <-ctx.Done():
@@ -434,26 +379,72 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
}()
}
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel int) {
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
// (if any). Returns whether the scheduler needs to evict a model to make this one fit.
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool {
numParallel := int(envconfig.NumParallel())
if numParallel < 1 {
numParallel = 1
}
// Embedding models should always be loaded with parallel=1
if req.model.CheckCapabilities(model.CapabilityCompletion) != nil {
numParallel = 1
}
// `mllama` is a snowflake and uses an encoder cache which cannot be used with num_parallel > 1
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains(req.model.Config.ModelFamilies, "mllama") && numParallel != 1 {
numParallel = 1
slog.Warn("mllama does not currently support parallel requests")
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
llama, err := s.newServerFn(gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
s.loadedMu.Lock()
llama := s.activeLoading
if llama == nil {
var err error
llama, err = s.newServerFn(gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
}
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
req.errCh <- err
s.loadedMu.Unlock()
return false
}
s.activeLoading = llama
} else {
if s.activeLoading.ModelPath() != req.model.ModelPath {
panic(fmt.Errorf("attempting to load different model after eviction (original %v new %v)", s.activeLoading.ModelPath(), req.model.ModelPath))
}
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
req.errCh <- err
return
}
s.loadedMu.Unlock()
err := llama.Load(req.ctx, gpus, requireFull)
if err != nil {
if errors.Is(err, llm.ErrLoadRequiredFull) {
return true
}
slog.Info("Load failed", "model", req.model.ModelPath, "error", err)
s.activeLoading.Close()
s.activeLoading = nil
req.errCh <- err
return false
}
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
@@ -461,8 +452,8 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
Options: &req.opts,
sessionDuration: sessionDuration,
gpus: gpus,
estimatedVRAM: llama.EstimatedVRAM(),
estimatedTotal: llama.EstimatedTotal(),
vramSize: llama.VRAMSize(),
totalSize: llama.TotalSize(),
loading: true,
pid: llama.Pid(),
}
@@ -477,6 +468,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
oldRunner.unload()
oldRunner.refMu.Unlock()
}
s.activeLoading = nil
s.loaded[req.model.ModelPath] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
@@ -503,6 +495,8 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
}()
req.successCh <- runner
}()
return false
}
func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
@@ -521,7 +515,7 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
r.refMu.Lock()
if r.llama != nil {
for _, gpu := range allGpus {
predMap[predKey{gpu.Library, gpu.ID}] += r.llama.EstimatedVRAMByGPU(gpu.ID)
predMap[predKey{gpu.Library, gpu.ID}] += r.llama.VRAMByGPU(gpu.ID)
}
} else {
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
@@ -548,41 +542,17 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
}
}
// While models are loading the VRAM consumption numbers will be indeterminate, so we have
// to avoid scheduling another model on the same GPU(s) that haven't stabilized.
// This routine returns the set of GPUs that do not have an active loading model.
// If all GPUs have loading models, an empty list will be returned (not a single CPU entry)
func (s *Scheduler) filterGPUsWithoutLoadingModels(allGpus discover.GpuInfoList) discover.GpuInfoList {
ret := append(discover.GpuInfoList{}, allGpus...)
s.loadedMu.Lock()
defer s.loadedMu.Unlock()
for _, runner := range s.loaded {
if runner.loading {
slog.Debug("overlapping loads detected", "gpus", runner.gpus, "model", runner.modelPath)
for _, busyGPU := range runner.gpus {
for i := range ret {
if ret[i].ID == busyGPU.ID {
ret = append(ret[:i], ret[i+1:]...)
break
}
}
}
}
}
return ret
}
// TODO consolidate sched_types.go
type runnerRef struct {
refMu sync.Mutex
refCount uint // prevent unloading if > 0
llama llm.LlamaServer
pid int
loading bool // True only during initial load, then false forever
gpus discover.GpuInfoList // Recorded at time of provisioning
estimatedVRAM uint64
estimatedTotal uint64
llama llm.LlamaServer
pid int
loading bool // True only during initial load, then false forever
gpus discover.GpuInfoList // Recorded at time of provisioning
vramSize uint64
totalSize uint64
sessionDuration time.Duration
expireTimer *time.Timer
@@ -631,9 +601,6 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
optsNew.NumGPU = -1
}
// Normalize the NumCtx for parallelism
optsExisting.NumCtx = optsExisting.NumCtx / runner.numParallel
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
@@ -694,7 +661,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
freeMemoryNow += gpu.FreeMemory
}
// If we're within ~80% of the estimated memory usage recovered, bail out
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.estimatedVRAM)*0.8 {
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.8 {
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "runner", runner)
finished <- struct{}{}
return
@@ -719,8 +686,8 @@ func (runner *runnerRef) LogValue() slog.Value {
)
}
attrs = append(attrs,
slog.String("size", format.HumanBytes2(runner.estimatedTotal)),
slog.String("vram", format.HumanBytes2(runner.estimatedVRAM)),
slog.String("size", format.HumanBytes2(runner.totalSize)),
slog.String("vram", format.HumanBytes2(runner.vramSize)),
slog.Int("parallel", runner.numParallel),
slog.Int("pid", runner.pid),
slog.String("model", runner.modelPath),
@@ -750,95 +717,7 @@ func (a ByDurationAndName) Less(i, j int) bool {
// type BySize []*runnerRef
// func (a BySize) Len() int { return len(a) }
// func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM }
// pickBestFullFitByLibrary will try to find the optimal placement of the model in the available GPUs where the model fully fits
// The list of GPUs returned will always be the same brand (library)
// If the model can not be fit fully within the available GPU(s) nil is returned
// If numParallel is <= 0, this will attempt try to optimize parallelism based on available VRAM, and adjust
// opts.NumCtx accordingly
func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList {
var numParallelToTry []int
if *numParallel <= 0 {
// If no specific parallel setting was provided, try larger then smaller, always end with 1
numParallelToTry = append(numParallelToTry, defaultParallel, 1)
} else {
numParallelToTry = []int{*numParallel}
}
for _, gl := range gpus.ByLibrary() {
sgl := append(make(discover.GpuInfoList, 0, len(gl)), gl...)
// TODO - potentially sort by performance capability, existing models loaded, etc.
// TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them
// Note: at present, this will favor most current available VRAM descending and ignoring faster GPU speed in mixed setups
sort.Sort(sort.Reverse(discover.ByFreeMemory(sgl)))
if !envconfig.SchedSpread() {
for _, p := range numParallelToTry {
req.opts.NumCtx = req.origNumCtx * p
// Try to pack into as few GPUs as possible, starting from 1 GPU
for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ {
gpuSubset := sgl[:numGPUs]
ok, estimatedVRAM := llm.PredictServerFit(gpuSubset, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p)
if ok {
slog.Info("new model will fit in available VRAM across minimum required GPUs, loading",
"model", req.model.ModelPath,
"library", sgl[0].Library,
"parallel", p,
"required", format.HumanBytes2(estimatedVRAM),
"gpus", numGPUs)
*numParallel = p
return gpuSubset
}
}
}
} else {
// TODO future refinements
// - if multiple Libraries, see if any single GPU in any Library will fit
// - try subsets of GPUs instead of just falling back to 1 or all in a family
// Now try all the GPUS (OLLAMA_SCHED_SPREAD is set)
for _, p := range numParallelToTry {
req.opts.NumCtx = req.origNumCtx * p
if ok, estimatedVRAM := llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
slog.Info("new model will fit in available VRAM, loading",
"model", req.model.ModelPath,
"library", sgl[0].Library,
"parallel", p,
"required", format.HumanBytes2(estimatedVRAM),
"gpus", len(sgl))
*numParallel = p
return sgl
}
}
}
}
return nil
}
// If multiple Libraries are detected, pick the Library which loads the most layers for the model
func pickBestPartialFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList {
if *numParallel <= 0 {
*numParallel = 1
req.opts.NumCtx = req.origNumCtx
}
byLibrary := gpus.ByLibrary()
if len(byLibrary) <= 1 {
return gpus
}
var bestEstimate uint64
var bestFit int
for i, gl := range byLibrary {
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, *numParallel)
if estimatedVRAM > bestEstimate {
bestEstimate = estimatedVRAM
bestFit = i
}
}
return byLibrary[bestFit]
}
// func (a BySize) Less(i, j int) bool { return a[i].vramSize < a[j].vramSize }
// findRunnerToUnload finds a runner to unload to make room for a new model
func (s *Scheduler) findRunnerToUnload() *runnerRef {
@@ -875,6 +754,13 @@ func (s *Scheduler) findRunnerToUnload() *runnerRef {
func (s *Scheduler) unloadAllRunners() {
s.loadedMu.Lock()
defer s.loadedMu.Unlock()
if s.activeLoading != nil {
slog.Debug("shutting down currently loading runner")
s.activeLoading.Close()
s.activeLoading = nil
}
for model, runner := range s.loaded {
if runner.llama != nil {
slog.Debug("shutting down runner", "model", model)
@@ -901,18 +787,3 @@ func (s *Scheduler) expireRunner(model *Model) {
runner.refMu.Unlock()
}
}
// If other runners are loaded, make sure the pending request will fit in system memory
// If not, pick a runner to unload, else return nil and the request can be loaded
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef {
slog.Debug("evaluating if CPU model load will fit in available system memory")
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts, req.opts.NumCtx/req.origNumCtx)
if estimate.TotalSize <= gpus[0].FreeMemory {
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
return nil
}
// TODO - optimization: try to find CPU only runners first, or partial offloads with enough in system memory to make room
return s.findRunnerToUnload()
}