Skip to content

Commit fc9653e

Browse files
committed
Add face_detection example utilizing webrtcd
1 parent fb11792 commit fc9653e

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed

examples/face_detection.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import argparse
2+
import asyncio
3+
4+
import aiortc
5+
import aiohttp
6+
import cv2
7+
import pygame
8+
9+
from bodyrtc import WebRTCOfferBuilder, StreamingOffer
10+
11+
12+
def pygame_should_quit():
13+
for event in pygame.event.get():
14+
if event.type == pygame.QUIT:
15+
return True
16+
return False
17+
18+
19+
class WebrtcdConnectionProvider:
20+
"""
21+
Connection provider reaching webrtcd server on comma three
22+
"""
23+
def __init__(self, host, port=5001):
24+
self.url = f"http://{host}:{port}/stream"
25+
26+
async def __call__(self, offer: StreamingOffer) -> aiortc.RTCSessionDescription:
27+
async with aiohttp.ClientSession() as session:
28+
body = {'sdp': offer.sdp, 'cameras': offer.video, 'bridge_services_in': [], 'bridge_services_out': []}
29+
async with session.post(self.url, json=body) as resp:
30+
payload = await resp.json()
31+
answer = aiortc.RTCSessionDescription(**payload)
32+
return answer
33+
34+
35+
class FaceDetector:
36+
"""
37+
Simple face detector using opencv
38+
"""
39+
def __init__(self):
40+
self.classifier = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
41+
42+
def detect(self, array):
43+
gray_array = cv2.cvtColor(array, cv2.COLOR_RGB2GRAY)
44+
faces = self.classifier.detectMultiScale(gray_array, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
45+
return faces
46+
47+
def draw(self, array, faces):
48+
for (x, y, w, h) in faces:
49+
cv2.rectangle(array, (x, y), (x + w, y + h), (0, 255, 0), 2)
50+
return array
51+
52+
53+
async def run_face_detection(stream):
54+
# setup pygame window
55+
pygame.init()
56+
screen_width, screen_height = 1280, 720
57+
screen = pygame.display.set_mode((screen_width, screen_height))
58+
pygame.display.set_caption("Face detection demo")
59+
surface = pygame.Surface((screen_width, screen_height))
60+
61+
# get the driver camera video track from the stream
62+
# generally its better to reuse the track object instead of getting it every time
63+
track = stream.get_incoming_video_track("driver", buffered=False)
64+
# cv2 face detector
65+
detector = FaceDetector()
66+
while stream.is_connected_and_ready and not pygame_should_quit():
67+
try:
68+
# receive frame as pyAV VideoFrame, convert to rgb24 numpy array
69+
frame = await track.recv()
70+
array = frame.to_ndarray(format="rgb24")
71+
72+
# detect faces and draw rects around them
73+
resized_array = cv2.resize(array, (screen_width, screen_height))
74+
faces = detector.detect(resized_array)
75+
detector.draw(resized_array, faces)
76+
77+
# display the image
78+
pygame.surfarray.blit_array(surface, resized_array.swapaxes(0, 1))
79+
screen.blit(surface, (0, 0))
80+
pygame.display.flip()
81+
82+
print("Received frame from", "driver", frame.time)
83+
except aiortc.mediastreams.MediaStreamError:
84+
break
85+
86+
pygame.quit()
87+
await stream.stop()
88+
89+
90+
async def run(args):
91+
# build your own the offer stream
92+
builder = WebRTCOfferBuilder(WebrtcdConnectionProvider(args.host))
93+
# request video stream from drivers camera
94+
builder.offer_to_receive_video_stream("driver")
95+
# add cereal messaging streaming support
96+
builder.add_messaging()
97+
98+
stream = builder.stream()
99+
100+
# start the stream then wait for connection
101+
# server will receive the offer and attempt to fulfill it
102+
await stream.start()
103+
await stream.wait_for_connection()
104+
105+
assert stream.has_incoming_video_track("driver") and stream.has_messaging_channel()
106+
107+
# run face detection loop on the drivers camera
108+
await run_face_detection(stream)
109+
110+
111+
if __name__=='__main__':
112+
parser = argparse.ArgumentParser()
113+
parser.add_argument("--host", default="localhost", help="Host for webrtcd server")
114+
115+
args = parser.parse_args()
116+
asyncio.run(run(args))

0 commit comments

Comments
 (0)