无状态shiro认证组件(禁用默认session)

准备内容

简单的shiro无状态认证

  无状态认证拦截器

import com.hjzgg.stateless.shiroSimpleWeb.Constants;
import com.hjzgg.stateless.shiroSimpleWeb.realm.StatelessToken;
import org.apache.shiro.web.filter.AccessControlFilter;

import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

/**

 * <p>Version: 1.0
 */
public class StatelessAuthcFilter extends AccessControlFilter {

    @Override
    protected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) throws Exception {
        return false;
    }

    @Override
    protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {
        //1、客户端生成的消息摘要
        String clientDigest = request.getParameter(Constants.PARAM_DIGEST);
        //2、客户端传入的用户身份
        String username = request.getParameter(Constants.PARAM_USERNAME);
        //3、客户端请求的参数列表
        Map<String, String[]> params = new HashMap<String, String[]>(request.getParameterMap());
        params.remove(Constants.PARAM_DIGEST);

        //4、生成无状态Token
        StatelessToken token = new StatelessToken(username, params, clientDigest);

        try {
            //5、委托给Realm进行登录
            getSubject(request, response).login(token);
        } catch (Exception e) {
            e.printStackTrace();
            onLoginFail(response); //6、登录失败
            return false;
        }
        return true;
    }

    //登录失败时默认返回401状态码
    private void onLoginFail(ServletResponse response) throws IOException {
        HttpServletResponse httpResponse = (HttpServletResponse) response;
        httpResponse.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
        httpResponse.getWriter().write("login error");
    }
}
View Code

  Subject工厂

import org.apache.shiro.subject.Subject;
import org.apache.shiro.subject.SubjectContext;
import org.apache.shiro.web.mgt.DefaultWebSubjectFactory;

/**

 * <p>Version: 1.0
 */
public class StatelessDefaultSubjectFactory extends DefaultWebSubjectFactory {

    @Override
    public Subject createSubject(SubjectContext context) {
        //不创建session
        context.setSessionCreationEnabled(false);
        return super.createSubject(context);
    }
}
View Code

  注意,这里禁用了session

  无状态Realm

import com.hjzgg.stateless.shiroSimpleWeb.codec.HmacSHA256Utils;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authc.AuthenticationInfo;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.authc.SimpleAuthenticationInfo;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.realm.AuthorizingRealm;
import org.apache.shiro.subject.PrincipalCollection;

/**

 * <p>Version: 1.0
 */
public class StatelessRealm extends AuthorizingRealm {
    @Override
    public boolean supports(AuthenticationToken token) {
        //仅支持StatelessToken类型的Token
        return token instanceof StatelessToken;
    }
    @Override
    protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) {
        //根据用户名查找角色,请根据需求实现
        String username = (String) principals.getPrimaryPrincipal();
        SimpleAuthorizationInfo authorizationInfo =  new SimpleAuthorizationInfo();
        authorizationInfo.addRole("admin");
        return authorizationInfo;
    }
    @Override
    protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken token) throws AuthenticationException {
        StatelessToken statelessToken = (StatelessToken) token;
        String username = statelessToken.getUsername();
        String key = getKey(username);//根据用户名获取密钥(和客户端的一样)
        //在服务器端生成客户端参数消息摘要
        String serverDigest = HmacSHA256Utils.digest(key, statelessToken.getParams());
        System.out.println(statelessToken.getClientDigest());
        System.out.println(serverDigest);
        //然后进行客户端消息摘要和服务器端消息摘要的匹配
        return new SimpleAuthenticationInfo(
                username,
                serverDigest,
                getName());
    }

    private String getKey(String username) {//得到密钥,此处硬编码一个
        if("admin".equals(username)) {
            return "dadadswdewq2ewdwqdwadsadasd";
        }
        return null;
    }
}
View Code

  无状态Token

import org.apache.shiro.authc.AuthenticationToken;
import org.springframework.beans.*;
import org.springframework.validation.DataBinder;

import java.util.HashMap;
import java.util.Map;

/**

 * <p>Version: 1.0
 */
public class StatelessToken implements AuthenticationToken {

    private String username;
    private Map<String, ?> params;
    private String clientDigest;

    public StatelessToken(String username,  Map<String, ?> params, String clientDigest) {
        this.username = username;
        this.params = params;
        this.clientDigest = clientDigest;
    }

    public String getUsername() {
        return username;
    }

    public void setUsername(String username) {
        this.username = username;
    }

    public  Map<String, ?> getParams() {
        return params;
    }

    public void setParams( Map<String, ?> params) {
        this.params = params;
    }

    public String getClientDigest() {
        return clientDigest;
    }

    public void setClientDigest(String clientDigest) {
        this.clientDigest = clientDigest;
    }

    @Override
    public Object getPrincipal() {
       return username;
    }

    @Override
    public Object getCredentials() {
        return clientDigest;
    }

    public static void main(String[] args) {

    }
    public static void test1() {
        StatelessToken token = new StatelessToken(null, null, null);
        BeanWrapperImpl beanWrapper = new BeanWrapperImpl(token);
        beanWrapper.setPropertyValue(new PropertyValue("username", "hjzgg"));
        System.out.println(token.getUsername());
    }

    public static void test2() {
        StatelessToken token = new StatelessToken(null, null, null);
        DataBinder dataBinder = new DataBinder(token);
        Map<String, Object> params = new HashMap<>();
        params.put("username", "hjzgg");
        PropertyValues propertyValues = new MutablePropertyValues(params);
        dataBinder.bind(propertyValues);
        System.out.println(token.getUsername());
    }
}
View Code

  shiro配置文件

