Skip to content
Prev Previous commit
Next Next commit
Added 'requests' test client
  • Loading branch information
lovelydinosaur committed Aug 15, 2016
commit 3d1fff3f26835612be17b6624d766a56520880ec
2 changes: 1 addition & 1 deletion rest_framework/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def POST(self):
if not _hasattr(self, '_data'):
self._load_data_and_files()
if is_form_media_type(self.content_type):
return self.data
return self._data
return QueryDict('', encoding=self._request._encoding)

@property
Expand Down
86 changes: 85 additions & 1 deletion rest_framework/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
# to make it harder for the user to import the wrong thing without realizing.
from __future__ import unicode_literals

import io

from django.conf import settings
from django.core.handlers.wsgi import WSGIHandler
from django.test import testcases
from django.test.client import Client as DjangoClient
from django.test.client import RequestFactory as DjangoRequestFactory
Expand All @@ -13,6 +16,10 @@
from django.utils.encoding import force_bytes
from django.utils.http import urlencode
from requests import Session
from requests.adapters import BaseAdapter
from requests.models import Response
from requests.structures import CaseInsensitiveDict
from requests.utils import get_encoding_from_headers

from rest_framework.settings import api_settings

Expand All @@ -22,6 +29,83 @@ def force_authenticate(request, user=None, token=None):
request._force_auth_token = token


class DjangoTestAdapter(BaseAdapter):
"""
A transport adaptor for `requests`, that makes requests via the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like typo:

"adaptor" -> "adapter"

Django WSGI app, rather than making actual HTTP requests ovet the network.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"ovet" -> "over"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

"""
def __init__(self):
self.app = WSGIHandler()
self.factory = DjangoRequestFactory()

def get_environ(self, request):
"""
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
"""
method = request.method
url = request.url
kwargs = {}

# Set request content, if any exists.
if request.body is not None:
kwargs['data'] = request.body
if 'content-type' in request.headers:
kwargs['content_type'] = request.headers['content-type']

# Set request headers.
for key, value in request.headers.items():
key = key.upper()
if key in ('CONNECTION', 'CONTENT_LENGTH', 'CONTENT-TYPE'):
continue
kwargs['HTTP_%s' % key] = value

return self.factory.generic(method, url, **kwargs).environ

def send(self, request, *args, **kwargs):
"""
Make an outgoing request to the Django WSGI application.
"""
response = Response()

def start_response(status, headers):
status_code, _, reason_phrase = status.partition(' ')
response.status_code = int(status_code)
response.reason = reason_phrase
response.headers = CaseInsensitiveDict(headers)
response.encoding = get_encoding_from_headers(response.headers)

environ = self.get_environ(request)
raw_bytes = self.app(environ, start_response)

response.request = request
response.url = request.url
response.raw = io.BytesIO(b''.join(raw_bytes))

return response

def close(self):
pass


class DjangoTestSession(Session):
def __init__(self, *args, **kwargs):
super(DjangoTestSession, self).__init__(*args, **kwargs)

adapter = DjangoTestAdapter()
hostnames = list(settings.ALLOWED_HOSTS) + ['testserver']

for hostname in hostnames:
if hostname == '*':
hostname = ''
self.mount('http://%s' % hostname, adapter)
self.mount('https://%s' % hostname, adapter)

def request(self, method, url, *args, **kwargs):
if ':' not in url:
url = 'http://testserver/' + url.lstrip('/')
return super(DjangoTestSession, self).request(method, url, *args, **kwargs)


class APIRequestFactory(DjangoRequestFactory):
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
Expand Down Expand Up @@ -224,7 +308,7 @@ class APITestCase(testcases.TestCase):

def _pre_setup(self):
super(APITestCase, self)._pre_setup()
self.requests = Session()
self.requests = DjangoTestSession()


class APISimpleTestCase(testcases.SimpleTestCase):
Expand Down
119 changes: 116 additions & 3 deletions tests/test_requests_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,128 @@

class Root(APIView):
def get(self, request):
return Response({'hello': 'world'})
return Response({
'method': request.method,
'query_params': request.query_params,
})

def post(self, request):
files = {
key: (value.name, value.read())
for key, value in request.FILES.items()
}
post = request.POST
json = None
if request.META.get('CONTENT_TYPE') == 'application/json':
json = request.data

return Response({
'method': request.method,
'query_params': request.query_params,
'POST': post,
'FILES': files,
'JSON': json
})


class Headers(APIView):
def get(self, request):
headers = {
key[5:]: value
for key, value in request.META.items()
if key.startswith('HTTP_')
}
return Response({
'method': request.method,
'headers': headers
})


urlpatterns = [
url(r'^$', Root.as_view()),
url(r'^headers/$', Headers.as_view()),
]


@override_settings(ROOT_URLCONF='tests.test_requests_client')
class RequestsClientTests(APITestCase):
def test_get_root(self):
print self.requests.get('http://example.com')
def test_get_request(self):
response = self.requests.get('/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'GET',
'query_params': {}
}
assert response.json() == expected

def test_get_request_query_params_in_url(self):
response = self.requests.get('/?key=value')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'GET',
'query_params': {'key': 'value'}
}
assert response.json() == expected

def test_get_request_query_params_by_kwarg(self):
response = self.requests.get('/', params={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'GET',
'query_params': {'key': 'value'}
}
assert response.json() == expected

def test_get_with_headers(self):
response = self.requests.get('/headers/', headers={'User-Agent': 'example'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
headers = response.json()['headers']
assert headers['USER-AGENT'] == 'example'

def test_post_form_request(self):
response = self.requests.post('/', data={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'POST',
'query_params': {},
'POST': {'key': 'value'},
'FILES': {},
'JSON': None
}
assert response.json() == expected

def test_post_json_request(self):
response = self.requests.post('/', json={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'POST',
'query_params': {},
'POST': {},
'FILES': {},
'JSON': {'key': 'value'}
}
assert response.json() == expected

def test_post_multipart_request(self):
files = {
'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n')
}
response = self.requests.post('/', files=files)
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'POST',
'query_params': {},
'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']},
'POST': {},
'JSON': None
}
assert response.json() == expected

# cookies/session auth