"""
Kontext提示词小秘 - 安全版本节点实现
包含API密钥保护和安全功能
"""

import json
import requests
import time
import os
import hashlib
from typing import Dict, Any, Tuple, Optional


class SecureKontextPromptAssistant:
    """
    安全版本的Kontext提示词小秘节点
    包含API密钥保护功能
    """
    
    def __init__(self):
        self.api_key_cache = {}  # 内存中的密钥缓存
    
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "api_base_url": ("STRING", {
                    "default": "https://api.openai.com/v1",
                    "multiline": False,
                    "placeholder": "API基础URL"
                }),
                "api_key": ("STRING", {
                    "default": "",
                    "multiline": False,
                    "placeholder": "API密钥 (将自动遮蔽显示)",
                    "password": True  # 启用密码模式
                }),
                "model": ("STRING", {
                    "default": "gpt-3.5-turbo",
                    "multiline": False,
                    "placeholder": "模型名称"
                }),
                "system_prompt": ("STRING", {
                    "default": "你是一个有用的AI助手。",
                    "multiline": True,
                    "placeholder": "系统提示词"
                }),
                "user_input": ("STRING", {
                    "default": "",
                    "multiline": True,
                    "placeholder": "用户输入内容"
                }),
                "temperature": ("FLOAT", {
                    "default": 0.7,
                    "min": 0.0,
                    "max": 2.0,
                    "step": 0.1,
                    "display": "slider"
                }),
                "max_tokens": ("INT", {
                    "default": 1000,
                    "min": 1,
                    "max": 4000,
                    "step": 1
                }),
            },
            "optional": {
                "use_env_key": ("BOOLEAN", {
                    "default": False,
                    "label_on": "使用环境变量",
                    "label_off": "使用输入密钥"
                }),
                "env_key_name": ("STRING", {
                    "default": "OPENAI_API_KEY",
                    "multiline": False,
                    "placeholder": "环境变量名称"
                }),
                "timeout": ("INT", {
                    "default": 30,
                    "min": 5,
                    "max": 120,
                    "step": 1
                }),
                "retry_count": ("INT", {
                    "default": 3,
                    "min": 1,
                    "max": 10,
                    "step": 1
                }),
            }
        }
    
    RETURN_TYPES = ("STRING", "STRING", "STRING")
    RETURN_NAMES = ("generated_text", "masked_response", "status_info")
    FUNCTION = "generate_response_secure"
    CATEGORY = "Kontext/提示词工具"
    
    def mask_api_key(self, api_key: str) -> str:
        """
        遮蔽API密钥显示
        """
        if not api_key or len(api_key) < 8:
            return "***"
        
        # 显示前4位和后4位，中间用星号替代
        return f"{api_key[:4]}{'*' * (len(api_key) - 8)}{api_key[-4:]}"
    
    def get_api_key(self, input_key: str, use_env: bool, env_name: str) -> Tuple[str, str]:
        """
        安全获取API密钥
        
        Returns:
            Tuple[str, str]: (实际密钥, 状态信息)
        """
        if use_env:
            # 从环境变量获取
            env_key = os.getenv(env_name)
            if env_key:
                return env_key, f"✅ 从环境变量 {env_name} 获取密钥"
            else:
                return "", f"❌ 环境变量 {env_name} 未设置"
        else:
            # 使用输入的密钥
            if input_key.strip():
                return input_key.strip(), f"✅ 使用输入密钥 ({self.mask_api_key(input_key)})"
            else:
                return "", "❌ API密钥不能为空"
    
    def mask_sensitive_data(self, response_text: str, api_key: str) -> str:
        """
        遮蔽响应中的敏感信息
        """
        if not response_text:
            return response_text
        
        # 替换可能出现的API密钥
        if api_key and len(api_key) > 8:
            masked_key = self.mask_api_key(api_key)
            response_text = response_text.replace(api_key, masked_key)
        
        # 遮蔽其他敏感信息模式
        import re
        
        # 遮蔽类似API密钥的字符串
        response_text = re.sub(r'sk-[a-zA-Z0-9]{20,}', 'sk-****', response_text)
        response_text = re.sub(r'Bearer [a-zA-Z0-9]{20,}', 'Bearer ****', response_text)
        
        return response_text
    
    def generate_response_secure(
        self,
        api_base_url: str,
        api_key: str,
        model: str,
        system_prompt: str,
        user_input: str,
        temperature: float,
        max_tokens: int,
        use_env_key: bool = False,
        env_key_name: str = "OPENAI_API_KEY",
        timeout: int = 30,
        retry_count: int = 3
    ) -> Tuple[str, str, str]:
        """
        安全版本的API调用
        """
        
        # 安全获取API密钥
        actual_key, key_status = self.get_api_key(api_key, use_env_key, env_key_name)
        
        if not actual_key:
            return "", "", key_status
        
        # 参数验证
        if not user_input.strip():
            return "", "", "错误: 用户输入不能为空"
        
        # 构建请求
        api_url = api_base_url.rstrip('/') + '/chat/completions'
        
        headers = {
            'Content-Type': 'application/json',
            'Authorization': f'Bearer {actual_key}',
            'User-Agent': 'Kontext-ComfyUI-Plugin/1.0'
        }
        
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_input}
        ]
        
        payload = {
            "model": model,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "stream": False
        }
        
        # 执行API调用
        for attempt in range(retry_count):
            try:
                print(f"[Kontext提示词小秘-安全版] 尝试第 {attempt + 1} 次API调用...")
                
                response = requests.post(
                    api_url,
                    headers=headers,
                    json=payload,
                    timeout=timeout
                )
                
                if response.status_code == 200:
                    response_data = response.json()
                    
                    if 'choices' in response_data and len(response_data['choices']) > 0:
                        generated_text = response_data['choices'][0]['message']['content']
                        
                        # 遮蔽敏感信息的响应
                        masked_response = self.mask_sensitive_data(
                            json.dumps(response_data, ensure_ascii=False, indent=2),
                            actual_key
                        )
                        
                        status_info = f"✅ 成功生成 {len(generated_text)} 个字符 | {key_status}"
                        
                        print(f"[Kontext提示词小秘-安全版] API调用成功")
                        return generated_text, masked_response, status_info
                    else:
                        error_msg = "错误: API响应格式异常"
                        return "", "", error_msg
                
                else:
                    try:
                        error_data = response.json()
                        error_msg = f"API错误 {response.status_code}: {error_data.get('error', {}).get('message', '未知错误')}"
                    except:
                        error_msg = f"HTTP错误 {response.status_code}"
                    
                    if attempt < retry_count - 1:
                        print(f"[Kontext提示词小秘-安全版] {error_msg}, 等待重试...")
                        time.sleep(2 ** attempt)
                        continue
                    else:
                        return "", "", f"{error_msg} | {key_status}"
            
            except requests.exceptions.Timeout:
                error_msg = f"请求超时 (超过 {timeout} 秒)"
                if attempt < retry_count - 1:
                    print(f"[Kontext提示词小秘-安全版] {error_msg}, 等待重试...")
                    time.sleep(2 ** attempt)
                    continue
                else:
                    return "", "", f"{error_msg} | {key_status}"
            
            except Exception as e:
                error_msg = f"未知错误: {str(e)}"
                if attempt < retry_count - 1:
                    print(f"[Kontext提示词小秘-安全版] {error_msg}, 等待重试...")
                    time.sleep(2 ** attempt)
                    continue
                else:
                    return "", "", f"{error_msg} | {key_status}"
        
        return "", "", f"所有重试都失败了 | {key_status}"


