Java时间轮算法的实现代码示例

考虑这样一个场景,现在有5000个任务,要让这5000个任务每隔5分中触发某个操作,怎么去实现这个需求。大部分人首先想到的是使用定时器,但是5000个任务,你就要用5000个定时器,一个定时器就是一个线程,你懂了吧,这种方法肯定是不行的。

针对这个场景,催生了时间轮算法,时间轮到底是什么?我一贯的风格,自行谷歌去。大发慈悲,发个时间轮介绍你们看看,看文字和图就好了,代码不要看了,那个文章里的代码运行不起来,时间轮介绍。

看好了介绍,我们就开始动手吧。

开发环境:idea + jdk1.8 + maven

新建一个maven工程

创建如下的目录结构

不要忘了pom.xml中添加netty库

<dependencies>

<dependency>

<groupId>io.netty</groupId>

<artifactId>netty-all</artifactId>

<version>4.1.5.Final</version>

</dependency>

</dependencies>

代码如下

Timeout.Java

package com.tanghuachun.timer;

public interface Timeout {

Timer timer();

TimerTask task();

boolean isExpired();

boolean isCancelled();

boolean cancel();

}

Timer.java

package com.tanghuachun.timer;

import java.util.Set;

import java.util.concurrent.TimeUnit;

public interface Timer {

Timeout newTimeout(TimerTask task, long delay, TimeUnit unit, String argv);

Set<Timeout> stop();

}

TimerTask.java

package com.tanghuachun.timer;

public interface TimerTask {

void run(Timeout timeout, String argv) throws Exception;

}

TimerWheel.java

/*

* Copyright 2012 The Netty Project

*

* The Netty Project licenses this file to you under the Apache License,

* version 2.0 (the "License"); you may not use this file except in compliance

* with the License. You may obtain a copy of the License at:

*

* http://www.apache.org/licenses/LICENSE-2.0

*

* Unless required by applicable law or agreed to in writing, software

* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT

* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the

* License for the specific language governing permissions and limitations

* under the License.

*/

package com.tanghuachun.timer;

import io.netty.util.*;

import io.netty.util.internal.PlatformDependent;

import io.netty.util.internal.StringUtil;

import io.netty.util.internal.logging.InternalLogger;

import io.netty.util.internal.logging.InternalLoggerFactory;

import java.util.Collections;

import java.util.HashSet;

import java.util.Queue;

import java.util.Set;

import java.util.concurrent.CountDownLatch;

import java.util.concurrent.Executors;

import java.util.concurrent.ThreadFactory;

import java.util.concurrent.TimeUnit;

import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;