<?xml version="1.0" encoding="UTF-8"?>
<beans xmlns="http://www.springframework.org/schema/beans"
       xmlns:util="http://www.springframework.org/schema/util"
       xmlns:aop="http://www.springframework.org/schema/aop"
       xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
       xsi:schemaLocation="
       http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans.xsd
       http://www.springframework.org/schema/util http://www.springframework.org/schema/util/spring-util.xsd
       http://www.springframework.org/schema/aop http://www.springframework.org/schema/aop/spring-aop.xsd">

    <!-- Realm实现 -->
    <bean id="statelessRealm" class="com.hjzgg.stateless.shiroSimpleWeb.realm.StatelessRealm">
        <property name="cachingEnabled" value="false"/>
    </bean>

    <!-- Subject工厂 -->
    <bean id="subjectFactory" class="com.hjzgg.stateless.shiroSimpleWeb.mgt.StatelessDefaultSubjectFactory"/>

    <!-- 会话管理器 -->
    <bean id="sessionManager" class="org.apache.shiro.session.mgt.DefaultSessionManager">
        <property name="sessionValidationSchedulerEnabled" value="false"/>
    </bean>

    <!-- 安全管理器 -->
    <bean id="securityManager" class="org.apache.shiro.web.mgt.DefaultWebSecurityManager">
        <property name="realm" ref="statelessRealm"/>
        <property name="subjectDAO.sessionStorageEvaluator.sessionStorageEnabled" value="false"/>
        <property name="subjectFactory" ref="subjectFactory"/>
        <property name="sessionManager" ref="sessionManager"/>
    </bean>

    <!-- 相当于调用SecurityUtils.setSecurityManager(securityManager) -->
    <bean class="org.springframework.beans.factory.config.MethodInvokingFactoryBean">
        <property name="staticMethod" value="org.apache.shiro.SecurityUtils.setSecurityManager"/>
        <property name="arguments" ref="securityManager"/>
    </bean>

    <bean id="statelessAuthcFilter" class="com.hjzgg.stateless.shiroSimpleWeb.filter.StatelessAuthcFilter"/>

    <!-- Shiro的Web过滤器 -->
    <bean id="shiroFilter" class="org.apache.shiro.spring.web.ShiroFilterFactoryBean">
        <property name="securityManager" ref="securityManager"/>
        <property name="filters">
            <util:map>
                <entry key="statelessAuthc" value-ref="statelessAuthcFilter"/>
            </util:map>
        </property>
        <property name="filterChainDefinitions">
            <value>
                /**=statelessAuthc
            </value>
        </property>
    </bean>

    <!-- Shiro生命周期处理器-->
    <bean id="lifecycleBeanPostProcessor" class="org.apache.shiro.spring.LifecycleBeanPostProcessor"/>

</beans>
View Code

  这里禁用了回话调度器的session存储

  web.xml配置

<?xml version="1.0" encoding="UTF-8"?>
<web-app
        xmlns="http://java.sun.com/xml/ns/javaee"
        xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
        xsi:schemaLocation="http://java.sun.com/xml/ns/javaee http://java.sun.com/xml/ns/javaee/web-app_3_0.xsd"
        version="3.0"
        metadata-complete="false">

    <display-name>shiro-example-chapter20</display-name>

    <!-- Spring配置文件开始  -->
    <context-param>
        <param-name>contextConfigLocation</param-name>
        <param-value>
            classpath:spring-config-shiro.xml
        </param-value>
    </context-param>
    <listener>
        <listener-class>org.springframework.web.context.ContextLoaderListener</listener-class>
    </listener>
    <!-- Spring配置文件结束 -->

    <!-- shiro 安全过滤器 -->
    <filter>
        <filter-name>shiroFilter</filter-name>
        <filter-class>org.springframework.web.filter.DelegatingFilterProxy</filter-class>
        <async-supported>true</async-supported>
        <init-param>
            <param-name>targetFilterLifecycle</param-name>
            <param-value>true</param-value>
        </init-param>
    </filter>

    <filter-mapping>
        <filter-name>shiroFilter</filter-name>
        <url-pattern>/*</url-pattern>
        <dispatcher>REQUEST</dispatcher>
    </filter-mapping>

    <servlet>
        <servlet-name>spring</servlet-name>
        <servlet-class>org.springframework.web.servlet.DispatcherServlet</servlet-class>
        <init-param>
            <param-name>contextConfigLocation</param-name>
            <param-value>classpath:spring-mvc.xml</param-value>
        </init-param>
        <load-on-startup>1</load-on-startup>
        <async-supported>true</async-supported>
    </servlet>
    <servlet-mapping>
        <servlet-name>spring</servlet-name>
        <url-pattern>/</url-pattern>
    </servlet-mapping>


</web-app>
View Code

  token生成工具类

import org.apache.commons.codec.binary.Hex;

import javax.crypto.Mac;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import java.util.List;
import java.util.Map;

/**

 * <p>Version: 1.0
 */
public class HmacSHA256Utils {

