HEX
Server: LiteSpeed
System: Linux linux31.centraldnserver.com 4.18.0-553.83.1.lve.el8.x86_64 #1 SMP Wed Nov 12 10:04:12 UTC 2025 x86_64
User: salamatk (1501)
PHP: 8.1.33
Disabled: show_source, system, shell_exec, passthru, exec, popen, proc_open
Upload Files
File: //opt/imunify360/venv/lib/python3.11/site-packages/defence360agent/model/wp_disabled_rule.py
"""WordPress-specific disabled rules data model.

This module provides a separate data model for WordPress disabled rules,
independent of the existing DisabledRule/DisabledRuleDomain models used
by modsec/ossec plugins.

Disable Behavior:
    Global and domain-level disables are independent and can coexist.
    A rule is considered effectively disabled for a given WordPress domain
    if EITHER of these conditions is true:
    - A global disable exists for the rule (applies to all domains)
    - A domain-specific disable exists for the rule and that domain

    Enabling a rule at one scope does not affect disables at the other scope.
    For example, removing a global disable leaves any domain-specific disables
    intact, and vice versa.
"""

from collections.abc import Iterator
import logging
import time

from peewee import (
    CharField,
    FloatField,
    IntegerField,
    IntegrityError,
    PrimaryKeyField,
    fn,
)

from defence360agent.model import Model, instance

logger = logging.getLogger(__name__)


