package code.security.ahs;
import code.*;
import code.security.*;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import java.util.Vector;

public class TreeNode{
	
	/**
	 * hash of superhashes of children
	 */
  protected final byte[] hash;        // hash of superhashes of children
  /**
   * the level of current treenode
   */
  protected final int level;
  /**
   * startTS
   */
  protected final long startTS;
  /**
   * endTS
   */
  protected final long endTS;
  /**
   * target
   */
  protected final InvalTarget invalTarget;
  /**
   * the maxDVV and the corresponding hashes (nodeID-> (TS, Hash))
   */
  protected final DVVMap maxDVVMap;
  /**
	 * hash of superhashes of children and local fields
	 */
  protected byte[] superhash;

  // references to children are only valid for internal TreeNode
  /**
   * leftChild : only valid for internal treeNode
   */
  private TreeNode leftChild;
  /**
   * rightChild: only valid for internal treeNode
   */
  private TreeNode rightChild;
  // These references are only valid for leaf nodes and impreciseTreeNode
  /**
   * only valid for leaf nodes and imprecise TreeNodes: first one returns dummyNode
   */
  private TreeNode prevNode;
  /**
   * only valid for leaf nodes and imprecise treeNodes
   */
  private TreeNode nextNode;

  /**
   * the actual invalidate (if any) that corresponds to this treeNode
   */
  protected SecureInv inval;

  protected final static SubscriptionSet ss = SubscriptionSet.makeSubscriptionSet("/*");

  // Dummy TreeNode
  protected TreeNode(){
    hash = null;
    level = -1;
    startTS = endTS = -1;
    invalTarget = null;
    maxDVVMap = null;
    superhash = null;
    leftChild = rightChild = null;
    prevNode = nextNode = null;
  }

  protected TreeNode(AHSEntry ahsEntry){
    this.level = ahsEntry.level;
    this.hash = ahsEntry.hash;
    this.startTS = ahsEntry.startTS;
    this.endTS = ahsEntry.endTS;
    this.invalTarget = ahsEntry.invalTarget;
    this.maxDVVMap = ahsEntry.maxDVVMap;
    leftChild = rightChild = null;
    prevNode = nextNode = null;
    byte[] b = null;
    try{
      b = generateSuperHash();
//      assert (Arrays.equals(b, superhash)): "actual " + DVVMapEntry.byteString(b) + " reported " + DVVMapEntry.byteString(superhash) + "\n" + this.toString();
    }catch(IOException e){
      System.err.println(e.toString());
      System.exit(-1);
    }
    this.superhash = b;
    sanityCheck();
  }

  protected TreeNode(TreeNode org){
    hash = org.hash;
    level = org.level;
    startTS = org.startTS;
    endTS = org.endTS;
    invalTarget = org.invalTarget;
    superhash = org.superhash;
    maxDVVMap = org.maxDVVMap;
    leftChild = rightChild = null;
    prevNode = nextNode = null;

    try{
      byte[] b = generateSuperHash();
      assert (Arrays.equals(b, superhash)): "actual " + DVVMapEntry.byteString(b) + " reported " + DVVMapEntry.byteString(superhash)  + "\n" + this.toString();
    }catch(IOException e){
      System.err.println(e.toString());
      System.exit(-1);
    }
    
    inval = org.inval;
    sanityCheck();
  }

  /**
   * Creates a leaf treeNode using a secure precise inv 
   * @param spi
   * @param ahsMap
   */
  public TreeNode(SecurePreciseInv spi, AHSMap ahsMap){
    long start;
    if(SecurityFilter.measureTime){
      start = System.currentTimeMillis();
    }
    this.level = 0;
//    hash = DataHash.getMD(spi.obj2Bytes());
    hash = new byte[20];
    for(int i = 0; i < 20; i++){
      hash[i] = 0;
    }
    startTS = endTS = spi.getAcceptStamp().getLocalClock();
//    if(spi.getInvalTarget() instanceof HierInvalTarget){
      invalTarget = spi.getInvalTarget();
//    }else{
//      ObjInvalTarget oit = (ObjInvalTarget)spi.getInvalTarget();
//      invalTarget = HierInvalTarget.makeHierInvalTarget(oit.getObjId().getPath());
//    }
    
    maxDVVMap = new DVVMap(spi.getDVV(), ahsMap);
    leftChild = rightChild = null;
    prevNode = nextNode = null;
    byte[] b = null;
    try{
      b = generateSuperHash();
    }catch(IOException e){
      System.err.println(e.toString());
      System.exit(-1);
    }
    finally{
      if(SecurityFilter.measureTime){
        AHSMap.linkCheck += System.currentTimeMillis() - start;
      }
    }
    superhash = b;
    inval = spi;
    sanityCheck();
  }