    public static String digest(String key, String content) {
        try {
            Mac mac = Mac.getInstance("HmacSHA256");
            byte[] secretByte = key.getBytes("utf-8");
            byte[] dataBytes = content.getBytes("utf-8");

            SecretKey secret = new SecretKeySpec(secretByte, "HMACSHA256");
            mac.init(secret);

            byte[] doFinal = mac.doFinal(dataBytes);
            byte[] hexB = new Hex().encode(doFinal);
            return new String(hexB, "utf-8");
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static String digest(String key, Map<String, ?> map) {
        StringBuilder s = new StringBuilder();
        for(Object values : map.values()) {
            if(values instanceof String[]) {
                for(String value : (String[])values) {
                    s.append(value);
                }
            } else if(values instanceof List) {
                for(String value : (List<String>)values) {
                    s.append(value);
                }
            } else {
                s.append(values);
            }
        }
        return digest(key, s.toString());
    }

}
View Code

  简单测试一下

import com.alibaba.fastjson.JSONObject;
import com.hjzgg.stateless.shiroSimpleWeb.codec.HmacSHA256Utils;
import com.hjzgg.stateless.shiroSimpleWeb.utils.RestTemplateUtils;
import org.junit.Test;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.util.UriComponentsBuilder;

/**
 * <p>Version: 1.0
 */
public class ClientTest {

    private static final String WEB_URL = "http://localhost:8080/shiro/hello";

    @Test
    public void testServiceHelloSuccess() {
        String username = "admin";
        String param11 = "param11";
        String param12 = "param12";
        String param2 = "param2";
        String key = "dadadswdewq2ewdwqdwadsadasd";
        JSONObject params = new JSONObject();
        params.put(Constants.PARAM_USERNAME, username);
        params.put("param1", param11);
        params.put("param1", param12);
        params.put("param2", param2);
        params.put(Constants.PARAM_DIGEST, HmacSHA256Utils.digest(key, params));

        String result = RestTemplateUtils.get(WEB_URL, params);
        System.out.println(result);
    }

    @Test
    public void testServiceHelloFail() {
        String username = "admin";
        String param11 = "param11";
        String param12 = "param12";
        String param2 = "param2";
        String key = "dadadswdewq2ewdwqdwadsadasd";
        MultiValueMap<String, String> params = new LinkedMultiValueMap<String, String>();
        params.add(Constants.PARAM_USERNAME, username);
        params.add("param1", param11);
        params.add("param1", param12);
        params.add("param2", param2);
        params.add(Constants.PARAM_DIGEST, HmacSHA256Utils.digest(key, params));
        params.set("param2", param2 + "1");

        String url = UriComponentsBuilder
                .fromHttpUrl("http://localhost:8080/hello")
                .queryParams(params).build().toUriString();
    }
}
View Code

  补充Spring中多重属性赋值处理

  以上参考 开涛老师的博文

相对复杂一点的shiro无状态认证

  *加入session,放入redis中(user_name作为key值,token作为hash值,当前登录时间作为value值)

  *用户登录互斥操作:如果互斥,清除redis中该用户对应的状态,重新写入新的状态;如果不互斥,写入新的状态,刷新key值,并检测该用户其他的状态是否已经超时(根据key值获取到所有的 key和hashKey的组合,判断value[登入时间]+timeout[超时时间] >= curtime[当前时间]),如果超时则清除状态。

  *使用esapi进行token的生成

  *认证信息,如果是web端则从cookie中获取,ajax从header中获取;如果是移动端也是从header中获取

  session manager逻辑

import com.hjzgg.stateless.auth.token.ITokenProcessor;
import com.hjzgg.stateless.auth.token.TokenFactory;
import com.hjzgg.stateless.auth.token.TokenGenerator;
import com.hjzgg.stateless.common.cache.RedisCacheTemplate;
import com.hjzgg.stateless.common.esapi.EncryptException;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

@Component
public class ShiroSessionManager {

    @Autowired
    private RedisCacheTemplate redisCacheTemplate;

    @Value("${sessionMutex}")
    private boolean sessionMutex = false;

    public static final String TOKEN_SEED = "token_seed";
    
    public static final String DEFAULT_CHARSET = "UTF-8";
    
    private final Logger logger = LoggerFactory.getLogger(getClass());
    
    private static String localSeedValue = null;

    /**
     * 获得当前系统的 token seed
     */
    public String findSeed() throws EncryptException {
        if(localSeedValue != null){
            return localSeedValue;
        } else {
            String seed = getSeedValue(TOKEN_SEED);
            if (StringUtils.isBlank(seed)) {
                seed = TokenGenerator.genSeed();
                localSeedValue = seed;
                redisCacheTemplate.put(TOKEN_SEED, seed);
            }
            return seed;
        }
    }
    
    public String getSeedValue(String key) {
        return (String) redisCacheTemplate.get(key);
    }


    /**
     * 删除session缓存
     * 
     * @param sid mock的sessionid
     */
    public void removeSessionCache(String sid) {
        redisCacheTemplate.delete(sid);
    }


    private int getTimeout(String sid){
        return TokenFactory.getTokenInfo(sid).getIntegerExpr();
    }

    private String getCurrentTimeSeconds() {
        return String.valueOf(System.currentTimeMillis()/1000);
    }
    
    public void registOnlineSession(final String userName, final String token, final ITokenProcessor processor) {
        final String key = userName;
        logger.debug("token processor id is {}, key is {}, sessionMutex is {}!" , processor.getId(), key, sessionMutex);

        // 是否互斥,如果是,则踢掉所有当前用户的session,重新创建,此变量将来从配置文件读取
        if(sessionMutex){
            deleteUserSession(key);
        } else {
            // 清理此用户过期的session,过期的常为异常或者直接关闭浏览器,没有走正常注销的key
            clearOnlineSession(key);
        }

        redisCacheTemplate.hPut(userName, token, getCurrentTimeSeconds());
        int timeout = getTimeout(token);
        if (timeout > 0) {
            redisCacheTemplate.expire(token, timeout);
        }
    }

    private void clearOnlineSession(final String key) {
        redisCacheTemplate.hKeys(key).forEach((obj) -> {
            String hashKey = (String) obj;
            int timeout = getTimeout(hashKey);
            if (timeout > 0) {
                int oldTimeSecondsValue = Integer.valueOf((String) redisCacheTemplate.hGet(key, hashKey));
                int curTimeSecondsValue = (int) (System.currentTimeMillis()/1000);
                //如果 key-hashKey 对应的时间+过期时间 小于 当前时间,则剔除
                if(curTimeSecondsValue - (oldTimeSecondsValue+timeout) > 0) {
                    redisCacheTemplate.hDel(key, hashKey);
                }
            }
        });
    }

    public boolean validateOnlineSession(final String key, final String hashKey) {
        int timeout = getTimeout(hashKey);
        if (timeout > 0) {
            String oldTimeSecondsValue = (String) redisCacheTemplate.hGet(key, hashKey);
            if (StringUtils.isEmpty(oldTimeSecondsValue)) {
                return false;
            } else {
                int curTimeSecondsValue = (int) (System.currentTimeMillis()/1000);
                if(Integer.valueOf(oldTimeSecondsValue)+timeout >= curTimeSecondsValue) {
                    //刷新 key
                    redisCacheTemplate.hPut(key, hashKey, getCurrentTimeSeconds());
                    redisCacheTemplate.expire(key, timeout);
                    return true;
                } else {
                    redisCacheTemplate.hDel(key, hashKey);
                    return false;
                }
            }
        } else {
            return redisCacheTemplate.hGet(key, hashKey) != null;
        }
    }
    
    // 注销用户时候需要调用
    public void delOnlineSession(final String key, final String hashKey){
        redisCacheTemplate.hDel(key, hashKey);
    }
    
    // 禁用或者删除用户时候调用
    public void deleteUserSession(final String key){
        redisCacheTemplate.delete(key);
    }
}
View Code

  无状态认证过滤器

package com.hjzgg.stateless.auth.shiro;

import com.alibaba.fastjson.JSONObject;
import com.hjzgg.stateless.auth.token.ITokenProcessor;
import com.hjzgg.stateless.auth.token.TokenFactory;
import com.hjzgg.stateless.auth.token.TokenParameter;
import com.hjzgg.stateless.common.constants.AuthConstants;
import com.hjzgg.stateless.common.utils.CookieUtil;
import com.hjzgg.stateless.common.utils.InvocationInfoProxy;
import com.hjzgg.stateless.common.utils.MapToStringUtil;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.lang3.StringUtils;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.web.filter.AccessControlFilter;
import org.apache.shiro.web.util.WebUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;

import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.net.URL;
import java.util.*;

public class StatelessAuthcFilter extends AccessControlFilter {
    
    private static final Logger log = LoggerFactory.getLogger(StatelessAuthcFilter.class);

    public static final int HTTP_STATUS_AUTH = 306;

    @Value("${filterExclude}")
    private String exeludeStr;

    @Autowired
    private TokenFactory tokenFactory;
    
    private String[] esc = new String[] {
        "/logout","/login","/formLogin",".jpg",".png",".gif",".css",".js",".jpeg"
    };

    private List<String> excludCongtextKeys = new ArrayList<>();
    
    public void setTokenFactory(TokenFactory tokenFactory) {
        this.tokenFactory = tokenFactory;
    }

    public void setEsc(String[] esc) {
        this.esc = esc;
    }
    
    public void setExcludCongtextKeys(List<String> excludCongtextKeys) {
        this.excludCongtextKeys = excludCongtextKeys;
    }

    @Override
    protected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) throws Exception {
        return false;
    }

    @Override
    protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {

        boolean isAjax = isAjax(request);

        // 1、客户端发送来的摘要
        HttpServletRequest hReq = (HttpServletRequest) request;
        HttpServletRequest httpRequest = hReq;
        Cookie[] cookies = httpRequest.getCookies();
        String authority = httpRequest.getHeader("Authority");
        
        //如果header中包含,则以header为主,否则,以cookie为主
        if(StringUtils.isNotBlank(authority)){
            Set<Cookie> cookieSet = new HashSet<Cookie>();
            String[] ac = authority.split(";");
            for(String s : ac){
                String[] cookieArr = s.split("=");
                String key = StringUtils.trim(cookieArr[0]);
                String value = StringUtils.trim(cookieArr[1]);
                Cookie cookie = new Cookie(key, value);
                cookieSet.add(cookie);
            }
            cookies = cookieSet.toArray(new Cookie[]{});
        }
        
        String tokenStr = CookieUtil.findCookieValue(cookies, AuthConstants.PARAM_TOKEN);
        String cookieUserName = CookieUtil.findCookieValue(cookies, AuthConstants.PARAM_USERNAME);

        String loginTs = CookieUtil.findCookieValue(cookies, AuthConstants.PARAM_LOGINTS);

        // 2、客户端传入的用户身份
        String userName = request.getParameter(AuthConstants.PARAM_USERNAME);
        if (userName == null && StringUtils.isNotBlank(cookieUserName)) {
            userName = cookieUserName;
        }

        boolean needCheck = !include(hReq);

        if (needCheck) {
            if (StringUtils.isEmpty(tokenStr) || StringUtils.isEmpty(userName)) {
                if (isAjax) {
                    onAjaxAuthFail(request, response);
                } else {
                    onLoginFail(request, response);
                }
                return false;
            }

            // 3、客户端请求的参数列表
            Map<String, String[]> params = new HashMap<String, String[]>(request.getParameterMap());

            ITokenProcessor tokenProcessor = tokenFactory.getTokenProcessor(tokenStr);
            TokenParameter tp = tokenProcessor.getTokenParameterFromCookie(cookies);
            // 4、生成无状态Token
            StatelessToken token = new StatelessToken(userName, tokenProcessor, tp, params, new String(tokenStr));

            try {
                // 5、委托给Realm进行登录
                getSubject(request, response).login(token); // 这个地方应该验证上下文信息中的正确性

                // 设置上下文变量
                InvocationInfoProxy.setUserName(userName);
                InvocationInfoProxy.setLoginTs(loginTs);
                InvocationInfoProxy.setToken(tokenStr);

                //设置上下文携带的额外属性
                initExtendParams(cookies);

                initMDC();
                afterValidate(hReq);
            } catch (Exception e) {
                log.error(e.getMessage(), e);
                if (isAjax && e instanceof AuthenticationException) {
                    onAjaxAuthFail(request, response); // 6、验证失败,返回ajax调用方信息
                    return false;
                } else {
                    onLoginFail(request, response); // 6、登录失败,跳转到登录页
                    return false;
                }
            }
            return true;
        } else {
            return true;
        }

    }

    private boolean isAjax(ServletRequest request) {
        boolean isAjax = false;
        if (request instanceof HttpServletRequest) {
            HttpServletRequest rq = (HttpServletRequest) request;
            String requestType = rq.getHeader("X-Requested-With");
            if (requestType != null && "XMLHttpRequest".equals(requestType)) {
                isAjax = true;
            }
        }
        return isAjax;
    }

    protected void onAjaxAuthFail(ServletRequest request, ServletResponse resp) throws IOException {
        HttpServletResponse response = (HttpServletResponse) resp;
        JSONObject json = new JSONObject();
        json.put("msg", "auth check error!");
        response.setStatus(HTTP_STATUS_AUTH);
        response.getWriter().write(json.toString());
    }

    // 登录失败时默认返回306状态码
    protected void onLoginFail(ServletRequest request, ServletResponse response) throws IOException {
        HttpServletResponse httpResponse = (HttpServletResponse) response;
        httpResponse.setStatus(HTTP_STATUS_AUTH);
        request.setAttribute("msg", "auth check error!");
        // 跳转到登录页
        redirectToLogin(request, httpResponse);
    }

    @Override
    protected void redirectToLogin(ServletRequest request, ServletResponse response) throws IOException {
        HttpServletRequest hReq = (HttpServletRequest) request;
        String rURL = hReq.getRequestURI();
        String errors = StringUtils.isEmpty((String) request.getAttribute("msg")) ? "" : "&msg=" + request.getAttribute("msg");

        if(request.getAttribute("msg") != null) {
            rURL += ((StringUtils.isNotEmpty(hReq.getQueryString())) ?
                    "&" : "") + "msg=" + request.getAttribute("msg");
        }

        rURL = Base64.encodeBase64URLSafeString(rURL.getBytes()) ;
        // 加入登录前地址, 以及错误信息
        String loginUrl = getLoginUrl() + "?r=" + rURL + errors;

        WebUtils.issueRedirect(request, response, loginUrl);
    }

    public boolean include(HttpServletRequest request) {
        String u = request.getRequestURI();
        for (String e : esc) {
            if (u.endsWith(e)) {
                return true;
            }
        }

        if(StringUtils.isNotBlank(exeludeStr)){
            String[] customExcludes = exeludeStr.split(",");
            for (String e : customExcludes) {
                if (u.endsWith(e)) {
                    return true;
                }
            }
        }
        
        return false;
    }

    @Override
    public void afterCompletion(ServletRequest request, ServletResponse response, Exception exception) throws Exception {
        super.afterCompletion(request, response, exception);
        InvocationInfoProxy.reset();
        clearMDC();
    }

    // 设置上下文中的扩展参数,rest传递上下文时生效,Authority header中排除固定key的其它信息都设置到InvocationInfoProxy的parameters
    private void initExtendParams(Cookie[] cookies) {
        for (Cookie cookie : cookies) {
            String cname = cookie.getName();
            String cvalue = cookie.getValue();
            if(!excludCongtextKeys.contains(cname)){
                InvocationInfoProxy.setParameter(cname, cvalue);
            }
        }
    }
    
    private void initMDC() {
        String userName = "";
        Subject subject = SecurityUtils.getSubject();
        if (subject != null && subject.getPrincipal() != null) {
            userName = (String) SecurityUtils.getSubject().getPrincipal();
        }

        // MDC中记录用户信息
        MDC.put(AuthConstants.PARAM_USERNAME, userName);

        initCustomMDC();
    }
    
    protected void initCustomMDC() {
        MDC.put("InvocationInfoProxy", MapToStringUtil.toEqualString(InvocationInfoProxy.getResources(), ';'));
    }

    protected void afterValidate(HttpServletRequest hReq){
    }
    
    protected void clearMDC() {
        // MDC中记录用户信息
        MDC.remove(AuthConstants.PARAM_USERNAME);

        clearCustomMDC();
    }

    protected void clearCustomMDC() {
        MDC.remove("InvocationInfoProxy");
    }

    //初始化 AuthConstants类中定义的常量
    {
        Field[] fields = AuthConstants.class.getDeclaredFields();
        try {
            for (Field field : fields) {
                field.setAccessible(true);
                if (field.getType().toString().endsWith("java.lang.String")
                        && Modifier.isStatic(field.getModifiers())) {
                    excludCongtextKeys.add((String) field.get(AuthConstants.class));
                }
            }
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
    }
}
View Code

  dubbo服务调用时上下文的传递问题

  思路:认证过滤器中 通过MDC将上下文信息写入到InheritableThreadLocal中,写一个dubbo的过滤器。在过滤器中判断,如果是消费一方,则将MDC中的上下文取出来放入dubbo的context变量中;如果是服务方,则从dubbo的context中拿出上下文,解析并放入MDC以及InvocationInfoProxy(下面会提到)类中

  Subject工厂

import org.apache.shiro.subject.Subject;
import org.apache.shiro.subject.SubjectContext;
import org.apache.shiro.web.mgt.DefaultWebSubjectFactory;

public class StatelessDefaultSubjectFactory extends DefaultWebSubjectFactory {

    @Override
    public Subject createSubject(SubjectContext context) {
        //不创建session
        context.setSessionCreationEnabled(false);
        return super.createSubject(context);
    }
}
View Code

  同样禁用掉session的创建

  无状态Realm

import com.hjzgg.stateless.auth.session.ShiroSessionManager;
import com.hjzgg.stateless.auth.token.ITokenProcessor;
import com.hjzgg.stateless.auth.token.TokenParameter;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authc.AuthenticationInfo;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.authc.SimpleAuthenticationInfo;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.realm.AuthorizingRealm;
import org.apache.shiro.subject.PrincipalCollection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;

import java.util.ArrayList;
import java.util.List;

public class StatelessRealm extends AuthorizingRealm {

    private static final Logger logger = LoggerFactory.getLogger(StatelessRealm.class);

    @Autowired
    private ShiroSessionManager shiroSessionManager;

    @Override
    public boolean supports(AuthenticationToken token) {
        // 仅支持StatelessToken类型的Token
        return token instanceof StatelessToken;
    }

    @Override
    protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) {
        SimpleAuthorizationInfo info = new SimpleAuthorizationInfo();
        List<String> roles = new ArrayList<String>();
        info.addRoles(roles);
        return info;
    }

    @Override
    protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken atoken) throws AuthenticationException {
        StatelessToken token = (StatelessToken) atoken;
        TokenParameter tp = token.getTp();
        String userName = (String) token.getPrincipal();
        ITokenProcessor tokenProcessor = token.getTokenProcessor();
        String tokenStr = tokenProcessor.generateToken(tp);
        if (tokenStr == null || !shiroSessionManager.validateOnlineSession(userName, tokenStr)) {
            logger.error("User [{}] authenticate fail in System, maybe session timeout!", userName);
            throw new AuthenticationException("User " + userName + " authenticate fail in System");
        }
        
        return new SimpleAuthenticationInfo(userName, tokenStr, getName());
    }

}
View Code

  这里使用自定义 session manager去校验

  无状态token

