Support RSocket
Skip to main content

Channels

In this section we will add basic channel support:

  • Joining and leaving channels
  • Sending messages to channels

We will also cover automatically decoding the payload data based on method argument type-hints.

See resulting code on GitHub

Shared code

Let's add a channel property to the Message class. It will contain the name of the channel the message is intended for.

class Message:
...
channel: Optional[str] = None

Server side

Data-classes

We will add functionality to store the channel state. Add the following fields to the ChatData class:

from asyncio import Queue
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Set
from weakref import WeakSet

@dataclass(frozen=True)
class ChatData:
...
channel_users: Dict[str, Set[SessionId]] = field(default_factory=lambda: defaultdict(WeakSet))
channel_messages: Dict[str, Queue] = field(default_factory=lambda: defaultdict(Queue))

In the channel_users dict, the keys are channel names, and the value is a set of user session ids. A WeakSet is used to automatically remove logged-out users.

In the channel_messages dict, the keys are the channel names, and the value is a Queue of messages sent by users to the channel.

Helper methods

Next, we will define some helper methods for managing channel messages:

  • ensure_channel_exists: initialize the data for a new channel if it doesn't exist.
  • channel_message_delivery: an asyncio task which will deliver channel messages to all the users in a channel.
def ensure_channel_exists(channel_name: str):
if channel_name not in chat_data.channel_users:
chat_data.channel_users[channel_name] = WeakSet()
chat_data.channel_messages[channel_name] = Queue()
asyncio.create_task(channel_message_delivery(channel_name))

If the channel doesn't exist yet (Line 2) It will be added to the channel_users and channel_messages dictionaries. Line 5 starts an asyncio task (described below) which will deliver messages sent to the channel, to the channel's users.

async def channel_message_delivery(channel_name: str):
while True:
try:
message = await chat_data.channel_messages[channel_name].get()
for session_id in chat_data.channel_users[channel_name]:
user_specific_message = Message(user=message.user,
content=message.content)
chat_data.user_session_by_id[session_id].messages.put_nowait(user_specific_message)
except Exception as exception:
logging.error(str(exception), exc_info=True)

The above method will loop infinitely and watch the channel_messages queue of the specified channel (Line 8). Upon receiving a message, it will be delivered to all the users in the channel (Lines 9-13).

The final helper will look up username by session id:

def find_username_by_session(session_id: SessionId) -> Optional[str]:
session = chat_data.user_session_by_id.get(session_id)
if session is None:
return None
return session.username

Join/Leave Channel

Now let's add the channel join/leave handling request-response endpoints.

class ChatUserSession:

def router_factory(self):
router = RequestRouter()

@router.response('channel.join')
async def join_channel(payload: Payload) -> Awaitable[Payload]:
channel_name = payload.data.decode('utf-8')
ensure_channel_exists(channel_name)
chat_data.channel_users[channel_name].add(self._session.session_id)
return create_response()

@router.response('channel.leave')
async def leave_channel(payload: Payload) -> Awaitable[Payload]:
channel_name = payload.data.decode('utf-8')
chat_data.channel_users[channel_name].discard(self._session.session_id)
return create_response()

Send channel message

Next we add the ability to send channel message. We will modify the send_message method:

class ChatUserSession:

def router_factory(self):
router = RequestRouter()

@router.response('message')
async def send_message(payload: Payload) -> Awaitable[Payload]:
message = Message(**json.loads(payload.data))

logging.info('Received message for user: %s, channel: %s', message.user, message.channel)

target_message = Message(self._session.username, message.content, message.channel)

if message.channel is not None:
await chat_data.channel_messages[message.channel].put(target_message)
elif message.user is not None:
session = find_session_by_username(message.user)
await session.messages.put(target_message)

return create_response()

Lines 16-20 decide whether it is a private message or a channel message, and add it to the relevant queue.

List channels

class ChatUserSession:

def router_factory(self):
router = RequestRouter()

@router.stream('channels')
async def get_channels() -> Publisher:
count = len(chat_data.channel_messages)
generator = ((Payload(ensure_bytes(channel)), index == count) for (index, channel) in
enumerate(chat_data.channel_messages.keys(), 1))
return StreamFromGenerator(lambda: generator)

Lines 6-11 define an endpoint for getting a list of channels. It uses the StreamFromGenerator helper. Note that the argument to this class is a factory method for the generator, not the generator itself.

Get channel users

class ChatUserSession:

def router_factory(self):
router = RequestRouter()

@router.stream('channel.users')
async def get_channel_users(payload: Payload) -> Publisher:
channel_name = utf8_decode(payload.data)

if channel_name not in chat_data.channel_users:
return EmptyStream()

