moving to scripts

This commit is contained in:
eneller
2021-11-16 23:55:48 +01:00
parent f591ca2077
commit 14bfb7f96f
2575 changed files with 465862 additions and 0 deletions

View File

@@ -0,0 +1,41 @@
# XX this does not belong here -- b/c it's here, these things only apply to
# the tests in trio/_core/tests, not in trio/tests. For now there's some
# copy-paste...
#
# this stuff should become a proper pytest plugin
import pytest
import inspect
from ..testing import trio_test, MockClock
RUN_SLOW = True
def pytest_addoption(parser):
parser.addoption("--run-slow", action="store_true", help="run slow tests")
def pytest_configure(config):
global RUN_SLOW
RUN_SLOW = config.getoption("--run-slow", True)
@pytest.fixture
def mock_clock():
return MockClock()
@pytest.fixture
def autojump_clock():
return MockClock(autojump_threshold=0)
# FIXME: split off into a package (or just make part of Trio's public
# interface?), with config file to enable? and I guess a mark option too; I
# guess it's useful with the class- and file-level marking machinery (where
# the raw @trio_test decorator isn't enough).
@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(pyfuncitem):
if inspect.iscoroutinefunction(pyfuncitem.obj):
pyfuncitem.obj = trio_test(pyfuncitem.obj)

View File

@@ -0,0 +1,21 @@
regular = "hi"
from .. import _deprecate
_deprecate.enable_attribute_deprecations(__name__)
# Make sure that we don't trigger infinite recursion when accessing module
# attributes in between calling enable_attribute_deprecations and defining
# __deprecated_attributes__:
import sys
this_mod = sys.modules[__name__]
assert this_mod.regular == "hi"
assert not hasattr(this_mod, "dep1")
__deprecated_attributes__ = {
"dep1": _deprecate.DeprecatedAttribute("value1", "1.1", issue=1),
"dep2": _deprecate.DeprecatedAttribute(
"value2", "1.2", issue=1, instead="instead-string"
),
}

View File

@@ -0,0 +1,49 @@
import pytest
import attr
from ..testing import assert_checkpoints
from .. import abc as tabc
async def test_AsyncResource_defaults():
@attr.s
class MyAR(tabc.AsyncResource):
record = attr.ib(factory=list)
async def aclose(self):
self.record.append("ac")
async with MyAR() as myar:
assert isinstance(myar, MyAR)
assert myar.record == []
assert myar.record == ["ac"]
def test_abc_generics():
# Pythons below 3.5.2 had a typing.Generic that would throw
# errors when instantiating or subclassing a parameterized
# version of a class with any __slots__. This is why RunVar
# (which has slots) is not generic. This tests that
# the generic ABCs are fine, because while they are slotted
# they don't actually define any slots.
class SlottedChannel(tabc.SendChannel[tabc.Stream]):
__slots__ = ("x",)
def send_nowait(self, value):
raise RuntimeError
async def send(self, value):
raise RuntimeError # pragma: no cover
def clone(self):
raise RuntimeError # pragma: no cover
async def aclose(self):
pass # pragma: no cover
channel = SlottedChannel()
with pytest.raises(RuntimeError):
channel.send_nowait(None)

View File

@@ -0,0 +1,407 @@
import pytest
from ..testing import wait_all_tasks_blocked, assert_checkpoints
import trio
from trio import open_memory_channel, EndOfChannel
async def test_channel():
with pytest.raises(TypeError):
open_memory_channel(1.0)
with pytest.raises(ValueError):
open_memory_channel(-1)
s, r = open_memory_channel(2)
repr(s) # smoke test
repr(r) # smoke test
s.send_nowait(1)
with assert_checkpoints():
await s.send(2)
with pytest.raises(trio.WouldBlock):
s.send_nowait(None)
with assert_checkpoints():
assert await r.receive() == 1
assert r.receive_nowait() == 2
with pytest.raises(trio.WouldBlock):
r.receive_nowait()
s.send_nowait("last")
await s.aclose()
with pytest.raises(trio.ClosedResourceError):
await s.send("too late")
with pytest.raises(trio.ClosedResourceError):
s.send_nowait("too late")
with pytest.raises(trio.ClosedResourceError):
s.clone()
await s.aclose()
assert r.receive_nowait() == "last"
with pytest.raises(EndOfChannel):
await r.receive()
await r.aclose()
with pytest.raises(trio.ClosedResourceError):
await r.receive()
with pytest.raises(trio.ClosedResourceError):
await r.receive_nowait()
await r.aclose()
async def test_553(autojump_clock):
s, r = open_memory_channel(1)
with trio.move_on_after(10) as timeout_scope:
await r.receive()
assert timeout_scope.cancelled_caught
await s.send("Test for PR #553")
async def test_channel_multiple_producers():
async def producer(send_channel, i):
# We close our handle when we're done with it
async with send_channel:
for j in range(3 * i, 3 * (i + 1)):
await send_channel.send(j)
send_channel, receive_channel = open_memory_channel(0)
async with trio.open_nursery() as nursery:
# We hand out clones to all the new producers, and then close the
# original.
async with send_channel:
for i in range(10):
nursery.start_soon(producer, send_channel.clone(), i)
got = []
async for value in receive_channel:
got.append(value)
got.sort()
assert got == list(range(30))
async def test_channel_multiple_consumers():
successful_receivers = set()
received = []
async def consumer(receive_channel, i):
async for value in receive_channel:
successful_receivers.add(i)
received.append(value)
async with trio.open_nursery() as nursery:
send_channel, receive_channel = trio.open_memory_channel(1)
async with send_channel:
for i in range(5):
nursery.start_soon(consumer, receive_channel, i)
await wait_all_tasks_blocked()
for i in range(10):
await send_channel.send(i)
assert successful_receivers == set(range(5))
assert len(received) == 10
assert set(received) == set(range(10))
async def test_close_basics():
async def send_block(s, expect):
with pytest.raises(expect):
await s.send(None)
# closing send -> other send gets ClosedResourceError
s, r = open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.ClosedResourceError)
await wait_all_tasks_blocked()
await s.aclose()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
s.send_nowait(None)
with pytest.raises(trio.ClosedResourceError):
await s.send(None)
# and receive gets EndOfChannel
with pytest.raises(EndOfChannel):
r.receive_nowait()
with pytest.raises(EndOfChannel):
await r.receive()
# closing receive -> send gets BrokenResourceError
s, r = open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.BrokenResourceError)
await wait_all_tasks_blocked()
await r.aclose()
# and it's persistent
with pytest.raises(trio.BrokenResourceError):
s.send_nowait(None)
with pytest.raises(trio.BrokenResourceError):
await s.send(None)
# closing receive -> other receive gets ClosedResourceError
async def receive_block(r):
with pytest.raises(trio.ClosedResourceError):
await r.receive()
s, r = open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(receive_block, r)
await wait_all_tasks_blocked()
await r.aclose()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
r.receive_nowait()
with pytest.raises(trio.ClosedResourceError):
await r.receive()
async def test_close_sync():
async def send_block(s, expect):
with pytest.raises(expect):
await s.send(None)
# closing send -> other send gets ClosedResourceError
s, r = open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.ClosedResourceError)
await wait_all_tasks_blocked()
s.close()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
s.send_nowait(None)
with pytest.raises(trio.ClosedResourceError):
await s.send(None)
# and receive gets EndOfChannel
with pytest.raises(EndOfChannel):
r.receive_nowait()
with pytest.raises(EndOfChannel):
await r.receive()
# closing receive -> send gets BrokenResourceError
s, r = open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.BrokenResourceError)
await wait_all_tasks_blocked()
r.close()
# and it's persistent
with pytest.raises(trio.BrokenResourceError):
s.send_nowait(None)
with pytest.raises(trio.BrokenResourceError):
await s.send(None)
# closing receive -> other receive gets ClosedResourceError
async def receive_block(r):
with pytest.raises(trio.ClosedResourceError):
await r.receive()
s, r = open_memory_channel(0)
async with trio.open_nursery() as nursery:
nursery.start_soon(receive_block, r)
await wait_all_tasks_blocked()
r.close()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
r.receive_nowait()
with pytest.raises(trio.ClosedResourceError):
await r.receive()
async def test_receive_channel_clone_and_close():
s, r = open_memory_channel(10)
r2 = r.clone()
r3 = r.clone()
s.send_nowait(None)
await r.aclose()
with r2:
pass
with pytest.raises(trio.ClosedResourceError):
r.clone()
with pytest.raises(trio.ClosedResourceError):
r2.clone()
# Can still send, r3 is still open
s.send_nowait(None)
await r3.aclose()
# But now the receiver is really closed
with pytest.raises(trio.BrokenResourceError):
s.send_nowait(None)
async def test_close_multiple_send_handles():
# With multiple send handles, closing one handle only wakes senders on
# that handle, but others can continue just fine
s1, r = open_memory_channel(0)
s2 = s1.clone()
async def send_will_close():
with pytest.raises(trio.ClosedResourceError):
await s1.send("nope")
async def send_will_succeed():
await s2.send("ok")
async with trio.open_nursery() as nursery:
nursery.start_soon(send_will_close)
nursery.start_soon(send_will_succeed)
await wait_all_tasks_blocked()
await s1.aclose()
assert await r.receive() == "ok"
async def test_close_multiple_receive_handles():
# With multiple receive handles, closing one handle only wakes receivers on
# that handle, but others can continue just fine
s, r1 = open_memory_channel(0)
r2 = r1.clone()
async def receive_will_close():
with pytest.raises(trio.ClosedResourceError):
await r1.receive()
async def receive_will_succeed():
assert await r2.receive() == "ok"
async with trio.open_nursery() as nursery:
nursery.start_soon(receive_will_close)
nursery.start_soon(receive_will_succeed)
await wait_all_tasks_blocked()
await r1.aclose()
await s.send("ok")
async def test_inf_capacity():
s, r = open_memory_channel(float("inf"))
# It's accepted, and we can send all day without blocking
with s:
for i in range(10):
s.send_nowait(i)
got = []
async for i in r:
got.append(i)
assert got == list(range(10))
async def test_statistics():
s, r = open_memory_channel(2)
assert s.statistics() == r.statistics()
stats = s.statistics()
assert stats.current_buffer_used == 0
assert stats.max_buffer_size == 2
assert stats.open_send_channels == 1
assert stats.open_receive_channels == 1
assert stats.tasks_waiting_send == 0
assert stats.tasks_waiting_receive == 0
s.send_nowait(None)
assert s.statistics().current_buffer_used == 1
s2 = s.clone()
assert s.statistics().open_send_channels == 2
await s.aclose()
assert s2.statistics().open_send_channels == 1
r2 = r.clone()
assert s2.statistics().open_receive_channels == 2
await r2.aclose()
assert s2.statistics().open_receive_channels == 1
async with trio.open_nursery() as nursery:
s2.send_nowait(None) # fill up the buffer
assert s.statistics().current_buffer_used == 2
nursery.start_soon(s2.send, None)
nursery.start_soon(s2.send, None)
await wait_all_tasks_blocked()
assert s.statistics().tasks_waiting_send == 2
nursery.cancel_scope.cancel()
assert s.statistics().tasks_waiting_send == 0
# empty out the buffer again
try:
while True:
r.receive_nowait()
except trio.WouldBlock:
pass
async with trio.open_nursery() as nursery:
nursery.start_soon(r.receive)
await wait_all_tasks_blocked()
assert s.statistics().tasks_waiting_receive == 1
nursery.cancel_scope.cancel()
assert s.statistics().tasks_waiting_receive == 0
async def test_channel_fairness():
# We can remove an item we just sent, and send an item back in after, if
# no-one else is waiting.
s, r = open_memory_channel(1)
s.send_nowait(1)
assert r.receive_nowait() == 1
s.send_nowait(2)
assert r.receive_nowait() == 2
# But if someone else is waiting to receive, then they "own" the item we
# send, so we can't receive it (even though we run first):
result = None
async def do_receive(r):
nonlocal result
result = await r.receive()
async with trio.open_nursery() as nursery:
nursery.start_soon(do_receive, r)
await wait_all_tasks_blocked()
s.send_nowait(2)
with pytest.raises(trio.WouldBlock):
r.receive_nowait()
assert result == 2
# And the analogous situation for send: if we free up a space, we can't
# immediately send something in it if someone is already waiting to do
# that
s, r = open_memory_channel(1)
s.send_nowait(1)
with pytest.raises(trio.WouldBlock):
s.send_nowait(None)
async with trio.open_nursery() as nursery:
nursery.start_soon(s.send, 2)
await wait_all_tasks_blocked()
assert r.receive_nowait() == 1
with pytest.raises(trio.WouldBlock):
s.send_nowait(3)
assert (await r.receive()) == 2
async def test_unbuffered():
s, r = open_memory_channel(0)
with pytest.raises(trio.WouldBlock):
r.receive_nowait()
with pytest.raises(trio.WouldBlock):
s.send_nowait(1)
async def do_send(s, v):
with assert_checkpoints():
await s.send(v)
async with trio.open_nursery() as nursery:
nursery.start_soon(do_send, s, 1)
with assert_checkpoints():
assert await r.receive() == 1
with pytest.raises(trio.WouldBlock):
r.receive_nowait()

View File

@@ -0,0 +1,243 @@
import pytest
import inspect
import warnings
from .._deprecate import (
TrioDeprecationWarning,
warn_deprecated,
deprecated,
deprecated_alias,
)
from . import module_with_deprecations
@pytest.fixture
def recwarn_always(recwarn):
warnings.simplefilter("always")
# ResourceWarnings about unclosed sockets can occur nondeterministically
# (during GC) which throws off the tests in this file
warnings.simplefilter("ignore", ResourceWarning)
return recwarn
def _here():
info = inspect.getframeinfo(inspect.currentframe().f_back)
return (info.filename, info.lineno)
def test_warn_deprecated(recwarn_always):
def deprecated_thing():
warn_deprecated("ice", "1.2", issue=1, instead="water")
deprecated_thing()
filename, lineno = _here()
assert len(recwarn_always) == 1
got = recwarn_always.pop(TrioDeprecationWarning)
assert "ice is deprecated" in got.message.args[0]
assert "Trio 1.2" in got.message.args[0]
assert "water instead" in got.message.args[0]
assert "/issues/1" in got.message.args[0]
assert got.filename == filename
assert got.lineno == lineno - 1
def test_warn_deprecated_no_instead_or_issue(recwarn_always):
# Explicitly no instead or issue
warn_deprecated("water", "1.3", issue=None, instead=None)
assert len(recwarn_always) == 1
got = recwarn_always.pop(TrioDeprecationWarning)
assert "water is deprecated" in got.message.args[0]
assert "no replacement" in got.message.args[0]
assert "Trio 1.3" in got.message.args[0]
def test_warn_deprecated_stacklevel(recwarn_always):
def nested1():
nested2()
def nested2():
warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3)
filename, lineno = _here()
nested1()
got = recwarn_always.pop(TrioDeprecationWarning)
assert got.filename == filename
assert got.lineno == lineno + 1
def old(): # pragma: no cover
pass
def new(): # pragma: no cover
pass
def test_warn_deprecated_formatting(recwarn_always):
warn_deprecated(old, "1.0", issue=1, instead=new)
got = recwarn_always.pop(TrioDeprecationWarning)
assert "test_deprecate.old is deprecated" in got.message.args[0]
assert "test_deprecate.new instead" in got.message.args[0]
@deprecated("1.5", issue=123, instead=new)
def deprecated_old():
return 3
def test_deprecated_decorator(recwarn_always):
assert deprecated_old() == 3
got = recwarn_always.pop(TrioDeprecationWarning)
assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0]
assert "1.5" in got.message.args[0]
assert "test_deprecate.new" in got.message.args[0]
assert "issues/123" in got.message.args[0]
class Foo:
@deprecated("1.0", issue=123, instead="crying")
def method(self):
return 7
def test_deprecated_decorator_method(recwarn_always):
f = Foo()
assert f.method() == 7
got = recwarn_always.pop(TrioDeprecationWarning)
assert "test_deprecate.Foo.method is deprecated" in got.message.args[0]
@deprecated("1.2", thing="the thing", issue=None, instead=None)
def deprecated_with_thing():
return 72
def test_deprecated_decorator_with_explicit_thing(recwarn_always):
assert deprecated_with_thing() == 72
got = recwarn_always.pop(TrioDeprecationWarning)
assert "the thing is deprecated" in got.message.args[0]
def new_hotness():
return "new hotness"
old_hotness = deprecated_alias("old_hotness", new_hotness, "1.23", issue=1)
def test_deprecated_alias(recwarn_always):
assert old_hotness() == "new hotness"
got = recwarn_always.pop(TrioDeprecationWarning)
assert "test_deprecate.old_hotness is deprecated" in got.message.args[0]
assert "1.23" in got.message.args[0]
assert "test_deprecate.new_hotness instead" in got.message.args[0]
assert "issues/1" in got.message.args[0]
assert ".. deprecated:: 1.23" in old_hotness.__doc__
assert "test_deprecate.new_hotness instead" in old_hotness.__doc__
assert "issues/1>`__" in old_hotness.__doc__
class Alias:
def new_hotness_method(self):
return "new hotness method"
old_hotness_method = deprecated_alias(
"Alias.old_hotness_method", new_hotness_method, "3.21", issue=1
)
def test_deprecated_alias_method(recwarn_always):
obj = Alias()
assert obj.old_hotness_method() == "new hotness method"
got = recwarn_always.pop(TrioDeprecationWarning)
msg = got.message.args[0]
assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg
assert "test_deprecate.Alias.new_hotness_method instead" in msg
@deprecated("2.1", issue=1, instead="hi")
def docstring_test1(): # pragma: no cover
"""Hello!"""
@deprecated("2.1", issue=None, instead="hi")
def docstring_test2(): # pragma: no cover
"""Hello!"""
@deprecated("2.1", issue=1, instead=None)
def docstring_test3(): # pragma: no cover
"""Hello!"""
@deprecated("2.1", issue=None, instead=None)
def docstring_test4(): # pragma: no cover
"""Hello!"""
def test_deprecated_docstring_munging():
assert (
docstring_test1.__doc__
== """Hello!
.. deprecated:: 2.1
Use hi instead.
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
"""
)
assert (
docstring_test2.__doc__
== """Hello!
.. deprecated:: 2.1
Use hi instead.
"""
)
assert (
docstring_test3.__doc__
== """Hello!
.. deprecated:: 2.1
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
"""
)
assert (
docstring_test4.__doc__
== """Hello!
.. deprecated:: 2.1
"""
)
def test_module_with_deprecations(recwarn_always):
assert module_with_deprecations.regular == "hi"
assert len(recwarn_always) == 0
filename, lineno = _here()
assert module_with_deprecations.dep1 == "value1"
got = recwarn_always.pop(TrioDeprecationWarning)
assert got.filename == filename
assert got.lineno == lineno + 1
assert "module_with_deprecations.dep1" in got.message.args[0]
assert "Trio 1.1" in got.message.args[0]
assert "/issues/1" in got.message.args[0]
assert "value1 instead" in got.message.args[0]
assert module_with_deprecations.dep2 == "value2"
got = recwarn_always.pop(TrioDeprecationWarning)
assert "instead-string instead" in got.message.args[0]
with pytest.raises(AttributeError):
module_with_deprecations.asdf

View File

