Implemented most of the 'unittest' executor and runner. unittest

Fri, 13 May 2022 17:23:21 +0200

author
Detlev Offenbach <detlev@die-offenbachs.de>
date
Fri, 13 May 2022 17:23:21 +0200
branch
unittest
changeset 9062
7f27bf3b50c3
parent 9061
22dab1be7953
child 9063
f1d7dd7ae471

Implemented most of the 'unittest' executor and runner.

eric7/EricNetwork/EricJsonStreamReader.py file | annotate | diff | comparison | revisions
eric7/Unittest/Interfaces/UTExecutorBase.py file | annotate | diff | comparison | revisions
eric7/Unittest/Interfaces/UTFrameworkRegistry.py file | annotate | diff | comparison | revisions
eric7/Unittest/Interfaces/UnittestExecutor.py file | annotate | diff | comparison | revisions
eric7/Unittest/Interfaces/UnittestRunner.py file | annotate | diff | comparison | revisions
eric7/Unittest/UTTestResultsTree.py file | annotate | diff | comparison | revisions
eric7/Unittest/UnittestWidget.py file | annotate | diff | comparison | revisions
eric7/Unittest/UnittestWidget.ui file | annotate | diff | comparison | revisions
--- a/eric7/EricNetwork/EricJsonStreamReader.py	Thu May 12 09:00:35 2022 +0200
+++ b/eric7/EricNetwork/EricJsonStreamReader.py	Fri May 13 17:23:21 2022 +0200
@@ -94,7 +94,7 @@
         if self.__connection is not None:
             self.__connection.close()
         
-            self.__connection = connection
+        self.__connection = connection
         
         connection.readyRead.connect(self.__receiveJson)
         connection.disconnected.connect(self.__handleDisconnect)
--- a/eric7/Unittest/Interfaces/UTExecutorBase.py	Thu May 12 09:00:35 2022 +0200
+++ b/eric7/Unittest/Interfaces/UTExecutorBase.py	Fri May 13 17:23:21 2022 +0200
@@ -21,6 +21,7 @@
     """
     Class defining the supported result categories.
     """
+    RUNNING = 0
     FAIL = 1
     OK = 2
     SKIP = 3
@@ -32,14 +33,16 @@
     """
     Class containing the test result data.
     """
-    category: int               # result category
+    category: ResultCategory    # result category
     status: str                 # test status
     name: str                   # test name
-    message: str                # short result message
-    extra: str                  # additional information text
-    duration: float             # test duration
-    filename: str               # file name of a failed test
-    lineno: int                 # line number of a failed test
+    id: str                     # test id
+    description: str = ""       # short description of test
+    message: str = ""           # short result message
+    extra: list = None          # additional information text
+    duration: float = None      # test duration
+    filename: str = None        # file name of a failed test
+    lineno: int = None          # line number of a failed test
 
 
 @dataclass
@@ -61,43 +64,47 @@
     """
     Base class for test framework specific implementations.
     
-    @signal collected(list of str) emitted after all tests have been
-        collected
+    @signal collected(list of tuple of (str, str, str)) emitted after all tests
+        have been collected. Tuple elements are the test id, the test name and
+        a short description of the test.
     @signal collectError(list of tuple of (str, str)) emitted when errors
         are encountered during test collection. Tuple elements are the
         test name and the error message.
-    @signal startTest(list of str) emitted before tests are run
+    @signal startTest(tuple of (str, str, str) emitted before tests are run.
+        Tuple elements are test id, test name and short description.
     @signal testResult(UTTestResult) emitted when a test result is ready
     @signal testFinished(list, str) emitted when the test has finished.
         The elements are the list of test results and the captured output
         of the test worker (if any).
+    @signal testRunFinished(int, float) emitted when the test run has finished.
+        The elements are the number of tests run and the duration in seconds
     @signal stop() emitted when the test process is being stopped.
+    @signal coverageDataSaved(str) emitted after the coverage data was saved.
+        The element is the absolute path of the coverage data file.
     """
     collected = pyqtSignal(list)
     collectError = pyqtSignal(list)
-    startTest = pyqtSignal(list)
+    startTest = pyqtSignal(tuple)
     testResult = pyqtSignal(UTTestResult)
     testFinished = pyqtSignal(list, str)
+    testRunFinished = pyqtSignal(int, float)
     stop = pyqtSignal()
+    coverageDataSaved = pyqtSignal(str)
     
     module = ""
     name = ""
     runner = ""
     
-    def __init__(self, testWidget, logfile=None):
+    def __init__(self, testWidget):
         """
         Constructor
         
         @param testWidget reference to the unit test widget
         @type UnittestWidget
-        @param logfile file name to log test results to (defaults to None)
-        @type str (optional)
         """
         super().__init__(testWidget)
         
         self.__process = None
-        self._logfile = logfile
-        # TODO: add log file creation
     
     @classmethod
     def isInstalled(cls, interpreter):
--- a/eric7/Unittest/Interfaces/UTFrameworkRegistry.py	Thu May 12 09:00:35 2022 +0200
+++ b/eric7/Unittest/Interfaces/UTFrameworkRegistry.py	Fri May 13 17:23:21 2022 +0200
@@ -38,7 +38,7 @@
         """
         self.__frameworks[executorClass.name] = executorClass
     
-    def createExecutor(self, framework, widget, logfile=None):
+    def createExecutor(self, framework, widget):
         """
         Public method to create a test framework executor.
         
@@ -48,12 +48,11 @@
         @type str
         @param widget reference to the unit test widget
         @type UnittestWidget
-        @param logfile file name to log test results to (defaults to None)
-        @type str (optional)
         @return test framework executor object
+        @rtype UTExecutorBase
         """
         cls = self.__frameworks[framework]
-        return cls(widget, logfile=logfile)
+        return cls(widget)
     
     def getFrameworks(self):
         """
