Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use shared types #75

Merged
merged 1 commit into from
Dec 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions realtime/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
from realtime.channel import Channel
from realtime.exceptions import NotConnectedError
from realtime.message import HEARTBEAT_PAYLOAD, PHOENIX_CHANNEL, ChannelEvents, Message


T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")
from realtime.types import Callback, T_ParamSpec, T_Retval

logging.basicConfig(
format="%(asctime)s:%(levelname)s - %(message)s", level=logging.INFO)
format="%(asctime)s:%(levelname)s - %(message)s", level=logging.INFO
)


def ensure_connection(func: Callable[T_ParamSpec, T_Retval]):
def ensure_connection(func: Callback):
@wraps(func)
def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
if not args[0].connected:
Expand All @@ -31,7 +30,13 @@ def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:


class Socket:
def __init__(self, url: str, auto_reconnect: bool = False, params: Dict[str, Any] = {}, hb_interval: int = 5) -> None:
def __init__(
self,
url: str,
auto_reconnect: bool = False,
params: Dict[str, Any] = {},
hb_interval: int = 5,
) -> None:
"""
`Socket` is the abstraction for an actual socket connection that receives and 'reroutes' `Message` according to its `topic` and `event`.
Socket-Channel has a 1-many relationship.
Expand Down Expand Up @@ -59,8 +64,7 @@ def listen(self) -> None:
:return: None
"""
loop = asyncio.get_event_loop() # TODO: replace with get_running_loop
loop.run_until_complete(asyncio.gather(
self._listen(), self._keep_alive()))
loop.run_until_complete(asyncio.gather(self._listen(), self._keep_alive()))

async def _listen(self) -> None:
"""
Expand All @@ -81,7 +85,9 @@ async def _listen(self) -> None:
cl.callback(msg.payload)
except websockets.exceptions.ConnectionClosed:
if self.auto_reconnect:
logging.info("Connection with server closed, trying to reconnect...")
logging.info(
"Connection with server closed, trying to reconnect..."
)
await self._connect()
for topic, channels in self.channels.items():
for channel in channels:
Expand Down Expand Up @@ -125,7 +131,9 @@ async def _keep_alive(self) -> None:
await asyncio.sleep(self.hb_interval)
except websockets.exceptions.ConnectionClosed:
if self.auto_reconnect:
logging.info("Connection with server closed, trying to reconnect...")
logging.info(
"Connection with server closed, trying to reconnect..."
)
await self._connect()
else:
logging.exception("Connection with the server closed.")
Expand All @@ -149,5 +157,4 @@ def summary(self) -> None:
"""
for topic, chans in self.channels.items():
for chan in chans:
print(
f"Topic: {topic} | Events: {[e for e, _ in chan.callbacks]}]")
print(f"Topic: {topic} | Events: {[e for e, _ in chan.callbacks]}]")