Skip to content

Commit 8a16adb

Browse files
committed
refactor!: reduce allocations in run; make SessionOutputs not a map
1 parent 552727e commit 8a16adb

File tree

5 files changed

+383
-65
lines changed

5 files changed

+383
-65
lines changed

src/io_binding.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,17 @@ impl IoBinding {
231231
Some(Arc::clone(&self.session))
232232
)
233233
}
234-
});
234+
})
235+
.collect::<Vec<_>>();
235236

236237
// output values will be freed when the `Value`s in `SessionOutputs` drop
237238

238-
Ok(SessionOutputs::new_backed(self.output_names.iter().map(String::as_str), output_values, &self.session.allocator, output_values_ptr.cast()))
239+
Ok(SessionOutputs::new_backed(
240+
self.output_names.iter().map(String::as_str).collect(),
241+
output_values,
242+
&self.session.allocator,
243+
output_values_ptr.cast()
244+
))
239245
} else {
240246
Ok(SessionOutputs::new_empty())
241247
}

src/session/async.rs

+5-13
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use std::{
1212
use ort_sys::{OrtStatus, c_void};
1313

1414
use crate::{
15-
error::{Result, assert_non_null_pointer},
15+
error::Result,
1616
session::{RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner},
1717
value::Value
1818
};
@@ -138,17 +138,9 @@ crate::extern_system_fn! {
138138
let ctx = unsafe { Box::from_raw(user_data.cast::<AsyncInferenceContext<'_, '_>>()) };
139139

140140
// Reconvert name ptrs to CString so drop impl is called and memory is freed
141-
drop(
142-
ctx.input_name_ptrs
143-
.into_iter()
144-
.chain(ctx.output_name_ptrs)
145-
.map(|p| {
146-
assert_non_null_pointer(p, "c_char for CString")?;
147-
unsafe { Ok(CString::from_raw(p.cast_mut().cast())) }
148-
})
149-
.collect::<Result<Vec<_>>>()
150-
.expect("Input name should not be null")
151-
);
141+
for p in ctx.input_name_ptrs {
142+
drop(unsafe { CString::from_raw(p.cast_mut().cast()) });
143+
}
152144

153145
if let Err(e) = crate::error::status_to_result(status) {
154146
ctx.inner.emplace_value(Err(e));
@@ -164,7 +156,7 @@ crate::extern_system_fn! {
164156
})
165157
.collect();
166158

167-
ctx.inner.emplace_value(Ok(SessionOutputs::new(ctx.output_names.into_iter(), outputs)));
159+
ctx.inner.emplace_value(Ok(SessionOutputs::new(ctx.output_names, outputs)));
168160
ctx.inner.wake();
169161
}
170162
}

src/session/mod.rs

+5-12
Original file line numberDiff line numberDiff line change
@@ -324,18 +324,11 @@ impl Session {
324324
.collect();
325325

326326
// Reconvert name ptrs to CString so drop impl is called and memory is freed
327-
drop(
328-
input_names_ptr
329-
.into_iter()
330-
.chain(output_names_ptr.into_iter())
331-
.map(|p| {
332-
assert_non_null_pointer(p, "c_char for CString")?;
333-
unsafe { Ok(CString::from_raw(p.cast_mut().cast())) }
334-
})
335-
.collect::<Result<Vec<_>>>()?
336-
);
337-
338-
Ok(SessionOutputs::new(output_names.into_iter(), outputs))
327+
for p in input_names_ptr.into_iter().chain(output_names_ptr.into_iter()) {
328+
drop(unsafe { CString::from_raw(p.cast_mut().cast()) });
329+
}
330+
331+
Ok(SessionOutputs::new(output_names, outputs))
339332
}
340333

341334
/// Asynchronously run input data through the ONNX graph, performing inference.

0 commit comments

Comments
 (0)