--- a/eric7/Unittest/Interfaces/UnittestExecutor.py	Thu May 12 09:00:35 2022 +0200
+++ b/eric7/Unittest/Interfaces/UnittestExecutor.py	Fri May 13 17:23:21 2022 +0200
@@ -10,10 +10,13 @@
 import contextlib
 import json
 import os
+import re
 
-from PyQt6.QtCore import QProcess
+from PyQt6.QtCore import pyqtSlot, QProcess
 
-from .UTExecutorBase import UTExecutorBase
+from EricNetwork.EricJsonStreamReader import EricJsonReader
+
+from .UTExecutorBase import UTExecutorBase, UTTestResult, ResultCategory
 
 
 class UnittestExecutor(UTExecutorBase):
@@ -25,6 +28,33 @@
     
     runner = os.path.join(os.path.dirname(__file__), "UnittestRunner.py")
     
+    def __init__(self, testWidget):
+        """
+        Constructor
+        
+        @param testWidget reference to the unit test widget
+        @type UnittestWidget
+        """
+        super().__init__(testWidget)
+        
+        self.__statusCategoryMapping = {
+            "failure": ResultCategory.FAIL,
+            "error": ResultCategory.FAIL,
+            "skipped": ResultCategory.SKIP,
+            "expected failure": ResultCategory.OK,
+            "unexpected success": ResultCategory.FAIL,
+            "success": ResultCategory.OK,
+        }
+        
+        self.__statusDisplayMapping = {
+            "failure": self.tr("Failure"),
+            "error": self.tr("Error"),
+            "skipped": self.tr("Skipped"),
+            "expected failure": self.tr("Expected Failure"),
+            "unexpected success": self.tr("Unexpected Success"),
+            "success": self.tr("Success"),
+        }
+    
     def getVersions(self, interpreter):
         """
         Public method to get the test framework version and version information
@@ -55,9 +85,125 @@
         @type UTTestConfig
         @return list of process arguments
         @rtype list of str
-        @exception NotImplementedError this method needs to be implemented by
-            derived classes
+        """
+        args = [
+            UnittestExecutor.runner,
+            "runtest",
+            self.reader.address(),
+            str(self.reader.port()),
+        ]
+        
+        if config.discover:
+            args.extend([
+                "discover",
+                "--start-directory",
+                config.discoveryStart,
+            ])
+        
+        if config.failFast:
+            args.append("--failfast")
+        
+        if config.collectCoverage:
+            args.append("--cover")
+            if config.eraseCoverage:
+                args.append("--cover-erase")
+        
+        if config.testFilename and config.testName:
+            args.append(config.testFilename)
+            args.append(config.testName)
+        
+        return args
+    
+    def start(self, config, pythonpath):
+        """
+        Public method to start the testing process.
+        
+        @param config configuration for the test execution
+        @type UTTestConfig
+        @param pythonpath list of directories to be added to the Python path
+        @type list of str
+        """
+        self.reader = EricJsonReader(name="Unittest Reader", parent=self)
+        self.reader.dataReceived.connect(self.__processData)
+        
+        super().start(config, pythonpath)
+    
+    def finished(self):
+        """
+        Public method handling the unit test process been finished.
+        
+        This method should read the results (if necessary) and emit the signal
+        testFinished.
+        """
+        self.reader.close()
+        
+        output = self.readAllOutput()
+        self.testFinished.emit([], output)
+    
+    @pyqtSlot(object)
+    def __processData(self, data):
         """
-        raise NotImplementedError
+        Private slot to process the received data.
+        
+        @param data data object received
+        @type dict
+        """
+        # error collecting tests
+        if data["event"] == "collecterror":
+            self.collectError.emit([("", data["error"])])
+        
+        # tests collected
+        elif data["event"] == "collected":
+            self.collected.emit([
+                (t["id"], t["name"], t["description"]) for t in data["tests"]
+            ])
+        
+        # test started
+        elif data["event"] == "started":
+            self.startTest.emit(
+                (data["id"], data["name"], data["description"])
+            )
         