public class TimerWheel implements Timer {

static final InternalLogger logger =

InternalLoggerFactory.getInstance(TimerWheel.class);

private static final ResourceLeakDetector<TimerWheel> leakDetector = ResourceLeakDetectorFactory.instance()

.newResourceLeakDetector(TimerWheel.class, 1, Runtime.getRuntime().availableProcessors() * 4L);

private static final AtomicIntegerFieldUpdater<TimerWheel> WORKER_STATE_UPDATER;

static {

AtomicIntegerFieldUpdater<TimerWheel> workerStateUpdater =

PlatformDependent.newAtomicIntegerFieldUpdater(TimerWheel.class, "workerState");

if (workerStateUpdater == null) {

workerStateUpdater = AtomicIntegerFieldUpdater.newUpdater(TimerWheel.class, "workerState");

}

WORKER_STATE_UPDATER = workerStateUpdater;

}

private final ResourceLeak leak;

private final Worker worker = new Worker();

private final Thread workerThread;

public static final int WORKER_STATE_INIT = 0;

public static final int WORKER_STATE_STARTED = 1;

public static final int WORKER_STATE_SHUTDOWN = 2;

@SuppressWarnings({ "unused", "FieldMayBeFinal", "RedundantFieldInitialization" })

private volatile int workerState = WORKER_STATE_INIT; // 0 - init, 1 - started, 2 - shut down

private final long tickDuration;

private final HashedWheelBucket[] wheel;

private final int mask;

private final CountDownLatch startTimeInitialized = new CountDownLatch(1);

private final Queue<HashedWheelTimeout> timeouts = PlatformDependent.newMpscQueue();

private final Queue<HashedWheelTimeout> cancelledTimeouts = PlatformDependent.newMpscQueue();

private volatile long startTime;

/**

* Creates a new timer with the default thread factory

* ({@link Executors#defaultThreadFactory()}), default tick duration, and

* default number of ticks per wheel.

*/

public TimerWheel() {

this(Executors.defaultThreadFactory());

}

/**

* Creates a new timer with the default thread factory

* ({@link Executors#defaultThreadFactory()}) and default number of ticks

* per wheel.

*

* @param tickDuration the duration between tick

* @param unit the time unit of the {@code tickDuration}

* @throws NullPointerException if {@code unit} is {@code null}

* @throws IllegalArgumentException if {@code tickDuration} is <= 0

*/

public TimerWheel(long tickDuration, TimeUnit unit) {

this(Executors.defaultThreadFactory(), tickDuration, unit);

}

/**

* Creates a new timer with the default thread factory

* ({@link Executors#defaultThreadFactory()}).

*

* @param tickDuration the duration between tick

* @param unit the time unit of the {@code tickDuration}

* @param ticksPerWheel the size of the wheel

* @throws NullPointerException if {@code unit} is {@code null}

* @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is <= 0

*/

public TimerWheel(long tickDuration, TimeUnit unit, int ticksPerWheel) {

this(Executors.defaultThreadFactory(), tickDuration, unit, ticksPerWheel);

}

/**

* Creates a new timer with the default tick duration and default number of

* ticks per wheel.

*

* @param threadFactory a {@link ThreadFactory} that creates a

* background {@link Thread} which is dedicated to

* {@link TimerTask} execution.

* @throws NullPointerException if {@code threadFactory} is {@code null}

*/

public TimerWheel(ThreadFactory threadFactory) {

this(threadFactory, 100, TimeUnit.MILLISECONDS);

}

/**

* Creates a new timer with the default number of ticks per wheel.

*

* @param threadFactory a {@link ThreadFactory} that creates a

* background {@link Thread} which is dedicated to

* {@link TimerTask} execution.

* @param tickDuration the duration between tick

* @param unit the time unit of the {@code tickDuration}

* @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code null}

* @throws IllegalArgumentException if {@code tickDuration} is <= 0

*/

public TimerWheel(

ThreadFactory threadFactory, long tickDuration, TimeUnit unit) {

this(threadFactory, tickDuration, unit, 512);

}

/**

* Creates a new timer.

*

* @param threadFactory a {@link ThreadFactory} that creates a

* background {@link Thread} which is dedicated to

* {@link TimerTask} execution.

* @param tickDuration the duration between tick

* @param unit the time unit of the {@code tickDuration}

* @param ticksPerWheel the size of the wheel

* @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code null}

* @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is <= 0

*/

public TimerWheel(

ThreadFactory threadFactory,

long tickDuration, TimeUnit unit, int ticksPerWheel) {

this(threadFactory, tickDuration, unit, ticksPerWheel, true);

}

/**

* Creates a new timer.

*

* @param threadFactory a {@link ThreadFactory} that creates a

* background {@link Thread} which is dedicated to

* {@link TimerTask} execution.

* @param tickDuration the duration between tick

* @param unit the time unit of the {@code tickDuration}

* @param ticksPerWheel the size of the wheel

* @param leakDetection {@code true} if leak detection should be enabled always, if false it will only be enabled

* if the worker thread is not a daemon thread.

* @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code null}

* @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is <= 0

*/

public TimerWheel(

ThreadFactory threadFactory,

long tickDuration, TimeUnit unit, int ticksPerWheel, boolean leakDetection) {

if (threadFactory == null) {

throw new NullPointerException("threadFactory");

}

if (unit == null) {

throw new NullPointerException("unit");

}

if (tickDuration <= 0) {

throw new IllegalArgumentException("tickDuration must be greater than 0: " + tickDuration);

}

if (ticksPerWheel <= 0) {

throw new IllegalArgumentException("ticksPerWheel must be greater than 0: " + ticksPerWheel);

}

// Normalize ticksPerWheel to power of two and initialize the wheel.

wheel = createWheel(ticksPerWheel);

mask = wheel.length - 1;

// Convert tickDuration to nanos.

this.tickDuration = unit.toNanos(tickDuration);

// Prevent overflow.

if (this.tickDuration >= Long.MAX_VALUE / wheel.length) {

throw new IllegalArgumentException(String.format(

"tickDuration: %d (expected: 0 < tickDuration in nanos < %d",

tickDuration, Long.MAX_VALUE / wheel.length));

}

workerThread = threadFactory.newThread(worker);

leak = leakDetection || !workerThread.isDaemon() ? leakDetector.open(this) : null;

}

private static HashedWheelBucket[] createWheel(int ticksPerWheel) {

if (ticksPerWheel <= 0) {

throw new IllegalArgumentException(

"ticksPerWheel must be greater than 0: " + ticksPerWheel);

}

if (ticksPerWheel > 1073741824) {

throw new IllegalArgumentException(

"ticksPerWheel may not be greater than 2^30: " + ticksPerWheel);

}

ticksPerWheel = normalizeTicksPerWheel(ticksPerWheel);

HashedWheelBucket[] wheel = new HashedWheelBucket[ticksPerWheel];

for (int i = 0; i < wheel.length; i ++) {

wheel[i] = new HashedWheelBucket();

}

return wheel;

}

private static int normalizeTicksPerWheel(int ticksPerWheel) {

int normalizedTicksPerWheel = 1;

while (normalizedTicksPerWheel < ticksPerWheel) {

normalizedTicksPerWheel <<= 1;

}

return normalizedTicksPerWheel;

}

/**

* Starts the background thread explicitly. The background thread will

* start automatically on demand even if you did not call this method.

*

* @throws IllegalStateException if this timer has been

* {@linkplain #stop() stopped} already

*/

public void start() {

switch (WORKER_STATE_UPDATER.get(this)) {

case WORKER_STATE_INIT:

if (WORKER_STATE_UPDATER.compareAndSet(this, WORKER_STATE_INIT, WORKER_STATE_STARTED)) {

workerThread.start();

}

break;

case WORKER_STATE_STARTED:

break;

case WORKER_STATE_SHUTDOWN:

throw new IllegalStateException("cannot be started once stopped");

default:

throw new Error("Invalid WorkerState");

}

// Wait until the startTime is initialized by the worker.

while (startTime == 0) {

try {

startTimeInitialized.await();

} catch (InterruptedException ignore) {

// Ignore - it will be ready very soon.

}

}

}

@Override

public Set<Timeout> stop() {

if (Thread.currentThread() == workerThread) {

throw new IllegalStateException(

TimerWheel.class.getSimpleName() +

".stop() cannot be called from " +

TimerTask.class.getSimpleName());

}

if (!WORKER_STATE_UPDATER.compareAndSet(this, WORKER_STATE_STARTED, WORKER_STATE_SHUTDOWN)) {

// workerState can be 0 or 2 at this moment - let it always be 2.

WORKER_STATE_UPDATER.set(this, WORKER_STATE_SHUTDOWN);

if (leak != null) {

leak.close();

}

return Collections.emptySet();

}

boolean interrupted = false;

while (workerThread.isAlive()) {

workerThread.interrupt();

try {

workerThread.join(100);

} catch (InterruptedException ignored) {

interrupted = true;

}

}

if (interrupted) {

Thread.currentThread().interrupt();

}

if (leak != null) {

leak.close();

}

return worker.unprocessedTimeouts();

}

@Override

public Timeout newTimeout(TimerTask task, long delay, TimeUnit unit, String argv) {

if (task == null) {

throw new NullPointerException("task");

}

if (unit == null) {

throw new NullPointerException("unit");

}

start();

// Add the timeout to the timeout queue which will be processed on the next tick.

// During processing all the queued HashedWheelTimeouts will be added to the correct HashedWheelBucket.

long deadline = System.nanoTime() + unit.toNanos(delay) - startTime;

HashedWheelTimeout timeout = new HashedWheelTimeout(this, task, deadline, argv);

timeouts.add(timeout);

return timeout;

}

private final class Worker implements Runnable {

private final Set<Timeout> unprocessedTimeouts = new HashSet<Timeout>();

private long tick;

@Override

public void run() {

// Initialize the startTime.

startTime = System.nanoTime();

if (startTime == 0) {

// We use 0 as an indicator for the uninitialized value here, so make sure it's not 0 when initialized.

startTime = 1;

}

// Notify the other threads waiting for the initialization at start().

startTimeInitialized.countDown();

do {

final long deadline = waitForNextTick();

if (deadline > 0) {

int idx = (int) (tick & mask);

processCancelledTasks();

HashedWheelBucket bucket =

wheel[idx];

transferTimeoutsToBuckets();

bucket.expireTimeouts(deadline);

tick++;

}

} while (WORKER_STATE_UPDATER.get(TimerWheel.this) == WORKER_STATE_STARTED);

// Fill the unprocessedTimeouts so we can return them from stop() method.

for (HashedWheelBucket bucket: wheel) {

bucket.clearTimeouts(unprocessedTimeouts);

}

for (;;) {

HashedWheelTimeout timeout = timeouts.poll();

if (timeout == null) {

break;

}

if (!timeout.isCancelled()) {

unprocessedTimeouts.add(timeout);

}

}

processCancelledTasks();

}

private void transferTimeoutsToBuckets() {

// transfer only max. 100000 timeouts per tick to prevent a thread to stale the workerThread when it just

// adds new timeouts in a loop.

for (int i = 0; i < 100000; i++) {

HashedWheelTimeout timeout = timeouts.poll();

if (timeout == null) {

// all processed

break;

}

if (timeout.state() == HashedWheelTimeout.ST_CANCELLED) {

// Was cancelled in the meantime.

continue;

}

long calculated = timeout.deadline / tickDuration;

timeout.remainingRounds = (calculated - tick) / wheel.length;

final long ticks = Math.max(calculated, tick); // Ensure we don't schedule for past.

int stopIndex = (int) (ticks & mask);

HashedWheelBucket bucket = wheel[stopIndex];

bucket.addTimeout(timeout);

}

}

private void processCancelledTasks() {

for (;;) {

HashedWheelTimeout timeout = cancelledTimeouts.poll();

if (timeout == null) {

// all processed

break;

}

try {

timeout.remove();

} catch (Throwable t) {

if (logger.isWarnEnabled()) {

logger.warn("An exception was thrown while process a cancellation task", t);

}

}

}

}

/**

* calculate goal nanoTime from startTime and current tick number,

* then wait until that goal has been reached.

* @return Long.MIN_VALUE if received a shutdown request,

* current time otherwise (with Long.MIN_VALUE changed by +1)

*/

private long waitForNextTick() {

long deadline = tickDuration * (tick + 1);

for (;;) {

final long currentTime = System.nanoTime() - startTime;

long sleepTimeMs = (deadline - currentTime + 999999) / 1000000;

if (sleepTimeMs <= 0) {

if (currentTime == Long.MIN_VALUE) {

return -Long.MAX_VALUE;

} else {

return currentTime;

}

}

// Check if we run on windows, as if thats the case we will need

// to round the sleepTime as workaround for a bug that only affect

// the JVM if it runs on windows.

//

// See https://github.com/netty/netty/issues/356

if (PlatformDependent.isWindows()) {

sleepTimeMs = sleepTimeMs / 10 * 10;

}

try {

Thread.sleep(sleepTimeMs);

} catch (InterruptedException ignored) {

if (WORKER_STATE_UPDATER.get(TimerWheel.this) == WORKER_STATE_SHUTDOWN) {

return Long.MIN_VALUE;

}

}

}

}

public Set<Timeout> unprocessedTimeouts() {

return Collections.unmodifiableSet(unprocessedTimeouts);

}

}

private static final class HashedWheelTimeout implements Timeout {

private static final int ST_INIT = 0;

private static final int ST_CANCELLED = 1;

private static final int ST_EXPIRED = 2;

private static final AtomicIntegerFieldUpdater<HashedWheelTimeout> STATE_UPDATER;

static {

AtomicIntegerFieldUpdater<HashedWheelTimeout> updater =

PlatformDependent.newAtomicIntegerFieldUpdater(HashedWheelTimeout.class, "state");

if (updater == null) {

updater = AtomicIntegerFieldUpdater.newUpdater(HashedWheelTimeout.class, "state");

}

STATE_UPDATER = updater;

}

private final TimerWheel timer;

private final TimerTask task;

private final long deadline;

@SuppressWarnings({"unused", "FieldMayBeFinal", "RedundantFieldInitialization" })

private volatile int state = ST_INIT;

// remainingRounds will be calculated and set by Worker.transferTimeoutsToBuckets() before the

// HashedWheelTimeout will be added to the correct HashedWheelBucket.

long remainingRounds;

String argv;

// This will be used to chain timeouts in HashedWheelTimerBucket via a double-linked-list.

// As only the workerThread will act on it there is no need for synchronization / volatile.

HashedWheelTimeout next;

HashedWheelTimeout prev;

// The bucket to which the timeout was added

HashedWheelBucket bucket;

HashedWheelTimeout(TimerWheel timer, TimerTask task, long deadline, String argv) {

this.timer = timer;

this.task = task;

this.deadline = deadline;

this.argv = argv;

}

@Override

public Timer timer() {

return timer;

}

@Override

public TimerTask task() {

return task;

}

@Override

public boolean cancel() {

// only update the state it will be removed from HashedWheelBucket on next tick.

if (!compareAndSetState(ST_INIT, ST_CANCELLED)) {

return false;

}

// If a task should be canceled we put this to another queue which will be processed on each tick.

// So this means that we will have a GC latency of max. 1 tick duration which is good enough. This way

// we can make again use of our MpscLinkedQueue and so minimize the locking / overhead as much as possible.

timer.cancelledTimeouts.add(this);

return true;

}

void remove() {

HashedWheelBucket bucket = this.bucket;

if (bucket != null) {

bucket.remove(this);

}

}

public boolean compareAndSetState(int expected, int state) {

return STATE_UPDATER.compareAndSet(this, expected, state);

}

public int state() {

return state;

}

@Override

public boolean isCancelled() {

return state() == ST_CANCELLED;

}

@Override

public boolean isExpired() {

return state() == ST_EXPIRED;

}

public void expire() {

if (!compareAndSetState(ST_INIT, ST_EXPIRED)) {

return;

}

try {

task.run(this, argv);

} catch (Throwable t) {

if (logger.isWarnEnabled()) {

logger.warn("An exception was thrown by " + TimerTask.class.getSimpleName() + '.', t);

}

}

}

@Override

public String toString() {

final long currentTime = System.nanoTime();

long remaining = deadline - currentTime + timer.startTime;

StringBuilder buf = new StringBuilder(192)

.append(StringUtil.simpleClassName(this))

.append('(')

.append("deadline: ");

if (remaining > 0) {

buf.append(remaining)

.append(" ns later");

} else if (remaining < 0) {

buf.append(-remaining)

.append(" ns ago");

} else {

buf.append("now");

}

if (isCancelled()) {

buf.append(", cancelled");

}

return buf.append(", task: ")

.append(task())

.append(')')

.toString();

}

}

/**

* Bucket that stores HashedWheelTimeouts. These are stored in a linked-list like datastructure to allow easy

* removal of HashedWheelTimeouts in the middle. Also the HashedWheelTimeout act as nodes themself and so no

* extra object creation is needed.

*/

private static final class HashedWheelBucket {

// Used for the linked-list datastructure

private HashedWheelTimeout head;

private HashedWheelTimeout tail;

/**

* Add {@link HashedWheelTimeout} to this bucket.

*/

public void addTimeout(HashedWheelTimeout timeout) {

assert timeout.bucket == null;

timeout.bucket = this;

if (head == null) {

head = tail = timeout;

} else {

tail.next = timeout;

timeout.prev = tail;

tail = timeout;

}

}

/**

* Expire all {@link HashedWheelTimeout}s for the given {@code deadline}.

*/

public void expireTimeouts(long deadline) {

HashedWheelTimeout timeout = head;

// process all timeouts

while (timeout != null) {

boolean remove = false;

if (timeout.remainingRounds <= 0) {

if (timeout.deadline <= deadline) {

timeout.expire();

} else {

// The timeout was placed into a wrong slot. This should never happen.

throw new IllegalStateException(String.format(

"timeout.deadline (%d) > deadline (%d)", timeout.deadline, deadline));

}

remove = true;

} else if (timeout.isCancelled()) {

remove = true;

} else {

timeout.remainingRounds --;

}

// store reference to next as we may null out timeout.next in the remove block.

HashedWheelTimeout next = timeout.next;

if (remove) {

remove(timeout);

}

timeout = next;

}

}

public void remove(HashedWheelTimeout timeout) {

HashedWheelTimeout next = timeout.next;

// remove timeout that was either processed or cancelled by updating the linked-list

if (timeout.prev != null) {

timeout.prev.next = next;

}

if (timeout.next != null) {

timeout.next.prev = timeout.prev;

}

if (timeout == head) {

// if timeout is also the tail we need to adjust the entry too

if (timeout == tail) {

tail = null;

head = null;

} else {

head = next;

}

} else if (timeout == tail) {

// if the timeout is the tail modify the tail to be the prev node.

tail = timeout.prev;

}

// null out prev, next and bucket to allow for GC.

timeout.prev = null;

timeout.next = null;

timeout.bucket = null;

}

/**

* Clear this bucket and return all not expired / cancelled {@link Timeout}s.

*/

public void clearTimeouts(Set<Timeout> set) {

for (;;) {

HashedWheelTimeout timeout = pollTimeout();

if (timeout == null) {

return;

}

if (timeout.isExpired() || timeout.isCancelled()) {

continue;

}

set.add(timeout);

}

}

private HashedWheelTimeout pollTimeout() {

HashedWheelTimeout head = this.head;

if (head == null) {

return null;

}

HashedWheelTimeout next = head.next;

if (next == null) {

tail = this.head = null;

} else {

this.head = next;

next.prev = null;

}

// null out prev and next to allow for GC.

head.next = null;

head.prev = null;

head.bucket = null;

return head;

}

}

}

编写测试类Main.java

package com.tanghuachun.timer;

import java.util.concurrent.TimeUnit;

/**

* Created by darren on 2016/11/17.

*/

public class Main implements TimerTask{

final static Timer timer = new TimerWheel();

public static void main(String[] args) {

TimerTask timerTask = new Main();

for (int i = 0; i < 10; i++) {

timer.newTimeout(timerTask, 5, TimeUnit.SECONDS, "" + i );

}

}

@Override

public void run(Timeout timeout, String argv) throws Exception {

System.out.println("timeout, argv = " + argv );

}

}

然后就可以看到运行结果啦。

工程代码下载(以maven的方式导入)。

以上是 Java时间轮算法的实现代码示例 的全部内容, 来源链接: utcz.com/p/214403.html

回到顶部