silverhoof 发表于 2015-05-29 14:04

一个简单的拦截器代码

拦截器package com.sfexpress.pmp.setting;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.sfexpress.pmp.dao.UserDao;
import com.sfexpress.pmp.model.user.User;

import org.glassfish.grizzly.http.server.Session;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.*;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;

import java.io.*;
import java.util.Map;
import java.util.regex.Pattern;

@Component
public class XSSRequestFilter implements Filter {
   
    @Autowired
    private UserDao userDao;
   
    @Override
    public void init(FilterConfig config) throws ServletException {
    }

    @Override
    public void destroy() {
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
      HttpServletRequest httpServletRequest = (HttpServletRequest) request;
      HttpServletResponse httpServletResponse = (HttpServletResponse) response;
         
         
      Cookie[] cookies = httpServletRequest.getCookies();
      HttpSession session = httpServletRequest.getSession();
      addUserToSession(session);
      //User user = (User)session.getAttribute("user");
      String url = httpServletRequest.getServletPath();
      Map<String,String> map = Maps.newHashMap();//用于存登录的时候请求地址
      map.put("a", "/resources/login.htm");
      map.put("b", "/favicon.ico");
      map.put("c", "/j_spring_security_check");
      boolean flag=true;
      if(map.containsValue(url)){//包含登录
            flag=false;
      }
      if(null != cookies){
            for(Cookie c :cookies ){
                //当前端cookies里有jsessionid,并且前端cookies的jsessionid跟服务端的sessionId不一致,说明是有退出,登录的时候除外
                if(flag && null != c.getValue() &&"JSESSIONID".equals(c.getName()) && !c.getValue().equals(session.getId())){
                  //httpServletResponse.sendRedirect("http://asura.st.sf-express.com/");
                  httpServletResponse.setStatus(406);
                  return;
                }
            }
      }
      /*
      if(httpServletRequest.getSession(false)==null || session.getId()==null){
            System.out.println("session已过期");
      }*/
         
         
      if (decodeHash(httpServletRequest, httpServletResponse)) return;

      if (request.getContentType() != null && request.getContentType().contains("multipart/form-data")) {
            chain.doFilter(httpServletRequest, response);
      } else {
            if ((httpServletRequest).getMethod().equalsIgnoreCase("POST") || (httpServletRequest).getMethod().equalsIgnoreCase("PUT")) {
                httpServletRequest = new XSSRequestWrapper(httpServletRequest);
            }
            chain.doFilter(httpServletRequest, response);
      }
    }

   
    public void addUserToSession(HttpSession session){
         
         
    }
   
    private boolean decodeHash(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws IOException {
      String requestURI = httpServletRequest.getRequestURI();
      if (requestURI.contains("%23")) {
            String replacedURL = requestURI.replace("%23", "#");
            httpServletResponse.sendRedirect(replacedURL);
            return true;
      }
      return false;
    }

    private static class XSSRequestWrapper extends HttpServletRequestWrapper {
      private HttpServletRequest request;

      private static Map<Pattern, String> FILTER_MAPPING = Maps.newLinkedHashMap(new ImmutableMap.Builder<Pattern, String>()
               /*.put(Pattern.compile("&"), "&amp;")
                .put(Pattern.compile("<"), "&lt;")
                .put(Pattern.compile(">"), "&gt;")*/
                .put(Pattern.compile("&"), "&")
                .put(Pattern.compile("<"), "〈")
                .put(Pattern.compile(">"), "〉")
                .put(Pattern.compile("&amp;"), "&")
                .put(Pattern.compile("&lt;"), "〈")
                .put(Pattern.compile("&gt;"), "〉")
                .build());

      public XSSRequestWrapper(HttpServletRequest request) throws IOException {
            super(request);
            this.request = request;
      }

      @Override
      public ServletInputStream getInputStream() throws IOException {
            XSSServletInputStream xssServletInputStream = new XSSServletInputStream();
            ServletInputStream inputStream = request.getInputStream();

            ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
            try {
                int ch;
                while ((ch = inputStream.read()) != -1) {
                  byteStream.write(ch);
                }
            } finally {
                inputStream.close();
            }
            xssServletInputStream.stream = new ByteArrayInputStream(filterRequestBody(new String(byteStream.toByteArray(),"utf-8")).getBytes("utf-8"));
            return xssServletInputStream;
      }

      @Override
      public BufferedReader getReader() throws IOException {
            return new BufferedReader(new InputStreamReader(getInputStream(), request.getCharacterEncoding()));
      }

      private String filterRequestBody(String requestBody){
            String filterResult = readUnicodeStr2(requestBody);
            for (Map.Entry<Pattern, String> patternStringEntry : FILTER_MAPPING.entrySet()) {
                Pattern pattern = patternStringEntry.getKey();
                filterResult = pattern.matcher(filterResult).replaceAll(patternStringEntry.getValue());
            }
            return filterResult;
      }

      private class XSSServletInputStream extends ServletInputStream {
            private InputStream stream;

            @Override
            public int read() throws IOException {
                return stream.read();
            }
      }

      public static String readUnicodeStr2(String unicodeStr) {
            StringBuilder buf = new StringBuilder();

            for (int i = 0; i < unicodeStr.length(); i++) {
                char char1 = unicodeStr.charAt(i);
                if (char1 == '\\' && isUnicode(unicodeStr, i)) {
                  char char2 = unicodeStr.charAt(i-1);
                  if(char2 == '\\'){
                        buf.append(char1);
                        continue;
                  }
                  String cStr = unicodeStr.substring(i + 2, i + 6);
                  int cInt = Integer.parseInt(cStr,16);
                  buf.append((char) cInt);
                  // 跨过当前unicode码,因为还有i++,所以这里i加5,而不是6
                  i = i + 5;
                } else {
                  buf.append(char1);
                }
            }
            return buf.toString();
      }

      // 判断以index从i开始的串,是不是unicode码
      private static boolean isUnicode(String unicodeStr, int i) {
            int len = unicodeStr.length();
            int remain = len - i;
            // unicode码,反斜杠后还有5个字符 uxxxx
            if (remain < 5)
                return false;

            char flag2 = unicodeStr.charAt(i + 1);
            if (flag2 != 'u')
                return false;
            String nextFour = unicodeStr.substring(i + 2, i + 6);
            return isHexStr(nextFour);
      }

      /** hex str 1-9 a-f A-F */
      private static boolean isHexStr(String str) {
            for (int i = 0; i < str.length(); i++) {
                char ch = str.charAt(i);
                boolean isHex = (ch >= '0' && ch <= '9' || ch >= 'a' && ch <= 'f' || ch >= 'A' && ch <= 'F');
                if (!isHex)
                  return false;
            }
            return true;
      }
    }

}<filter>
    <filter-name>XSSFilter</filter-name>
    <filter-class>com.sfexpress.pmp.setting.XSSRequestFilter</filter-class>
</filter>

<filter-mapping>
    <filter-name>XSSFilter</filter-name>
    <url-pattern>/*</url-pattern>
</filter-mapping>
页: [1]
查看完整版本: 一个简单的拦截器代码