diff --git a/discover/runner.go b/discover/runner.go index 44737aa2..c963de6f 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -147,7 +147,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml. wg.Add(1) go func(i int) { defer wg.Done() - extraEnvs := ml.GetVisibleDevicesEnv(devices[i : i+1]) + extraEnvs := ml.GetVisibleDevicesEnv(devices[i:i+1], true) devices[i].AddInitValidation(extraEnvs) if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 { slog.Debug("filtering device which didn't fully initialize", @@ -333,7 +333,8 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml. defer cancel() // Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct - devFilter := ml.GetVisibleDevicesEnv(devices) + // We avoid CUDA filters here to keep ROCm from failing to discover GPUs in a mixed environment + devFilter := ml.GetVisibleDevicesEnv(devices, false) for dir := range libDirs { updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter) diff --git a/llm/server.go b/llm/server.go index fa4e438d..e9d0a030 100644 --- a/llm/server.go +++ b/llm/server.go @@ -227,7 +227,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st modelPath, gpuLibs, status, - ml.GetVisibleDevicesEnv(gpus), + ml.GetVisibleDevicesEnv(gpus, false), ) s := llmServer{ diff --git a/ml/device.go b/ml/device.go index 7d86dfdd..f892b512 100644 --- a/ml/device.go +++ b/ml/device.go @@ -494,13 +494,14 @@ func FlashAttentionSupported(l []DeviceInfo) bool { // Given the list of GPUs this instantiation is targeted for, // figure out the visible devices environment variables -func GetVisibleDevicesEnv(l []DeviceInfo) map[string]string { +// Set mustFilter true to enable filtering of CUDA devices +func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string { if len(l) == 0 { return nil } env := map[string]string{} for _, d := range l { - d.updateVisibleDevicesEnv(env) + d.updateVisibleDevicesEnv(env, mustFilter) } return env } @@ -532,7 +533,7 @@ func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool { return false } -func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string) { +func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) { var envVar string switch d.Library { case "ROCm": @@ -541,8 +542,15 @@ func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string) { if runtime.GOOS != "linux" { envVar = "HIP_VISIBLE_DEVICES" } + case "CUDA": + if !mustFilter { + // By default we try to avoid filtering CUDA devices because ROCm also + // looks at the CUDA env var, and gets confused in mixed vendor environments. + return + } + envVar = "CUDA_VISIBLE_DEVICES" default: - // CUDA and Vulkan are not filtered via env var, but via scheduling decisions + // Vulkan is not filtered via env var, but via scheduling decisions return } v, existing := env[envVar]