Skip to content

Commit 82dcf84

Browse files
committed
refactor!: put ptr() behind trait
1 parent 1563c13 commit 82dcf84

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+358
-251
lines changed

src/adapter.rs

+18-6
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,19 @@ use std::{
44
sync::Arc
55
};
66

7-
use crate::{Allocator, Result, ortsys, util};
7+
use crate::{Allocator, AsPointer, Result, ortsys, util};
88

99
#[derive(Debug)]
1010
pub(crate) struct AdapterInner {
11-
pub(crate) ptr: NonNull<ort_sys::OrtLoraAdapter>
11+
ptr: NonNull<ort_sys::OrtLoraAdapter>
12+
}
13+
14+
impl AsPointer for AdapterInner {
15+
type Sys = ort_sys::OrtLoraAdapter;
16+
17+
fn ptr(&self) -> *const Self::Sys {
18+
self.ptr.as_ptr()
19+
}
1220
}
1321

1422
impl Drop for AdapterInner {
@@ -25,7 +33,7 @@ pub struct Adapter {
2533
impl Adapter {
2634
pub fn from_file(path: impl AsRef<Path>, allocator: Option<&Allocator>) -> Result<Self> {
2735
let path = util::path_to_os_char(path);
28-
let allocator_ptr = allocator.map(|c| c.ptr()).unwrap_or_else(ptr::null_mut);
36+
let allocator_ptr = allocator.map(|c| c.ptr().cast_mut()).unwrap_or_else(ptr::null_mut);
2937
let mut ptr = ptr::null_mut();
3038
ortsys![unsafe CreateLoraAdapter(path.as_ptr(), allocator_ptr, &mut ptr)?];
3139
Ok(Adapter {
@@ -36,7 +44,7 @@ impl Adapter {
3644
}
3745

3846
pub fn from_memory(bytes: &[u8], allocator: Option<&Allocator>) -> Result<Self> {
39-
let allocator_ptr = allocator.map(|c| c.ptr()).unwrap_or_else(ptr::null_mut);
47+
let allocator_ptr = allocator.map(|c| c.ptr().cast_mut()).unwrap_or_else(ptr::null_mut);
4048
let mut ptr = ptr::null_mut();
4149
ortsys![unsafe CreateLoraAdapterFromArray(bytes.as_ptr().cast(), bytes.len(), allocator_ptr, &mut ptr)?];
4250
Ok(Adapter {
@@ -45,9 +53,13 @@ impl Adapter {
4553
})
4654
})
4755
}
56+
}
57+
58+
impl AsPointer for Adapter {
59+
type Sys = ort_sys::OrtLoraAdapter;
4860

49-
pub fn ptr(&self) -> *mut ort_sys::OrtLoraAdapter {
50-
self.inner.ptr.as_ptr()
61+
fn ptr(&self) -> *const Self::Sys {
62+
self.inner.ptr()
5163
}
5264
}
5365

src/environment.rs

+10-9
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use tracing::{Level, debug};
99

1010
#[cfg(feature = "load-dynamic")]
1111
use crate::G_ORT_DYLIB_PATH;
12-
use crate::{error::Result, execution_providers::ExecutionProviderDispatch, extern_system_fn, ortsys};
12+
use crate::{AsPointer, error::Result, execution_providers::ExecutionProviderDispatch, extern_system_fn, ortsys};
1313

1414
struct EnvironmentSingleton {
1515
lock: RwLock<Option<Arc<Environment>>>
@@ -33,24 +33,25 @@ static G_ENV: EnvironmentSingleton = EnvironmentSingleton { lock: RwLock::new(No
3333
#[derive(Debug)]
3434
pub struct Environment {
3535
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>,
36-
pub(crate) env_ptr: NonNull<ort_sys::OrtEnv>,
36+
ptr: NonNull<ort_sys::OrtEnv>,
3737
pub(crate) has_global_threadpool: bool
3838
}
3939

4040
unsafe impl Send for Environment {}
4141
unsafe impl Sync for Environment {}
4242

43-
impl Environment {
44-
/// Returns the underlying [`ort_sys::OrtEnv`] pointer.
45-
pub fn ptr(&self) -> *mut ort_sys::OrtEnv {
46-
self.env_ptr.as_ptr()
43+
impl AsPointer for Environment {
44+
type Sys = ort_sys::OrtEnv;
45+
46+
fn ptr(&self) -> *const Self::Sys {
47+
self.ptr.as_ptr()
4748
}
4849
}
4950

5051
impl Drop for Environment {
5152
fn drop(&mut self) {
52-
debug!(ptr = ?self.env_ptr.as_ptr(), "Releasing environment");
53-
ortsys![unsafe ReleaseEnv(self.env_ptr.as_ptr())];
53+
debug!(ptr = ?self.ptr(), "Releasing environment");
54+
ortsys![unsafe ReleaseEnv(self.ptr_mut())];
5455
}
5556
}
5657

@@ -213,7 +214,7 @@ impl EnvironmentBuilder {
213214
let env = Arc::new(Environment {
214215
execution_providers: self.execution_providers,
215216
// we already asserted the env pointer is non-null in the `CreateEnvWithCustomLogger` call
216-
env_ptr: unsafe { NonNull::new_unchecked(env_ptr) },
217+
ptr: unsafe { NonNull::new_unchecked(env_ptr) },
217218
has_global_threadpool
218219
});
219220
env_lock.replace(Arc::clone(&env));

src/execution_providers/acl.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ impl ExecutionProvider for ACLExecutionProvider {
4343
}
4444

4545
#[allow(unused, unreachable_code)]
46-
fn register(&self, session_builder: &SessionBuilder) -> Result<()> {
46+
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
4747
#[cfg(any(feature = "load-dynamic", feature = "acl"))]
4848
{
49+
use crate::AsPointer;
50+
4951
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut ort_sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> ort_sys::OrtStatusPtr);
50-
return crate::error::status_to_result(unsafe {
51-
OrtSessionOptionsAppendExecutionProvider_ACL(session_builder.session_options_ptr.as_ptr(), self.use_arena.into())
52-
});
52+
return crate::error::status_to_result(unsafe { OrtSessionOptionsAppendExecutionProvider_ACL(session_builder.ptr_mut(), self.use_arena.into()) });
5353
}
5454

5555
Err(Error::new(format!("`{}` was not registered because its corresponding Cargo feature is not enabled.", self.as_str())))

src/execution_providers/armnn.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ impl ExecutionProvider for ArmNNExecutionProvider {
4343
}
4444

4545
#[allow(unused, unreachable_code)]
46-
fn register(&self, session_builder: &SessionBuilder) -> Result<()> {
46+
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
4747
#[cfg(any(feature = "load-dynamic", feature = "armnn"))]
4848
{
49+
use crate::AsPointer;
50+
4951
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_ArmNN(options: *mut ort_sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> ort_sys::OrtStatusPtr);
50-
return crate::error::status_to_result(unsafe {
51-
OrtSessionOptionsAppendExecutionProvider_ArmNN(session_builder.session_options_ptr.as_ptr(), self.use_arena.into())
52-
});
52+
return crate::error::status_to_result(unsafe { OrtSessionOptionsAppendExecutionProvider_ArmNN(session_builder.ptr_mut(), self.use_arena.into()) });
5353
}
5454

5555
Err(Error::new(format!("`{}` was not registered because its corresponding Cargo feature is not enabled.", self.as_str())))

src/execution_providers/cann.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,11 @@ impl ExecutionProvider for CANNExecutionProvider {
139139
}
140140

141141
#[allow(unused, unreachable_code)]
142-
fn register(&self, session_builder: &SessionBuilder) -> Result<()> {
142+
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
143143
#[cfg(any(feature = "load-dynamic", feature = "cann"))]
144144
{
145+
use crate::AsPointer;
146+
145147
let mut cann_options: *mut ort_sys::OrtCANNProviderOptions = std::ptr::null_mut();
146148
crate::ortsys![unsafe CreateCANNProviderOptions(&mut cann_options)?];
147149
let ffi_options = self.options.to_ffi();
@@ -152,7 +154,7 @@ impl ExecutionProvider for CANNExecutionProvider {
152154
return Err(e);
153155
}
154156

155-
let status = crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_CANN(session_builder.session_options_ptr.as_ptr(), cann_options)];
157+
let status = crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_CANN(session_builder.ptr_mut(), cann_options)];
156158
crate::ortsys![unsafe ReleaseCANNProviderOptions(cann_options)];
157159
return crate::error::status_to_result(status);
158160
}

src/execution_providers/coreml.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,11 @@ impl ExecutionProvider for CoreMLExecutionProvider {
6363
}
6464

6565
#[allow(unused, unreachable_code)]
66-
fn register(&self, session_builder: &SessionBuilder) -> Result<()> {
66+
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
6767
#[cfg(any(feature = "load-dynamic", feature = "coreml"))]
6868
{
69+
use crate::AsPointer;
70+
6971
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_CoreML(options: *mut ort_sys::OrtSessionOptions, flags: u32) -> ort_sys::OrtStatusPtr);
7072
let mut flags = 0;
7173
if self.use_cpu_only {
@@ -77,9 +79,7 @@ impl ExecutionProvider for CoreMLExecutionProvider {
7779
if self.only_enable_device_with_ane {
7880
flags |= 0x004;
7981
}
80-
return crate::error::status_to_result(unsafe {
81-
OrtSessionOptionsAppendExecutionProvider_CoreML(session_builder.session_options_ptr.as_ptr(), flags)
82-
});
82+
return crate::error::status_to_result(unsafe { OrtSessionOptionsAppendExecutionProvider_CoreML(session_builder.ptr_mut(), flags) });
8383
}
8484

8585
Err(Error::new(format!("`{}` was not registered because its corresponding Cargo feature is not enabled.", self.as_str())))

src/execution_providers/cpu.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::{
2+
AsPointer,
23
error::Result,
34
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
45
ortsys,
@@ -43,11 +44,11 @@ impl ExecutionProvider for CPUExecutionProvider {
4344
true
4445
}
4546

46-
fn register(&self, session_builder: &SessionBuilder) -> Result<()> {
47+
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
4748
if self.use_arena {
48-
ortsys![unsafe EnableCpuMemArena(session_builder.session_options_ptr.as_ptr())?];
49+
ortsys![unsafe EnableCpuMemArena(session_builder.ptr_mut())?];
4950
} else {
50-
ortsys![unsafe DisableCpuMemArena(session_builder.session_options_ptr.as_ptr())?];
51+
ortsys![unsafe DisableCpuMemArena(session_builder.ptr_mut())?];
5152
}
5253
Ok(())
5354
}

src/execution_providers/cuda.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,11 @@ impl ExecutionProvider for CUDAExecutionProvider {
264264
}
265265

266266
#[allow(unused, unreachable_code)]
267-
fn register(&self, session_builder: &SessionBuilder) -> Result<()> {
267+
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
268268
#[cfg(any(feature = "load-dynamic", feature = "cuda"))]
269269
{
270+
use crate::AsPointer;
271+
270272
let mut cuda_options: *mut ort_sys::OrtCUDAProviderOptionsV2 = std::ptr::null_mut();
271273
crate::ortsys![unsafe CreateCUDAProviderOptions(&mut cuda_options)?];
272274
let ffi_options = self.options.to_ffi();
@@ -277,7 +279,7 @@ impl ExecutionProvider for CUDAExecutionProvider {
277279
return Err(e);
278280
}
279281

280-
let status = crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_CUDA_V2(session_builder.session_options_ptr.as_ptr(), cuda_options)];
282+
let status = crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_CUDA_V2(session_builder.ptr_mut(), cuda_options)];
281283
crate::ortsys![unsafe ReleaseCUDAProviderOptions(cuda_options)];
282284
return crate::error::status_to_result(status);
283285
}

src/execution_providers/directml.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@ impl ExecutionProvider for DirectMLExecutionProvider {
4343
}
4444

4545
#[allow(unused, unreachable_code)]
46-
fn register(&self, session_builder: &SessionBuilder) -> Result<()> {
46+
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
4747
#[cfg(any(feature = "load-dynamic", feature = "directml"))]
4848
{
49+
use crate::AsPointer;
50+
4951
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_DML(options: *mut ort_sys::OrtSessionOptions, device_id: std::os::raw::c_int) -> ort_sys::OrtStatusPtr);
50-
return crate::error::status_to_result(unsafe {
51-
OrtSessionOptionsAppendExecutionProvider_DML(session_builder.session_options_ptr.as_ptr(), self.device_id as _)
52-
});
52+
return crate::error::status_to_result(unsafe { OrtSessionOptionsAppendExecutionProvider_DML(session_builder.ptr_mut(), self.device_id as _) });
5353
}
5454

5555
Err(Error::new(format!("`{}` was not registered because its corresponding Cargo feature is not enabled.", self.as_str())))

src/execution_providers/migraphx.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,11 @@ impl ExecutionProvider for MIGraphXExecutionProvider {
8484
}
8585

8686
#[allow(unused, unreachable_code)]
87-
fn register(&self, session_builder: &SessionBuilder) -> Result<()> {
87+
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
8888
#[cfg(any(feature = "load-dynamic", feature = "migraphx"))]
8989
{
90+
use crate::AsPointer;
91+
9092
let options = ort_sys::OrtMIGraphXProviderOptions {
9193
device_id: self.device_id,
9294
migraphx_fp16_enable: self.enable_fp16.into(),
@@ -103,7 +105,7 @@ impl ExecutionProvider for MIGraphXExecutionProvider {
103105
migraphx_save_model_path: self.save_model_path.as_ref().map(|c| c.as_ptr()).unwrap_or_else(std::ptr::null),
104106
migraphx_exhaustive_tune: self.exhaustive_tune
105107
};
106-
crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.session_options_ptr.as_ptr(), &options)?];
108+
crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.ptr_mut(), &options)?];
107109
return Ok(());
108110
}
109111

src/execution_providers/mod.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ pub trait ExecutionProvider: Send + Sync {
104104
}
105105

106106
/// Attempts to register this execution provider on the given session.
107-
fn register(&self, session_builder: &SessionBuilder) -> Result<()>;
107+
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()>;
108108
}
109109

110110
/// Trait used for execution providers that can have arbitrary configuration keys applied.
@@ -180,17 +180,20 @@ impl ExecutionProviderOptions {
180180
.insert(CString::new(key).expect("unexpected nul in key string"), CString::new(value).expect("unexpected nul in value string"));
181181
}
182182

183+
#[allow(unused)]
183184
pub fn to_ffi(&self) -> ExecutionProviderOptionsFFI {
184185
let (key_ptrs, value_ptrs) = self.0.iter().map(|(k, v)| (k.as_ptr(), v.as_ptr())).unzip();
185186
ExecutionProviderOptionsFFI { key_ptrs, value_ptrs }
186187
}
187188
}
188189

190+
#[allow(unused)]
189191
pub(crate) struct ExecutionProviderOptionsFFI {
190192
key_ptrs: Vec<*const c_char>,
191193
value_ptrs: Vec<*const c_char>
192194
}
193195

196+
#[allow(unused)]
194197
impl ExecutionProviderOptionsFFI {
195198
pub fn key_ptrs(&self) -> *const *const c_char {
196199
self.key_ptrs.as_ptr()
@@ -228,7 +231,10 @@ macro_rules! get_ep_register {
228231
#[allow(unused)]
229232
pub(crate) use get_ep_register;
230233

231-
pub(crate) fn apply_execution_providers(session_builder: &SessionBuilder, execution_providers: impl Iterator<Item = ExecutionProviderDispatch>) -> Result<()> {
234+
pub(crate) fn apply_execution_providers(
235+
session_builder: &mut SessionBuilder,
236+
execution_providers: impl Iterator<Item = ExecutionProviderDispatch>
237+
) -> Result<()> {
232238
let execution_providers: Vec<_> = execution_providers.collect();
233239
let mut fallback_to_cpu = !execution_providers.is_empty();
234240
for ex in execution_providers {

src/execution_providers/nnapi.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@ impl ExecutionProvider for NNAPIExecutionProvider {
7676
}
7777

7878
#[allow(unused, unreachable_code)]
79-
fn register(&self, session_builder: &SessionBuilder) -> Result<()> {
79+
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
8080
#[cfg(any(feature = "load-dynamic", feature = "nnapi"))]
8181
{
82+
use crate::AsPointer;
83+
8284
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_Nnapi(options: *mut ort_sys::OrtSessionOptions, flags: u32) -> ort_sys::OrtStatusPtr);
8385
let mut flags = 0;
8486
if self.use_fp16 {
@@ -93,9 +95,7 @@ impl ExecutionProvider for NNAPIExecutionProvider {
9395
if self.cpu_only {
9496
flags |= 0x008;
9597
}
96-
return crate::error::status_to_result(unsafe {
97-
OrtSessionOptionsAppendExecutionProvider_Nnapi(session_builder.session_options_ptr.as_ptr(), flags)
98-
});
98+
return crate::error::status_to_result(unsafe { OrtSessionOptionsAppendExecutionProvider_Nnapi(session_builder.ptr_mut(), flags) });
9999
}
100100

101101
Err(Error::new(format!("`{}` was not registered because its corresponding Cargo feature is not enabled.", self.as_str())))

src/execution_providers/onednn.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ impl ExecutionProvider for OneDNNExecutionProvider {
4646
}
4747

4848
#[allow(unused, unreachable_code)]
49-
fn register(&self, session_builder: &SessionBuilder) -> Result<()> {
49+
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
5050
#[cfg(any(feature = "load-dynamic", feature = "onednn"))]
5151
{
52+
use crate::AsPointer;
53+
5254
let mut dnnl_options: *mut ort_sys::OrtDnnlProviderOptions = std::ptr::null_mut();
5355
crate::ortsys![unsafe CreateDnnlProviderOptions(&mut dnnl_options)?];
5456
let ffi_options = self.options.to_ffi();
@@ -59,7 +61,7 @@ impl ExecutionProvider for OneDNNExecutionProvider {
5961
return Err(e);
6062
}
6163

62-
let status = crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_Dnnl(session_builder.session_options_ptr.as_ptr(), dnnl_options)];
64+
let status = crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_Dnnl(session_builder.ptr_mut(), dnnl_options)];
6365
crate::ortsys![unsafe ReleaseDnnlProviderOptions(dnnl_options)];
6466
return Ok(());
6567
}

src/execution_providers/openvino.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,13 @@ impl ExecutionProvider for OpenVINOExecutionProvider {
120120
}
121121

122122
#[allow(unused, unreachable_code)]
123-
fn register(&self, session_builder: &SessionBuilder) -> Result<()> {
123+
fn register(&self, session_builder: &mut SessionBuilder) -> Result<()> {
124124
#[cfg(any(feature = "load-dynamic", feature = "openvino"))]
125125
{
126126
use std::ffi::CString;
127127

128+
use crate::AsPointer;
129+
128130
let device_type = self.device_type.as_deref().map(CString::new).transpose()?;
129131
let device_id = self.device_id.as_deref().map(CString::new).transpose()?;
130132
let cache_dir = self.cache_dir.as_deref().map(CString::new).transpose()?;
@@ -145,7 +147,7 @@ impl ExecutionProvider for OpenVINOExecutionProvider {
145147
enable_npu_fast_compile: self.enable_npu_fast_compile.into()
146148
};
147149
return crate::error::status_to_result(
148-
crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_OpenVINO(session_builder.session_options_ptr.as_ptr(), std::ptr::addr_of!(openvino_options))]
150+
crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_OpenVINO(session_builder.ptr_mut(), std::ptr::addr_of!(openvino_options))]
149151
);
150152
}
151153

0 commit comments

Comments
 (0)