    private void sanityCheck(){
	if(SecurityFilter.sanityCheck){

	DependencyVV dvv = maxDVVMap.getDVV();
	VVIterator vvi = dvv.getIterator();
	long ts = endTS;
	while(vvi.hasMoreElements()){
	    NodeId n = vvi.getNext();
	    assert dvv.getStampByIteratorToken(n) <= ts: dvv + " ts " + ts;
	    assert maxDVVMap.getEntry(n).hasHash(): this;
	}
	}
    }
    private byte[] generateSuperHash() throws IOException{
      if(SangminConfig.forkjoin){
        /*
         * superhash is generated when this node is linked with prev leaf node
         */
        return null;
      }
    byte[] b = null;
    ByteArrayOutputStream bs = new ByteArrayOutputStream();
    ObjectOutputStream oos = new ObjectOutputStream(bs);

    oos.writeLong(startTS);
    oos.writeLong(endTS);
    oos.write(hash);
    oos.write(level);
    oos.write(invalTarget.toString().getBytes());
    oos.write(maxDVVMap.obj2Bytes());
    oos.flush();
    b = DataHash.getMD(bs.toByteArray());
    oos.close();
    bs.close();

    return b;
  }

  /**
   * creates an internal treeNode using the left child, right child and the writer
   * @param lChild
   * @param rChild
   * @param writer
   */
  public TreeNode(TreeNode lChild, TreeNode rChild, NodeId writer){
    assert(lChild.level == rChild.level);
    assert(lChild.endTS < rChild.startTS);
    long start;
    if(SecurityFilter.measureTime){
      start = System.currentTimeMillis();
    }
    byte[] b = null;

    ByteArrayOutputStream buffer = new ByteArrayOutputStream();

    this.leftChild = lChild;
    this.rightChild = rChild;

    if(lChild instanceof ImpreciseTreeNode){
      ((ImpreciseTreeNode)lChild).parent = this;
    }
    if(rChild instanceof ImpreciseTreeNode){
      ((ImpreciseTreeNode)rChild).parent = this;
    }
    
    if(SangminConfig.forkjoin){
      this.hash = null;
    } else {
      try{
        buffer.write(lChild.superhash);
        buffer.write(rChild.superhash);

        b = DataHash.getMD(buffer.toByteArray());
        buffer.close();

      }catch(IOException e){
        System.err.println(e.toString());
        System.exit(-1);
      }
      this.hash = b;
    }
    
    this.level = lChild.level +1;
    this.startTS = lChild.startTS;
    this.endTS = rChild.endTS;
    
    this.maxDVVMap = new DVVMap(lChild.getMaxDVVMap(), rChild.getMaxDVVMap(), writer);
    this.invalTarget = leftChild.invalTarget.getUnion(rightChild.invalTarget, ss);

    if(SangminConfig.forkjoin){
      this.superhash = rightChild.superhash;
    } else {
      b = null;
      try{
        b = this.generateSuperHash();
      }catch(IOException e){
        System.err.println(e.toString());
        System.exit(-1);
      }    
      finally{
        if(SecurityFilter.measureTime){
          AHSMap.linkCheck += System.currentTimeMillis() - start;
        }
      }

      this.superhash = b;
    }
  }

  public SecureInv getInv(){
    return this.inval;
  }

  public long getStartTS(){
    return this.startTS;
  }

  public long getEndTS(){
    return this.endTS;
  }

  public synchronized TreeNode getRightChild(){
    return this.rightChild;
  }

  public synchronized void setRightChild(TreeNode right){
    this.rightChild = right;
  }

  public synchronized TreeNode getLeftChild(){
    return this.leftChild;
  }

  public synchronized void setLeftChild(TreeNode left){
    assert(this.level != 0);
    this.leftChild = left;
  }

  public synchronized TreeNode getNext(){
    return this.nextNode;
  }

  public synchronized void setNext(TreeNode next){
    this.nextNode = next;
  }

  public synchronized TreeNode getPrev(){
    return this.prevNode;
  }

  public synchronized void setPrev(TreeNode prev){
    this.prevNode = prev;
  }
  
  public synchronized void setSuperHash(byte [] superhash){
    assert SangminConfig.forkjoin && this.level < 0; // only for dummy node
    this.superhash = superhash;
  }
  
