一个简单的拦截器代码
拦截器package com.sfexpress.pmp.setting;import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.sfexpress.pmp.dao.UserDao;
import com.sfexpress.pmp.model.user.User;
import org.glassfish.grizzly.http.server.Session;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.*;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.*;
import java.util.Map;
import java.util.regex.Pattern;
@Component
public class XSSRequestFilter implements Filter {
@Autowired
private UserDao userDao;
@Override
public void init(FilterConfig config) throws ServletException {
}
@Override
public void destroy() {
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
HttpServletResponse httpServletResponse = (HttpServletResponse) response;
Cookie[] cookies = httpServletRequest.getCookies();
HttpSession session = httpServletRequest.getSession();
addUserToSession(session);
//User user = (User)session.getAttribute("user");
String url = httpServletRequest.getServletPath();
Map<String,String> map = Maps.newHashMap();//用于存登录的时候请求地址
map.put("a", "/resources/login.htm");
map.put("b", "/favicon.ico");
map.put("c", "/j_spring_security_check");
boolean flag=true;
if(map.containsValue(url)){//包含登录
flag=false;
}
if(null != cookies){
for(Cookie c :cookies ){
//当前端cookies里有jsessionid,并且前端cookies的jsessionid跟服务端的sessionId不一致,说明是有退出,登录的时候除外
if(flag && null != c.getValue() &&"JSESSIONID".equals(c.getName()) && !c.getValue().equals(session.getId())){
//httpServletResponse.sendRedirect("http://asura.st.sf-express.com/");
httpServletResponse.setStatus(406);
return;
}
}
}
/*
if(httpServletRequest.getSession(false)==null || session.getId()==null){
System.out.println("session已过期");
}*/
if (decodeHash(httpServletRequest, httpServletResponse)) return;
if (request.getContentType() != null && request.getContentType().contains("multipart/form-data")) {
chain.doFilter(httpServletRequest, response);
} else {
if ((httpServletRequest).getMethod().equalsIgnoreCase("POST") || (httpServletRequest).getMethod().equalsIgnoreCase("PUT")) {
httpServletRequest = new XSSRequestWrapper(httpServletRequest);
}
chain.doFilter(httpServletRequest, response);
}
}
public void addUserToSession(HttpSession session){
}
private boolean decodeHash(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws IOException {
String requestURI = httpServletRequest.getRequestURI();
if (requestURI.contains("%23")) {
String replacedURL = requestURI.replace("%23", "#");
httpServletResponse.sendRedirect(replacedURL);
return true;
}
return false;
}
private static class XSSRequestWrapper extends HttpServletRequestWrapper {
private HttpServletRequest request;
private static Map<Pattern, String> FILTER_MAPPING = Maps.newLinkedHashMap(new ImmutableMap.Builder<Pattern, String>()
/*.put(Pattern.compile("&"), "&")
.put(Pattern.compile("<"), "<")
.put(Pattern.compile(">"), ">")*/
.put(Pattern.compile("&"), "&")
.put(Pattern.compile("<"), "〈")
.put(Pattern.compile(">"), "〉")
.put(Pattern.compile("&"), "&")
.put(Pattern.compile("<"), "〈")
.put(Pattern.compile(">"), "〉")
.build());
public XSSRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
this.request = request;
}
@Override
public ServletInputStream getInputStream() throws IOException {
XSSServletInputStream xssServletInputStream = new XSSServletInputStream();
ServletInputStream inputStream = request.getInputStream();
ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
try {
int ch;
while ((ch = inputStream.read()) != -1) {
byteStream.write(ch);
}
} finally {
inputStream.close();
}
xssServletInputStream.stream = new ByteArrayInputStream(filterRequestBody(new String(byteStream.toByteArray(),"utf-8")).getBytes("utf-8"));
return xssServletInputStream;
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream(), request.getCharacterEncoding()));
}
private String filterRequestBody(String requestBody){
String filterResult = readUnicodeStr2(requestBody);
for (Map.Entry<Pattern, String> patternStringEntry : FILTER_MAPPING.entrySet()) {
Pattern pattern = patternStringEntry.getKey();
filterResult = pattern.matcher(filterResult).replaceAll(patternStringEntry.getValue());
}
return filterResult;
}
private class XSSServletInputStream extends ServletInputStream {
private InputStream stream;
@Override
public int read() throws IOException {
return stream.read();
}
}
public static String readUnicodeStr2(String unicodeStr) {
StringBuilder buf = new StringBuilder();
for (int i = 0; i < unicodeStr.length(); i++) {
char char1 = unicodeStr.charAt(i);
if (char1 == '\\' && isUnicode(unicodeStr, i)) {
char char2 = unicodeStr.charAt(i-1);
if(char2 == '\\'){
buf.append(char1);
continue;
}
String cStr = unicodeStr.substring(i + 2, i + 6);
int cInt = Integer.parseInt(cStr,16);
buf.append((char) cInt);
// 跨过当前unicode码,因为还有i++,所以这里i加5,而不是6
i = i + 5;
} else {
buf.append(char1);
}
}
return buf.toString();
}
// 判断以index从i开始的串,是不是unicode码
private static boolean isUnicode(String unicodeStr, int i) {
int len = unicodeStr.length();
int remain = len - i;
// unicode码,反斜杠后还有5个字符 uxxxx
if (remain < 5)
return false;
char flag2 = unicodeStr.charAt(i + 1);
if (flag2 != 'u')
return false;
String nextFour = unicodeStr.substring(i + 2, i + 6);
return isHexStr(nextFour);
}
/** hex str 1-9 a-f A-F */
private static boolean isHexStr(String str) {
for (int i = 0; i < str.length(); i++) {
char ch = str.charAt(i);
boolean isHex = (ch >= '0' && ch <= '9' || ch >= 'a' && ch <= 'f' || ch >= 'A' && ch <= 'F');
if (!isHex)
return false;
}
return true;
}
}
}<filter>
<filter-name>XSSFilter</filter-name>
<filter-class>com.sfexpress.pmp.setting.XSSRequestFilter</filter-class>
</filter>
<filter-mapping>
<filter-name>XSSFilter</filter-name>
<url-pattern>/*</url-pattern>
</filter-mapping>
页:
[1]