@@ -318,6 +318,34 @@ impl From<ort_sys::OrtMemType> for MemoryType {
318
318
}
319
319
}
320
320
321
+ #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
322
+ #[ allow( clippy:: upper_case_acronyms) ]
323
+ pub enum DeviceType {
324
+ CPU ,
325
+ GPU ,
326
+ FPGA
327
+ }
328
+
329
+ impl From < DeviceType > for ort_sys:: OrtMemoryInfoDeviceType {
330
+ fn from ( value : DeviceType ) -> Self {
331
+ match value {
332
+ DeviceType :: CPU => ort_sys:: OrtMemoryInfoDeviceType :: OrtMemoryInfoDeviceType_CPU ,
333
+ DeviceType :: GPU => ort_sys:: OrtMemoryInfoDeviceType :: OrtMemoryInfoDeviceType_GPU ,
334
+ DeviceType :: FPGA => ort_sys:: OrtMemoryInfoDeviceType :: OrtMemoryInfoDeviceType_FPGA
335
+ }
336
+ }
337
+ }
338
+
339
+ impl From < ort_sys:: OrtMemoryInfoDeviceType > for DeviceType {
340
+ fn from ( value : ort_sys:: OrtMemoryInfoDeviceType ) -> Self {
341
+ match value {
342
+ ort_sys:: OrtMemoryInfoDeviceType :: OrtMemoryInfoDeviceType_CPU => DeviceType :: CPU ,
343
+ ort_sys:: OrtMemoryInfoDeviceType :: OrtMemoryInfoDeviceType_GPU => DeviceType :: GPU ,
344
+ ort_sys:: OrtMemoryInfoDeviceType :: OrtMemoryInfoDeviceType_FPGA => DeviceType :: FPGA
345
+ }
346
+ }
347
+ }
348
+
321
349
/// Describes allocation properties for value memory.
322
350
///
323
351
/// `MemoryInfo` is used in the creation of [`Session`]s, [`Allocator`]s, and [`crate::Value`]s to describe on which
@@ -445,10 +473,17 @@ impl MemoryInfo {
445
473
raw as _
446
474
}
447
475
476
+ /// Returns the type of device (CPU/GPU) this memory is allocated on.
477
+ pub fn device_type ( & self ) -> DeviceType {
478
+ let mut raw: ort_sys:: OrtMemoryInfoDeviceType = ort_sys:: OrtMemoryInfoDeviceType :: OrtMemoryInfoDeviceType_CPU ;
479
+ ortsys ! [ unsafe MemoryInfoGetDeviceType ( self . ptr. as_ptr( ) , & mut raw) ] ;
480
+ raw. into ( )
481
+ }
482
+
448
483
/// Returns `true` if this memory is accessible by the CPU; meaning that, if a value were allocated on this device,
449
484
/// it could be extracted to an `ndarray` or slice.
450
485
pub fn is_cpu_accessible ( & self ) -> bool {
451
- self . allocation_device ( ) == AllocationDevice :: CPU || matches ! ( self . memory_type ( ) , MemoryType :: CPUInput | MemoryType :: CPUOutput )
486
+ self . device_type ( ) == DeviceType :: CPU
452
487
}
453
488
454
489
pub fn ptr ( & self ) -> * mut ort_sys:: OrtMemoryInfo {
0 commit comments