Skip to content

Commit 552727e

Browse files
committed
feat: expose ortsys macro; make api return a reference
1 parent 516db5f commit 552727e

File tree

2 files changed

+25
-27
lines changed

2 files changed

+25
-27
lines changed

src/lib.rs

+19-21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#![doc(html_logo_url = "https://raw.githubusercontent.com/pykeio/ort/v2/docs/icon.png")]
22
#![cfg_attr(docsrs, feature(doc_cfg))]
33
#![allow(clippy::tabs_in_doc_comments, clippy::arc_with_non_send_sync)]
4+
#![allow(clippy::macro_metavars_in_unsafe)]
45
#![warn(clippy::unwrap_used)]
56

67
//! <div align=center>
@@ -165,23 +166,23 @@ pub fn info() -> &'static str {
165166
/// ```
166167
/// # use std::ffi::CStr;
167168
/// # fn main() -> ort::Result<()> {
168-
/// let api = ort::api().as_ptr();
169-
/// let build_info = unsafe { CStr::from_ptr((*api).GetBuildInfoString.unwrap()()) };
169+
/// let api = ort::api();
170+
/// let build_info = unsafe { CStr::from_ptr(api.GetBuildInfoString.unwrap()()) };
170171
/// println!("{}", build_info.to_string_lossy());
171172
/// // ORT Build Info: git-branch=HEAD, git-commit-id=4573740, build type=Release, cmake cxx flags: /DWIN32 /D_WINDOWS /EHsc /EHsc /wd26812 -DEIGEN_HAS_C99_MATH -DCPUINFO_SUPPORTED
172173
/// # Ok(())
173174
/// # }
174175
/// ```
175176
///
176177
/// For the full list of ONNX Runtime APIs, consult the [`ort_sys::OrtApi`] struct and the [ONNX Runtime C API](https://onnxruntime.ai/docs/api/c/struct_ort_api.html).
177-
pub fn api() -> NonNull<ort_sys::OrtApi> {
178+
pub fn api() -> &'static ort_sys::OrtApi {
178179
struct ApiPointer(NonNull<ort_sys::OrtApi>);
179180
unsafe impl Send for ApiPointer {}
180181
unsafe impl Sync for ApiPointer {}
181182

182183
static G_ORT_API: OnceLock<ApiPointer> = OnceLock::new();
183184

184-
G_ORT_API
185+
let ptr = G_ORT_API
185186
.get_or_init(|| {
186187
#[cfg(feature = "load-dynamic")]
187188
unsafe {
@@ -227,55 +228,52 @@ pub fn api() -> NonNull<ort_sys::OrtApi> {
227228
ApiPointer(NonNull::new(api.cast_mut()).expect("Failed to initialize ORT API"))
228229
}
229230
})
230-
.0
231+
.0;
232+
unsafe { ptr.as_ref() }
231233
}
232234

235+
#[macro_export]
233236
macro_rules! ortsys {
234237
($method:ident) => {
235-
$crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))
236-
};
237-
(unsafe $method:ident) => {
238-
unsafe { $crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null"))) }
238+
$crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))
239239
};
240240
($method:ident($($n:expr),+ $(,)?)) => {
241-
$crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)
241+
$crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)
242242
};
243243
(unsafe $method:ident($($n:expr),+ $(,)?)) => {
244-
unsafe { $crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }
244+
unsafe { $crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }
245245
};
246246
($method:ident($($n:expr),+ $(,)?).expect($e:expr)) => {
247-
$crate::error::status_to_result($crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).expect($e)
247+
$crate::error::status_to_result($crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).expect($e)
248248
};
249249
(unsafe $method:ident($($n:expr),+ $(,)?).expect($e:expr)) => {
250-
$crate::error::status_to_result(unsafe { $crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).expect($e)
250+
$crate::error::status_to_result(unsafe { $crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).expect($e)
251251
};
252252
($method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {
253-
$crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+);
253+
$crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+);
254254
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
255255
};
256256
(unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {{
257-
let _x = unsafe { $crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) };
257+
let _x = unsafe { $crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) };
258258
$($crate::error::assert_non_null_pointer($check, stringify!($method)).unwrap();)+
259259
_x
260260
}};
261261
($method:ident($($n:expr),+ $(,)?)?) => {
262-
$crate::error::status_to_result($crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+))?;
262+
$crate::error::status_to_result($crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+))?;
263263
};
264264
(unsafe $method:ident($($n:expr),+ $(,)?)?) => {
265-
$crate::error::status_to_result(unsafe { $crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) })?;
265+
$crate::error::status_to_result(unsafe { $crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) })?;
266266
};
267267
($method:ident($($n:expr),+ $(,)?)?; nonNull($($check:expr),+ $(,)?)$(;)?) => {
268-
$crate::error::status_to_result($crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+))?;
268+
$crate::error::status_to_result($crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+))?;
269269
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
270270
};
271271
(unsafe $method:ident($($n:expr),+ $(,)?)?; nonNull($($check:expr),+ $(,)?)$(;)?) => {{
272-
$crate::error::status_to_result(unsafe { $crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) })?;
272+
$crate::error::status_to_result(unsafe { $crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) })?;
273273
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
274274
}};
275275
}
276276

277-
pub(crate) use ortsys;
278-
279277
pub(crate) fn char_p_to_string(raw: *const c_char) -> Result<String> {
280278
if raw.is_null() {
281279
return Ok(String::new());

src/session/mod.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -505,12 +505,12 @@ mod dangerous {
505505
use super::*;
506506

507507
pub(super) fn extract_inputs_count(session_ptr: NonNull<ort_sys::OrtSession>) -> Result<usize> {
508-
let f = ortsys![unsafe SessionGetInputCount];
508+
let f = ortsys![SessionGetInputCount];
509509
extract_io_count(f, session_ptr)
510510
}
511511

512512
pub(super) fn extract_outputs_count(session_ptr: NonNull<ort_sys::OrtSession>) -> Result<usize> {
513-
let f = ortsys![unsafe SessionGetOutputCount];
513+
let f = ortsys![SessionGetOutputCount];
514514
extract_io_count(f, session_ptr)
515515
}
516516

@@ -525,12 +525,12 @@ mod dangerous {
525525
}
526526

527527
fn extract_input_name(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<String> {
528-
let f = ortsys![unsafe SessionGetInputName];
528+
let f = ortsys![SessionGetInputName];
529529
extract_io_name(f, session_ptr, allocator, i)
530530
}
531531

532532
fn extract_output_name(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<String> {
533-
let f = ortsys![unsafe SessionGetOutputName];
533+
let f = ortsys![SessionGetOutputName];
534534
extract_io_name(f, session_ptr, allocator, i)
535535
}
536536

@@ -568,14 +568,14 @@ mod dangerous {
568568

569569
pub(super) fn extract_input(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<Input> {
570570
let input_name = extract_input_name(session_ptr, allocator, i)?;
571-
let f = ortsys![unsafe SessionGetInputTypeInfo];
571+
let f = ortsys![SessionGetInputTypeInfo];
572572
let input_type = extract_io(f, session_ptr, i)?;
573573
Ok(Input { name: input_name, input_type })
574574
}
575575

576576
pub(super) fn extract_output(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<Output> {
577577
let output_name = extract_output_name(session_ptr, allocator, i)?;
578-
let f = ortsys![unsafe SessionGetOutputTypeInfo];
578+
let f = ortsys![SessionGetOutputTypeInfo];
579579
let output_type = extract_io(f, session_ptr, i)?;
580580
Ok(Output { name: output_name, output_type })
581581
}

0 commit comments

Comments
 (0)