-        return []
+        # test result
+        elif data["event"] == "result":
+            fn, ln = None, None
+            tracebackLines = []
+            if "traceback" in data:
+                # get the error info
+                tracebackLines = data["traceback"].splitlines()
+                # find the last entry matching the pattern
+                for index in range(len(tracebackLines) - 1, -1, -1):
+                    fmatch = re.search(r'File "(.*?)", line (\d*?),.*',
+                                       tracebackLines[index])
+                    if fmatch:
+                        break
+                if fmatch:
+                    fn, ln = fmatch.group(1, 2)
+                
+            if "shortmsg" in data:
+                message = data["shortmsg"]
+            elif tracebackLines:
+                message = tracebackLines[-1].split(":", 1)[1].strip()
+            else:
+                message = ""
+            
+            self.testResult.emit(UTTestResult(
+                category=self.__statusCategoryMapping[data["status"]],
+                status=self.__statusDisplayMapping[data["status"]],
+                name=data["name"],
+                id=data["id"],
+                description=data["description"],
+                message=message,
+                extra=tracebackLines,
+                duration=data["duration_ms"],
+                filename=fn,
+                lineno=ln,
+            ))
+        
+        # test run finished
+        elif data["event"] == "finished":
+            self.testRunFinished.emit(data["tests"], data["duration_s"])
+        
+        # coverage data
+        elif data["event"] == "coverage":
+            self.coverageDataSaved.emit(data["file"])
--- a/eric7/Unittest/Interfaces/UnittestRunner.py	Thu May 12 09:00:35 2022 +0200
+++ b/eric7/Unittest/Interfaces/UnittestRunner.py	Fri May 13 17:23:21 2022 +0200
@@ -8,26 +8,391 @@
 """
 
 import json
+import os
 import sys
+import time
+import unittest
+
+
+sys.path.insert(
+    2,
+    os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
+)
+
+
+class EricTestResult(unittest.TestResult):
+    """
+    Class implementing a TestResult derivative to send the data via a network
+    connection.
+    """
+    def __init__(self, writer, failfast):
+        """
+        Constructor
+        
+        @param writer reference to the object to write the results to
+        @type EricJsonWriter
+        @param failfast flag indicating to stop at the first error
+        @type bool
+        """
+        super().__init__()
+        self.__writer = writer
+        self.failfast = failfast
+        self.__testsRun = 0
+        
+        self.__currentTestStatus = {}
+    
+    def addFailure(self, test, err):
+        """
+        Public method called if a test failed.
+        
+        @param test reference to the test object
+        @type TestCase
+        @param err tuple containing the exception data like sys.exc_info
+            (exception type, exception instance, traceback)
+        @type tuple
+        """
+        super().addFailure(test, err)
+        tracebackLines = self._exc_info_to_string(err, test)
+        
+        self.__currentTestStatus.update({
+            "status": "failure",
+            "traceback": tracebackLines,
+        })
+    
+    def addError(self, test, err):
+        """
+        Public method called if a test errored.
+        
+        @param test reference to the test object
+        @type TestCase
+        @param err tuple containing the exception data like sys.exc_info
+            (exception type, exception instance, traceback)
+        @type tuple
+        """
+        super().addError(test, err)
+        tracebackLines = self._exc_info_to_string(err, test)
+        
+        self.__currentTestStatus.update({
+            "status": "error",
+            "traceback": tracebackLines,
+        })
+    
+    def addSubTest(self, test, subtest, err):
+        """
+        Public method called for each subtest to record its result.
+        
+        @param test reference to the test object
+        @type TestCase
+        @param subtest reference to the subtest object
+        @type TestCase
+        @param err tuple containing the exception data like sys.exc_info
+            (exception type, exception instance, traceback)
+        @type tuple
+        """
+        if err is not None:
+            super().addSubTest(test, subtest, err)
+            tracebackLines = self._exc_info_to_string(err, test)
+            status = (
+                "failure"
+                if issubclass(err[0], test.failureException) else
+                "error"
+            )
+            
+            self.__currentTestStatus.update({
+                "status": status,
+                "name": str(subtest),
+                "traceback": tracebackLines,
+            })
+            
+            if self.failfast:
+                self.stop()
+    
+    def addSkip(self, test, reason):
+        """
+        Public method called if a test was skipped.
+        
+        @param test reference to the test object
+        @type TestCase
+        @param reason reason for skipping the test
+        @type str
+        """
+        super().addSkip(test, reason)
+        
+        self.__currentTestStatus.update({
+            "status": "skipped",
+            "shortmsg": reason,
+        })
+    
+    def addExpectedFailure(self, test, err):
+        """
+        Public method called if a test failed expected.
+        
+        @param test reference to the test object
+        @type TestCase
+        @param err tuple containing the exception data like sys.exc_info
+            (exception type, exception instance, traceback)
+        @type tuple
+        """
+        super().addExpectedFailure(test, err)
+        tracebackLines = self._exc_info_to_string(err, test)
+        
+        self.__currentTestStatus.update({
+            "status": "expected failure",
+            "traceback": tracebackLines,
+        })
+    
+    def addUnexpectedSuccess(self, test):
+        """
+        Public method called if a test succeeded expectedly.
+        
+        @param test reference to the test object
+        @type TestCase
+        """
+        super().addUnexpectedSuccess(test)
+        
+        self.__currentTestStatus["status"] = "unexpected success"
+    
+    def startTest(self, test):
+        """
+        Public method called at the start of a test.
+        
+        @param test reference to the test object
+        @type TestCase
+        """
+        super().startTest(test)
+        
+        self.__testsRun += 1
+        self.__currentTestStatus = {
+            "event": "result",
+            "status": "success",
+            "name": str(test),
+            "id": test.id(),
+            "description": test.shortDescription(),
+        }
+        
+        self.__writer.write({
+            "event": "started",
+            "name": str(test),
+            "id": test.id(),
+            "description": test.shortDescription(),
+        })
+        
+        self.__startTime = time.monotonic_ns()
+    
+    def stopTest(self, test):
+        """
+        Public method called at the end of a test.
+        
+        @param test reference to the test object
+        @type TestCase
+        """
+        stopTime = time.monotonic_ns()
+        duration = (stopTime - self.__startTime) / 1_000_000     # ms
+        
+        super().stopTest(test)
+        
+        self.__currentTestStatus["duration_ms"] = duration
+        self.__writer.write(self.__currentTestStatus)
+    
+    def startTestRun(self):
+        """
+        Public method called once before any tests are executed.
+        """
+        self.__totalStartTime = time.monotonic_ns()
+        self.__testsRun = 0
+    
+    def stopTestRun(self):
+        """
+        Public method called once after all tests are executed.
+        """
+        stopTime = time.monotonic_ns()
+        duration = (stopTime - self.__totalStartTime) / 1_000_000_000   # s
+        
+        self.__writer.write({
+            "event": "finished",
+            "duration_s": duration,
+            "tests": self.__testsRun,
+        })
+
+
+def _assembleTestCasesList(suite):
+    """
+    Protected function to assemble a list of test cases included in a test
+    suite.
+    
+    @param suite test suite to be inspected
+    @type unittest.TestSuite
+    @return list of tuples containing the test case ID, the string
+        representation and the short description
+    @rtype list of tuples of (str, str)
+    """
+    testCases = []
+    for test in suite:
+        if isinstance(test, unittest.TestSuite):
+            testCases.extend(_assembleTestCasesList(test))
+        else:
+            testId = test.id()
+            if (
+                "ModuleImportFailure" not in testId and
+                "LoadTestsFailure" not in testId and
+                "_FailedTest" not in testId
+            ):
+                testCases.append(
+                    (testId, str(test), test.shortDescription())
+                )
+    return testCases
+
+
+def runtest(argv):
+    """
+    Function to run the tests.
+    
+    @param argv list of command line parameters.
+    @type list of str
+    """
+    from EricNetwork.EricJsonStreamWriter import EricJsonWriter
+    writer = EricJsonWriter(argv[0], int(argv[1]))
+    del argv[:2]
+    
+    # process arguments
+    if argv[0] == "discover":
+        discover = True
+        argv.pop(0)
+        if argv[0] == "--start-directory":
+            discoveryStart = argv[1]
+            del argv[:2]
+    else:
+        discover = False
+        discoveryStart = ""
+    
+    failfast = "--failfast" in argv
+    if failfast:
+        argv.remove("--failfast")
+    
+    coverage = "--cover" in argv
+    if coverage:
+        argv.remove("--cover")
+    coverageErase = "--cover-erase" in argv
+    if coverageErase:
+        argv.remove("--cover-erase")
+    
+    if not discover:
+        testFileName, testName = argv[:2]
+        del argv[:2]
+    else:
+        testFileName = testName = ""
+    
+    testCases = argv[:]
+    
+    if testFileName:
+        sys.path.insert(1, os.path.dirname(os.path.abspath(testFileName)))
+    elif discoveryStart:
+        sys.path.insert(1, os.path.abspath(discoveryStart))
+    
+    try:
+        testLoader = unittest.TestLoader()
+        if discover:
+            if testCases:
+                test = testLoader.loadTestsFromNames(testCases)
+            else:
+                test = testLoader.discover(discoveryStart)
+        else:
+            if testFileName:
+                module = __import__(os.path.splitext(
+                    os.path.basename(testFileName))[0])
+            else:
+                module = None
+            # TODO: implement 'failed only'
+#            if failedOnly and self.__failedTests:
+#                if module:
+#                    failed = [t.split(".", 1)[1]
+#                              for t in self.__failedTests]
+#                else:
+#                    failed = list(self.__failedTests)
+#                test = testLoader.loadTestsFromNames(
+#                    failed, module)
+#            else:
+            test = testLoader.loadTestsFromName(
+                testName, module)
+    except Exception as err:
+        print("Exception:", str(err))
+        writer.write({
+            "event": "collecterror",
+            "error": str(err),
+        })
+        sys.exit(1)
+    
+    collectedTests = {
+        "event": "collected",
+        "tests": [
+            {"id": id, "name": name, "description": desc}
+            for id, name, desc in _assembleTestCasesList(test)
+        ]
+    }
+    writer.write(collectedTests)
+    
+    # setup test coverage
+    if coverage:
+        if discover:
+            covname = os.path.join(discoveryStart, "unittest")
+        elif testFileName:
+            covname = os.path.splitext(
+                os.path.abspath(testFileName))[0]
+        else:
+            covname = "unittest"
+        covDataFile = "{0}.coverage".format(covname)
+        if not os.path.isabs(covDataFile):
+            covDataFile = os.path.abspath(covDataFile)
+        
+        from DebugClients.Python.coverage import coverage as cov
+        cover = cov(data_file=covDataFile)
+        if coverageErase:
+            cover.erase()
+    else:
+        cover = None
+    
+    testResult = EricTestResult(writer, failfast)
+    startTestRun = getattr(testResult, 'startTestRun', None)
+    if startTestRun is not None:
+        startTestRun()
+    try:
+        if cover:
+            cover.start()
+        test.run(testResult)
+    finally:
+        if cover:
+            cover.stop()
+            cover.save()
+            writer.write({
+                "event": "coverage",
+                "file": covDataFile,
+            })
+        stopTestRun = getattr(testResult, 'stopTestRun', None)
+        if stopTestRun is not None:
+            stopTestRun()
+    
+    writer.close()
+    sys.exit(0)
 
 if __name__ == '__main__':
-    command = sys.argv[1]
-    if command == "installed":
-        try:
-            import unittest         # __IGNORE_WARNING__
+    if len(sys.argv) > 1:
+        command = sys.argv[1]
+        if command == "installed":
             sys.exit(0)
-        except ImportError:
-            sys.exit(1)
-    
-    elif command == "versions":
-        import platform
-        versions = {
-            "name": "unittest",
-            "version": platform.python_version(),
-            "plugins": [],
-        }
-        print(json.dumps(versions))
-        sys.exit(0)
+        
+        elif command == "versions":
+            import platform
+            versions = {
+                "name": "unittest",
+                "version": platform.python_version(),
+                "plugins": [],
+            }
+            print(json.dumps(versions))
+            sys.exit(0)
+        
+        elif command == "runtest":
+            runtest(sys.argv[2:])
+            sys.exit(0)
     
     sys.exit(42)
 
--- a/eric7/Unittest/UTTestResultsTree.py	Thu May 12 09:00:35 2022 +0200
+++ b/eric7/Unittest/UTTestResultsTree.py	Fri May 13 17:23:21 2022 +0200
@@ -8,11 +8,23 @@
 data.
 """
 
