Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add tests
Signed-off-by: Tushar Goel <tushar.goel.dav@gmail.com>
  • Loading branch information
TG1999 committed Jan 6, 2025
commit da1862ea65884c7412249d61330fbf7ea0fd1806
14 changes: 9 additions & 5 deletions vulnerabilities/api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#


from django.db.models import Prefetch
from django_filters import rest_framework as filters
from drf_spectacular.utils import OpenApiParameter
from drf_spectacular.utils import extend_schema
Expand All @@ -20,8 +21,6 @@
from rest_framework.response import Response
from rest_framework.reverse import reverse

from vulnerabilities.api import PackageFilterSet
from vulnerabilities.api import VulnerabilitySeveritySerializer
from vulnerabilities.models import Package
from vulnerabilities.models import Vulnerability
from vulnerabilities.models import VulnerabilityReference
Expand Down Expand Up @@ -198,9 +197,8 @@ def get_affected_by_vulnerabilities(self, obj):
"""
Return a dictionary with vulnerabilities as keys and their details, including fixed_by_packages.
"""
vulnerabilities = obj.affected_by_vulnerabilities.prefetch_related("fixed_by_packages")
result = {}
for vuln in vulnerabilities:
for vuln in getattr(obj, "prefetched_affected_vulnerabilities", []):
fixed_by_package = vuln.fixed_by_packages.first()
purl = None
if fixed_by_package:
Expand Down Expand Up @@ -247,7 +245,13 @@ class PackageV2FilterSet(filters.FilterSet):


class PackageV2ViewSet(viewsets.ReadOnlyModelViewSet):
queryset = Package.objects.all()
queryset = Package.objects.all().prefetch_related(
Prefetch(
"affected_by_vulnerabilities",
queryset=Vulnerability.objects.prefetch_related("fixed_by_packages"),
to_attr="prefetched_affected_vulnerabilities",
)
)
serializer_class = PackageV2Serializer
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = PackageV2FilterSet
Expand Down
135 changes: 102 additions & 33 deletions vulnerabilities/tests/test_api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# See https://aboutcode.org for more information about nexB OSS projects.
#

