SpringBoot防止XSS

前言:什么是XSS

在跨站脚本(XSS)攻击中,攻击者可以在受害者的浏览器中执行恶意脚本。这种攻击通常是通过在网页中插入恶意代码 (JavaScript) 来完成的。攻击者在使用攻击后一般能够:

  • 修改网页内容
  • 将用户重定向到其他网站
  • 访问用户的 Cookie 并利用此信息来冒充用户
  • 访问有关用户系统的关键信息,例如地理位置,网络摄像头,文件系统
  • 将木马功能注入应用程序

如果被攻击的用户在应用程序中具有更高的权限。攻击者可以完全控制应用程序,并破坏所有用户及其数据。

演示XSS

此文不太严谨,将密码作为设置xss攻击列,各位看官见谅,本文只是作为了解及防范xss

数据表

create table mybatis.user
(
    id          int auto_increment
        primary key,
    user_name   varchar(255)                       null,
    user_pwd    varchar(255)                       null,
    create_time datetime default CURRENT_TIMESTAMP null
)
    collate = utf8mb4_bin;
View Code

依赖:

<dependencies>
	<dependency>
		<groupId>cn.hutool</groupId>
		<artifactId>hutool-all</artifactId>
		<version>5.8.3</version>
	</dependency>
	<dependency>
		<groupId>org.springframework.boot</groupId>
		<artifactId>spring-boot-starter-web</artifactId>
	</dependency>
	<dependency>
		<groupId>com.baomidou</groupId>
		<artifactId>mybatis-plus-boot-starter</artifactId>
		<version>3.5.1</version>
	</dependency>
	<dependency>
		<groupId>mysql</groupId>
		<artifactId>mysql-connector-java</artifactId>
	</dependency>
	<dependency>
		<groupId>org.projectlombok</groupId>
		<artifactId>lombok</artifactId>
		<optional>true</optional>
	</dependency>
</dependencies>

配置文件

# 数据库驱动:
spring.datasource.driver-class-name=com.mysql.cj.jdbc.Driver
# 数据源名称
spring.datasource.name=defaultDataSource
# 数据库连接地址
spring.datasource.url=jdbc:mysql://localhost:3306/mybatis?userSSL=true&useUnicode=true&characterEncoding=UTF-8&serverTimezone=UTC
# 数据库用户名&密码:
spring.datasource.username=***
spring.datasource.password=***
# 应用服务 WEB 访问端口
server.port=8080
#配置文件位置
mybatis-plus.mapper-locations=classpath:mapper/*.xml
#MyBatisPlus配置-下划线转驼峰
mybatis-plus.configuration.map-underscore-to-camel-case=true
# 配置日志 (系统自带的,控制台输出)
mybatis-plus.configuration.log-impl=org.apache.ibatis.logging.stdout.StdOutImpl

实体类

import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;

import java.io.Serializable;
import java.time.LocalDateTime;

import com.fasterxml.jackson.annotation.JsonFormat;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.datatype.jsr310.deser.LocalDateTimeDeserializer;
import com.fasterxml.jackson.datatype.jsr310.ser.LocalDateTimeSerializer;
import lombok.Data;
import org.springframework.format.annotation.DateTimeFormat;
import org.springframework.web.bind.annotation.RequestParam;

/**
 * @TableName user
 */
@TableName(value = "user")
@Data
public class User implements Serializable {

    @TableId(value = "id", type = IdType.AUTO)
    private Integer id;

    @TableField(value = "user_name")
    private String userName;

    @TableField(value = "user_pwd")
    private String userPwd;

    /**
     * 创建时间
     */
    @TableField(value = "create_time")// 数据库中字段名称
    @JsonDeserialize(using = LocalDateTimeDeserializer.class)// 反序列化
    @JsonSerialize(using = LocalDateTimeSerializer.class)// 序列化
    @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss", timezone = "GMT+8")// 对入参进行格式化
    @DateTimeFormat(pattern = "yyyy-MM-dd HH:mm:ss")// 对出参进行格式化
    private LocalDateTime createTime;


    @TableField(exist = false)
    private static final long serialVersionUID = 1L;

    @Override
    public boolean equals(Object that) {
        if (this == that) {
            return true;
        }
        if (that == null) {
            return false;
        }
        if (getClass() != that.getClass()) {
            return false;
        }
        User other = (User) that;
        return (this.getId() == null ? other.getId() == null : this.getId().equals(other.getId()))
                && (this.getUserName() == null ? other.getUserName() == null : this.getUserName().equals(other.getUserName()))
                && (this.getUserPwd() == null ? other.getUserPwd() == null : this.getUserPwd().equals(other.getUserPwd()))
                && (this.getCreateTime() == null ? other.getCreateTime() == null : this.getCreateTime().equals(other.getCreateTime()));
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result + ((getId() == null) ? 0 : getId().hashCode());
        result = prime * result + ((getUserName() == null) ? 0 : getUserName().hashCode());
        result = prime * result + ((getUserPwd() == null) ? 0 : getUserPwd().hashCode());
        result = prime * result + ((getCreateTime() == null) ? 0 : getCreateTime().hashCode());
        return result;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(getClass().getSimpleName());
        sb.append(" [");
        sb.append("Hash = ").append(hashCode());
        sb.append(", id=").append(id);
        sb.append(", userName=").append(userName);
        sb.append(", userPwd=").append(userPwd);
        sb.append(", createTime=").append(createTime);
        sb.append(", serialVersionUID=").append(serialVersionUID);
        sb.append("]");
        return sb.toString();
    }
}
View Code

