package com.bxm.newidea.component.log;

import com.bxm.newidea.component.config.ComponentWebConfigurationProperties;
import com.bxm.newidea.component.util.WebUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.logging.log4j.ThreadContext;
import org.springframework.util.AntPathMatcher;

import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;

/**
 * 日志上下文过滤器，用于设置当前请求环境的上下文，用于后续的日志记录与现实
 *
 * @author liujia 2018/3/30 15:20
 */
@WebFilter(filterName = "log4j2ContextFilter", urlPatterns = "/**")
@Slf4j
public class LogContextFilter implements Filter {

    public LogContextFilter(ComponentWebConfigurationProperties properties) {
        this.properties = properties;
        antPathMatcher = new AntPathMatcher();
    }

    private ComponentWebConfigurationProperties properties;

    private AntPathMatcher antPathMatcher;

    @Override
    public void init(FilterConfig filterConfig) {
        if (log.isDebugEnabled()) {
            log.debug("init log context filter");
        }
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        HttpServletResponse response = (HttpServletResponse) servletResponse;

        //修改request，使得可以读取多次requestBody
        if (!WebUtils.isMultipartRequest(request) && !(request instanceof ContentCachingRequestWrapper)) {
            request = new ContentCachingRequestWrapper(request);
        }

        if (properties.isEnableRequestLog()) {
            String url = request.getRequestURI();

            for (String logUrl : properties.getIncludeUrlList()) {
                if (antPathMatcher.match(logUrl, url)) {
                    String requestParam = WebUtils.getRequestParam(request);
                    log.info("请求地址：[{}]，请求参数：[{}]", url, requestParam);
                }
            }
        }

        //设置当前请求的IP到Thread中，方便记录时保持IP信息
        ThreadContext.put(LogConstant.REQUEST_IP, WebUtils.getIpAddr(request));
        chain.doFilter(request, response);
    }

    @Override
    public void destroy() {
        ThreadContext.clearAll();
    }

    /**
     * 保存请求中的requestBody
     */
    private class ContentCachingRequestWrapper extends HttpServletRequestWrapper {

        private byte[] body;

        private BufferedReader reader;

        private ServletInputStream inputStream;

        private Map<String, String[]> paramMap;

        private ContentCachingRequestWrapper(HttpServletRequest request) throws IOException {
            super(request);
            loadBody(request);
            paramMap = new HashMap<>(request.getParameterMap());
        }

        private void loadBody(HttpServletRequest request) throws IOException {
            body = IOUtils.toByteArray(request.getInputStream());
            inputStream = new RequestCachingInputStream(body);
        }

        @Override
        public ServletInputStream getInputStream() throws IOException {
            return new RequestCachingInputStream(body);
        }

        @Override
        public BufferedReader getReader() throws IOException {
            if (reader == null) {
                reader = new BufferedReader(new InputStreamReader(inputStream, getCharacterEncoding()));
            }
            return reader;
        }

        @Override
        public String getParameter(String name) {
            String[] strings = paramMap.get(name);

            if (null != strings && strings.length > 0) {
                return strings[0];
            }
            return null;
        }

        @Override
        public Map<String, String[]> getParameterMap() {
            return paramMap;
        }

        @Override
        public Enumeration<String> getParameterNames() {
            return Collections.enumeration(paramMap.keySet());
        }

        @Override
        public String[] getParameterValues(String name) {
            return paramMap.get(name);
        }

        private class RequestCachingInputStream extends ServletInputStream {

            private final ByteArrayInputStream inputStream;

            private RequestCachingInputStream(byte[] bytes) {
                inputStream = new ByteArrayInputStream(bytes);
            }

            @Override
            public int read() {
                return inputStream.read();
            }

            @Override
            public int read(byte[] b) throws IOException {
                return inputStream.read(b);
            }

            @Override
            public boolean isFinished() {
                return inputStream.available() == 0;
            }

            @Override
            public boolean isReady() {
                return true;
            }

            @Override
            public void setReadListener(ReadListener readlistener) {
            }
        }
    }
}