@@ -0,0 +1,156 @@
import re
import sys
import importlib
import types
import inspect
import enum
import pytest
import trio
import trio.testing
from .. import _core
from .. import _util
def test_core_is_properly_reexported():
# Each export from _core should be re-exported by exactly one of these
# three modules:
sources = [trio, trio.lowlevel, trio.testing]
for symbol in dir(_core):
if symbol.startswith("_") or symbol == "tests":
continue
found = 0
for source in sources:
if symbol in dir(source) and getattr(source, symbol) is getattr(
_core, symbol
):
found += 1
print(symbol, found)
assert found == 1
def public_modules(module):
yield module
for name, class_ in module.__dict__.items():
if name.startswith("_"): # pragma: no cover
continue
if not isinstance(class_, types.ModuleType):
continue
if not class_.__name__.startswith(module.__name__): # pragma: no cover
continue
if class_ is module:
continue
# We should rename the trio.tests module (#274), but until then we use
# a special-case hack:
if class_.__name__ == "trio.tests":
continue
yield from public_modules(class_)
PUBLIC_MODULES = list(public_modules(trio))
PUBLIC_MODULE_NAMES = [m.__name__ for m in PUBLIC_MODULES]
# It doesn't make sense for downstream redistributors to run this test, since
# they might be using a newer version of Python with additional symbols which
# won't be reflected in trio.socket, and this shouldn't cause downstream test
# runs to start failing.
@pytest.mark.redistributors_should_skip
# pylint/jedi often have trouble with alpha releases, where Python's internals
# are in flux, grammar may not have settled down, etc.
@pytest.mark.skipif(
sys.version_info.releaselevel == "alpha",
reason="skip static introspection tools on Python dev/alpha releases",
)
@pytest.mark.filterwarnings(
# https://github.com/PyCQA/astroid/issues/681
"ignore:the imp module is deprecated.*:DeprecationWarning"
)
@pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES)
@pytest.mark.parametrize("tool", ["pylint", "jedi"])
@pytest.mark.filterwarnings(
"ignore:"
+ re.escape(
"The distutils package is deprecated and slated for removal in Python 3.12. "
"Use setuptools or check PEP 632 for potential alternatives"
)
+ ":DeprecationWarning",
"ignore:"
+ re.escape("The distutils.sysconfig module is deprecated, use sysconfig instead")
+ ":DeprecationWarning",
)
def test_static_tool_sees_all_symbols(tool, modname):
module = importlib.import_module(modname)
def no_underscores(symbols):
return {symbol for symbol in symbols if not symbol.startswith("_")}
runtime_names = no_underscores(dir(module))
# We should rename the trio.tests module (#274), but until then we use a
# special-case hack:
if modname == "trio":
runtime_names.remove("tests")
if tool == "pylint":
from pylint.lint import PyLinter
linter = PyLinter()
ast = linter.get_ast(module.__file__, modname)
static_names = no_underscores(ast)
elif tool == "jedi":
import jedi
# Simulate typing "import trio; trio.<TAB>"
script = jedi.Script("import {}; {}.".format(modname, modname))
completions = script.complete()
static_names = no_underscores(c.name for c in completions)
else: # pragma: no cover
assert False
# It's expected that the static set will contain more names than the
# runtime set:
# - static tools are sometimes sloppy and include deleted names
# - some symbols are platform-specific at runtime, but always show up in
# static analysis (e.g. in trio.socket or trio.lowlevel)
# So we check that the runtime names are a subset of the static names.
missing_names = runtime_names - static_names
if missing_names: # pragma: no cover
print("{} can't see the following names in {}:".format(tool, modname))
print()
for name in sorted(missing_names):
print(" {}".format(name))
assert False
def test_classes_are_final():
for module in PUBLIC_MODULES:
for name, class_ in module.__dict__.items():
if not isinstance(class_, type):
continue
# Deprecated classes are exported with a leading underscore
if name.startswith("_"): # pragma: no cover
continue
# Abstract classes can be subclassed, because that's the whole
# point of ABCs
if inspect.isabstract(class_):
continue
# Exceptions are allowed to be subclassed, because exception
# subclassing isn't used to inherit behavior.
if issubclass(class_, BaseException):
continue
# These are classes that are conceptually abstract, but
# inspect.isabstract returns False for boring reasons.
if class_ in {trio.abc.Instrument, trio.socket.SocketType}:
continue
# Enums have their own metaclass, so we can't use our metaclasses.
# And I don't think there's a lot of risk from people subclassing
# enums...
if issubclass(class_, enum.Enum):
continue
# ... insert other special cases here ...
assert isinstance(class_, _util.Final)

View File

@@ -0,0 +1,198 @@
import io
import os
import pytest
from unittest import mock
from unittest.mock import sentinel
import trio
from trio import _core
from trio._file_io import AsyncIOWrapper, _FILE_SYNC_ATTRS, _FILE_ASYNC_METHODS
@pytest.fixture
def path(tmpdir):
return os.fspath(tmpdir.join("test"))
@pytest.fixture
def wrapped():
return mock.Mock(spec_set=io.StringIO)
@pytest.fixture
def async_file(wrapped):
return trio.wrap_file(wrapped)
def test_wrap_invalid():
with pytest.raises(TypeError):
trio.wrap_file(str())
def test_wrap_non_iobase():
class FakeFile:
def close(self): # pragma: no cover
pass
def write(self): # pragma: no cover
pass
wrapped = FakeFile()
assert not isinstance(wrapped, io.IOBase)
async_file = trio.wrap_file(wrapped)
assert isinstance(async_file, AsyncIOWrapper)
del FakeFile.write
with pytest.raises(TypeError):
trio.wrap_file(FakeFile())
def test_wrapped_property(async_file, wrapped):
assert async_file.wrapped is wrapped
def test_dir_matches_wrapped(async_file, wrapped):
attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS)
# all supported attrs in wrapped should be available in async_file
assert all(attr in dir(async_file) for attr in attrs if attr in dir(wrapped))
# all supported attrs not in wrapped should not be available in async_file
assert not any(
attr in dir(async_file) for attr in attrs if attr not in dir(wrapped)
)
def test_unsupported_not_forwarded():
class FakeFile(io.RawIOBase):
def unsupported_attr(self): # pragma: no cover
pass
async_file = trio.wrap_file(FakeFile())
assert hasattr(async_file.wrapped, "unsupported_attr")
with pytest.raises(AttributeError):
getattr(async_file, "unsupported_attr")
def test_sync_attrs_forwarded(async_file, wrapped):
for attr_name in _FILE_SYNC_ATTRS:
if attr_name not in dir(async_file):
continue
assert getattr(async_file, attr_name) is getattr(wrapped, attr_name)
def test_sync_attrs_match_wrapper(async_file, wrapped):
for attr_name in _FILE_SYNC_ATTRS:
if attr_name in dir(async_file):
continue
with pytest.raises(AttributeError):
getattr(async_file, attr_name)
with pytest.raises(AttributeError):
getattr(wrapped, attr_name)
def test_async_methods_generated_once(async_file):
for meth_name in _FILE_ASYNC_METHODS:
if meth_name not in dir(async_file):
continue
assert getattr(async_file, meth_name) is getattr(async_file, meth_name)
def test_async_methods_signature(async_file):
# use read as a representative of all async methods
assert async_file.read.__name__ == "read"
assert async_file.read.__qualname__ == "AsyncIOWrapper.read"
assert "io.StringIO.read" in async_file.read.__doc__
async def test_async_methods_wrap(async_file, wrapped):
for meth_name in _FILE_ASYNC_METHODS:
if meth_name not in dir(async_file):
continue
meth = getattr(async_file, meth_name)
wrapped_meth = getattr(wrapped, meth_name)
value = await meth(sentinel.argument, keyword=sentinel.keyword)
wrapped_meth.assert_called_once_with(
sentinel.argument, keyword=sentinel.keyword
)
assert value == wrapped_meth()
wrapped.reset_mock()
async def test_async_methods_match_wrapper(async_file, wrapped):
for meth_name in _FILE_ASYNC_METHODS:
if meth_name in dir(async_file):
continue
with pytest.raises(AttributeError):
getattr(async_file, meth_name)
with pytest.raises(AttributeError):
getattr(wrapped, meth_name)
async def test_open(path):
f = await trio.open_file(path, "w")
assert isinstance(f, AsyncIOWrapper)
await f.aclose()
async def test_open_context_manager(path):
async with await trio.open_file(path, "w") as f:
assert isinstance(f, AsyncIOWrapper)
assert not f.closed
assert f.closed
async def test_async_iter():
async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar"))
expected = list(async_file.wrapped)
result = []
async_file.wrapped.seek(0)
async for line in async_file:
result.append(line)
assert result == expected
async def test_aclose_cancelled(path):
with _core.CancelScope() as cscope:
f = await trio.open_file(path, "w")
cscope.cancel()
with pytest.raises(_core.Cancelled):
await f.write("a")
with pytest.raises(_core.Cancelled):
await f.aclose()
assert f.closed
async def test_detach_rewraps_asynciobase():
raw = io.BytesIO()
buffered = io.BufferedReader(raw)
async_file = trio.wrap_file(buffered)
detached = await async_file.detach()
assert isinstance(detached, AsyncIOWrapper)
assert detached.wrapped is raw

View File

@@ -0,0 +1,94 @@
import pytest
import attr
from ..abc import SendStream, ReceiveStream
from .._highlevel_generic import StapledStream
@attr.s
class RecordSendStream(SendStream):
record = attr.ib(factory=list)
async def send_all(self, data):
self.record.append(("send_all", data))
async def wait_send_all_might_not_block(self):
self.record.append("wait_send_all_might_not_block")
async def aclose(self):
self.record.append("aclose")
@attr.s
class RecordReceiveStream(ReceiveStream):
record = attr.ib(factory=list)
async def receive_some(self, max_bytes=None):
self.record.append(("receive_some", max_bytes))
async def aclose(self):
self.record.append("aclose")
async def test_StapledStream():
send_stream = RecordSendStream()
receive_stream = RecordReceiveStream()
stapled = StapledStream(send_stream, receive_stream)
assert stapled.send_stream is send_stream
assert stapled.receive_stream is receive_stream
await stapled.send_all(b"foo")
await stapled.wait_send_all_might_not_block()
assert send_stream.record == [
("send_all", b"foo"),
"wait_send_all_might_not_block",
]
send_stream.record.clear()
await stapled.send_eof()
assert send_stream.record == ["aclose"]
send_stream.record.clear()
async def fake_send_eof():
send_stream.record.append("send_eof")
send_stream.send_eof = fake_send_eof
await stapled.send_eof()
assert send_stream.record == ["send_eof"]
send_stream.record.clear()
assert receive_stream.record == []
await stapled.receive_some(1234)
assert receive_stream.record == [("receive_some", 1234)]
assert send_stream.record == []
receive_stream.record.clear()
await stapled.aclose()
assert receive_stream.record == ["aclose"]
assert send_stream.record == ["aclose"]
async def test_StapledStream_with_erroring_close():
# Make sure that if one of the aclose methods errors out, then the other
# one still gets called.
class BrokenSendStream(RecordSendStream):
async def aclose(self):
await super().aclose()
raise ValueError
class BrokenReceiveStream(RecordReceiveStream):
async def aclose(self):
await super().aclose()
raise ValueError
stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
with pytest.raises(ValueError) as excinfo:
await stapled.aclose()
assert isinstance(excinfo.value.__context__, ValueError)
assert stapled.send_stream.record == ["aclose"]
assert stapled.receive_stream.record == ["aclose"]

View File

@@ -0,0 +1,295 @@
import pytest
import socket as stdlib_socket
import errno
import attr
import trio
from trio import open_tcp_listeners, serve_tcp, SocketListener, open_tcp_stream
from trio.testing import open_stream_to_socket_listener
from .. import socket as tsocket
from .._core.tests.tutil import slow, creates_ipv6, binds_ipv6
async def test_open_tcp_listeners_basic():
listeners = await open_tcp_listeners(0)
assert isinstance(listeners, list)
for obj in listeners:
assert isinstance(obj, SocketListener)
# Binds to wildcard address by default
assert obj.socket.family in [tsocket.AF_INET, tsocket.AF_INET6]
assert obj.socket.getsockname()[0] in ["0.0.0.0", "::"]
listener = listeners[0]
# Make sure the backlog is at least 2
c1 = await open_stream_to_socket_listener(listener)
c2 = await open_stream_to_socket_listener(listener)
s1 = await listener.accept()
s2 = await listener.accept()
# Note that we don't know which client stream is connected to which server
# stream
await s1.send_all(b"x")
await s2.send_all(b"x")
assert await c1.receive_some(1) == b"x"
assert await c2.receive_some(1) == b"x"
for resource in [c1, c2, s1, s2] + listeners:
await resource.aclose()
async def test_open_tcp_listeners_specific_port_specific_host():
# Pick a port
sock = tsocket.socket()
await sock.bind(("127.0.0.1", 0))
host, port = sock.getsockname()
sock.close()
(listener,) = await open_tcp_listeners(port, host=host)
async with listener:
assert listener.socket.getsockname() == (host, port)
@binds_ipv6
async def test_open_tcp_listeners_ipv6_v6only():
# Check IPV6_V6ONLY is working properly
(ipv6_listener,) = await open_tcp_listeners(0, host="::1")
async with ipv6_listener:
_, port, *_ = ipv6_listener.socket.getsockname()
with pytest.raises(OSError):
await open_tcp_stream("127.0.0.1", port)
async def test_open_tcp_listeners_rebind():
(l1,) = await open_tcp_listeners(0, host="127.0.0.1")
sockaddr1 = l1.socket.getsockname()
# Plain old rebinding while it's still there should fail, even if we have
# SO_REUSEADDR set
with stdlib_socket.socket() as probe:
probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1)
with pytest.raises(OSError):
probe.bind(sockaddr1)
# Now use the first listener to set up some connections in various states,
# and make sure that they don't create any obstacle to rebinding a second
# listener after the first one is closed.
c_established = await open_stream_to_socket_listener(l1)
s_established = await l1.accept()
c_time_wait = await open_stream_to_socket_listener(l1)
s_time_wait = await l1.accept()
# Server-initiated close leaves socket in TIME_WAIT
await s_time_wait.aclose()
await l1.aclose()
(l2,) = await open_tcp_listeners(sockaddr1[1], host="127.0.0.1")
sockaddr2 = l2.socket.getsockname()
assert sockaddr1 == sockaddr2
assert s_established.socket.getsockname() == sockaddr2
assert c_time_wait.socket.getpeername() == sockaddr2
for resource in [
l1,
l2,
c_established,
s_established,
c_time_wait,
s_time_wait,
]:
await resource.aclose()
class FakeOSError(OSError):
pass
@attr.s
class FakeSocket(tsocket.SocketType):
family = attr.ib()
type = attr.ib()
proto = attr.ib()
closed = attr.ib(default=False)
poison_listen = attr.ib(default=False)
backlog = attr.ib(default=None)
def getsockopt(self, level, option):
if (level, option) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN):
return True
assert False # pragma: no cover
def setsockopt(self, level, option, value):
pass
async def bind(self, sockaddr):
pass
def listen(self, backlog):
assert self.backlog is None
assert backlog is not None
self.backlog = backlog
if self.poison_listen:
raise FakeOSError("whoops")
def close(self):
self.closed = True
@attr.s
class FakeSocketFactory:
poison_after = attr.ib()
sockets = attr.ib(factory=list)
raise_on_family = attr.ib(factory=dict) # family => errno
def socket(self, family, type, proto):
if family in self.raise_on_family:
raise OSError(self.raise_on_family[family], "nope")
sock = FakeSocket(family, type, proto)
self.poison_after -= 1
if self.poison_after == 0:
sock.poison_listen = True
self.sockets.append(sock)
return sock
@attr.s
class FakeHostnameResolver:
family_addr_pairs = attr.ib()
async def getaddrinfo(self, host, port, family, type, proto, flags):
return [
(family, tsocket.SOCK_STREAM, 0, "", (addr, port))
for family, addr in self.family_addr_pairs
]
async def test_open_tcp_listeners_multiple_host_cleanup_on_error():
# If we were trying to bind to multiple hosts and one of them failed, they
# call get cleaned up before returning
fsf = FakeSocketFactory(3)
tsocket.set_custom_socket_factory(fsf)
tsocket.set_custom_hostname_resolver(
FakeHostnameResolver(
[
(tsocket.AF_INET, "1.1.1.1"),
(tsocket.AF_INET, "2.2.2.2"),
(tsocket.AF_INET, "3.3.3.3"),
]
)
)
with pytest.raises(FakeOSError):
await open_tcp_listeners(80, host="example.org")
assert len(fsf.sockets) == 3
for sock in fsf.sockets:
assert sock.closed
async def test_open_tcp_listeners_port_checking():
for host in ["127.0.0.1", None]:
with pytest.raises(TypeError):
await open_tcp_listeners(None, host=host)
with pytest.raises(TypeError):
await open_tcp_listeners(b"80", host=host)
with pytest.raises(TypeError):
await open_tcp_listeners("http", host=host)
async def test_serve_tcp():
async def handler(stream):
await stream.send_all(b"x")
async with trio.open_nursery() as nursery:
listeners = await nursery.start(serve_tcp, handler, 0)
stream = await open_stream_to_socket_listener(listeners[0])
async with stream:
await stream.receive_some(1) == b"x"
nursery.cancel_scope.cancel()
@pytest.mark.parametrize(
"try_families",
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
)
@pytest.mark.parametrize(
"fail_families",
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
)
async def test_open_tcp_listeners_some_address_families_unavailable(
try_families, fail_families
):
fsf = FakeSocketFactory(
10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families}
)
tsocket.set_custom_socket_factory(fsf)
tsocket.set_custom_hostname_resolver(
FakeHostnameResolver([(family, "foo") for family in try_families])
)
should_succeed = try_families - fail_families
if not should_succeed:
with pytest.raises(OSError) as exc_info:
await open_tcp_listeners(80, host="example.org")
assert "This system doesn't support" in str(exc_info.value)
if isinstance(exc_info.value.__cause__, trio.MultiError):
for subexc in exc_info.value.__cause__.exceptions:
assert "nope" in str(subexc)
else:
assert isinstance(exc_info.value.__cause__, OSError)
assert "nope" in str(exc_info.value.__cause__)
else:
listeners = await open_tcp_listeners(80)
for listener in listeners:
should_succeed.remove(listener.socket.family)
assert not should_succeed
async def test_open_tcp_listeners_socket_fails_not_afnosupport():
fsf = FakeSocketFactory(
10,
raise_on_family={
tsocket.AF_INET: errno.EAFNOSUPPORT,
tsocket.AF_INET6: errno.EINVAL,
},
)
tsocket.set_custom_socket_factory(fsf)
tsocket.set_custom_hostname_resolver(
FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")])
)
with pytest.raises(OSError) as exc_info:
await open_tcp_listeners(80, host="example.org")
assert exc_info.value.errno == errno.EINVAL
assert exc_info.value.__cause__ is None
assert "nope" in str(exc_info.value)
# We used to have an elaborate test that opened a real TCP listening socket
# and then tried to measure its backlog by making connections to it. And most
# of the time, it worked. But no matter what we tried, it was always fragile,
# because it had to do things like use timeouts to guess when the listening
# queue was full, sometimes the CI hosts go into SYN-cookie mode (where there
# effectively is no backlog), sometimes the host might not be enough resources
# to give us the full requested backlog... it was a mess. So now we just check
# that the backlog argument is passed through correctly.
async def test_open_tcp_listeners_backlog():
fsf = FakeSocketFactory(99)
tsocket.set_custom_socket_factory(fsf)
for (given, expected) in [
(None, 0xFFFF),
(99999999, 0xFFFF),
(10, 10),
(1, 1),
]:
listeners = await open_tcp_listeners(0, backlog=given)
assert listeners
for listener in listeners:
assert listener.socket.backlog == expected

View File