import com.hjzgg.stateless.auth.token.ITokenProcessor;
import com.hjzgg.stateless.auth.token.TokenParameter;
import org.apache.shiro.authc.AuthenticationToken;

import java.util.Map;

public class StatelessToken implements AuthenticationToken {

    private String userName;
    // 预留参数集合,校验更复杂的权限
    private Map<String, ?> params;
    private String clientDigest;
    ITokenProcessor tokenProcessor;
    TokenParameter tp;
    public StatelessToken(String userName, ITokenProcessor tokenProcessor, TokenParameter tp , Map<String, ?> params, String clientDigest) {
        this.userName = userName;
        this.params = params;
        this.tp = tp;
        this.tokenProcessor = tokenProcessor;
        this.clientDigest = clientDigest;
    }

    public TokenParameter getTp() {
        return tp;
    }

    public void setTp(TokenParameter tp) {
        this.tp = tp;
    }

    public String getUserName() {
        return userName;
    }

    public void setUserName(String userName) {
        this.userName = userName;
    }

    public  Map<String, ?> getParams() {
        return params;
    }

    public void setParams( Map<String, ?> params) {
        this.params = params;
    }

    public String getClientDigest() {
        return clientDigest;
    }

    public void setClientDigest(String clientDigest) {
        this.clientDigest = clientDigest;
    }