+import contextlib
+import copy
+import locale
+from operator import attrgetter
+
 from PyQt6.QtCore import (
     pyqtSignal, pyqtSlot, Qt, QAbstractItemModel, QCoreApplication, QModelIndex
 )
+from PyQt6.QtGui import QBrush, QColor
 from PyQt6.QtWidgets import QTreeView
 
+from EricWidgets.EricApplication import ericApp
+
+import Preferences
+
+from .Interfaces.UTExecutorBase import ResultCategory
+
 TopLevelId = 2 ** 32 - 1
 
 
@@ -27,6 +39,11 @@
         QCoreApplication.translate("TestResultsModel", "Duration (ms)"),
     ]
     
+    StatusColumn = 0
+    NameColumn = 1
+    MessageColumn = 2
+    DurationColumn = 3
+    
     def __init__(self, parent=None):
         """
         Constructor
@@ -36,8 +53,107 @@
         """
         super().__init__(parent)
         
+        if ericApp().usesDarkPalette():
+            self.__backgroundColors = {
+                ResultCategory.RUNNING: None,
+                ResultCategory.FAIL: QBrush(QColor("#880000")),
+                ResultCategory.OK: QBrush(QColor("#005500")),
+                ResultCategory.SKIP: QBrush(QColor("#3f3f3f")),
+                ResultCategory.PENDING: QBrush(QColor("#004768")),
+            }
+        else:
+            self.__backgroundColors = {
+                ResultCategory.RUNNING: None,
+                ResultCategory.FAIL: QBrush(QColor("#ff8080")),
+                ResultCategory.OK: QBrush(QColor("#c1ffba")),
+                ResultCategory.SKIP: QBrush(QColor("#c5c5c5")),
+                ResultCategory.PENDING: QBrush(QColor("#6fbaff")),
+            }
+        
         self.__testResults = []
     
