1 import java.util.concurrent.ExecutionException;
2 import java.util.concurrent.ForkJoinPool;
3 import java.util.concurrent.ForkJoinTask;
4 import java.util.concurrent.RecursiveTask;
5
6 /**
7 * 并发执行任务
8 */
9 public class CountTask extends RecursiveTask<Integer> {
10
11 private static final long serialVersionUID = 1L;
12 private static final int THRESHOLD = 2; //阈值
13 private int start;
14 private int end;
15
16 public CountTask(int start,int end) {
17 this.start = start;
18 this.end = end;
19 }
20
21 @Override
22 protected Integer compute() {
23 int sum = 0;
24 //如果任务足够小就计算任务
25 boolean canCompute = (end - start) <= THRESHOLD;
26 if(canCompute) {
27 for (int i = start; i <= end; i++) {
28 sum += i;
29 }
30 }else {
31 //如果任务大于阈值 就分裂成两个子任务计算
32 int middle = (start + end) / 2;
33 CountTask leftTask = new CountTask(start, middle);
34 CountTask rightTask = new CountTask(middle + 1, end);
35 leftTask.fork();
36 rightTask.fork();
37 //等待子任务执行完成 并得到其结果
38 int leftResult = leftTask.join();
39 int rightResult = rightTask.join();
40 //合并子任务
41 sum = leftResult + rightResult;
42 }
43 return sum;
44 }
45
46 public static void main(String[] args) {
47 ForkJoinPool forkJoinPool = new ForkJoinPool();
48 //生成一个计算任务,负责计算1+2+3+4
49 CountTask task = new CountTask(1,4);
50 //执行一个任务
51 ForkJoinTask<Integer> result = forkJoinPool.submit(task);
52 try {
53 System.out.println(result.get());
54 } catch (InterruptedException e) {
55 e.printStackTrace();
56 } catch (ExecutionException e) {
57 e.printStackTrace();
58 }
59 }
60 }