    @Override
    public Object getPrincipal() {
       return userName;
    }

    @Override
    public Object getCredentials() {
        return clientDigest;
    }

    public ITokenProcessor getTokenProcessor() {
        return tokenProcessor;
    }

    public void setTokenProcessor(ITokenProcessor tokenProcessor) {
        this.tokenProcessor = tokenProcessor;
    }
}
View Code

  token处理器

import com.hjzgg.stateless.auth.session.ShiroSessionManager;
import com.hjzgg.stateless.common.constants.AuthConstants;
import com.hjzgg.stateless.common.esapi.EncryptException;
import com.hjzgg.stateless.common.esapi.IYCPESAPI;
import com.hjzgg.stateless.common.utils.CookieUtil;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;

import javax.servlet.http.Cookie;
import java.io.UnsupportedEncodingException;
import java.net.URL;
import java.net.URLEncoder;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;

/**
 * 默认Token处理器提供将cooke和TokenParameter相互转换,Token生成的能力
 * <p>
 * 可以注册多个实例
 * </p>
 * 
 * @author li
 *
 */
public class DefaultTokenPorcessor implements ITokenProcessor {
    private static Logger log = LoggerFactory.getLogger(DefaultTokenPorcessor.class);
    private static int HTTPVERSION = 3;
    static {
        URL res = DefaultTokenPorcessor.class.getClassLoader().getResource("javax/servlet/annotation/WebServlet.class");
        if (res == null) {
            HTTPVERSION = 2;
        }
    }
    private String id;
    private String domain;
    private String path = "/";
    private Integer expr;
    // 默认迭代次数
    private int hashIterations = 2;