class KontextAPIKeyManager:
    """
    API密钥管理节点
    用于安全管理和配置API密钥
    """
    
    def __init__(self):
        pass
    
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "action": (["设置环境变量", "验证密钥", "清除缓存"], {
                    "default": "验证密钥"
                }),
                "api_key": ("STRING", {
                    "default": "",
                    "multiline": False,
                    "placeholder": "API密钥",
                    "password": True
                }),
            },
            "optional": {
                "env_var_name": ("STRING", {
                    "default": "OPENAI_API_KEY",
                    "multiline": False,
                    "placeholder": "环境变量名称"
                }),
                "test_url": ("STRING", {
                    "default": "https://api.openai.com/v1",
                    "multiline": False,
                    "placeholder": "测试API URL"
                }),
            }
        }
    
    RETURN_TYPES = ("STRING", "STRING")
    RETURN_NAMES = ("masked_key", "status")
    FUNCTION = "manage_api_key"
    CATEGORY = "Kontext/提示词工具"
    
    def mask_api_key(self, api_key: str) -> str:
        """遮蔽API密钥"""
        if not api_key or len(api_key) < 8:
            return "***"
        return f"{api_key[:4]}{'*' * (len(api_key) - 8)}{api_key[-4:]}"
    
    def verify_api_key(self, api_key: str, test_url: str) -> Tuple[bool, str]:
        """验证API密钥有效性"""
        try:
            headers = {
                'Authorization': f'Bearer {api_key}',
                'Content-Type': 'application/json'
            }
            
            # 简单的验证请求
            test_payload = {
                "model": "gpt-3.5-turbo",
                "messages": [{"role": "user", "content": "test"}],
                "max_tokens": 1
            }
            
            response = requests.post(
                f"{test_url.rstrip('/')}/chat/completions",
                headers=headers,
                json=test_payload,
                timeout=10
            )
            
            if response.status_code == 200:
                return True, "✅ API密钥验证成功"
            elif response.status_code == 401:
                return False, "❌ API密钥无效或已过期"
            elif response.status_code == 429:
                return True, "⚠️ API密钥有效，但达到速率限制"
            else:
                return False, f"❌ 验证失败: HTTP {response.status_code}"
                
        except Exception as e:
            return False, f"❌ 验证过程出错: {str(e)}"
    
    def manage_api_key(
        self,
        action: str,
        api_key: str,
        env_var_name: str = "OPENAI_API_KEY",
        test_url: str = "https://api.openai.com/v1"
    ) -> Tuple[str, str]:
        """
        管理API密钥
        """
        
        masked_key = self.mask_api_key(api_key)
        
        if action == "验证密钥":
            if not api_key.strip():
                return masked_key, "❌ 请输入API密钥"
            
            is_valid, message = self.verify_api_key(api_key, test_url)
            return masked_key, message
        
        elif action == "设置环境变量":
            if not api_key.strip():
                return masked_key, "❌ 请输入API密钥"
            
            try:
                os.environ[env_var_name] = api_key
                return masked_key, f"✅ 已设置环境变量 {env_var_name}"
            except Exception as e:
                return masked_key, f"❌ 设置环境变量失败: {str(e)}"
        
        elif action == "清除缓存":
            # 清除可能的缓存
            if env_var_name in os.environ:
                del os.environ[env_var_name]
            return "***", f"✅ 已清除环境变量 {env_var_name}"
        
        return masked_key, "❓ 未知操作"


# 更新节点映射，包含安全版本
SECURE_NODE_CLASS_MAPPINGS = {
    "SecureKontextPromptAssistant": SecureKontextPromptAssistant,
    "KontextAPIKeyManager": KontextAPIKeyManager,
}

SECURE_NODE_DISPLAY_NAME_MAPPINGS = {
    "SecureKontextPromptAssistant": "Kontext提示词小秘(安全版)",
    "KontextAPIKeyManager": "Kontext密钥管理器",
}