zhangjidong 2 天之前
父節點
當前提交
ef2ade33f1
共有 1 個文件被更改,包括 66 次插入29 次删除
  1. 66 29
      models/aice_users.go

+ 66 - 29
models/aice_users.go

@@ -26,6 +26,32 @@ func init() {
 	orm.RegisterModel(new(AiceUsers))
 }
 
+// isValidFieldName 验证字段名是否有效,防止SQL注入
+func isValidFieldName(fieldName string) bool {
+	// AiceUsers结构体中的有效字段
+	validFields := map[string]bool{
+		"userid":   true,
+		"email":    true,
+		"token":    true,
+		"username": true,
+		"address":  true,
+		"password": true,
+	}
+
+	// 处理带isnull的情况(如:email__isnull)
+	baseField := strings.Replace(fieldName, "__isnull", "", -1)
+	baseField = strings.Replace(baseField, ".", "__", -1)
+
+	// 检查基础字段是否有效
+	for field := range validFields {
+		if strings.HasPrefix(baseField, field) || baseField == field {
+			return true
+		}
+	}
+
+	return false
+}
+
 // AddAiceUsers insert a new AiceUsers into database and returns
 // last inserted Id on success.
 func AddAiceUsers(m *AiceUsers) (id int64, err error) {
@@ -53,6 +79,10 @@ func GetAllAiceUsers(query map[string]string, fields []string, sortby []string,
 	qs := o.QueryTable(new(AiceUsers))
 	// query k=v
 	for k, v := range query {
+		// 验证字段名有效性,防止SQL注入
+		if !isValidFieldName(k) {
+			return nil, fmt.Errorf("invalid field name: %s", k)
+		}
 		// rewrite dot-notation to Object__Attribute
 		k = strings.Replace(k, ".", "__", -1)
 		if strings.Contains(k, "isnull") {
@@ -64,36 +94,31 @@ func GetAllAiceUsers(query map[string]string, fields []string, sortby []string,
 	// order by:
 	var sortFields []string
 	if len(sortby) != 0 {
-		if len(sortby) == len(order) {
-			// 1) for each sort field, there is an associated order
-			for i, v := range sortby {
-				orderby := ""
-				if order[i] == "desc" {
-					orderby = "-" + v
-				} else if order[i] == "asc" {
-					orderby = v
-				} else {
-					return nil, errors.New("Error: Invalid order. Must be either [asc|desc]")
-				}
-				sortFields = append(sortFields, orderby)
+		// 验证order参数的有效性
+		if len(order) != 0 && len(order) != 1 && len(order) != len(sortby) {
+			return nil, errors.New("Error: 'sortby', 'order' sizes mismatch or 'order' size is not 1")
+		}
+
+		// 统一处理排序逻辑,消除重复代码
+		for i, field := range sortby {
+			orderDir := "asc"
+			if len(order) == 1 {
+				orderDir = order[0]
+			} else if len(order) > 1 {
+				orderDir = order[i]
 			}
-			qs = qs.OrderBy(sortFields...)
-		} else if len(sortby) != len(order) && len(order) == 1 {
-			// 2) there is exactly one order, all the sorted fields will be sorted by this order
-			for _, v := range sortby {
-				orderby := ""
-				if order[0] == "desc" {
-					orderby = "-" + v
-				} else if order[0] == "asc" {
-					orderby = v
-				} else {
-					return nil, errors.New("Error: Invalid order. Must be either [asc|desc]")
-				}
-				sortFields = append(sortFields, orderby)
+
+			orderby := ""
+			if orderDir == "desc" {
+				orderby = "-" + field
+			} else if orderDir == "asc" {
+				orderby = field
+			} else {
+				return nil, errors.New("Error: Invalid order. Must be either [asc|desc]")
 			}
-		} else if len(sortby) != len(order) && len(order) != 1 {
-			return nil, errors.New("Error: 'sortby', 'order' sizes mismatch or 'order' size is not 1")
+			sortFields = append(sortFields, orderby)
 		}
+		qs = qs.OrderBy(sortFields...)
 	} else {
 		if len(order) != 0 {
 			return nil, errors.New("Error: unused 'order' fields")
@@ -101,7 +126,15 @@ func GetAllAiceUsers(query map[string]string, fields []string, sortby []string,
 	}
 
 	var l []AiceUsers
-	qs = qs.OrderBy(sortFields...)
+
+	// 验证分页参数
+	if limit <= 0 || limit > 1000 {
+		return nil, errors.New("Error: limit must be between 1 and 1000")
+	}
+	if offset < 0 {
+		return nil, errors.New("Error: offset must be non-negative")
+	}
+
 	if _, err = qs.Limit(limit, offset).All(&l, fields...); err == nil {
 		if len(fields) == 0 {
 			for _, v := range l {
@@ -113,7 +146,11 @@ func GetAllAiceUsers(query map[string]string, fields []string, sortby []string,
 				m := make(map[string]interface{})
 				val := reflect.ValueOf(v)
 				for _, fname := range fields {
-					m[fname] = val.FieldByName(fname).Interface()
+					field := val.FieldByName(fname)
+					if !field.IsValid() {
+						return nil, fmt.Errorf("invalid field name: %s", fname)
+					}
+					m[fname] = field.Interface()
 				}
 				ml = append(ml, m)
 			}