ggml: handle all streams (#13350)

Follow up from #12992

Free all streams, and keep the alloc logic aligned across streams.
This commit is contained in:
Daniel Hiltgen
2025-12-05 16:10:33 -08:00
committed by GitHub
parent 31b8c6a214
commit c146a138e3
6 changed files with 55 additions and 35 deletions

View File

@@ -62,7 +62,7 @@ index 74b7f070c..8d2cc167f 100644
GGML_ASSERT(device);
return device->iface.get_buffer_type(device);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index fe0da71ca..0787e443c 100644
index f1c178f31..1110ca372 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -109,6 +109,11 @@ int ggml_cuda_get_device() {
@@ -77,7 +77,7 @@ index fe0da71ca..0787e443c 100644
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device);
cudaError_t err;
@@ -4380,7 +4385,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
@@ -4386,7 +4391,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
props->id = ggml_backend_cuda_device_get_id(dev);
props->type = ggml_backend_cuda_device_get_type(dev);
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
@@ -89,7 +89,7 @@ index fe0da71ca..0787e443c 100644
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
#ifdef GGML_CUDA_NO_PEER_COPY
@@ -4835,6 +4843,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
@@ -4841,6 +4849,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
}
@@ -101,7 +101,7 @@ index fe0da71ca..0787e443c 100644
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .get_name = */ ggml_backend_cuda_device_get_name,
/* .get_description = */ ggml_backend_cuda_device_get_description,
@@ -4851,6 +4864,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
@@ -4857,6 +4870,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .event_new = */ ggml_backend_cuda_device_event_new,
/* .event_free = */ ggml_backend_cuda_device_event_free,
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,