ggml: Enable op_offload to improve partial offload performance

When a model is partially offloaded to system RAM, we can either
do the calculations on the CPU or we can temporarily transfer the
data to the GPU to do the calculations there. Small batches tend
to be better on the CPU, large batches on the GPU.

The llamarunner used the GPU in most cases and the ollamarunner
used the CPU. Although the ollamarunner saw an improvement in
token generation performance, there was a large performance hit
in prompt processing (3-10x).

There is an existing heuristic to dynamically switch between these
two modes but in practice it doesn't have enough information to
accurately make that decision. This adds authoritative data to make
the check work to get the best of both worlds.

Fixes #12037
This commit is contained in:
Jesse Gross
2025-10-27 16:32:05 -07:00
committed by Jesse Gross
parent 26465fb85f
commit afaf7ce8c3
15 changed files with 405 additions and 128 deletions

View File

@@ -106,6 +106,11 @@ type Context interface {
Arange(start, stop, step float32, dtype DType) Tensor
Forward(...Tensor) Context
// SetBatchSize provides a hint on the batch size to optimize processing
// Uses heuristics if not set
SetBatchSize(int)
Compute(...Tensor)
ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun

View File

@@ -386,7 +386,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
C.int(len(schedBackends)),
C.size_t(maxGraphNodes),
C._Bool(false),
C._Bool(false),
C._Bool(true),
C._Bool(params.AllocMemory),
)
@@ -749,6 +749,9 @@ type Context struct {
ctx *C.struct_ggml_context
graph *C.struct_ggml_cgraph
// batchSize is a hint to optimize processing
batchSize int
// buft is the buffer type used for new tensors
buft C.ggml_backend_buffer_type_t
@@ -805,6 +808,10 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
return c
}
func (c *Context) SetBatchSize(batchSize int) {
c.batchSize = batchSize
}
func (c *Context) Compute(tensors ...ml.Tensor) {
c.ComputeWithNotify(nil, tensors...)
}
@@ -815,6 +822,11 @@ func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) {
if cb != nil {
go cb()
}
if c.batchSize > 0 {
C.ggml_backend_sched_set_batch_size(c.b.sched, C.int(c.batchSize))
}
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
panic(fmt.Errorf("error computing ggml graph: %v", status))
}
@@ -836,6 +848,10 @@ func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) {
}
func (c *Context) Reserve() {
if c.batchSize > 0 {
C.ggml_backend_sched_set_batch_size(c.b.sched, C.int(c.batchSize))
}
reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph)
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))

View File