    @Autowired
    private ShiroSessionManager shiroSessionManager;

    @Override
    public String getId() {
        return id;
    }

    public void setId(String id) {
        this.id = id;
    }

    public String getDomain() {
        return domain;
    }

    public void setDomain(String domain) {
        this.domain = domain;
    }

    public String getPath() {
        return path;
    }

    public void setPath(String path) {
        this.path = path;
    }

    public Integer getExpr() {
        return expr;
    }

    public void setExpr(Integer expr) {
        this.expr = expr;
    }

    private List<String> exacts = new ArrayList<String>();

    public void setExacts(List<String> exacts) {
        this.exacts = exacts;
    }

    public int getHashIterations() {
        return hashIterations;
    }

    public void setHashIterations(int hashIterations) {
        this.hashIterations = hashIterations;
    }

    @Override
    public String generateToken(TokenParameter tp) {
        try {
            String seed = shiroSessionManager.findSeed();
            String token = IYCPESAPI.encryptor().hash(
                            this.id + tp.getUserName() + tp.getLoginTs() + getSummary(tp) + getExpr(),
                            seed,
                            getHashIterations());
            token = this.id + "," + getExpr() + "," + token;
            return Base64.encodeBase64URLSafeString(org.apache.commons.codec.binary.StringUtils.getBytesUtf8(token));
        } catch (EncryptException e) {
            log.error("TokenParameter is not validate!", e);
            throw new IllegalArgumentException("TokenParameter is not validate!");
        }
    }

