# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Common test-helper code involving workflows."""

from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, cast
from unittest import mock

from django.db import connection
from django.test import override_settings

from debusine.artifacts.models import TaskTypes
from debusine.db.models import Artifact, WorkRequest
from debusine.server.scheduler import schedule
from debusine.server.workflows import SbuildWorkflow, Workflow, workflow_utils
from debusine.server.workflows.models import BaseWorkflowData
from debusine.tasks.models import BaseDynamicTaskData
from debusine.tasks.tests.helper_mixin import SampleBaseTask
from debusine.test.django import TestCase
from debusine.utils import extract_generic_type_arguments


class SampleWorkflow[WD: BaseWorkflowData, DTD: BaseDynamicTaskData](
    Workflow[WD, DTD], SampleBaseTask[WD, DTD]
):
    """Common test implementation of Workflow methods."""


class WorkflowTestBase[W: Workflow[Any, Any]](TestCase):
    """Base class for workflow tests."""

    workflow_type: type[W]

    def __init_subclass__(cls, **kwargs: Any) -> None:
        """Extract the type argument."""
        super().__init_subclass__(**kwargs)
        [cls.workflow_type] = extract_generic_type_arguments(
            cls, WorkflowTestBase
        )

    @classmethod
    def get_workflow(cls, work_request: WorkRequest) -> W:
        """Create and return a suitable Workflow for this work request."""
        task = work_request.get_task()
        assert isinstance(task, cls.workflow_type)
        return cast(W, task)

    @contextmanager
    def schedule_and_run(self) -> Generator[list[WorkRequest]]:
        """
        Run the scheduler and any resulting Celery callbacks.

        Yields the list of scheduled work requests so that the caller can
        assert things about them before they are run.
        """
        start_count = len(connection.run_on_commit)
        with self.captureOnCommitCallbacks() as on_commit:
            yield schedule()
        with (
            override_settings(CELERY_TASK_ALWAYS_EAGER=True),
            mock.patch("pgtransaction.Atomic.execute_set_transaction_modes"),
        ):
            for callback in on_commit:
                callback()
            del connection.run_on_commit[
                start_count : start_count + len(on_commit)
            ]

    def schedule_and_run_workflow_callback(self, parent: WorkRequest) -> None:
        """
        Run the scheduler and a workflow callback.

        :param parent: Expect the workflow callback to be a child of this
          workflow.
        """
        with self.schedule_and_run() as [work_request]:
            self.assertEqual(work_request.task_type, TaskTypes.INTERNAL)
            self.assertEqual(work_request.task_name, "workflow")
            self.assertEqual(work_request.parent, parent)

    def schedule_and_run_workflow(self, workflow: WorkRequest) -> None:
        """
        Run the scheduler and a workflow.

        :param workflow: Expect this workflow to be scheduled.
        """
        with self.schedule_and_run() as work_requests:
            self.assertIn(workflow, work_requests)

    def simulate_sbuild_workflow_completion(
        self,
        sbuild_workflow: WorkRequest,
        only_architecture: str | None = None,
        section: str = "misc",
        priority: str = "optional",
    ) -> dict[str, list[Artifact]]:
        """
        Simulate completion of an sbuild workflow.

        :param sbuild_workflow: The workflow to complete.
        :param only_architecture: If given, only complete the sbuild task
          for this architecture.
        :param section: ``Section`` field for binary artifacts.
        :param priority: ``Priority`` field for binary artifacts.
        :return: A dictionary mapping architectures to the artifacts created
          for them.
        """
        sbuild_workflow_task = sbuild_workflow.get_task()
        assert isinstance(sbuild_workflow_task, SbuildWorkflow)
        source_data = workflow_utils.source_package_data(
            sbuild_workflow_task, configuration_key="input.source_artifact"
        )

        binaries: dict[str, list[Artifact]] = {}
        for sbuild in sbuild_workflow.children.filter(
            task_type=TaskTypes.WORKER,
            task_name="sbuild",
            workflow_data_json__step__startswith="build-",
        ).order_by("workflow_data_json__step"):
            assert sbuild.workflow_data.step is not None
            arch = sbuild.workflow_data.step.removeprefix("build-")
            if only_architecture is not None and arch != only_architecture:
                continue
            sbuild = sbuild_workflow.children.get(
                task_type=TaskTypes.WORKER,
                task_name="sbuild",
                workflow_data_json__step=f"build-{arch}",
            )
            upload = self.playground.create_upload_artifacts(
                src_name=source_data.name,
                version=source_data.version,
                binary=True,
                source=False,
                binaries=[(source_data.name, arch)],
                section=section,
                priority=priority,
                workspace=sbuild.workspace,
                work_request=sbuild,
            )
            binaries[arch] = upload.binaries
            self.playground.advance_work_request(
                sbuild, result=WorkRequest.Results.SUCCESS
            )
        return binaries
