Skip to content
Prev Previous commit
Next Next commit
Add get_requests_client
  • Loading branch information
lovelydinosaur committed Aug 18, 2016
commit 0cc3f5008fcd41d1597776c43dc57502ed2a7542
12 changes: 5 additions & 7 deletions rest_framework/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def request(self, method, url, *args, **kwargs):
return super(DjangoTestSession, self).request(method, url, *args, **kwargs)


def get_requests_client():
assert requests is not None, 'requests must be installed'
return DjangoTestSession()


class APIRequestFactory(DjangoRequestFactory):
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
Expand Down Expand Up @@ -321,13 +326,6 @@ class APITransactionTestCase(testcases.TransactionTestCase):
class APITestCase(testcases.TestCase):
client_class = APIClient

@property
def requests(self):
if not hasattr(self, '_requests'):
assert requests is not None, 'requests must be installed'
self._requests = DjangoTestSession()
return self._requests


class APISimpleTestCase(testcases.SimpleTestCase):
client_class = APIClient
Expand Down
37 changes: 23 additions & 14 deletions tests/test_requests_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from rest_framework.compat import is_authenticated, requests
from rest_framework.response import Response
from rest_framework.test import APITestCase
from rest_framework.test import APITestCase, get_requests_client
from rest_framework.views import APIView


Expand Down Expand Up @@ -103,7 +103,8 @@ def post(self, request):
@override_settings(ROOT_URLCONF='tests.test_requests_client')
class RequestsClientTests(APITestCase):
def test_get_request(self):
response = self.requests.get('/')
client = get_requests_client()
response = client.get('/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
Expand All @@ -113,7 +114,8 @@ def test_get_request(self):
assert response.json() == expected

def test_get_request_query_params_in_url(self):
response = self.requests.get('/?key=value')
client = get_requests_client()
response = client.get('/?key=value')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
Expand All @@ -123,7 +125,8 @@ def test_get_request_query_params_in_url(self):
assert response.json() == expected

def test_get_request_query_params_by_kwarg(self):
response = self.requests.get('/', params={'key': 'value'})
client = get_requests_client()
response = client.get('/', params={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
Expand All @@ -133,14 +136,16 @@ def test_get_request_query_params_by_kwarg(self):
assert response.json() == expected

def test_get_with_headers(self):
response = self.requests.get('/headers/', headers={'User-Agent': 'example'})
client = get_requests_client()
response = client.get('/headers/', headers={'User-Agent': 'example'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
headers = response.json()['headers']
assert headers['USER-AGENT'] == 'example'

def test_post_form_request(self):
response = self.requests.post('/', data={'key': 'value'})
client = get_requests_client()
response = client.post('/', data={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
Expand All @@ -153,7 +158,8 @@ def test_post_form_request(self):
assert response.json() == expected

def test_post_json_request(self):
response = self.requests.post('/', json={'key': 'value'})
client = get_requests_client()
response = client.post('/', json={'key': 'value'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
Expand All @@ -166,10 +172,11 @@ def test_post_json_request(self):
assert response.json() == expected

def test_post_multipart_request(self):
client = get_requests_client()
files = {
'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n')
}
response = self.requests.post('/', files=files)
response = client.post('/', files=files)
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
Expand All @@ -182,27 +189,29 @@ def test_post_multipart_request(self):
assert response.json() == expected

def test_session(self):
response = self.requests.get('/session/')
client = get_requests_client()
response = client.get('/session/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {}
assert response.json() == expected

response = self.requests.post('/session/', json={'example': 'abc'})
response = client.post('/session/', json={'example': 'abc'})
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {'example': 'abc'}
assert response.json() == expected

response = self.requests.get('/session/')
response = client.get('/session/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {'example': 'abc'}
assert response.json() == expected

def test_auth(self):
# Confirm session is not authenticated
response = self.requests.get('/auth/')
client = get_requests_client()
response = client.get('/auth/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
Expand All @@ -217,7 +226,7 @@ def test_auth(self):
user.save()

# Perform a login
response = self.requests.post('/auth/', json={
response = client.post('/auth/', json={
'username': 'tom',
'password': 'password'
}, headers={'X-CSRFToken': csrftoken})
Expand All @@ -229,7 +238,7 @@ def test_auth(self):
assert response.json() == expected

# Confirm session is authenticated
response = self.requests.get('/auth/')
response = client.get('/auth/')
assert response.status_code == 200
assert response.headers['Content-Type'] == 'application/json'
expected = {
Expand Down