AvanzadoGraph MLGNNcall graphcontrol flow graphdetección de malware

Graph ML para Análisis de Malware: Relaciones y Call Graphs

Aplicación de Graph Neural Networks (GNN) al análisis de malware: construcción de call graphs, control flow graphs, grafos de comportamiento, node embeddings con GCN y GAT, y detección de familias a través de la estructura del código.

MalwareIntel Research··9 min lectura
Serie: AI/ML para Malware — Parte 14

Por qué los grafos capturan lo que los vectores pierden

Un binario de malware no es solo una secuencia de bytes o una colección de features numéricas. Es un programa con estructura: funciones que llaman a otras funciones, flujo de control con bifurcaciones y bucles, datos que fluyen entre instrucciones. Esa estructura es la esencia del programa.

Cuando un autor de malware ofusca su código (renombra variables, reordena funciones, añade código muerto), los bytes cambian. El hash cambia. Las features estadísticas cambian. Pero la estructura lógica del programa permanece: las mismas funciones siguen llamando a las mismas funciones, el mismo flujo de control sigue ejecutándose, los mismos datos siguen procesándose.

Los grafos capturan esa estructura. Y las Graph Neural Networks (GNN) aprenden a clasificar grafos por su topología, no por sus bytes.

Tipos de grafos en análisis de malware

Call Graph (CG)

El call graph representa las llamadas entre funciones:

import networkx as nx
from dataclasses import dataclass, field

@dataclass
class FunctionNode:
    address: int
    name: str
    size: int
    num_instructions: int
    num_basic_blocks: int
    imports_called: list[str] = field(default_factory=list)

def build_call_graph_from_binary(filepath: str) -> nx.DiGraph:
    """Construye el call graph de un binario PE usando angr."""
    import angr
    
    project = angr.Project(filepath, auto_load_libs=False)
    cfg = project.analyses.CFGFast()
    
    cg = nx.DiGraph()
    
    for func_addr, func in cfg.kb.functions.items():
        # Nodo: cada función
        cg.add_node(func_addr, **{
            "name": func.name,
            "size": func.size,
            "num_blocks": len(list(func.blocks)),
            "is_import": func.is_simprocedure or func.is_plt,
            "num_instructions": sum(
                len(list(block.capstone.insns)) 
                for block in func.blocks
            ) if not func.is_simprocedure else 0,
        })
        
        # Aristas: llamadas entre funciones
        for callee_addr in func.callees:
            if callee_addr in cfg.kb.functions:
                cg.add_edge(func_addr, callee_addr)
    
    return cg

Control Flow Graph (CFG)

El CFG representa el flujo de control dentro de una función:

def build_cfg_for_function(
    project,  # angr.Project
    func_addr: int
) -> nx.DiGraph:
    """Construye el CFG de una función específica."""
    
    cfg_analysis = project.analyses.CFGFast()
    func = cfg_analysis.kb.functions.get(func_addr)
    
    if func is None:
        return nx.DiGraph()
    
    cfg = nx.DiGraph()
    
    for block in func.blocks:
        # Features del bloque básico
        instructions = list(block.capstone.insns)
        
        # Distribución de tipos de instrucciones
        inst_types = {"arithmetic": 0, "logic": 0, "transfer": 0,
                      "call": 0, "compare": 0, "other": 0}
        
        for insn in instructions:
            mnemonic = insn.mnemonic
            if mnemonic in ("add", "sub", "mul", "div", "inc", "dec"):
                inst_types["arithmetic"] += 1
            elif mnemonic in ("and", "or", "xor", "not", "shl", "shr"):
                inst_types["logic"] += 1
            elif mnemonic in ("mov", "push", "pop", "lea", "xchg"):
                inst_types["transfer"] += 1
            elif mnemonic in ("call",):
                inst_types["call"] += 1
            elif mnemonic in ("cmp", "test"):
                inst_types["compare"] += 1
            else:
                inst_types["other"] += 1
        
        cfg.add_node(block.addr, **{
            "size": block.size,
            "num_instructions": len(instructions),
            **inst_types,
        })
    
    # Aristas: saltos entre bloques
    for block in func.blocks:
        for successor_addr in func.graph.successors(block.addr):
            cfg.add_edge(block.addr, successor_addr)
    
    return cfg

Behavior Graph

Los behavior graphs capturan acciones observadas durante ejecución en sandbox:

