说明
- 假设:有一个 API 接口,里面有两个 SQL 查询语句,第一个查询平均耗时
10秒,第二个查询平均耗时 8秒,如果按照常规写法,无论是先查询第一个 SQL,还是先查询第二个 SQL,该 API 接口的平均耗时都是 18秒(10秒+8秒)
- 最简单的办法就是同时执行这两个 SQL,理论平均耗时为:
max(第一个查询平均耗时 10秒, 第二个查询平均耗时 8秒)=10秒
- 考虑到由于
网络、初始化、懒加载、并发执行 等原因,可能存在误差,这里测试类 java.util.concurrent.ExecutorService 的超时时间误差设置为 1 秒,即:在 11秒 内可以执行完成原本需要 18秒 的 API 接口,详情见下方测试类与截图
源码
接口
CompletableFutureRestController.java
package cn.com.xuxiaowei.controller;
import cn.com.xuxiaowei.vo.CompletableFutureResult;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
@RestController
@RequestMapping("/CompletableFuture")
public class CompletableFutureRestController {
@RequestMapping
public CompletableFutureResult index(Integer time1, Integer time2) {
CompletableFutureResult result = new CompletableFutureResult();
result.setTime1(time1);
result.setTime2(time2);
CompletableFuture<String> time1CompletableFuture = CompletableFuture.supplyAsync(() -> {
try {
// 模拟任务 1 延时,单位:秒
Thread.sleep(time1 * 1000);
return "ok";
}
catch (InterruptedException e) {
return getStackTrace(e);
}
});
CompletableFuture<String> time2CompletableFuture = CompletableFuture.supplyAsync(() -> {
try {
// 模拟任务 2 延时,单位:秒
Thread.sleep(time2 * 1000);
return "ok";
}
catch (InterruptedException e) {
return getStackTrace(e);
}
});
// 线程执行任务 1,处理异常
try {
String string = time1CompletableFuture.get();
result.setTime1Result(string);
}
catch (InterruptedException | ExecutionException e) {
result.setTime1Result(getStackTrace(e));
}
// 线程执行任务 2,处理异常
try {
String string = time2CompletableFuture.get();
result.setTime2Result(string);
}
catch (InterruptedException | ExecutionException e) {
result.setTime2Result(getStackTrace(e));
}
return result;
}
/**
* 获取异常完整的堆栈信息
*/
public static String getStackTrace(Throwable throwable) {
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw, true);
throwable.printStackTrace(pw);
return sw.getBuffer().toString();
}
}
CompletableFutureResult.java
package cn.com.xuxiaowei.vo;
import lombok.Data;
@Data
public class CompletableFutureResult {
private Integer time1;
private Integer time2;
private String time1Result;
private String time2Result;
}
测试类
CompletableFutureRestControllerTests.java
package cn.com.xuxiaowei.controller;
import cn.com.xuxiaowei.vo.CompletableFutureResult;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.web.client.RestTemplate;
import java.util.concurrent.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@Slf4j
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
class CompletableFutureRestControllerTests {
private static final ExecutorService EXECUTOR_SERVICE = Executors
.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
@LocalServerPort
private int serverPort;
@Test
void index() {
// 单位:秒
// 任务 1 所需消耗时间
int time1 = 10;
// 任务 2 所需消耗时间
int time2 = 8;
String url = String.format("http://localhost:%d/CompletableFuture?time1=%s&time2=%s", serverPort, time1, time2);
long start = System.currentTimeMillis();
Future<CompletableFutureResult> future = EXECUTOR_SERVICE.submit(() -> {
CompletableFutureResult result = new RestTemplate().getForObject(url, CompletableFutureResult.class);
log.info("result: {}", result);
return result;
});
// 求出单个任务最大消耗时间
int max = Math.max(time1, time2);
// 由于网络、初始化、懒加载、并发执行等原因,可能存在误差,这里误差设置为 1 秒
int timeout = max + 1;
CompletableFutureResult result = null;
try {
// 定义任务超时时间
result = future.get(timeout, TimeUnit.SECONDS);
}
catch (InterruptedException e) {
log.error("任务中断", e);
}
catch (ExecutionException e) {
log.error("执行异常", e);
}
catch (TimeoutException e) {
log.error("超时异常", e);
}
long end = System.currentTimeMillis();
log.info("用时:{} ms", end - start);
assertNotNull(result);
assertEquals(time1, result.getTime1());
assertEquals(time2, result.getTime2());
assertEquals("ok", result.getTime1Result());
assertEquals("ok", result.getTime2Result());
}
}