+    def index(self, row, column, parent=QModelIndex()):
+        """
+        Public method to generate an index for the given row and column to
+        identify the item.
+        
+        @param row row for the index
+        @type int
+        @param column column for the index
+        @type int
+        @param parent index of the parent item (defaults to QModelIndex())
+        @type QModelIndex (optional)
+        @return index for the item
+        @rtype QModelIndex
+        """
+        if not self.hasIndex(row, column, parent):  # check bounds etc.
+            return QModelIndex()
+        
+        if not parent.isValid():
+            # top level item
+            return self.createIndex(row, column, TopLevelId)
+        else:
+            testResultIndex = parent.row()
+            return self.createIndex(row, column, testResultIndex)
+    
+    def data(self, index, role):
+        """
+        Public method to get the data for the various columns and roles.
+        
+        @param index index of the data to be returned
+        @type QModelIndex
+        @param role role designating the data to return
+        @type Qt.ItemDataRole
+        @return requested data item
+        @rtype Any
+        """
+        if not index.isValid():
+            return None
+        
+        row = index.row()
+        column = index.column()
+        idx = index.internalId()
+        
+        if role == Qt.ItemDataRole.DisplayRole:
+            if idx != TopLevelId:
+                if bool(self.__testResults[idx].extra):
+                    return self.__testResults[idx].extra[index.row()]
+                else:
+                    return None
+            elif column == TestResultsModel.StatusColumn:
+                return self.__testResults[row].status
+            elif column == TestResultsModel.NameColumn:
+                return self.__testResults[row].name
+            elif column == TestResultsModel.MessageColumn:
+                return self.__testResults[row].message
+            elif column == TestResultsModel.DurationColumn:
+                duration = self.__testResults[row].duration
+                return (
+                    ''
+                    if duration is None else
+                    locale.format_string("%.2f", duration, grouping=True)
+                )
+        elif role == Qt.ItemDataRole.ToolTipRole:
+            if idx == TopLevelId and column == TestResultsModel.NameColumn:
+                return self.testresults[row].name
+        elif role == Qt.ItemDataRole.FontRole:
+            if idx != TopLevelId:
+                return Preferences.getEditorOtherFonts("MonospacedFont")
+        elif role == Qt.ItemDataRole.BackgroundRole:
+            if idx == TopLevelId:
+                testResult = self.__testResults[row]
+                with contextlib.suppress(KeyError):
+                    return self.__backgroundColors[testResult.category]
+        elif role == Qt.ItemDataRole.TextAlignmentRole:
+            if idx == TopLevelId and column == TestResultsModel.DurationColumn:
+                return Qt.AlignmentFlag.AlignRight
+        elif role == Qt.ItemDataRole.UserRole:      # __IGNORE_WARNING_Y102__
+            if idx == TopLevelId:
+                testresult = self.testresults[row]
+                return (testresult.filename, testresult.lineno)
+        
+        return None
+    
     def headerData(self, section, orientation,
                    role=Qt.ItemDataRole.DisplayRole):
         """
@@ -60,6 +176,24 @@
         else:
             return None
     
+    def parent(self, index):
+        """
+        Public method to get the parent of the item pointed to by index.
+        
+        @param index index of the item
+        @type QModelIndex
+        @return index of the parent item
+        @rtype QModelIndex
+        """
+        if not index.isValid():
+            return QModelIndex()
+        
+        idx = index.internalId()
+        if idx == TopLevelId:
+            return QModelIndex()
+        else:
+            return self.index(idx, 0)
+    
     def rowCount(self, parent=QModelIndex()):
         """
         Public method to get the number of row for a given parent index.
@@ -72,7 +206,11 @@
         if not parent.isValid():
             return len(self.__testResults)
         
-        if parent.internalId() == TopLevelId and parent.column() == 0:
+        if (
+            parent.internalId() == TopLevelId and
+            parent.column() == 0 and
+            self.__testResults[parent.row()].extra is not None
+        ):
             return len(self.__testResults[parent.row()].extra)
         
         return 0
@@ -98,6 +236,100 @@
         self.beginResetModel()
         self.__testResults.clear()
         self.endResetModel()
