moving to scripts
This commit is contained in:
@@ -0,0 +1,94 @@
|
||||
import pytest
|
||||
|
||||
import attr
|
||||
|
||||
from ..abc import SendStream, ReceiveStream
|
||||
from .._highlevel_generic import StapledStream
|
||||
|
||||
|
||||
@attr.s
|
||||
class RecordSendStream(SendStream):
|
||||
record = attr.ib(factory=list)
|
||||
|
||||
async def send_all(self, data):
|
||||
self.record.append(("send_all", data))
|
||||
|
||||
async def wait_send_all_might_not_block(self):
|
||||
self.record.append("wait_send_all_might_not_block")
|
||||
|
||||
async def aclose(self):
|
||||
self.record.append("aclose")
|
||||
|
||||
|
||||
@attr.s
|
||||
class RecordReceiveStream(ReceiveStream):
|
||||
record = attr.ib(factory=list)
|
||||
|
||||
async def receive_some(self, max_bytes=None):
|
||||
self.record.append(("receive_some", max_bytes))
|
||||
|
||||
async def aclose(self):
|
||||
self.record.append("aclose")
|
||||
|
||||
|
||||
async def test_StapledStream():
|
||||
send_stream = RecordSendStream()
|
||||
receive_stream = RecordReceiveStream()
|
||||
stapled = StapledStream(send_stream, receive_stream)
|
||||
|
||||
assert stapled.send_stream is send_stream
|
||||
assert stapled.receive_stream is receive_stream
|
||||
|
||||
await stapled.send_all(b"foo")
|
||||
await stapled.wait_send_all_might_not_block()
|
||||
assert send_stream.record == [
|
||||
("send_all", b"foo"),
|
||||
"wait_send_all_might_not_block",
|
||||
]
|
||||
send_stream.record.clear()
|
||||
|
||||
await stapled.send_eof()
|
||||
assert send_stream.record == ["aclose"]
|
||||
send_stream.record.clear()
|
||||
|
||||
async def fake_send_eof():
|
||||
send_stream.record.append("send_eof")
|
||||
|
||||
send_stream.send_eof = fake_send_eof
|
||||
await stapled.send_eof()
|
||||
assert send_stream.record == ["send_eof"]
|
||||
|
||||
send_stream.record.clear()
|
||||
assert receive_stream.record == []
|
||||
|
||||
await stapled.receive_some(1234)
|
||||
assert receive_stream.record == [("receive_some", 1234)]
|
||||
assert send_stream.record == []
|
||||
receive_stream.record.clear()
|
||||
|
||||
await stapled.aclose()
|
||||
assert receive_stream.record == ["aclose"]
|
||||
assert send_stream.record == ["aclose"]
|
||||
|
||||
|
||||
async def test_StapledStream_with_erroring_close():
|
||||
# Make sure that if one of the aclose methods errors out, then the other
|
||||
# one still gets called.
|
||||
class BrokenSendStream(RecordSendStream):
|
||||
async def aclose(self):
|
||||
await super().aclose()
|
||||
raise ValueError
|
||||
|
||||
class BrokenReceiveStream(RecordReceiveStream):
|
||||
async def aclose(self):
|
||||
await super().aclose()
|
||||
raise ValueError
|
||||
|
||||
stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
await stapled.aclose()
|
||||
assert isinstance(excinfo.value.__context__, ValueError)
|
||||
|
||||
assert stapled.send_stream.record == ["aclose"]
|
||||
assert stapled.receive_stream.record == ["aclose"]
|
||||
Reference in New Issue
Block a user