@@ -0,0 +1,571 @@
import pytest
import sys
import socket
import attr
import trio
from trio.socket import AF_INET, AF_INET6, SOCK_STREAM, IPPROTO_TCP
from trio._highlevel_open_tcp_stream import (
reorder_for_rfc_6555_section_5_4,
close_all,
open_tcp_stream,
format_host_port,
)
def test_close_all():
class CloseMe:
closed = False
def close(self):
self.closed = True
class CloseKiller:
def close(self):
raise OSError
c = CloseMe()
with close_all() as to_close:
to_close.add(c)
assert c.closed
c = CloseMe()
with pytest.raises(RuntimeError):
with close_all() as to_close:
to_close.add(c)
raise RuntimeError
assert c.closed
c = CloseMe()
with pytest.raises(OSError):
with close_all() as to_close:
to_close.add(CloseKiller())
to_close.add(c)
assert c.closed
def test_reorder_for_rfc_6555_section_5_4():
def fake4(i):
return (
AF_INET,
SOCK_STREAM,
IPPROTO_TCP,
"",
("10.0.0.{}".format(i), 80),
)
def fake6(i):
return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", ("::{}".format(i), 80))
for fake in fake4, fake6:
# No effect on homogeneous lists
targets = [fake(0), fake(1), fake(2)]
reorder_for_rfc_6555_section_5_4(targets)
assert targets == [fake(0), fake(1), fake(2)]
# Single item lists also OK
targets = [fake(0)]
reorder_for_rfc_6555_section_5_4(targets)
assert targets == [fake(0)]
# If the list starts out with different families in positions 0 and 1,
# then it's left alone
orig = [fake4(0), fake6(0), fake4(1), fake6(1)]
targets = list(orig)
reorder_for_rfc_6555_section_5_4(targets)
assert targets == orig
# If not, it's reordered
targets = [fake4(0), fake4(1), fake4(2), fake6(0), fake6(1)]
reorder_for_rfc_6555_section_5_4(targets)
assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)]
def test_format_host_port():
assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80"
assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80"
assert format_host_port("example.com", 443) == "example.com:443"
assert format_host_port(b"example.com", 443) == "example.com:443"
assert format_host_port("::1", "http") == "[::1]:http"
assert format_host_port(b"::1", "http") == "[::1]:http"
# Make sure we can connect to localhost using real kernel sockets
async def test_open_tcp_stream_real_socket_smoketest():
listen_sock = trio.socket.socket()
await listen_sock.bind(("127.0.0.1", 0))
_, listen_port = listen_sock.getsockname()
listen_sock.listen(1)
client_stream = await open_tcp_stream("127.0.0.1", listen_port)
server_sock, _ = await listen_sock.accept()
await client_stream.send_all(b"x")
assert await server_sock.recv(1) == b"x"
await client_stream.aclose()
server_sock.close()
listen_sock.close()
async def test_open_tcp_stream_input_validation():
with pytest.raises(ValueError):
await open_tcp_stream(None, 80)
with pytest.raises(TypeError):
await open_tcp_stream("127.0.0.1", b"80")
def can_bind_127_0_0_2():
with socket.socket() as s:
try:
s.bind(("127.0.0.2", 0))
except OSError:
return False
return s.getsockname()[0] == "127.0.0.2"
async def test_local_address_real():
with trio.socket.socket() as listener:
await listener.bind(("127.0.0.1", 0))
listener.listen()
# It's hard to test local_address properly, because you need multiple
# local addresses that you can bind to. Fortunately, on most Linux
# systems, you can bind to any 127.*.*.* address, and they all go
# through the loopback interface. So we can use a non-standard
# loopback address. On other systems, the only address we know for
# certain we have is 127.0.0.1, so we can't really test local_address=
# properly -- passing local_address=127.0.0.1 is indistinguishable
# from not passing local_address= at all. But, we can still do a smoke
# test to make sure the local_address= code doesn't crash.
if can_bind_127_0_0_2():
local_address = "127.0.0.2"
else:
local_address = "127.0.0.1"
async with await open_tcp_stream(
*listener.getsockname(), local_address=local_address
) as client_stream:
assert client_stream.socket.getsockname()[0] == local_address
if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"):
assert client_stream.socket.getsockopt(
trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT
)
server_sock, remote_addr = await listener.accept()
await client_stream.aclose()
server_sock.close()
assert remote_addr[0] == local_address
# Trying to connect to an ipv4 address with the ipv6 wildcard
# local_address should fail
with pytest.raises(OSError):
await open_tcp_stream(*listener.getsockname(), local_address="::")
# But the ipv4 wildcard address should work
async with await open_tcp_stream(
*listener.getsockname(), local_address="0.0.0.0"
) as client_stream:
server_sock, remote_addr = await listener.accept()
server_sock.close()
assert remote_addr == client_stream.socket.getsockname()
# Now, thorough tests using fake sockets
@attr.s(eq=False)
class FakeSocket(trio.socket.SocketType):
scenario = attr.ib()
family = attr.ib()
type = attr.ib()
proto = attr.ib()
ip = attr.ib(default=None)
port = attr.ib(default=None)
succeeded = attr.ib(default=False)
closed = attr.ib(default=False)
failing = attr.ib(default=False)
async def connect(self, sockaddr):
self.ip = sockaddr[0]
self.port = sockaddr[1]
assert self.ip not in self.scenario.sockets
self.scenario.sockets[self.ip] = self
self.scenario.connect_times[self.ip] = trio.current_time()
delay, result = self.scenario.ip_dict[self.ip]
await trio.sleep(delay)
if result == "error":
raise OSError("sorry")
if result == "postconnect_fail":
self.failing = True
self.succeeded = True
def close(self):
self.closed = True
# called when SocketStream is constructed
def setsockopt(self, *args, **kwargs):
if self.failing:
# raise something that isn't OSError as SocketStream
# ignores those
raise KeyboardInterrupt
class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver):
def __init__(self, port, ip_list, supported_families):
# ip_list have to be unique
ip_order = [ip for (ip, _, _) in ip_list]
assert len(set(ip_order)) == len(ip_list)
ip_dict = {}
for ip, delay, result in ip_list:
assert 0 <= delay
assert result in ["error", "success", "postconnect_fail"]
ip_dict[ip] = (delay, result)
self.port = port
self.ip_order = ip_order
self.ip_dict = ip_dict
self.supported_families = supported_families
self.socket_count = 0
self.sockets = {}
self.connect_times = {}
def socket(self, family, type, proto):
if family not in self.supported_families:
raise OSError("pretending not to support this family")
self.socket_count += 1
return FakeSocket(self, family, type, proto)
def _ip_to_gai_entry(self, ip):
if ":" in ip:
family = trio.socket.AF_INET6
sockaddr = (ip, self.port, 0, 0)
else:
family = trio.socket.AF_INET
sockaddr = (ip, self.port)
return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr)
async def getaddrinfo(self, host, port, family, type, proto, flags):
assert host == b"test.example.com"
assert port == self.port
assert family == trio.socket.AF_UNSPEC
assert type == trio.socket.SOCK_STREAM
assert proto == 0
assert flags == 0
return [self._ip_to_gai_entry(ip) for ip in self.ip_order]
async def getnameinfo(self, sockaddr, flags): # pragma: no cover
raise NotImplementedError
def check(self, succeeded):
# sockets only go into self.sockets when connect is called; make sure
# all the sockets that were created did in fact go in there.
assert self.socket_count == len(self.sockets)
for ip, socket in self.sockets.items():
assert ip in self.ip_dict
if socket is not succeeded:
assert socket.closed
assert socket.port == self.port
async def run_scenario(
# The port to connect to
port,
# A list of
# (ip, delay, result)
# tuples, where delay is in seconds and result is "success" or "error"
# The ip's will be returned from getaddrinfo in this order, and then
# connect() calls to them will have the given result.
ip_list,
*,
# If False, AF_INET4/6 sockets error out on creation, before connect is
# even called.
ipv4_supported=True,
ipv6_supported=True,
# Normally, we return (winning_sock, scenario object)
# If this is True, we require there to be an exception, and return
# (exception, scenario object)
expect_error=(),
**kwargs,
):
supported_families = set()
if ipv4_supported:
supported_families.add(trio.socket.AF_INET)
if ipv6_supported:
supported_families.add(trio.socket.AF_INET6)
scenario = Scenario(port, ip_list, supported_families)
trio.socket.set_custom_hostname_resolver(scenario)
trio.socket.set_custom_socket_factory(scenario)
try:
stream = await open_tcp_stream("test.example.com", port, **kwargs)
assert expect_error == ()
scenario.check(stream.socket)
return (stream.socket, scenario)
except AssertionError: # pragma: no cover
raise
except expect_error as exc:
scenario.check(None)
return (exc, scenario)
async def test_one_host_quick_success(autojump_clock):
sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")])
assert sock.ip == "1.2.3.4"
assert trio.current_time() == 0.123
async def test_one_host_slow_success(autojump_clock):
sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")])
assert sock.ip == "1.2.3.4"
assert trio.current_time() == 100
async def test_one_host_quick_fail(autojump_clock):
exc, scenario = await run_scenario(
82, [("1.2.3.4", 0.123, "error")], expect_error=OSError
)
assert isinstance(exc, OSError)
assert trio.current_time() == 0.123
async def test_one_host_slow_fail(autojump_clock):
exc, scenario = await run_scenario(
83, [("1.2.3.4", 100, "error")], expect_error=OSError
)
assert isinstance(exc, OSError)
assert trio.current_time() == 100
async def test_one_host_failed_after_connect(autojump_clock):
exc, scenario = await run_scenario(
83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt
)
assert isinstance(exc, KeyboardInterrupt)
# With the default 0.250 second delay, the third attempt will win
async def test_basic_fallthrough(autojump_clock):
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 1, "success"),
("3.3.3.3", 0.2, "success"),
],
)
assert sock.ip == "3.3.3.3"
# current time is default time + default time + connection time
assert trio.current_time() == (0.250 + 0.250 + 0.2)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.250,
"3.3.3.3": 0.500,
}
async def test_early_success(autojump_clock):
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 0.1, "success"),
("3.3.3.3", 0.2, "success"),
],
)
assert sock.ip == "2.2.2.2"
assert trio.current_time() == (0.250 + 0.1)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.250,
# 3.3.3.3 was never even started
}
# With a 0.450 second delay, the first attempt will win
async def test_custom_delay(autojump_clock):
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 1, "success"),
("3.3.3.3", 0.2, "success"),
],
happy_eyeballs_delay=0.450,
)
assert sock.ip == "1.1.1.1"
assert trio.current_time() == 1
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.450,
"3.3.3.3": 0.900,
}
async def test_custom_errors_expedite(autojump_clock):
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 0.1, "error"),
("2.2.2.2", 0.2, "error"),
("3.3.3.3", 10, "success"),
# .25 is the default timeout
("4.4.4.4", 0.25, "success"),
],
)
assert sock.ip == "4.4.4.4"
assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.1,
"3.3.3.3": 0.1 + 0.2,
"4.4.4.4": 0.1 + 0.2 + 0.25,
}
async def test_all_fail(autojump_clock):
exc, scenario = await run_scenario(
80,
[
("1.1.1.1", 0.1, "error"),
("2.2.2.2", 0.2, "error"),
("3.3.3.3", 10, "error"),
("4.4.4.4", 0.250, "error"),
],
expect_error=OSError,
)
assert isinstance(exc, OSError)
assert isinstance(exc.__cause__, trio.MultiError)
assert len(exc.__cause__.exceptions) == 4
assert trio.current_time() == (0.1 + 0.2 + 10)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.1,
"3.3.3.3": 0.1 + 0.2,
"4.4.4.4": 0.1 + 0.2 + 0.25,
}
async def test_multi_success(autojump_clock):
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 0.5, "error"),
("2.2.2.2", 10, "success"),
("3.3.3.3", 10 - 1, "success"),
("4.4.4.4", 10 - 2, "success"),
("5.5.5.5", 0.5, "error"),
],
happy_eyeballs_delay=1,
)
assert not scenario.sockets["1.1.1.1"].succeeded
assert (
scenario.sockets["2.2.2.2"].succeeded
or scenario.sockets["3.3.3.3"].succeeded
or scenario.sockets["4.4.4.4"].succeeded
)
assert not scenario.sockets["5.5.5.5"].succeeded
assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"]
assert trio.current_time() == (0.5 + 10)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.5,
"3.3.3.3": 1.5,
"4.4.4.4": 2.5,
"5.5.5.5": 3.5,
}
async def test_does_reorder(autojump_clock):
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 10, "error"),
# This would win if we tried it first...
("2.2.2.2", 1, "success"),
# But in fact we try this first, because of section 5.4
("::3", 0.5, "success"),
],
happy_eyeballs_delay=1,
)
assert sock.ip == "::3"
assert trio.current_time() == 1 + 0.5
assert scenario.connect_times == {
"1.1.1.1": 0,
"::3": 1,
}
async def test_handles_no_ipv4(autojump_clock):
sock, scenario = await run_scenario(
80,
# Here the ipv6 addresses fail at socket creation time, so the connect
# configuration doesn't matter
[
("::1", 10, "success"),
("2.2.2.2", 0, "success"),
("::3", 0.1, "success"),
("4.4.4.4", 0, "success"),
],
happy_eyeballs_delay=1,
ipv4_supported=False,
)
assert sock.ip == "::3"
assert trio.current_time() == 1 + 0.1
assert scenario.connect_times == {
"::1": 0,
"::3": 1.0,
}
async def test_handles_no_ipv6(autojump_clock):
sock, scenario = await run_scenario(
80,
# Here the ipv6 addresses fail at socket creation time, so the connect
# configuration doesn't matter
[
("::1", 0, "success"),
("2.2.2.2", 10, "success"),
("::3", 0, "success"),
("4.4.4.4", 0.1, "success"),
],
happy_eyeballs_delay=1,
ipv6_supported=False,
)
assert sock.ip == "4.4.4.4"
assert trio.current_time() == 1 + 0.1
assert scenario.connect_times == {
"2.2.2.2": 0,
"4.4.4.4": 1.0,
}
async def test_no_hosts(autojump_clock):
exc, scenario = await run_scenario(80, [], expect_error=OSError)
assert "no results found" in str(exc)
async def test_cancel(autojump_clock):
with trio.move_on_after(5) as cancel_scope:
exc, scenario = await run_scenario(
80,
[
("1.1.1.1", 10, "success"),
("2.2.2.2", 10, "success"),
("3.3.3.3", 10, "success"),
("4.4.4.4", 10, "success"),
],
expect_error=trio.MultiError,
)
# What comes out should be 1 or more Cancelled errors that all belong
# to this cancel_scope; this is the easiest way to check that
raise exc
assert cancel_scope.cancelled_caught
assert trio.current_time() == 5
# This should have been called already, but just to make sure, since the
# exception-handling logic in run_scenario is a bit complicated and the
# main thing we care about here is that all the sockets were cleaned up.
scenario.check(succeeded=False)

View File

@@ -0,0 +1,67 @@
import os
import socket
import tempfile
import pytest
from trio import open_unix_socket, Path
from trio._highlevel_open_unix_stream import close_on_error
if not hasattr(socket, "AF_UNIX"):
pytestmark = pytest.mark.skip("Needs unix socket support")
def test_close_on_error():
class CloseMe:
closed = False
def close(self):
self.closed = True
with close_on_error(CloseMe()) as c:
pass
assert not c.closed
with pytest.raises(RuntimeError):
with close_on_error(CloseMe()) as c:
raise RuntimeError
assert c.closed
@pytest.mark.parametrize("filename", [4, 4.5])
async def test_open_with_bad_filename_type(filename):
with pytest.raises(TypeError):
await open_unix_socket(filename)
async def test_open_bad_socket():
# mktemp is marked as insecure, but that's okay, we don't want the file to
# exist
name = tempfile.mktemp()
with pytest.raises(FileNotFoundError):
await open_unix_socket(name)
async def test_open_unix_socket():
for name_type in [Path, str]:
name = tempfile.mktemp()
serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
with serv_sock:
serv_sock.bind(name)
try:
serv_sock.listen(1)
# The actual function we're testing
unix_socket = await open_unix_socket(name_type(name))
async with unix_socket:
client, _ = serv_sock.accept()
with client:
await unix_socket.send_all(b"test")
assert client.recv(2048) == b"test"
client.sendall(b"response")
received = await unix_socket.receive_some(2048)
assert received == b"response"
finally:
os.unlink(name)

View File

@@ -0,0 +1,145 @@
import pytest
from functools import partial
import errno
import attr
import trio
from trio.testing import memory_stream_pair, wait_all_tasks_blocked
@attr.s(hash=False, eq=False)
class MemoryListener(trio.abc.Listener):
closed = attr.ib(default=False)
accepted_streams = attr.ib(factory=list)
queued_streams = attr.ib(factory=(lambda: trio.open_memory_channel(1)))
accept_hook = attr.ib(default=None)
async def connect(self):
assert not self.closed
client, server = memory_stream_pair()
await self.queued_streams[0].send(server)
return client
async def accept(self):
await trio.lowlevel.checkpoint()
assert not self.closed
if self.accept_hook is not None:
await self.accept_hook()
stream = await self.queued_streams[1].receive()
self.accepted_streams.append(stream)
return stream
async def aclose(self):
self.closed = True
await trio.lowlevel.checkpoint()
async def test_serve_listeners_basic():
listeners = [MemoryListener(), MemoryListener()]
record = []
def close_hook():
# Make sure this is a forceful close
assert trio.current_effective_deadline() == float("-inf")
record.append("closed")
async def handler(stream):
await stream.send_all(b"123")
assert await stream.receive_some(10) == b"456"
stream.send_stream.close_hook = close_hook
stream.receive_stream.close_hook = close_hook
async def client(listener):
s = await listener.connect()
assert await s.receive_some(10) == b"123"
await s.send_all(b"456")
async def do_tests(parent_nursery):
async with trio.open_nursery() as nursery:
for listener in listeners:
for _ in range(3):
nursery.start_soon(client, listener)
await wait_all_tasks_blocked()
# verifies that all 6 streams x 2 directions each were closed ok
assert len(record) == 12
parent_nursery.cancel_scope.cancel()
async with trio.open_nursery() as nursery:
l2 = await nursery.start(trio.serve_listeners, handler, listeners)
assert l2 == listeners
# This is just split into another function because gh-136 isn't
# implemented yet
nursery.start_soon(do_tests, nursery)
for listener in listeners:
assert listener.closed
async def test_serve_listeners_accept_unrecognized_error():
for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]:
listener = MemoryListener()
async def raise_error():
raise error
listener.accept_hook = raise_error
with pytest.raises(type(error)) as excinfo:
await trio.serve_listeners(None, [listener])
assert excinfo.value is error
async def test_serve_listeners_accept_capacity_error(autojump_clock, caplog):
listener = MemoryListener()
async def raise_EMFILE():
raise OSError(errno.EMFILE, "out of file descriptors")
listener.accept_hook = raise_EMFILE
# It retries every 100 ms, so in 950 ms it will retry at 0, 100, ..., 900
# = 10 times total
with trio.move_on_after(0.950):
await trio.serve_listeners(None, [listener])
assert len(caplog.records) == 10
for record in caplog.records:
assert "retrying" in record.msg
assert record.exc_info[1].errno == errno.EMFILE
async def test_serve_listeners_connection_nursery(autojump_clock):
listener = MemoryListener()
async def handler(stream):
await trio.sleep(1)
class Done(Exception):
pass
async def connection_watcher(*, task_status=trio.TASK_STATUS_IGNORED):
async with trio.open_nursery() as nursery:
task_status.started(nursery)
await wait_all_tasks_blocked()
assert len(nursery.child_tasks) == 10
raise Done
with pytest.raises(Done):
async with trio.open_nursery() as nursery:
handler_nursery = await nursery.start(connection_watcher)
await nursery.start(
partial(
trio.serve_listeners,
handler,
[listener],
handler_nursery=handler_nursery,
)
)
for _ in range(10):
nursery.start_soon(listener.connect)

View File

