SecureBERT Vulnerability Classifier (CVSS & CWE Flat Classifier)

This model automatically analyzes raw vulnerability descriptions (e.g., CVE reports, bug bounty submissions) and predicts CVSS v3.1 metrics alongside a 4-level CWE taxonomy (Pillar, Class, Base, Variant).

It is a fine-tuned version of the domain-specific cisco-ai/SecureBERT2.0 utilizing a Multi-Task Learning (MTL) architecture with flat classification heads.

🎯 Intended Use

The primary use case is automating the initial Vulnerability Triage process. By inputting unstructured threat narratives, security analysts can instantly receive:

  • 8 CVSS v3.1 Metrics: Attack Vector, Attack Complexity, Privileges Required, User Interaction, Scope, Confidentiality, Integrity, and Availability.
  • CWE Classification: Probabilistic mapping to the MITRE CWE tree across 4 levels of abstraction (Top-K predictions).

🧠 Model Architecture

The model uses a shared SecureBERT2.0 backbone with 12 distinct classification heads attached to the pooled outputs:

  • CVSS Heads (8): Multi-Layer Perceptrons (MLP) consisting of LayerNorm -> Linear -> GELU -> Dropout -> Linear -> Softmax. They use the [CLS] token embedding to predict nominal and ordinal CVSS categories.
  • CWE Heads (4): Multi-Layer Perceptrons (MLP) consisting of `LayerNorm -> Linear -> GELU -> Dropout -> Linear. These heads utilize the Mean-Pooled token embeddings.

πŸ“‚ Repository Structure & Custom Config

Unlike standard Hugging Face models, this repository features a highly customized config.json. It dynamically dictates the architecture and handles label decoding.

  • cvss_map: Contains the exact string labels for all 8 CVSS metrics (e.g., ["Network", "Adjacent", "Local", "Physical"]).
  • cwe_labels: Contains ID-to-Name mappings for all supported CWEs across pillar, class, base, and variant levels.

Note: Because of the custom multi-head architecture, you cannot use the default AutoModelForSequenceClassification. You must define the custom PyTorch class provided in the usage snippet below.

πŸ’» Usage & Inference

Below is a complete, standalone Python snippet to load the model, tokenizer, and configuration directly from this Hugging Face repository and perform predictions.

import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download

# 1. Define the Custom Architecture
class SecureBERTFlatClassifier(nn.Module):
    def __init__(self, model_name, cvss_map, class_counts):
        super().__init__()
        config = AutoConfig.from_pretrained(model_name)
        if hasattr(config, "reference_compile"): config.reference_compile = False
        self.bert = AutoModel.from_pretrained(model_name, config=config)
        
        def make_head(out_features, is_cvss=False):
            layers =[
                nn.LayerNorm(768), nn.Dropout(0.1), 
                nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
                nn.Linear(768, 768), nn.GELU(), nn.Dropout(0.1),
                nn.Linear(768, out_features)
            ]
            if is_cvss: layers.append(nn.Softmax(dim=1))
            return nn.Sequential(*layers)

        self.cvss_heads = nn.ModuleDict({k: make_head(len(v), True) for k, v in cvss_map.items()})
        self.cwe_heads = nn.ModuleDict({k: make_head(v) for k, v in class_counts.items()})

    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        cls_emb = out[:, 0, :]
        mask = attention_mask.unsqueeze(-1).expand(out.size()).float()
        mean_emb = torch.sum(out * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
        
        res = {}
        for k, head in self.cvss_heads.items(): res[k] = head(cls_emb)
        for k, head in self.cwe_heads.items(): res[k] = head(mean_emb)
        return res

# 2. Inference Wrapper
class VulnPredictor:
    def __init__(self, repo_id):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        conf_path = hf_hub_download(repo_id=repo_id, filename="config.json")
        model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
        
        with open(conf_path, "r") as f: self.config = json.load(f)
            
        base_model = self.config.get("base_model", "cisco-ai/SecureBERT2.0-biencoder")
        counts = {k: len(v) for k, v in self.config.get("cwe_labels", {}).items()}
        
        self.tokenizer = AutoTokenizer.from_pretrained(base_model)
        self.model = SecureBERTFlatClassifier(base_model, self.config["cvss_map"], counts)
        self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False)
        self.model.to(self.device).eval()

    def predict(self, text, top_k=3):
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
        with torch.no_grad():
            out = self.model(inputs['input_ids'], inputs['attention_mask'])
        
        res = {'cvss': {}, 'cwe': {}}
        for task, labels in self.config.get("cvss_map", {}).items():
            score, idx = torch.max(out[task], dim=1)
            res['cvss'][task] = {"value": labels[idx.item()], "confidence": round(score.item(), 4)}
            
        for lv, cwe_data in self.config.get("cwe_labels", {}).items():
            if lv in out:
                probs = F.softmax(out[lv], dim=1)
                scores, idxs = torch.topk(probs, k=min(top_k, probs.size(1)))
                res['cwe'][lv] =[
                    {"id": int(str(cwe_data[i.item()]['id']).replace('CWE-','')), 
                     "name": cwe_data[i.item()]['name'], 
                     "score": round(s.item(), 4)} for s, i in zip(scores[0], idxs[0])
                ]
        return res

# 3. Quickstart
if __name__ == "__main__":
    REPO_ID = "bziemba/SecureBERT2.0-final" 
    
    predictor = VulnPredictor(REPO_ID)
    
    sample_cve = "An issue was discovered in the login panel allowing attackers to bypass authentication via crafted SQL queries."
    results = predictor.predict(sample_cve)
    
    print(json.dumps(results, indent=2))
Downloads last month
185
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for bziemba/SecureBERT2.0-final

Finetuned
(3)
this model

Dataset used to train bziemba/SecureBERT2.0-final