  public synchronized void updateSuperHash() throws IOException{
    assert this.prevNode != null || prevNode.endTS == -1;
    assert SangminConfig.forkjoin;
    byte[] prevHash;
    if(prevNode.superhash == null){
      assert prevNode.endTS < 0;
      prevHash = new byte[20];
      for(int i = 0; i < 20; i++){
        prevHash[i] = 0;
      }
    } else {
      prevHash = prevNode.superhash;
    }
    //byte[] b = null;
    ByteArrayOutputStream bs = new ByteArrayOutputStream();
    ObjectOutputStream oos = new ObjectOutputStream(bs);

    oos.writeLong(startTS);
    oos.writeLong(endTS);
    oos.write(prevHash);
    oos.write(level);
    oos.write(invalTarget.toString().getBytes());
    //oos.write(maxDVVMap.obj2Bytes());
    oos.flush();
    superhash = DataHash.getMD(bs.toByteArray());
    oos.close();
    bs.close();
  	
  	
  }


  public DVVMap getMaxDVVMap(){
    return this.maxDVVMap;
  }

  public DependencyVV getDVV(){
    return this.maxDVVMap.getDVV();
  }

  /**
   * Returns the AHS corresponding to the range [start, end]
   * Only values known in the range are included. Hence, even 
   * if you give larger range, it will silently report the AHS
   * @param start
   * @param end
   * @param ahs
   */
  public void getTreeAHS(long start, long end, AHS ahs){

    if(this.startTS >= start
        && this.endTS <= end){
      ahs.add(this);
      return;
    }

    if(this.endTS < start || this.startTS > end){
      return;
    }

    if(start < this.startTS){
      start = this.startTS;
    }
    if(end > this.endTS){
      end = this.endTS;
    }

    /**
     * right end of the left range
     */
    long rightOfLeftRange = (leftChild.endTS < end)?leftChild.endTS:end;
    long leftOfRightRange = (rightChild.startTS > start)?rightChild.startTS:start;
    this.leftChild.getTreeAHS(start, rightOfLeftRange, ahs);
    this.rightChild.getTreeAHS(leftOfRightRange, end, ahs);
    return;
  }

  public synchronized boolean isPrecise()
  {
	return true;  
  }
  
  /**
   * Returns the AHS corresponding to the range [absoluteStart, end]
   * Only values known in the range are included. Hence, even 
   * if you give larger range, it will silently report the AHS
   * @param end
   * @param ahs
   */
  public void getTreeAHS(long ts, AHS ahs){
    getTreeAHS(this.startTS, ts, ahs);
  }

  /**
   * returns the leaf tree node that contains this ts
   * By leaf, we means a node that has no child
   * Therefore, if it encounters imprecise tree node while traversing,
   * the imprecise tree node is returned. 
   * returns null if no matching tree node
   * @param ts
   * @return
   */
  public TreeNode getLeafTS(long ts){
    if(this.startTS > ts || this.endTS < ts){
      return null;
    }

    TreeNode start = this;
    while(start.level > 0){
      if(start instanceof ImpreciseTreeNode){
        return start;
      }else{
        if(start.leftChild.endTS >= ts){
          start = start.leftChild;
        }else if(start.rightChild.startTS <= ts){
          start = start.rightChild;
        }else{
          return null;
        }
      }
    }
    assert start.startTS == start.endTS;
    assert start.endTS == ts;
    return start;
  }

  /**
   * returns the most precise tree (traversing RootLists recursive for impreciseTreeNodes)
   * node that contains this ts
   * returns null if no matching treenode
   * @param ts
   * @return
   */
  public TreeNode getTreeNodeTS(long ts){

    if(this.startTS > ts || this.endTS < ts){
      return null;
    }

    TreeNode start = this;

    while(start.level > 0){
      if(start instanceof ImpreciseTreeNode){
        return ((ImpreciseTreeNode)start).getImpreciseTreeNodeAS(ts);
      }else{
        if(start.leftChild.endTS >= ts){
          start = start.leftChild;
        }else if(start.rightChild.startTS <= ts){
          start = start.rightChild;
        }else{
          return start;
        }
      }
      /*
      if(start.leftChild.endTS < ts){
        start = start.rightChild;
      }else{
        if(start.leftChild.endTS == ts){
          return start.leftChild;
        }
        start = start.leftChild;
      } */
    }
    assert start.startTS == start.endTS;
    assert start.endTS == ts;
    return start;
  }

