Skip to content

Commit d77d7c6

Browse files
committed
Base implementation
1 parent 7792a9b commit d77d7c6

File tree

9 files changed

+283
-39
lines changed

9 files changed

+283
-39
lines changed

rstest/src/lib.rs

+16
Original file line numberDiff line numberDiff line change
@@ -1469,3 +1469,19 @@ pub use rstest_macros::fixture;
14691469
/// ```
14701470
///
14711471
pub use rstest_macros::rstest;
1472+
1473+
pub struct Context {
1474+
pub name: &'static str,
1475+
pub description: Option<&'static str>,
1476+
pub case: Option<usize>,
1477+
}
1478+
1479+
impl Context {
1480+
pub fn new(name: &'static str, description: Option<&'static str>, case: Option<usize>) -> Self {
1481+
Self {
1482+
name,
1483+
description,
1484+
case,
1485+
}
1486+
}
1487+
}
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
use rstest::*;
2+
3+
#[rstest]
4+
#[case::description(42)]
5+
fn with_case(#[context] ctx: Context, #[case] _c: u32) {
6+
assert_eq!("with_case", ctx.name);
7+
assert_eq!(Some("description"), ctx.description);
8+
assert_eq!(Some(0), ctx.case);
9+
}
10+
11+
#[rstest]
12+
fn without_case(#[context] ctx: Context) {
13+
assert_eq!("without_case", ctx.name);
14+
assert_eq!(None, ctx.description);
15+
assert_eq!(None, ctx.case);
16+
}

rstest/tests/rstest/mod.rs

+10
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,16 @@ fn no_std() {
12161216
.assert(output);
12171217
}
12181218

1219+
#[test]
1220+
fn context() {
1221+
let (output, _) = run_test("context.rs");
1222+
1223+
TestResults::new()
1224+
.ok("with_case::case_1_description")
1225+
.ok("without_case")
1226+
.assert(output);
1227+
}
1228+
12191229
mod async_timeout_feature {
12201230
use super::*;
12211231

rstest_macros/src/parse/arguments.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::collections::HashMap;
1+
use std::collections::{HashMap, HashSet};
22

33
use quote::format_ident;
44
use syn::{FnArg, Ident, Pat};
@@ -98,6 +98,7 @@ pub(crate) struct ArgumentsInfo {
9898
args: Args,
9999
is_global_await: bool,
100100
once: Option<syn::Attribute>,
101+
contexts: HashSet<Pat>,
101102
}
102103

103104
impl ArgumentsInfo {
@@ -235,6 +236,19 @@ impl ArgumentsInfo {
235236
fn_arg
236237
})
237238
}
239+
240+
#[allow(dead_code)]
241+
pub(crate) fn add_context(&mut self, pat: Pat) {
242+
self.contexts.insert(pat);
243+
}
244+
245+
pub(crate) fn set_contexts(&mut self, contexts: impl Iterator<Item = Pat>) {
246+
contexts.for_each(|c| self.add_context(c))
247+
}
248+
249+
pub(crate) fn contexts(&self) -> impl Iterator<Item = &Pat> + '_ {
250+
self.contexts.iter()
251+
}
238252
}
239253

240254
#[cfg(test)]

rstest_macros/src/parse/context.rs

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
use syn::{visit_mut::VisitMut, ItemFn, Pat};
2+
3+
use crate::error::ErrorsVec;
4+
5+
use super::just_once::JustOnceFnArgAttributeExtractor;
6+
7+
pub(crate) fn extract_context(item_fn: &mut ItemFn) -> Result<Vec<Pat>, ErrorsVec> {
8+
let mut extractor = JustOnceFnArgAttributeExtractor::from("context");
9+
extractor.visit_item_fn_mut(item_fn);
10+
extractor.take()
11+
}
12+
13+
#[cfg(test)]
14+
mod should {
15+
use super::*;
16+
use crate::test::{assert_eq, *};
17+
use rstest_test::assert_in;
18+
19+
#[rstest]
20+
#[case("fn simple(a: u32) {}")]
21+
#[case("fn more(a: u32, b: &str) {}")]
22+
#[case("fn gen<S: AsRef<str>>(a: u32, b: S) {}")]
23+
#[case("fn attr(#[case] a: u32, #[values(1,2)] b: i32) {}")]
24+
fn not_change_anything_if_no_ignore_attribute_found(#[case] item_fn: &str) {
25+
let mut item_fn: ItemFn = item_fn.ast();
26+
let orig = item_fn.clone();
27+
28+
let by_refs = extract_context(&mut item_fn).unwrap();
29+
30+
assert_eq!(orig, item_fn);
31+
assert!(by_refs.is_empty());
32+
}
33+
34+
#[rstest]
35+
#[case::simple("fn f(#[context] a: u32) {}", "fn f(a: u32) {}", &["a"])]
36+
#[case::more_than_one(
37+
"fn f(#[context] a: u32, #[context] b: String, #[context] c: std::collection::HashMap<usize, String>) {}",
38+
r#"fn f(a: u32,
39+
b: String,
40+
c: std::collection::HashMap<usize, String>) {}"#,
41+
&["a", "b", "c"])]
42+
fn extract(#[case] item_fn: &str, #[case] expected: &str, #[case] expected_refs: &[&str]) {
43+
let mut item_fn: ItemFn = item_fn.ast();
44+
let expected: ItemFn = expected.ast();
45+
46+
let by_refs = extract_context(&mut item_fn).unwrap();
47+
48+
assert_eq!(expected, item_fn);
49+
assert_eq!(by_refs, to_pats!(expected_refs));
50+
}
51+
52+
#[test]
53+
fn raise_error() {
54+
let mut item_fn: ItemFn = "fn f(#[context] #[context] a: u32) {}".ast();
55+
56+
let err = extract_context(&mut item_fn).unwrap_err();
57+
58+
assert_in!(format!("{:?}", err), "more than once");
59+
}
60+
}