+    
+    def sort(self, column, order):
+        """
+        Public method to sort the model data by column in order.
+        
+        @param column sort column number
+        @type int
+        @param order sort order
+        @type Qt.SortOrder
+        """             # __IGNORE_WARNING_D234r__
+        def durationKey(result):
+            """
+            Function to generate a key for duration sorting
+            
+            @param result result object
+            @type UTTestResult
+            @return sort key
+            @rtype float
+            """
+            return result.duration or -1.0
+
+        self.beginResetModel()
+        reverse = order == Qt.SortOrder.DescendingOrder
+        if column == TestResultsModel.StatusColumn:
+            self.__testResults.sort(key=attrgetter('category', 'status'),
+                                    reverse=reverse)
+        elif column == TestResultsModel.NameColumn:
+            self.__testResults.sort(key=attrgetter('name'), reverse=reverse)
+        elif column == TestResultsModel.MessageColumn:
+            self.__testResults.sort(key=attrgetter('message'), reverse=reverse)
+        elif column == TestResultsModel.DurationColumn:
+            self.__testResults.sort(key=durationKey, reverse=reverse)
+        self.endResetModel()
+    
+    def getTestResults(self):
+        """
+        Public method to get the list of test results managed by the model.
+        
+        @return list of test results managed by the model
+        @rtype list of UTTestResult
+        """
+        return copy.deepcopy(self.__testResults)
+    
+    def setTestResults(self, testResults):
+        """
+        Public method to set the list of test results of the model.
+        
+        @param testResults test results to be managed by the model
+        @type list of UTTestResult
+        """
+        self.beginResetModel()
+        self.__testResults = copy.deepcopy(testResults)
+        self.endResetModel()
+    
+    def addTestResults(self, testResults):
+        """
+        Public method to add test results to the ones already managed by the
+        model.
+        
+        @param testResults test results to be added to the model
+        @type list of UTTestResult
+        """
+        firstRow = len(self.__testResults)
+        lastRow = firstRow + len(testResults) - 1
+        self.beginInsertRows(QModelIndex(), firstRow, lastRow)
+        self.__testResults.extend(testResults)
+        self.endInsertRows()
+    
+    def updateTestResults(self, testResults):
+        """
+        Public method to update the data of managed test result items.
+        
+        @param testResults test results to be updated
+        @type list of UTTestResult
+        """
+        minIndex = None
+        maxIndex = None
+        
+        for testResult in testResults:
+            for (index, currentResult) in enumerate(self.__testResults):
+                if currentResult.id == testResult.id:
+                    self.__testResults[index] = testResult
+                    if minIndex is None:
+                        minIndex = index
+                        maxIndex = index
+                    else:
+                        minIndex = min(minIndex, index)
+                        maxIndex = max(maxIndex, index)
+        
+        if minIndex is not None:
+            self.dataChanged.emit(
+                self.index(minIndex, 0),
+                self.index(maxIndex, len(TestResultsModel.Headers) - 1)
+            )
 
 
 class TestResultsTreeView(QTreeView):
@@ -132,6 +364,51 @@
         self.header().sortIndicatorChanged.connect(
             lambda column, order: self.header().setSortIndicatorShown(True))
     
+    def reset(self):
+        """
+        Public method to reset the internal state of the view.
+        """
+        super().reset()
+        
+        self.resizeColumns()
+        self.spanFirstColumn(0, self.model().rowCount() - 1)
+    
+    def rowsInserted(self, parent, startRow, endRow):
+        """
+        Public method called when rows are inserted.
+        
+        @param parent model index of the parent item
+        @type QModelIndex
+        @param startRow first row been inserted
+        @type int
+        @param endRow last row been inserted
+        @type int
+        """
+        super().rowsInserted(parent, startRow, endRow)
+        
+        self.resizeColumns()
+        self.spanFirstColumn(startRow, endRow)
+    
+    def dataChanged(self, topLeft, bottomRight, roles=[]):
+        """
+        Public method called when the model data has changed.
+        
+        @param topLeft index of the top left element
+        @type QModelIndex
+        @param bottomRight index of the bottom right element
+        @type QModelIndex
+        @param roles list of roles changed (defaults to [])
+        @type list of Qt.ItemDataRole (optional)
+        """
+        super().dataChanged(topLeft, bottomRight, roles)
+        
+        self.resizeColumns()
+        while topLeft.parent().isValid():
+            topLeft = topLeft.parent()
+        while bottomRight.parent().isValid():
+            bottomRight = bottomRight.parent()
+        self.spanFirstColumn(topLeft.row(), bottomRight.row())
+    
     @pyqtSlot(QModelIndex)
     def __gotoTestDefinition(self, index):
         """
@@ -140,8 +417,33 @@
         @param index index for the double-clicked item
         @type QModelIndex
         """
-        # TODO: not implemented yet
+        # TODO: not implemented yet (__gotoTestDefinition)
         pass
+    
+    def resizeColumns(self):
+        """
+        Public method to resize the columns to their contents.
+        """
+        for column in range(self.model().columnCount()):
+            self.resizeColumnToContents(column)
+    
+    def spanFirstColumn(self, startRow, endRow):
+        """
+        Public method to make the first column span the row for second level
+        items.
+        
+        These items contain the test results.
+        
+        @param startRow index of the first row to span
+        @type QModelIndex
+        @param endRow index of the last row (including) to span
+        @type QModelIndex
+        """
+        model = self.model()
+        for row in range(startRow, endRow + 1):
+            index = model.index(row, 0)
+            for i in range(model.rowCount(index)):
+                self.setFirstColumnSpanned(i, index, True)
 
 #