def build_behavior_graph(sandbox_report: dict) -> nx.DiGraph:
    """Construye grafo de comportamiento desde un reporte de sandbox."""
    
    bg = nx.DiGraph()
    
    # Nodos: entidades (procesos, ficheros, claves de registro, IPs)
    for proc in sandbox_report.get("processes", []):
        bg.add_node(f"proc_{proc['pid']}", 
                    type="process", name=proc["name"])
    
    for file_op in sandbox_report.get("file_operations", []):
        bg.add_node(f"file_{file_op['path']}", 
                    type="file", path=file_op["path"])
        bg.add_edge(
            f"proc_{file_op['pid']}", 
            f"file_{file_op['path']}",
            action=file_op["action"]  # create, write, delete
        )
    
    for reg_op in sandbox_report.get("registry_operations", []):
        bg.add_node(f"reg_{reg_op['key']}", 
                    type="registry", key=reg_op["key"])
        bg.add_edge(
            f"proc_{reg_op['pid']}", 
            f"reg_{reg_op['key']}",
            action=reg_op["action"]  # set, delete, query
        )
    
    for net_op in sandbox_report.get("network_connections", []):
        bg.add_node(f"net_{net_op['dst_ip']}", 
                    type="network", ip=net_op["dst_ip"])
        bg.add_edge(
            f"proc_{net_op['pid']}", 
            f"net_{net_op['dst_ip']}",
            action="connect", port=net_op["dst_port"]
        )
    
    return bg

Features de grafos para ML clásico

Antes de las GNN, se extraían features topológicas del grafo para alimentar modelos clásicos:

def extract_graph_features(G: nx.DiGraph) -> dict:
    """Extrae features topológicas de un grafo."""
    
    if G.number_of_nodes() == 0:
        return {k: 0 for k in [
            "num_nodes", "num_edges", "density",
            "avg_degree", "max_degree",
            "num_connected_components", "avg_clustering",
            "avg_betweenness", "avg_closeness",
        ]}
    
    features = {
        "num_nodes": G.number_of_nodes(),
        "num_edges": G.number_of_edges(),
        "density": nx.density(G),
    }
    
    # Grados
    degrees = [d for _, d in G.degree()]
    in_degrees = [d for _, d in G.in_degree()]
    out_degrees = [d for _, d in G.out_degree()]
    
    features["avg_degree"] = sum(degrees) / len(degrees)
    features["max_degree"] = max(degrees)
    features["std_degree"] = float(np.std(degrees))
    features["max_in_degree"] = max(in_degrees)
    features["max_out_degree"] = max(out_degrees)
    
    # Componentes conexas (en versión no dirigida)
    G_undirected = G.to_undirected()
    components = list(nx.connected_components(G_undirected))
    features["num_connected_components"] = len(components)
    features["largest_component_ratio"] = (
        max(len(c) for c in components) / G.number_of_nodes()
    )
    
    # Clustering coefficient
    features["avg_clustering"] = nx.average_clustering(G_undirected)
    
    # Centralidad (en subgrafo para eficiencia)
    if G.number_of_nodes() < 5000:
        betweenness = nx.betweenness_centrality(G)
        features["avg_betweenness"] = sum(betweenness.values()) / len(betweenness)
        features["max_betweenness"] = max(betweenness.values())
    
    return features

Graph Neural Networks para malware

Las GNN aprenden representaciones de nodos y grafos directamente de la topología, sin necesidad de extraer features manuales.

Graph Convolutional Network (GCN)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, DataLoader

class MalwareGCN(nn.Module):
    """GCN para clasificación de grafos de malware."""
    
    def __init__(
        self,
        num_node_features: int,
        hidden_dim: int = 128,
        num_classes: int = 2,
        num_layers: int = 3,
        dropout: float = 0.3,
    ):
        super().__init__()
        
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        
        # Primera capa
        self.convs.append(GCNConv(num_node_features, hidden_dim))
        self.bns.append(nn.BatchNorm1d(hidden_dim))
        
        # Capas intermedias
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            self.bns.append(nn.BatchNorm1d(hidden_dim))
        
        # Clasificador
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(64, num_classes),
        )
        
        self.dropout = dropout
    
    def forward(self, data: Data) -> torch.Tensor:
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Capas GCN
        for conv, bn in zip(self.convs, self.bns):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Graph-level pooling
        graph_embedding = global_mean_pool(x, batch)
        
        # Clasificación
        output = self.classifier(graph_embedding)
        return output

Graph Attention Network (GAT)

GAT añade mecanismos de atención que ponderan la importancia de cada vecino:

from torch_geometric.nn import GATConv

