diff --git a/llama/llama.go b/llama/llama.go index f8a051ea..582d4128 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -63,8 +63,13 @@ func BackendInit() { C.llama_backend_init() } -func EnumerateGPUs() []ml.DeviceID { - var ids []ml.DeviceID +type Devices struct { + ml.DeviceID + LlamaID uint64 +} + +func EnumerateGPUs() []Devices { + var ids []Devices for i := range C.ggml_backend_dev_count() { device := C.ggml_backend_dev_get(i) @@ -74,9 +79,12 @@ func EnumerateGPUs() []ml.DeviceID { C.GGML_BACKEND_DEVICE_TYPE_IGPU: var props C.struct_ggml_backend_dev_props C.ggml_backend_dev_get_props(device, &props) - ids = append(ids, ml.DeviceID{ - ID: C.GoString(props.id), - Library: C.GoString(props.library), + ids = append(ids, Devices{ + DeviceID: ml.DeviceID{ + ID: C.GoString(props.id), + Library: C.GoString(props.library), + }, + LlamaID: uint64(i), }) } } @@ -231,6 +239,7 @@ func (c *Context) GetLogitsIth(i int) []float32 { } type ModelParams struct { + Devices []uint64 NumGpuLayers int MainGpu int UseMmap bool @@ -254,6 +263,21 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) { cparams.use_mmap = C.bool(params.UseMmap) cparams.vocab_only = C.bool(params.VocabOnly) + var devices []C.ggml_backend_dev_t + for _, llamaID := range params.Devices { + devices = append(devices, C.ggml_backend_dev_get(C.size_t(llamaID))) + } + if len(devices) > 0 { + devices = append(devices, C.ggml_backend_dev_t(C.NULL)) + devicesData := &devices[0] + + var devicesPin runtime.Pinner + devicesPin.Pin(devicesData) + defer devicesPin.Unpin() + + cparams.devices = devicesData + } + if len(params.TensorSplit) > 0 { tensorSplitData := ¶ms.TensorSplit[0] diff --git a/ml/device.go b/ml/device.go index dc91359f..040764fe 100644 --- a/ml/device.go +++ b/ml/device.go @@ -8,6 +8,7 @@ import ( "hash/maphash" "io" "log/slog" + "math" "net/http" "runtime" "slices" @@ -28,6 +29,22 @@ type GPULayers struct { Layers []int } +// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty. +func (g GPULayers) FirstLayer() int { + if len(g.Layers) == 0 { + return math.MaxInt + } + + first := g.Layers[0] + for i := 1; i < len(g.Layers); i++ { + if g.Layers[i] < first { + first = g.Layers[i] + } + } + + return first +} + func (g GPULayers) String() string { if len(g.Layers) == 0 { return "" @@ -54,6 +71,17 @@ func (g GPULayers) String() string { // GPULayersList is a set of layer allocations across multiple GPUs type GPULayersList []GPULayers +func (l GPULayersList) Len() int { return len(l) } +func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] } + +// Sort by the ordering of the layers offloaded +func (l GPULayersList) Less(i, j int) bool { + li := l[i].FirstLayer() + lj := l[j].FirstLayer() + + return li < lj +} + func (l GPULayersList) String() string { if l.Sum() > 0 { return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l)) diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 16c84a78..a23ddd61 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -12,6 +12,7 @@ import ( "net/http" "os" "regexp" + "sort" "strconv" "strings" "sync" @@ -900,19 +901,24 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) { s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) - gpuIDs := llama.EnumerateGPUs() - tensorSplit := make([]float32, len(gpuIDs)) numGPU := 0 - for i := range gpuIDs { - for _, layers := range req.GPULayers { - if gpuIDs[i] == layers.DeviceID { - tensorSplit[i] = float32(len(layers.Layers)) + var tensorSplit []float32 + var llamaIDs []uint64 + + gpuIDs := llama.EnumerateGPUs() + sort.Sort(req.GPULayers) + for _, layers := range req.GPULayers { + for i := range gpuIDs { + if gpuIDs[i].DeviceID == layers.DeviceID { numGPU += len(layers.Layers) + tensorSplit = append(tensorSplit, float32(len(layers.Layers))) + llamaIDs = append(llamaIDs, gpuIDs[i].LlamaID) } } } params := llama.ModelParams{ + Devices: llamaIDs, NumGpuLayers: numGPU, MainGpu: req.MainGPU, UseMmap: req.UseMmap && len(req.LoraPath) == 0,