@@ -0,0 +1,267 @@
import pytest
import sys
import socket as stdlib_socket
import errno
from .. import _core
from ..testing import (
check_half_closeable_stream,
wait_all_tasks_blocked,
assert_checkpoints,
)
from .._highlevel_socket import *
from .. import socket as tsocket
async def test_SocketStream_basics():
# stdlib socket bad (even if connected)
a, b = stdlib_socket.socketpair()
with a, b:
with pytest.raises(TypeError):
SocketStream(a)
# DGRAM socket bad
with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock:
with pytest.raises(ValueError):
SocketStream(sock)
a, b = tsocket.socketpair()
with a, b:
s = SocketStream(a)
assert s.socket is a
# Use a real, connected socket to test socket options, because
# socketpair() might give us a unix socket that doesn't support any of
# these options
with tsocket.socket() as listen_sock:
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(1)
with tsocket.socket() as client_sock:
await client_sock.connect(listen_sock.getsockname())
s = SocketStream(client_sock)
# TCP_NODELAY enabled by default
assert s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
# We can disable it though
s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
b = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1)
assert isinstance(b, bytes)
async def test_SocketStream_send_all():
BIG = 10000000
a_sock, b_sock = tsocket.socketpair()
with a_sock, b_sock:
a = SocketStream(a_sock)
b = SocketStream(b_sock)
# Check a send_all that has to be split into multiple parts (on most
# platforms... on Windows every send() either succeeds or fails as a
# whole)
async def sender():
data = bytearray(BIG)
await a.send_all(data)
# send_all uses memoryviews internally, which temporarily "lock"
# the object they view. If it doesn't clean them up properly, then
# some bytearray operations might raise an error afterwards, which
# would be a pretty weird and annoying side-effect to spring on
# users. So test that this doesn't happen, by forcing the
# bytearray's underlying buffer to be realloc'ed:
data += bytes(BIG)
# (Note: the above line of code doesn't do a very good job at
# testing anything, because:
# - on CPython, the refcount GC generally cleans up memoryviews
# for us even if we're sloppy.
# - on PyPy3, at least as of 5.7.0, the memoryview code and the
# bytearray code conspire so that resizing never fails if
# resizing forces the bytearray's internal buffer to move, then
# all memoryview references are automagically updated (!!).
# See:
# https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227
# But I'm leaving the test here in hopes that if this ever changes
# and we break our implementation of send_all, then we'll get some
# early warning...)
async def receiver():
# Make sure the sender fills up the kernel buffers and blocks
await wait_all_tasks_blocked()
nbytes = 0
while nbytes < BIG:
nbytes += len(await b.receive_some(BIG))
assert nbytes == BIG
async with _core.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(receiver)
# We know that we received BIG bytes of NULs so far. Make sure that
# was all the data in there.
await a.send_all(b"e")
assert await b.receive_some(10) == b"e"
await a.send_eof()
assert await b.receive_some(10) == b""
async def fill_stream(s):
async def sender():
while True:
await s.send_all(b"x" * 10000)
async def waiter(nursery):
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
async with _core.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(waiter, nursery)
async def test_SocketStream_generic():
async def stream_maker():
left, right = tsocket.socketpair()
return SocketStream(left), SocketStream(right)
async def clogged_stream_maker():
left, right = await stream_maker()
await fill_stream(left)
await fill_stream(right)
return left, right
await check_half_closeable_stream(stream_maker, clogged_stream_maker)
async def test_SocketListener():
# Not a Trio socket
with stdlib_socket.socket() as s:
s.bind(("127.0.0.1", 0))
s.listen(10)
with pytest.raises(TypeError):
SocketListener(s)
# Not a SOCK_STREAM
with tsocket.socket(type=tsocket.SOCK_DGRAM) as s:
await s.bind(("127.0.0.1", 0))
with pytest.raises(ValueError) as excinfo:
SocketListener(s)
excinfo.match(r".*SOCK_STREAM")
# Didn't call .listen()
# macOS has no way to check for this, so skip testing it there.
if sys.platform != "darwin":
with tsocket.socket() as s:
await s.bind(("127.0.0.1", 0))
with pytest.raises(ValueError) as excinfo:
SocketListener(s)
excinfo.match(r".*listen")
listen_sock = tsocket.socket()
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(10)
listener = SocketListener(listen_sock)
assert listener.socket is listen_sock
client_sock = tsocket.socket()
await client_sock.connect(listen_sock.getsockname())
with assert_checkpoints():
server_stream = await listener.accept()
assert isinstance(server_stream, SocketStream)
assert server_stream.socket.getsockname() == listen_sock.getsockname()
assert server_stream.socket.getpeername() == client_sock.getsockname()
with assert_checkpoints():
await listener.aclose()
with assert_checkpoints():
await listener.aclose()
with assert_checkpoints():
with pytest.raises(_core.ClosedResourceError):
await listener.accept()
client_sock.close()
await server_stream.aclose()
async def test_SocketListener_socket_closed_underfoot():
listen_sock = tsocket.socket()
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(10)
listener = SocketListener(listen_sock)
# Close the socket, not the listener
listen_sock.close()
# SocketListener gives correct error
with assert_checkpoints():
with pytest.raises(_core.ClosedResourceError):
await listener.accept()
async def test_SocketListener_accept_errors():
class FakeSocket(tsocket.SocketType):
def __init__(self, events):
self._events = iter(events)
type = tsocket.SOCK_STREAM
# Fool the check for SO_ACCEPTCONN in SocketListener.__init__
def getsockopt(self, level, opt):
return True
def setsockopt(self, level, opt, value):
pass
async def accept(self):
await _core.checkpoint()
event = next(self._events)
if isinstance(event, BaseException):
raise event
else:
return event, None
fake_server_sock = FakeSocket([])
fake_listen_sock = FakeSocket(
[
OSError(errno.ECONNABORTED, "Connection aborted"),
OSError(errno.EPERM, "Permission denied"),
OSError(errno.EPROTO, "Bad protocol"),
fake_server_sock,
OSError(errno.EMFILE, "Out of file descriptors"),
OSError(errno.EFAULT, "attempt to write to read-only memory"),
OSError(errno.ENOBUFS, "out of buffers"),
fake_server_sock,
]
)
l = SocketListener(fake_listen_sock)
with assert_checkpoints():
s = await l.accept()
assert s.socket is fake_server_sock
for code in [errno.EMFILE, errno.EFAULT, errno.ENOBUFS]:
with assert_checkpoints():
with pytest.raises(OSError) as excinfo:
await l.accept()
assert excinfo.value.errno == code
with assert_checkpoints():
s = await l.accept()
assert s.socket is fake_server_sock
async def test_socket_stream_works_when_peer_has_already_closed():
sock_a, sock_b = tsocket.socketpair()
with sock_a, sock_b:
await sock_b.send(b"x")
sock_b.close()
stream = SocketStream(sock_a)
assert await stream.receive_some(1) == b"x"
assert await stream.receive_some(1) == b""

View File

@@ -0,0 +1,113 @@
import pytest
from functools import partial
import attr
import trio
from trio.socket import AF_INET, SOCK_STREAM, IPPROTO_TCP
import trio.testing
from .test_ssl import client_ctx, SERVER_CTX
from .._highlevel_ssl_helpers import (
open_ssl_over_tcp_stream,
open_ssl_over_tcp_listeners,
serve_ssl_over_tcp,
)
async def echo_handler(stream):
async with stream:
try:
while True:
data = await stream.receive_some(10000)
if not data:
break
await stream.send_all(data)
except trio.BrokenResourceError:
pass
# Resolver that always returns the given sockaddr, no matter what host/port
# you ask for.
@attr.s
class FakeHostnameResolver(trio.abc.HostnameResolver):
sockaddr = attr.ib()
async def getaddrinfo(self, *args):
return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)]
async def getnameinfo(self, *args): # pragma: no cover
raise NotImplementedError
# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
# noqa is needed because flake8 doesn't understand how pytest fixtures work.
async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa: F811
async with trio.open_nursery() as nursery:
(listener,) = await nursery.start(
partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1")
)
async with listener:
sockaddr = listener.transport_listener.socket.getsockname()
hostname_resolver = FakeHostnameResolver(sockaddr)
trio.socket.set_custom_hostname_resolver(hostname_resolver)
# We don't have the right trust set up
# (checks that ssl_context=None is doing some validation)
stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
async with stream:
with pytest.raises(trio.BrokenResourceError):
await stream.do_handshake()
# We have the trust but not the hostname
# (checks custom ssl_context + hostname checking)
stream = await open_ssl_over_tcp_stream(
"xyzzy.example.org", 80, ssl_context=client_ctx
)
async with stream:
with pytest.raises(trio.BrokenResourceError):
await stream.do_handshake()
# This one should work!
stream = await open_ssl_over_tcp_stream(
"trio-test-1.example.org", 80, ssl_context=client_ctx
)
async with stream:
assert isinstance(stream, trio.SSLStream)
assert stream.server_hostname == "trio-test-1.example.org"
await stream.send_all(b"x")
assert await stream.receive_some(1) == b"x"
# Check https_compatible settings are being passed through
assert not stream._https_compatible
stream = await open_ssl_over_tcp_stream(
"trio-test-1.example.org",
80,
ssl_context=client_ctx,
https_compatible=True,
# also, smoke test happy_eyeballs_delay
happy_eyeballs_delay=1,
)
async with stream:
assert stream._https_compatible
# Stop the echo server
nursery.cancel_scope.cancel()
async def test_open_ssl_over_tcp_listeners():
(listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1")
async with listener:
assert isinstance(listener, trio.SSLListener)
tl = listener.transport_listener
assert isinstance(tl, trio.SocketListener)
assert tl.socket.getsockname()[0] == "127.0.0.1"
assert not listener._https_compatible
(listener,) = await open_ssl_over_tcp_listeners(
0, SERVER_CTX, host="127.0.0.1", https_compatible=True
)
async with listener:
assert listener._https_compatible

View File

@@ -0,0 +1,262 @@
import os
import pathlib
import pytest
import trio
from trio._path import AsyncAutoWrapperType as Type
from trio._file_io import AsyncIOWrapper
@pytest.fixture
def path(tmpdir):
p = str(tmpdir.join("test"))
return trio.Path(p)
def method_pair(path, method_name):
path = pathlib.Path(path)
async_path = trio.Path(path)
return getattr(path, method_name), getattr(async_path, method_name)
async def test_open_is_async_context_manager(path):
async with await path.open("w") as f:
assert isinstance(f, AsyncIOWrapper)
assert f.closed
async def test_magic():
path = trio.Path("test")
assert str(path) == "test"
assert bytes(path) == b"test"
cls_pairs = [
(trio.Path, pathlib.Path),
(pathlib.Path, trio.Path),
(trio.Path, trio.Path),
]
@pytest.mark.parametrize("cls_a,cls_b", cls_pairs)
async def test_cmp_magic(cls_a, cls_b):
a, b = cls_a(""), cls_b("")
assert a == b
assert not a != b
a, b = cls_a("a"), cls_b("b")
assert a < b
assert b > a
# this is intentionally testing equivalence with none, due to the
# other=sentinel logic in _forward_magic
assert not a == None # noqa
assert not b == None # noqa
# upstream python3.8 bug: we should also test (pathlib.Path, trio.Path), but
# __*div__ does not properly raise NotImplementedError like the other comparison
# magic, so trio.Path's implementation does not get dispatched
cls_pairs = [
(trio.Path, pathlib.Path),
(trio.Path, trio.Path),
(trio.Path, str),
(str, trio.Path),
]
@pytest.mark.parametrize("cls_a,cls_b", cls_pairs)
async def test_div_magic(cls_a, cls_b):
a, b = cls_a("a"), cls_b("b")
result = a / b
assert isinstance(result, trio.Path)
assert str(result) == os.path.join("a", "b")
@pytest.mark.parametrize(
"cls_a,cls_b", [(trio.Path, pathlib.Path), (trio.Path, trio.Path)]
)
@pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"])
async def test_hash_magic(cls_a, cls_b, path):
a, b = cls_a(path), cls_b(path)
assert hash(a) == hash(b)
async def test_forwarded_properties(path):
# use `name` as a representative of forwarded properties
assert "name" in dir(path)
assert path.name == "test"
async def test_async_method_signature(path):
# use `resolve` as a representative of wrapped methods
assert path.resolve.__name__ == "resolve"
assert path.resolve.__qualname__ == "Path.resolve"
assert "pathlib.Path.resolve" in path.resolve.__doc__
@pytest.mark.parametrize("method_name", ["is_dir", "is_file"])
async def test_compare_async_stat_methods(method_name):
method, async_method = method_pair(".", method_name)
result = method()
async_result = await async_method()
assert result == async_result
async def test_invalid_name_not_wrapped(path):
with pytest.raises(AttributeError):
getattr(path, "invalid_fake_attr")
@pytest.mark.parametrize("method_name", ["absolute", "resolve"])
async def test_async_methods_rewrap(method_name):
method, async_method = method_pair(".", method_name)
result = method()
async_result = await async_method()
assert isinstance(async_result, trio.Path)
assert str(result) == str(async_result)
async def test_forward_methods_rewrap(path, tmpdir):
with_name = path.with_name("foo")
with_suffix = path.with_suffix(".py")
assert isinstance(with_name, trio.Path)
assert with_name == tmpdir.join("foo")
assert isinstance(with_suffix, trio.Path)
assert with_suffix == tmpdir.join("test.py")
async def test_forward_properties_rewrap(path):
assert isinstance(path.parent, trio.Path)
async def test_forward_methods_without_rewrap(path, tmpdir):
path = await path.parent.resolve()
assert path.as_uri().startswith("file:///")
async def test_repr():
path = trio.Path(".")
assert repr(path) == "trio.Path('.')"
class MockWrapped:
unsupported = "unsupported"
_private = "private"
class MockWrapper:
_forwards = MockWrapped
_wraps = MockWrapped
async def test_type_forwards_unsupported():
with pytest.raises(TypeError):
Type.generate_forwards(MockWrapper, {})
async def test_type_wraps_unsupported():
with pytest.raises(TypeError):
Type.generate_wraps(MockWrapper, {})
async def test_type_forwards_private():
Type.generate_forwards(MockWrapper, {"unsupported": None})
assert not hasattr(MockWrapper, "_private")
async def test_type_wraps_private():
Type.generate_wraps(MockWrapper, {"unsupported": None})
assert not hasattr(MockWrapper, "_private")
@pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath])
async def test_path_wraps_path(path, meth):
wrapped = await path.absolute()
result = meth(path, wrapped)
if result is None:
result = path
assert wrapped == result
async def test_path_nonpath():
with pytest.raises(TypeError):
trio.Path(1)
async def test_open_file_can_open_path(path):
async with await trio.open_file(path, "w") as f:
assert f.name == os.fspath(path)
async def test_globmethods(path):
# Populate a directory tree
await path.mkdir()
await (path / "foo").mkdir()
await (path / "foo" / "_bar.txt").write_bytes(b"")
await (path / "bar.txt").write_bytes(b"")
await (path / "bar.dat").write_bytes(b"")
# Path.glob
for _pattern, _results in {
"*.txt": {"bar.txt"},
"**/*.txt": {"_bar.txt", "bar.txt"},
}.items():
entries = set()
for entry in await path.glob(_pattern):
assert isinstance(entry, trio.Path)
entries.add(entry.name)
assert entries == _results
# Path.rglob
entries = set()
for entry in await path.rglob("*.txt"):
assert isinstance(entry, trio.Path)
entries.add(entry.name)
assert entries == {"_bar.txt", "bar.txt"}
async def test_iterdir(path):
# Populate a directory
await path.mkdir()
await (path / "foo").mkdir()
await (path / "bar.txt").write_bytes(b"")
entries = set()
for entry in await path.iterdir():
assert isinstance(entry, trio.Path)
entries.add(entry.name)
assert entries == {"bar.txt", "foo"}
async def test_classmethods():
assert isinstance(await trio.Path.home(), trio.Path)
# pathlib.Path has only two classmethods
assert str(await trio.Path.home()) == os.path.expanduser("~")
assert str(await trio.Path.cwd()) == os.getcwd()
# Wrapped method has docstring
assert trio.Path.home.__doc__

View File

@@ -0,0 +1,40 @@
import trio
async def scheduler_trace():
"""Returns a scheduler-dependent value we can use to check determinism."""
trace = []
async def tracer(name):
for i in range(50):
trace.append((name, i))
await trio.sleep(0)
async with trio.open_nursery() as nursery:
for i in range(5):
nursery.start_soon(tracer, i)
return tuple(trace)
def test_the_trio_scheduler_is_not_deterministic():
# At least, not yet. See https://github.com/python-trio/trio/issues/32
traces = []
for _ in range(10):
traces.append(trio.run(scheduler_trace))
assert len(set(traces)) == len(traces)
def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch):
monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True)
traces = []
for _ in range(10):
state = trio._core._run._r.getstate()
try:
trio._core._run._r.seed(0)
traces.append(trio.run(scheduler_trace))
finally:
trio._core._run._r.setstate(state)
assert len(traces) == 10
assert len(set(traces)) == 1

View File