count = len(chat_data.channel_users[channel_name])
generator = ((Payload(ensure_bytes(find_username_by_session(session_id))), index == count) for
(index, session_id) in
enumerate(chat_data.channel_users[channel_name], 1))

return StreamFromGenerator(lambda: generator)

Lines 6-11 define an endpoint for getting a list of users in a given channel. The find_username_by_session helper method is used to convert the session ids to usernames.

If the channel does not exist (Line 10) the EmptyStream helper can be used as a response.

Simplify route methods

Up until now, all routed methods received the Payload as an argument and extracted and decoded the data property. There is a simpler method for doing this. By passing a payload_mapper to the RequestRouter and specifying a type-hint on the route method argument, the method arguments will be automatically decoded. We will modify the code and add this functionality.

In the shared module add the following helper method, which will be used as the payload_mapper:

from rsocket.payload import Payload
from rsocket.helpers import utf8_decode

def decode_payload(cls, payload: Payload):
data = payload.data

if cls is bytes:
return data

if cls is str:
return utf8_decode(data)

return decode_dataclass(data, cls)

This method converts the payload data according to some type hints. It assumes that the payload data can be converted to either:

  • bytes, in which case no transformation is applied (Lines 7-8)
  • str, in which case a utf8 decode is applied (Lines 10-11)
  • A dataclass which, which is decoded using a previously define helper (See Private Message) (Line 13)

Next we will pass this method as an argument to the RequestRouter:

router = RequestRouter(payload_mapper=decode_payload)

Once this is done, the signature of the routed methods may be simplified. For example the join_channel method can be simplified to:

@router.response('channel.join')
async def join_channel(channel_name: str) -> Awaitable[Payload]:
# channel_name = payload.data.decode('utf-8') # this can be removed
...

and the send_message method can be simplified to:

@router.response('message')
async def send_message(message: Message) -> Awaitable[Payload]:
# message = Message(**json.loads(payload.data)) # this can be removed
logging.info('Received message for user: %s, channel: %s', message.user, message.channel)
...

Client side

Channel requests

We will add the methods on the ChatClient to interact with the new server functionality:

from typing import List

from rsocket.awaitable.awaitable_rsocket import AwaitableRSocket
from rsocket.extensions.helpers import composite, route
from rsocket.frame_helpers import ensure_bytes
from rsocket.payload import Payload
from rsocket.helpers import utf8_decode

from shared import encode_dataclass

class ChatClient:

async def join(self, channel_name: str):
request = Payload(ensure_bytes(channel_name), composite(route('channel.join')))
await self._rsocket.request_response(request)
return self

async def leave(self, channel_name: str):
request = Payload(ensure_bytes(channel_name), composite(route('channel.leave')))
await self._rsocket.request_response(request)
return self

async def channel_message(self, channel: str, content: str):
print(f'Sending {content} to channel {channel}')
await self._rsocket.request_response(Payload(encode_dataclass(Message(channel=channel, content=content)),
composite(route('message'))))

async def list_channel_users(self, channel_name: str):
request = Payload(ensure_bytes(channel_name), composite(route('channel.users')))
response = await AwaitableRSocket(self._rsocket).request_stream(request)
return list(map(lambda _: utf8_decode(_.data), response))

async def get_users(self, channel_name: str) -> List[str]:
request = Payload(ensure_bytes(channel_name), composite(route('channel.users')))
users = await AwaitableRSocket(self._rsocket).request_stream(request)
return [utf8_decode(user.data) for user in users]

Lines 15-23 define the join/leave methods. They are both simple routed request_response calls, with the channel name as the payload data.

Lines 25-28 define the list_channels method. This method uses the AwaitableRSocket adapter to simplify getting the response stream as a list.

Lines 30-31 define the get_users method, which lists a channel's users.

Update the print_message method to include the channel:

def print_message(data: bytes):
message = Message(**json.loads(data))
print(f'{self._username}: from {message.user} ({message.channel}): {message.content}')

Test the new functionality

Let's test the new functionality using the following code:

async def messaging_example(user1: ChatClient, user2: ChatClient):
user1.listen_for_messages()
user2.listen_for_messages()

await user1.join('channel1')
await user2.join('channel1')

print(f'Channels: {await user1.list_channels()}')

await user1.private_message('user2', 'private message from user1')
await user1.channel_message('channel1', 'channel message from user1')

await asyncio.sleep(1)

user1.stop_listening_for_messages()
user2.stop_listening_for_messages()

Call the example method from the main method and pass it the two chat clients:

user1 = ChatClient(client1)
user2 = ChatClient(client2)

await user1.login('user1')
await user2.login('user2')

await messaging_example(user1, user2)