Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,11 @@ public static <T> RestResult<T> fail(String errorMessage) {
return new RestResult<>(400, errorMessage, null, false);
}

public static <T> RestResult<T> limit() {
return new RestResult<>(408, "访问频繁,请稍后再试");
}

public static <T> RestResult<T> notFound() {
return new RestResult<>(404, "NOT FOUND");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
*/
public class ResponseUtils {

private final static String DATE_FORMAT = "yyyy-MM-dd HH:mm:ss";

public static void objectToJson(HttpServletResponse response, Object object) throws IOException {
response.setCharacterEncoding("UTF-8");
response.setContentType("application/json");
response.setStatus(HttpServletResponse.SC_OK);
ObjectMapper mapper = new ObjectMapper();
// todo 注释上说这是线程不安全的,可能需要修改
mapper.setDateFormat(new SimpleDateFormat(DATE_FORMAT));
// 注释上说这是线程不安全的,可能需要修改(ThreadLocal解决线程安全问题)
mapper.setDateFormat(TimeUtil.getDateTimeFormat());
mapper.writeValue(response.getWriter(), object);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package cn.sticki.common.web.utils;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.text.DateFormat;
import java.time.Duration;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.text.SimpleDateFormat;
import java.util.Date;

/**
* @Author if
* @Description: 时间日期工具类并通过ThreadLocal解决DateFormat的线程安全问题
* ThreadLocal通过保存各个线程的SimpleDateFormat类对象的副本
* 使每个线程在运行时,各自使用自身绑定的SimpleDateFormat对象
* 互不干扰,执行性能比较高,推荐在高并发的生产环境使用。
* @Date 2022-01-08 下午 03:20
*/
public class TimeUtil {

/**
* 时间格式(yyyy-MM-dd)
*/
public final static String DATE_PATTERN = "yyyy-MM-dd";
private static final ThreadLocal<DateFormat> DATE_LOCAL = ThreadLocal.withInitial(() -> new SimpleDateFormat(DATE_PATTERN));

public static DateFormat getDateFormat() {
return DATE_LOCAL.get();
}

/**
* 时间格式(yyyy-MM-dd HH:mm:ss)
*/
public final static String DATE_TIME_PATTERN = "yyyy-MM-dd HH:mm:ss";
private static final ThreadLocal<DateFormat> DATE_TIME_LOCAL = ThreadLocal.withInitial(() -> new SimpleDateFormat(DATE_TIME_PATTERN));

public static DateFormat getDateTimeFormat() {
return DATE_TIME_LOCAL.get();
}


/**
* 日期格式化
*
* @param date 日期
* @return 默认返回yyyy-MM-dd HH:mm:ss格式日期
*/
public static String format(Date date) {
return format(date, DATE_TIME_PATTERN);
}

/**
* 计算相差分钟数,并保留两位小数
*
* @param start 起始时间
* @param end 截止时间
* @return 保留两位小数
*/
public static Double getSubMinutes(Date start, Date end) {
long seconds = getSubSeconds(start, end);
BigDecimal b = new BigDecimal(seconds / 60.0D);
return b.setScale(2, RoundingMode.HALF_UP).doubleValue();
}

/**
* 计算相差秒数
*
* @param start 起始时间
* @param end 截止时间
* @return 相差秒数
*/
public static Long getSubSeconds(Date start, Date end) {
LocalDateTime startTime = LocalDateTime.parse(format(start), DateTimeFormatter.ofPattern(DATE_TIME_PATTERN));
LocalDateTime endTime = LocalDateTime.parse(format(end), DateTimeFormatter.ofPattern(DATE_TIME_PATTERN));
return Duration.between(startTime, endTime).toMillis() / 1000;
}

/**
* 日期格式化
*
* @param date 日期
* @param pattern 格式,如:DateUtils.DATE_TIME_PATTERN
* @return 返回yyyy-MM-dd格式日期
*/
public static String format(Date date, String pattern) {
if (date == null) {
return null;
}
switch (pattern) {
case (DATE_PATTERN):
return DATE_LOCAL.get().format(date);
case (DATE_TIME_PATTERN):
return DATE_TIME_LOCAL.get().format(date);
default:
return null;
}
}

/**
* 判断今天是否在两个日期范围之内
*
* @param startTime 开始时间
* @param endTime 结束时间
*/
public static Boolean isTodayBetweenTwoDays(Date startTime, Date endTime) {
return isSomeDayBetweenTwoDays(new Date(), startTime, endTime);
}

/**
* 判断某个日期是否在两个日期范围之内
*
* @param someTime 某个日期
* @param startTime 开始时间
* @param endTime 结束时间
*/
public static Boolean isSomeDayBetweenTwoDays(Date someTime, Date startTime, Date endTime) {
return someTime.getTime() >= startTime.getTime() && someTime.getTime() <= endTime.getTime();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
Expand All @@ -13,19 +14,17 @@
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.RequestPath;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import javax.annotation.Resource;
import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

import static cn.sticki.gateway.utils.RedisConstants.GATEWAY_IPLIMIT_KEY;
import static cn.sticki.gateway.utils.RedisConstants.GATEWAY_IPLIMIT_TTL;

/**
* ip访问限制过滤器
*
Expand All @@ -35,51 +34,63 @@
@Component
public class IpLimitFilter implements GlobalFilter, Ordered {

private final static int LIMIT_COUNT = 30;
private final ObjectMapper objectMapper = new ObjectMapper();

private final static int LIMIT_TIME = 10;
@Value("${gateway.ip.limit.key:gateway:ipLimit:}")
private String gatewayIpLimitKey;

private final ObjectMapper objectMapper = new ObjectMapper();
@Value("${gateway.ip.limit.count:30}")
private int gatewayIpLimitCount;

@Value("${gateway.ip.limit.time:5}")
private long gatewayIpLimitTime;

@Value("${gateway.ip.limit.time:10}")
private long gatewayIpLimitTtl;

@Resource
private RedisTemplate<String, Integer> redisTemplate;

@SneakyThrows(IOException.class)
/**
* 限制接口频繁访问
* todo 有待优化,目前逻辑为,5s内连续访问30次,需要冷却10s
*/
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
// 0. 获取请求,响应,ip与uri
ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();
// 0.打印访问情况
log.info("{} {} {} {}", request.getRemoteAddress(), request.getMethod(), response.getStatusCode(), request.getPath());
// todo 有待优化,目前逻辑为,5s内连续访问30次,需要冷却10s
// 1. 获取当前ip
String ip = RequestUtils.getIpAddress(request);
// 2. 获取ip计数
String key = GATEWAY_IPLIMIT_KEY + ip;
Integer ipCount = redisTemplate.opsForValue().get(key);
if (ipCount == null) {
ipCount = 0;
}
// 3.判断ip访问次数
if (ipCount > LIMIT_COUNT) {
// 3.1 超过限制,禁止访问
// 3.10 重置当前ip的ttl为限制时长
redisTemplate.opsForValue().set(key, ipCount, LIMIT_TIME, TimeUnit.SECONDS);
// 3.11 设置状态码和响应类型
response.setStatusCode(HttpStatus.LOCKED);
response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
// 3.12 构造返回体
DataBufferFactory bufferFactory = response.bufferFactory();
log.info("ip limit : {} {}", ip, exchange.getRequest().getPath());
DataBuffer wrap = bufferFactory.wrap(objectMapper.writeValueAsBytes(new RestResult<>(408, "访问频繁,请稍后再试")));
return response.writeWith(Mono.fromSupplier(() -> wrap));
} else {
// 3.2 未超过限制,正常访问,访问次数+1
ipCount++;
redisTemplate.opsForValue().set(key, ipCount, GATEWAY_IPLIMIT_TTL, TimeUnit.SECONDS);
// 3.3 放行
return chain.filter(exchange);
RequestPath uri = request.getPath();

// 1. 打印访问情况
log.info("ip={}, method={}, status={}, uri={}", ip, request.getMethod(), response.getStatusCode(), uri);

// 2. 获取ip计数,缓存中没有则给0
String key = gatewayIpLimitKey + ip;
int ipCount = Optional.ofNullable(redisTemplate.opsForValue().get(key)).orElse(0);

// 3. 超过限制,禁止访问
if (ipCount > gatewayIpLimitCount) {
return response.writeWith(Mono.fromSupplier(() -> buildWrap(response, ip, uri, key)));
}

// 4. ip计数写入redis并放行
redisTemplate.opsForValue().set(key, ++ipCount, gatewayIpLimitTime, TimeUnit.SECONDS);
return chain.filter(exchange);
}

@SneakyThrows
private DataBuffer buildWrap(ServerHttpResponse response, String ip, RequestPath uri, String key) {
// 设置状态码和响应类型
response.setStatusCode(HttpStatus.LOCKED);
response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
// 构造返回体
DataBufferFactory bufferFactory = response.bufferFactory();
// 设置冷却时间
redisTemplate.expire(key, gatewayIpLimitTtl, TimeUnit.SECONDS);
log.info("ip limit : ip={}, uri={}", ip, uri);
return bufferFactory.wrap(objectMapper.writeValueAsBytes(RestResult.limit()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public Mono<Void> handle(ServerWebExchange exchange, Throwable ex) {
DataBufferFactory bufferFactory = response.bufferFactory();
try {
log.warn("Error Gateway : {} {}", ex.getMessage(), exchange.getRequest().getPath());
return bufferFactory.wrap(objectMapper.writeValueAsBytes(new RestResult<>(404, "NOT FOUND")));
return bufferFactory.wrap(objectMapper.writeValueAsBytes(RestResult.notFound()));
} catch (JsonProcessingException e) {
log.error("Error writing response", ex);
return bufferFactory.wrap(new byte[0]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ public class RedisConstants {

public static final String GATEWAY_IPLIMIT_KEY = "gateway:ipLimit:";

public static final Long GATEWAY_IPLIMIT_TTL = 5L;
public static final Long GATEWAY_IPLIMIT = 30L;

public static final Long GATEWAY_IPLIMIT_TTL = 10L;

}