Backport test case improvements

Add random primary key mixin. Split test case code into mixins.
Make the view test case and the API test cases part of the same
class hierachy. Update tests that failed due to the new import
locations.

Signed-off-by: Roberto Rosario <roberto.rosario.gonzalez@gmail.com>
This commit is contained in:
Roberto Rosario
2019-04-02 02:31:35 -04:00
parent 8cbae9021b
commit 586d41eeff
13 changed files with 310 additions and 266 deletions

View File

@@ -0,0 +1,37 @@
from __future__ import unicode_literals
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from acls.models import AccessControlList
from permissions.models import Role
from permissions.tests.literals import TEST_ROLE_LABEL
from permissions.tests.mixins import RoleTestCaseMixin
from user_management.tests.literals import (
TEST_ADMIN_PASSWORD, TEST_ADMIN_USERNAME, TEST_ADMIN_EMAIL,
TEST_GROUP_NAME, TEST_USER_EMAIL, TEST_USER_USERNAME, TEST_USER_PASSWORD
)
class ACLTestCaseMixin(RoleTestCaseMixin):
def setUp(self):
super(ACLTestCaseMixin, self).setUp()
self.admin_user = get_user_model().objects.create_superuser(
username=TEST_ADMIN_USERNAME, email=TEST_ADMIN_EMAIL,
password=TEST_ADMIN_PASSWORD
)
self.user = get_user_model().objects.create_user(
username=TEST_USER_USERNAME, email=TEST_USER_EMAIL,
password=TEST_USER_PASSWORD
)
self.group = Group.objects.create(name=TEST_GROUP_NAME)
self.role = Role.objects.create(label=TEST_ROLE_LABEL)
self.group.user_set.add(self.user)
self.role.groups.add(self.group)
def grant_access(self, permission, obj):
AccessControlList.objects.grant(
permission=permission, role=self.role, obj=obj
)

View File

@@ -53,6 +53,7 @@ class ACLAPITestCase(DocumentTestMixin, BaseAPITestCase):
) )
def test_object_acl_delete_view(self): def test_object_acl_delete_view(self):
self.expected_content_type = None
self._create_acl() self._create_acl()
response = self.delete( response = self.delete(
@@ -87,6 +88,7 @@ class ACLAPITestCase(DocumentTestMixin, BaseAPITestCase):
) )
def test_object_acl_permission_delete_view(self): def test_object_acl_permission_delete_view(self):
self.expected_content_type = None
self._create_acl() self._create_acl()
permission = self.acl.permissions.first() permission = self.acl.permissions.first()

View File

