#!/usr/bin/env python3
"""
ComfyUI工作流安全清理工具
自动清除工作流JSON中的敏感信息
"""

import json
import re
import sys
import os
from typing import Dict, Any, List


class WorkflowCleaner:
    """工作流安全清理器"""
    
    def __init__(self):
        # 敏感信息模式
        self.sensitive_patterns = [
            r'sk-[a-zA-Z0-9]{20,}',  # OpenAI API keys
            r'sk-ant-[a-zA-Z0-9]{20,}',  # Claude API keys
            r'Bearer [a-zA-Z0-9]{20,}',  # Bearer tokens
            r'[a-zA-Z0-9]{32,}',  # 长字符串（可能是密钥）
        ]
        
        # 需要清理的字段名
        self.sensitive_fields = [
            'api_key',
            'apikey',
            'token',
            'password',
            'secret',
            'authorization'
        ]
    
    def detect_sensitive_info(self, data: Any, path: str = "") -> List[str]:
        """检测敏感信息"""
        issues = []
        
        if isinstance(data, dict):
            for key, value in data.items():
                current_path = f"{path}.{key}" if path else key
                
                # 检查字段名
                if any(field in key.lower() for field in self.sensitive_fields):
                    if value and str(value).strip():
                        issues.append(f"敏感字段: {current_path} = {self.mask_value(str(value))}")
                
                # 递归检查值
                issues.extend(self.detect_sensitive_info(value, current_path))
        
        elif isinstance(data, list):
            for i, item in enumerate(data):
                current_path = f"{path}[{i}]" if path else f"[{i}]"
                issues.extend(self.detect_sensitive_info(item, current_path))
        
        elif isinstance(data, str):
            # 检查字符串模式
            for pattern in self.sensitive_patterns:
                if re.search(pattern, data):
                    issues.append(f"敏感模式: {path} 包含可能的API密钥")
                    break
        
        return issues
    
    def mask_value(self, value: str) -> str:
        """遮蔽敏感值"""
        if len(value) <= 8:
            return "***"
        return f"{value[:4]}***{value[-4:]}"
    
    def clean_sensitive_data(self, data: Any) -> Any:
        """清理敏感数据"""
        if isinstance(data, dict):
            cleaned = {}
            for key, value in data.items():
                # 检查是否为敏感字段
                if any(field in key.lower() for field in self.sensitive_fields):
                    cleaned[key] = ""  # 清空敏感字段
                else:
                    cleaned[key] = self.clean_sensitive_data(value)
            return cleaned
        
        elif isinstance(data, list):
            return [self.clean_sensitive_data(item) for item in data]
        
        elif isinstance(data, str):
            # 清理字符串中的敏感模式
            cleaned_str = data
            for pattern in self.sensitive_patterns:
                cleaned_str = re.sub(pattern, "***REMOVED***", cleaned_str)
            return cleaned_str
        
        else:
            return data
    
    def clean_kontext_nodes(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """专门清理Kontext节点的敏感信息"""
        if 'nodes' not in data:
            return data
        
        for node in data['nodes']:
            if node.get('type') in ['KontextPromptAssistant', 'SecureKontextPromptAssistant']:
                # 清理widgets_values中的API密钥
                if 'widgets_values' in node and len(node['widgets_values']) > 1:
                    # API密钥通常在第二个位置 (index 1)
                    node['widgets_values'][1] = ""
                
                # 添加使用说明
                if 'properties' not in node:
                    node['properties'] = {}
                node['properties']['_security_note'] = "请在使用前配置API密钥"
        
        return data
    
    def add_security_metadata(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """添加安全元数据"""
        if 'extra' not in data:
            data['extra'] = {}
        
        data['extra']['security_info'] = {
            "cleaned": True,
            "cleaner_version": "1.0",
            "instructions": "此工作流已清理敏感信息，使用前请配置API密钥",
            "setup_guide": "参见 SECURITY_GUIDE.md"
        }
        
        return data


def clean_workflow_file(input_file: str, output_file: str = None) -> None:
    """清理工作流文件"""
    if not os.path.exists(input_file):
        print(f"❌ 文件不存在: {input_file}")
        return
    
    if output_file is None:
        name, ext = os.path.splitext(input_file)
        output_file = f"{name}_cleaned{ext}"
    
    cleaner = WorkflowCleaner()
    
    try:
        # 读取原文件
        with open(input_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        print(f"📂 正在处理: {input_file}")
        
        # 检测敏感信息
        issues = cleaner.detect_sensitive_info(data)
        if issues:
            print("🔍 发现敏感信息:")
            for issue in issues[:10]:  # 只显示前10个
                print(f"  ⚠️ {issue}")
            if len(issues) > 10:
                print(f"  ... 还有 {len(issues) - 10} 个问题")
        else:
            print("✅ 未发现明显的敏感信息")
        
        # 清理数据
        print("🧹 正在清理敏感信息...")
        cleaned_data = cleaner.clean_sensitive_data(data)
        cleaned_data = cleaner.clean_kontext_nodes(cleaned_data)
        cleaned_data = cleaner.add_security_metadata(cleaned_data)
        
        # 保存清理后的文件
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(cleaned_data, f, ensure_ascii=False, indent=2)
        
        print(f"✅ 清理完成: {output_file}")
        
        # 验证清理结果
        verification_issues = cleaner.detect_sensitive_info(cleaned_data)
        if verification_issues:
            print("⚠️ 清理后仍有潜在问题:")
            for issue in verification_issues:
                print(f"  🔸 {issue}")
        else:
            print("🎉 清理验证通过，可以安全分享！")
    
    except json.JSONDecodeError as e:
        print(f"❌ JSON格式错误: {e}")
    except Exception as e:
        print(f"❌ 处理过程中出错: {e}")


def main():
    """主函数"""
    print("🔒 ComfyUI工作流安全清理工具")
    print("=" * 50)
    
    if len(sys.argv) < 2:
        print("使用方法:")
        print(f"  python {sys.argv[0]} <工作流文件.json> [输出文件.json]")
        print("\n示例:")
        print(f"  python {sys.argv[0]} my_workflow.json")
        print(f"  python {sys.argv[0]} my_workflow.json clean_workflow.json")
        return
    
    input_file = sys.argv[1]
    output_file = sys.argv[2] if len(sys.argv) > 2 else None
    
    clean_workflow_file(input_file, output_file)
    
    print("\n💡 使用建议:")
    print("1. 分享工作流前请使用此工具清理")
    print("2. 接收工作流后请重新配置API密钥")
    print("3. 定期检查和更新API密钥")
    print("4. 查看 SECURITY_GUIDE.md 了解更多安全信息")


if __name__ == "__main__":
    main()