Java 多线程任务 CompletableFuture

说明

  • 假设:有一个 API 接口,里面有两个 SQL 查询语句,第一个查询平均耗时 10秒,第二个查询平均耗时 8秒,如果按照常规写法,无论是先查询第一个 SQL,还是先查询第二个 SQL,该 API 接口的平均耗时都是 18秒10秒+8秒
    1. 最简单的办法就是同时执行这两个 SQL,理论平均耗时为:max(第一个查询平均耗时 10秒, 第二个查询平均耗时 8秒)=10秒
    2. 考虑到由于网络初始化懒加载并发执行 等原因,可能存在误差,这里测试类 java.util.concurrent.ExecutorService 的超时时间误差设置为 1 秒,即:在 11秒 内可以执行完成原本需要 18秒 的 API 接口,详情见下方测试类与截图

源码

接口

CompletableFutureRestController.java

package cn.com.xuxiaowei.controller;

import cn.com.xuxiaowei.vo.CompletableFutureResult;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

@RestController
@RequestMapping("/CompletableFuture")
public class CompletableFutureRestController {

	@RequestMapping
	public CompletableFutureResult index(Integer time1, Integer time2) {
		CompletableFutureResult result = new CompletableFutureResult();
		result.setTime1(time1);
		result.setTime2(time2);

		CompletableFuture<String> time1CompletableFuture = CompletableFuture.supplyAsync(() -> {
			try {
				// 模拟任务 1 延时,单位:秒
				Thread.sleep(time1 * 1000);
				return "ok";
			}
			catch (InterruptedException e) {
				return getStackTrace(e);
			}
		});

		CompletableFuture<String> time2CompletableFuture = CompletableFuture.supplyAsync(() -> {
			try {
				// 模拟任务 2 延时,单位:秒
				Thread.sleep(time2 * 1000);
				return "ok";
			}
			catch (InterruptedException e) {
				return getStackTrace(e);
			}
		});

		// 线程执行任务 1,处理异常
		try {
			String string = time1CompletableFuture.get();
			result.setTime1Result(string);
		}
		catch (InterruptedException | ExecutionException e) {
			result.setTime1Result(getStackTrace(e));
		}

		// 线程执行任务 2,处理异常
		try {
			String string = time2CompletableFuture.get();
			result.setTime2Result(string);
		}
		catch (InterruptedException | ExecutionException e) {
			result.setTime2Result(getStackTrace(e));
		}

		return result;
	}

	/**
	 * 获取异常完整的堆栈信息
	 */
	public static String getStackTrace(Throwable throwable) {
		StringWriter sw = new StringWriter();
		PrintWriter pw = new PrintWriter(sw, true);
		throwable.printStackTrace(pw);
		return sw.getBuffer().toString();
	}

}

CompletableFutureResult.java

package cn.com.xuxiaowei.vo;

import lombok.Data;

@Data
public class CompletableFutureResult {

	private Integer time1;

	private Integer time2;

	private String time1Result;

	private String time2Result;

}

测试类

CompletableFutureRestControllerTests.java

package cn.com.xuxiaowei.controller;

import cn.com.xuxiaowei.vo.CompletableFutureResult;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.web.client.RestTemplate;

import java.util.concurrent.*;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

@Slf4j
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
class CompletableFutureRestControllerTests {

	private static final ExecutorService EXECUTOR_SERVICE = Executors
		.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);

	@LocalServerPort
	private int serverPort;

	@Test
	void index() {
		// 单位:秒
		// 任务 1 所需消耗时间
		int time1 = 10;
		// 任务 2 所需消耗时间
		int time2 = 8;

		String url = String.format("http://localhost:%d/CompletableFuture?time1=%s&time2=%s", serverPort, time1, time2);

		long start = System.currentTimeMillis();

		Future<CompletableFutureResult> future = EXECUTOR_SERVICE.submit(() -> {
			CompletableFutureResult result = new RestTemplate().getForObject(url, CompletableFutureResult.class);
			log.info("result: {}", result);
			return result;
		});

		// 求出单个任务最大消耗时间
		int max = Math.max(time1, time2);
		// 由于网络、初始化、懒加载、并发执行等原因,可能存在误差,这里误差设置为 1 秒
		int timeout = max + 1;

		CompletableFutureResult result = null;

		try {
			// 定义任务超时时间
			result = future.get(timeout, TimeUnit.SECONDS);
		}
		catch (InterruptedException e) {
			log.error("任务中断", e);
		}
		catch (ExecutionException e) {
			log.error("执行异常", e);
		}
		catch (TimeoutException e) {
			log.error("超时异常", e);
		}

		long end = System.currentTimeMillis();
		log.info("用时:{} ms", end - start);

		assertNotNull(result);
		assertEquals(time1, result.getTime1());
		assertEquals(time2, result.getTime2());
		assertEquals("ok", result.getTime1Result());
		assertEquals("ok", result.getTime2Result());
	}

}