diff --git a/cve_bin_tool/version_compare.py b/cve_bin_tool/version_compare.py index 8bde559ff4..8ca605f711 100644 --- a/cve_bin_tool/version_compare.py +++ b/cve_bin_tool/version_compare.py @@ -20,6 +20,12 @@ class CannotParseVersionException(Exception): """ +class UnknownVersion(Exception): + """ + Thrown if version is null or "unknown". + """ + + def parse_version(version_string: str): """ Splits a version string into an array for comparison. @@ -27,13 +33,20 @@ def parse_version(version_string: str): e.g. 1.1.1a would become [1, 1, 1, a] """ - versionString = version_string + + if not version_string or version_string.lower() == "unknown": + raise UnknownVersion(f"version string = {version_string}") + + versionString = version_string.strip() versionArray = [] # convert - and _ to be treated like . below # we could switch to a re split but it seems to leave blanks so this is less hassle versionString = versionString.replace("-", ".") versionString = versionString.replace("_", ".") + # Note: there may be other non-alphanumeric characters we want to add here in the + # future, but we'd like to look at those cases before adding them in case the version + # logic is very different. # Attempt a split split_version = versionString.split(".") @@ -52,6 +65,7 @@ def parse_version(version_string: str): versionArray.append(section) # if it looks like 42a split out the letters and numbers + # We will treat 42a as coming *after* version 42. elif re.match(number_letter, section): result = re.findall(number_letter, section) @@ -65,6 +79,7 @@ def parse_version(version_string: str): # if it looks like rc1 or dev7 we'll leave it together as it may be some kind of pre-release # and we'll probably want to handle it specially in the compare. + # We need to threat 42dev7 as coming *before* version 42. elif re.match(letter_number, section): versionArray.append(section) @@ -101,10 +116,12 @@ def version_compare(v1: str, v2: str): # This might be a bad choice in some cases: Do we want ag < z? # I suspect projects using letters in version names may not use ranges in nvd # for this reason (e.g. openssl) + # Converting to lower() so that 3.14a == 3.14A + # but this may not be ideal in all cases elif v1_array[i].isalpha() and v2_array[i].isalpha(): - if v1_array[i] > v2_array[i]: + if v1_array[i].lower() > v2_array[i].lower(): return 1 - if v1_array[i] < v2_array[i]: + if v1_array[i].lower() < v2_array[i].lower(): return -1 else: @@ -170,25 +187,33 @@ class Version(str): """ def __cmp__(self, other): + """compare""" return version_compare(self, other) def __lt__(self, other): + """<""" return bool(version_compare(self, other) < 0) def __le__(self, other): + """<=""" return bool(version_compare(self, other) <= 0) def __gt__(self, other): + """>""" return bool(version_compare(self, other) > 0) def __ge__(self, other): + """>=""" return bool(version_compare(self, other) >= 0) def __eq__(self, other): + """==""" return bool(version_compare(self, other) == 0) def __ne__(self, other): + """!=""" return bool(version_compare(self, other) != 0) def __repr__(self): + """print the version string""" return f"Version: {self}" diff --git a/test/test_version_compare.py b/test/test_version_compare.py index 3ec5e493e0..8532db340d 100644 --- a/test/test_version_compare.py +++ b/test/test_version_compare.py @@ -1,14 +1,20 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: GPL-3.0-or-later +import pytest -from cve_bin_tool.version_compare import Version +from cve_bin_tool.version_compare import UnknownVersion, Version class TestVersionCompare: + """Test the cve_bin_tool.version_compare functionality""" + def test_eq(self): """Make sure == works between versions""" assert Version("1.2") == Version("1.2") + assert Version("1.1a") == Version("1.1A") + assert Version("4.4.A") == Version("4.4.a") + assert Version("5.6 ") == Version("5.6") def test_lt(self): """Make sure < works between versions, including some with unusual version schemes""" @@ -23,6 +29,7 @@ def test_lt(self): assert Version("1.2.post8") < Version("1.2.1") assert Version("rc5") < Version("rc10") assert Version("9.10") < Version("9.10.post") + assert Version("5.3.9") < Version("5.4") def test_gt(self): """Make sure > works between versions, including some with unusual version schemes""" @@ -34,3 +41,10 @@ def test_gt(self): assert Version("10.2.3.rc1") > Version("10.2.3.rc0") assert Version("10.2.3.rc10") > Version("10.2.3.rc2") assert Version("9.10.post") > Version("9.10") + assert Version("5.5") > Version("5.4.1") + + def test_error(self): + """Make sure 'unknown' and blank strings raise appropriate errors""" + with pytest.raises(UnknownVersion): + Version("6") > Version("unknown") + Version("") > Version("6")