  /**
   * value comparison
   * @param t
   * @return
   */
  public boolean equals(TreeNode t){
    return 
    //(t.startTS == this.startTS) && (t.endTS == this.endTS)
    Arrays.equals(t.superhash, this.superhash);

  }
  /**
   * Returns whether the treeNode passed represents a structurally equal tree to this: all values and their relative pos should be same
   * @param t
   * @return
   */
  public boolean equalsTree(TreeNode t){

    if(t == null){
      return false;
    }
    boolean ret = true;
    if(t.leftChild != null){
      ret = ret && t.leftChild.equalsTree(leftChild);
    }else{
      ret = ret && leftChild == null;
    }
    if(t.rightChild != null){
      ret = ret && t.rightChild.equalsTree(rightChild);
    }else{
      ret = ret && rightChild == null;
    }
    if(t.getNext() != null){
      ret = ret && t.getNext().equalsTree(this.getNext());
    }else{
      ret = ret && this.getNext() == null;
    }
    if(t.prevNode != null){
      ret = ret && t.prevNode.equalsTree(prevNode);
    }else{
      ret = ret && prevNode == null;
    }

    return ret;
  }

  public synchronized TreeNode clone(){
    TreeNode ret = new TreeNode(this);

    if(leftChild == null){
      ret.leftChild = null;
    }else{
      ret.leftChild = this.leftChild.clone();
      if(ret.leftChild instanceof ImpreciseTreeNode){
        ((ImpreciseTreeNode)(ret.leftChild)).parent = ret;
      }
    }

    if(rightChild == null){
      ret.rightChild =null;
    }else{
      ret.rightChild = this.rightChild.clone();
      if(ret.rightChild instanceof ImpreciseTreeNode){
        ((ImpreciseTreeNode)(ret.rightChild)).parent = ret;
      }
    }

    if(ret.level != 0 && !(ret instanceof ImpreciseTreeNode)){

      TreeNode t1 = ret.leftChild;
      t1 = t1.getRightMostChild();
      TreeNode t2 = ret.rightChild.getLeftMostChild();
      t1.setNext(t2);
      t2.setPrev(t1);
    }
    //princem: Need to check this assert(this.equalsTree(ret));
    assert(this.equals(ret));
    return ret;
  }

  public InvalTarget getInvalTarget(){
    return invalTarget;
  }

  public Vector<TreeNode> getAllChildren(){
    Vector<TreeNode> vec = new Vector<TreeNode>();
    if(this.getLeftChild() != null){
      vec.addAll(this.getLeftChild().getAllChildren());
    }
    if(this.getRightChild() != null){
      vec.addAll(this.getRightChild().getAllChildren());
    }
    if(level == 0 || this instanceof ImpreciseTreeNode){
      vec.add(this);
    }
    return vec;
  }

  public TreeNode getLeftMostChild(){
    TreeNode t1 = this;
    while(!(t1.level == 0 || t1 instanceof ImpreciseTreeNode)){
      t1 = t1.getLeftChild();
    }
    return t1;
  }

  public TreeNode getRightMostChild(){
    TreeNode t1 = this;
    while(!(t1.level == 0 || t1 instanceof ImpreciseTreeNode)){
      t1 = t1.getRightChild();
    }
    return t1;
  }

  public String toString(){
    //    return "TreeNode: startTS:"+startTS+" endTS:"+ endTS + " maxDVV: " + (maxDVVMap!=null?maxDVVMap.getDVV():"NULL");
    return "TreeNode: startTS:"+startTS+" endTS:"+ endTS + "invalTarget: " + invalTarget + " level: " + level + 
    " maxDVV: " + (maxDVVMap!=null?maxDVVMap:"NULL") + (SangminConfig.printDebugHashes?" hash " + DVVMapEntry.byteString(hash) + "superhash: " + DVVMapEntry.byteString(superhash): "");
  }

  public int getLevel(){
    return level;
  }
  
  public Vector<Integer> onDiskSize(){
    if(this.getLeftChild() == null && this.getRightChild() == null){
      Vector<Integer> v = new Vector<Integer>();
      int diskSize = 0;
      int logicalSize = 0;
      diskSize = invalTarget.onDiskSize();
      assert diskSize >0: diskSize + " " + this;
      logicalSize = 1;
      v.add(diskSize);
      v.add(logicalSize);
      return v;
    }else if(this.getLeftChild() != null){
      if(this.getRightChild() != null){
        Vector<Integer> v = new Vector<Integer>();
        Vector<Integer> vl = this.getLeftChild().onDiskSize();
        Vector<Integer> vr = this.getRightChild().onDiskSize();
        v.add(vl.get(0) + vr.get(0));
        v.add(vl.get(1) + vr.get(1));
        return v;
      }else{
        return this.getLeftChild().onDiskSize();
      }
    }else{
      return this.getRightChild().onDiskSize();
    }
  }

  public byte[] getSuperHash(){
    return superhash;
  }

}