@@ -0,0 +1,177 @@
import signal
import pytest
import trio
from .. import _core
from .._util import signal_raise
from .._signals import open_signal_receiver, _signal_handler
async def test_open_signal_receiver():
orig = signal.getsignal(signal.SIGILL)
with open_signal_receiver(signal.SIGILL) as receiver:
# Raise it a few times, to exercise signal coalescing, both at the
# call_soon level and at the SignalQueue level
signal_raise(signal.SIGILL)
signal_raise(signal.SIGILL)
await _core.wait_all_tasks_blocked()
signal_raise(signal.SIGILL)
await _core.wait_all_tasks_blocked()
async for signum in receiver: # pragma: no branch
assert signum == signal.SIGILL
break
assert receiver._pending_signal_count() == 0
signal_raise(signal.SIGILL)
async for signum in receiver: # pragma: no branch
assert signum == signal.SIGILL
break
assert receiver._pending_signal_count() == 0
with pytest.raises(RuntimeError):
await receiver.__anext__()
assert signal.getsignal(signal.SIGILL) is orig
async def test_open_signal_receiver_restore_handler_after_one_bad_signal():
orig = signal.getsignal(signal.SIGILL)
with pytest.raises(ValueError):
with open_signal_receiver(signal.SIGILL, 1234567):
pass # pragma: no cover
# Still restored even if we errored out
assert signal.getsignal(signal.SIGILL) is orig
async def test_open_signal_receiver_empty_fail():
with pytest.raises(TypeError, match="No signals were provided"):
with open_signal_receiver():
pass
async def test_open_signal_receiver_restore_handler_after_duplicate_signal():
orig = signal.getsignal(signal.SIGILL)
with open_signal_receiver(signal.SIGILL, signal.SIGILL):
pass
# Still restored correctly
assert signal.getsignal(signal.SIGILL) is orig
async def test_catch_signals_wrong_thread():
async def naughty():
with open_signal_receiver(signal.SIGINT):
pass # pragma: no cover
with pytest.raises(RuntimeError):
await trio.to_thread.run_sync(trio.run, naughty)
async def test_open_signal_receiver_conflict():
with pytest.raises(trio.BusyResourceError):
with open_signal_receiver(signal.SIGILL) as receiver:
async with trio.open_nursery() as nursery:
nursery.start_soon(receiver.__anext__)
nursery.start_soon(receiver.__anext__)
# Blocks until all previous calls to run_sync_soon(idempotent=True) have been
# processed.
async def wait_run_sync_soon_idempotent_queue_barrier():
ev = trio.Event()
token = _core.current_trio_token()
token.run_sync_soon(ev.set, idempotent=True)
await ev.wait()
async def test_open_signal_receiver_no_starvation():
# Set up a situation where there are always 2 pending signals available to
# report, and make sure that instead of getting the same signal reported
# over and over, it alternates between reporting both of them.
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
try:
print(signal.getsignal(signal.SIGILL))
previous = None
for _ in range(10):
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
if previous is None:
previous = await receiver.__anext__()
else:
got = await receiver.__anext__()
assert got in [signal.SIGILL, signal.SIGFPE]
assert got != previous
previous = got
# Clear out the last signal so it doesn't get redelivered
while receiver._pending_signal_count() != 0:
await receiver.__anext__()
except: # pragma: no cover
# If there's an unhandled exception above, then exiting the
# open_signal_receiver block might cause the signal to be
# redelivered and give us a core dump instead of a traceback...
import traceback
traceback.print_exc()
async def test_catch_signals_race_condition_on_exit():
delivered_directly = set()
def direct_handler(signo, frame):
delivered_directly.add(signo)
print(1)
# Test the version where the call_soon *doesn't* have a chance to run
# before we exit the with block:
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
delivered_directly.clear()
print(2)
# Test the version where the call_soon *does* have a chance to run before
# we exit the with block:
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
assert receiver._pending_signal_count() == 2
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
delivered_directly.clear()
# Again, but with a SIG_IGN signal:
print(3)
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
with open_signal_receiver(signal.SIGILL) as receiver:
signal_raise(signal.SIGILL)
await wait_run_sync_soon_idempotent_queue_barrier()
# test passes if the process reaches this point without dying
print(4)
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
with open_signal_receiver(signal.SIGILL) as receiver:
signal_raise(signal.SIGILL)
await wait_run_sync_soon_idempotent_queue_barrier()
assert receiver._pending_signal_count() == 1
# test passes if the process reaches this point without dying
# Check exception chaining if there are multiple exception-raising
# handlers
def raise_handler(signum, _):
raise RuntimeError(signum)
with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler):
with pytest.raises(RuntimeError) as excinfo:
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
assert receiver._pending_signal_count() == 2
exc = excinfo.value
signums = {exc.args[0]}
assert isinstance(exc.__context__, RuntimeError)
signums.add(exc.__context__.args[0])
assert signums == {signal.SIGILL, signal.SIGFPE}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,482 @@
import os
import signal
import subprocess
import sys
import pytest
import random
from functools import partial
from .. import (
_core,
move_on_after,
fail_after,
sleep,
sleep_forever,
Process,
open_process,
run_process,
TrioDeprecationWarning,
)
from .._core.tests.tutil import slow, skip_if_fbsd_pipes_broken
from ..testing import wait_all_tasks_blocked
posix = os.name == "posix"
if posix:
from signal import SIGKILL, SIGTERM, SIGUSR1
else:
SIGKILL, SIGTERM, SIGUSR1 = None, None, None
# Since Windows has very few command-line utilities generally available,
# all of our subprocesses are Python processes running short bits of
# (mostly) cross-platform code.
def python(code):
return [sys.executable, "-u", "-c", "import sys; " + code]
EXIT_TRUE = python("sys.exit(0)")
EXIT_FALSE = python("sys.exit(1)")
CAT = python("sys.stdout.buffer.write(sys.stdin.buffer.read())")
SLEEP = lambda seconds: python("import time; time.sleep({})".format(seconds))
def got_signal(proc, sig):
if posix:
return proc.returncode == -sig
else:
return proc.returncode != 0
async def test_basic():
async with await open_process(EXIT_TRUE) as proc:
pass
assert isinstance(proc, Process)
assert proc._pidfd is None
assert proc.returncode == 0
assert repr(proc) == f"<trio.Process {EXIT_TRUE}: exited with status 0>"
async with await open_process(EXIT_FALSE) as proc:
pass
assert proc.returncode == 1
assert repr(proc) == "<trio.Process {!r}: {}>".format(
EXIT_FALSE, "exited with status 1"
)
async def test_auto_update_returncode():
p = await open_process(SLEEP(9999))
assert p.returncode is None
assert "running" in repr(p)
p.kill()
p._proc.wait()
assert p.returncode is not None
assert "exited" in repr(p)
assert p._pidfd is None
assert p.returncode is not None
async def test_multi_wait():
async with await open_process(SLEEP(10)) as proc:
# Check that wait (including multi-wait) tolerates being cancelled
async with _core.open_nursery() as nursery:
nursery.start_soon(proc.wait)
nursery.start_soon(proc.wait)
nursery.start_soon(proc.wait)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# Now try waiting for real
async with _core.open_nursery() as nursery:
nursery.start_soon(proc.wait)
nursery.start_soon(proc.wait)
nursery.start_soon(proc.wait)
await wait_all_tasks_blocked()
proc.kill()
async def test_kill_when_context_cancelled():
with move_on_after(100) as scope:
async with await open_process(SLEEP(10)) as proc:
assert proc.poll() is None
scope.cancel()
await sleep_forever()
assert scope.cancelled_caught
assert got_signal(proc, SIGKILL)
assert repr(proc) == "<trio.Process {!r}: {}>".format(
SLEEP(10), "exited with signal 9" if posix else "exited with status 1"
)
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR = python(
"data = sys.stdin.buffer.read(); "
"sys.stdout.buffer.write(data); "
"sys.stderr.buffer.write(data[::-1])"
)
async def test_pipes():
async with await open_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as proc:
msg = b"the quick brown fox jumps over the lazy dog"
async def feed_input():
await proc.stdin.send_all(msg)
await proc.stdin.aclose()
async def check_output(stream, expected):
seen = bytearray()
async for chunk in stream:
seen += chunk
assert seen == expected
async with _core.open_nursery() as nursery:
# fail eventually if something is broken
nursery.cancel_scope.deadline = _core.current_time() + 30.0
nursery.start_soon(feed_input)
nursery.start_soon(check_output, proc.stdout, msg)
nursery.start_soon(check_output, proc.stderr, msg[::-1])
assert not nursery.cancel_scope.cancelled_caught
assert 0 == await proc.wait()
async def test_interactive():
# Test some back-and-forth with a subprocess. This one works like so:
# in: 32\n
# out: 0000...0000\n (32 zeroes)
# err: 1111...1111\n (64 ones)
# in: 10\n
# out: 2222222222\n (10 twos)
# err: 3333....3333\n (20 threes)
# in: EOF
# out: EOF
# err: EOF
async with await open_process(
python(
"idx = 0\n"
"while True:\n"
" line = sys.stdin.readline()\n"
" if line == '': break\n"
" request = int(line.strip())\n"
" print(str(idx * 2) * request)\n"
" print(str(idx * 2 + 1) * request * 2, file=sys.stderr)\n"
" idx += 1\n"
),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as proc:
newline = b"\n" if posix else b"\r\n"
async def expect(idx, request):
async with _core.open_nursery() as nursery:
async def drain_one(stream, count, digit):
while count > 0:
result = await stream.receive_some(count)
assert result == (
"{}".format(digit).encode("utf-8") * len(result)
)
count -= len(result)
assert count == 0
assert await stream.receive_some(len(newline)) == newline
nursery.start_soon(drain_one, proc.stdout, request, idx * 2)
nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1)
with fail_after(5):
await proc.stdin.send_all(b"12")
await sleep(0.1)
await proc.stdin.send_all(b"345" + newline)
await expect(0, 12345)
await proc.stdin.send_all(b"100" + newline + b"200" + newline)
await expect(1, 100)
await expect(2, 200)
await proc.stdin.send_all(b"0" + newline)
await expect(3, 0)
await proc.stdin.send_all(b"999999")
with move_on_after(0.1) as scope:
await expect(4, 0)
assert scope.cancelled_caught
await proc.stdin.send_all(newline)
await expect(4, 999999)
await proc.stdin.aclose()
assert await proc.stdout.receive_some(1) == b""
assert await proc.stderr.receive_some(1) == b""
assert proc.returncode == 0
async def test_run():
data = bytes(random.randint(0, 255) for _ in range(2 ** 18))
result = await run_process(
CAT, stdin=data, capture_stdout=True, capture_stderr=True
)
assert result.args == CAT
assert result.returncode == 0
assert result.stdout == data
assert result.stderr == b""
result = await run_process(CAT, capture_stdout=True)
assert result.args == CAT
assert result.returncode == 0
assert result.stdout == b""
assert result.stderr is None
result = await run_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=data,
capture_stdout=True,
capture_stderr=True,
)
assert result.args == COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR
assert result.returncode == 0
assert result.stdout == data
assert result.stderr == data[::-1]
# invalid combinations
with pytest.raises(UnicodeError):
await run_process(CAT, stdin="oh no, it's text")
with pytest.raises(ValueError):
await run_process(CAT, stdin=subprocess.PIPE)
with pytest.raises(ValueError):
await run_process(CAT, capture_stdout=True, stdout=subprocess.DEVNULL)
with pytest.raises(ValueError):
await run_process(CAT, capture_stderr=True, stderr=None)
async def test_run_check():
cmd = python("sys.stderr.buffer.write(b'test\\n'); sys.exit(1)")
with pytest.raises(subprocess.CalledProcessError) as excinfo:
await run_process(cmd, stdin=subprocess.DEVNULL, capture_stderr=True)
assert excinfo.value.cmd == cmd
assert excinfo.value.returncode == 1
assert excinfo.value.stderr == b"test\n"
assert excinfo.value.stdout is None
result = await run_process(
cmd, capture_stdout=True, capture_stderr=True, check=False
)
assert result.args == cmd
assert result.stdout == b""
assert result.stderr == b"test\n"
assert result.returncode == 1
@skip_if_fbsd_pipes_broken
async def test_run_with_broken_pipe():
result = await run_process(
[sys.executable, "-c", "import sys; sys.stdin.close()"], stdin=b"x" * 131072
)
assert result.returncode == 0
assert result.stdout is result.stderr is None
async def test_stderr_stdout():
async with await open_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as proc:
assert proc.stdout is not None
assert proc.stderr is None
await proc.stdio.send_all(b"1234")
await proc.stdio.send_eof()
output = []
while True:
chunk = await proc.stdio.receive_some(16)
if chunk == b"":
break
output.append(chunk)
assert b"".join(output) == b"12344321"
assert proc.returncode == 0
# equivalent test with run_process()
result = await run_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=b"1234",
capture_stdout=True,
stderr=subprocess.STDOUT,
)
assert result.returncode == 0
assert result.stdout == b"12344321"
assert result.stderr is None
# this one hits the branch where stderr=STDOUT but stdout
# is not redirected
async with await open_process(
CAT, stdin=subprocess.PIPE, stderr=subprocess.STDOUT
) as proc:
assert proc.stdout is None
assert proc.stderr is None
await proc.stdin.aclose()
assert proc.returncode == 0
if posix:
try:
r, w = os.pipe()
async with await open_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=subprocess.PIPE,
stdout=w,
stderr=subprocess.STDOUT,
) as proc:
os.close(w)
assert proc.stdio is None
assert proc.stdout is None
assert proc.stderr is None
await proc.stdin.send_all(b"1234")
await proc.stdin.aclose()
assert await proc.wait() == 0
assert os.read(r, 4096) == b"12344321"
assert os.read(r, 4096) == b""
finally:
os.close(r)
async def test_errors():
with pytest.raises(TypeError) as excinfo:
await open_process(["ls"], encoding="utf-8")
assert "unbuffered byte streams" in str(excinfo.value)
assert "the 'encoding' option is not supported" in str(excinfo.value)
if posix:
with pytest.raises(TypeError) as excinfo:
await open_process(["ls"], shell=True)
with pytest.raises(TypeError) as excinfo:
await open_process("ls", shell=False)
async def test_signals():
async def test_one_signal(send_it, signum):
with move_on_after(1.0) as scope:
async with await open_process(SLEEP(3600)) as proc:
send_it(proc)
assert not scope.cancelled_caught
if posix:
assert proc.returncode == -signum
else:
assert proc.returncode != 0
await test_one_signal(Process.kill, SIGKILL)
await test_one_signal(Process.terminate, SIGTERM)
# Test that we can send arbitrary signals.
#
# We used to use SIGINT here, but it turns out that the Python interpreter
# has race conditions that can cause it to explode in weird ways if it
# tries to handle SIGINT during startup. SIGUSR1's default disposition is
# to terminate the target process, and Python doesn't try to do anything
# clever to handle it.
if posix:
await test_one_signal(lambda proc: proc.send_signal(SIGUSR1), SIGUSR1)
@pytest.mark.skipif(not posix, reason="POSIX specific")
async def test_wait_reapable_fails():
old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
try:
# With SIGCHLD disabled, the wait() syscall will wait for the
# process to exit but then fail with ECHILD. Make sure we
# support this case as the stdlib subprocess module does.
async with await open_process(SLEEP(3600)) as proc:
async with _core.open_nursery() as nursery:
nursery.start_soon(proc.wait)
await wait_all_tasks_blocked()
proc.kill()
nursery.cancel_scope.deadline = _core.current_time() + 1.0
assert not nursery.cancel_scope.cancelled_caught
assert proc.returncode == 0 # exit status unknowable, so...
finally:
signal.signal(signal.SIGCHLD, old_sigchld)
@slow
def test_waitid_eintr():
# This only matters on PyPy (where we're coding EINTR handling
# ourselves) but the test works on all waitid platforms.
from .._subprocess_platform import wait_child_exiting
if not wait_child_exiting.__module__.endswith("waitid"):
pytest.skip("waitid only")
from .._subprocess_platform.waitid import sync_wait_reapable
got_alarm = False
sleeper = subprocess.Popen(["sleep", "3600"])
def on_alarm(sig, frame):
nonlocal got_alarm
got_alarm = True
sleeper.kill()
old_sigalrm = signal.signal(signal.SIGALRM, on_alarm)
try:
signal.alarm(1)
sync_wait_reapable(sleeper.pid)
assert sleeper.wait(timeout=1) == -9
finally:
if sleeper.returncode is None: # pragma: no cover
# We only get here if something fails in the above;
# if the test passes, wait() will reap the process
sleeper.kill()
sleeper.wait()
signal.signal(signal.SIGALRM, old_sigalrm)
async def test_custom_deliver_cancel():
custom_deliver_cancel_called = False
async def custom_deliver_cancel(proc):
nonlocal custom_deliver_cancel_called
custom_deliver_cancel_called = True
proc.terminate()
# Make sure this does get cancelled when the process exits, and that
# the process really exited.
try:
await sleep_forever()
finally:
assert proc.returncode is not None
async with _core.open_nursery() as nursery:
nursery.start_soon(
partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel)
)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert custom_deliver_cancel_called
async def test_warn_on_failed_cancel_terminate(monkeypatch):
original_terminate = Process.terminate
def broken_terminate(self):
original_terminate(self)
raise OSError("whoops")
monkeypatch.setattr(Process, "terminate", broken_terminate)
with pytest.warns(RuntimeWarning, match=".*whoops.*"):
async with _core.open_nursery() as nursery:
nursery.start_soon(run_process, SLEEP(9999))
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
@pytest.mark.skipif(os.name != "posix", reason="posix only")
async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch):
monkeypatch.setattr(Process, "terminate", lambda *args: None)
with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"):
async with _core.open_nursery() as nursery:
nursery.start_soon(run_process, SLEEP(9999))
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()

View File

@@ -0,0 +1,570 @@
import pytest
import weakref
from ..testing import wait_all_tasks_blocked, assert_checkpoints
from .. import _core
from .. import _timeouts
from .._timeouts import sleep_forever, move_on_after
from .._sync import *
async def test_Event():
e = Event()
assert not e.is_set()
assert e.statistics().tasks_waiting == 0
e.set()
assert e.is_set()
with assert_checkpoints():
await e.wait()
e = Event()
record = []
async def child():
record.append("sleeping")
await e.wait()
record.append("woken")
async with _core.open_nursery() as nursery:
nursery.start_soon(child)
nursery.start_soon(child)
await wait_all_tasks_blocked()
assert record == ["sleeping", "sleeping"]
assert e.statistics().tasks_waiting == 2
e.set()
await wait_all_tasks_blocked()
assert record == ["sleeping", "sleeping", "woken", "woken"]
async def test_CapacityLimiter():
with pytest.raises(TypeError):
CapacityLimiter(1.0)
with pytest.raises(ValueError):
CapacityLimiter(-1)
c = CapacityLimiter(2)
repr(c) # smoke test
assert c.total_tokens == 2
assert c.borrowed_tokens == 0
assert c.available_tokens == 2
with pytest.raises(RuntimeError):
c.release()
assert c.borrowed_tokens == 0
c.acquire_nowait()
assert c.borrowed_tokens == 1
assert c.available_tokens == 1
stats = c.statistics()
assert stats.borrowed_tokens == 1
assert stats.total_tokens == 2
assert stats.borrowers == [_core.current_task()]
assert stats.tasks_waiting == 0
# Can't re-acquire when we already have it
with pytest.raises(RuntimeError):
c.acquire_nowait()
assert c.borrowed_tokens == 1
with pytest.raises(RuntimeError):
await c.acquire()
assert c.borrowed_tokens == 1
# We can acquire on behalf of someone else though
with assert_checkpoints():
await c.acquire_on_behalf_of("someone")
# But then we've run out of capacity
assert c.borrowed_tokens == 2
with pytest.raises(_core.WouldBlock):
c.acquire_on_behalf_of_nowait("third party")
assert set(c.statistics().borrowers) == {_core.current_task(), "someone"}
# Until we release one
c.release_on_behalf_of(_core.current_task())
assert c.statistics().borrowers == ["someone"]
c.release_on_behalf_of("someone")
assert c.borrowed_tokens == 0
with assert_checkpoints():
async with c:
assert c.borrowed_tokens == 1
async with _core.open_nursery() as nursery:
await c.acquire_on_behalf_of("value 1")
await c.acquire_on_behalf_of("value 2")
nursery.start_soon(c.acquire_on_behalf_of, "value 3")
await wait_all_tasks_blocked()
assert c.borrowed_tokens == 2
assert c.statistics().tasks_waiting == 1
c.release_on_behalf_of("value 2")
# Fairness:
assert c.borrowed_tokens == 2
with pytest.raises(_core.WouldBlock):
c.acquire_nowait()
c.release_on_behalf_of("value 3")
c.release_on_behalf_of("value 1")
async def test_CapacityLimiter_inf():
from math import inf
c = CapacityLimiter(inf)
repr(c) # smoke test
assert c.total_tokens == inf
assert c.borrowed_tokens == 0
assert c.available_tokens == inf
with pytest.raises(RuntimeError):
c.release()
assert c.borrowed_tokens == 0
c.acquire_nowait()
assert c.borrowed_tokens == 1
assert c.available_tokens == inf
async def test_CapacityLimiter_change_total_tokens():
c = CapacityLimiter(2)
with pytest.raises(TypeError):
c.total_tokens = 1.0
with pytest.raises(ValueError):
c.total_tokens = 0
with pytest.raises(ValueError):
c.total_tokens = -10
assert c.total_tokens == 2
async with _core.open_nursery() as nursery:
for i in range(5):
nursery.start_soon(c.acquire_on_behalf_of, i)
await wait_all_tasks_blocked()
assert set(c.statistics().borrowers) == {0, 1}
assert c.statistics().tasks_waiting == 3
c.total_tokens += 2
assert set(c.statistics().borrowers) == {0, 1, 2, 3}
assert c.statistics().tasks_waiting == 1
c.total_tokens -= 3
assert c.borrowed_tokens == 4
assert c.total_tokens == 1
c.release_on_behalf_of(0)
c.release_on_behalf_of(1)
c.release_on_behalf_of(2)
assert set(c.statistics().borrowers) == {3}
assert c.statistics().tasks_waiting == 1
c.release_on_behalf_of(3)
assert set(c.statistics().borrowers) == {4}
assert c.statistics().tasks_waiting == 0
# regression test for issue #548
async def test_CapacityLimiter_memleak_548():
limiter = CapacityLimiter(total_tokens=1)
await limiter.acquire()
async with _core.open_nursery() as n:
n.start_soon(limiter.acquire)
await wait_all_tasks_blocked() # give it a chance to run the task
n.cancel_scope.cancel()
# if this is 1, the acquire call (despite being killed) is still there in the task, and will
# leak memory all the while the limiter is active
assert len(limiter._pending_borrowers) == 0
async def test_Semaphore():
with pytest.raises(TypeError):
Semaphore(1.0)
with pytest.raises(ValueError):
Semaphore(-1)
s = Semaphore(1)
repr(s) # smoke test
assert s.value == 1
assert s.max_value is None
s.release()
assert s.value == 2
assert s.statistics().tasks_waiting == 0
s.acquire_nowait()
assert s.value == 1
with assert_checkpoints():
await s.acquire()
assert s.value == 0
with pytest.raises(_core.WouldBlock):
s.acquire_nowait()
s.release()
assert s.value == 1
with assert_checkpoints():
async with s:
assert s.value == 0
assert s.value == 1
s.acquire_nowait()
record = []
async def do_acquire(s):
record.append("started")
await s.acquire()
record.append("finished")
async with _core.open_nursery() as nursery:
nursery.start_soon(do_acquire, s)
await wait_all_tasks_blocked()
assert record == ["started"]
assert s.value == 0
s.release()
# Fairness:
assert s.value == 0
with pytest.raises(_core.WouldBlock):
s.acquire_nowait()
assert record == ["started", "finished"]
async def test_Semaphore_bounded():
with pytest.raises(TypeError):
Semaphore(1, max_value=1.0)
with pytest.raises(ValueError):
Semaphore(2, max_value=1)
bs = Semaphore(1, max_value=1)
assert bs.max_value == 1
repr(bs) # smoke test
with pytest.raises(ValueError):
bs.release()
assert bs.value == 1
bs.acquire_nowait()
assert bs.value == 0
bs.release()
assert bs.value == 1
@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__)
async def test_Lock_and_StrictFIFOLock(lockcls):
l = lockcls() # noqa
assert not l.locked()
# make sure locks can be weakref'ed (gh-331)
r = weakref.ref(l)
assert r() is l
repr(l) # smoke test
# make sure repr uses the right name for subclasses
assert lockcls.__name__ in repr(l)
with assert_checkpoints():
async with l:
assert l.locked()
repr(l) # smoke test (repr branches on locked/unlocked)
assert not l.locked()
l.acquire_nowait()
assert l.locked()
l.release()
assert not l.locked()
with assert_checkpoints():
await l.acquire()
assert l.locked()
l.release()
assert not l.locked()
l.acquire_nowait()
with pytest.raises(RuntimeError):
# Error out if we already own the lock
l.acquire_nowait()
l.release()
with pytest.raises(RuntimeError):
# Error out if we don't own the lock
l.release()
holder_task = None
async def holder():
nonlocal holder_task
holder_task = _core.current_task()
async with l:
await sleep_forever()
async with _core.open_nursery() as nursery:
assert not l.locked()
nursery.start_soon(holder)
await wait_all_tasks_blocked()
assert l.locked()
# WouldBlock if someone else holds the lock
with pytest.raises(_core.WouldBlock):
l.acquire_nowait()
# Can't release a lock someone else holds
with pytest.raises(RuntimeError):
l.release()
statistics = l.statistics()
print(statistics)
assert statistics.locked
assert statistics.owner is holder_task
assert statistics.tasks_waiting == 0
nursery.start_soon(holder)
await wait_all_tasks_blocked()
statistics = l.statistics()
print(statistics)
assert statistics.tasks_waiting == 1
nursery.cancel_scope.cancel()
statistics = l.statistics()
assert not statistics.locked
assert statistics.owner is None
assert statistics.tasks_waiting == 0
async def test_Condition():
with pytest.raises(TypeError):
Condition(Semaphore(1))
with pytest.raises(TypeError):
Condition(StrictFIFOLock)
l = Lock() # noqa
c = Condition(l)
assert not l.locked()
assert not c.locked()
with assert_checkpoints():
await c.acquire()
assert l.locked()
assert c.locked()
c = Condition()
assert not c.locked()
c.acquire_nowait()
assert c.locked()
with pytest.raises(RuntimeError):
c.acquire_nowait()
c.release()
with pytest.raises(RuntimeError):
# Can't wait without holding the lock
await c.wait()
with pytest.raises(RuntimeError):
# Can't notify without holding the lock
c.notify()
with pytest.raises(RuntimeError):
# Can't notify without holding the lock
c.notify_all()
finished_waiters = set()
async def waiter(i):
async with c:
await c.wait()
finished_waiters.add(i)
async with _core.open_nursery() as nursery:
for i in range(3):
nursery.start_soon(waiter, i)
await wait_all_tasks_blocked()
async with c:
c.notify()
assert c.locked()
await wait_all_tasks_blocked()
assert finished_waiters == {0}
async with c:
c.notify_all()
await wait_all_tasks_blocked()
assert finished_waiters == {0, 1, 2}
finished_waiters = set()
async with _core.open_nursery() as nursery:
for i in range(3):
nursery.start_soon(waiter, i)
await wait_all_tasks_blocked()
async with c:
c.notify(2)
statistics = c.statistics()
print(statistics)
assert statistics.tasks_waiting == 1
assert statistics.lock_statistics.tasks_waiting == 2
# exiting the context manager hands off the lock to the first task
assert c.statistics().lock_statistics.tasks_waiting == 1
await wait_all_tasks_blocked()
assert finished_waiters == {0, 1}
async with c:
c.notify_all()
# After being cancelled still hold the lock (!)
# (Note that c.__aexit__ checks that we hold the lock as well)
with _core.CancelScope() as scope:
async with c:
scope.cancel()
try:
await c.wait()
finally:
assert c.locked()
from .._sync import async_cm
from .._channel import open_memory_channel
# Three ways of implementing a Lock in terms of a channel. Used to let us put
# the channel through the generic lock tests.
@async_cm
class ChannelLock1:
def __init__(self, capacity):
self.s, self.r = open_memory_channel(capacity)
for _ in range(capacity - 1):
self.s.send_nowait(None)
def acquire_nowait(self):
self.s.send_nowait(None)
async def acquire(self):
await self.s.send(None)
def release(self):
self.r.receive_nowait()
@async_cm
class ChannelLock2:
def __init__(self):
self.s, self.r = open_memory_channel(10)
self.s.send_nowait(None)
def acquire_nowait(self):
self.r.receive_nowait()
async def acquire(self):
await self.r.receive()
def release(self):
self.s.send_nowait(None)
@async_cm
class ChannelLock3:
def __init__(self):
self.s, self.r = open_memory_channel(0)
# self.acquired is true when one task acquires the lock and
# only becomes false when it's released and no tasks are
# waiting to acquire.
self.acquired = False
def acquire_nowait(self):
assert not self.acquired
self.acquired = True
async def acquire(self):
if self.acquired:
await self.s.send(None)
else:
self.acquired = True
await _core.checkpoint()
def release(self):
try:
self.r.receive_nowait()
except _core.WouldBlock:
assert self.acquired
self.acquired = False
lock_factories = [
lambda: CapacityLimiter(1),
lambda: Semaphore(1),
Lock,
StrictFIFOLock,
lambda: ChannelLock1(10),
lambda: ChannelLock1(1),
ChannelLock2,
ChannelLock3,
]
lock_factory_names = [
"CapacityLimiter(1)",
"Semaphore(1)",
"Lock",
"StrictFIFOLock",
"ChannelLock1(10)",
"ChannelLock1(1)",
"ChannelLock2",
"ChannelLock3",
]
generic_lock_test = pytest.mark.parametrize(
"lock_factory", lock_factories, ids=lock_factory_names
)
# Spawn a bunch of workers that take a lock and then yield; make sure that
# only one worker is ever in the critical section at a time.
@generic_lock_test
async def test_generic_lock_exclusion(lock_factory):
LOOPS = 10
WORKERS = 5
in_critical_section = False
acquires = 0
async def worker(lock_like):
nonlocal in_critical_section, acquires
for _ in range(LOOPS):
async with lock_like:
acquires += 1
assert not in_critical_section
in_critical_section = True
await _core.checkpoint()
await _core.checkpoint()
assert in_critical_section
in_critical_section = False
async with _core.open_nursery() as nursery:
lock_like = lock_factory()
for _ in range(WORKERS):
nursery.start_soon(worker, lock_like)
assert not in_critical_section
assert acquires == LOOPS * WORKERS
# Several workers queue on the same lock; make sure they each get it, in
# order.
@generic_lock_test
async def test_generic_lock_fifo_fairness(lock_factory):
initial_order = []
record = []
LOOPS = 5
async def loopy(name, lock_like):
# Record the order each task was initially scheduled in
initial_order.append(name)
for _ in range(LOOPS):
async with lock_like:
record.append(name)
lock_like = lock_factory()
async with _core.open_nursery() as nursery:
nursery.start_soon(loopy, 1, lock_like)
nursery.start_soon(loopy, 2, lock_like)
nursery.start_soon(loopy, 3, lock_like)
# The first three could be in any order due to scheduling randomness,
# but after that they should repeat in the same order
for i in range(LOOPS):
assert record[3 * i : 3 * (i + 1)] == initial_order
@generic_lock_test
async def test_generic_lock_acquire_nowait_blocks_acquire(lock_factory):
lock_like = lock_factory()
record = []
async def lock_taker():
record.append("started")
async with lock_like:
pass
record.append("finished")
async with _core.open_nursery() as nursery:
lock_like.acquire_nowait()
nursery.start_soon(lock_taker)
await wait_all_tasks_blocked()
assert record == ["started"]
lock_like.release()

