Skip to content
2 changes: 1 addition & 1 deletion rest_framework/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def POST(self):
if not _hasattr(self, '_data'):
self._load_data_and_files()
if is_form_media_type(self.content_type):
return self.data
return self._data
return QueryDict('', encoding=self._request._encoding)

@property
Expand Down
89 changes: 89 additions & 0 deletions rest_framework/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
# to make it harder for the user to import the wrong thing without realizing.
from __future__ import unicode_literals

import io

from django.conf import settings
from django.core.handlers.wsgi import WSGIHandler
from django.test import testcases
from django.test.client import Client as DjangoClient
from django.test.client import RequestFactory as DjangoRequestFactory
from django.test.client import ClientHandler
from django.utils import six
from django.utils.encoding import force_bytes
from django.utils.http import urlencode
from requests import Session
from requests.adapters import BaseAdapter
from requests.models import Response
from requests.structures import CaseInsensitiveDict
from requests.utils import get_encoding_from_headers

from rest_framework.settings import api_settings

Expand All @@ -21,6 +29,83 @@ def force_authenticate(request, user=None, token=None):
request._force_auth_token = token


class DjangoTestAdapter(BaseAdapter):
"""
A transport adaptor for `requests`, that makes requests via the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like typo:

"adaptor" -> "adapter"

Django WSGI app, rather than making actual HTTP requests ovet the network.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"ovet" -> "over"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

"""
def __init__(self):
self.app = WSGIHandler()
self.factory = DjangoRequestFactory()

def get_environ(self, request):
"""
Given a `requests.PreparedRequest` instance, return a WSGI environ dict.
"""
method = request.method
url = request.url
kwargs = {}

# Set request content, if any exists.
if request.body is not None:
kwargs['data'] = request.body
if 'content-type' in request.headers:
kwargs['content_type'] = request.headers['content-type']

# Set request headers.
for key, value in request.headers.items():
key = key.upper()
if key in ('CONNECTION', 'CONTENT_LENGTH', 'CONTENT-TYPE'):
continue
kwargs['HTTP_%s' % key] = value

return self.factory.generic(method, url, **kwargs).environ

def send(self, request, *args, **kwargs):
"""
Make an outgoing request to the Django WSGI application.
"""
response = Response()

def start_response(status, headers):
status_code, _, reason_phrase = status.partition(' ')
response.status_code = int(status_code)
response.reason = reason_phrase
response.headers = CaseInsensitiveDict(headers)
response.encoding = get_encoding_from_headers(response.headers)

environ = self.get_environ(request)
raw_bytes = self.app(environ, start_response)

response.request = request
response.url = request.url
response.raw = io.BytesIO(b''.join(raw_bytes))

return response

def close(self):
pass


class DjangoTestSession(Session):
def __init__(self, *args, **kwargs):
super(DjangoTestSession, self).__init__(*args, **kwargs)

adapter = DjangoTestAdapter()
hostnames = list(settings.ALLOWED_HOSTS) + ['testserver']

for hostname in hostnames:
if hostname == '*':
hostname = ''
self.mount('http://%s' % hostname, adapter)
self.mount('https://%s' % hostname, adapter)

def request(self, method, url, *args, **kwargs):
if ':' not in url:
url = 'http://testserver/' + url.lstrip('/')
return super(DjangoTestSession, self).request(method, url, *args, **kwargs)


class APIRequestFactory(DjangoRequestFactory):
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
Expand Down Expand Up @@ -221,6 +306,10 @@ class APITransactionTestCase(testcases.TransactionTestCase):
class APITestCase(testcases.TestCase):
client_class = APIClient

def _pre_setup(self):
super(APITestCase, self)._pre_setup()
self.requests = DjangoTestSession()


class APISimpleTestCase(testcases.SimpleTestCase):
client_class = APIClient
Expand Down
137 changes: 137 additions & 0 deletions tests/test_requests_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from __future__ import unicode_literals

from django.conf.urls import url
from django.test import override_settings

from rest_framework.response import Response
from rest_framework.test import APITestCase
from rest_framework.views import APIView


class Root(APIView):
def get(self, request):
return Response({
'method': request.method,
'query_params': request.query_params,
})

def post(self, request):
files = {
key: (value.name, value.read())
for key, value in request.FILES.items()
}
post = request.POST
json = None
if request.META.get('CONTENT_TYPE') == 'application/json':
json = request.data

return Response({
'method': request.method,
'query_params': request.query_params,
'POST': post,
'FILES': files,
'JSON': json
})


class Headers(APIView):
def get(self, request):
headers = {
key[5:]: value
for key, value in request.META.items()
if key.startswith('HTTP_')
}
return Response({
'method': request.method,
'headers': headers
})


urlpatterns = [
url(r'^$', Root.as_view()),
url(r'^headers/$', Headers.as_view()),
]


@override_settings(ROOT_URLCONF='tests.test_requests_client')
class RequestsClientTests(APITestCase):
def test_get_request(self):
response = self.requests.get('/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'GET',
'query_params': {}
}
assert response.json() == expected

def test_get_request_query_params_in_url(self):
response = self.requests.get('/?key=value')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'GET',
'query_params': {'key': 'value'}
}
assert response.json() == expected

def test_get_request_query_params_by_kwarg(self):
response = self.requests.get('/', params={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'GET',
'query_params': {'key': 'value'}
}
assert response.json() == expected

def test_get_with_headers(self):
response = self.requests.get('/headers/', headers={'User-Agent': 'example'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
headers = response.json()['headers']
assert headers['USER-AGENT'] == 'example'

def test_post_form_request(self):
response = self.requests.post('/', data={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'POST',
'query_params': {},
'POST': {'key': 'value'},
'FILES': {},
'JSON': None
}
assert response.json() == expected

def test_post_json_request(self):
response = self.requests.post('/', json={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'POST',
'query_params': {},
'POST': {},
'FILES': {},
'JSON': {'key': 'value'}
}
assert response.json() == expected

def test_post_multipart_request(self):
files = {
'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n')
}
response = self.requests.post('/', files=files)
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
'method': 'POST',
'query_params': {},
'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']},
'POST': {},
'JSON': None
}
assert response.json() == expected

# cookies/session auth