diff --git a/middleware/log.go b/middleware/log.go index b61b38f..615cea3 100644 --- a/middleware/log.go +++ b/middleware/log.go @@ -15,6 +15,7 @@ import ( "sweet-cms/form/response" "sweet-cms/model" "sweet-cms/service" + "sweet-cms/utils" "time" ) @@ -46,9 +47,9 @@ func LogHandler(logService *service.LogService) gin.HandlerFunc { Ip: c.ClientIP(), Locality: "", Url: c.Request.URL.Path, - Body: string(bodyStr), - Query: string(queryStr), - Response: responseBody, + Body: utils.SanitizeInput(string(bodyStr)), + Query: utils.SanitizeInput(string(queryStr)), + Response: utils.SanitizeInput(responseBody), } err := logService.CreateAccessLog(c, accessLog) if err != nil { @@ -57,9 +58,9 @@ func LogHandler(logService *service.LogService) gin.HandlerFunc { zap.L().Info("用户访问日志:", zap.String("uri", c.Request.URL.Path), zap.String("method", c.Request.Method), - zap.Any("query", c.Request.URL.Query()), - zap.Any("body", c.Request.Body), - zap.Any("response", responseBody), + zap.Any("query", accessLog.Query), + zap.Any("body", accessLog.Body), + zap.String("response", accessLog.Response), zap.String("ip", c.ClientIP()), zap.String("duration", fmt.Sprintf("%.4f seconds", duration.Seconds()))) zap.L().Info("Access Log end") diff --git a/utils/tools.go b/utils/tools.go index 8b2c93b..7b08bf4 100644 --- a/utils/tools.go +++ b/utils/tools.go @@ -9,6 +9,7 @@ import ( ut "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" "github.com/pkg/errors" + "html" "io" "math/rand" "net/http" @@ -296,6 +297,7 @@ func ValidatorBody[T any](ctx *gin.Context, data *T, translator ut.Translator) e } return err } + cleanData(data) return nil } @@ -326,6 +328,7 @@ func ValidatorQuery[T any](ctx *gin.Context, data *T, translator ut.Translator) } return err } + cleanData(data) return nil } @@ -358,3 +361,70 @@ func BuildMenuTree(menus []model.SysMenu, pid int) []model.SysMenu { } return tree } + +func cleanData(data any) { + val := reflect.ValueOf(data).Elem() + if val.Kind() != reflect.Struct { + return + } + + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + + switch field.Kind() { + case reflect.String: + // Sanitize the string by escaping HTML special characters + escapedStr := html.EscapeString(field.String()) + field.SetString(escapedStr) + case reflect.Struct: + // Recursively sanitize nested structs + fieldValue := field.Addr().Interface() + cleanData(fieldValue) + + case reflect.Slice: + // Recursively sanitize elements of slice if it's a slice of struct or string + if field.Type().Elem().Kind() == reflect.String { + for j := 0; j < field.Len(); j++ { + escapedStr := html.EscapeString(field.Index(j).String()) + field.Index(j).SetString(escapedStr) + } + } else if field.Type().Elem().Kind() == reflect.Struct { + for j := 0; j < field.Len(); j++ { + element := field.Index(j).Addr().Interface() + cleanData(element) + } + } + case reflect.Map: + // Recursively sanitize elements of map if it's a map of strings + if field.Type().Key().Kind() == reflect.String && field.Type().Elem().Kind() == reflect.String { + iter := field.MapRange() + for iter.Next() { + key := iter.Key() + val := iter.Value() + escapedVal := html.EscapeString(val.String()) + field.SetMapIndex(key, reflect.ValueOf(escapedVal)) + } + } + } + } +} + +func SanitizeInput(input string) string { + replacements := map[string]string{ + "\n": "\\n", + "\r": "\\r", + "\t": "\\t", + } + for old, new := range replacements { + input = strings.ReplaceAll(input, old, new) + } + cleaned := "" + for _, r := range input { + if r >= 32 && r <= 126 { + cleaned += string(r) + } else { + cleaned += fmt.Sprintf("\\x%x", r) + } + } + return cleaned +}