-# eflag: noqa = M822
+# eflag: noqa = M821, M822
--- a/eric7/Unittest/UnittestWidget.py	Thu May 12 09:00:35 2022 +0200
+++ b/eric7/Unittest/UnittestWidget.py	Fri May 13 17:23:21 2022 +0200
@@ -8,6 +8,7 @@
 """
 
 import enum
+import locale
 import os
 
 from PyQt6.QtCore import pyqtSlot, Qt, QEvent, QCoreApplication
@@ -24,7 +25,9 @@
 
 from .UTTestResultsTree import TestResultsModel, TestResultsTreeView
 from .Interfaces import Frameworks
-from .Interfaces.UTExecutorBase import UTTestConfig, UTTestResult
+from .Interfaces.UTExecutorBase import (
+    UTTestConfig, UTTestResult, ResultCategory
+)
 from .Interfaces.UTFrameworkRegistry import UTFrameworkRegistry
 
 import Preferences
@@ -46,6 +49,8 @@
     STOPPED = 2         # test run finished
 
 
+# TODO: add a "Show Coverage" function using PyCoverageDialog
+
 class UnittestWidget(QWidget, Ui_UnittestWidget):
     """
     Class implementing a widget to orchestrate unit test execution.
@@ -175,7 +180,7 @@
                 self.__insertDiscovery("")
         else:
             self.__insertDiscovery("")
-        self.__insertProg(testfile)
+        self.__insertTestFile(testfile)
         self.__insertTestName("")
         
         self.clearHistoriesButton.clicked.connect(self.clearRecent)
@@ -253,7 +258,7 @@
         widget.addItems(history)
         
         if current:
-            widget.setText(current)
+            widget.setEditText(current)
     
     @pyqtSlot(str)
     def __insertDiscovery(self, start):
@@ -268,7 +273,7 @@
                              start)
     
     @pyqtSlot(str)
-    def __insertProg(self, prog):
+    def __insertTestFile(self, prog):
         """
         Private slot to insert a test file name into the testsuitePicker
         object.
@@ -392,13 +397,31 @@
             self.__startButton.setDefault(False)
         
         # Start Failed button
-        # TODO: not implemented yet
+        # TODO: not implemented yet (Start Failed button)
         
         # Stop button
         self.__stopButton.setEnabled(
             self.__mode == UnittestWidgetModes.RUNNING)
         self.__stopButton.setDefault(
             self.__mode == UnittestWidgetModes.RUNNING)
+        
+        # Close button
+        self.buttonBox.button(
+            QDialogButtonBox.StandardButton.Close
+        ).setEnabled(self.__mode in (
+            UnittestWidgetModes.IDLE, UnittestWidgetModes.STOPPED
+        ))
+    
+    def __updateProgress(self):
+        """
+        Private method update the progress indicators.
+        """
+        self.progressCounterRunCount.setText(
+            str(self.__runCount))
+        self.progressCounterRemCount.setText(
+            str(self.__totalCount - self.__runCount))
+        self.progressProgressBar.setMaximum(self.__totalCount)
+        self.progressProgressBar.setValue(self.__runCount)
     
     def __setIdleMode(self):
         """
@@ -406,20 +429,37 @@
         """
         self.__mode = UnittestWidgetModes.IDLE
         self.__updateButtonBoxButtons()
+        self.tabWidget.setCurrentIndex(0)
     
     def __setRunningMode(self):
         """
         Private method to switch the widget to running mode.
         """
-        # TODO: not implemented yet
-        pass
+        self.__mode = UnittestWidgetModes.RUNNING
+        
+        self.__totalCount = 0
+        self.__runCount = 0
+        
+        self.__coverageFile = ""
+        # TODO: implement the handling of the 'Show Coverage' button
+        
+        self.sbLabel.setText(self.tr("Running"))
+        self.tabWidget.setCurrentIndex(1)
+        self.__updateButtonBoxButtons()
+        self.__updateProgress()
+        
+        self.__resultsModel.clear()
     
     def __setStoppedMode(self):
         """
         Private method to switch the widget to stopped mode.
         """
-        # TODO: not implemented yet
-        pass
+        self.__mode = UnittestWidgetModes.STOPPED
+        
+        self.__updateButtonBoxButtons()
+        
+        self.raise_()
+        self.activateWindow()
     
     @pyqtSlot(QAbstractButton)
     def on_buttonBox_clicked(self, button):
@@ -429,10 +469,6 @@
         @param button button that was clicked
         @type QAbstractButton
         """
-##        if button == self.discoverButton:
-##            self.__discover()
-##            self.__saveRecent()
-##        elif button == self.__startButton:
         if button == self.__startButton:
             self.startTests()
             self.__saveRecent()
@@ -523,13 +559,16 @@
             discoveryStart = ""
             testFileName = self.testsuitePicker.currentText()
             if testFileName:
-                self.__insertProg(testFileName)
+                self.__insertTestFile(testFileName)
             testName = self.testComboBox.currentText()
             if testName:
-                self.insertTestName(testName)
+                self.__insertTestName(testName)
             if testFileName and not testName:
                 testName = "suite"
         
