ggml: Preallocate CUDA pool memory

The GGML CUDA backend allocates additional memory for intermediate
results during calculation. This memory isn't currently allocated
during worst case graph reservation and therefore not included in
scheduling. This means that as these buffers potentially grow
with context length, we could crash.

This extends the memory allocation system down layer from the GGML
graph to the CUDA layer, preallocating the worst case memory there
as well.

Fixes #11753
This commit is contained in:
Jesse Gross
2025-09-09 16:17:31 -07:00
committed by Jesse Gross
parent efaee8c2d6
commit 3d0b1734c0
7 changed files with 927 additions and 126 deletions

View File

@@ -159,7 +159,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
bt := C.ggml_backend_dev_buffer_type(d)
cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, bt)
C.ggml_backend_buft_set_alloc(bt, C.bool(params.AllocMemory))
btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU
}
@@ -181,7 +180,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
d: d,
bts: append([]C.ggml_backend_buffer_type_t{bt}, cpuDeviceBufferType.bts...),
})
C.ggml_backend_buft_set_alloc(bt, C.bool(params.AllocMemory))
btDeviceMemory[bt] = &requiredMemory.GPUs[i]
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
@@ -337,35 +335,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
}
}
// allocate buffers for each context
bbs := make(map[*C.struct_ggml_context]C.ggml_backend_buffer_t, len(ctxs))
for bt, c := range ctxs {
if C.ggml_get_first_tensor(c) == nil {
continue
}
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
if b == nil {
for _, b := range bbs {
C.ggml_backend_buffer_free(b)
}
for _, ctx := range ctxs {
C.ggml_free(ctx)
}
panic(ml.ErrNoMem{BackendMemory: requiredMemory})
}
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
bbs[c] = b
}
for bs := range maps.Values(bbs) {
logutil.Trace("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
"size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
}
// map tensor names to tensors for easy lookup later
tensors := make(map[string]*C.struct_ggml_tensor)
for _, c := range ctxs {
@@ -403,6 +372,46 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
}
maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
sched := C.ggml_backend_sched_new_ext(
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
C.int(len(schedBackends)),
C.size_t(maxGraphNodes),
C._Bool(false),
C._Bool(false),
C._Bool(params.AllocMemory),
)
// allocate buffers for each context
bbs := make(map[*C.struct_ggml_context]C.ggml_backend_buffer_t, len(ctxs))
for bt, c := range ctxs {
if C.ggml_get_first_tensor(c) == nil {
continue
}
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
if b == nil {
for _, b := range bbs {
C.ggml_backend_buffer_free(b)
}
for _, ctx := range ctxs {
C.ggml_free(ctx)
}
panic(ml.ErrNoMem{BackendMemory: requiredMemory})
}
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
bbs[c] = b
}
for bs := range maps.Values(bbs) {
logutil.Trace("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
"size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
}
return &Backend{
modelPath: modelPath,
allocMemory: params.AllocMemory,
@@ -410,18 +419,11 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
meta: meta,
tensorLoadTargets: targets,
tensors: tensors,
sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
C.int(len(schedBackends)),
C.size_t(maxGraphNodes),
C._Bool(false),
C._Bool(false),
),
schedBackends: schedBackends,
schedBufts: schedBufts,
input: deviceBufferTypes[input.d],
output: output.d,
sched: sched,
schedBackends: schedBackends,
schedBufts: schedBufts,
input: deviceBufferTypes[input.d],
output: output.d,
layers: func() map[int]layerDevice {
m := make(map[int]layerDevice)
for i, layer := range layers {

View File

@@ -35,7 +35,6 @@ extern "C" {
//
GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft);
GGML_API void ggml_backend_buft_set_alloc (ggml_backend_buffer_type_t buft, bool alloc);
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size);
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft);
@@ -293,6 +292,7 @@ extern "C" {
// Initialize a backend scheduler, backends with low index are given priority over backends with high index
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload);
GGML_API ggml_backend_sched_t ggml_backend_sched_new_ext(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload, bool alloc_buffers);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
// Initialize backend buffers from a measure graph

View File

@@ -26,6 +26,10 @@ extern "C" {
size_t (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
// (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false)
bool (*is_host) (ggml_backend_buffer_type_t buft);
// (optional) returns a dummy buffer that is equivalent to one created by alloc_buffer but without actually being backed
// by memory
ggml_backend_buffer_t (*noalloc_buffer)(ggml_backend_buffer_type_t buft, size_t size);
};
struct ggml_backend_buffer_type {
@@ -116,6 +120,16 @@ extern "C" {
void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event);
// wait for an event on on a different stream
void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event);
// (optional) reserves intermediate buffers needed for the compution
// if alloc is true, memory is actually allocated, otherwise the required amount is just returned by buffer_size
enum ggml_status (*graph_reserve) (ggml_backend_t backend, struct ggml_cgraph * cgraph, bool alloc);
// (optional) returns the memory needed after calling graph_reserve
size_t (*buffer_size) (ggml_backend_t backend);
// (optional) frees memory from intermediate buffers that was allocated either by graph_compute or graph_reserve
void (*reset) (ggml_backend_t backend);
};
struct ggml_backend {

View File

@@ -35,10 +35,6 @@ const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
return buft->iface.get_name(buft);
}
void ggml_backend_buft_set_alloc(ggml_backend_buffer_type_t buft, bool alloc) {
buft->no_alloc = !alloc;
}
ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
if (size == 0) {
// return a dummy buffer for zero-sized allocations
@@ -46,7 +42,14 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t
}
if (buft->no_alloc) {
ggml_backend_buffer_t buf = ggml_backend_buffer_init(buft, {}, NULL, size);
ggml_backend_buffer_t buf;
if (buft->iface.noalloc_buffer != NULL) {
buf = buft->iface.noalloc_buffer(buft, size);
} else {
buf = ggml_backend_buffer_init(buft, {}, NULL, size);
}
buf->no_alloc = true;
return buf;
}
@@ -688,6 +691,12 @@ struct ggml_backend_sched {
bool op_offload;
int debug;
// allocate buffers on attached ggml_backend_buffer_type_t's and during reservation
// if false, dummy buffers are used for faster memory sizing calculations
// the scheduler needs to be recreated with allocated buffers before it can be used
// for computation
bool alloc_buffers;
};
#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
@@ -1474,6 +1483,17 @@ ggml_backend_sched_t ggml_backend_sched_new(
size_t graph_size,
bool parallel,
bool op_offload) {
return ggml_backend_sched_new_ext(backends, bufts, n_backends, graph_size, parallel, op_offload, true);
}
ggml_backend_sched_t ggml_backend_sched_new_ext(
ggml_backend_t * backends,
ggml_backend_buffer_type_t * bufts,
int n_backends,
size_t graph_size,
bool parallel,
bool op_offload,
bool alloc_buffers) {
GGML_ASSERT(n_backends > 0);
GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU);
@@ -1515,10 +1535,13 @@ ggml_backend_sched_t ggml_backend_sched_new(
sched->events[b][c] = ggml_backend_event_new(backends[b]->device);
}
}
sched->bufts[b]->no_alloc = !alloc_buffers;
}
sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
sched->op_offload = op_offload;
sched->alloc_buffers = alloc_buffers;
ggml_backend_sched_reset(sched);
@@ -1533,6 +1556,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
for (int c = 0; c < sched->n_copies; c++) {
ggml_backend_event_free(sched->events[b][c]);
}
if (sched->backends[b]->iface.reset != NULL) {
sched->backends[b]->iface.reset(sched->backends[b]);
}
}
ggml_gallocr_free(sched->galloc);
ggml_free(sched->ctx);
@@ -1572,6 +1599,24 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
return false;
}
if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
return false;
}
struct ggml_backend_sched_split * splits = sched->splits;
for (int i = 0; i < sched->n_splits; i++) {
struct ggml_backend_sched_split * split = &splits[i];
int split_backend_id = split->backend_id;
ggml_backend_t split_backend = sched->backends[split_backend_id];
if (split_backend->iface.graph_reserve != NULL) {
enum ggml_status ec = split_backend->iface.graph_reserve(split_backend, &split->graph, sched->alloc_buffers);
if (ec != GGML_STATUS_SUCCESS) {
return false;
}
}
}
ggml_backend_sched_reset(sched);
return true;
@@ -1660,7 +1705,13 @@ size_t ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched,
int backend_index = ggml_backend_sched_backend_id(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
return ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index);
size_t size = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index);
if (backend->iface.buffer_size != NULL) {
size += backend->iface.buffer_size(backend);
}
return size;
}
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {

View File

@@ -35,6 +35,31 @@
#include "vendors/cuda.h"
#endif // defined(GGML_USE_HIP)
extern bool reserving_graph;
// If we are reserving the graph, pointers might be invalid and will fail if cudaMemcpyAsync tries to validate them.
// However, since we don't actually expect a result, we don't need to actually do the memcpy.
static cudaError_t cudaMemcpyAsyncReserve ( void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream = 0 ) {
if (!reserving_graph) {
return cudaMemcpyAsync(dst, src, count, kind, stream);
} else {
return cudaSuccess;
}
}
static cudaError_t cudaMemcpy2DAsyncReserve ( void* dst, size_t dpitch, const void* src, size_t spitch, size_t width, size_t height, cudaMemcpyKind kind, cudaStream_t stream = 0 ) {
if (!reserving_graph) {
return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, kind, stream);
} else {
return cudaSuccess;
}
}
#undef cudaMemcpyAsync
#define cudaMemcpyAsync cudaMemcpyAsyncReserve
#undef cudaMemcpy2DAsync
#define cudaMemcpy2DAsync cudaMemcpy2DAsyncReserve
#define STRINGIZE_IMPL(...) #__VA_ARGS__
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
@@ -771,6 +796,9 @@ struct ggml_cuda_pool {
virtual void * alloc(size_t size, size_t * actual_size) = 0;
virtual void free(void * ptr, size_t size) = 0;
virtual bool alloc_memory() = 0;
virtual size_t alloc_size() = 0;
};
template<typename T>
@@ -914,11 +942,11 @@ struct ggml_backend_cuda_context {
// pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, bool alloc);
ggml_cuda_pool & pool(int device) {
if (pools[device] == nullptr) {
pools[device] = new_pool_for_device(device);
pools[device] = new_pool_for_device(device, true);
}
return *pools[device];
}
@@ -926,4 +954,20 @@ struct ggml_backend_cuda_context {
ggml_cuda_pool & pool() {
return pool(device);
}
void pool_set_alloc(bool alloc) {
GGML_ASSERT(pools[device] == nullptr || pools[device]->alloc_memory() == alloc);
if (pools[device] == nullptr) {
pools[device] = new_pool_for_device(device, alloc);
}
}
size_t pool_get_alloc_size() {
if (pools[device] == nullptr) {
return 0;
}
return pools[device]->alloc_size();
}
};

View File

@@ -355,6 +355,8 @@ const ggml_cuda_device_info & ggml_cuda_info() {
// #define DEBUG_CUDA_MALLOC
#define CUDA_ALIGNMENT 128
// buffer pool for cuda (legacy)
struct ggml_cuda_pool_leg : public ggml_cuda_pool {
static const int MAX_BUFFERS = 256;
@@ -367,9 +369,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
size_t pool_size = 0;
bool allocate = true;
size_t last_alloc = 0;
explicit ggml_cuda_pool_leg(int device) :
device(device) {
explicit ggml_cuda_pool_leg(int device, bool alloc) :
device(device),
allocate(alloc) {
}
~ggml_cuda_pool_leg() {
@@ -377,7 +382,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
for (int i = 0; i < MAX_BUFFERS; ++i) {
ggml_cuda_buffer & b = buffer_pool[i];
if (b.ptr != nullptr) {
CUDA_CHECK(cudaFree(b.ptr));
if (allocate) {
CUDA_CHECK(cudaFree(b.ptr));
}
pool_size -= b.size;
}
}
@@ -425,8 +432,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
void * ptr;
size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
ggml_cuda_set_device(device);
CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
if (allocate) {
ggml_cuda_set_device(device);
if (ggml_cuda_device_malloc(&ptr, look_ahead_size, device) != cudaSuccess) {
last_alloc = look_ahead_size;
throw std::bad_alloc();
}
} else {
ptr = (void *)CUDA_ALIGNMENT;
}
*actual_size = look_ahead_size;
pool_size += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC
@@ -446,10 +460,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
}
}
GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n");
ggml_cuda_set_device(device);
CUDA_CHECK(cudaFree(ptr));
if (allocate) {
ggml_cuda_set_device(device);
CUDA_CHECK(cudaFree(ptr));
}
pool_size -= size;
}
bool alloc_memory() override {
return allocate;
}
size_t alloc_size() override {
return pool_size + last_alloc;
}
};
// pool with virtual memory
@@ -461,18 +485,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
CUdeviceptr pool_addr = 0;
size_t pool_used = 0;
size_t pool_size = 0;
bool allocate = true;
size_t last_alloc = 0;
size_t granularity;
#if defined(GGML_USE_HIP)
std::vector<std::pair<CUdeviceptr, size_t>> mappings;
#endif
explicit ggml_cuda_pool_vmm(int device) :
explicit ggml_cuda_pool_vmm(int device, bool alloc) :
device(device),
granularity(ggml_cuda_info().devices[device].vmm_granularity) {
granularity(ggml_cuda_info().devices[device].vmm_granularity),
allocate(alloc) {
if (!allocate) {
pool_addr = (CUdeviceptr)CUDA_ALIGNMENT;
}
}
~ggml_cuda_pool_vmm() {
if (pool_addr != 0) {
if (pool_addr != 0 && allocate) {
#if defined(GGML_USE_HIP)
// Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
for (std::pair<CUdeviceptr, size_t> & mapping : mappings) {
@@ -499,36 +529,50 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
// allocate more physical memory
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device;
CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
if (allocate) {
// allocate more physical memory
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device;
CUmemGenericAllocationHandle handle;
if (cuMemCreate(&handle, reserve_size, &prop, 0) != CUDA_SUCCESS) {
last_alloc = reserve_size;
throw std::bad_alloc();
}
// reserve virtual address space (if not already reserved)
if (pool_addr == 0) {
CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
// reserve virtual address space (if not already reserved)
if (pool_addr == 0) {
CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
}
// map at the end of the pool
CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size);
if (cuMemMap(start_ptr, reserve_size, 0, handle, 0) != CUDA_SUCCESS) {
last_alloc = reserve_size;
CU_CHECK(cuMemRelease(handle));
throw std::bad_alloc();
}
// the memory allocation handle is no longer needed after mapping
CU_CHECK(cuMemRelease(handle));
// set access
CUmemAccessDesc access = {};
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = device;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
if (cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1) != CUDA_SUCCESS) {
CU_CHECK(cuMemUnmap(start_ptr, reserve_size));
last_alloc = reserve_size;
throw std::bad_alloc();
}
#if defined(GGML_USE_HIP)
mappings.push_back({start_ptr, reserve_size});
#endif
}
// map at the end of the pool
CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size);
CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0));
#if defined(GGML_USE_HIP)
mappings.push_back({start_ptr, reserve_size});
#endif
// the memory allocation handle is no longer needed after mapping
CU_CHECK(cuMemRelease(handle));
// set access
CUmemAccessDesc access = {};
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = device;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1));
// add to the pool
pool_size += reserve_size;
@@ -560,16 +604,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
// all deallocations must be in reverse order of the allocations
GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
}
bool alloc_memory() override {
return allocate;
}
size_t alloc_size() override {
return pool_size + last_alloc;
}
};
#endif // defined(GGML_USE_VMM)
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device, bool alloc) {
#if defined(GGML_USE_VMM)
if (ggml_cuda_info().devices[device].vmm) {
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device, alloc));
}
#endif // defined(GGML_USE_VMM)
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device, alloc));
}
// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
@@ -753,11 +805,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
}
static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return 128;
return CUDA_ALIGNMENT;
GGML_UNUSED(buft);
}
static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_noalloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
void * dev_ptr = (void *)ggml_backend_cuda_buffer_type_get_alignment(buft);
ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
return ggml_backend_buffer_init(buft, {}, ctx, size);
}
static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
size_t size = ggml_nbytes(tensor);
int64_t ne0 = tensor->ne[0];
@@ -781,6 +842,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
/* .is_host = */ NULL,
/* .noalloc_buffer = */ ggml_backend_cuda_buffer_type_noalloc_buffer,
};
ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
@@ -2941,6 +3003,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
// flag used to determine whether it is an integrated_gpu
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
@@ -2956,6 +3019,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
continue;
}
// When reserving, we are forcing CUDA graphs but this operation is not graph-safe so we need to skip it
if (reserving_graph && node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
continue;
}
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
@@ -3027,6 +3095,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
cuda_ctx->pool_set_alloc(true);
ggml_cuda_set_device(cuda_ctx->device);
@@ -3106,6 +3175,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
return GGML_STATUS_SUCCESS;
}
// This is used to skip operations that are not graph safe during the reservation process.
bool reserving_graph = false;
static enum ggml_status ggml_backend_cuda_graph_reserve(ggml_backend_t backend, ggml_cgraph * cgraph, bool alloc) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
cuda_ctx->pool_set_alloc(alloc);
#ifdef USE_CUDA_GRAPH
if (cuda_ctx->cuda_graph == nullptr) {
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
}
#endif
ggml_cuda_set_device(cuda_ctx->device);
{
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
}
reserving_graph = true;
// Create CuBLAS handles early to avoid synchronous allocations during graph capture.
cuda_ctx->cublas_handle();
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
enum ggml_status result = GGML_STATUS_SUCCESS;
try {
bool use_cuda_graph = false;
bool cuda_graph_update_required = false;
bool graph_evaluated_or_captured = false;
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
} catch (const std::exception &e) {
result = GGML_STATUS_FAILED;
}
cudaGraph_t graph;
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph));
CUDA_CHECK(cudaGraphDestroy(graph));
reserving_graph = false;
{
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
ggml_cuda_lock_cv.notify_all();
}
}
return result;
}
static size_t ggml_backend_cuda_buffer_size(ggml_backend_t backend) {
ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context;
return ctx->pool_get_alloc_size();
}
static void ggml_backend_cuda_reset(ggml_backend_t backend) {
ggml_backend_cuda_context * ctx = (ggml_backend_cuda_context *)backend->context;
ctx->pools[ctx->device] = NULL;
}
static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
@@ -3145,6 +3279,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
/* .event_record = */ ggml_backend_cuda_event_record,
/* .event_wait = */ ggml_backend_cuda_event_wait,
/* .graph_reserve = */ ggml_backend_cuda_graph_reserve,
/* .buffer_size = */ ggml_backend_cuda_buffer_size,
/* .reset = */ ggml_backend_cuda_reset,
};
static ggml_guid_t ggml_backend_cuda_guid() {