@@ -35,21 +35,6 @@ class DocumentCheckoutTestCase(DocumentTestMixin, BaseTestCase):
) )
) )
def test_version_creation_blocking(self):
# Silence unrelated logging
logging.getLogger('documents.models').setLevel(logging.CRITICAL)
expiration_datetime = now() + datetime.timedelta(days=1)
DocumentCheckout.objects.checkout_document(
document=self.document, expiration_datetime=expiration_datetime,
user=self.admin_user, block_new_version=True
)
with self.assertRaises(NewDocumentVersionNotAllowed):
with open(TEST_SMALL_DOCUMENT_PATH, mode='rb') as file_object:
self.document.new_version(file_object=file_object)
def test_checkin_in(self): def test_checkin_in(self):
expiration_datetime = now() + datetime.timedelta(days=1) expiration_datetime = now() + datetime.timedelta(days=1)
@@ -100,16 +85,6 @@ class DocumentCheckoutTestCase(DocumentTestMixin, BaseTestCase):
self.assertFalse(self.document.is_checked_out()) self.assertFalse(self.document.is_checked_out())
def test_blocking_new_versions(self):
# Silence unrelated logging
logging.getLogger('documents.models').setLevel(logging.CRITICAL)
NewVersionBlock.objects.block(document=self.document)
with self.assertRaises(NewDocumentVersionNotAllowed):
with open(TEST_SMALL_DOCUMENT_PATH, mode='rb') as file_object:
self.document.new_version(file_object=file_object)
@override_settings(OCR_AUTO_OCR=False) @override_settings(OCR_AUTO_OCR=False)
class NewVersionBlockTestCase(DocumentTestMixin, BaseTestCase): class NewVersionBlockTestCase(DocumentTestMixin, BaseTestCase):
@@ -121,6 +96,16 @@ class NewVersionBlockTestCase(DocumentTestMixin, BaseTestCase):
NewVersionBlock.objects.first().document, self.document NewVersionBlock.objects.first().document, self.document
) )
def test_blocking_new_versions(self):
# Silence unrelated logging
logging.getLogger('documents.models').setLevel(logging.CRITICAL)
NewVersionBlock.objects.block(document=self.document)
with self.assertRaises(NewDocumentVersionNotAllowed):
with open(TEST_SMALL_DOCUMENT_PATH, mode='rb') as file_object:
self.document.new_version(file_object=file_object)
def test_unblocking(self): def test_unblocking(self):
NewVersionBlock.objects.create(document=self.document) NewVersionBlock.objects.create(document=self.document)
@@ -140,3 +125,18 @@ class NewVersionBlockTestCase(DocumentTestMixin, BaseTestCase):
self.assertFalse( self.assertFalse(
NewVersionBlock.objects.is_blocked(document=self.document) NewVersionBlock.objects.is_blocked(document=self.document)
) )
def test_version_creation_blocking(self):
# Silence unrelated logging
logging.getLogger('documents.models').setLevel(logging.CRITICAL)
expiration_datetime = now() + datetime.timedelta(days=1)
DocumentCheckout.objects.checkout_document(
document=self.document, expiration_datetime=expiration_datetime,
user=self.admin_user, block_new_version=True
)
with self.assertRaises(NewDocumentVersionNotAllowed):
with open(TEST_SMALL_DOCUMENT_PATH, mode='rb') as file_object:
self.document.new_version(file_object=file_object)

View File

@@ -8,7 +8,7 @@ from django.utils.timezone import now
from common.literals import TIME_DELTA_UNIT_DAYS from common.literals import TIME_DELTA_UNIT_DAYS
from documents.tests import GenericDocumentViewTestCase from documents.tests import GenericDocumentViewTestCase
from sources.links import link_upload_version from sources.links import link_upload_version
from user_management.tests import ( from user_management.tests.literals import (
TEST_USER_PASSWORD, TEST_USER_USERNAME, TEST_ADMIN_PASSWORD, TEST_USER_PASSWORD, TEST_USER_USERNAME, TEST_ADMIN_PASSWORD,
TEST_ADMIN_USERNAME, TEST_ADMIN_USERNAME,
) )

View File