    @Override
    public Cookie[] getCookieFromTokenParameter(TokenParameter tp) {
        List<Cookie> cookies = new ArrayList<Cookie>();
        String tokenStr = generateToken(tp);
        Cookie token = new Cookie(AuthConstants.PARAM_TOKEN, tokenStr);
        if (HTTPVERSION == 3)
            token.setHttpOnly(true);
        if (StringUtils.isNotEmpty(domain))
            token.setDomain(domain);
        token.setPath(path);
        cookies.add(token);

        try {
            Cookie userId = new Cookie(AuthConstants.PARAM_USERNAME, URLEncoder.encode(tp.getUserName(), "UTF-8"));
            if (StringUtils.isNotEmpty(domain))
                userId.setDomain(domain);
            userId.setPath(path);
            cookies.add(userId);

            // 登录的时间戳
            Cookie logints = new Cookie(AuthConstants.PARAM_LOGINTS, URLEncoder.encode(tp.getLoginTs(), "UTF-8"));
            if (StringUtils.isNotEmpty(domain))
                logints.setDomain(domain);
            logints.setPath(path);
            cookies.add(logints);
        } catch (UnsupportedEncodingException e) {
            log.error("encode error!", e);
        }

        if (!tp.getExt().isEmpty()) {
            Iterator<Entry<String, String>> it = tp.getExt().entrySet().iterator();
            while (it.hasNext()) {
                Entry<String, String> i = it.next();
                Cookie ext = new Cookie(i.getKey(), i.getValue());
                if (StringUtils.isNotEmpty(domain))
                    ext.setDomain(domain);
                ext.setPath(path);
                cookies.add(ext);
            }
        }

        shiroSessionManager.registOnlineSession(tp.getUserName(), tokenStr, this);

        return cookies.toArray(new Cookie[] {});
    }

    @Override
    public TokenParameter getTokenParameterFromCookie(Cookie[] cookies) {
        TokenParameter tp = new TokenParameter();
        String token = CookieUtil.findCookieValue(cookies, AuthConstants.PARAM_TOKEN);
        TokenInfo ti = TokenFactory.getTokenInfo(token);
        if (ti.getIntegerExpr().intValue() != this.getExpr().intValue()) {
            throw new IllegalArgumentException("illegal token!");
        }
        String userId = CookieUtil.findCookieValue(cookies, AuthConstants.PARAM_USERNAME);
        tp.setUserName(userId);
        String loginTs = CookieUtil.findCookieValue(cookies, AuthConstants.PARAM_LOGINTS);
        tp.setLoginTs(loginTs);

        if (exacts != null && !exacts.isEmpty()) {
            for (int i = 0; i < cookies.length; i++) {
                Cookie cookie = cookies[i];
                String name = cookie.getName();
                if (exacts.contains(name)) {
                    tp.getExt().put(name,
                            cookie.getValue() == null ? "" : cookie.getValue());
                }
            }
        }
        return tp;
    }

    protected String getSummary(TokenParameter tp) {
        if (exacts != null && !exacts.isEmpty()) {
            int len = exacts.size();
            String[] exa = new String[len];
            for (int i = 0; i < len; i++) {
                String name = exacts.get(i);
                String value = tp.getExt().get(name);
                if(value == null) value = "";
                exa[i] = value;
            }
            return StringUtils.join(exa, "#");
        }
        return "";
    }

    @Override
    public Cookie[] getLogoutCookie(String tokenStr, String uid) {
        List<Cookie> cookies = new ArrayList<Cookie>();
        Cookie token = new Cookie(AuthConstants.PARAM_TOKEN, null);
        if (StringUtils.isNotEmpty(domain))
            token.setDomain(domain);
        token.setPath(path);
        cookies.add(token);

        Cookie userId = new Cookie(AuthConstants.PARAM_USERNAME, null);
        if (StringUtils.isNotEmpty(domain))
            userId.setDomain(domain);
        userId.setPath(path);
        cookies.add(userId);

        // 登录的时间戳
        Cookie logints = new Cookie(AuthConstants.PARAM_LOGINTS, null);
        if (StringUtils.isNotEmpty(domain))
            logints.setDomain(domain);
        logints.setPath(path);
        cookies.add(logints);
        for (String exact : exacts) {
            Cookie ext = new Cookie(exact, null);
            if (StringUtils.isNotEmpty(domain))
                ext.setDomain(domain);
            ext.setPath(path);
            cookies.add(ext);
        }

        shiroSessionManager.delOnlineSession(uid, tokenStr);

        return cookies.toArray(new Cookie[] {});
    }
}
View Code

  将一些必须字段和扩展字段进行通过esapi 的hash算法进行加密,生成token串,最终的token = token处理器标识+过期时间+原token

  shiro配置文件