from django.db.models import Prefetch
from django.urls import reverse
from packageurl import PackageURL
from rest_framework import status
Expand Down Expand Up @@ -67,6 +68,8 @@ def test_list_vulnerabilities(self):
"""
url = reverse("vulnerability-v2-list")
response = self.client.get(url, format="json")
with self.assertNumQueries(5):
response = self.client.get(url, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("results", response.data)
self.assertIn("vulnerabilities", response.data["results"])
Expand All @@ -80,7 +83,8 @@ def test_retrieve_vulnerability_detail(self):
Test retrieving vulnerability details by vulnerability_id.
"""
url = reverse("vulnerability-v2-detail", kwargs={"vulnerability_id": "VCID-1234"})
response = self.client.get(url, format="json")
with self.assertNumQueries(8):
response = self.client.get(url, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["vulnerability_id"], "VCID-1234")
self.assertEqual(response.data["summary"], "Test vulnerability 1")
Expand All @@ -93,7 +97,8 @@ def test_filter_vulnerability_by_vulnerability_id(self):
Test filtering vulnerabilities by vulnerability_id.
"""
url = reverse("vulnerability-v2-list")
response = self.client.get(url, {"vulnerability_id": "VCID-1234"}, format="json")
with self.assertNumQueries(4):
response = self.client.get(url, {"vulnerability_id": "VCID-1234"}, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["vulnerability_id"], "VCID-1234")

Expand All @@ -102,7 +107,8 @@ def test_filter_vulnerability_by_alias(self):
Test filtering vulnerabilities by alias.
"""
url = reverse("vulnerability-v2-list")
response = self.client.get(url, {"alias": "CVE-2021-5678"}, format="json")
with self.assertNumQueries(5):
response = self.client.get(url, {"alias": "CVE-2021-5678"}, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("results", response.data)
self.assertIn("vulnerabilities", response.data["results"])
Expand All @@ -116,7 +122,8 @@ def test_filter_vulnerabilities_multiple_ids(self):
Test filtering vulnerabilities by multiple vulnerability_ids.
"""
url = reverse("vulnerability-v2-list")
response = self.client.get(
with self.assertNumQueries(5):
response = self.client.get(
url, {"vulnerability_id": ["VCID-1234", "VCID-5678"]}, format="json"
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand All @@ -127,7 +134,8 @@ def test_filter_vulnerabilities_multiple_aliases(self):
Test filtering vulnerabilities by multiple aliases.
"""
url = reverse("vulnerability-v2-list")
response = self.client.get(
with self.assertNumQueries(5):
response = self.client.get(
url, {"alias": ["CVE-2021-1234", "CVE-2021-5678"]}, format="json"
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand All @@ -139,7 +147,8 @@ def test_invalid_vulnerability_id(self):
Should return 404 Not Found.
"""
url = reverse("vulnerability-v2-detail", kwargs={"vulnerability_id": "VCID-9999"})
response = self.client.get(url, format="json")
with self.assertNumQueries(5):
response = self.client.get(url, format="json")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_get_url_in_serializer(self):
Expand Down Expand Up @@ -207,7 +216,8 @@ def test_list_packages(self):
Should return a list of packages with their details and associated vulnerabilities.
"""
url = reverse("package-v2-list")
response = self.client.get(url, format="json")
with self.assertNumQueries(31):
response = self.client.get(url, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("results", response.data)
self.assertIn("packages", response.data["results"])
Expand All @@ -228,7 +238,8 @@ def test_filter_packages_by_purl(self):
Test filtering packages by one or more PURLs.
"""
url = reverse("package-v2-list")
response = self.client.get(url, {"purl": "pkg:pypi/django@3.2"}, format="json")
with self.assertNumQueries(19):
response = self.client.get(url, {"purl": "pkg:pypi/django@3.2"}, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data["results"]["packages"]), 1)
self.assertEqual(response.data["results"]["packages"][0]["purl"], "pkg:pypi/django@3.2")
Expand All @@ -238,7 +249,8 @@ def test_filter_packages_by_affected_vulnerability(self):
Test filtering packages by affected_by_vulnerability.
"""
url = reverse("package-v2-list")
response = self.client.get(url, {"affected_by_vulnerability": "VCID-1234"}, format="json")
with self.assertNumQueries(19):
response = self.client.get(url, {"affected_by_vulnerability": "VCID-1234"}, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data["results"]["packages"]), 1)
self.assertEqual(response.data["results"]["packages"][0]["purl"], "pkg:pypi/django@3.2")
Expand All @@ -248,29 +260,59 @@ def test_filter_packages_by_fixing_vulnerability(self):
Test filtering packages by fixing_vulnerability.
"""
url = reverse("package-v2-list")
response = self.client.get(url, {"fixing_vulnerability": "VCID-5678"}, format="json")
with self.assertNumQueries(18):
response = self.client.get(url, {"fixing_vulnerability": "VCID-5678"}, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data["results"]["packages"]), 1)
self.assertEqual(response.data["results"]["packages"][0]["purl"], "pkg:npm/lodash@4.17.20")

def test_package_serializer_fields(self):
"""
Test that the PackageV2Serializer returns the correct fields.
Test that the PackageV2Serializer returns the correct fields and formats them correctly.
"""
# Fetch the package
package = Package.objects.get(package_url="pkg:pypi/django@3.2")

# Ensure prefetched data is available for the serializer
package = (
Package.objects.filter(package_url="pkg:pypi/django@3.2")
.prefetch_related(
Prefetch(
"affected_by_vulnerabilities",
queryset=Vulnerability.objects.prefetch_related("fixed_by_packages"),
to_attr="prefetched_affected_vulnerabilities",
)
)
.first()
)

# Serialize the package
serializer = PackageV2Serializer(package)
data = serializer.data

# Verify the presence of required fields
self.assertIn("purl", data)
self.assertIn("affected_by_vulnerabilities", data)
self.assertIn("fixing_vulnerabilities", data)
self.assertIn("next_non_vulnerable_version", data)
self.assertIn("latest_non_vulnerable_version", data)
self.assertIn("risk_score", data)

# Verify field values
self.assertEqual(data["purl"], "pkg:pypi/django@3.2")
self.assertEqual(
data["affected_by_vulnerabilities"],
{"VCID-1234": {"vulnerability_id": "VCID-1234", "fixed_by_packages": None}},
)
self.assertEqual(data["fixing_vulnerabilities"], [])
self.assertEqual(data["next_non_vulnerable_version"], None)
self.assertEqual(data["latest_non_vulnerable_version"], None)
self.assertEqual(data["risk_score"], None)

# Verify affected_by_vulnerabilities structure
expected_affected_by_vulnerabilities = {
"VCID-1234": {"vulnerability_id": "VCID-1234", "fixed_by_packages": None}
}
self.assertEqual(data["affected_by_vulnerabilities"], expected_affected_by_vulnerabilities)

# Verify fixing_vulnerabilities structure
expected_fixing_vulnerabilities = []
self.assertEqual(data["fixing_vulnerabilities"], expected_fixing_vulnerabilities)

def test_list_packages_pagination(self):
"""
Expand Down Expand Up @@ -303,7 +345,8 @@ def test_invalid_vulnerability_filter(self):
Should return an empty list.
"""
url = reverse("package-v2-list")
response = self.client.get(url, {"affected_by_vulnerability": "VCID-9999"}, format="json")
with self.assertNumQueries(4):
response = self.client.get(url, {"affected_by_vulnerability": "VCID-9999"}, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data["results"]["packages"]), 0)

Expand All @@ -313,15 +356,27 @@ def test_invalid_purl_filter(self):
Should return an empty list.
"""
url = reverse("package-v2-list")
response = self.client.get(url, {"purl": "pkg:nonexistent/package@1.0.0"}, format="json")
with self.assertNumQueries(4):
response = self.client.get(url, {"purl": "pkg:nonexistent/package@1.0.0"}, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data["results"]["packages"]), 0)

def test_get_affected_by_vulnerabilities(self):
"""
Test the get_affected_by_vulnerabilities method in the serializer.
"""
package = Package.objects.get(package_url="pkg:pypi/django@3.2")
package = (
Package.objects.filter(package_url="pkg:pypi/django@3.2")
.prefetch_related(
Prefetch(
"affected_by_vulnerabilities",
queryset=Vulnerability.objects.prefetch_related("fixed_by_packages"),
to_attr="prefetched_affected_vulnerabilities",
)
)
.first()
)

serializer = PackageV2Serializer()
vulnerabilities = serializer.get_affected_by_vulnerabilities(package)
self.assertEqual(
Expand All @@ -345,7 +400,8 @@ def test_bulk_lookup_with_valid_purls(self):
"""
url = reverse("package-v2-bulk-lookup")
data = {"purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"]}
response = self.client.post(url, data, format="json")
with self.assertNumQueries(28):
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("packages", response.data)
self.assertIn("vulnerabilities", response.data)
Expand All @@ -369,7 +425,8 @@ def test_bulk_lookup_with_invalid_purls(self):
"""
url = reverse("package-v2-bulk-lookup")
data = {"purls": ["pkg:pypi/nonexistent@1.0.0", "pkg:npm/unknown@0.0.1"]}
response = self.client.post(url, data, format="json")
with self.assertNumQueries(4):
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Since the packages don't exist, the response should be empty
self.assertEqual(len(response.data["packages"]), 0)
Expand All @@ -382,7 +439,8 @@ def test_bulk_lookup_with_empty_purls(self):
"""
url = reverse("package-v2-bulk-lookup")
data = {"purls": []}
response = self.client.post(url, data, format="json")
with self.assertNumQueries(3):
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("error", response.data)
self.assertIn("message", response.data)
Expand All @@ -395,7 +453,8 @@ def test_bulk_search_with_valid_purls(self):
"""
url = reverse("package-v2-bulk-search")
data = {"purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"]}
response = self.client.post(url, data, format="json")
with self.assertNumQueries(28):
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("packages", response.data)
self.assertIn("vulnerabilities", response.data)
Expand All @@ -422,7 +481,8 @@ def test_bulk_search_with_purl_only_true(self):
"purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"],
"purl_only": True,
}
response = self.client.post(url, data, format="json")
with self.assertNumQueries(17):
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Since purl_only=True, response should be a list of PURLs
self.assertIsInstance(response.data, list)
Expand All @@ -448,7 +508,8 @@ def test_bulk_search_with_plain_purl_true(self):
"purls": ["pkg:pypi/django@3.2", "pkg:pypi/django@3.2?extension=tar.gz"],
"plain_purl": True,
}
response = self.client.post(url, data, format="json")
with self.assertNumQueries(16):
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("packages", response.data)
self.assertIn("vulnerabilities", response.data)
Expand All @@ -468,7 +529,8 @@ def test_bulk_search_with_purl_only_and_plain_purl_true(self):
"purl_only": True,
"plain_purl": True,
}
response = self.client.post(url, data, format="json")
with self.assertNumQueries(11):
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Response should be a list of plain PURLs
self.assertIsInstance(response.data, list)
Expand All @@ -483,7 +545,8 @@ def test_bulk_search_with_invalid_purls(self):
"""
url = reverse("package-v2-bulk-search")
data = {"purls": ["pkg:pypi/nonexistent@1.0.0", "pkg:npm/unknown@0.0.1"]}
response = self.client.post(url, data, format="json")
with self.assertNumQueries(4):
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Since the packages don't exist, the response should be empty
self.assertEqual(len(response.data["packages"]), 0)
Expand All @@ -496,7 +559,8 @@ def test_bulk_search_with_empty_purls(self):
"""
url = reverse("package-v2-bulk-search")
data = {"purls": []}
response = self.client.post(url, data, format="json")
with self.assertNumQueries(3):
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("error", response.data)
self.assertIn("message", response.data)
Expand All @@ -507,7 +571,8 @@ def test_all_vulnerable_packages(self):
Test the 'all' endpoint that returns all vulnerable package URLs.
"""
url = reverse("package-v2-all")
response = self.client.get(url, format="json")
with self.assertNumQueries(4):
response = self.client.get(url, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Since package1 is vulnerable, it should be returned
expected_purls = ["pkg:pypi/django@3.2"]
Expand All @@ -520,7 +585,8 @@ def test_lookup_with_valid_purl(self):
"""
url = reverse("package-v2-lookup")
data = {"purl": "pkg:pypi/django@3.2"}
response = self.client.post(url, data, format="json")
with self.assertNumQueries(12):
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(1, len(response.data))
self.assertIn("purl", response.data[0])
Expand All @@ -542,7 +608,8 @@ def test_lookup_with_invalid_purl(self):
"""
url = reverse("package-v2-lookup")
data = {"purl": "pkg:pypi/nonexistent@1.0.0"}
response = self.client.post(url, data, format="json")
with self.assertNumQueries(4):
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
# No packages or vulnerabilities should be returned
self.assertEqual(len(response.data), 0)
Expand All @@ -554,7 +621,8 @@ def test_lookup_with_missing_purl(self):
"""
url = reverse("package-v2-lookup")
data = {}
response = self.client.post(url, data, format="json")
with self.assertNumQueries(3):
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("error", response.data)
self.assertIn("message", response.data)
Expand All @@ -567,7 +635,8 @@ def test_lookup_with_invalid_purl_format(self):
"""
url = reverse("package-v2-lookup")
data = {"purl": "invalid_purl_format"}
response = self.client.post(url, data, format="json")
with self.assertNumQueries(4):
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
# No packages or vulnerabilities should be returned
self.assertEqual(len(response.data), 0)