package code.security;

import java.util.Enumeration;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.TreeSet;
import java.util.Vector;

import code.AcceptStamp;
import code.AcceptVV;
import code.CounterVV;
import code.HierInvalTarget;
import code.ImpreciseInv;
import code.InvalTarget;
import code.NoSuchEntryException;
import code.NodeId;
import code.SubscriptionSet;
import code.VV;
import code.VVIterator;
import code.security.ahs.AHSEntry;
import code.security.ahs.AHSMap;
import code.security.ahs.DVVMap;
import code.security.ahs.DependencyVV;
import code.security.ahs.IH;
import code.security.ahs.RootList;
import code.security.ahs.TreeNode;
import code.security.ahs.AHS;
import code.security.ahs.UnmatchingTreeNodeException;

class ImpreciseInvFilter {

  SecurityFilter securityFilter;

  ImpreciseInvFilter(SecurityFilter securityFilter)
  {
    this.securityFilter = securityFilter;
  }
  
  /**
   * verify the AHSEntry: compatibility and inclusion check 
   * @param nodeId
   * @param ahsEntry
   * @param sender the node who sent this entry
   * @param cvv
   * @return
   */
  public  boolean verifyAHSEntry(NodeId nodeId, AHSEntry ahsEntry, NodeId sender, AcceptVV cvv){
    return verifyAHSEntry(nodeId, ahsEntry, sender, cvv, true);
  }

  /**
   * verify the AHSEntry: compatibility and inclusion check (dependent on the verifyHash flag)
   * @param nodeId
   * @param ahsEntry
   * @param sender the node who sent this entry
   * @param cvv
   * @param verifyHash indicates whether the SH present in the DVV entries should be verified or not? If not, then the SH is set to the current SH 
   * @return
   */
  public  boolean verifyAHSEntry(NodeId nodeId, AHSEntry ahsEntry, NodeId sender, AcceptVV cvv, boolean verifyHash){
    assert SangminConfig.optimizeDVVMapHashes ||!verifyHash;
    DVVMap dvvMap = ahsEntry.getMaxDVVMap();
    if(!cvv.includes(dvvMap.getDVV())){
      assert false: "CVV" + cvv + " DVV " + dvvMap;
      return false;
    }
    TreeNode prev = securityFilter.getAhsMap().getLastTreeNodePriorTo(nodeId, ahsEntry.getStartTS());
    if(prev != null && prev.getNext()!= null){
      if(prev.getNext().getStartTS() != ahsEntry.getStartTS()){
	  System.out.println("PROBLEM: compatibility check failed" + ahsEntry + " prev " + prev + " " +  prev.getNext().getStartTS());
        return false;
      }
      DependencyVV dvv = dvvMap.getDVV();
      try{
        if(dvv.containsNodeId(nodeId) && dvv.getStampByServer(nodeId) != prev.getEndTS()){
	  System.out.println("PROBLEM: compatibility check failed" + ahsEntry + " prev " + prev + " " +  prev.getEndTS());
          return false;
        }
        if(prev.getEndTS()!= -1 && (!dvv.containsNodeId(nodeId) || dvv.getStampByServer(nodeId) != prev.getEndTS())){
	  System.out.println("PROBLEM: compatibility check failed" + ahsEntry + " prev " + prev);
          return false;
        }
      }catch(NoSuchEntryException e){
        // TODO Auto-generated catch block
        e.printStackTrace();
        assert false;
      }
    }
    
    if(verifyHash){
      return dvvMap.retain(nodeId).verifyDVVMap(securityFilter.getAhsMap(), securityFilter.secureRMIClient, sender, securityFilter);
    }else{
      return true;
    }
  }

  /***
   * Try and apply secureimpreciseinv sii. The process may fail if compatibility check fails: if we have accepted incompatible 
   * updates for the overlapping range.
   * @param sii
   * @return
   */
  boolean tryAndApply(SecureImpreciseInv sii, NodeId sender){
    assert securityFilter.core.specialLockHeldByMe();

    CounterVV cvv = new CounterVV(securityFilter.getCurrentVV());
    cvv.advanceTimestamps(sii.getEndVV());
    // Apply ih tuples for this writer
    try{

      for(NodeId nId: sii.ihTuples.keySet()){
        for(IH ih: sii.ihTuples.get(nId)){
          if(!cvv.includes(ih.getAHSEntry().getMaxDVVMap().getDVV()) || !this.verifyAHSEntry(nId, ih.getAHSEntry(), sender, new AcceptVV(cvv))){
            assert false;
            return false;
          }else{
            securityFilter.getAhsMap().applyAHS(nId, new AHS(ih.getAHSEntry()),sii);
          }
        }
      }

    }catch(UnmatchingTreeNodeException e){
      e.printStackTrace();
      assert(false);
      return false;
      //TODO: call controller
    }

    return true;
  }

