Skip to content

Commit 0c82c8a

Browse files
committed
Add source code
1 parent 82324da commit 0c82c8a

File tree

5 files changed

+443
-0
lines changed

5 files changed

+443
-0
lines changed

src/bodyrtc/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from bodyrtc.builder import WebRTCOfferBuilder, WebRTCAnswerBuilder # noqa
2+
from bodyrtc.stream import WebRTCBaseStream, ConnectionProvider, MessageHandler # noqa

src/bodyrtc/builder.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import abc
2+
from typing import List, Dict
3+
4+
import aiortc
5+
6+
from bodyrtc.stream import WebRTCBaseStream, WebRTCOfferStream, WebRTCAnswerStream, ConnectionProvider
7+
8+
9+
class WebRTCStreamBuilder(abc.ABC):
10+
@abc.abstractmethod
11+
def stream(self) -> WebRTCBaseStream:
12+
raise NotImplementedError
13+
14+
15+
class WebRTCOfferBuilder(WebRTCStreamBuilder):
16+
def __init__(self, connection_provider: ConnectionProvider):
17+
self.connection_provider = connection_provider
18+
self.requested_camera_types: List[str] = []
19+
self.requested_audio = False
20+
self.audio_tracks: List[aiortc.MediaStreamTrack] = []
21+
self.messaging_enabled = False
22+
23+
def offer_to_receive_video_stream(self, camera_type: str):
24+
assert camera_type in ["driver", "wideRoad", "road"]
25+
self.requested_camera_types.append(camera_type)
26+
27+
def offer_to_receive_audio_stream(self):
28+
self.requested_audio = True
29+
30+
def add_audio_stream(self, track: aiortc.MediaStreamTrack):
31+
assert len(self.audio_tracks) == 0
32+
self.audio_tracks = [track]
33+
34+
def add_messaging(self):
35+
self.messaging_enabled = True
36+
37+
def stream(self) -> WebRTCBaseStream:
38+
return WebRTCOfferStream(
39+
self.connection_provider,
40+
consumed_camera_types=self.requested_camera_types,
41+
consume_audio=self.requested_audio,
42+
video_producer_tracks=[],
43+
audio_producer_tracks=self.audio_tracks,
44+
should_add_data_channel=self.messaging_enabled,
45+
)
46+
47+
48+
class WebRTCAnswerBuilder(WebRTCStreamBuilder):
49+
def __init__(self, offer_sdp: str):
50+
self.offer_sdp = offer_sdp
51+
self.video_tracks: Dict[str, aiortc.MediaStreamTrack] = dict()
52+
self.requested_audio = False
53+
self.audio_tracks: List[aiortc.MediaStreamTrack] = []
54+
55+
def offer_to_receive_audio_stream(self):
56+
self.requested_audio = True
57+
58+
def add_video_stream(self, camera_type: str, track: aiortc.MediaStreamTrack):
59+
assert camera_type not in self.video_tracks
60+
assert camera_type in ["driver", "wideRoad", "road"]
61+
self.video_tracks[camera_type] = track
62+
63+
def add_audio_stream(self, track: aiortc.MediaStreamTrack):
64+
assert len(self.audio_tracks) == 0
65+
self.audio_tracks = [track]
66+
67+
def stream(self) -> WebRTCBaseStream:
68+
description = aiortc.RTCSessionDescription(sdp=self.offer_sdp, type="offer")
69+
return WebRTCAnswerStream(
70+
description,
71+
consumed_camera_types=[],
72+
consume_audio=self.requested_audio,
73+
video_producer_tracks=list(self.video_tracks.values()),
74+
audio_producer_tracks=self.audio_tracks,
75+
should_add_data_channel=False,
76+
)

src/bodyrtc/info.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
import dataclasses
3+
4+
import aiortc
5+
6+
7+
@dataclasses.dataclass
8+
class StreamingMediaInfo:
9+
n_expected_camera_tracks: int
10+
expected_audio_track: bool
11+
incoming_audio_track: bool
12+
incoming_datachannel: bool
13+
14+
15+
def parse_info_from_offer(sdp: str) -> StreamingMediaInfo:
16+
"""
17+
helper function to parse info about outgoing and incoming streams from an offer sdp
18+
"""
19+
desc = aiortc.sdp.SessionDescription.parse(sdp)
20+
audio_tracks = [m for m in desc.media if m.kind == "audio"]
21+
video_tracks = [m for m in desc.media if m.kind == "video" and m.direction in ["recvonly", "sendrecv"]]
22+
application_tracks = [m for m in desc.media if m.kind == "application"]
23+
has_incoming_audio_track = next((t for t in audio_tracks if t.direction in ["sendonly", "sendrecv"]), None) is not None
24+
has_incoming_datachannel = len(application_tracks) > 0
25+
expects_outgoing_audio_track = next((t for t in audio_tracks if t.direction in ["recvonly", "sendrecv"]), None) is not None
26+
27+
return StreamingMediaInfo(len(video_tracks), expects_outgoing_audio_track, has_incoming_audio_track, has_incoming_datachannel)

