BaseController.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. package admin
  2. import (
  3. "strconv"
  4. "time"
  5. beego "github.com/beego/beego/v2/server/web"
  6. "github.com/dgrijalva/jwt-go"
  7. )
  8. // JWT配置
  9. var (
  10. JWTSecret = []byte("your-secret-key") // 生产环境应该使用更复杂的密钥
  11. TokenExpire = 7200 // Token有效期,单位秒(2小时)
  12. )
  13. // 用户信息结构
  14. type UserInfo struct {
  15. Id int `json:"id"`
  16. Username string `json:"username"`
  17. Role string `json:"role"`
  18. // 其他用户信息字段...
  19. }
  20. // JWT Claims结构
  21. type Claims struct {
  22. UserInfo UserInfo `json:"user_info"`
  23. jwt.StandardClaims
  24. }
  25. // BaseController 基础控制器,提供JWT鉴权和权限控制
  26. type BaseController struct {
  27. beego.Controller
  28. UserInfo UserInfo // 当前登录用户信息
  29. IsLogin bool // 是否已登录
  30. }
  31. // 生成Token
  32. func GenerateToken(userInfo UserInfo) (string, error) {
  33. expireTime := time.Now().Add(time.Duration(TokenExpire) * time.Second)
  34. claims := &Claims{
  35. UserInfo: userInfo,
  36. StandardClaims: jwt.StandardClaims{
  37. ExpiresAt: expireTime.Unix(),
  38. IssuedAt: time.Now().Unix(),
  39. },
  40. }
  41. token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
  42. return token.SignedString(JWTSecret)
  43. }
  44. // 解析Token
  45. func ParseToken(tokenString string) (*Claims, error) {
  46. token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
  47. return JWTSecret, nil
  48. })
  49. if err != nil {
  50. return nil, err
  51. }
  52. if claims, ok := token.Claims.(*Claims); ok && token.Valid {
  53. return claims, nil
  54. }
  55. return nil, err
  56. }
  57. // 验证Token中间件
  58. func (c *BaseController) VerifyToken() {
  59. tokenString := c.Ctx.Input.Header("Authorization")
  60. if tokenString == "" {
  61. c.Error("Authorization token required", 401)
  62. return
  63. }
  64. claims, err := ParseToken(tokenString)
  65. if err != nil {
  66. c.Error("Invalid token: "+err.Error(), 401)
  67. return
  68. }
  69. // 检查Token是否过期
  70. if time.Now().Unix() > claims.ExpiresAt {
  71. c.Error("Token expired", 401)
  72. return
  73. }
  74. // 设置用户信息
  75. c.UserInfo = claims.UserInfo
  76. c.IsLogin = true
  77. }
  78. // 权限检查函数类型
  79. type PermissionCheckFunc func(userInfo UserInfo, permission string) bool
  80. // 默认权限检查函数
  81. func DefaultPermissionCheck(userInfo UserInfo, permission string) bool {
  82. // 这里实现具体的权限逻辑
  83. // 例如:检查用户角色是否有所需权限
  84. // 简单示例:admin角色有所有权限
  85. if userInfo.Role == "admin" {
  86. return true
  87. }
  88. // 可以根据具体权限字符串检查
  89. // 实际项目中应该查询数据库或缓存中的权限配置
  90. return false
  91. }
  92. // 检查权限中间件
  93. func (c *BaseController) CheckPermission(permission string, checkFunc ...PermissionCheckFunc) {
  94. if !c.IsLogin {
  95. c.Error("Permission denied, user not logged in", 403)
  96. return
  97. }
  98. var check PermissionCheckFunc
  99. if len(checkFunc) > 0 {
  100. check = checkFunc[0]
  101. } else {
  102. check = DefaultPermissionCheck
  103. }
  104. if !check(c.UserInfo, permission) {
  105. c.Error("Permission denied for user: "+c.UserInfo.Username, 403)
  106. return
  107. }
  108. }
  109. // 返回JSON响应
  110. func (c *BaseController) JSONResponse(data interface{}, errCode ...int) {
  111. code := 200
  112. if len(errCode) > 0 {
  113. code = errCode[0]
  114. }
  115. response := map[string]interface{}{
  116. "code": code,
  117. "data": data,
  118. "msg": getMessageByCode(code),
  119. }
  120. c.Data["json"] = response
  121. c.ServeJSON()
  122. }
  123. // 成功响应
  124. func (c *BaseController) Success(data interface{}, msg ...string) {
  125. message := "success"
  126. if len(msg) > 0 {
  127. message = msg[0]
  128. }
  129. response := map[string]interface{}{
  130. "code": 200,
  131. "data": data,
  132. "msg": message,
  133. }
  134. c.Data["json"] = response
  135. c.ServeJSON()
  136. }
  137. // 失败响应
  138. func (c *BaseController) Error(msg string, code ...int) {
  139. errCode := 400
  140. if len(code) > 0 {
  141. errCode = code[0]
  142. }
  143. response := map[string]interface{}{
  144. "code": errCode,
  145. "data": nil,
  146. "msg": msg,
  147. }
  148. c.Data["json"] = response
  149. c.ServeJSON()
  150. }
  151. // 获取分页参数
  152. func (c *BaseController) GetPageParams() (page, pageSize int) {
  153. page, _ = strconv.Atoi(c.Ctx.Input.Query("page"))
  154. if page < 1 {
  155. page = 1
  156. }
  157. pageSize, _ = strconv.Atoi(c.Ctx.Input.Query("pageSize"))
  158. if pageSize < 1 {
  159. pageSize = 20 // 默认每页20条
  160. }
  161. return page, pageSize
  162. }
  163. // Prepare 在执行任何HTTP方法之前调用
  164. func (c *BaseController) Prepare() {
  165. // 自动验证Token(除了登录接口)
  166. if c.Ctx.Input.URL() != "/admin/login" {
  167. c.VerifyToken()
  168. }
  169. }
  170. // getMessageByCode 根据状态码获取消息
  171. func getMessageByCode(code int) string {
  172. messages := map[int]string{
  173. 200: "success",
  174. 400: "bad request",
  175. 401: "unauthorized",
  176. 403: "forbidden",
  177. 404: "not found",
  178. 500: "internal server error",
  179. }
  180. if msg, ok := messages[code]; ok {
  181. return msg
  182. }
  183. return "unknown error"
  184. }