  ImpreciseInv getInsecureImpreciseInv(Hashtable<NodeId, Vector<IH>> ihs){
    assert !ihs.isEmpty();
    CounterVV startVV = new CounterVV();
    CounterVV endVV = new CounterVV();
    InvalTarget it = new HierInvalTarget();
    SubscriptionSet ss = SubscriptionSet.makeSubscriptionSet("/*");
    long max = Long.MIN_VALUE;
    long min = Long.MAX_VALUE;
    for(NodeId nId: ihs.keySet()){
      max = Long.MIN_VALUE;
      min = Long.MAX_VALUE;
      for(IH ih: ihs.get(nId)){
        max=(max>ih.getEndTS()?max:ih.getEndTS());
        min=(min<ih.getStartTS()?min:ih.getStartTS());
        it = it.getUnion(ih.getAHSEntry().getInvalTarget(), ss);
      }
      startVV.setStamp(new AcceptStamp(min, nId));
      endVV.setStamp(new AcceptStamp(max, nId));
    }
    assert !it.isEmpty();
    return new ImpreciseInv(it, startVV.cloneAcceptVV(), endVV.cloneAcceptVV());
  }

  ImpreciseInv getInsecureImpreciseInv(NodeId nId, Vector<IH> ihs){
    CounterVV startVV = new CounterVV();
    CounterVV endVV = new CounterVV();
    InvalTarget it = new HierInvalTarget();
    SubscriptionSet ss = SubscriptionSet.makeSubscriptionSet("/*");
    long max = Long.MIN_VALUE;
    long min = Long.MAX_VALUE;
    max = Long.MIN_VALUE;
    min = Long.MAX_VALUE;
    for(IH ih: ihs){
      max=(max>ih.getEndTS()?max:ih.getEndTS());
      min=(min<ih.getStartTS()?min:ih.getStartTS());
      it = it.getUnion(ih.getAHSEntry().getInvalTarget(), ss);
    }
    startVV.setStamp(new AcceptStamp(min, nId));
    endVV.setStamp(new AcceptStamp(max, nId));
    return new ImpreciseInv(it, startVV.cloneAcceptVV(), endVV.cloneAcceptVV());
  }


  boolean applyImprecise(SecureImpreciseInv sii, NodeId sender){
    AcceptVV initialVV = securityFilter.getCurrentVV();

    if(!tryAndApply(sii, sender)){
      return false;
    }
    this.securityFilter.apply(sii, initialVV);
    return true;
  }

  ImpreciseInv createImpreciseInv(ImpreciseInv iInv){

    //TODO: zjd, need to sanitycheck the iInv.targetSet is 
    // consistent with what we get from all the ihtuples in the secureImpreciseInv 
    // we are going to generate.
    ImpreciseInv ret = null;
    if(SangminConfig.securityLevel == SangminConfig.COMPLETE){

      Hashtable<NodeId,TreeSet<Long>> splitPoints = new Hashtable<NodeId,TreeSet<Long>>(); //key:NodeId, value: TreeSet
      Hashtable<NodeId,Vector<IH>> ihTuples = new Hashtable<NodeId,Vector<IH>>(); // key:NodeId, value: Vector<IH>

      AcceptVV startVV = iInv.getStartVV();
      AcceptVV endVV = iInv.getEndVV();

      assert endVV.includes(startVV):"invalid imprecise inval" + iInv;
      assert !startVV.includes(endVV) || startVV.equals(endVV):"invalid imprecise inval" + iInv;
      CounterVV externalDVV = new CounterVV();

      //System.out.println("~~~~~~~~~ 1");
      for(VVIterator vvi = startVV.getIterator(); vvi.hasMoreElements();){
        try{
          Object token = vvi.getNext();
          NodeId nodeId = startVV.getServerByIteratorToken(token);
          long startTS = startVV.getStampByIteratorToken(token);
          long endTS = endVV.getStampByServer(nodeId);
          assert !securityFilter.getAhsMap().isEmpty(nodeId):"RootList shouldn't be empty";
          Vector<IH> v =securityFilter.getAhsMap().generateIHs(nodeId,startTS,endTS);
          assert v.size() > 0:"can't generate empty imprecise invals"+nodeId+" " + startVV + " " + endVV + " iInv" + iInv;
          ihTuples.put(nodeId, v);
          DVVMap maxDVVMap = this.securityFilter.getMaxDVVMap(v,nodeId);

          externalDVV.advanceTimestamps(maxDVVMap.getDVV());

          Enumeration<NodeId> e = maxDVVMap.getNodes();
          while(e.hasMoreElements()){
            NodeId nId = (NodeId)e.nextElement();
            TreeSet<Long> set = splitPoints.get(nId);
            if(set == null){
              set = new TreeSet<Long>();
            }
            set.add(Long.valueOf(maxDVVMap.getEntry(nId).getTimeStamp()));
            splitPoints.put(nId,set);
          }
        }catch(Exception e){
          e.printStackTrace();
          System.exit(-1);
        }
      }

      //System.out.println("~~~~~~~~~ 2");
      // Split IHes such that receiver can verify maxDvvMap

      for(NodeId nodeId: ihTuples.keySet()){

        Vector<IH> v = ihTuples.get(nodeId);
        TreeSet<Long> perNodeSplitPoints = splitPoints.get(nodeId);
        if(perNodeSplitPoints != null){
          v = splitIH(nodeId, v, splitPoints.get(nodeId));
        }
        ihTuples.put(nodeId, v);

        externalDVV.setStampByServer(nodeId, AcceptVV.BEFORE_TIME_BEGAIN);
      }

      externalDVV.dropNegatives();

      //System.out.println("~~~~~~~~~ 3");
      //zjd - start
      // sanitycheck of impreciseInv
      ImpreciseInv ihImpreciseInv = this.getInsecureImpreciseInv(ihTuples);
      assert ihImpreciseInv.getStartVV().equals(iInv.getStartVV());
      assert ihImpreciseInv.getEndVV().equals(iInv.getEndVV());
      //iInv should be conservative -- actually we could ignore iInv.taget, just take ihImpreciseInv.target
      InvalTarget ihIt = ihImpreciseInv.getInvalTarget();
      assert !ihIt.isEmpty();
      assert !iInv.getInvalTarget().isEmpty():" empty impreciseinv:" + iInv;
      assert ihIt.getIntersection(iInv.getInvalTarget()).equals(ihIt);
      assert !ihImpreciseInv.getInvalTarget().isEmpty();
      //zjd -end
      assert !ihImpreciseInv.getInvalTarget().isEmpty();
      //System.out.println("~~~~~~~~~ 4");
      return new SecureImpreciseInv(iInv, ihTuples, securityFilter.core.getMyNodeId(), externalDVV.cloneAcceptVV());

    }else if(SangminConfig.securityLevel == SangminConfig.SIGNATURE){

      return new SecureImpreciseInv(iInv,null, securityFilter.core.getMyNodeId());          
    }else{
      ret = iInv;
    }
    return ret;
  }