controller

import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.zhixi.pojo.User;
import com.zhixi.service.UserService;
import org.springframework.web.bind.annotation.*;

import javax.annotation.Resource;
import java.time.LocalDateTime;
/**
 * @ClassName UserController
 * @Author zhangzhixi
 * @Description
 * @Date 2022-12-05 14:26
 * @Version 1.0
 */
@RestController
@RequestMapping("user")
public class UserController {

    @Resource
    private UserService userService;
    /**
     * 插入数据
     * @param user 实体
     * @return =1表示插入成功
     */
    @PostMapping("/insertUser")
    public String insertUser(@RequestBody User user) {
        BaseMapper<User> baseMapper = userService.getBaseMapper();
        int insert = baseMapper.insert(user);
        return insert == 1 ? "插入成功,id为" + user.getId() : "数据插入失败";
    }

    /**
     * 通过id获取user密码
     * @param id 用户id
     * @return 用户实体
     */
    @GetMapping("/getUserById/{id}")
    public String getUserById(@PathVariable Integer id) {
        BaseMapper<User> baseMapper = userService.getBaseMapper();
        User user = baseMapper.selectById(id);
        return user.getUserPwd();
    }
}
View Code

测试

发送POST请求进行插入数据:

浏览器发送Get请求,获取到密码:发现在浏览器中执行了js代码,显然这是不应该发生的。

解决XSS(一):使用Filter过滤器 

代码:

1、编写装饰器类,对请求方法的参数进行处理:

package com.zhixi.config.xss;

import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HtmlUtil;
import cn.hutool.json.JSONUtil;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
/**
 * @ClassName XssHttpServletRequestWrapper
 * @Author zhangzhixi
 * @Description 装饰器模式-Xss过滤处理
 * @Date 2022/8/18 12:51
 * @Version 1.0
 */
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {

    public XssHttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
    }

    @Override
    public String getParameter(String name) {
        String value = super.getParameter(name);
        if (!StrUtil.hasEmpty(value)) {
            value = HtmlUtil.filter(value);
        }
        return value;
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] values = super.getParameterValues(name);
        if (values != null) {
            for (int i = 0; i < values.length; i++) {
                String value = values[i];
                if (!StrUtil.hasEmpty(value)) {
                    value = HtmlUtil.filter(value);
                }
                values[i] = value;
            }
        }
        return values;
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        Map<String, String[]> parameters = super.getParameterMap();
        Map<String, String[]> map = new LinkedHashMap<>();
        if (parameters != null) {
            for (String key : parameters.keySet()) {
                String[] values = parameters.get(key);
                for (int i = 0; i < values.length; i++) {
                    String value = values[i];
                    if (!StrUtil.hasEmpty(value)) {
                        value = HtmlUtil.filter(value);
                    }
                    values[i] = value;
                }
                map.put(key, values);
            }
        }
        return map;
    }

    @Override
    public String getHeader(String name) {
        String value = super.getHeader(name);
        if (!StrUtil.hasEmpty(value)) {
            value = HtmlUtil.filter(value);
        }
        return value;
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        InputStream in = super.getInputStream();
        StringBuilder body = new StringBuilder();
        InputStreamReader reader = new InputStreamReader(in, StandardCharsets.UTF_8);
        BufferedReader buffer = new BufferedReader(reader);
        String line = buffer.readLine();
        while (line != null) {
            body.append(line);
            line = buffer.readLine();
        }
        buffer.close();
        reader.close();
        in.close();

        Map<String, Object> map = JSONUtil.parseObj(body.toString());
        Map<String, Object> resultMap = new HashMap<>(map.size());
        for (String key : map.keySet()) {
            Object val = map.get(key);
            if (map.get(key) instanceof String) {
                resultMap.put(key, HtmlUtil.filter(val.toString()));
            } else {
                resultMap.put(key, val);
            }
        }
        String str = JSONUtil.toJsonStr(resultMap);
        final ByteArrayInputStream bain = new ByteArrayInputStream(str.getBytes());
        return new ServletInputStream() {
            @Override
            public int read() throws IOException {
                return bain.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener listener) {
            }
        };
    }
}
View Code

2、编写全局过滤器

package com.zhixi.config.xss;

import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;

/**
 * @ClassName XssFilter
 * @Author zhangzhixi
 * @Description 过滤所有请求,将参数中的特殊字符过滤掉
 * @Date 2022/8/18 12:56
 * @Version 1.0
 */
@WebFilter(urlPatterns = "/*")
public class XssFilter implements Filter {

    @Override
    public void init(FilterConfig config) throws ServletException {
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        XssHttpServletRequestWrapper xssRequest = new XssHttpServletRequestWrapper((HttpServletRequest) request);
        chain.doFilter(xssRequest, response);
    }

    @Override
    public void destroy() {
    }
}
View Code

3、使Filter生效,在主启动类加上注解

@ServletComponentScan(basePackages = {"com.zhixi.config.xss"})

测试:

 

posted @ 2022-12-05 22:10  Java小白的搬砖路  阅读(741)  评论(0编辑  收藏  举报