Skip to content

Commit 2628378

Browse files
committed
fix: query device type to determine CPU accessibility
1 parent 636a133 commit 2628378

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

src/memory.rs

+36-1
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,34 @@ impl From<ort_sys::OrtMemType> for MemoryType {
318318
}
319319
}
320320

321+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
322+
#[allow(clippy::upper_case_acronyms)]
323+
pub enum DeviceType {
324+
CPU,
325+
GPU,
326+
FPGA
327+
}
328+
329+
impl From<DeviceType> for ort_sys::OrtMemoryInfoDeviceType {
330+
fn from(value: DeviceType) -> Self {
331+
match value {
332+
DeviceType::CPU => ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU,
333+
DeviceType::GPU => ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU,
334+
DeviceType::FPGA => ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA
335+
}
336+
}
337+
}
338+
339+
impl From<ort_sys::OrtMemoryInfoDeviceType> for DeviceType {
340+
fn from(value: ort_sys::OrtMemoryInfoDeviceType) -> Self {
341+
match value {
342+
ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU => DeviceType::CPU,
343+
ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU => DeviceType::GPU,
344+
ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA => DeviceType::FPGA
345+
}
346+
}
347+
}
348+
321349
/// Describes allocation properties for value memory.
322350
///
323351
/// `MemoryInfo` is used in the creation of [`Session`]s, [`Allocator`]s, and [`crate::Value`]s to describe on which
@@ -445,10 +473,17 @@ impl MemoryInfo {
445473
raw as _
446474
}
447475

476+
/// Returns the type of device (CPU/GPU) this memory is allocated on.
477+
pub fn device_type(&self) -> DeviceType {
478+
let mut raw: ort_sys::OrtMemoryInfoDeviceType = ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU;
479+
ortsys![unsafe MemoryInfoGetDeviceType(self.ptr.as_ptr(), &mut raw)];
480+
raw.into()
481+
}
482+
448483
/// Returns `true` if this memory is accessible by the CPU; meaning that, if a value were allocated on this device,
449484
/// it could be extracted to an `ndarray` or slice.
450485
pub fn is_cpu_accessible(&self) -> bool {
451-
self.allocation_device() == AllocationDevice::CPU || matches!(self.memory_type(), MemoryType::CPUInput | MemoryType::CPUOutput)
486+
self.device_type() == DeviceType::CPU
452487
}
453488

454489
pub fn ptr(&self) -> *mut ort_sys::OrtMemoryInfo {

0 commit comments

Comments
 (0)