<?xml version="1.0" encoding="UTF-8"?>
<beans xmlns="http://www.springframework.org/schema/beans"
       xmlns:util="http://www.springframework.org/schema/util"
       xmlns:aop="http://www.springframework.org/schema/aop"
       xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
       xsi:schemaLocation="
       http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans.xsd
       http://www.springframework.org/schema/util http://www.springframework.org/schema/util/spring-util.xsd
       http://www.springframework.org/schema/aop http://www.springframework.org/schema/aop/spring-aop.xsd">

    <bean id="statelessRealm" class="com.hjzgg.stateless.auth.shiro.StatelessRealm">
        <property name="cachingEnabled" value="false" />
    </bean>

    <!-- Subject工厂 -->
    <bean id="subjectFactory"
          class="com.hjzgg.stateless.auth.shiro.StatelessDefaultSubjectFactory" />

    <bean id="webTokenProcessor" class="com.hjzgg.stateless.auth.token.DefaultTokenPorcessor">
        <property name="id" value="web"></property>
        <property name="path" value="${context.name}"></property>
        <property name="expr" value="${sessionTimeout}"></property>
        <property name="exacts">
            <list>
                <value type="java.lang.String">userType</value>
            </list>
        </property>
    </bean>
    <bean id="maTokenProcessor" class="com.hjzgg.stateless.auth.token.DefaultTokenPorcessor">
        <property name="id" value="ma"></property>
        <property name="path" value="${context.name}"></property>
        <property name="expr" value="-1"></property>
        <property name="exacts">
            <list>
                <value type="java.lang.String">userType</value>
            </list>
        </property>
    </bean>

    <bean id="tokenFactory" class="com.hjzgg.stateless.auth.token.TokenFactory">
        <property name="processors">
            <list>
                <ref bean="webTokenProcessor" />
                <ref bean="maTokenProcessor" />
            </list>
        </property>
    </bean>


    <!-- 会话管理器 -->
    <bean id="sessionManager" class="org.apache.shiro.session.mgt.DefaultSessionManager">
        <property name="sessionValidationSchedulerEnabled" value="false" />
    </bean>

    <!-- 安全管理器 -->
    <bean id="securityManager" class="org.apache.shiro.web.mgt.DefaultWebSecurityManager">
        <property name="realms">
            <list>
                <ref bean="statelessRealm" />
            </list>
        </property>
        <property name="subjectDAO.sessionStorageEvaluator.sessionStorageEnabled"
                  value="false" />
        <property name="subjectFactory" ref="subjectFactory" />
        <property name="sessionManager" ref="sessionManager" />
    </bean>

    <!-- 相当于调用SecurityUtils.setSecurityManager(securityManager) -->
    <bean
            class="org.springframework.beans.factory.config.MethodInvokingFactoryBean">
        <property name="staticMethod"
                  value="org.apache.shiro.SecurityUtils.setSecurityManager" />
        <property name="arguments" ref="securityManager" />
    </bean>

    <bean id="statelessAuthcFilter" class="com.hjzgg.stateless.auth.shiro.StatelessAuthcFilter">
        <property name="tokenFactory" ref="tokenFactory" />
    </bean>

    <bean id="logout" class="com.hjzgg.stateless.auth.shiro.LogoutFilter"></bean>

    <!-- Shiro的Web过滤器 -->
    <bean id="shiroFilter" class="org.apache.shiro.spring.web.ShiroFilterFactoryBean">
        <property name="securityManager" ref="securityManager" />
        <property name="loginUrl" value="/login" />
        <property name="filters">
            <util:map>
                <entry key="statelessAuthc" value-ref="statelessAuthcFilter" />
            </util:map>
        </property>
        <property name="filterChainDefinitions">
            <value>
                <!--swagger-->
                /webjars/** = anon
                /v2/api-docs/** = anon
                /swagger-resources/** = anon

                /login/** = anon
                /logout = logout
                /static/** = anon
                /css/** = anon
                /images/** = anon
                /trd/** = anon
                /js/** = anon
                /api/** = anon
                /cxf/** = anon
                /jaxrs/** = anon
                /** = statelessAuthc
            </value>
        </property>
    </bean>
    <!-- Shiro生命周期处理器 -->
    <bean id="lifecycleBeanPostProcessor" class="org.apache.shiro.spring.LifecycleBeanPostProcessor" />
</beans>
View Code

  通过InvocationInfoProxy这个类(基于ThreadLocal的),可以拿到用户相关的参数信息

import com.hjzgg.stateless.common.constants.AuthConstants;

import java.util.HashMap;
import java.util.Map;

/**
 * Created by hujunzheng on 2017/7/18.
 */
public class InvocationInfoProxy {
    private static final ThreadLocal<Map<String, Object>> resources =
        ThreadLocal.withInitial(() -> {
            Map<String, Object> initialValue = new HashMap<>();
            initialValue.put(AuthConstants.ExtendConstants.PARAM_PARAMETER, new HashMap<String, String>());
            return initialValue;
        }
    );

    public static String getUserName() {
        return (String) resources.get().get(AuthConstants.PARAM_USERNAME);
    }

    public static void setUserName(String userName) {
        resources.get().put(AuthConstants.PARAM_USERNAME, userName);
    }

    public static String getLoginTs() {
        return (String) resources.get().get(AuthConstants.PARAM_LOGINTS);
    }

    public static void setLoginTs(String loginTs) {
        resources.get().put(AuthConstants.PARAM_LOGINTS, loginTs);
    }

    public static String getToken() {
        return (String) resources.get().get(AuthConstants.PARAM_TOKEN);
    }

    public static void setToken(String token) {
        resources.get().put(AuthConstants.PARAM_TOKEN, token);
    }

    public static void setParameter(String key, String value) {
        ((Map<String, String>) resources.get().get(AuthConstants.ExtendConstants.PARAM_PARAMETER)).put(key, value);
    }

    public static String getParameter(String key) {
        return ((Map<String, String>) resources.get().get(AuthConstants.ExtendConstants.PARAM_PARAMETER)).get(key);
    }

    public static void reset() {
        resources.remove();
    }
}
View Code

  还有esapi和cache的相关代码到项目里看一下吧

项目地址

  欢迎访问,无状态shiro认证组件

参考拦截

    ESAPI入门使用方法

   Spring MVC 4.2 增加 CORS 支持

  HTTP访问控制(CORS)

  Slf4j MDC 使用和 基于 Logback 的实现分析

 

posted @ 2017-07-24 15:49  hjzqyx  阅读(19210)  评论(1编辑  收藏  举报