  private Vector<IH> splitIH(NodeId nodeId, Vector<IH> v, TreeSet<Long> set){

    Vector<IH> ret = new Vector<IH>(); // Vector<IH>

    AHSMap ahsMap = securityFilter.getAhsMap();

    //  System.out.println("++++++ before spliting ++++++");
    //  for(int j=0; j<v.size(); j++){
    //  IH ih = (IH) v.elementAt(j);
    //  System.out.println("start:"+ih.getStartTS()+" end:"+ih.getEndTS());
    //  }
    //  System.out.println("++++++ spliting points++++++");
    //  Iterator<Long> it2 = set.iterator();
    //  while(it2.hasNext()){
    //  System.out.println(it2.next());
    //  }


    for(int i=0; i < v.size(); i++){
      IH ih = (IH)v.elementAt(i);
      IH curIH = ih;
      TreeSet<Long> done = new TreeSet<Long>();
      Iterator<Long> it = set.iterator();

      while(it.hasNext()){
        Long TS = it.next();
        long ts = TS.longValue();
//      System.out.println("Current start:"+curIH.getStartTS());
//      System.out.println("Current end:"+curIH.getEndTS());
//      System.out.println("Current ts:"+ts);
        if(curIH.getStartTS()<=ts && curIH.getEndTS() > ts){
          //Split this ih
          ret.addAll(securityFilter.getAhsMap().generateIHs(nodeId, curIH.getStartTS(),ts));
//        System.out.println("Splitted start:"+ret.lastElement().getStartTS());
//        System.out.println("Splitted end:"+ret.lastElement().getEndTS());
          TreeNode nextNode = ahsMap.getTreeNodeAfter(nodeId, ts);
          Vector<IH> left = securityFilter.getAhsMap().generateIHs(nodeId, nextNode.getStartTS(), curIH.getEndTS());

          assert(left != null);
          assert(left.size() > 0);
          curIH = left.remove(0);
          v.addAll(i+1, left);

          //curIH = securityFilter.getAhsMap().generateIHs(nodeId, nextNode.getStartTS(), curIH.getEndTS()).firstElement();
//        System.out.println("left start:"+curIH.getStartTS());
//        System.out.println("left end:"+curIH.getEndTS());
          done.add(TS);
        }else if(curIH.getEndTS() == ts){
          //ret.add(curIH);
          done.add(TS);
          set.removeAll(done);
          break;
        }else if(curIH.getEndTS() < ts){
          set.removeAll(done);
          break;
        }

      }
      ret.add(curIH);
    }

    //  System.out.println("++++++ after spliting ++++++");
    //  for(int j=0; j<ret.size(); j++){
    //  IH ih = (IH) ret.elementAt(j);
    //  System.out.println("start:"+ih.getStartTS()+" end:"+ih.getEndTS());
    //  }

    return ret;

  }

}
