import java.util.NoSuchElementException;
import java.util.Random;
import java.util.concurrent.atomic.AtomicReference;

public class LockFreeStack<T> implements MyStack<T> {
    AtomicReference<Node> top = new AtomicReference<Node>(null);

    // Use backoff, tweak these numbers to improve perf
    static final int MIN_DELAY = 1;
    static final int MAX_DELAY = 100;
    Backoff backoff = new Backoff(MIN_DELAY, MAX_DELAY);

    // Push logic
    protected boolean tryPush(Node node) {
                Node oldTop = top.get();
                node.next = oldTop;
                return(top.compareAndSet(oldTop, node));
    }

    public void push(T value) {
        Node node = new Node(value);
                while(true) {
                    if(tryPush(node)) {
                        return;
                    } else {
                        try {
                                        backoff.backoff();
                                } catch (InterruptedException e) {
                                        // do nothing
                                }
                    }
                }
    }


    // Pop logic
    protected Node tryPop() throws NoSuchElementException {
                Node oldTop = top.get();
                if(oldTop == null) {
                    throw new NoSuchElementException();
                }

                Node newTop = oldTop.next;
                if(top.compareAndSet(oldTop, newTop)) {
                    return oldTop;
                } else {
                    return null;
                }
    }

    public T pop() throws NoSuchElementException {
                while(true) {
                    Node returnNode = tryPop();
                    if(returnNode != null) {
                        return returnNode.value;
                    } else {
                        try {
                                        backoff.backoff();
                                } catch (InterruptedException e) {
                                        // do nothing
                                }
                    }
                }
    }

    // Node class
    public class Node {
                public T value;
                //public AtomicStampedReference<Node> next;
                public Node next;

                public Node(T value) {
                    this.value = value;
                    next = null;
                }
    }

    // Backoff class
    public class Backoff {
                final int minDelay, maxDelay;
                int limit;
                final Random random;
                public Backoff(int min, int max) {
                    minDelay = min;
                    maxDelay = max;
                    limit = minDelay;
                    random = new Random();
                }

                public void backoff() throws InterruptedException {
                    int delay = random.nextInt(limit);
                    limit = Math.min(maxDelay, 2*limit);
                    Thread.sleep(delay);
                }
    }

    public static void main(String [] args) {
          int runs = 10;
          int numThreads = 6;
          int numIterations = 120000;

          double runtimes[] = new double[runs];


        //If command line argument present, gives patience in seconds.
        if (args.length > 0) {
            try {
                runs = Integer.parseInt(args[0].trim());
                numThreads = Integer.parseInt(args[1].trim());
                numIterations = Integer.parseInt(args[2].trim());
                System.out.format("Runs = %d, NumThreads = %d, NumIterations = %d\n", runs, numThreads, numIterations);
            } catch (NumberFormatException e) {
                System.err.println("Argument must be an integer.");
                System.exit(1);
            }
        }

          for (int i = 0; i < runs; ++i) {
                  MyStack<Integer> stack = new LockFreeStack<Integer>();
                  StackTester tester = new StackTester(stack);
                  runtimes[i] = tester.runTest(numThreads, numIterations);

                  //System.out.println("Run time (" + (i + 1) + "): " + runtimes[i] + "s");
          }
          double totalRuntime = 0;
          for (int i = 0; i < runs; ++i) {
                  totalRuntime += runtimes[i];
          }

          System.out.println("Average runtime of " + runs + " runs: " + (totalRuntime / runs) + "s");
    }
}