File: //opt/imunify360/venv/lib64/python3.11/site-packages/imav/plugins/send_malware_infection_state.py
"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Copyright © 2019 Cloud Linux Software Inc.
This software is also available under ImunifyAV commercial license,
see <https://www.imunify360.com/legal/eula>
"""
import asyncio
import datetime
import logging
import pwd
import time
from asyncio import AbstractEventLoop
from collections import defaultdict
from contextlib import suppress
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Iterable, cast
from defence360agent.contracts.messages import (
Message,
Reportable,
ReportTarget,
)
from defence360agent.contracts.myimunify_id import get_myimunify_users
from defence360agent.contracts.plugins import MessageSink, MessageSource
from defence360agent.subsys.panels.base import AbstractPanel
from defence360agent.subsys.panels.hosting_panel import HostingPanel
from defence360agent.subsys.persistent_state import load_state, save_state
from defence360agent.utils import Scope, recurring_check, split_for_chunk
from imav.malwarelib.config import MalwareHitStatus
from imav.malwarelib.model import MalwareHit
logger = logging.getLogger(__name__)
# Check every 5 minutes if it's time to send
RECURRING_CHECK_INTERVAL = 60 * 5
SEND_INTERVAL = int(datetime.timedelta(days=1).total_seconds())
MalwareHitDict = dict[str, Any]
@dataclass(slots=True)
class ProcessedHits:
"""Result of processing malware hits."""
files: list[dict[str, Any]] = field(default_factory=list)
infected_count: int = 0
cleaned_count: int = 0
class MalwareInfectionState(Message, Reportable):
DEFAULT_METHOD = "MALWARE_STATE_SNAPSHOTS"
TARGET = ReportTarget.API
class SendMalwareInfectionState(MessageSink, MessageSource):
SCOPE = Scope.AV_IM360
CHUNK_SIZE = 1000
STATE_KEY = "SendMalwareInfectionState"
def __init__(self):
self._task = None
self._last_send_timestamp = 0
async def create_sink(self, loop: AbstractEventLoop):
pass
async def create_source(self, loop: AbstractEventLoop, sink):
self._sink = sink
self._last_send_timestamp = self._load_last_send_timestamp()
self._task = loop.create_task(
recurring_check(RECURRING_CHECK_INTERVAL)(self._check_and_send)()
)
async def shutdown(self):
if self._task is not None:
self._task, t = None, self._task
t.cancel()
with suppress(asyncio.CancelledError):
await t
@staticmethod
def _is_valid_timestamp(timestamp: Any) -> bool:
"""Check if timestamp is a valid positive number."""
return isinstance(timestamp, (int, float)) and timestamp >= 0
def _save_last_send_timestamp(self, ts: int | float | None = None):
"""Save the last send timestamp to persistent state."""
timestamp = self._last_send_timestamp if ts is None else ts
if not self._is_valid_timestamp(timestamp):
logger.warning(
"Invalid timestamp to save: %s, skipping", timestamp
)
return
save_state(self.STATE_KEY, {"last_send_timestamp": timestamp})
def _load_last_send_timestamp(self) -> float:
"""Load the last send timestamp from persistent state."""
timestamp = load_state(self.STATE_KEY).get("last_send_timestamp")
if not self._is_valid_timestamp(timestamp):
logger.info(
"No valid last send timestamp found, will send on first check"
)
return 0
return cast(float, timestamp)
async def _check_and_send(self):
"""Check if enough time has passed and send malware state if so."""
if time.time() - self._last_send_timestamp < SEND_INTERVAL:
return
try:
await self._send_malware_state()
except Exception:
logger.exception("Failed to send malware infection state")
finally:
# Update timestamp even on error to avoid repeated failures
self._last_send_timestamp = time.time()
self._save_last_send_timestamp()
async def _send_malware_state(self):
async for message in self._generate_messages():
await self._sink.process_message(message)
@staticmethod
def _map_status(status: str) -> str:
"""Map MalwareHitStatus to output status (infected/cleaned/restored)."""
if status in (
MalwareHitStatus.CLEANUP_DONE,
MalwareHitStatus.CLEANUP_REMOVED,
):
return "cleaned"
if status == MalwareHitStatus.RESTORED_FROM_BACKUP:
return "restored"
return "infected"
def _hit_to_file_dict(self, hit: MalwareHitDict) -> dict[str, Any]:
"""Convert a malware hit to the file dictionary format."""
return {
"orig_file": hit.get("file", ""),
"hash": hit.get("hash"),
"type": hit.get("type"),
"size": hit.get("size"),
"status": self._map_status(hit.get("status", "")),
"cleaned_at": hit.get("cleaned_at"),
"resource_type": hit.get("resource_type"),
"app_name": hit.get("app_name"),
}
def _process_hits(self, hits: list[MalwareHitDict]) -> ProcessedHits:
"""Process hits and return files list with counts."""
result = ProcessedHits()
for hit in hits:
file_dict = self._hit_to_file_dict(hit)
status = file_dict["status"]
if status == "infected":
result.infected_count += 1
elif status == "cleaned":
result.cleaned_count += 1
result.files.append(file_dict)
return result
def _create_message(
self,
now: float,
username: str,
uid: int,
files_chunk: list[dict[str, Any]],
infected_count: int,
cleaned_count: int,
domain_document_root: str = "",
domain_names: list[str] | None = None,
) -> MalwareInfectionState:
"""Create a MalwareInfectionState message."""
msg = MalwareInfectionState()
msg["ctimestamp"] = now
msg["owner"] = username
msg["uid"] = uid
msg["domain_document_root"] = domain_document_root
msg["domain_names"] = domain_names or []
msg["current_malware_count"] = infected_count
msg["current_cleaned_count"] = cleaned_count
msg["files"] = files_chunk
return msg
async def _generate_chunked_messages(
self,
hits: list[MalwareHitDict],
now: float,
username: str,
uid: int,
domain_document_root: str = "",
domain_names: list[str] | None = None,
) -> AsyncIterator[MalwareInfectionState]:
"""Generate chunked messages for a set of hits."""
processed = self._process_hits(hits)
for chunk in split_for_chunk(
processed.files, chunk_size=self.CHUNK_SIZE
):
yield self._create_message(
now=now,
username=username,
uid=uid,
files_chunk=chunk,
infected_count=processed.infected_count,
cleaned_count=processed.cleaned_count,
domain_document_root=domain_document_root,
domain_names=domain_names,
)
@staticmethod
def _find_matching_docroot(file_path: str, docroots: Iterable[str]) -> str:
"""Find the longest matching docroot for a file path."""
matched_docroot = ""
for docroot in docroots:
# Ensure docroot ends with / for proper path boundary matching
normalized = docroot if docroot.endswith("/") else docroot + "/"
if file_path.startswith(normalized):
if len(docroot) > len(matched_docroot):
matched_docroot = docroot
return matched_docroot
async def _generate_messages(self) -> AsyncIterator[MalwareInfectionState]:
hp = HostingPanel()
users = await get_myimunify_users()
if not users:
logger.info("No users found, skip sending malware infection state")
return
now = time.time()
for user_info in users:
username = user_info["username"]
# Get user uid
try:
user_pwd = pwd.getpwnam(username)
uid = user_pwd.pw_uid
except KeyError:
logger.info("User %s doesn't exist, skipping", username)
continue
except Exception as e:
logger.error(
"Error retrieving UID for user %s: %s", username, str(e)
)
continue
# Get malware hits for this user
_, hits = MalwareHit.malicious_list(user=username)
if not hits:
logger.info("User %s has no hits", username)
continue
async for message in self._generate_user_messages(
hits, now, username, uid
):
yield message
async for message in self._generate_docroot_messages(
hits, now, username, uid, hp
):
yield message
async def _generate_user_messages(
self, hits: list[MalwareHitDict], now: float, username: str, uid: int
) -> AsyncIterator[MalwareInfectionState]:
"""Generate messages aggregated at user level."""
async for message in self._generate_chunked_messages(
hits=hits,
now=now,
username=username,
uid=uid,
domain_document_root="",
domain_names=[],
):
yield message
async def _generate_docroot_messages(
self,
hits: list[MalwareHitDict],
now: float,
username: str,
uid: int,
hp: AbstractPanel,
) -> AsyncIterator[MalwareInfectionState]:
"""Generate messages aggregated at document root level."""
# Get domain details for this user
domain_details = await hp.get_user_domains_details(username)
# Build mapping: docroot -> list of domain names
docroot_to_domains: defaultdict[str, list[str]] = defaultdict(list)
for domain_data in domain_details:
docroot_to_domains[domain_data.docroot].append(domain_data.domain)
# Group hits by document root
docroot_to_hits: defaultdict[str, list[MalwareHitDict]] = defaultdict(
list
)
for hit in hits:
file_path = hit.get("file", "")
matched_docroot = self._find_matching_docroot(
file_path, docroot_to_domains.keys()
)
docroot_to_hits[matched_docroot].append(hit)
# Generate a message for each docroot with hits
for docroot, docroot_hits in docroot_to_hits.items():
async for message in self._generate_chunked_messages(
hits=docroot_hits,
now=now,
username=username,
uid=uid,
domain_document_root=docroot,
domain_names=docroot_to_domains.get(docroot, []),
):
yield message