@@ -98,7 +98,7 @@ extern "C" {
GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size);
// NOTE: will be removed, use device version instead
GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
@@ -317,6 +317,9 @@ extern "C" {
GGML_API ggml_backend_sched_t ggml_backend_sched_new_ext(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload, bool alloc_buffers);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
// Provide a hint on the batch size to optimize processing (uses heuristics if unset)
GGML_API void ggml_backend_sched_set_batch_size(ggml_backend_sched_t sched, int batch_size);
// Initialize backend buffers from a measure graph
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success

View File

@@ -112,8 +112,8 @@ extern "C" {
// compute the graph with the plan
enum ggml_status (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
// compute graph (always async if supported by the backend)
enum ggml_status (*graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
// compute graph (always async if supported by the backend). batch_size may be -1 if unknown
enum ggml_status (*graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size);
// (optional) event synchronization
// record an event on this stream

View File

@@ -368,14 +368,14 @@ enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_ba
}
enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph);
enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph, -1);
ggml_backend_synchronize(backend);
return err;
}
enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size) {
GGML_ASSERT(backend);
return backend->iface.graph_compute(backend, cgraph);
return backend->iface.graph_compute(backend, cgraph, batch_size);
}
bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
@@ -750,6 +750,8 @@ struct ggml_backend_sched {
bool op_offload;
int batch_size; // a hint on the batch size to optimize processing, -1 to use heuristics
int debug;
// allocate buffers on attached ggml_backend_buffer_type_t's and during reservation
@@ -848,7 +850,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
// check if a backend with higher prio wants to offload the op
if (sched->op_offload && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
if (sched->op_offload && (sched->batch_size < 0 || sched->batch_size >= 32) && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
for (int b = 0; b < src_backend_id; b++) {
if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
SET_CAUSE(tensor, "1.off");
@@ -1584,7 +1586,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
}
if (!sched->callback_eval) {
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph, sched->batch_size);
if (ec != GGML_STATUS_SUCCESS) {
return ec;
}
@@ -1606,7 +1608,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv);
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv, sched->batch_size);
if (ec != GGML_STATUS_SUCCESS) {
return ec;
}
@@ -1698,6 +1700,7 @@ ggml_backend_sched_t ggml_backend_sched_new_ext(
sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
sched->op_offload = op_offload;
sched->batch_size = -1;
sched->alloc_buffers = alloc_buffers;
ggml_backend_sched_reset(sched);
@@ -1734,6 +1737,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
free(sched);
}
void ggml_backend_sched_set_batch_size(ggml_backend_sched_t sched, int batch_size) {
sched->batch_size = batch_size;
}
void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
GGML_ASSERT(sched);
// reset state for the next run

View File

@@ -224,7 +224,7 @@ static void ggml_backend_blas_free(ggml_backend_t backend) {
delete backend;
}
static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size) {
ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -254,6 +254,7 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
return GGML_STATUS_SUCCESS;
GGML_UNUSED(backend);
GGML_UNUSED(batch_size);
}
static struct ggml_backend_i blas_backend_i = {

View File

@@ -164,7 +164,7 @@ static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backe
GGML_UNUSED(backend);
}
static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size) {
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
@@ -184,6 +184,8 @@ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, s
cplan.abort_callback_data = cpu_ctx->abort_callback_data;
return ggml_graph_compute(cgraph, &cplan);
GGML_UNUSED(batch_size);
}
static const struct ggml_backend_i ggml_backend_cpu_i = {

View File

@@ -2775,31 +2775,19 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
#ifdef USE_CUDA_GRAPH
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool use_cuda_graph) {
int batch_size, bool use_cuda_graph) {
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
// This fix was added in llama.cpp and Ollama in parallel, but with
// different tensor names.
// llama.cpp: https://github.com/ggml-org/llama.cpp/pull/14741
// ollama: https://github.com/ollama/ollama/pull/11525
const std::string gemma3n_per_layer_proj_src1_name_ollama = " (reshaped)";
const std::string gemma3n_node_name_ollama = "node_";
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
const std::string ffn_moe_bias_suffix = "_exps.bias";
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@@ -2821,30 +2809,34 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
#endif
}
if (node->op == GGML_OP_ADD &&
node->src[1] && node->src[1]->ne[1] > 1 &&
// ollama
// workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n
// number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here
!(node->ne[0] == 256 && node->ne[2] == 1 && node->ne[3] == 1 && node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name_ollama) != std::string::npos : false && node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name_ollama : false) &&
node->src[1] ? std::string(node->src[1]->name).find(ffn_moe_bias_suffix) == std::string::npos : false &&
// upstream
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
use_cuda_graph = false;
// If we have an explicit batch size hint then we don't need to use the tensor name heuristics
if (batch_size >= 0) {
if (batch_size > 1) {
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%d]\n", __func__, batch_size);
#endif
}
} else {
if (node->op == GGML_OP_ADD &&
node->src[1] && node->src[1]->ne[1] > 1 &&
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
#endif
}
}
if (node->op == GGML_OP_CPY) {
@@ -3247,7 +3239,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
}
}
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
cuda_ctx->pool_set_alloc(true);
@@ -3286,7 +3278,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
if (use_cuda_graph) {
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, batch_size, use_cuda_graph);
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) {

View File

@@ -410,10 +410,12 @@ static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml
GGML_UNUSED(dst);
}
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
ggml_metal_t ctx = (ggml_metal_t)backend->context;
return ggml_metal_graph_compute(ctx, cgraph);
GGML_UNUSED(batch_size);
}
static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {

View File

@@ -12039,7 +12039,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru
return num_adds;
}
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) {
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -12235,6 +12235,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
return GGML_STATUS_SUCCESS;
UNUSED(backend);
UNUSED(batch_size);
}
// Sort the graph for improved parallelism.