diff --git a/internal/api/auth.go b/internal/api/auth.go index c5d8ba47..a02e014a 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -2,7 +2,6 @@ package api import ( "NodePassDash/internal/auth" - "encoding/base64" "encoding/json" "fmt" "io/ioutil" @@ -400,6 +399,8 @@ func (h *AuthHandler) HandleOAuth2Callback(c *gin.Context) { h.handleGitHubOAuth(c, code) case "cloudflare": h.handleCloudflareOAuth(c, code) + case "custom": + h.handleCustomOIDC(c, code) default: c.JSON(http.StatusOK, gin.H{ "success": false, @@ -425,7 +426,10 @@ func (h *AuthHandler) handleGitHubOAuth(c *gin.Context, code string) { RedirectURI string `json:"redirectUri"` } var cfg ghCfg - _ = json.Unmarshal([]byte(cfgStr), &cfg) + if err := auth.UnmarshalConfig(cfgStr, &cfg); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("配置解析失败: %v", err)}) + return + } if cfg.ClientID == "" || cfg.ClientSecret == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "GitHub OAuth2 配置不完整"}) @@ -452,7 +456,12 @@ func (h *AuthHandler) handleGitHubOAuth(c *gin.Context, code string) { cfg.ClientID, redirectURI, cfg.TokenURL) fmt.Printf("🔍 请求体: %s\n", form.Encode()) - tokenReq, _ := http.NewRequest("POST", cfg.TokenURL, strings.NewReader(form.Encode())) + tokenReq, err := http.NewRequest("POST", cfg.TokenURL, strings.NewReader(form.Encode())) + if err != nil { + fmt.Printf("❌ GitHub Token 请求创建失败: %v\n", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "创建 Token 请求失败"}) + return + } tokenReq.Header.Set("Accept", "application/json") tokenReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -467,13 +476,23 @@ func (h *AuthHandler) handleGitHubOAuth(c *gin.Context, code string) { defer resp.Body.Close() if resp.StatusCode >= 400 { - bodyBytes, _ := ioutil.ReadAll(resp.Body) + bodyBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Printf("❌ GitHub Token 读取响应失败: %v\n", err) + c.JSON(http.StatusBadGateway, gin.H{"error": "读取 GitHub Token 响应失败"}) + return + } fmt.Printf("❌ GitHub Token 错误 %d: %s\n", resp.StatusCode, string(bodyBytes)) c.JSON(http.StatusBadGateway, gin.H{"error": "GitHub Token 接口返回错误"}) return } - body, _ := ioutil.ReadAll(resp.Body) + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Printf("❌ GitHub Token 读取响应失败: %v\n", err) + c.JSON(http.StatusBadGateway, gin.H{"error": "读取 GitHub Token 响应失败"}) + return + } fmt.Printf("🔑 GitHub Token 响应: %s\n", string(body)) var tokenRes struct { @@ -481,14 +500,22 @@ func (h *AuthHandler) handleGitHubOAuth(c *gin.Context, code string) { Scope string `json:"scope"` TokenType string `json:"token_type"` } - _ = json.Unmarshal(body, &tokenRes) + if err := auth.UnmarshalBytes(body, &tokenRes); err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "解析 Token 响应失败"}) + return + } if tokenRes.AccessToken == "" { c.JSON(http.StatusBadGateway, gin.H{"error": "获取 AccessToken 失败"}) return } // 获取用户信息 - userReq, _ := http.NewRequest("GET", cfg.UserInfoURL, nil) + userReq, err := http.NewRequest("GET", cfg.UserInfoURL, nil) + if err != nil { + fmt.Printf("❌ GitHub 用户信息请求创建失败: %v\n", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "创建用户信息请求失败"}) + return + } userReq.Header.Set("Authorization", "token "+tokenRes.AccessToken) userReq.Header.Set("Accept", "application/json") @@ -499,11 +526,19 @@ func (h *AuthHandler) handleGitHubOAuth(c *gin.Context, code string) { return } defer userResp.Body.Close() - userBody, _ := ioutil.ReadAll(userResp.Body) + userBody, err := ioutil.ReadAll(userResp.Body) + if err != nil { + fmt.Printf("❌ GitHub 用户信息读取失败: %v\n", err) + c.JSON(http.StatusBadGateway, gin.H{"error": "读取 GitHub 用户信息失败"}) + return + } fmt.Printf("👤 GitHub 用户信息: %s\n", string(userBody)) var userData map[string]interface{} - _ = json.Unmarshal(userBody, &userData) + if err := auth.UnmarshalBytes(userBody, &userData); err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "解析用户信息失败"}) + return + } providerID := fmt.Sprintf("%v", userData["id"]) login := fmt.Sprintf("%v", userData["login"]) @@ -580,7 +615,10 @@ func (h *AuthHandler) handleCloudflareOAuth(c *gin.Context, code string) { RedirectURI string `json:"redirectUri"` } var cfg cfCfg - _ = json.Unmarshal([]byte(cfgStr), &cfg) + if err := auth.UnmarshalConfig(cfgStr, &cfg); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("配置解析失败: %v", err)}) + return + } if cfg.ClientID == "" || cfg.ClientSecret == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "Cloudflare OAuth2 配置不完整"}) @@ -604,7 +642,12 @@ func (h *AuthHandler) handleCloudflareOAuth(c *gin.Context, code string) { } form.Set("redirect_uri", redirectURI) - tokenReq, _ := http.NewRequest("POST", cfg.TokenURL, strings.NewReader(form.Encode())) + tokenReq, err := http.NewRequest("POST", cfg.TokenURL, strings.NewReader(form.Encode())) + if err != nil { + fmt.Printf("❌ Cloudflare Token 请求创建失败: %v\n", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "创建 Token 请求失败"}) + return + } tokenReq.Header.Set("Accept", "application/json") tokenReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -618,13 +661,23 @@ func (h *AuthHandler) handleCloudflareOAuth(c *gin.Context, code string) { defer resp.Body.Close() if resp.StatusCode >= 400 { - bodyBytes, _ := ioutil.ReadAll(resp.Body) + bodyBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Printf("❌ Cloudflare Token 读取响应失败: %v\n", err) + c.JSON(http.StatusBadGateway, gin.H{"error": "读取 Cloudflare Token 响应失败"}) + return + } fmt.Printf("❌ Cloudflare Token 错误 %d: %s\n", resp.StatusCode, string(bodyBytes)) c.JSON(http.StatusBadGateway, gin.H{"error": "Cloudflare Token 接口返回错误"}) return } - body, _ := ioutil.ReadAll(resp.Body) + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Printf("❌ Cloudflare Token 读取响应失败: %v\n", err) + c.JSON(http.StatusBadGateway, gin.H{"error": "读取 Cloudflare Token 响应失败"}) + return + } fmt.Printf("🔑 Cloudflare Token 响应: %s\n", string(body)) var tokenRes struct { @@ -633,42 +686,48 @@ func (h *AuthHandler) handleCloudflareOAuth(c *gin.Context, code string) { Scope string `json:"scope"` TokenType string `json:"token_type"` } - _ = json.Unmarshal(body, &tokenRes) + if err := auth.UnmarshalBytes(body, &tokenRes); err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "解析 Token 响应失败"}) + return + } if tokenRes.AccessToken == "" { c.JSON(http.StatusBadGateway, gin.H{"error": "获取 AccessToken 失败"}) return } - var userData map[string]interface{} + // 获取用户信息 + if cfg.UserInfoURL == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cloudflare OAuth2 配置缺少 userInfoUrl"}) + return + } - if cfg.UserInfoURL != "" { - // 调用用户信息端点 - userReq, _ := http.NewRequest("GET", cfg.UserInfoURL, nil) - userReq.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken) - userReq.Header.Set("Accept", "application/json") - - // 使用支持代理的HTTP客户端 - userResp, err := proxyClient.Do(userReq) - if err == nil { - defer userResp.Body.Close() - bodyBytes, _ := ioutil.ReadAll(userResp.Body) - _ = json.Unmarshal(bodyBytes, &userData) - fmt.Printf("👤 Cloudflare 用户信息: %s\n", string(bodyBytes)) - } + userReq, err := http.NewRequest("GET", cfg.UserInfoURL, nil) + if err != nil { + fmt.Printf("❌ Cloudflare 用户信息请求创建失败: %v\n", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "创建用户信息请求失败"}) + return } + userReq.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken) + userReq.Header.Set("Accept", "application/json") - // 若未获取到用户信息且 id_token 存在,则解析 id_token - if len(userData) == 0 && tokenRes.IdToken != "" { - parts := strings.Split(tokenRes.IdToken, ".") - if len(parts) >= 2 { - payload, _ := base64.RawURLEncoding.DecodeString(parts[1]) - _ = json.Unmarshal(payload, &userData) - fmt.Printf("👤 Cloudflare id_token payload: %s\n", string(payload)) - } + // 使用支持代理的HTTP客户端 + userResp, err := proxyClient.Do(userReq) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "获取 Cloudflare 用户信息失败"}) + return + } + defer userResp.Body.Close() + bodyBytes, err := ioutil.ReadAll(userResp.Body) + if err != nil { + fmt.Printf("❌ Cloudflare 用户信息读取失败: %v\n", err) + c.JSON(http.StatusBadGateway, gin.H{"error": "读取 Cloudflare 用户信息失败"}) + return } + fmt.Printf("👤 Cloudflare 用户信息: %s\n", string(bodyBytes)) - if len(userData) == 0 { - c.JSON(http.StatusBadGateway, gin.H{"error": "无法获取 Cloudflare 用户信息"}) + var userData map[string]interface{} + if err := auth.UnmarshalBytes(bodyBytes, &userData); err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "解析 Cloudflare 用户信息失败"}) return } @@ -801,6 +860,13 @@ func (h *AuthHandler) HandleOAuth2Config(c *gin.Context) { c.JSON(http.StatusOK, resp) case http.MethodPost: + // 1. 验证会话(仅管理员可配置) + sessionID, err := c.Cookie("session") + if err != nil || !h.authService.ValidateSession(sessionID) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "需要管理员权限"}) + return + } + var req OAuth2ConfigRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) @@ -811,6 +877,42 @@ func (h *AuthHandler) HandleOAuth2Config(c *gin.Context) { return } + // 2. Custom OIDC 执行 discovery + if req.Provider == "custom" { + issuerURL, ok := req.Config["issuerUrl"].(string) + if !ok || issuerURL == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 issuerUrl"}) + return + } + + // 3. Discovery(强制使用 HTTPS,支持内网 IP) + validator := &auth.URLValidator{ + AllowPrivateIP: true, // 支持内网 IP(使用 HTTPS) + } + + discoveredConfig, err := auth.SecureDiscoverOIDC(issuerURL, validator) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{ + "success": false, + "error": fmt.Sprintf("OIDC Discovery 失败: %v", err), + }) + return + } + + // 4. 自动填充端点 + req.Config["authUrl"] = discoveredConfig.AuthorizationEndpoint + req.Config["tokenUrl"] = discoveredConfig.TokenEndpoint + req.Config["userInfoUrl"] = discoveredConfig.UserinfoEndpoint + req.Config["issuer"] = discoveredConfig.Issuer + } + + // 5. 添加 redirectUri + scheme := "http" + if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" { + scheme = "https" + } + req.Config["redirectUri"] = fmt.Sprintf("%s://%s/api/oauth2/callback", scheme, c.Request.Host) + cfgBytes, _ := json.Marshal(req.Config) if err := h.authService.SetSystemConfig("oauth2_config", string(cfgBytes)); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "save config failed"}) @@ -896,7 +998,8 @@ func (h *AuthHandler) HandleOAuth2Login(c *gin.Context) { q.Set("scope", scopes) } - if provider == "cloudflare" { + // Cloudflare 和 Custom OIDC 需要设置 response_type=code(OIDC 标准) + if provider == "cloudflare" || provider == "custom" { q.Set("response_type", "code") } @@ -906,14 +1009,292 @@ func (h *AuthHandler) HandleOAuth2Login(c *gin.Context) { c.Redirect(http.StatusFound, loginURL) } +// handleCustomOIDC 处理 Custom OIDC 回调 +func (h *AuthHandler) handleCustomOIDC(c *gin.Context, code string) { + // 读取配置 + cfgStr, err := h.authService.GetSystemConfig("oauth2_config") + if err != nil || cfgStr == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Custom OIDC 未配置"}) + return + } + + type customCfg struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + AuthURL string `json:"authUrl"` + TokenURL string `json:"tokenUrl"` + UserInfoURL string `json:"userInfoUrl"` + RedirectURI string `json:"redirectUri"` + Scopes []string `json:"scopes"` + UserIDPath string `json:"userIdPath"` + UsernamePath string `json:"usernamePath"` + DisplayName string `json:"displayName"` + } + var cfg customCfg + if err := auth.UnmarshalConfig(cfgStr, &cfg); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("配置解析失败: %v", err)}) + return + } + + if cfg.ClientID == "" || cfg.ClientSecret == "" || cfg.TokenURL == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Custom OIDC 配置不完整"}) + return + } + + // 设置默认值 + if cfg.UserIDPath == "" { + cfg.UserIDPath = "sub" + } + if cfg.UsernamePath == "" { + cfg.UsernamePath = "preferred_username" + } + if cfg.DisplayName == "" { + cfg.DisplayName = "OIDC" + } + + // 交换 access token + form := url.Values{} + form.Set("client_id", cfg.ClientID) + form.Set("client_secret", cfg.ClientSecret) + form.Set("code", code) + form.Set("grant_type", "authorization_code") + + // 设置 redirect_uri + redirectURI := cfg.RedirectURI + if redirectURI == "" { + baseURL := fmt.Sprintf("%s://%s", "http", c.Request.Host) + redirectURI = baseURL + "/api/oauth2/callback" + } + form.Set("redirect_uri", redirectURI) + + fmt.Printf("🔍 Custom OIDC Token 请求: token_url=%s, redirect_uri=%s\n", cfg.TokenURL, redirectURI) + + tokenReq, err := http.NewRequest("POST", cfg.TokenURL, strings.NewReader(form.Encode())) + if err != nil { + fmt.Printf("❌ Custom OIDC Token 请求创建失败: %v\n", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "创建 Token 请求失败"}) + return + } + tokenReq.Header.Set("Accept", "application/json") + tokenReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // 使用支持代理的HTTP客户端 + proxyClient := h.createProxyClient() + resp, err := proxyClient.Do(tokenReq) + if err != nil { + fmt.Printf("❌ Custom OIDC Token 请求错误: %v\n", err) + c.JSON(http.StatusBadGateway, gin.H{"error": "请求 OIDC Token 失败"}) + return + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + bodyBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Printf("❌ Custom OIDC Token 读取响应失败: %v\n", err) + c.JSON(http.StatusBadGateway, gin.H{"error": "读取 OIDC Token 响应失败"}) + return + } + fmt.Printf("❌ Custom OIDC Token 错误 %d: %s\n", resp.StatusCode, string(bodyBytes)) + c.JSON(http.StatusBadGateway, gin.H{"error": "OIDC Token 接口返回错误"}) + return + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Printf("❌ Custom OIDC Token 读取响应失败: %v\n", err) + c.JSON(http.StatusBadGateway, gin.H{"error": "读取 OIDC Token 响应失败"}) + return + } + fmt.Printf("🔑 Custom OIDC Token 响应: %s\n", string(body)) + + var tokenRes struct { + AccessToken string `json:"access_token"` + IdToken string `json:"id_token"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` + } + if err := auth.UnmarshalBytes(body, &tokenRes); err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "解析 Token 响应失败"}) + return + } + if tokenRes.AccessToken == "" { + c.JSON(http.StatusBadGateway, gin.H{"error": "获取 AccessToken 失败"}) + return + } + + // 获取用户信息 + if cfg.UserInfoURL == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Custom OIDC 配置缺少 userInfoUrl"}) + return + } + + userReq, err := http.NewRequest("GET", cfg.UserInfoURL, nil) + if err != nil { + fmt.Printf("❌ Custom OIDC 用户信息请求创建失败: %v\n", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "创建用户信息请求失败"}) + return + } + userReq.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken) + userReq.Header.Set("Accept", "application/json") + + userResp, err := proxyClient.Do(userReq) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "获取 OIDC 用户信息失败"}) + return + } + defer userResp.Body.Close() + bodyBytes, err := ioutil.ReadAll(userResp.Body) + if err != nil { + fmt.Printf("❌ Custom OIDC 用户信息读取失败: %v\n", err) + c.JSON(http.StatusBadGateway, gin.H{"error": "读取 OIDC 用户信息失败"}) + return + } + fmt.Printf("👤 Custom OIDC 用户信息: %s\n", string(bodyBytes)) + + var userData map[string]interface{} + if err := auth.UnmarshalBytes(bodyBytes, &userData); err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "解析 OIDC 用户信息失败"}) + return + } + + // 提取用户 ID(使用配置的 userIdPath) + providerID := h.extractFieldFromUserData(userData, cfg.UserIDPath) + if providerID == "" { + // 回退到常用字段 + providerID = h.extractFieldFromUserData(userData, "sub") + if providerID == "" { + providerID = h.extractFieldFromUserData(userData, "id") + } + } + + if providerID == "" { + c.JSON(http.StatusBadGateway, gin.H{"error": "无法获取 OIDC 用户唯一标识"}) + return + } + + // 提取用户名(使用配置的 usernamePath) + login := h.extractFieldFromUserData(userData, cfg.UsernamePath) + if login == "" { + // 回退到常用字段 + login = h.extractFieldFromUserData(userData, "preferred_username") + if login == "" { + login = h.extractFieldFromUserData(userData, "email") + } + if login == "" { + login = h.extractFieldFromUserData(userData, "name") + } + if login == "" { + login = providerID // 最后回退到使用 providerID + } + } + + username := "custom:" + login + + // 保存用户信息 + dataJSON, _ := json.Marshal(userData) + if err := h.authService.SaveOAuthUser("custom", providerID, username, string(dataJSON)); err != nil { + fmt.Printf("❌ 保存 Custom OIDC 用户失败: %v\n", err) + // 重定向到错误页面 + baseURL := "" + if cfg.RedirectURI != "" { + baseURL = strings.Replace(cfg.RedirectURI, "/api/oauth2/callback", "", 1) + } else { + scheme := "http" + if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" { + scheme = "https" + } + baseURL = fmt.Sprintf("%s://%s", scheme, c.Request.Host) + } + errorURL := fmt.Sprintf("%s/oauth-error?error=%s&provider=custom", + baseURL, url.QueryEscape(err.Error())) + c.Redirect(http.StatusFound, errorURL) + return + } + + // 创建会话 (24小时有效期) + sessionID, err := h.authService.CreateSession(username, 24*time.Hour) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "创建会话失败"}) + return + } + + // 设置 cookie + c.SetCookie("session", sessionID, 24*60*60, "/", "", false, true) + + // 重定向到 dashboard + redirectURL := c.Query("redirect") + if redirectURL == "" { + redirectURL = strings.Replace(cfg.RedirectURI, "/api/oauth2/callback", "/dashboard", 1) + } + + accept := c.GetHeader("Accept") + if strings.Contains(accept, "text/html") || strings.Contains(accept, "application/xhtml+xml") || redirectURL != "" { + c.Redirect(http.StatusFound, redirectURL) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "provider": "custom", + "username": username, + "message": "登录成功", + }) +} + +// extractFieldFromUserData 从用户数据中提取字段(支持简单的点号路径) +func (h *AuthHandler) extractFieldFromUserData(data map[string]interface{}, path string) string { + if path == "" { + return "" + } + + parts := strings.Split(path, ".") + current := data + + for i, part := range parts { + if val, ok := current[part]; ok { + if i == len(parts)-1 { + // 最后一个部分,转换为字符串 + return fmt.Sprintf("%v", val) + } + // 不是最后一个部分,继续深入 + if nested, ok := val.(map[string]interface{}); ok { + current = nested + } else { + return "" + } + } else { + return "" + } + } + return "" +} + // HandleOAuth2Provider 仅返回当前绑定的 OAuth2 provider(用于登录页) func (h *AuthHandler) HandleOAuth2Provider(c *gin.Context) { provider, _ := h.authService.GetSystemConfig("oauth2_provider") disableLogin, _ := h.authService.GetSystemConfig("disable_login") - c.JSON(http.StatusOK, gin.H{ + resp := gin.H{ "success": true, "provider": provider, "disableLogin": disableLogin == "true", - }) + } + + // 如果是 custom provider,返回 displayName + if provider == "custom" { + cfgStr, _ := h.authService.GetSystemConfig("oauth2_config") + if cfgStr != "" { + var cfg map[string]interface{} + if err := auth.UnmarshalConfig(cfgStr, &cfg); err == nil { + displayName := auth.SafeStringAssert(cfg["displayName"], "") + if displayName != "" { + resp["displayName"] = displayName + } + } + } + } + + c.JSON(http.StatusOK, resp) } + diff --git a/internal/auth/helpers.go b/internal/auth/helpers.go new file mode 100644 index 00000000..79344817 --- /dev/null +++ b/internal/auth/helpers.go @@ -0,0 +1,39 @@ +package auth + +import ( + "encoding/json" + "fmt" +) + +// UnmarshalConfig 安全地解析配置 JSON 字符串 +func UnmarshalConfig(data string, v interface{}) error { + if data == "" { + return fmt.Errorf("配置为空") + } + if err := json.Unmarshal([]byte(data), v); err != nil { + return fmt.Errorf("解析配置失败: %w", err) + } + return nil +} + +// UnmarshalBytes 安全地解析字节数据为 JSON +func UnmarshalBytes(data []byte, v interface{}) error { + if len(data) == 0 { + return fmt.Errorf("数据为空") + } + if err := json.Unmarshal(data, v); err != nil { + return fmt.Errorf("解析数据失败: %w", err) + } + return nil +} + +// SafeStringAssert 安全地断言为字符串,失败返回默认值 +func SafeStringAssert(v interface{}, fallback string) string { + if v == nil { + return fallback + } + if s, ok := v.(string); ok && s != "" { + return s + } + return fallback +} diff --git a/internal/auth/security.go b/internal/auth/security.go new file mode 100644 index 00000000..5f6dae91 --- /dev/null +++ b/internal/auth/security.go @@ -0,0 +1,155 @@ +package auth + +import ( + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "time" +) + +// URLValidator URL 验证器 +type URLValidator struct { + AllowPrivateIP bool // 是否允许私有 IP (默认为 false,即仅允许公网 IP) +} + +// ValidateURL 验证 URL 的安全性 +func (v *URLValidator) ValidateURL(rawURL string) error { + // 解析 URL + parsedURL, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("无效的 URL 格式: %w", err) + } + + // 检查 scheme,强制使用 HTTPS + scheme := strings.ToLower(parsedURL.Scheme) + if scheme != "https" { + return fmt.Errorf("仅允许 HTTPS 协议,当前协议: %s", scheme) + } + + // 检查 host 是否为空 + if parsedURL.Host == "" { + return fmt.Errorf("URL 缺少 host") + } + + // 如果不允许私有 IP,则检查 IP 地址 + if !v.AllowPrivateIP { + // 提取 hostname (去除端口) + hostname := parsedURL.Hostname() + + // 解析 IP + ip := net.ParseIP(hostname) + if ip == nil { + // 如果不是 IP,尝试解析域名 + ips, err := net.LookupIP(hostname) + if err == nil && len(ips) > 0 { + ip = ips[0] + } + } + + // 检查是否为私有 IP + if ip != nil && v.isPrivateIP(ip) { + return fmt.Errorf("不允许访问私有 IP 地址: %s", ip.String()) + } + } + + return nil +} + +// isPrivateIP 检查 IP 是否为私有地址 +func (v *URLValidator) isPrivateIP(ip net.IP) bool { + // 检查 loopback + if ip.IsLoopback() { + return true + } + + // 检查私有 IP 范围 + privateRanges := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "fc00::/7", // IPv6 unique local addresses + } + + for _, cidr := range privateRanges { + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + // 硬编码的 CIDR 不应该失败,但为了安全起见还是检查 + continue + } + if ipnet.Contains(ip) { + return true + } + } + + return false +} + +// OIDCConfig OIDC 配置 +type OIDCConfig struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` +} + +// SecureDiscoverOIDC 安全地执行 OIDC Discovery +func SecureDiscoverOIDC(issuer string, validator *URLValidator) (*OIDCConfig, error) { + // 构造 Discovery URL + issuer = strings.TrimSuffix(issuer, "/") + discoveryURL := issuer + "/.well-known/openid-configuration" + + // 验证 Discovery URL + if err := validator.ValidateURL(discoveryURL); err != nil { + return nil, fmt.Errorf("Discovery URL 验证失败: %w", err) + } + + // 发起 HTTP 请求 + client := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequest("GET", discoveryURL, nil) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("请求失败: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Discovery 请求失败,状态码: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 解析配置 + var config OIDCConfig + if err := json.Unmarshal(body, &config); err != nil { + return nil, fmt.Errorf("解析配置失败: %w", err) + } + + // 验证 issuer 一致性 + responseIssuer := strings.TrimSuffix(config.Issuer, "/") + expectedIssuer := strings.TrimSuffix(issuer, "/") + if responseIssuer != expectedIssuer { + return nil, fmt.Errorf("issuer 不匹配: 请求 %s, 响应 %s", expectedIssuer, responseIssuer) + } + + // 验证必要端点 + if config.AuthorizationEndpoint == "" || config.TokenEndpoint == "" { + return nil, fmt.Errorf("配置不完整,缺少必要端点") + } + + return &config, nil +} diff --git a/web/src/components/settings/security-settings.tsx b/web/src/components/settings/security-settings.tsx index 81d64f00..5d4a9847 100644 --- a/web/src/components/settings/security-settings.tsx +++ b/web/src/components/settings/security-settings.tsx @@ -39,6 +39,17 @@ const securitySettingsSchema = z.object({ type SecuritySettingsForm = z.infer; +// 辅助函数: 根据配置状态返回当前 OAuth Provider +const getCurrentProvider = ( + isGitHub: boolean, + isCloudflare: boolean, +): "github" | "cloudflare" | "custom" => { + if (isGitHub) return "github"; + if (isCloudflare) return "cloudflare"; + + return "custom"; +}; + // OAuth2 配置类型 interface OAuth2Config { clientId: string; @@ -115,13 +126,31 @@ const SecuritySettings = forwardRef((props, ref) => { scopes: ["openid", "profile"], }); + // Custom OIDC 配置 + interface CustomOIDCConfig extends OAuth2Config { + issuerUrl: string; + displayName: string; + usernamePath: string; + } + + const [customConfig, setCustomConfig] = useState({ + issuerUrl: "", + clientId: "", + clientSecret: "", + userIdPath: "sub", + usernamePath: "preferred_username", + scopes: ["openid", "profile", "email"], + displayName: "", + }); + // 模拟的配置状态(实际应该从后端获取) const [isGitHubConfigured, setIsGitHubConfigured] = useState(false); const [isCloudflareConfigured, setIsCloudflareConfigured] = useState(false); + const [isCustomConfigured, setIsCustomConfigured] = useState(false); // 在 state 部分添加 selectedProvider 和 provider select disclosure const [selectedProvider, setSelectedProvider] = useState< - "github" | "cloudflare" | null + "github" | "cloudflare" | "custom" | null >(null); const { isOpen: isSelectOpen, @@ -129,6 +158,13 @@ const SecuritySettings = forwardRef((props, ref) => { onOpenChange: onSelectOpenChange, } = useDisclosure(); + // Custom OIDC 配置模态框 + const { + isOpen: isCustomOpen, + onOpen: onCustomOpen, + onOpenChange: onCustomOpenChange, + } = useDisclosure(); + // 初始化表单 const { register, @@ -150,7 +186,11 @@ const SecuritySettings = forwardRef((props, ref) => { if (!data.success) return; - const curProvider = data.provider as "github" | "cloudflare" | ""; + const curProvider = data.provider as + | "github" + | "cloudflare" + | "custom" + | ""; if (!curProvider) return; // 未绑定 @@ -162,6 +202,9 @@ const SecuritySettings = forwardRef((props, ref) => { } else if (curProvider === "cloudflare") { setCloudflareConfig((prev: any) => ({ ...prev, ...cfgData.config })); setIsCloudflareConfigured(true); + } else if (curProvider === "custom") { + setCustomConfig((prev: any) => ({ ...prev, ...cfgData.config })); + setIsCustomConfigured(true); } } catch (e) { console.error("初始化 OAuth2 配置失败", e); @@ -404,8 +447,94 @@ const SecuritySettings = forwardRef((props, ref) => { } }; + // Custom OIDC 配置保存 + const handleSaveCustomConfig = async () => { + // 验证必填字段 + if (!customConfig.issuerUrl) { + addToast({ + title: "配置不完整", + description: "请填写 Issuer URL", + color: "warning", + }); + return; + } + + if (!customConfig.clientId || !customConfig.clientSecret) { + addToast({ + title: "配置不完整", + description: "请填写 Client ID 和 Client Secret", + color: "warning", + }); + return; + } + + if (!customConfig.displayName) { + addToast({ + title: "配置不完整", + description: "请填写显示名称(如 Keycloak、Authentik 等)", + color: "warning", + }); + return; + } + + try { + setIsSubmitting(true); + + // 只发送用户配置的字段,端点由后端自动填充 + const payload = { + provider: "custom", + config: { + issuerUrl: customConfig.issuerUrl, + clientId: customConfig.clientId, + clientSecret: customConfig.clientSecret, + displayName: customConfig.displayName, + userIdPath: customConfig.userIdPath || "sub", + usernamePath: customConfig.usernamePath || "preferred_username", + scopes: customConfig.scopes || ["openid", "profile", "email"], + }, + }; + + const res = await fetch(buildApiUrl("/api/oauth2/config"), { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }); + + const data = await res.json(); + + if (!res.ok) { + throw new Error(data.error || "保存失败"); + } + + addToast({ + title: "配置保存成功", + description: `${customConfig.displayName} OIDC 配置已成功保存`, + color: "success", + }); + + setIsCustomConfigured(true); + onCustomOpenChange(); + } catch (error) { + console.error("保存 Custom OIDC 配置失败:", error); + const errorMessage = + error instanceof Error + ? error.message + : "无法连接到 OIDC 服务器,请检查 Issuer URL"; + + addToast({ + title: "保存失败", + description: errorMessage, + color: "danger", + }); + } finally { + setIsSubmitting(false); + } + }; + // 解绑处理 - const handleUnbindProvider = async (provider: "github" | "cloudflare") => { + const handleUnbindProvider = async ( + provider: "github" | "cloudflare" | "custom", + ) => { try { setIsSubmitting(true); const res = await fetch(buildApiUrl("/api/oauth2/config"), { @@ -419,7 +548,8 @@ const SecuritySettings = forwardRef((props, ref) => { color: "success", }); if (provider === "github") setIsGitHubConfigured(false); - else setIsCloudflareConfigured(false); + else if (provider === "cloudflare") setIsCloudflareConfigured(false); + else if (provider === "custom") setIsCustomConfigured(false); } catch (e) { console.error("解绑失败", e); addToast({ @@ -675,7 +805,9 @@ const SecuritySettings = forwardRef((props, ref) => { - {isGitHubConfigured || isCloudflareConfigured ? ( + {isGitHubConfigured || + isCloudflareConfigured || + isCustomConfigured ? ( // 已绑定状态
@@ -702,6 +834,15 @@ const SecuritySettings = forwardRef((props, ref) => { Cloudflare{" "} )} + {isCustomConfigured && ( + <> + {" "} + {" "} + + {customConfig.displayName || "Custom OIDC"} + {" "} + + )} 已绑定 @@ -715,6 +856,7 @@ const SecuritySettings = forwardRef((props, ref) => { // 打开对应配置模态框 if (isGitHubConfigured) onGitHubOpen(); else if (isCloudflareConfigured) onCloudflareOpen(); + else if (isCustomConfigured) onCustomOpen(); }} > 配置 @@ -728,7 +870,10 @@ const SecuritySettings = forwardRef((props, ref) => { } onPress={() => handleUnbindProvider( - isGitHubConfigured ? "github" : "cloudflare", + getCurrentProvider( + isGitHubConfigured, + isCloudflareConfigured, + ), ) } > @@ -795,6 +940,18 @@ const SecuritySettings = forwardRef((props, ref) => { > Cloudflare +
@@ -974,6 +1131,153 @@ const SecuritySettings = forwardRef((props, ref) => { )} + + {/* Custom OIDC 配置模态框 */} + + + {(onClose) => ( + <> + + Custom OIDC 配置 + + +
+ {/* Issuer URL */} + + setCustomConfig((prev) => ({ + ...prev, + issuerUrl: e.target.value, + })) + } + /> + + + + {/* 显示名称 */} + + setCustomConfig((prev) => ({ + ...prev, + displayName: e.target.value, + })) + } + /> + + {/* Client ID / Secret */} + + setCustomConfig((prev) => ({ + ...prev, + clientId: e.target.value, + })) + } + /> + + setCustomConfig((prev) => ({ + ...prev, + clientSecret: e.target.value, + })) + } + /> + + {/* Callback URL (只读) */} + + + + + {/* Scopes 和字段映射 */} + { + const scopesStr = e.target.value; + const scopesArr = scopesStr + .split(/\s+/) + .filter((s) => s.length > 0); + + setCustomConfig((prev) => ({ + ...prev, + scopes: + scopesArr.length > 0 + ? scopesArr + : ["openid", "profile", "email"], + })); + }} + /> + + setCustomConfig((prev) => ({ + ...prev, + userIdPath: e.target.value || "sub", + })) + } + /> + + setCustomConfig((prev) => ({ + ...prev, + usernamePath: e.target.value || "preferred_username", + })) + } + /> +
+
+ + + + + + )} +
+
); }); diff --git a/web/src/pages/login/index.tsx b/web/src/pages/login/index.tsx index 12a0f406..4a69a839 100644 --- a/web/src/pages/login/index.tsx +++ b/web/src/pages/login/index.tsx @@ -38,7 +38,8 @@ export default function LoginPage() { // OAuth2 配置状态 const [oauthProviders, setOauthProviders] = useState<{ - provider?: "github" | "cloudflare"; + provider?: "github" | "cloudflare" | "custom"; + displayName?: string; config?: any; }>({}); // 是否禁用用户名密码登录 @@ -58,7 +59,7 @@ export default function LoginPage() { */ const fetchCurrentProvider = async () => { try { - const res = await fetch("/api/auth/oauth2"); // 仅返回 provider 和 disableLogin + const res = await fetch("/api/auth/oauth2"); // 仅返回 provider、disableLogin 和 displayName(custom 时) const data = await res.json(); if (data.success) { @@ -66,9 +67,12 @@ export default function LoginPage() { const loginDisabled = data.disableLogin === true; if (data.provider) { - const cur = data.provider as "github" | "cloudflare"; + const cur = data.provider as "github" | "cloudflare" | "custom"; - setOauthProviders({ provider: cur }); + setOauthProviders({ + provider: cur, + displayName: data.displayName || undefined, + }); } // 设置是否禁用用户名密码登录 @@ -326,6 +330,20 @@ export default function LoginPage() { 使用 Cloudflare 登录 )} + {oauthProviders.provider === "custom" && ( + + )}
)}