View File

@@ -0,0 +1,657 @@
# XX this should get broken up, like testing.py did
import tempfile
import pytest
from .._core.tests.tutil import can_bind_ipv6
from .. import sleep
from .. import _core
from .._highlevel_generic import aclose_forcefully
from ..testing import *
from ..testing._check_streams import _assert_raises
from ..testing._memory_streams import _UnboundedByteQueue
from .. import socket as tsocket
from .._highlevel_socket import SocketListener
async def test_wait_all_tasks_blocked():
record = []
async def busy_bee():
for _ in range(10):
await _core.checkpoint()
record.append("busy bee exhausted")
async def waiting_for_bee_to_leave():
await wait_all_tasks_blocked()
record.append("quiet at last!")
async with _core.open_nursery() as nursery:
nursery.start_soon(busy_bee)
nursery.start_soon(waiting_for_bee_to_leave)
nursery.start_soon(waiting_for_bee_to_leave)
# check cancellation
record = []
async def cancelled_while_waiting():
try:
await wait_all_tasks_blocked()
except _core.Cancelled:
record.append("ok")
async with _core.open_nursery() as nursery:
nursery.start_soon(cancelled_while_waiting)
nursery.cancel_scope.cancel()
assert record == ["ok"]
async def test_wait_all_tasks_blocked_with_timeouts(mock_clock):
record = []
async def timeout_task():
record.append("tt start")
await sleep(5)
record.append("tt finished")
async with _core.open_nursery() as nursery:
nursery.start_soon(timeout_task)
await wait_all_tasks_blocked()
assert record == ["tt start"]
mock_clock.jump(10)
await wait_all_tasks_blocked()
assert record == ["tt start", "tt finished"]
async def test_wait_all_tasks_blocked_with_cushion():
record = []
async def blink():
record.append("blink start")
await sleep(0.01)
await sleep(0.01)
await sleep(0.01)
record.append("blink end")
async def wait_no_cushion():
await wait_all_tasks_blocked()
record.append("wait_no_cushion end")
async def wait_small_cushion():
await wait_all_tasks_blocked(0.02)
record.append("wait_small_cushion end")
async def wait_big_cushion():
await wait_all_tasks_blocked(0.03)
record.append("wait_big_cushion end")
async with _core.open_nursery() as nursery:
nursery.start_soon(blink)
nursery.start_soon(wait_no_cushion)
nursery.start_soon(wait_small_cushion)
nursery.start_soon(wait_small_cushion)
nursery.start_soon(wait_big_cushion)
assert record == [
"blink start",
"wait_no_cushion end",
"blink end",
"wait_small_cushion end",
"wait_small_cushion end",
"wait_big_cushion end",
]
################################################################
async def test_assert_checkpoints(recwarn):
with assert_checkpoints():
await _core.checkpoint()
with pytest.raises(AssertionError):
with assert_checkpoints():
1 + 1
# partial yield cases
# if you have a schedule point but not a cancel point, or vice-versa, then
# that's not a checkpoint.
for partial_yield in [
_core.checkpoint_if_cancelled,
_core.cancel_shielded_checkpoint,
]:
print(partial_yield)
with pytest.raises(AssertionError):
with assert_checkpoints():
await partial_yield()
# But both together count as a checkpoint
with assert_checkpoints():
await _core.checkpoint_if_cancelled()
await _core.cancel_shielded_checkpoint()
async def test_assert_no_checkpoints(recwarn):
with assert_no_checkpoints():
1 + 1
with pytest.raises(AssertionError):
with assert_no_checkpoints():
await _core.checkpoint()
# partial yield cases
# if you have a schedule point but not a cancel point, or vice-versa, then
# that doesn't make *either* version of assert_{no_,}yields happy.
for partial_yield in [
_core.checkpoint_if_cancelled,
_core.cancel_shielded_checkpoint,
]:
print(partial_yield)
with pytest.raises(AssertionError):
with assert_no_checkpoints():
await partial_yield()
# And both together also count as a checkpoint
with pytest.raises(AssertionError):
with assert_no_checkpoints():
await _core.checkpoint_if_cancelled()
await _core.cancel_shielded_checkpoint()
################################################################
async def test_Sequencer():
record = []
def t(val):
print(val)
record.append(val)
async def f1(seq):
async with seq(1):
t(("f1", 1))
async with seq(3):
t(("f1", 3))
async with seq(4):
t(("f1", 4))
async def f2(seq):
async with seq(0):
t(("f2", 0))
async with seq(2):
t(("f2", 2))
seq = Sequencer()
async with _core.open_nursery() as nursery:
nursery.start_soon(f1, seq)
nursery.start_soon(f2, seq)
async with seq(5):
await wait_all_tasks_blocked()
assert record == [("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4)]
seq = Sequencer()
# Catches us if we try to re-use a sequence point:
async with seq(0):
pass
with pytest.raises(RuntimeError):
async with seq(0):
pass # pragma: no cover
async def test_Sequencer_cancel():
# Killing a blocked task makes everything blow up
record = []
seq = Sequencer()
async def child(i):
with _core.CancelScope() as scope:
if i == 1:
scope.cancel()
try:
async with seq(i):
pass # pragma: no cover
except RuntimeError:
record.append("seq({}) RuntimeError".format(i))
async with _core.open_nursery() as nursery:
nursery.start_soon(child, 1)
nursery.start_soon(child, 2)
async with seq(0):
pass # pragma: no cover
assert record == ["seq(1) RuntimeError", "seq(2) RuntimeError"]
# Late arrivals also get errors
with pytest.raises(RuntimeError):
async with seq(3):
pass # pragma: no cover
################################################################
async def test__assert_raises():
with pytest.raises(AssertionError):
with _assert_raises(RuntimeError):
1 + 1
with pytest.raises(TypeError):
with _assert_raises(RuntimeError):
"foo" + 1
with _assert_raises(RuntimeError):
raise RuntimeError
# This is a private implementation detail, but it's complex enough to be worth
# testing directly
async def test__UnboundeByteQueue():
ubq = _UnboundedByteQueue()
ubq.put(b"123")
ubq.put(b"456")
assert ubq.get_nowait(1) == b"1"
assert ubq.get_nowait(10) == b"23456"
ubq.put(b"789")
assert ubq.get_nowait() == b"789"
with pytest.raises(_core.WouldBlock):
ubq.get_nowait(10)
with pytest.raises(_core.WouldBlock):
ubq.get_nowait()
with pytest.raises(TypeError):
ubq.put("string")
ubq.put(b"abc")
with assert_checkpoints():
assert await ubq.get(10) == b"abc"
ubq.put(b"def")
ubq.put(b"ghi")
with assert_checkpoints():
assert await ubq.get(1) == b"d"
with assert_checkpoints():
assert await ubq.get() == b"efghi"
async def putter(data):
await wait_all_tasks_blocked()
ubq.put(data)
async def getter(expect):
with assert_checkpoints():
assert await ubq.get() == expect
async with _core.open_nursery() as nursery:
nursery.start_soon(getter, b"xyz")
nursery.start_soon(putter, b"xyz")
# Two gets at the same time -> BusyResourceError
with pytest.raises(_core.BusyResourceError):
async with _core.open_nursery() as nursery:
nursery.start_soon(getter, b"asdf")
nursery.start_soon(getter, b"asdf")
# Closing
ubq.close()
with pytest.raises(_core.ClosedResourceError):
ubq.put(b"---")
assert ubq.get_nowait(10) == b""
assert ubq.get_nowait() == b""
assert await ubq.get(10) == b""
assert await ubq.get() == b""
# close is idempotent
ubq.close()
# close wakes up blocked getters
ubq2 = _UnboundedByteQueue()
async def closer():
await wait_all_tasks_blocked()
ubq2.close()
async with _core.open_nursery() as nursery:
nursery.start_soon(getter, b"")
nursery.start_soon(closer)
async def test_MemorySendStream():
mss = MemorySendStream()
async def do_send_all(data):
with assert_checkpoints():
await mss.send_all(data)
await do_send_all(b"123")
assert mss.get_data_nowait(1) == b"1"
assert mss.get_data_nowait() == b"23"
with assert_checkpoints():
await mss.wait_send_all_might_not_block()
with pytest.raises(_core.WouldBlock):
mss.get_data_nowait()
with pytest.raises(_core.WouldBlock):
mss.get_data_nowait(10)
await do_send_all(b"456")
with assert_checkpoints():
assert await mss.get_data() == b"456"
# Call send_all twice at once; one should get BusyResourceError and one
# should succeed. But we can't let the error propagate, because it might
# cause the other to be cancelled before it can finish doing its thing,
# and we don't know which one will get the error.
resource_busy_count = 0
async def do_send_all_count_resourcebusy():
nonlocal resource_busy_count
try:
await do_send_all(b"xxx")
except _core.BusyResourceError:
resource_busy_count += 1
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all_count_resourcebusy)
nursery.start_soon(do_send_all_count_resourcebusy)
assert resource_busy_count == 1
with assert_checkpoints():
await mss.aclose()
assert await mss.get_data() == b"xxx"
assert await mss.get_data() == b""
with pytest.raises(_core.ClosedResourceError):
await do_send_all(b"---")
# hooks
assert mss.send_all_hook is None
assert mss.wait_send_all_might_not_block_hook is None
assert mss.close_hook is None
record = []
async def send_all_hook():
# hook runs after send_all does its work (can pull data out)
assert mss2.get_data_nowait() == b"abc"
record.append("send_all_hook")
async def wait_send_all_might_not_block_hook():
record.append("wait_send_all_might_not_block_hook")
def close_hook():
record.append("close_hook")
mss2 = MemorySendStream(
send_all_hook, wait_send_all_might_not_block_hook, close_hook
)
assert mss2.send_all_hook is send_all_hook
assert mss2.wait_send_all_might_not_block_hook is wait_send_all_might_not_block_hook
assert mss2.close_hook is close_hook
await mss2.send_all(b"abc")
await mss2.wait_send_all_might_not_block()
await aclose_forcefully(mss2)
mss2.close()
assert record == [
"send_all_hook",
"wait_send_all_might_not_block_hook",
"close_hook",
"close_hook",
]
async def test_MemoryReceiveStream():
mrs = MemoryReceiveStream()
async def do_receive_some(max_bytes):
with assert_checkpoints():
return await mrs.receive_some(max_bytes)
mrs.put_data(b"abc")
assert await do_receive_some(1) == b"a"
assert await do_receive_some(10) == b"bc"
mrs.put_data(b"abc")
assert await do_receive_some(None) == b"abc"
with pytest.raises(_core.BusyResourceError):
async with _core.open_nursery() as nursery:
nursery.start_soon(do_receive_some, 10)
nursery.start_soon(do_receive_some, 10)
assert mrs.receive_some_hook is None
mrs.put_data(b"def")
mrs.put_eof()
mrs.put_eof()
assert await do_receive_some(10) == b"def"
assert await do_receive_some(10) == b""
assert await do_receive_some(10) == b""
with pytest.raises(_core.ClosedResourceError):
mrs.put_data(b"---")
async def receive_some_hook():
mrs2.put_data(b"xxx")
record = []
def close_hook():
record.append("closed")
mrs2 = MemoryReceiveStream(receive_some_hook, close_hook)
assert mrs2.receive_some_hook is receive_some_hook
assert mrs2.close_hook is close_hook
mrs2.put_data(b"yyy")
assert await mrs2.receive_some(10) == b"yyyxxx"
assert await mrs2.receive_some(10) == b"xxx"
assert await mrs2.receive_some(10) == b"xxx"
mrs2.put_data(b"zzz")
mrs2.receive_some_hook = None
assert await mrs2.receive_some(10) == b"zzz"
mrs2.put_data(b"lost on close")
with assert_checkpoints():
await mrs2.aclose()
assert record == ["closed"]
with pytest.raises(_core.ClosedResourceError):
await mrs2.receive_some(10)
async def test_MemoryRecvStream_closing():
mrs = MemoryReceiveStream()
# close with no pending data
mrs.close()
with pytest.raises(_core.ClosedResourceError):
assert await mrs.receive_some(10) == b""
# repeated closes ok
mrs.close()
# put_data now fails
with pytest.raises(_core.ClosedResourceError):
mrs.put_data(b"123")
mrs2 = MemoryReceiveStream()
# close with pending data
mrs2.put_data(b"xyz")
mrs2.close()
with pytest.raises(_core.ClosedResourceError):
await mrs2.receive_some(10)
async def test_memory_stream_pump():
mss = MemorySendStream()
mrs = MemoryReceiveStream()
# no-op if no data present
memory_stream_pump(mss, mrs)
await mss.send_all(b"123")
memory_stream_pump(mss, mrs)
assert await mrs.receive_some(10) == b"123"
await mss.send_all(b"456")
assert memory_stream_pump(mss, mrs, max_bytes=1)
assert await mrs.receive_some(10) == b"4"
assert memory_stream_pump(mss, mrs, max_bytes=1)
assert memory_stream_pump(mss, mrs, max_bytes=1)
assert not memory_stream_pump(mss, mrs, max_bytes=1)
assert await mrs.receive_some(10) == b"56"
mss.close()
memory_stream_pump(mss, mrs)
assert await mrs.receive_some(10) == b""
async def test_memory_stream_one_way_pair():
s, r = memory_stream_one_way_pair()
assert s.send_all_hook is not None
assert s.wait_send_all_might_not_block_hook is None
assert s.close_hook is not None
assert r.receive_some_hook is None
await s.send_all(b"123")
assert await r.receive_some(10) == b"123"
async def receiver(expected):
assert await r.receive_some(10) == expected
# This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook
async with _core.open_nursery() as nursery:
nursery.start_soon(receiver, b"abc")
await wait_all_tasks_blocked()
await s.send_all(b"abc")
# And this fails if we don't pump from close_hook
async with _core.open_nursery() as nursery:
nursery.start_soon(receiver, b"")
await wait_all_tasks_blocked()
await s.aclose()
s, r = memory_stream_one_way_pair()
async with _core.open_nursery() as nursery:
nursery.start_soon(receiver, b"")
await wait_all_tasks_blocked()
s.close()
s, r = memory_stream_one_way_pair()
old = s.send_all_hook
s.send_all_hook = None
await s.send_all(b"456")
async def cancel_after_idle(nursery):
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
async def check_for_cancel():
with pytest.raises(_core.Cancelled):
# This should block forever... or until cancelled. Even though we
# sent some data on the send stream.
await r.receive_some(10)
async with _core.open_nursery() as nursery:
nursery.start_soon(cancel_after_idle, nursery)
nursery.start_soon(check_for_cancel)
s.send_all_hook = old
await s.send_all(b"789")
assert await r.receive_some(10) == b"456789"
async def test_memory_stream_pair():
a, b = memory_stream_pair()
await a.send_all(b"123")
await b.send_all(b"abc")
assert await b.receive_some(10) == b"123"
assert await a.receive_some(10) == b"abc"
await a.send_eof()
assert await b.receive_some(10) == b""
async def sender():
await wait_all_tasks_blocked()
await b.send_all(b"xyz")
async def receiver():
assert await a.receive_some(10) == b"xyz"
async with _core.open_nursery() as nursery:
nursery.start_soon(receiver)
nursery.start_soon(sender)
async def test_memory_streams_with_generic_tests():
async def one_way_stream_maker():
return memory_stream_one_way_pair()
await check_one_way_stream(one_way_stream_maker, None)
async def half_closeable_stream_maker():
return memory_stream_pair()
await check_half_closeable_stream(half_closeable_stream_maker, None)
async def test_lockstep_streams_with_generic_tests():
async def one_way_stream_maker():
return lockstep_stream_one_way_pair()
await check_one_way_stream(one_way_stream_maker, one_way_stream_maker)
async def two_way_stream_maker():
return lockstep_stream_pair()
await check_two_way_stream(two_way_stream_maker, two_way_stream_maker)
async def test_open_stream_to_socket_listener():
async def check(listener):
async with listener:
client_stream = await open_stream_to_socket_listener(listener)
async with client_stream:
server_stream = await listener.accept()
async with server_stream:
await client_stream.send_all(b"x")
await server_stream.receive_some(1) == b"x"
# Listener bound to localhost
sock = tsocket.socket()
await sock.bind(("127.0.0.1", 0))
sock.listen(10)
await check(SocketListener(sock))
# Listener bound to IPv4 wildcard (needs special handling)
sock = tsocket.socket()
await sock.bind(("0.0.0.0", 0))
sock.listen(10)
await check(SocketListener(sock))
if can_bind_ipv6:
# Listener bound to IPv6 wildcard (needs special handling)
sock = tsocket.socket(family=tsocket.AF_INET6)
await sock.bind(("::", 0))
sock.listen(10)
await check(SocketListener(sock))
if hasattr(tsocket, "AF_UNIX"):
# Listener bound to Unix-domain socket
sock = tsocket.socket(family=tsocket.AF_UNIX)
# can't use pytest's tmpdir; if we try then macOS says "OSError:
# AF_UNIX path too long"
with tempfile.TemporaryDirectory() as tmpdir:
path = "{}/sock".format(tmpdir)
await sock.bind(path)
sock.listen(10)
await check(SocketListener(sock))

View File

@@ -0,0 +1,587 @@
import threading
import queue as stdlib_queue
import time
import weakref
import pytest
from trio._core import TrioToken, current_trio_token
from .. import _core
from .. import Event, CapacityLimiter, sleep
from ..testing import wait_all_tasks_blocked
from .._core.tests.tutil import buggy_pypy_asyncgens
from .._threads import (
to_thread_run_sync,
current_default_thread_limiter,
from_thread_run,
from_thread_run_sync,
)
from .._core.tests.test_ki import ki_self
async def test_do_in_trio_thread():
trio_thread = threading.current_thread()
async def check_case(do_in_trio_thread, fn, expected, trio_token=None):
record = []
def threadfn():
try:
record.append(("start", threading.current_thread()))
x = do_in_trio_thread(fn, record, trio_token=trio_token)
record.append(("got", x))
except BaseException as exc:
print(exc)
record.append(("error", type(exc)))
child_thread = threading.Thread(target=threadfn, daemon=True)
child_thread.start()
while child_thread.is_alive():
print("yawn")
await sleep(0.01)
assert record == [("start", child_thread), ("f", trio_thread), expected]
token = _core.current_trio_token()
def f(record):
assert not _core.currently_ki_protected()
record.append(("f", threading.current_thread()))
return 2
await check_case(from_thread_run_sync, f, ("got", 2), trio_token=token)
def f(record):
assert not _core.currently_ki_protected()
record.append(("f", threading.current_thread()))
raise ValueError
await check_case(from_thread_run_sync, f, ("error", ValueError), trio_token=token)
async def f(record):
assert not _core.currently_ki_protected()
await _core.checkpoint()
record.append(("f", threading.current_thread()))
return 3
await check_case(from_thread_run, f, ("got", 3), trio_token=token)
async def f(record):
assert not _core.currently_ki_protected()
await _core.checkpoint()
record.append(("f", threading.current_thread()))
raise KeyError
await check_case(from_thread_run, f, ("error", KeyError), trio_token=token)
async def test_do_in_trio_thread_from_trio_thread():
with pytest.raises(RuntimeError):
from_thread_run_sync(lambda: None) # pragma: no branch
async def foo(): # pragma: no cover
pass
with pytest.raises(RuntimeError):
from_thread_run(foo)
def test_run_in_trio_thread_ki():
# if we get a control-C during a run_in_trio_thread, then it propagates
# back to the caller (slick!)
record = set()
async def check_run_in_trio_thread():
token = _core.current_trio_token()
def trio_thread_fn():
print("in Trio thread")
assert not _core.currently_ki_protected()
print("ki_self")
try:
ki_self()
finally:
import sys
print("finally", sys.exc_info())
async def trio_thread_afn():
trio_thread_fn()
def external_thread_fn():
try:
print("running")
from_thread_run_sync(trio_thread_fn, trio_token=token)
except KeyboardInterrupt:
print("ok1")
record.add("ok1")
try:
from_thread_run(trio_thread_afn, trio_token=token)
except KeyboardInterrupt:
print("ok2")
record.add("ok2")
thread = threading.Thread(target=external_thread_fn)
thread.start()
print("waiting")
while thread.is_alive():
await sleep(0.01)
print("waited, joining")
thread.join()
print("done")
_core.run(check_run_in_trio_thread)
assert record == {"ok1", "ok2"}
def test_await_in_trio_thread_while_main_exits():
record = []
ev = Event()
async def trio_fn():
record.append("sleeping")
ev.set()
await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
def thread_fn(token):
try:
from_thread_run(trio_fn, trio_token=token)
except _core.Cancelled:
record.append("cancelled")
async def main():
token = _core.current_trio_token()
thread = threading.Thread(target=thread_fn, args=(token,))
thread.start()
await ev.wait()
assert record == ["sleeping"]
return thread
thread = _core.run(main)
thread.join()
assert record == ["sleeping", "cancelled"]
async def test_run_in_worker_thread():
trio_thread = threading.current_thread()
def f(x):
return (x, threading.current_thread())
x, child_thread = await to_thread_run_sync(f, 1)
assert x == 1
assert child_thread != trio_thread
def g():
raise ValueError(threading.current_thread())
with pytest.raises(ValueError) as excinfo:
await to_thread_run_sync(g)
print(excinfo.value.args)
assert excinfo.value.args[0] != trio_thread
async def test_run_in_worker_thread_cancellation():
register = [None]
def f(q):
# Make the thread block for a controlled amount of time
register[0] = "blocking"
q.get()
register[0] = "finished"
async def child(q, cancellable):
record.append("start")
try:
return await to_thread_run_sync(f, q, cancellable=cancellable)
finally:
record.append("exit")
record = []
q = stdlib_queue.Queue()
async with _core.open_nursery() as nursery:
nursery.start_soon(child, q, True)
# Give it a chance to get started. (This is important because
# to_thread_run_sync does a checkpoint_if_cancelled before
# blocking on the thread, and we don't want to trigger this.)
await wait_all_tasks_blocked()
assert record == ["start"]
# Then cancel it.
nursery.cancel_scope.cancel()
# The task exited, but the thread didn't:
assert register[0] != "finished"
# Put the thread out of its misery:
q.put(None)
while register[0] != "finished":
time.sleep(0.01)
# This one can't be cancelled
record = []
register[0] = None
async with _core.open_nursery() as nursery:
nursery.start_soon(child, q, False)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
with _core.CancelScope(shield=True):
for _ in range(10):
await _core.checkpoint()
# It's still running
assert record == ["start"]
q.put(None)
# Now it exits
# But if we cancel *before* it enters, the entry is itself a cancellation
# point
with _core.CancelScope() as scope:
scope.cancel()
await child(q, False)
assert scope.cancelled_caught
# Make sure that if trio.run exits, and then the thread finishes, then that's
# handled gracefully. (Requires that the thread result machinery be prepared
# for call_soon to raise RunFinishedError.)
def test_run_in_worker_thread_abandoned(capfd, monkeypatch):
monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01)
q1 = stdlib_queue.Queue()
q2 = stdlib_queue.Queue()
def thread_fn():
q1.get()
q2.put(threading.current_thread())
async def main():
async def child():
await to_thread_run_sync(thread_fn, cancellable=True)
async with _core.open_nursery() as nursery:
nursery.start_soon(child)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
_core.run(main)
q1.put(None)
# This makes sure:
# - the thread actually ran
# - that thread has finished before we check for its output
thread = q2.get()
while thread.is_alive():
time.sleep(0.01) # pragma: no cover
# Make sure we don't have a "Exception in thread ..." dump to the console:
out, err = capfd.readouterr()
assert "Exception in thread" not in out
assert "Exception in thread" not in err
@pytest.mark.parametrize("MAX", [3, 5, 10])
@pytest.mark.parametrize("cancel", [False, True])
@pytest.mark.parametrize("use_default_limiter", [False, True])
async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter):
# This test is a bit tricky. The goal is to make sure that if we set
# limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever
# running at a time, even if there are more concurrent calls to
# to_thread_run_sync, and even if some of those are cancelled. And
# also to make sure that the default limiter actually limits.
COUNT = 2 * MAX
gate = threading.Event()
lock = threading.Lock()
if use_default_limiter:
c = current_default_thread_limiter()
orig_total_tokens = c.total_tokens
c.total_tokens = MAX
limiter_arg = None
else:
c = CapacityLimiter(MAX)
orig_total_tokens = MAX
limiter_arg = c
try:
# We used to use regular variables and 'nonlocal' here, but it turns
# out that it's not safe to assign to closed-over variables that are
# visible in multiple threads, at least as of CPython 3.6 and PyPy
# 5.8:
#
# https://bugs.python.org/issue30744
# https://bitbucket.org/pypy/pypy/issues/2591/
#
# Mutating them in-place is OK though (as long as you use proper
# locking etc.).
class state:
pass
state.ran = 0
state.high_water = 0
state.running = 0
state.parked = 0
token = _core.current_trio_token()
def thread_fn(cancel_scope):
print("thread_fn start")
from_thread_run_sync(cancel_scope.cancel, trio_token=token)
with lock:
state.ran += 1
state.running += 1
state.high_water = max(state.high_water, state.running)
# The Trio thread below watches this value and uses it as a
# signal that all the stats calculations have finished.
state.parked += 1
gate.wait()
with lock:
state.parked -= 1
state.running -= 1
print("thread_fn exiting")
async def run_thread(event):
with _core.CancelScope() as cancel_scope:
await to_thread_run_sync(
thread_fn, cancel_scope, limiter=limiter_arg, cancellable=cancel
)
print("run_thread finished, cancelled:", cancel_scope.cancelled_caught)
event.set()
async with _core.open_nursery() as nursery:
print("spawning")
events = []
for i in range(COUNT):
events.append(Event())
nursery.start_soon(run_thread, events[-1])
await wait_all_tasks_blocked()
# In the cancel case, we in particular want to make sure that the
# cancelled tasks don't release the semaphore. So let's wait until
# at least one of them has exited, and that everything has had a
# chance to settle down from this, before we check that everyone
# who's supposed to be waiting is waiting:
if cancel:
print("waiting for first cancellation to clear")
await events[0].wait()
await wait_all_tasks_blocked()
# Then wait until the first MAX threads are parked in gate.wait(),
# and the next MAX threads are parked on the semaphore, to make
# sure no-one is sneaking past, and to make sure the high_water
# check below won't fail due to scheduling issues. (It could still
# fail if too many threads are let through here.)
while state.parked != MAX or c.statistics().tasks_waiting != MAX:
await sleep(0.01) # pragma: no cover
# Then release the threads
gate.set()
assert state.high_water == MAX
if cancel:
# Some threads might still be running; need to wait to them to
# finish before checking that all threads ran. We can do this
# using the CapacityLimiter.
while c.borrowed_tokens > 0:
await sleep(0.01) # pragma: no cover
assert state.ran == COUNT
assert state.running == 0
finally:
c.total_tokens = orig_total_tokens
async def test_run_in_worker_thread_custom_limiter():
# Basically just checking that we only call acquire_on_behalf_of and
# release_on_behalf_of, since that's part of our documented API.
record = []
class CustomLimiter:
async def acquire_on_behalf_of(self, borrower):
record.append("acquire")
self._borrower = borrower
def release_on_behalf_of(self, borrower):
record.append("release")
assert borrower == self._borrower
await to_thread_run_sync(lambda: None, limiter=CustomLimiter())
assert record == ["acquire", "release"]
async def test_run_in_worker_thread_limiter_error():
record = []
class BadCapacityLimiter:
async def acquire_on_behalf_of(self, borrower):
record.append("acquire")
def release_on_behalf_of(self, borrower):
record.append("release")
raise ValueError
bs = BadCapacityLimiter()
with pytest.raises(ValueError) as excinfo:
await to_thread_run_sync(lambda: None, limiter=bs)
assert excinfo.value.__context__ is None
assert record == ["acquire", "release"]
record = []
# If the original function raised an error, then the semaphore error
# chains with it
d = {}
with pytest.raises(ValueError) as excinfo:
await to_thread_run_sync(lambda: d["x"], limiter=bs)
assert isinstance(excinfo.value.__context__, KeyError)
assert record == ["acquire", "release"]
async def test_run_in_worker_thread_fail_to_spawn(monkeypatch):
# Test the unlikely but possible case where trying to spawn a thread fails
def bad_start(self, *args):
raise RuntimeError("the engines canna take it captain")
monkeypatch.setattr(_core._thread_cache.ThreadCache, "start_thread_soon", bad_start)
limiter = current_default_thread_limiter()
assert limiter.borrowed_tokens == 0
# We get an appropriate error, and the limiter is cleanly released
with pytest.raises(RuntimeError) as excinfo:
await to_thread_run_sync(lambda: None) # pragma: no cover
assert "engines" in str(excinfo.value)
assert limiter.borrowed_tokens == 0
async def test_trio_to_thread_run_sync_token():
# Test that to_thread_run_sync automatically injects the current trio token
# into a spawned thread
def thread_fn():
callee_token = from_thread_run_sync(_core.current_trio_token)
return callee_token
caller_token = _core.current_trio_token()
callee_token = await to_thread_run_sync(thread_fn)
assert callee_token == caller_token
async def test_trio_to_thread_run_sync_expected_error():
# Test correct error when passed async function
async def async_fn(): # pragma: no cover
pass
with pytest.raises(TypeError, match="expected a sync function"):
await to_thread_run_sync(async_fn)
async def test_trio_from_thread_run_sync():
# Test that to_thread_run_sync correctly "hands off" the trio token to
# trio.from_thread.run_sync()
def thread_fn():
trio_time = from_thread_run_sync(_core.current_time)
return trio_time
trio_time = await to_thread_run_sync(thread_fn)
assert isinstance(trio_time, float)
# Test correct error when passed async function
async def async_fn(): # pragma: no cover
pass
def thread_fn():
from_thread_run_sync(async_fn)
with pytest.raises(TypeError, match="expected a sync function"):
await to_thread_run_sync(thread_fn)
async def test_trio_from_thread_run():
# Test that to_thread_run_sync correctly "hands off" the trio token to
# trio.from_thread.run()
record = []
async def back_in_trio_fn():
_core.current_time() # implicitly checks that we're in trio
record.append("back in trio")
def thread_fn():
record.append("in thread")
from_thread_run(back_in_trio_fn)
await to_thread_run_sync(thread_fn)
assert record == ["in thread", "back in trio"]
# Test correct error when passed sync function
def sync_fn(): # pragma: no cover
pass
with pytest.raises(TypeError, match="appears to be synchronous"):
await to_thread_run_sync(from_thread_run, sync_fn)
async def test_trio_from_thread_token():
# Test that to_thread_run_sync and spawned trio.from_thread.run_sync()
# share the same Trio token
def thread_fn():
callee_token = from_thread_run_sync(_core.current_trio_token)
return callee_token
caller_token = _core.current_trio_token()
callee_token = await to_thread_run_sync(thread_fn)
assert callee_token == caller_token
async def test_trio_from_thread_token_kwarg():
# Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can
# use an explicitly defined token
def thread_fn(token):
callee_token = from_thread_run_sync(_core.current_trio_token, trio_token=token)
return callee_token
caller_token = _core.current_trio_token()
callee_token = await to_thread_run_sync(thread_fn, caller_token)
assert callee_token == caller_token
async def test_from_thread_no_token():
# Test that a "raw call" to trio.from_thread.run() fails because no token
# has been provided
with pytest.raises(RuntimeError):
from_thread_run_sync(_core.current_time)
def test_run_fn_as_system_task_catched_badly_typed_token():
with pytest.raises(RuntimeError):
from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype")
async def test_from_thread_inside_trio_thread():
def not_called(): # pragma: no cover
assert False
trio_token = _core.current_trio_token()
with pytest.raises(RuntimeError):
from_thread_run_sync(not_called, trio_token=trio_token)
@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy")
def test_from_thread_run_during_shutdown():
save = []
record = []
async def agen():
try:
yield
finally:
with pytest.raises(_core.RunFinishedError), _core.CancelScope(shield=True):
await to_thread_run_sync(from_thread_run, sleep, 0)
record.append("ok")
async def main():
save.append(agen())
await save[-1].asend(None)
_core.run(main)
assert record == ["ok"]
async def test_trio_token_weak_referenceable():
token = current_trio_token()
assert isinstance(token, TrioToken)
weak_reference = weakref.ref(token)
assert token is weak_reference()

