Skip to content

Commit 2b14108

Browse files
turturicaturturica
turturica
authored and
turturica
committed
Add package scoped fixtures #2283
1 parent 372bcdb commit 2b14108

File tree

8 files changed

+112
-15
lines changed

8 files changed

+112
-15
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ Hugo van Kemenade
8787
Hui Wang (coldnight)
8888
Ian Bicking
8989
Ian Lesperance
90+
Ionuț Turturică
9091
Jaap Broekhuizen
9192
Jan Balster
9293
Janne Vanhala

_pytest/fixtures.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def pytest_sessionstart(session):
3636
import _pytest.nodes
3737

3838
scopename2class.update({
39+
'package': _pytest.python.Package,
3940
'class': _pytest.python.Class,
4041
'module': _pytest.python.Module,
4142
'function': _pytest.nodes.Item,
@@ -48,6 +49,7 @@ def pytest_sessionstart(session):
4849

4950

5051
scope2props = dict(session=())
52+
scope2props["package"] = ("fspath",)
5153
scope2props["module"] = ("fspath", "module")
5254
scope2props["class"] = scope2props["module"] + ("cls",)
5355
scope2props["instance"] = scope2props["class"] + ("instance", )
@@ -156,9 +158,11 @@ def get_parametrized_fixture_keys(item, scopenum):
156158
continue
157159
if scopenum == 0: # session
158160
key = (argname, param_index)
159-
elif scopenum == 1: # module
161+
elif scopenum == 1: # package
160162
key = (argname, param_index, item.fspath)
161-
elif scopenum == 2: # class
163+
elif scopenum == 2: # module
164+
key = (argname, param_index, item.fspath)
165+
elif scopenum == 3: # class
162166
key = (argname, param_index, item.fspath, item.cls)
163167
yield key
164168

@@ -596,7 +600,7 @@ class ScopeMismatchError(Exception):
596600
"""
597601

598602

599-
scopes = "session module class function".split()
603+
scopes = "session package module class function".split()
600604
scopenum_function = scopes.index("function")
601605

602606

_pytest/main.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -405,17 +405,30 @@ def collect(self):
405405

406406
def _collect(self, arg):
407407
names = self._parsearg(arg)
408-
path = names.pop(0)
409-
if path.check(dir=1):
408+
argpath = names.pop(0)
409+
paths = []
410+
if argpath.check(dir=1):
410411
assert not names, "invalid arg %r" % (arg,)
411-
for path in path.visit(fil=lambda x: x.check(file=1),
412-
rec=self._recurse, bf=True, sort=True):
413-
for x in self._collectfile(path):
414-
yield x
412+
for path in argpath.visit(fil=lambda x: x.check(file=1),
413+
rec=self._recurse, bf=True, sort=True):
414+
pkginit = path.dirpath().join('__init__.py')
415+
if pkginit.exists() and not any(x in pkginit.parts() for x in paths):
416+
for x in self._collectfile(pkginit):
417+
yield x
418+
paths.append(x.fspath.dirpath())
419+
420+
if not any(x in path.parts() for x in paths):
421+
for x in self._collectfile(path):
422+
yield x
415423
else:
416-
assert path.check(file=1)
417-
for x in self.matchnodes(self._collectfile(path), names):
418-
yield x
424+
assert argpath.check(file=1)
425+
pkginit = argpath.dirpath().join('__init__.py')
426+
if not self.isinitpath(argpath) and pkginit.exists():
427+
for x in self._collectfile(pkginit):
428+
yield x
429+
else:
430+
for x in self.matchnodes(self._collectfile(argpath), names):
431+
yield x
419432

420433
def _collectfile(self, path):
421434
ihook = self.gethookproxy(path)

_pytest/python.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from _pytest.config import hookimpl
1818

1919
import _pytest
20+
from _pytest.main import Session
2021
import pluggy
2122
from _pytest import fixtures
2223
from _pytest import nodes
@@ -157,7 +158,7 @@ def pytest_collect_file(path, parent):
157158
ext = path.ext
158159
if ext == ".py":
159160
if not parent.session.isinitpath(path):
160-
for pat in parent.config.getini('python_files'):
161+
for pat in parent.config.getini('python_files') + ['__init__.py']:
161162
if path.fnmatch(pat):
162163
break
163164
else:
@@ -167,9 +168,23 @@ def pytest_collect_file(path, parent):
167168

168169

169170
def pytest_pycollect_makemodule(path, parent):
171+
if path.basename == '__init__.py':
172+
return Package(path, parent)
170173
return Module(path, parent)
171174

172175

176+
def pytest_ignore_collect(path, config):
177+
# Skip duplicate packages.
178+
keepduplicates = config.getoption("keepduplicates")
179+
if keepduplicates:
180+
duplicate_paths = config.pluginmanager._duplicatepaths
181+
if path.basename == '__init__.py':
182+
if path in duplicate_paths:
183+
return True
184+
else:
185+
duplicate_paths.add(path)
186+
187+
173188
@hookimpl(hookwrapper=True)
174189
def pytest_pycollect_makeitem(collector, name, obj):
175190
outcome = yield
@@ -475,6 +490,36 @@ def setup(self):
475490
self.addfinalizer(teardown_module)
476491

477492

493+
class Package(Session, Module):
494+
495+
def __init__(self, fspath, parent=None, config=None, session=None, nodeid=None):
496+
session = parent.session
497+
nodes.FSCollector.__init__(
498+
self, fspath, parent=parent,
499+
config=config, session=session, nodeid=nodeid)
500+
self.name = fspath.pyimport().__name__
501+
self.trace = session.trace
502+
self._norecursepatterns = session._norecursepatterns
503+
for path in list(session.config.pluginmanager._duplicatepaths):
504+
if path.dirname == fspath.dirname and path != fspath:
505+
session.config.pluginmanager._duplicatepaths.remove(path)
506+
pass
507+
508+
def isinitpath(self, path):
509+
return path in self.session._initialpaths
510+
511+
def collect(self):
512+
path = self.fspath.dirpath()
513+
pkg_prefix = None
514+
for path in path.visit(fil=lambda x: 1,
515+
rec=self._recurse, bf=True, sort=True):
516+
if pkg_prefix and pkg_prefix in path.parts():
517+
continue
518+
for x in self._collectfile(path):
519+
yield x
520+
if isinstance(x, Package):
521+
pkg_prefix = path.dirpath()
522+
478523
def _get_xunit_setup_teardown(holder, attr_name, param_obj=None):
479524
"""
480525
Return a callable to perform xunit-style setup or teardown if

changelog/2283.feature

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Pytest now supports package-level fixtures.

testing/python/fixture.py

+33
Original file line numberDiff line numberDiff line change
@@ -1448,6 +1448,39 @@ def test_x(one):
14481448
reprec = testdir.inline_run("..")
14491449
reprec.assertoutcome(passed=2)
14501450

1451+
def test_package_xunit_fixture(self, testdir):
1452+
testdir.makepyfile(__init__="""\
1453+
values = []
1454+
""")
1455+
package = testdir.mkdir("package")
1456+
package.join("__init__.py").write(dedent("""\
1457+
from .. import values
1458+
def setup_module():
1459+
values.append("package")
1460+
def teardown_module():
1461+
values[:] = []
1462+
"""))
1463+
package.join("test_x.py").write(dedent("""\
1464+
from .. import values
1465+
def test_x():
1466+
assert values == ["package"]
1467+
"""))
1468+
package = testdir.mkdir("package2")
1469+
package.join("__init__.py").write(dedent("""\
1470+
from .. import values
1471+
def setup_module():
1472+
values.append("package2")
1473+
def teardown_module():
1474+
values[:] = []
1475+
"""))
1476+
package.join("test_x.py").write(dedent("""\
1477+
from .. import values
1478+
def test_x():
1479+
assert values == ["package2"]
1480+
"""))
1481+
reprec = testdir.inline_run()
1482+
reprec.assertoutcome(passed=2)
1483+
14511484

14521485
class TestAutouseDiscovery(object):
14531486

testing/test_collection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def test_continue_on_collection_errors_maxfail(testdir):
835835

836836
def test_fixture_scope_sibling_conftests(testdir):
837837
"""Regression test case for https://github.com/pytest-dev/pytest/issues/2836"""
838-
foo_path = testdir.mkpydir("foo")
838+
foo_path = testdir.mkdir("foo")
839839
foo_path.join("conftest.py").write(_pytest._code.Source("""
840840
import pytest
841841
@pytest.fixture

testing/test_session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ class TestY(TestX):
192192
started = reprec.getcalls("pytest_collectstart")
193193
finished = reprec.getreports("pytest_collectreport")
194194
assert len(started) == len(finished)
195-
assert len(started) == 7 # XXX extra TopCollector
195+
assert len(started) == 8 # XXX extra TopCollector
196196
colfail = [x for x in finished if x.failed]
197197
assert len(colfail) == 1
198198

0 commit comments

Comments
 (0)