rstest_macros/src/parse/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pub(crate) mod macros;
2828

2929
pub(crate) mod arguments;
3030
pub(crate) mod by_ref;
31+
pub(crate) mod context;
3132
pub(crate) mod expressions;
3233
pub(crate) mod fixture;
3334
pub(crate) mod future;

rstest_macros/src/parse/rstest.rs

+30-4
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ use self::files::{extract_files, ValueListFromFiles};
88
use super::{
99
arguments::ArgumentsInfo,
1010
by_ref::extract_by_ref,
11-
check_timeout_attrs, extract_case_args, extract_cases, extract_excluded_trace,
12-
extract_fixtures, extract_value_list,
11+
check_timeout_attrs,
12+
context::extract_context,
13+
extract_case_args, extract_cases, extract_excluded_trace, extract_fixtures, extract_value_list,
1314
future::{extract_futures, extract_global_awt},
1415
ignore::extract_ignores,
1516
parse_vector_trailing_till_double_comma,
@@ -49,20 +50,24 @@ impl Parse for RsTestInfo {
4950

5051
impl ExtendWithFunctionAttrs for RsTestInfo {
5152
fn extend_with_function_attrs(&mut self, item_fn: &mut ItemFn) -> Result<(), ErrorsVec> {
52-
let composed_tuple!(_inner, excluded, _timeout, futures, global_awt, by_refs, ignores) = merge_errors!(
53+
let composed_tuple!(
54+
_inner, excluded, _timeout, futures, global_awt, by_refs, ignores, contexts
55+
) = merge_errors!(
5356
self.data.extend_with_function_attrs(item_fn),
5457
extract_excluded_trace(item_fn),
5558
check_timeout_attrs(item_fn),
5659
extract_futures(item_fn),
5760
extract_global_awt(item_fn),
5861
extract_by_ref(item_fn),
59-
extract_ignores(item_fn)
62+
extract_ignores(item_fn),
63+
extract_context(item_fn)
6064
)?;
6165
self.attributes.add_notraces(excluded);
6266
self.arguments.set_global_await(global_awt);
6367
self.arguments.set_futures(futures.into_iter());
6468
self.arguments.set_by_refs(by_refs.into_iter());
6569
self.arguments.set_ignores(ignores.into_iter());
70+
self.arguments.set_contexts(contexts.into_iter());
6671
self.arguments
6772
.register_inner_destructored_idents_names(item_fn);
6873
Ok(())
@@ -379,6 +384,8 @@ mod test {
379384
}
380385

381386
mod no_cases {
387+
use std::collections::HashSet;
388+
382389
use super::{assert_eq, *};
383390

384391
#[test]
@@ -563,6 +570,25 @@ mod test {
563570
assert!(info.arguments.is_future(&pat("a")));
564571
assert!(!info.arguments.is_future(&pat("b")));
565572
}
573+
574+
#[rstest]
575+
fn extract_context() {
576+
let mut item_fn =
577+
"fn f(#[context] c: Context, #[context] other: Context, more: u32) {}".ast();
578+
let expected = "fn f(c: Context, other: Context, more: u32) {}".ast();
579+
580+
let mut info = RsTestInfo::default();
581+
582+
info.extend_with_function_attrs(&mut item_fn).unwrap();
583+
584+
assert_eq!(item_fn, expected);
585+
assert_eq!(
586+
info.arguments.contexts().cloned().collect::<HashSet<_>>(),
587+
vec![pat("c"), pat("other")]
588+
.into_iter()
589+
.collect::<HashSet<_>>()
590+
);
591+
}
566592
}
567593

568594
mod parametrize_cases {

0 commit comments

Comments
 (0)