class MalwareGAT(nn.Module):
    """GAT para clasificación de malware con mecanismo de atención."""
    
    def __init__(
        self,
        num_node_features: int,
        hidden_dim: int = 64,
        num_heads: int = 4,
        num_classes: int = 2,
        dropout: float = 0.3,
    ):
        super().__init__()
        
        # Capa GAT 1: multi-head attention
        self.gat1 = GATConv(
            num_node_features, hidden_dim,
            heads=num_heads, dropout=dropout
        )
        
        # Capa GAT 2: single head para reducir dimensionalidad
        self.gat2 = GATConv(
            hidden_dim * num_heads, hidden_dim,
            heads=1, concat=False, dropout=dropout
        )
        
        # Readout y clasificador
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 32),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(32, num_classes),
        )
        
        self.dropout = dropout
    
    def forward(self, data: Data) -> torch.Tensor:
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.elu(self.gat1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.gat2(x, edge_index)
        
        # Pooling
        graph_embedding = global_mean_pool(x, batch)
        
        return self.classifier(graph_embedding)

Convirtiendo call graphs a formato PyG

def callgraph_to_pyg(
    cg: nx.DiGraph,
    label: int,
    node_feature_dim: int = 10,
) -> Data:
    """Convierte un call graph NetworkX a un objeto PyG Data."""
    
    # Mapear nodos a índices consecutivos
    node_mapping = {node: i for i, node in enumerate(cg.nodes())}
    
    # Features de nodos
    node_features = []
    for node in cg.nodes():
        attrs = cg.nodes[node]
        features = [
            attrs.get("size", 0),
            attrs.get("num_blocks", 0),
            attrs.get("num_instructions", 0),
            attrs.get("is_import", 0),
            cg.in_degree(node),
            cg.out_degree(node),
        ]
        # Pad a dimensión fija
        while len(features) < node_feature_dim:
            features.append(0)
        node_features.append(features[:node_feature_dim])
    
    x = torch.tensor(node_features, dtype=torch.float)
    
    # Aristas
    edges = [
        [node_mapping[u], node_mapping[v]]
        for u, v in cg.edges()
        if u in node_mapping and v in node_mapping
    ]
    
    if edges:
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    else:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
    
    # Label del grafo
    y = torch.tensor([label], dtype=torch.long)
    
    return Data(x=x, edge_index=edge_index, y=y)

Entrenamiento y evaluación

from torch_geometric.loader import DataLoader as PyGDataLoader
from sklearn.metrics import classification_report

def train_gnn(
    model: nn.Module,
    train_graphs: list[Data],
    val_graphs: list[Data],
    epochs: int = 100,
    lr: float = 1e-3,
    batch_size: int = 32,
):
    """Entrena un modelo GNN sobre grafos de malware."""
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    train_loader = PyGDataLoader(
        train_graphs, batch_size=batch_size, shuffle=True
    )
    val_loader = PyGDataLoader(
        val_graphs, batch_size=batch_size
    )
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0.0
    
    for epoch in range(epochs):
        # Train
        model.train()
        total_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch)
            loss = criterion(out, batch.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                pred = model(batch).argmax(dim=1)
                correct += (pred == batch.y).sum().item()
                total += batch.y.size(0)
        
        val_acc = correct / total
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_gnn_model.pt")
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}: loss={total_loss/len(train_loader):.4f}, "
                  f"val_acc={val_acc:.4f}")
    
    print(f"Mejor val accuracy: {best_val_acc:.4f}")

Resultados y comparativas

Los resultados publicados en papers recientes muestran:

MétodoTipo de grafoDatasetF1
Graph features + RFCall GraphMalicia0.92
GCNCall GraphMalNet-Tiny0.94
GATCFGCustom PE0.95
GraphSAGEBehavior GraphCuckoo0.91
DGCNNCall Graph + CFGMalNet-Large0.96

Las GNN muestran ventaja sobre los métodos basados en features manuales del grafo, especialmente cuando el dataset contiene familias con variantes muy similares que solo se distinguen por sutilezas topológicas.

Limitaciones prácticas

  1. Coste de construcción del grafo. Desensamblar un binario y construir su call graph tarda entre 1 segundo y varios minutos, dependiendo del tamaño y la complejidad. No es viable para escaneo masivo en tiempo real.

  2. Evasión por reestructuración. Si bien los grafos son más resistentes que los bytes a la ofuscación superficial, técnicas como function inlining, control flow flattening y opaque predicates pueden alterar significativamente la topología del grafo.

  3. Escalabilidad. Los call graphs de programas grandes pueden tener miles de nodos. Las GNN con message passing global son costosas en memoria para grafos grandes.

  4. Dependencia del desensamblador. La calidad del call graph depende del desensamblador (IDA Pro, Ghidra, angr, Binary Ninja). Diferentes herramientas producen grafos diferentes para el mismo binario.

Conclusión

Los grafos aportan una dimensión de análisis que las features vectoriales no capturan: la estructura lógica del programa. Las GNN son la herramienta natural para operar sobre esa estructura, aprendiendo patrones topológicos que distinguen familias de malware.

En la práctica, los grafos complementan (no reemplazan) los métodos basados en features estáticas y dinámicas. Un sistema de detección robusto combina LightGBM sobre features PE para velocidad, CNN/MalConv para análisis de bytes, y GNN sobre call graphs para resistencia a la ofuscación. Cada capa cubre las debilidades de las otras.

Preguntas frecuentes

Artículos relacionados

Este contenido tiene fines exclusivamente educativos y de investigación en ciberseguridad defensiva. No se proporcionan binarios maliciosos ni payloads ejecutables. El uso indebido de esta información es responsabilidad exclusiva del usuario. Leer disclaimer completo.