class WPDisabledRule(Model):
    """Stores disabled WordPress protection rules.

    Uses a scope-based design:
    - scope='global', scope_value=NULL: Rule disabled for all domains (root only)
    - scope='domain', scope_value='example.com': Rule disabled for specific domain
    """

    class Meta:
        database = instance.db
        db_table = "wp_disabled_rules"
        indexes = ((("rule_id", "scope", "scope_value"), True),)

    id = PrimaryKeyField()
    # The rule identifier (e.g., "CVE-2025-001")
    rule_id = CharField(null=False)
    # The scope type: "global" or "domain"
    scope = CharField(null=False)
    # The scope value: NULL for global, domain name for domain scope
    scope_value = CharField(null=True)
    # Unix timestamp when the rule was disabled
    disabled_at = FloatField(null=False)
    # Origin of the disable action: "wordpress" (from wordpress admin ui) or "agent" (from CLI/RPC)
    source = CharField(null=False)
    # UID of the user who disabled the rule (0 for root)
    created_by_user_id = IntegerField(null=False)

    # Scope constants
    SCOPE_GLOBAL = "global"
    SCOPE_DOMAIN = "domain"

    # Source constants
    SOURCE_WORDPRESS = "wordpress"
    SOURCE_AGENT = "agent"

    @classmethod
    def store(
        cls,
        rule_id: str,
        domains: list[str] | None,
        source: str,
        user_id: int,
        timestamp: float | None = None,
    ) -> int:
        """
        Disable a rule globally or for specific domains.

        Args:
            rule_id: The rule identifier (e.g., "CVE-2025-001")
            domains: List of domains to disable for, or None/empty for global disable
            source: Origin of the action ("wordpress" or "agent")
            user_id: UID of the user performing the action (0 for root)
            timestamp: Unix timestamp for when the rule was disabled.
                       If None, uses current time.

        Returns:
            Number of new entries created (0 if all were no-ops).
        """
        if timestamp is None:
            timestamp = time.time()

        if domains:
            return cls._disable_for_domains(
                rule_id, domains, timestamp, source, user_id
            )
        return cls._disable_globally(rule_id, timestamp, source, user_id)

    @classmethod
    def _disable_globally(
        cls,
        rule_id: str,
        timestamp: float,
        source: str,
        user_id: int,
    ) -> int:
        """Disable a rule globally (independent of domain-specific entries)."""
        created = cls._create_if_not_exists(
            rule_id=rule_id,
            scope=cls.SCOPE_GLOBAL,
            scope_value=None,
            disabled_at=timestamp,
            source=source,
            user_id=user_id,
        )
        if created:
            logger.debug(
                "Disabled rule %s globally (source=%s, user_id=%s)",
                rule_id,
                source,
                user_id,
            )
        return int(created)

    @classmethod
    def _disable_for_domains(
        cls,
        rule_id: str,
        domains: list[str],
        timestamp: float,
        source: str,
        user_id: int,
    ) -> int:
        """Disable a rule for specific domains (independent of global state)."""
        count = 0
        for domain in domains:
            created = cls._create_if_not_exists(
                rule_id=rule_id,
                scope=cls.SCOPE_DOMAIN,
                scope_value=domain,
                disabled_at=timestamp,
                source=source,
                user_id=user_id,
            )
            if created:
                count += 1
                logger.debug(
                    "Disabled rule %s for domain %s (source=%s, user_id=%s)",
                    rule_id,
                    domain,
                    source,
                    user_id,
                )
        return count

    @classmethod
    def _create_if_not_exists(
        cls,
        rule_id: str,
        scope: str,
        scope_value: str | None,
        disabled_at: float,
        source: str,
        user_id: int,
    ) -> bool:
        """
        Create a new disabled rule entry if it doesn't already exist.

        Returns:
            True if a new entry was created, False if it already existed (no-op)
        """
        try:
            cls.insert(
                rule_id=rule_id,
                scope=scope,
                scope_value=scope_value,
                disabled_at=disabled_at,
                source=source,
                created_by_user_id=user_id,
            ).execute()
            return True
        except IntegrityError:
            # Rule already disabled for this scope - no-op
            return False

    @classmethod
    def remove(cls, rule_id: str, domains: list[str] | None) -> int:
        """
        Re-enable a rule globally or for specific domains.

        Args:
            rule_id: The rule identifier
            domains: List of domains to enable for, or None/empty to enable globally

        Returns:
            Number of rows deleted
        """
        if not domains:
            # Enable globally - remove ONLY the global entry
            count = (
                cls.delete()
                .where(
                    cls.rule_id == rule_id,
                    cls.scope == cls.SCOPE_GLOBAL,
                )
                .execute()
            )
            if count:
                logger.debug("Enabled rule %s globally", rule_id)
        else:
            # Enable for specific domains
            count = (
                cls.delete()
                .where(
                    cls.rule_id == rule_id,
                    cls.scope == cls.SCOPE_DOMAIN,
                    cls.scope_value.in_(domains),
                )
                .execute()
            )
            if count:
                logger.debug(
                    "Enabled rule %s for %d domain(s)",
                    rule_id,
                    count,
                )
        return count

    @classmethod
    def is_rule_disabled(cls, rule_id: str, domain: str | None = None) -> bool:
        """
        Check if a rule is disabled globally or for a specific domain.

        Args:
            rule_id: The rule identifier
            domain: The domain to check. If None, only checks global disable.

        Returns:
            True if the rule is disabled, False otherwise
        """
        if domain is None:
            return (
                cls.select()
                .where(
                    cls.rule_id == rule_id,
                    cls.scope == cls.SCOPE_GLOBAL,
                )
                .exists()
            )

        return (
            cls.select()
            .where(
                cls.rule_id == rule_id,
                (
                    (cls.scope == cls.SCOPE_GLOBAL)
                    | (
                        (cls.scope == cls.SCOPE_DOMAIN)
                        & (cls.scope_value == domain)
                    )
                ),
            )
            .exists()
        )

    @classmethod
    def get_domain_disabled(
        cls, domain: str, include_global: bool = False
    ) -> list[str]:
        """
        Get all rule IDs that are disabled for a specific domain.

        Args:
            domain: The domain to get disabled rules for
            include_global: If True, also include globally disabled rules.
                           If False (default), only return domain-specific disables.

        Returns:
            List of rule IDs that are disabled for the domain
        """
        if include_global:
            query = (
                cls.select(cls.rule_id)
                .where(
                    (cls.scope == cls.SCOPE_GLOBAL)
                    | (
                        (cls.scope == cls.SCOPE_DOMAIN)
                        & (cls.scope_value == domain)
                    )
                )
                .distinct()
            )
        else:
            query = cls.select(cls.rule_id).where(
                cls.scope == cls.SCOPE_DOMAIN,
                cls.scope_value == domain,
            )
        return [row.rule_id for row in query]

    @classmethod
    def get_global_disabled(cls) -> Iterator[str]:
        """
        Get all rule IDs that are disabled globally.

        Returns:
            Iterator of globally disabled rule IDs
        """
        query = cls.select(cls.rule_id).where(cls.scope == cls.SCOPE_GLOBAL)
        return (row.rule_id for row in query)

    @classmethod
    def _build_filter_condition(
        cls,
        user_domains: list[str] | None,
        include_global: bool,
    ):
        """
        Build the WHERE condition for filtering rules.

        Returns:
            A Peewee expression for the WHERE clause, or None if no filter needed.
        """
        if user_domains is not None:
            domain_match = (cls.scope == cls.SCOPE_DOMAIN) & (
                cls.scope_value.in_(user_domains)
            )
            if include_global:
                return (cls.scope == cls.SCOPE_GLOBAL) | domain_match
            return domain_match
        if not include_global:
            return cls.scope == cls.SCOPE_DOMAIN
        return None

    @classmethod
    def fetch(
        cls,
        limit: int,
        offset: int = 0,
        user_domains: list[str] | None = None,
        include_global: bool = False,
    ) -> tuple[int, list[dict]]:
        """
        List disabled rules with aggregation by rule_id.

        Multiple domain entries for the same rule are aggregated into a single
        result with a list of domains. Results are ordered by most recently
        disabled first (using the latest disabled_at timestamp per rule_id).

        Uses a two-pass approach for efficiency:
        1. First pass: Get rule_ids ordered by latest disabled_at with pagination
        2. Second pass: Fetch only rows for the paginated rule_ids

        Args:
            limit: Maximum number of rules to return
            offset: Number of rules to skip
            user_domains: If provided, only return rules for these domains.
                         If None, return all rules (for root users).
            include_global: Whether to include global rules in the result

        Returns:
            Tuple of (total_count, list of rule dicts)
            Each dict has: {"rule_id": str, "is_global": bool, "domains": list[str]}
            is_global is True if rule has a global disable, domains lists domain-specific disables
        """
        # Build filter condition
        condition = cls._build_filter_condition(user_domains, include_global)

        # First pass: get rule_ids ordered by latest disabled_at (most recent first)
        rule_ids_query = (
            cls.select(cls.rule_id)
            .group_by(cls.rule_id)
            .order_by(fn.MAX(cls.disabled_at).desc())
        )
        if condition is not None:
            rule_ids_query = rule_ids_query.where(condition)

        # Get total count of distinct rule_ids
        total_count = rule_ids_query.count()

        # Apply pagination at DB level
        paginated_rule_ids = [
            row.rule_id for row in rule_ids_query.offset(offset).limit(limit)
        ]
        if not paginated_rule_ids:
            return total_count, []

        # Second pass: fetch rows for the paginated rule_ids
        rows_query = cls.select().where(cls.rule_id.in_(paginated_rule_ids))
        if condition is not None:
            rows_query = rows_query.where(condition)

        # Aggregate domains by rule_id
        rules_by_id: dict[str, dict] = {}
        for row in rows_query:
            if row.rule_id not in rules_by_id:
                rules_by_id[row.rule_id] = {
                    "rule_id": row.rule_id,
                    "is_global": False,
                    "domains": [],
                }

            if row.scope == cls.SCOPE_GLOBAL:
                rules_by_id[row.rule_id]["is_global"] = True
            elif row.scope == cls.SCOPE_DOMAIN:
                rules_by_id[row.rule_id]["domains"].append(row.scope_value)

        # Build result in order from first query (preserves DB ordering)
        result = []
        for rule_id in paginated_rule_ids:
            rule_data = rules_by_id[rule_id]
            rule_data["domains"] = sorted(rule_data["domains"])
            result.append(rule_data)

        return total_count, result