File: //opt/imunify360/venv/lib64/python3.11/site-packages/defence360agent/model/wp_disabled_rule.py
"""WordPress-specific disabled rules data model.
This module provides a separate data model for WordPress disabled rules,
independent of the existing DisabledRule/DisabledRuleDomain models used
by modsec/ossec plugins.
Disable Behavior:
Global and domain-level disables are independent and can coexist.
A rule is considered effectively disabled for a given WordPress domain
if EITHER of these conditions is true:
- A global disable exists for the rule (applies to all domains)
- A domain-specific disable exists for the rule and that domain
Enabling a rule at one scope does not affect disables at the other scope.
For example, removing a global disable leaves any domain-specific disables
intact, and vice versa.
"""
from collections.abc import Iterator
import logging
import time
from peewee import (
CharField,
FloatField,
IntegerField,
IntegrityError,
PrimaryKeyField,
fn,
)
from defence360agent.model import Model, instance
logger = logging.getLogger(__name__)
class WPDisabledRule(Model):
"""Stores disabled WordPress protection rules.
Uses a scope-based design:
- scope='global', scope_value=NULL: Rule disabled for all domains (root only)
- scope='domain', scope_value='example.com': Rule disabled for specific domain
"""
class Meta:
database = instance.db
db_table = "wp_disabled_rules"
indexes = ((("rule_id", "scope", "scope_value"), True),)
id = PrimaryKeyField()
# The rule identifier (e.g., "CVE-2025-001")
rule_id = CharField(null=False)
# The scope type: "global" or "domain"
scope = CharField(null=False)
# The scope value: NULL for global, domain name for domain scope
scope_value = CharField(null=True)
# Unix timestamp when the rule was disabled
disabled_at = FloatField(null=False)
# Origin of the disable action: "wordpress" (from wordpress admin ui) or "agent" (from CLI/RPC)
source = CharField(null=False)
# UID of the user who disabled the rule (0 for root)
created_by_user_id = IntegerField(null=False)
# Scope constants
SCOPE_GLOBAL = "global"
SCOPE_DOMAIN = "domain"
# Source constants
SOURCE_WORDPRESS = "wordpress"
SOURCE_AGENT = "agent"
@classmethod
def store(
cls,
rule_id: str,
domains: list[str] | None,
source: str,
user_id: int,
timestamp: float | None = None,
) -> int:
"""
Disable a rule globally or for specific domains.
Args:
rule_id: The rule identifier (e.g., "CVE-2025-001")
domains: List of domains to disable for, or None/empty for global disable
source: Origin of the action ("wordpress" or "agent")
user_id: UID of the user performing the action (0 for root)
timestamp: Unix timestamp for when the rule was disabled.
If None, uses current time.
Returns:
Number of new entries created (0 if all were no-ops).
"""
if timestamp is None:
timestamp = time.time()
if domains:
return cls._disable_for_domains(
rule_id, domains, timestamp, source, user_id
)
return cls._disable_globally(rule_id, timestamp, source, user_id)
@classmethod
def _disable_globally(
cls,
rule_id: str,
timestamp: float,
source: str,
user_id: int,
) -> int:
"""Disable a rule globally (independent of domain-specific entries)."""
created = cls._create_if_not_exists(
rule_id=rule_id,
scope=cls.SCOPE_GLOBAL,
scope_value=None,
disabled_at=timestamp,
source=source,
user_id=user_id,
)
if created:
logger.debug(
"Disabled rule %s globally (source=%s, user_id=%s)",
rule_id,
source,
user_id,
)
return int(created)
@classmethod
def _disable_for_domains(
cls,
rule_id: str,
domains: list[str],
timestamp: float,
source: str,
user_id: int,
) -> int:
"""Disable a rule for specific domains (independent of global state)."""
count = 0
for domain in domains:
created = cls._create_if_not_exists(
rule_id=rule_id,
scope=cls.SCOPE_DOMAIN,
scope_value=domain,
disabled_at=timestamp,
source=source,
user_id=user_id,
)
if created:
count += 1
logger.debug(
"Disabled rule %s for domain %s (source=%s, user_id=%s)",
rule_id,
domain,
source,
user_id,
)
return count
@classmethod
def _create_if_not_exists(
cls,
rule_id: str,
scope: str,
scope_value: str | None,
disabled_at: float,
source: str,
user_id: int,
) -> bool:
"""
Create a new disabled rule entry if it doesn't already exist.
Returns:
True if a new entry was created, False if it already existed (no-op)
"""
try:
cls.insert(
rule_id=rule_id,
scope=scope,
scope_value=scope_value,
disabled_at=disabled_at,
source=source,
created_by_user_id=user_id,
).execute()
return True
except IntegrityError:
# Rule already disabled for this scope - no-op
return False
@classmethod
def remove(cls, rule_id: str, domains: list[str] | None) -> int:
"""
Re-enable a rule globally or for specific domains.
Args:
rule_id: The rule identifier
domains: List of domains to enable for, or None/empty to enable globally
Returns:
Number of rows deleted
"""
if not domains:
# Enable globally - remove ONLY the global entry
count = (
cls.delete()
.where(
cls.rule_id == rule_id,
cls.scope == cls.SCOPE_GLOBAL,
)
.execute()
)
if count:
logger.debug("Enabled rule %s globally", rule_id)
else:
# Enable for specific domains
count = (
cls.delete()
.where(
cls.rule_id == rule_id,
cls.scope == cls.SCOPE_DOMAIN,
cls.scope_value.in_(domains),
)
.execute()
)
if count:
logger.debug(
"Enabled rule %s for %d domain(s)",
rule_id,
count,
)
return count
@classmethod
def is_rule_disabled(cls, rule_id: str, domain: str | None = None) -> bool:
"""
Check if a rule is disabled globally or for a specific domain.
Args:
rule_id: The rule identifier
domain: The domain to check. If None, only checks global disable.
Returns:
True if the rule is disabled, False otherwise
"""
if domain is None:
return (
cls.select()
.where(
cls.rule_id == rule_id,
cls.scope == cls.SCOPE_GLOBAL,
)
.exists()
)
return (
cls.select()
.where(
cls.rule_id == rule_id,
(
(cls.scope == cls.SCOPE_GLOBAL)
| (
(cls.scope == cls.SCOPE_DOMAIN)
& (cls.scope_value == domain)
)
),
)
.exists()
)
@classmethod
def get_domain_disabled(
cls, domain: str, include_global: bool = False
) -> list[str]:
"""
Get all rule IDs that are disabled for a specific domain.
Args:
domain: The domain to get disabled rules for
include_global: If True, also include globally disabled rules.
If False (default), only return domain-specific disables.
Returns:
List of rule IDs that are disabled for the domain
"""
if include_global:
query = (
cls.select(cls.rule_id)
.where(
(cls.scope == cls.SCOPE_GLOBAL)
| (
(cls.scope == cls.SCOPE_DOMAIN)
& (cls.scope_value == domain)
)
)
.distinct()
)
else:
query = cls.select(cls.rule_id).where(
cls.scope == cls.SCOPE_DOMAIN,
cls.scope_value == domain,
)
return [row.rule_id for row in query]
@classmethod
def get_global_disabled(cls) -> Iterator[str]:
"""
Get all rule IDs that are disabled globally.
Returns:
Iterator of globally disabled rule IDs
"""
query = cls.select(cls.rule_id).where(cls.scope == cls.SCOPE_GLOBAL)
return (row.rule_id for row in query)
@classmethod
def _build_filter_condition(
cls,
user_domains: list[str] | None,
include_global: bool,
):
"""
Build the WHERE condition for filtering rules.
Returns:
A Peewee expression for the WHERE clause, or None if no filter needed.
"""
if user_domains is not None:
domain_match = (cls.scope == cls.SCOPE_DOMAIN) & (
cls.scope_value.in_(user_domains)
)
if include_global:
return (cls.scope == cls.SCOPE_GLOBAL) | domain_match
return domain_match
if not include_global:
return cls.scope == cls.SCOPE_DOMAIN
return None
@classmethod
def fetch(
cls,
limit: int,
offset: int = 0,
user_domains: list[str] | None = None,
include_global: bool = False,
) -> tuple[int, list[dict]]:
"""
List disabled rules with aggregation by rule_id.
Multiple domain entries for the same rule are aggregated into a single
result with a list of domains. Results are ordered by most recently
disabled first (using the latest disabled_at timestamp per rule_id).
Uses a two-pass approach for efficiency:
1. First pass: Get rule_ids ordered by latest disabled_at with pagination
2. Second pass: Fetch only rows for the paginated rule_ids
Args:
limit: Maximum number of rules to return
offset: Number of rules to skip
user_domains: If provided, only return rules for these domains.
If None, return all rules (for root users).
include_global: Whether to include global rules in the result
Returns:
Tuple of (total_count, list of rule dicts)
Each dict has: {"rule_id": str, "is_global": bool, "domains": list[str]}
is_global is True if rule has a global disable, domains lists domain-specific disables
"""
# Build filter condition
condition = cls._build_filter_condition(user_domains, include_global)
# First pass: get rule_ids ordered by latest disabled_at (most recent first)
rule_ids_query = (
cls.select(cls.rule_id)
.group_by(cls.rule_id)
.order_by(fn.MAX(cls.disabled_at).desc())
)
if condition is not None:
rule_ids_query = rule_ids_query.where(condition)
# Get total count of distinct rule_ids
total_count = rule_ids_query.count()
# Apply pagination at DB level
paginated_rule_ids = [
row.rule_id for row in rule_ids_query.offset(offset).limit(limit)
]
if not paginated_rule_ids:
return total_count, []
# Second pass: fetch rows for the paginated rule_ids
rows_query = cls.select().where(cls.rule_id.in_(paginated_rule_ids))
if condition is not None:
rows_query = rows_query.where(condition)
# Aggregate domains by rule_id
rules_by_id: dict[str, dict] = {}
for row in rows_query:
if row.rule_id not in rules_by_id:
rules_by_id[row.rule_id] = {
"rule_id": row.rule_id,
"is_global": False,
"domains": [],
}
if row.scope == cls.SCOPE_GLOBAL:
rules_by_id[row.rule_id]["is_global"] = True
elif row.scope == cls.SCOPE_DOMAIN:
rules_by_id[row.rule_id]["domains"].append(row.scope_value)
# Build result in order from first query (preserves DB ordering)
result = []
for rule_id in paginated_rule_ids:
rule_data = rules_by_id[rule_id]
rule_data["domains"] = sorted(rule_data["domains"])
result.append(rule_data)
return total_count, result