+        self.sbLabel.setText(self.tr("Preparing Testsuite"))
+        QCoreApplication.processEvents()
+        
         interpreter = self.__venvManager.getVirtualenvInterpreter(
             self.__recentEnvironment)
         config = UTTestConfig(
@@ -546,27 +585,47 @@
         self.__resultsModel.clear()
         self.__testExecutor = self.__frameworkRegistry.createExecutor(
             self.__recentFramework, self)
-        self.__testExecutor.collected.connect(self.__testCollected)
+        self.__testExecutor.collected.connect(self.__testsCollected)
         self.__testExecutor.collectError.connect(self.__testsCollectError)
-        self.__testExecutor.startTest.connect(self.__testsStarted)
+        self.__testExecutor.startTest.connect(self.__testStarted)
         self.__testExecutor.testResult.connect(self.__processTestResult)
         self.__testExecutor.testFinished.connect(self.__testProcessFinished)
+        self.__testExecutor.testRunFinished.connect(self.__testRunFinished)
         self.__testExecutor.stop.connect(self.__testsStopped)
-        self.__testExecutor.start(config, [])
+        self.__testExecutor.coverageDataSaved.connect(self.__coverageData)
         
-        # TODO: not yet implemented
-        pass
+        self.__setRunningMode()
+        self.__testExecutor.start(config, [])
+    
+    @pyqtSlot()
+    def __stopTests(self):
+        """
+        Private slot to stop the current test run.
+        """
+        self.__testExecutor.stopIfRunning()
     
     @pyqtSlot(list)
-    def __testCollected(self, testNames):
+    def __testsCollected(self, testNames):
         """
         Private slot handling the 'collected' signal of the executor.
         
-        @param testNames list of names of collected tests
-        @type list of str
+        @param testNames list of tuples containing the test id and test name
+            of collected tests
+        @type list of tuple of (str, str)
         """
-        # TODO: not implemented yet
-        pass
+        testResults = [
+            UTTestResult(
+                category=ResultCategory.PENDING,
+                status=self.tr("pending"),
+                name=name,
+                id=id,
+                message=desc,
+            ) for id, name, desc in testNames
+        ]
+        self.__resultsModel.setTestResults(testResults)
+        
+        self.__totalCount = len(testResults)
+        self.__updateProgress()
     
     @pyqtSlot(list)
     def __testsCollectError(self, errors):
@@ -577,19 +636,49 @@
             of the error
         @type list of tuple of (str, str)
         """
-        # TODO: not implemented yet
-        pass
+        testResults = []
+        
+        for testFile, error in errors:
+            if testFile:
+                testResults.append(UTTestResult(
+                    category=ResultCategory.FAIL,
+                    status=self.tr("Failure"),
+                    name=testFile,
+                    id=testFile,
+                    message=self.tr("Collection Error"),
+                    extra=error.splitlines()
+                ))
+            else:
+                EricMessageBox.critical(
+                    self,
+                    self.tr("Collection Error"),
+                    self.tr(
+                        "<p>There was an error while collecting unit tests."
+                        "</p><p>{0}</p>"
+                    ).format("<br/>".join(error.splitlines()))
+                )
+        
+        if testResults:
+            self.__resultsModel.addTestResults(testResults)
     
-    @pyqtSlot(list)
-    def __testsStarted(self, testNames):
+    @pyqtSlot(tuple)
+    def __testStarted(self, test):
         """
         Private slot handling the 'startTest' signal of the executor.
         
-        @param testNames list of names of tests about to be run
-        @type list of str
+        @param test tuple containing the id, name and short description of the
+            tests about to be run
+        @type tuple of (str, str, str)
         """
-        # TODO: not implemented yet
-        pass
+        self.__resultsModel.updateTestResults([
+            UTTestResult(
+                category=ResultCategory.RUNNING,
+                status=self.tr("running"),
+                id=test[0],
+                name=test[1],
+                message="" if test[2] is None else test[2],
+            )
+        ])
     
     @pyqtSlot(UTTestResult)
     def __processTestResult(self, result):
@@ -599,8 +688,10 @@
         @param result test result object
         @type UTTestResult
         """
-        # TODO: not implemented yet
-        pass
+        self.__runCount += 1
+        self.__updateProgress()
+        
+        self.__resultsModel.updateTestResults([result])
     
     @pyqtSlot(list, str)
     def __testProcessFinished(self, results, output):
@@ -613,16 +704,47 @@
         @param output string containing the test process output (if any)
         @type str
         """
-        # TODO: not implemented yet
-        pass
+        self.__setStoppedMode()
+        self.__testExecutor = None
+    
+    @pyqtSlot(int, float)
+    def __testRunFinished(self, noTests, duration):
+        """
+        Private slot to handle the 'testRunFinished' signal of the executor.
+        
+        @param noTests number of tests run by the executor
+        @type int
+        @param duration time needed in seconds to run the tests
+        @type float
+        """
+        self.sbLabel.setText(
+            self.tr("Ran %n test(s) in {0}s", "", noTests).format(
+                locale.format_string("%.3f", duration, grouping=True)
+            )
+        )
+        
+        self.__setStoppedMode()
     
     @pyqtSlot()
     def __testsStopped(self):
         """
         Private slot to handle the 'stop' signal of the executor.
         """
-        # TODO: not implemented yet
-        pass
+        self.sbLabel.setText(self.tr("Ran %n test(s)", "", self.__runCount))
+        
+        self.__setStoppedMode()
+    
+    @pyqtSlot(str)
+    def __coverageData(self, coverageFile):
+        """
+        Private slot to handle the 'coverageData' signal of the executor.
+        
+        @param coverageFile file containing the coverage data
+        @type str
+        """
+        self.__coverageFile = coverageFile
+        
+        # TODO: implement the handling of the 'Show Coverage' button
 
 
 class UnittestWindow(EricMainWindow):
--- a/eric7/Unittest/UnittestWidget.ui	Thu May 12 09:00:35 2022 +0200
+++ b/eric7/Unittest/UnittestWidget.ui	Fri May 13 17:23:21 2022 +0200
@@ -6,7 +6,7 @@
    <rect>
     <x>0</x>
     <y>0</y>
-    <width>650</width>
+    <width>850</width>
     <height>700</height>
    </rect>
   </property>
@@ -17,7 +17,7 @@
    <item>
     <widget class="QTabWidget" name="tabWidget">
      <property name="currentIndex">
-      <number>0</number>
+      <number>1</number>
      </property>
      <widget class="QWidget" name="parametersTab">
       <attribute name="title">
@@ -381,7 +381,7 @@
    <item>
     <layout class="QHBoxLayout" name="_4">
      <item>
-      <widget class="QLabel" name="sbLabel_2">
+      <widget class="QLabel" name="sbLabel">
        <property name="sizePolicy">
         <sizepolicy hsizetype="Preferred" vsizetype="Preferred">
          <horstretch>0</horstretch>

eric ide

mercurial