View File

@@ -0,0 +1,104 @@
import outcome
import pytest
import time
from .._core.tests.tutil import slow
from .. import _core
from ..testing import assert_checkpoints
from .._timeouts import *
async def check_takes_about(f, expected_dur):
start = time.perf_counter()
result = await outcome.acapture(f)
dur = time.perf_counter() - start
print(dur / expected_dur)
# 1.5 is an arbitrary fudge factor because there's always some delay
# between when we become eligible to wake up and when we actually do. We
# used to sleep for 0.05, and regularly observed overruns of 1.6x on
# Appveyor, and then started seeing overruns of 2.3x on Travis's macOS, so
# now we bumped up the sleep to 1 second, marked the tests as slow, and
# hopefully now the proportional error will be less huge.
#
# We also also for durations that are a hair shorter than expected. For
# example, here's a run on Windows where a 1.0 second sleep was measured
# to take 0.9999999999999858 seconds:
# https://ci.appveyor.com/project/njsmith/trio/build/1.0.768/job/3lbdyxl63q3h9s21
# I believe that what happened here is that Windows's low clock resolution
# meant that our calls to time.monotonic() returned exactly the same
# values as the calls inside the actual run loop, but the two subtractions
# returned slightly different values because the run loop's clock adds a
# random floating point offset to both times, which should cancel out, but
# lol floating point we got slightly different rounding errors. (That
# value above is exactly 128 ULPs below 1.0, which would make sense if it
# started as a 1 ULP error at a different dynamic range.)
assert (1 - 1e-8) <= (dur / expected_dur) < 1.5
return result.unwrap()
# How long to (attempt to) sleep for when testing. Smaller numbers make the
# test suite go faster.
TARGET = 1.0
@slow
async def test_sleep():
async def sleep_1():
await sleep_until(_core.current_time() + TARGET)
await check_takes_about(sleep_1, TARGET)
async def sleep_2():
await sleep(TARGET)
await check_takes_about(sleep_2, TARGET)
with pytest.raises(ValueError):
await sleep(-1)
with assert_checkpoints():
await sleep(0)
# This also serves as a test of the trivial move_on_at
with move_on_at(_core.current_time()):
with pytest.raises(_core.Cancelled):
await sleep(0)
@slow
async def test_move_on_after():
with pytest.raises(ValueError):
with move_on_after(-1):
pass # pragma: no cover
async def sleep_3():
with move_on_after(TARGET):
await sleep(100)
await check_takes_about(sleep_3, TARGET)
@slow
async def test_fail():
async def sleep_4():
with fail_at(_core.current_time() + TARGET):
await sleep(100)
with pytest.raises(TooSlowError):
await check_takes_about(sleep_4, TARGET)
with fail_at(_core.current_time() + 100):
await sleep(0)
async def sleep_5():
with fail_after(TARGET):
await sleep(100)
with pytest.raises(TooSlowError):
await check_takes_about(sleep_5, TARGET)
with fail_after(100):
await sleep(0)
with pytest.raises(ValueError):
with fail_after(-1):
pass # pragma: no cover

View File

