mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-21 22:33:56 +00:00
llm: New memory management
This changes the memory allocation strategy from upfront estimation to tracking actual allocations done by the engine and reacting to that. The goal is avoid issues caused by both under-estimation (crashing) and over-estimation (low performance due to under-utilized GPUs). It is currently opt-in and can be enabled for models running on the Ollama engine by setting OLLAMA_NEW_ESTIMATES=1. Behavior in other cases is unchanged and will continue to use the existing estimates.
This commit is contained in:
@@ -12,7 +12,6 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -216,6 +215,12 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error)
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
// modelPath is the location of the model to be loaded
|
||||
modelPath string
|
||||
|
||||
// loadMu prevents more than one load attempt from occurring at a time
|
||||
loadMu sync.Mutex
|
||||
|
||||
// is the server ready to process requests?
|
||||
// protects access to model and image
|
||||
ready sync.WaitGroup
|
||||
@@ -723,21 +728,12 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type multiLPath []string
|
||||
|
||||
func (m *multiLPath) Set(value string) error {
|
||||
*m = append(*m, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *multiLPath) String() string {
|
||||
return strings.Join(*m, ", ")
|
||||
}
|
||||
|
||||
// loadModel allocates memory based on the given parameters and loads the weights. The
|
||||
// memory allocated is worst case for text models but not for vision.
|
||||
func (s *Server) loadModel(
|
||||
params llama.ModelParams,
|
||||
mpath string,
|
||||
lpath multiLPath,
|
||||
lpath []string,
|
||||
ppath string,
|
||||
kvSize int,
|
||||
kvCacheType string,
|
||||
@@ -757,12 +753,10 @@ func (s *Server) loadModel(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if lpath.String() != "" {
|
||||
for _, path := range lpath {
|
||||
err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, path := range lpath {
|
||||
err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -783,26 +777,81 @@ func (s *Server) loadModel(
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
// load is the handler called by the Ollama server to process different
|
||||
// load operations
|
||||
func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
||||
s.loadMu.Lock()
|
||||
defer s.loadMu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if s.status != llm.ServerStatusLaunched {
|
||||
http.Error(w, "model already loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var req llm.LoadRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("load", "request", req)
|
||||
|
||||
switch req.Operation {
|
||||
// LoadOperationFit and LoadOperationAlloc have no meaning here - just return a successful response
|
||||
|
||||
case llm.LoadOperationCommit:
|
||||
s.batchSize = req.BatchSize
|
||||
s.parallel = req.Parallel
|
||||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
|
||||
gpuIDs := llama.EnumerateGPUs()
|
||||
tensorSplit := make([]float32, len(gpuIDs))
|
||||
numGPU := 0
|
||||
for i := range gpuIDs {
|
||||
for _, layers := range req.GPULayers {
|
||||
if gpuIDs[i] == layers.ID {
|
||||
tensorSplit[i] = float32(len(layers.Layers))
|
||||
numGPU += len(layers.Layers)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
params := llama.ModelParams{
|
||||
NumGpuLayers: numGPU,
|
||||
MainGpu: req.MainGPU,
|
||||
UseMmap: req.UseMmap && len(req.LoraPath) == 0,
|
||||
TensorSplit: tensorSplit,
|
||||
Progress: func(progress float32) {
|
||||
s.progress = progress
|
||||
},
|
||||
}
|
||||
|
||||
s.status = llm.ServerStatusLoadingModel
|
||||
go s.loadModel(params, s.modelPath, req.LoraPath, req.ProjectorPath, req.KvSize, req.KvCacheType, req.FlashAttention, req.NumThreads, req.MultiUserCache)
|
||||
|
||||
case llm.LoadOperationClose:
|
||||
// No-op for us
|
||||
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
resp := llm.LoadResponse{Success: true}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func Execute(args []string) error {
|
||||
fs := flag.NewFlagSet("runner", flag.ExitOnError)
|
||||
mpath := fs.String("model", "", "Path to model binary file")
|
||||
ppath := fs.String("mmproj", "", "Path to projector binary file")
|
||||
parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
|
||||
batchSize := fs.Int("batch-size", 512, "Batch size")
|
||||
nGpuLayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||
mainGpu := fs.Int("main-gpu", 0, "Main GPU")
|
||||
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
|
||||
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||
threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||
_ = fs.Bool("verbose", false, "verbose output (default: disabled)")
|
||||
noMmap := fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
|
||||
tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
|
||||
multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||
|
||||
var lpaths multiLPath
|
||||
fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
|
||||
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintf(fs.Output(), "Runner usage\n")
|
||||
@@ -817,35 +866,11 @@ func Execute(args []string) error {
|
||||
llama.BackendInit()
|
||||
|
||||
server := &Server{
|
||||
batchSize: *batchSize,
|
||||
parallel: *parallel,
|
||||
seqs: make([]*Sequence, *parallel),
|
||||
seqsSem: semaphore.NewWeighted(int64(*parallel)),
|
||||
status: llm.ServerStatusLoadingModel,
|
||||
}
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
if *tensorSplit != "" {
|
||||
splits := strings.Split(*tensorSplit, ",")
|
||||
tensorSplitFloats = make([]float32, len(splits))
|
||||
for i, s := range splits {
|
||||
f, _ := strconv.ParseFloat(s, 32)
|
||||
tensorSplitFloats[i] = float32(f)
|
||||
}
|
||||
}
|
||||
|
||||
params := llama.ModelParams{
|
||||
NumGpuLayers: *nGpuLayers,
|
||||
MainGpu: *mainGpu,
|
||||
UseMmap: !*noMmap && lpaths.String() == "",
|
||||
TensorSplit: tensorSplitFloats,
|
||||
Progress: func(progress float32) {
|
||||
server.progress = progress
|
||||
},
|
||||
modelPath: *mpath,
|
||||
status: llm.ServerStatusLaunched,
|
||||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *kvCacheType, *flashAttention, *threads, *multiUserCache)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
@@ -863,6 +888,7 @@ func Execute(args []string) error {
|
||||
defer listener.Close()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("POST /load", server.load)
|
||||
mux.HandleFunc("/embedding", server.embeddings)
|
||||
mux.HandleFunc("/completion", server.completion)
|
||||
mux.HandleFunc("/health", server.health)
|
||||
|
||||
Reference in New Issue
Block a user