mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
llm: New memory management
This changes the memory allocation strategy from upfront estimation to tracking actual allocations done by the engine and reacting to that. The goal is avoid issues caused by both under-estimation (crashing) and over-estimation (low performance due to under-utilized GPUs). It is currently opt-in and can be enabled for models running on the Ollama engine by setting OLLAMA_NEW_ESTIMATES=1. Behavior in other cases is unchanged and will continue to use the existing estimates.
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -259,6 +260,16 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
// modelPath is the location of the model to be loaded
|
||||
modelPath string
|
||||
|
||||
// loadMu prevents more than one load attempt from occurring at a time
|
||||
loadMu sync.Mutex
|
||||
|
||||
// lastLoad is the load request from the previous load attempt. Used to
|
||||
// detect if we can reuse an existing memory allocation.
|
||||
lastLoad llm.LoadRequest
|
||||
|
||||
// is the server ready to process requests?
|
||||
// protects access to model and image
|
||||
ready sync.WaitGroup
|
||||
@@ -720,17 +731,6 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type multiLPath []string
|
||||
|
||||
func (m *multiLPath) Set(value string) error {
|
||||
*m = append(*m, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *multiLPath) String() string {
|
||||
return strings.Join(*m, ", ")
|
||||
}
|
||||
|
||||
func (s *Server) reserveWorstCaseGraph() error {
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
@@ -828,15 +828,28 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) initModel(
|
||||
// allocModel pre-allocates the maximum needed memory for a model
|
||||
// based on the given parameters
|
||||
func (s *Server) allocModel(
|
||||
mpath string,
|
||||
params ml.BackendParams,
|
||||
lpath multiLPath,
|
||||
loraPath []string,
|
||||
parallel int,
|
||||
kvCacheType string,
|
||||
kvSize int,
|
||||
multiUserCache bool,
|
||||
) error {
|
||||
) (panicErr error) {
|
||||
// Convert memory allocation panics to errors
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if err, ok := r.(error); ok {
|
||||
panicErr = err
|
||||
} else {
|
||||
panic(r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var err error
|
||||
s.model, err = model.New(mpath, params)
|
||||
if err != nil {
|
||||
@@ -844,7 +857,7 @@ func (s *Server) initModel(
|
||||
}
|
||||
|
||||
// TODO(jessegross): LoRA loading
|
||||
if lpath.String() != "" {
|
||||
if len(loraPath) > 0 {
|
||||
return errors.New("loras are not yet implemented")
|
||||
}
|
||||
|
||||
@@ -865,63 +878,122 @@ func (s *Server) initModel(
|
||||
return s.reserveWorstCaseGraph()
|
||||
}
|
||||
|
||||
func (s *Server) load(
|
||||
ctx context.Context,
|
||||
mpath string,
|
||||
params ml.BackendParams,
|
||||
lpath multiLPath,
|
||||
parallel int,
|
||||
kvCacheType string,
|
||||
kvSize int,
|
||||
multiUserCache bool,
|
||||
) {
|
||||
err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache)
|
||||
if err != nil {
|
||||
var noMem ml.ErrNoMem
|
||||
if errors.As(err, &noMem) {
|
||||
// We can't yet handle this but in the future we will
|
||||
s.cache.Close()
|
||||
if s.model != nil {
|
||||
s.model.Backend().Close()
|
||||
}
|
||||
}
|
||||
|
||||
panic(err)
|
||||
// closeModel frees all memory associated with a model
|
||||
func (s *Server) closeModel() {
|
||||
s.cache.Close()
|
||||
s.cache = nil
|
||||
if s.model != nil {
|
||||
s.model.Backend().Close()
|
||||
s.model = nil
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("memory", "allocated", s.model.Backend().BackendMemory())
|
||||
|
||||
err = s.model.Backend().Load(ctx,
|
||||
// loadModel loads the weights for a model. The memory must already
|
||||
// have been allocated with allocModel
|
||||
func (s *Server) loadModel() {
|
||||
err := s.model.Backend().Load(context.TODO(),
|
||||
func(progress float32) {
|
||||
s.progress = progress
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
panic(fmt.Errorf("failed to load model: %v", err))
|
||||
}
|
||||
|
||||
s.status = llm.ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
// load is the handler called by the Ollama server to process different
|
||||
// load operations
|
||||
func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
||||
s.loadMu.Lock()
|
||||
defer s.loadMu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if s.status != llm.ServerStatusLaunched {
|
||||
http.Error(w, "model already loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var req llm.LoadRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("load", "request", req)
|
||||
|
||||
if req.Operation == llm.LoadOperationClose {
|
||||
s.closeModel()
|
||||
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.lastLoad.Operation = req.Operation
|
||||
loadModel := s.model == nil || !reflect.DeepEqual(req, s.lastLoad)
|
||||
|
||||
s.lastLoad = req
|
||||
|
||||
if loadModel {
|
||||
s.closeModel()
|
||||
|
||||
params := ml.BackendParams{
|
||||
AllocMemory: req.Operation != llm.LoadOperationFit,
|
||||
NumThreads: req.NumThreads,
|
||||
GPULayers: req.GPULayers,
|
||||
FlashAttention: req.FlashAttention,
|
||||
}
|
||||
|
||||
s.batchSize = req.BatchSize
|
||||
|
||||
err := s.allocModel(s.modelPath, params, req.LoraPath, req.Parallel, req.KvCacheType, req.KvSize, req.MultiUserCache)
|
||||
if err != nil {
|
||||
s.closeModel()
|
||||
|
||||
var noMem ml.ErrNoMem
|
||||
if errors.As(err, &noMem) {
|
||||
resp := llm.LoadResponse{Success: false, Memory: noMem.BackendMemory}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, fmt.Sprintf("failed to initialize model: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
mem := s.model.Backend().BackendMemory()
|
||||
|
||||
switch req.Operation {
|
||||
case llm.LoadOperationFit:
|
||||
// LoadOperationFit can't be used for anything else, so just close it
|
||||
s.closeModel()
|
||||
|
||||
// LoadOperationAlloc should stay open for future operations
|
||||
|
||||
case llm.LoadOperationCommit:
|
||||
s.status = llm.ServerStatusLoadingModel
|
||||
go s.loadModel()
|
||||
}
|
||||
|
||||
resp := llm.LoadResponse{Success: true, Memory: mem}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func Execute(args []string) error {
|
||||
fs := flag.NewFlagSet("runner", flag.ExitOnError)
|
||||
mpath := fs.String("model", "", "Path to model binary file")
|
||||
parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
|
||||
batchSize := fs.Int("batch-size", 512, "Batch size")
|
||||
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
|
||||
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
|
||||
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||
threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||
_ = fs.Bool("verbose", false, "verbose output (default: disabled)")
|
||||
_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
|
||||
tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
|
||||
multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||
|
||||
var lpaths multiLPath
|
||||
fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
|
||||
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintf(fs.Output(), "Runner usage\n")
|
||||
@@ -933,39 +1005,17 @@ func Execute(args []string) error {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
slog.Info("starting ollama engine")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
server := &Server{
|
||||
batchSize: *batchSize,
|
||||
status: llm.ServerStatusLoadingModel,
|
||||
modelPath: *mpath,
|
||||
status: llm.ServerStatusLaunched,
|
||||
}
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
server.ready.Add(1)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// TODO(jessegross): Parameters that need to be implemented:
|
||||
// no-mmap
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
if *tensorSplit != "" {
|
||||
splits := strings.Split(*tensorSplit, ",")
|
||||
tensorSplitFloats = make([]float32, len(splits))
|
||||
for i, s := range splits {
|
||||
f, _ := strconv.ParseFloat(s, 32)
|
||||
tensorSplitFloats[i] = float32(f)
|
||||
}
|
||||
}
|
||||
|
||||
params := ml.BackendParams{
|
||||
NumThreads: *threads,
|
||||
NumGPULayers: *numGPULayers,
|
||||
MainGPU: *mainGPU,
|
||||
TensorSplit: tensorSplitFloats,
|
||||
FlashAttention: *flashAttention,
|
||||
}
|
||||
|
||||
go server.load(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||
go server.run(ctx)
|
||||
|
||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||
@@ -978,6 +1028,7 @@ func Execute(args []string) error {
|
||||
|
||||
mux := http.NewServeMux()
|
||||
// TODO: support embeddings
|
||||
mux.HandleFunc("POST /load", server.load)
|
||||
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user