@@ -0,0 +1,265 @@
import errno
import select
import os
import tempfile
import sys
import pytest
from .._core.tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken
from .. import _core, move_on_after
from ..testing import wait_all_tasks_blocked, check_one_way_stream
posix = os.name == "posix"
pytestmark = pytest.mark.skipif(not posix, reason="posix only")
if posix:
from .._unix_pipes import FdStream
else:
with pytest.raises(ImportError):
from .._unix_pipes import FdStream
# Have to use quoted types so import doesn't crash on windows
async def make_pipe() -> "Tuple[FdStream, FdStream]":
"""Makes a new pair of pipes."""
(r, w) = os.pipe()
return FdStream(w), FdStream(r)
async def make_clogged_pipe():
s, r = await make_pipe()
try:
while True:
# We want to totally fill up the pipe buffer.
# This requires working around a weird feature that POSIX pipes
# have.
# If you do a write of <= PIPE_BUF bytes, then it's guaranteed
# to either complete entirely, or not at all. So if we tried to
# write PIPE_BUF bytes, and the buffer's free space is only
# PIPE_BUF/2, then the write will raise BlockingIOError... even
# though a smaller write could still succeed! To avoid this,
# make sure to write >PIPE_BUF bytes each time, which disables
# the special behavior.
# For details, search for PIPE_BUF here:
# http://pubs.opengroup.org/onlinepubs/9699919799/functions/write.html
# for the getattr:
# https://bitbucket.org/pypy/pypy/issues/2876/selectpipe_buf-is-missing-on-pypy3
buf_size = getattr(select, "PIPE_BUF", 8192)
os.write(s.fileno(), b"x" * buf_size * 2)
except BlockingIOError:
pass
return s, r
async def test_send_pipe():
r, w = os.pipe()
async with FdStream(w) as send:
assert send.fileno() == w
await send.send_all(b"123")
assert (os.read(r, 8)) == b"123"
os.close(r)
async def test_receive_pipe():
r, w = os.pipe()
async with FdStream(r) as recv:
assert (recv.fileno()) == r
os.write(w, b"123")
assert (await recv.receive_some(8)) == b"123"
os.close(w)
async def test_pipes_combined():
write, read = await make_pipe()
count = 2 ** 20
async def sender():
big = bytearray(count)
await write.send_all(big)
async def reader():
await wait_all_tasks_blocked()
received = 0
while received < count:
received += len(await read.receive_some(4096))
assert received == count
async with _core.open_nursery() as n:
n.start_soon(sender)
n.start_soon(reader)
await read.aclose()
await write.aclose()
async def test_pipe_errors():
with pytest.raises(TypeError):
FdStream(None)
r, w = os.pipe()
os.close(w)
async with FdStream(r) as s:
with pytest.raises(ValueError):
await s.receive_some(0)
async def test_del():
w, r = await make_pipe()
f1, f2 = w.fileno(), r.fileno()
del w, r
gc_collect_harder()
with pytest.raises(OSError) as excinfo:
os.close(f1)
assert excinfo.value.errno == errno.EBADF
with pytest.raises(OSError) as excinfo:
os.close(f2)
assert excinfo.value.errno == errno.EBADF
async def test_async_with():
w, r = await make_pipe()
async with w, r:
pass
assert w.fileno() == -1
assert r.fileno() == -1
with pytest.raises(OSError) as excinfo:
os.close(w.fileno())
assert excinfo.value.errno == errno.EBADF
with pytest.raises(OSError) as excinfo:
os.close(r.fileno())
assert excinfo.value.errno == errno.EBADF
async def test_misdirected_aclose_regression():
# https://github.com/python-trio/trio/issues/661#issuecomment-456582356
w, r = await make_pipe()
old_r_fd = r.fileno()
# Close the original objects
await w.aclose()
await r.aclose()
# Do a little dance to get a new pipe whose receive handle matches the old
# receive handle.
r2_fd, w2_fd = os.pipe()
if r2_fd != old_r_fd: # pragma: no cover
os.dup2(r2_fd, old_r_fd)
os.close(r2_fd)
async with FdStream(old_r_fd) as r2:
assert r2.fileno() == old_r_fd
# And now set up a background task that's working on the new receive
# handle
async def expect_eof():
assert await r2.receive_some(10) == b""
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_eof)
await wait_all_tasks_blocked()
# Here's the key test: does calling aclose() again on the *old*
# handle, cause the task blocked on the *new* handle to raise
# ClosedResourceError?
await r.aclose()
await wait_all_tasks_blocked()
# Guess we survived! Close the new write handle so that the task
# gets an EOF and can exit cleanly.
os.close(w2_fd)
async def test_close_at_bad_time_for_receive_some(monkeypatch):
# We used to have race conditions where if one task was using the pipe,
# and another closed it at *just* the wrong moment, it would give an
# unexpected error instead of ClosedResourceError:
# https://github.com/python-trio/trio/issues/661
#
# This tests what happens if the pipe gets closed in the moment *between*
# when receive_some wakes up, and when it tries to call os.read
async def expect_closedresourceerror():
with pytest.raises(_core.ClosedResourceError):
await r.receive_some(10)
orig_wait_readable = _core._run.TheIOManager.wait_readable
async def patched_wait_readable(*args, **kwargs):
await orig_wait_readable(*args, **kwargs)
await r.aclose()
monkeypatch.setattr(_core._run.TheIOManager, "wait_readable", patched_wait_readable)
s, r = await make_pipe()
async with s, r:
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_closedresourceerror)
await wait_all_tasks_blocked()
# Trigger everything by waking up the receiver
await s.send_all(b"x")
async def test_close_at_bad_time_for_send_all(monkeypatch):
# We used to have race conditions where if one task was using the pipe,
# and another closed it at *just* the wrong moment, it would give an
# unexpected error instead of ClosedResourceError:
# https://github.com/python-trio/trio/issues/661
#
# This tests what happens if the pipe gets closed in the moment *between*
# when send_all wakes up, and when it tries to call os.write
async def expect_closedresourceerror():
with pytest.raises(_core.ClosedResourceError):
await s.send_all(b"x" * 100)
orig_wait_writable = _core._run.TheIOManager.wait_writable
async def patched_wait_writable(*args, **kwargs):
await orig_wait_writable(*args, **kwargs)
await s.aclose()
monkeypatch.setattr(_core._run.TheIOManager, "wait_writable", patched_wait_writable)
s, r = await make_clogged_pipe()
async with s, r:
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_closedresourceerror)
await wait_all_tasks_blocked()
# Trigger everything by waking up the sender
await r.receive_some(10000)
# On FreeBSD, directories are readable, and we haven't found any other trick
# for making an unreadable fd, so there's no way to run this test. Fortunately
# the logic this is testing doesn't depend on the platform, so testing on
# other platforms is probably good enough.
@pytest.mark.skipif(
sys.platform.startswith("freebsd"),
reason="no way to make read() return a bizarro error on FreeBSD",
)
async def test_bizarro_OSError_from_receive():
# Make sure that if the read syscall returns some bizarro error, then we
# get a BrokenResourceError. This is incredibly unlikely; there's almost
# no way to trigger a failure here intentionally (except for EBADF, but we
# exploit that to detect file closure, so it takes a different path). So
# we set up a strange scenario where the pipe fd somehow transmutes into a
# directory fd, causing os.read to raise IsADirectoryError (yes, that's a
# real built-in exception type).
s, r = await make_pipe()
async with s, r:
dir_fd = os.open("/", os.O_DIRECTORY, 0)
try:
os.dup2(dir_fd, r.fileno())
with pytest.raises(_core.BrokenResourceError):
await r.receive_some(10)
finally:
os.close(dir_fd)
@skip_if_fbsd_pipes_broken
async def test_pipe_fully():
await check_one_way_stream(make_pipe, make_clogged_pipe)

View File

@@ -0,0 +1,189 @@
import signal
import pytest
import trio
from .. import _core
from .._core.tests.tutil import (
ignore_coroutine_never_awaited_warnings,
create_asyncio_future_in_new_loop,
)
from .._util import (
signal_raise,
ConflictDetector,
is_main_thread,
coroutine_or_error,
generic_function,
Final,
NoPublicConstructor,
)
from ..testing import wait_all_tasks_blocked
def test_signal_raise():
record = []
def handler(signum, _):
record.append(signum)
old = signal.signal(signal.SIGFPE, handler)
try:
signal_raise(signal.SIGFPE)
finally:
signal.signal(signal.SIGFPE, old)
assert record == [signal.SIGFPE]
async def test_ConflictDetector():
ul1 = ConflictDetector("ul1")
ul2 = ConflictDetector("ul2")
with ul1:
with ul2:
print("ok")
with pytest.raises(_core.BusyResourceError) as excinfo:
with ul1:
with ul1:
pass # pragma: no cover
assert "ul1" in str(excinfo.value)
async def wait_with_ul1():
with ul1:
await wait_all_tasks_blocked()
with pytest.raises(_core.BusyResourceError) as excinfo:
async with _core.open_nursery() as nursery:
nursery.start_soon(wait_with_ul1)
nursery.start_soon(wait_with_ul1)
assert "ul1" in str(excinfo.value)
def test_module_metadata_is_fixed_up():
import trio
import trio.testing
assert trio.Cancelled.__module__ == "trio"
assert trio.open_nursery.__module__ == "trio"
assert trio.abc.Stream.__module__ == "trio.abc"
assert trio.lowlevel.wait_task_rescheduled.__module__ == "trio.lowlevel"
assert trio.testing.trio_test.__module__ == "trio.testing"
# Also check methods
assert trio.lowlevel.ParkingLot.__init__.__module__ == "trio.lowlevel"
assert trio.abc.Stream.send_all.__module__ == "trio.abc"
# And names
assert trio.Cancelled.__name__ == "Cancelled"
assert trio.Cancelled.__qualname__ == "Cancelled"
assert trio.abc.SendStream.send_all.__name__ == "send_all"
assert trio.abc.SendStream.send_all.__qualname__ == "SendStream.send_all"
assert trio.to_thread.__name__ == "trio.to_thread"
assert trio.to_thread.run_sync.__name__ == "run_sync"
assert trio.to_thread.run_sync.__qualname__ == "run_sync"
async def test_is_main_thread():
assert is_main_thread()
def not_main_thread():
assert not is_main_thread()
await trio.to_thread.run_sync(not_main_thread)
# @coroutine is deprecated since python 3.8, which is fine with us.
@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning")
def test_coroutine_or_error():
class Deferred:
"Just kidding"
with ignore_coroutine_never_awaited_warnings():
async def f(): # pragma: no cover
pass
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(f())
assert "expecting an async function" in str(excinfo.value)
import asyncio
@asyncio.coroutine
def generator_based_coro(): # pragma: no cover
yield from asyncio.sleep(1)
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(generator_based_coro())
assert "asyncio" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(create_asyncio_future_in_new_loop())
assert "asyncio" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(create_asyncio_future_in_new_loop)
assert "asyncio" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(Deferred())
assert "twisted" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(lambda: Deferred())
assert "twisted" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(len, [[1, 2, 3]])
assert "appears to be synchronous" in str(excinfo.value)
async def async_gen(arg): # pragma: no cover
yield
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(async_gen, [0])
msg = "expected an async function but got an async generator"
assert msg in str(excinfo.value)
# Make sure no references are kept around to keep anything alive
del excinfo
def test_generic_function():
@generic_function
def test_func(arg):
"""Look, a docstring!"""
return arg
assert test_func is test_func[int] is test_func[int, str]
assert test_func(42) == test_func[int](42) == 42
assert test_func.__doc__ == "Look, a docstring!"
assert test_func.__qualname__ == "test_generic_function.<locals>.test_func"
assert test_func.__name__ == "test_func"
assert test_func.__module__ == __name__
def test_final_metaclass():
class FinalClass(metaclass=Final):
pass
with pytest.raises(TypeError):
class SubClass(FinalClass):
pass
def test_no_public_constructor_metaclass():
class SpecialClass(metaclass=NoPublicConstructor):
pass
with pytest.raises(TypeError):
SpecialClass()
with pytest.raises(TypeError):
class SubClass(SpecialClass):
pass
# Private constructor should not raise
assert isinstance(SpecialClass._create(), SpecialClass)

View File

@@ -0,0 +1,220 @@
import os
import pytest
on_windows = os.name == "nt"
# Mark all the tests in this file as being windows-only
pytestmark = pytest.mark.skipif(not on_windows, reason="windows only")
from .._core.tests.tutil import slow
import trio
from .. import _core
from .. import _timeouts
if on_windows:
from .._core._windows_cffi import ffi, kernel32
from .._wait_for_object import (
WaitForSingleObject,
WaitForMultipleObjects_sync,
)
async def test_WaitForMultipleObjects_sync():
# This does a series of tests where we set/close the handle before
# initiating the waiting for it.
#
# Note that closing the handle (not signaling) will cause the
# *initiation* of a wait to return immediately. But closing a handle
# that is already being waited on will not stop whatever is waiting
# for it.
# One handle
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.SetEvent(handle1)
WaitForMultipleObjects_sync(handle1)
kernel32.CloseHandle(handle1)
print("test_WaitForMultipleObjects_sync one OK")
# Two handles, signal first
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.SetEvent(handle1)
WaitForMultipleObjects_sync(handle1, handle2)
kernel32.CloseHandle(handle1)
kernel32.CloseHandle(handle2)
print("test_WaitForMultipleObjects_sync set first OK")
# Two handles, signal second
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.SetEvent(handle2)
WaitForMultipleObjects_sync(handle1, handle2)
kernel32.CloseHandle(handle1)
kernel32.CloseHandle(handle2)
print("test_WaitForMultipleObjects_sync set second OK")
# Two handles, close first
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.CloseHandle(handle1)
with pytest.raises(OSError):
WaitForMultipleObjects_sync(handle1, handle2)
kernel32.CloseHandle(handle2)
print("test_WaitForMultipleObjects_sync close first OK")
# Two handles, close second
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.CloseHandle(handle2)
with pytest.raises(OSError):
WaitForMultipleObjects_sync(handle1, handle2)
kernel32.CloseHandle(handle1)
print("test_WaitForMultipleObjects_sync close second OK")
@slow
async def test_WaitForMultipleObjects_sync_slow():
# This does a series of test in which the main thread sync-waits for
# handles, while we spawn a thread to set the handles after a short while.
TIMEOUT = 0.3
# One handle
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
t0 = _core.current_time()
async with _core.open_nursery() as nursery:
nursery.start_soon(
trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1
)
await _timeouts.sleep(TIMEOUT)
# If we would comment the line below, the above thread will be stuck,
# and Trio won't exit this scope
kernel32.SetEvent(handle1)
t1 = _core.current_time()
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
kernel32.CloseHandle(handle1)
print("test_WaitForMultipleObjects_sync_slow one OK")
# Two handles, signal first
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
t0 = _core.current_time()
async with _core.open_nursery() as nursery:
nursery.start_soon(
trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2
)
await _timeouts.sleep(TIMEOUT)
kernel32.SetEvent(handle1)
t1 = _core.current_time()
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
kernel32.CloseHandle(handle1)
kernel32.CloseHandle(handle2)
print("test_WaitForMultipleObjects_sync_slow thread-set first OK")
# Two handles, signal second
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
t0 = _core.current_time()
async with _core.open_nursery() as nursery:
nursery.start_soon(
trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2
)
await _timeouts.sleep(TIMEOUT)
kernel32.SetEvent(handle2)
t1 = _core.current_time()
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
kernel32.CloseHandle(handle1)
kernel32.CloseHandle(handle2)
print("test_WaitForMultipleObjects_sync_slow thread-set second OK")
async def test_WaitForSingleObject():
# This does a series of test for setting/closing the handle before
# initiating the wait.
# Test already set
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.SetEvent(handle)
await WaitForSingleObject(handle) # should return at once
kernel32.CloseHandle(handle)
print("test_WaitForSingleObject already set OK")
# Test already set, as int
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle_int = int(ffi.cast("intptr_t", handle))
kernel32.SetEvent(handle)
await WaitForSingleObject(handle_int) # should return at once
kernel32.CloseHandle(handle)
print("test_WaitForSingleObject already set OK")
# Test already closed
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.CloseHandle(handle)
with pytest.raises(OSError):
await WaitForSingleObject(handle) # should return at once
print("test_WaitForSingleObject already closed OK")
# Not a handle
with pytest.raises(TypeError):
await WaitForSingleObject("not a handle") # Wrong type
# with pytest.raises(OSError):
# await WaitForSingleObject(99) # If you're unlucky, it actually IS a handle :(
print("test_WaitForSingleObject not a handle OK")
@slow
async def test_WaitForSingleObject_slow():
# This does a series of test for setting the handle in another task,
# and cancelling the wait task.
# Set the timeout used in the tests. We test the waiting time against
# the timeout with a certain margin.
TIMEOUT = 0.3
async def signal_soon_async(handle):
await _timeouts.sleep(TIMEOUT)
kernel32.SetEvent(handle)
# Test handle is SET after TIMEOUT in separate coroutine
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
t0 = _core.current_time()
async with _core.open_nursery() as nursery:
nursery.start_soon(WaitForSingleObject, handle)
nursery.start_soon(signal_soon_async, handle)
kernel32.CloseHandle(handle)
t1 = _core.current_time()
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
print("test_WaitForSingleObject_slow set from task OK")
# Test handle is SET after TIMEOUT in separate coroutine, as int
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle_int = int(ffi.cast("intptr_t", handle))
t0 = _core.current_time()
async with _core.open_nursery() as nursery:
nursery.start_soon(WaitForSingleObject, handle_int)
nursery.start_soon(signal_soon_async, handle)
kernel32.CloseHandle(handle)
t1 = _core.current_time()
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
print("test_WaitForSingleObject_slow set from task as int OK")
# Test handle is CLOSED after 1 sec - NOPE see comment above
# Test cancellation
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
t0 = _core.current_time()
with _timeouts.move_on_after(TIMEOUT):
await WaitForSingleObject(handle)
kernel32.CloseHandle(handle)
t1 = _core.current_time()
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
print("test_WaitForSingleObject_slow cancellation OK")

View File

@@ -0,0 +1,110 @@
import errno
import select
import os
import sys
import pytest
from .._core.tests.tutil import gc_collect_harder
from .. import _core, move_on_after
from ..testing import wait_all_tasks_blocked, check_one_way_stream
if sys.platform == "win32":
from .._windows_pipes import PipeSendStream, PipeReceiveStream
from .._core._windows_cffi import _handle, kernel32
from asyncio.windows_utils import pipe
else:
pytestmark = pytest.mark.skip(reason="windows only")
pipe = None # type: Any
PipeSendStream = None # type: Any
PipeReceiveStream = None # type: Any
async def make_pipe() -> "Tuple[PipeSendStream, PipeReceiveStream]":
"""Makes a new pair of pipes."""
(r, w) = pipe()
return PipeSendStream(w), PipeReceiveStream(r)
async def test_pipe_typecheck():
with pytest.raises(TypeError):
PipeSendStream(1.0)
with pytest.raises(TypeError):
PipeReceiveStream(None)
async def test_pipe_error_on_close():
# Make sure we correctly handle a failure from kernel32.CloseHandle
r, w = pipe()
send_stream = PipeSendStream(w)
receive_stream = PipeReceiveStream(r)
assert kernel32.CloseHandle(_handle(r))
assert kernel32.CloseHandle(_handle(w))
with pytest.raises(OSError):
await send_stream.aclose()
with pytest.raises(OSError):
await receive_stream.aclose()
async def test_pipes_combined():
write, read = await make_pipe()
count = 2 ** 20
replicas = 3
async def sender():
async with write:
big = bytearray(count)
for _ in range(replicas):
await write.send_all(big)
async def reader():
async with read:
await wait_all_tasks_blocked()
total_received = 0
while True:
# 5000 is chosen because it doesn't evenly divide 2**20
received = len(await read.receive_some(5000))
if not received:
break
total_received += received
assert total_received == count * replicas
async with _core.open_nursery() as n:
n.start_soon(sender)
n.start_soon(reader)
async def test_async_with():
w, r = await make_pipe()
async with w, r:
pass
with pytest.raises(_core.ClosedResourceError):
await w.send_all(b"")
with pytest.raises(_core.ClosedResourceError):
await r.receive_some(10)
async def test_close_during_write():
w, r = await make_pipe()
async with _core.open_nursery() as nursery:
async def write_forever():
with pytest.raises(_core.ClosedResourceError) as excinfo:
while True:
await w.send_all(b"x" * 4096)
assert "another task" in str(excinfo.value)
nursery.start_soon(write_forever)
await wait_all_tasks_blocked(0.1)
await w.aclose()
async def test_pipe_fully():
# passing make_clogged_pipe tests wait_send_all_might_not_block, and we
# can't implement that on Windows
await check_one_way_stream(make_pipe, None)

View File

@@ -0,0 +1,72 @@
import ast
import astor
import pytest
import os
import sys
from shutil import copyfile
from trio._tools.gen_exports import (
get_public_methods,
create_passthrough_args,
process,
)
SOURCE = '''from _run import _public
class Test:
@_public
def public_func(self):
"""With doc string"""
@ignore_this
@_public
@another_decorator
async def public_async_func(self):
pass # no doc string
def not_public(self):
pass
async def not_public_async(self):
pass
'''
def test_get_public_methods():
methods = list(get_public_methods(ast.parse(SOURCE)))
assert {m.name for m in methods} == {"public_func", "public_async_func"}
def test_create_pass_through_args():
testcases = [
("def f()", "()"),
("def f(one)", "(one)"),
("def f(one, two)", "(one, two)"),
("def f(one, *args)", "(one, *args)"),
(
"def f(one, *args, kw1, kw2=None, **kwargs)",
"(one, *args, kw1=kw1, kw2=kw2, **kwargs)",
),
]
for (funcdef, expected) in testcases:
func_node = ast.parse(funcdef + ":\n pass").body[0]
assert isinstance(func_node, ast.FunctionDef)
assert create_passthrough_args(func_node) == expected
def test_process(tmp_path):
modpath = tmp_path / "_module.py"
genpath = tmp_path / "_generated_module.py"
modpath.write_text(SOURCE, encoding="utf-8")
assert not genpath.exists()
with pytest.raises(SystemExit) as excinfo:
process([(str(modpath), "runner")], do_test=True)
assert excinfo.value.code == 1
process([(str(modpath), "runner")], do_test=False)
assert genpath.exists()
process([(str(modpath), "runner")], do_test=True)
# But if we change the lookup path it notices
with pytest.raises(SystemExit) as excinfo:
process([(str(modpath), "runner.io_manager")], do_test=True)
assert excinfo.value.code == 1