src/bodyrtc/stream.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
import abc
2+
import asyncio
3+
import dataclasses
4+
import logging
5+
from typing import Callable, Awaitable, Dict, List, Any, Optional
6+
7+
import aiortc
8+
from aiortc.contrib.media import MediaRelay
9+
10+
11+
@dataclasses.dataclass
12+
class StreamingOffer:
13+
sdp: str
14+
video: List[str]
15+
16+
17+
ConnectionProvider = Callable[[StreamingOffer], Awaitable[aiortc.RTCSessionDescription]]
18+
MessageHandler = Callable[[bytes], Awaitable[None]]
19+
20+
21+
class WebRTCBaseStream(abc.ABC):
22+
def __init__(self,
23+
consumed_camera_types: List[str],
24+
consume_audio: bool,
25+
video_producer_tracks: List[aiortc.MediaStreamTrack],
26+
audio_producer_tracks: List[aiortc.MediaStreamTrack],
27+
should_add_data_channel: bool):
28+
self.peer_connection = aiortc.RTCPeerConnection()
29+
self.media_relay = MediaRelay()
30+
self.expected_incoming_camera_types = consumed_camera_types
31+
self.expected_incoming_audio = consume_audio
32+
self.expected_number_of_incoming_media: Optional[int] = None
33+
34+
self.incoming_camera_tracks: Dict[str, aiortc.MediaStreamTrack] = dict()
35+
self.incoming_audio_tracks: List[aiortc.MediaStreamTrack] = []
36+
self.outgoing_video_tracks: List[aiortc.MediaStreamTrack] = video_producer_tracks
37+
self.outgoing_audio_tracks: List[aiortc.MediaStreamTrack] = audio_producer_tracks
38+
39+
self.should_add_data_channel = should_add_data_channel
40+
self.messaging_channel: Optional[aiortc.RTCDataChannel] = None
41+
self.incoming_message_handlers: List[MessageHandler] = []
42+
43+
self.incoming_media_ready_event = asyncio.Event()
44+
self.connection_attempted_event = asyncio.Event()
45+
self.connection_stopped_event = asyncio.Event()
46+
47+
self.peer_connection.on("connectionstatechange", self._on_connectionstatechange)
48+
self.peer_connection.on("datachannel", self._on_incoming_datachannel)
49+
self.peer_connection.on("track", self._on_incoming_track)
50+
51+
self.logger = logging.getLogger("WebRTCStream")
52+
53+
def _log_debug(self, msg: Any, *args):
54+
self.logger.debug(f"{type(self)}() {msg}", *args)
55+
56+
@property
57+
def _number_of_incoming_media(self) -> int:
58+
media = len(self.incoming_camera_tracks) + len(self.incoming_audio_tracks)
59+
# if stream does not add data_channel, then it means its incoming
60+
media += int(self.messaging_channel is not None) if not self.should_add_data_channel else 0
61+
return media
62+
63+
def _add_consumer_transceivers(self):
64+
for _ in self.expected_incoming_camera_types:
65+
self.peer_connection.addTransceiver("video", direction="recvonly")
66+
if self.expected_incoming_audio:
67+
self.peer_connection.addTransceiver("audio", direction="recvonly")
68+
69+
def _add_producer_tracks(self):
70+
for track in self.outgoing_video_tracks:
71+
sender = self.peer_connection.addTrack(track)
72+
if hasattr(track, "codec_preference") and track.codec_preference() is not None:
73+
transceiver = next(t for t in self.peer_connection.getTransceivers() if t.sender == sender)
74+
self._force_codec(transceiver, track.codec_preference(), "video")
75+
for track in self.outgoing_audio_tracks:
76+
self.peer_connection.addTrack(track)
77+
78+
def _add_messaging_channel(self, channel: Optional[aiortc.RTCDataChannel] = None):
79+
if not channel:
80+
channel = self.peer_connection.createDataChannel("data", ordered=True)
81+
82+
for handler in self.incoming_message_handlers:
83+
channel.on("message", handler)
84+
self.messaging_channel = channel
85+
86+
def _force_codec(self, transceiver: aiortc.RTCRtpTransceiver, codec: str, stream_type: str):
87+
codec_mime = f"{stream_type}/{codec.upper()}"
88+
rtp_codecs = aiortc.RTCRtpSender.getCapabilities(stream_type).codecs
89+
rtp_codec = [c for c in rtp_codecs if c.mimeType == codec_mime]
90+
transceiver.setCodecPreferences(rtp_codec)
91+
92+
def _on_connectionstatechange(self):
93+
self._log_debug("connection state is %s", self.peer_connection.connectionState)
94+
if self.peer_connection.connectionState in ['connected', 'failed']:
95+
self.connection_attempted_event.set()
96+
if self.peer_connection.connectionState in ['disconnected', 'closed', 'failed']:
97+
self.connection_stopped_event.set()
98+
99+
def _on_incoming_track(self, track: aiortc.MediaStreamTrack):
100+
self._log_debug("got track: %s %s", track.kind, track.id)
101+
if track.kind == "video":
102+
parts = track.id.split(":") # format: "camera_type:camera_id"
103+
if len(parts) < 2:
104+
return
105+
106+
camera_type = parts[0]
107+
if camera_type in self.expected_incoming_camera_types:
108+
self.incoming_camera_tracks[camera_type] = track
109+
elif track.kind == "audio":
110+
if self.expected_incoming_audio:
111+
self.incoming_audio_tracks.append(track)
112+
self._on_after_media()
113+
114+
def _on_incoming_datachannel(self, channel: aiortc.RTCDataChannel):
115+
self._log_debug("got data channel: %s", channel.label)
116+
if channel.label == "data" and self.messaging_channel is None:
117+
self._add_messaging_channel(channel)
118+
self._on_after_media()
119+
120+
def _on_after_media(self):
121+
if self._number_of_incoming_media == self.expected_number_of_incoming_media:
122+
self.incoming_media_ready_event.set()
123+
124+
def _parse_incoming_streams(self, remote_sdp: str):
125+
desc = aiortc.sdp.SessionDescription.parse(remote_sdp)
126+
sending_medias = [m for m in desc.media if m.direction in ["sendonly", "sendrecv"]]
127+
incoming_media_count = len(sending_medias)
128+
if not self.should_add_data_channel:
129+
channel_medias = [m for m in desc.media if m.kind == "application"]
130+
incoming_media_count += len(channel_medias)
131+
self.expected_number_of_incoming_media = incoming_media_count
132+
133+
def has_incoming_video_track(self, camera_type: str) -> bool:
134+
return camera_type in self.incoming_camera_tracks
135+
136+
def has_incoming_audio_track(self) -> bool:
137+
return len(self.incoming_audio_tracks) > 0
138+
139+
def has_messaging_channel(self) -> bool:
140+
return self.messaging_channel is not None
141+
142+
def get_incoming_video_track(self, camera_type: str, buffered: bool) -> aiortc.MediaStreamTrack:
143+
assert camera_type in self.incoming_camera_tracks, "Video tracks are not enabled on this stream"
144+
assert self.is_started, "Stream must be started"
145+
146+
track = self.incoming_camera_tracks[camera_type]
147+
relay_track = self.media_relay.subscribe(track, buffered=buffered)
148+
return relay_track
149+
150+
def get_incoming_audio_track(self, buffered: bool) -> aiortc.MediaStreamTrack:
151+
assert len(self.incoming_audio_tracks) > 0, "Audio tracks are not enabled on this stream"
152+
assert self.is_started, "Stream must be started"
153+
154+
track = self.incoming_audio_tracks[0]
155+
relay_track = self.media_relay.subscribe(track, buffered=buffered)
156+
return relay_track
157+
158+
def get_messaging_channel(self) -> aiortc.RTCDataChannel:
159+
assert self.messaging_channel is not None, "Messaging channel is not enabled on this stream"
160+
assert self.is_started, "Stream must be started"
161+
162+
return self.messaging_channel
163+
164+
def set_message_handler(self, message_handler: MessageHandler):
165+
self.incoming_message_handlers.append(message_handler)
166+
if self.messaging_channel is not None:
167+
self.messaging_channel.on("message", message_handler)
168+
169+
@property
170+
def is_started(self) -> bool:
171+
return self.peer_connection is not None and \
172+
self.peer_connection.localDescription is not None and \
173+
self.peer_connection.remoteDescription is not None
174+
175+
@property
176+
def is_connected_and_ready(self) -> bool:
177+
return self.peer_connection is not None and \
178+
self.peer_connection.connectionState == "connected" and \
179+
self.expected_number_of_incoming_media != 0 and self.incoming_media_ready_event.is_set()
180+
181+
async def wait_for_connection(self):
182+
await self.connection_attempted_event.wait()
183+
if self.peer_connection.connectionState != 'connected':
184+
raise ValueError("Connection failed.")
185+
if self.expected_number_of_incoming_media:
186+
await self.incoming_media_ready_event.wait()
187+
188+
async def wait_for_disconnection(self):
189+
await self.connection_stopped_event.wait()
190+
191+
async def stop(self):
192+
await self.peer_connection.close()
193+
194+
@abc.abstractmethod
195+
async def start(self) -> aiortc.RTCSessionDescription:
196+
raise NotImplementedError
197+
198+
199+
class WebRTCOfferStream(WebRTCBaseStream):
200+
def __init__(self, session_provider: ConnectionProvider, *args, **kwargs):
201+
super().__init__(*args, **kwargs)
202+
self.session_provider = session_provider
203+
204+
async def start(self) -> aiortc.RTCSessionDescription:
205+
self._add_consumer_transceivers()
206+
if self.should_add_data_channel:
207+
self._add_messaging_channel()
208+
self._add_producer_tracks()
209+
210+
offer = await self.peer_connection.createOffer()
211+
await self.peer_connection.setLocalDescription(offer)
212+
actual_offer = self.peer_connection.localDescription
213+
214+
streaming_offer = StreamingOffer(
215+
sdp=actual_offer.sdp,
216+
video=list(self.expected_incoming_camera_types),
217+
)
218+
remote_answer = await self.session_provider(streaming_offer)
219+
self._parse_incoming_streams(remote_sdp=remote_answer.sdp)
220+
await self.peer_connection.setRemoteDescription(remote_answer)
221+
actual_answer = self.peer_connection.remoteDescription
222+
223+
return actual_answer
224+
225+
226+
class WebRTCAnswerStream(WebRTCBaseStream):
227+
def __init__(self, session: aiortc.RTCSessionDescription, *args, **kwargs):
228+
super().__init__(*args, **kwargs)
229+
self.session = session
230+
231+
def _probe_video_codecs(self) -> List[str]:
232+
codecs = []
233+
for track in self.outgoing_video_tracks:
234+
if hasattr(track, "codec_preference") and track.codec_preference() is not None:
235+
codecs.append(track.codec_preference())
236+
237+
return codecs
238+
239+
def _override_incoming_video_codecs(self, remote_sdp: str, codecs: List[str]) -> str:
240+
desc = aiortc.sdp.SessionDescription.parse(remote_sdp)
241+
codec_mimes = [f"video/{c}" for c in codecs]
242+
for m in desc.media:
243+
if m.kind != "video":
244+
continue
245+
246+
preferred_codecs: List[aiortc.RTCRtpCodecParameters] = [c for c in m.rtp.codecs if c.mimeType in codec_mimes]
247+
if len(preferred_codecs) == 0:
248+
raise ValueError(f"None of {preferred_codecs} codecs is supported in remote SDP")
249+
250+
m.rtp.codecs = preferred_codecs
251+
m.fmt = [c.payloadType for c in preferred_codecs]
252+
253+
return str(desc)
254+
255+
async def start(self) -> aiortc.RTCSessionDescription:
256+
assert self.peer_connection.remoteDescription is None, "Connection already established"
257+
258+
self._add_consumer_transceivers()
259+
260+
# since we sent already encoded frames in some cases (e.g. livestream video tracks are in H264), we need to force aiortc to actually use it
261+
# we do that by overriding supported codec information on incoming sdp
262+
preferred_codecs = self._probe_video_codecs()
263+
if len(preferred_codecs) > 0:
264+
self.session.sdp = self._override_incoming_video_codecs(self.session.sdp, preferred_codecs)
265+
266+
self._parse_incoming_streams(remote_sdp=self.session.sdp)
267+
await self.peer_connection.setRemoteDescription(self.session)
268+
269+
self._add_producer_tracks()
270+
271+
answer = await self.peer_connection.createAnswer()
272+
await self.peer_connection.setLocalDescription(answer)
273+
actual_answer = self.peer_connection.localDescription
274+
275+
return actual_answer

0 commit comments

Comments
 (0)