@@ -1,30 +1,23 @@
from __future__ import absolute_import, unicode_literals from __future__ import absolute_import, unicode_literals
from django.conf.urls import url
from django.contrib.auth import get_user_model
from django.http import HttpResponse
from django.template import Context, Template
from django.test import TestCase from django.test import TestCase
from django.test.utils import ContextList
from django.urls import clear_url_caches, reverse
from django_downloadview import assert_download_response from django_downloadview import assert_download_response
from acls.tests.mixins import ACLTestCaseMixin
from permissions.classes import Permission from permissions.classes import Permission
from smart_settings.classes import Namespace from smart_settings.classes import Namespace
from user_management.tests import ( from user_management.tests.mixins import UserTestCaseMixin
TEST_ADMIN_PASSWORD, TEST_ADMIN_USERNAME, TEST_USER_USERNAME,
TEST_USER_PASSWORD
)
from .literals import TEST_VIEW_NAME, TEST_VIEW_URL
from .mixins import ( from .mixins import (
ContentTypeCheckMixin, DatabaseConversionMixin, OpenFileCheckMixin, ClientMethodsTestCaseMixin, ContentTypeCheckTestCaseMixin,
TempfileCheckMixin, UserMixin DatabaseConversionMixin, OpenFileCheckTestCaseMixin,
RandomPrimaryKeyModelMonkeyPatchMixin, TempfileCheckTestCasekMixin,
TestViewTestCaseMixin
) )
class BaseTestCase(DatabaseConversionMixin, UserMixin, ContentTypeCheckMixin, OpenFileCheckMixin, TempfileCheckMixin, TestCase): class BaseTestCase(RandomPrimaryKeyModelMonkeyPatchMixin, DatabaseConversionMixin, ACLTestCaseMixin, OpenFileCheckTestCaseMixin, TempfileCheckTestCasekMixin, TestCase):
""" """
This is the most basic test case class any test in the project should use. This is the most basic test case class any test in the project should use.
""" """
@@ -36,81 +29,9 @@ class BaseTestCase(DatabaseConversionMixin, UserMixin, ContentTypeCheckMixin, Op
Permission.invalidate_cache() Permission.invalidate_cache()
class GenericViewTestCase(BaseTestCase): class GenericViewTestCase(ClientMethodsTestCaseMixin, ContentTypeCheckTestCaseMixin, TestViewTestCaseMixin, UserTestCaseMixin, BaseTestCase):
def setUp(self): """
super(GenericViewTestCase, self).setUp() A generic view test case built on top of the base test case providing
self.has_test_view = False a single, user customizable view to test object resolution and shorthand
HTTP method functions.
def tearDown(self): """
from mayan.urls import urlpatterns
self.client.logout()
if self.has_test_view:
urlpatterns.pop(0)
super(GenericViewTestCase, self).tearDown()
def add_test_view(self, test_object):
from mayan.urls import urlpatterns
def test_view(request):
template = Template('{{ object }}')
context = Context(
{'object': test_object, 'resolved_object': test_object}
)
return HttpResponse(template.render(context=context))
urlpatterns.insert(0, url(TEST_VIEW_URL, test_view, name=TEST_VIEW_NAME))
clear_url_caches()
self.has_test_view = True
def get_test_view(self):
response = self.get(TEST_VIEW_NAME)
if isinstance(response.context, ContextList):
# template widget rendering causes test client response to be
# ContextList rather than RequestContext. Typecast to dictionary
# before updating.
result = dict(response.context).copy()
result.update({'request': response.wsgi_request})
return Context(result)
else:
response.context.update({'request': response.wsgi_request})
return Context(response.context)
def get(self, viewname=None, path=None, *args, **kwargs):
data = kwargs.pop('data', {})
follow = kwargs.pop('follow', False)
if viewname:
path = reverse(viewname=viewname, *args, **kwargs)
return self.client.get(
path=path, data=data, follow=follow
)
def login(self, username, password):
logged_in = self.client.login(username=username, password=password)
user = get_user_model().objects.get(username=username)
self.assertTrue(logged_in)
self.assertTrue(user.is_authenticated)
def login_user(self):
self.login(username=TEST_USER_USERNAME, password=TEST_USER_PASSWORD)
def login_admin_user(self):
self.login(username=TEST_ADMIN_USERNAME, password=TEST_ADMIN_PASSWORD)
def logout(self):
self.client.logout()
def post(self, viewname=None, path=None, *args, **kwargs):
data = kwargs.pop('data', {})
follow = kwargs.pop('follow', False)
if viewname:
path = reverse(viewname=viewname, *args, **kwargs)
return self.client.post(
path=path, data=data, follow=follow
)

View File

@@ -2,19 +2,21 @@ from __future__ import unicode_literals
import glob import glob
import os import os
import random
from furl import furl
from django.conf import settings from django.conf import settings
from django.contrib.auth import get_user_model from django.conf.urls import url
from django.contrib.auth.models import Group
from django.core import management from django.core import management
from django.db import models
from django.db.models.signals import post_save, pre_save
from django.http import HttpResponse
from django.template import Context, Template
from django.test.utils import ContextList
from django.urls import clear_url_caches, reverse
from acls.models import AccessControlList from .literals import TEST_VIEW_NAME, TEST_VIEW_URL
from permissions.models import Role
from permissions.tests.literals import TEST_ROLE_LABEL
from user_management.tests import (
TEST_ADMIN_PASSWORD, TEST_ADMIN_USERNAME, TEST_ADMIN_EMAIL,
TEST_GROUP_NAME, TEST_USER_EMAIL, TEST_USER_USERNAME, TEST_USER_PASSWORD
)
from ..settings import setting_temporary_directory from ..settings import setting_temporary_directory
@@ -22,24 +24,75 @@ if getattr(settings, 'COMMON_TEST_FILE_HANDLES', False):
import psutil import psutil
class ContentTypeCheckMixin(object): class ClientMethodsTestCaseMixin(object):
def _build_verb_kwargs(self, viewname=None, path=None, *args, **kwargs):
data = kwargs.pop('data', {})
follow = kwargs.pop('follow', False)
query = kwargs.pop('query', {})
if viewname:
path = reverse(viewname=viewname, *args, **kwargs)
path = furl(url=path)
path.args.update(query)
return {'follow': follow, 'data': data, 'path': path.tostr()}
def delete(self, viewname=None, path=None, *args, **kwargs):
return self.client.delete(
**self._build_verb_kwargs(
path=path, viewname=viewname, *args, **kwargs
)
)
def get(self, viewname=None, path=None, *args, **kwargs):
return self.client.get(
**self._build_verb_kwargs(
path=path, viewname=viewname, *args, **kwargs
)
)
def patch(self, viewname=None, path=None, *args, **kwargs):
return self.client.patch(
**self._build_verb_kwargs(
path=path, viewname=viewname, *args, **kwargs
)
)
def post(self, viewname=None, path=None, *args, **kwargs):
return self.client.post(
**self._build_verb_kwargs(
path=path, viewname=viewname, *args, **kwargs
)
)
def put(self, viewname=None, path=None, *args, **kwargs):
return self.client.put(
**self._build_verb_kwargs(
path=path, viewname=viewname, *args, **kwargs
)
)
class ContentTypeCheckTestCaseMixin(object):
expected_content_type = 'text/html; charset=utf-8' expected_content_type = 'text/html; charset=utf-8'
def _pre_setup(self): def _pre_setup(self):
super(ContentTypeCheckMixin, self)._pre_setup() super(ContentTypeCheckTestCaseMixin, self)._pre_setup()
test_instance = self test_instance = self
class CustomClient(self.client_class): class CustomClient(self.client_class):
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
response = super(CustomClient, self).request(*args, **kwargs) response = super(CustomClient, self).request(*args, **kwargs)
content_type = response._headers['content-type'][1] content_type = response._headers.get('content-type', [None, ''])[1]
test_instance.assertEqual( if test_instance.expected_content_type:
content_type, test_instance.expected_content_type, test_instance.assertEqual(
msg='Unexpected response content type: {}, expected: {}.'.format( content_type, test_instance.expected_content_type,
content_type, test_instance.expected_content_type msg='Unexpected response content type: {}, expected: {}.'.format(
content_type, test_instance.expected_content_type
)
) )
)
return response return response
@@ -53,7 +106,7 @@ class DatabaseConversionMixin(object):
) )
class OpenFileCheckMixin(object): class OpenFileCheckTestCaseMixin(object):
def _get_descriptor_count(self): def _get_descriptor_count(self):
process = psutil.Process() process = psutil.Process()
return process.num_fds() return process.num_fds()
@@ -63,7 +116,7 @@ class OpenFileCheckMixin(object):
return process.open_files() return process.open_files()
def setUp(self): def setUp(self):
super(OpenFileCheckMixin, self).setUp() super(OpenFileCheckTestCaseMixin, self).setUp()
if getattr(settings, 'COMMON_TEST_FILE_HANDLES', False): if getattr(settings, 'COMMON_TEST_FILE_HANDLES', False):
self._open_files = self._get_open_files() self._open_files = self._get_open_files()
@@ -78,10 +131,84 @@ class OpenFileCheckMixin(object):
self._skip_file_descriptor_test = False self._skip_file_descriptor_test = False
super(OpenFileCheckMixin, self).tearDown() super(OpenFileCheckTestCaseMixin, self).tearDown()
class TempfileCheckMixin(object): class RandomPrimaryKeyModelMonkeyPatchMixin(object):
random_primary_key_random_floor = 100
random_primary_key_random_ceiling = 10000
random_primary_key_maximum_attempts = 100
@staticmethod
def get_unique_primary_key(model):
pk_list = model._meta.default_manager.values_list('pk', flat=True)
attempts = 0
while True:
primary_key = random.randint(
RandomPrimaryKeyModelMonkeyPatchMixin.random_primary_key_random_floor,
RandomPrimaryKeyModelMonkeyPatchMixin.random_primary_key_random_ceiling
)
if primary_key not in pk_list:
break
attempts = attempts + 1
if attempts > RandomPrimaryKeyModelMonkeyPatchMixin.random_primary_key_maximum_attempts:
raise Exception(
'Maximum number of retries for an unique random primary '
'key reached.'
)
return primary_key
def setUp(self):
self.method_save_original = models.Model.save
def method_save_new(instance, *args, **kwargs):
if instance.pk:
return self.method_save_original(instance, *args, **kwargs)
else:
# Set meta.auto_created to True to have the original save_base
# not send the pre_save signal which would normally send
# the instance without a primary key. Since we assign a random
# primary key any pre_save signal handler that relies on an
# empty primary key will fail.
# The meta.auto_created and manual pre_save sending emulates
# the original behavior. Since meta.auto_created also disables
# the post_save signal we must also send it ourselves.
# This hack work with Django 1.11 .save_base() but can break
# in future versions if that method is updated.
pre_save.send(
sender=instance.__class__, instance=instance, raw=False,
update_fields=None,
)
instance._meta.auto_created = True
instance.pk = RandomPrimaryKeyModelMonkeyPatchMixin.get_unique_primary_key(
model=instance._meta.model
)
instance.id = instance.pk
result = instance.save_base(force_insert=True)
instance._meta.auto_created = False
post_save.send(
sender=instance.__class__, instance=instance, created=True,
update_fields=None, raw=False
)
return result
setattr(models.Model, 'save', method_save_new)
super(RandomPrimaryKeyModelMonkeyPatchMixin, self).setUp()
def tearDown(self):
models.Model.save = self.method_save_original
super(RandomPrimaryKeyModelMonkeyPatchMixin, self).tearDown()
class TempfileCheckTestCasekMixin(object):
# Ignore the jvmstat instrumentation and GitLab's CI .config files # Ignore the jvmstat instrumentation and GitLab's CI .config files
# Ignore LibreOffice fontconfig cache dir # Ignore LibreOffice fontconfig cache dir
ignore_globs = ('hsperfdata_*', '.config', '.cache') ignore_globs = ('hsperfdata_*', '.config', '.cache')
@@ -106,7 +233,7 @@ class TempfileCheckMixin(object):
) - set(ignored_result) ) - set(ignored_result)
def setUp(self): def setUp(self):
super(TempfileCheckMixin, self).setUp() super(TempfileCheckTestCasekMixin, self).setUp()
if getattr(settings, 'COMMON_TEST_TEMP_FILES', False): if getattr(settings, 'COMMON_TEST_TEMP_FILES', False):
self._temporary_items = self._get_temporary_entries() self._temporary_items = self._get_temporary_entries()
@@ -121,33 +248,43 @@ class TempfileCheckMixin(object):
','.join(final_temporary_items - self._temporary_items) ','.join(final_temporary_items - self._temporary_items)
) )
) )
super(TempfileCheckMixin, self).tearDown() super(TempfileCheckTestCasekMixin, self).tearDown()
class UserMixin(object): class TestViewTestCaseMixin(object):
def setUp(self): has_test_view = False
super(UserMixin, self).setUp()
self.admin_user = get_user_model().objects.create_superuser(
username=TEST_ADMIN_USERNAME, email=TEST_ADMIN_EMAIL,
password=TEST_ADMIN_PASSWORD
)
self.user = get_user_model().objects.create_user( def tearDown(self):
username=TEST_USER_USERNAME, email=TEST_USER_EMAIL, from mayan.urls import urlpatterns
password=TEST_USER_PASSWORD
)
self.group = Group.objects.create(name=TEST_GROUP_NAME) self.client.logout()
self.role = Role.objects.create(label=TEST_ROLE_LABEL) if self.has_test_view:
self.group.user_set.add(self.user) urlpatterns.pop(0)
self.role.groups.add(self.group) super(TestViewTestCaseMixin, self).tearDown()
def grant_access(self, permission, obj): def add_test_view(self, test_object):
AccessControlList.objects.grant( from mayan.urls import urlpatterns
permission=permission, role=self.role, obj=obj
)
def grant_permission(self, permission): def test_view(request):
self.role.permissions.add( template = Template('{{ object }}')
permission.stored_permission context = Context(
) {'object': test_object, 'resolved_object': test_object}
)
return HttpResponse(template.render(context=context))
urlpatterns.insert(0, url(TEST_VIEW_URL, test_view, name=TEST_VIEW_NAME))
clear_url_caches()
self.has_test_view = True
def get_test_view(self):
response = self.get(TEST_VIEW_NAME)
if isinstance(response.context, ContextList):
# template widget rendering causes test client response to be
# ContextList rather than RequestContext. Typecast to dictionary
# before updating.
result = dict(response.context).copy()
result.update({'request': response.wsgi_request})
return Context(result)
else:
response.context.update({'request': response.wsgi_request})
return Context(response.context)

