diff --git a/ml/device.go b/ml/device.go index f0654127..7d86dfdd 100644 --- a/ml/device.go +++ b/ml/device.go @@ -509,11 +509,9 @@ func GetVisibleDevicesEnv(l []DeviceInfo) map[string]string { // to crash at inference time and requires deeper validation before we include // it in the supported devices list. func (d DeviceInfo) NeedsInitValidation() bool { - // At this time the only library we know needs a 2nd pass is ROCm since - // rocblas will crash on unsupported devices. We want to find those crashes - // during bootstrap discovery so we can eliminate those GPUs before the user - // tries to run inference on them - return d.Library == "ROCm" + // ROCm: rocblas will crash on unsupported devices. + // CUDA: verify CC is supported by the version of the library + return d.Library == "ROCm" || d.Library == "CUDA" } // Set the init validation environment variable