View File

@@ -135,6 +135,7 @@ class DocumentTypeAPITestCase(BaseAPITestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_document_type_delete_with_access(self): def test_document_type_delete_with_access(self):
self.expected_content_type = None
self.document_type = DocumentType.objects.create( self.document_type = DocumentType.objects.create(
label=TEST_DOCUMENT_TYPE_LABEL label=TEST_DOCUMENT_TYPE_LABEL
) )
@@ -522,6 +523,8 @@ class TrashedDocumentAPITestCase(DocumentTestMixin, BaseAPITestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_document_move_to_trash_with_access(self): def test_document_move_to_trash_with_access(self):
self.expected_content_type = None
self.document = self.upload_document() self.document = self.upload_document()
self.grant_access( self.grant_access(
permission=permission_document_trash, obj=self.document permission=permission_document_trash, obj=self.document
@@ -546,6 +549,8 @@ class TrashedDocumentAPITestCase(DocumentTestMixin, BaseAPITestCase):
self.assertEqual(Document.trash.count(), 1) self.assertEqual(Document.trash.count(), 1)
def test_trashed_document_delete_from_trash_with_access(self): def test_trashed_document_delete_from_trash_with_access(self):
self.expected_content_type = None
self.document = self.upload_document() self.document = self.upload_document()
self.document.delete() self.document.delete()
self.grant_access(permission=permission_document_delete, obj=self.document) self.grant_access(permission=permission_document_delete, obj=self.document)

View File

@@ -0,0 +1,8 @@
from __future__ import unicode_literals
class RoleTestCaseMixin(object):
def grant_permission(self, permission):
self.role.permissions.add(
permission.stored_permission
)

View File

@@ -1,101 +1,19 @@
from __future__ import absolute_import, unicode_literals from __future__ import absolute_import, unicode_literals
from django.contrib.auth import get_user_model
from django.urls import reverse
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from common.tests.mixins import UserMixin from common.tests import GenericViewTestCase
from permissions.classes import Permission from permissions.classes import Permission
from smart_settings.classes import Namespace from smart_settings.classes import Namespace
from user_management.tests import (
TEST_ADMIN_PASSWORD, TEST_ADMIN_USERNAME, TEST_USER_USERNAME,
TEST_USER_PASSWORD
)
class BaseAPITestCase(UserMixin, APITestCase): class BaseAPITestCase(APITestCase, GenericViewTestCase):
""" """
API test case class that invalidates permissions and smart settings API test case class that invalidates permissions and smart settings
""" """
expected_content_type = None
def setUp(self): def setUp(self):
super(BaseAPITestCase, self).setUp() super(BaseAPITestCase, self).setUp()
Namespace.invalidate_cache_all() Namespace.invalidate_cache_all()
Permission.invalidate_cache() Permission.invalidate_cache()
def tearDown(self):
self.client.logout()
super(BaseAPITestCase, self).tearDown()
def delete(self, viewname=None, path=None, *args, **kwargs):
data = kwargs.pop('data', {})
follow = kwargs.pop('follow', False)
if viewname:
path = reverse(viewname=viewname, *args, **kwargs)
return self.client.delete(
path=path, data=data, follow=follow
)
def get(self, viewname=None, path=None, *args, **kwargs):
data = kwargs.pop('data', {})
follow = kwargs.pop('follow', False)
if viewname:
path = reverse(viewname=viewname, *args, **kwargs)
return self.client.get(
path=path, data=data, follow=follow
)
def login(self, username, password):
logged_in = self.client.login(username=username, password=password)
user = get_user_model().objects.get(username=username)
self.assertTrue(logged_in)
self.assertTrue(user.is_authenticated)
return user.is_authenticated
def login_user(self):
self.login(username=TEST_USER_USERNAME, password=TEST_USER_PASSWORD)
def login_admin_user(self):
self.login(username=TEST_ADMIN_USERNAME, password=TEST_ADMIN_PASSWORD)
def logout(self):
self.client.logout()
def patch(self, viewname=None, path=None, *args, **kwargs):
data = kwargs.pop('data', {})
follow = kwargs.pop('follow', False)
if viewname:
path = reverse(viewname=viewname, *args, **kwargs)
return self.client.patch(
path=path, data=data, follow=follow
)
def post(self, viewname=None, path=None, *args, **kwargs):
data = kwargs.pop('data', {})
follow = kwargs.pop('follow', False)
if viewname:
path = reverse(viewname=viewname, *args, **kwargs)
return self.client.post(
path=path, data=data, follow=follow
)
def put(self, viewname=None, path=None, *args, **kwargs):
data = kwargs.pop('data', {})
follow = kwargs.pop('follow', False)
if viewname:
path = reverse(viewname=viewname, *args, **kwargs)
return self.client.put(
path=path, data=data, follow=follow
)

View File

@@ -1 +0,0 @@
from .literals import * # NOQA

View File

@@ -1,11 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
__all__ = (
'TEST_ADMIN_EMAIL', 'TEST_ADMIN_PASSWORD', 'TEST_ADMIN_USERNAME',
'TEST_GROUP_NAME', 'TEST_GROUP_NAME_EDITED', 'TEST_USER_EMAIL',
'TEST_USER_PASSWORD', 'TEST_USER_PASSWORD_EDITED', 'TEST_USER_USERNAME'
)
TEST_ADMIN_EMAIL = 'admin@example.com' TEST_ADMIN_EMAIL = 'admin@example.com'
TEST_ADMIN_PASSWORD = 'test admin password' TEST_ADMIN_PASSWORD = 'test admin password'
TEST_ADMIN_USERNAME = 'test_admin' TEST_ADMIN_USERNAME = 'test_admin'

View File

@@ -3,7 +3,9 @@ from __future__ import unicode_literals
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from .literals import ( from .literals import (
TEST_USER_2_EMAIL, TEST_USER_2_PASSWORD, TEST_USER_2_USERNAME TEST_ADMIN_PASSWORD, TEST_ADMIN_USERNAME, TEST_USER_PASSWORD,
TEST_USER_USERNAME, TEST_USER_2_EMAIL, TEST_USER_2_PASSWORD,
TEST_USER_2_USERNAME
) )
@@ -18,3 +20,22 @@ class UserTestMixin(object):
self.user_2 = get_user_model().objects.create( self.user_2 = get_user_model().objects.create(
username=TEST_USER_2_USERNAME username=TEST_USER_2_USERNAME
) )
class UserTestCaseMixin(object):
def tearDown(self):
self.client.logout()
super(UserTestCaseMixin, self).tearDown()
def login(self, username, password):
logged_in = self.client.login(username=username, password=password)
return logged_in
def login_user(self):
self.login(username=TEST_USER_USERNAME, password=TEST_USER_PASSWORD)
def login_admin_user(self):
self.login(username=TEST_ADMIN_USERNAME, password=TEST_ADMIN_PASSWORD)
def logout(self):
self.client.logout()

View File

@@ -87,10 +87,11 @@ class UserManagementViewTestCase(UserTestMixin, GenericViewTestCase):
self.logout() self.logout()
with self.assertRaises(AssertionError): result = self.login(
self.login( username=TEST_USER_2_USERNAME, password=TEST_USER_PASSWORD_EDITED
username=TEST_USER_2_USERNAME, password=TEST_USER_PASSWORD_EDITED )
)
self.assertFalse(result)
response = self.get('common:current_user_details') response = self.get('common:current_user_details')
@@ -134,10 +135,11 @@ class UserManagementViewTestCase(UserTestMixin, GenericViewTestCase):
self.logout() self.logout()
with self.assertRaises(AssertionError): result = self.login(
self.login( username=TEST_USER_2_USERNAME, password=TEST_USER_PASSWORD_EDITED
username=TEST_USER_2_USERNAME, password=TEST_USER_PASSWORD_EDITED )
)
self.assertFalse(result)
response = self.get('common:current_user_details') response = self.get('common:current_user_details')
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)