#include "NTL/ZZ_pX.h"

#include "FHEProxy.h"
#include "FHE/FHE-SI.h"
#include "FHE/Util.h"
#include "FHE/Serialization.h"
#include <algorithm>

void BatchPolynomial(vector<ZZ_pX> &polynomials, int &maxDegree, 
                     const vector<unsigned int> &records,
                     size_t nRecords, size_t numSlots) {
  assert(records.size() > 0);
 
  size_t recordsPerSlot = (nRecords + numSlots - 1) / numSlots;
  
  vec_ZZ_p roots;
  polynomials.resize(numSlots, to_ZZ_pX(1));
  
  maxDegree = -1;
  
  ZZ_pX poly;
  size_t curInd = records[0] / recordsPerSlot;
  for (size_t i = 0; i < records.size(); i++) {
    unsigned int record = records[i];
    if (record >= (curInd + 1) * recordsPerSlot) {
      BuildFromRoots(poly, roots);
      polynomials[curInd] = poly;
      
      maxDegree = max(roots.length(), maxDegree);
      
      roots.SetLength(0);
      curInd = record / recordsPerSlot;
    }
    append(roots, to_ZZ_p(record));
  }
  
  BuildFromRoots(poly, roots);
  polynomials[curInd] = poly;
  
  maxDegree = max(roots.length(), maxDegree);
}

void EncryptPolynomial(Polynomial<Ciphertext> &batchedPoly, vector<ZZ_pX> &sets,
                       FHESIPubKey &publicKey, size_t nRecords, int maxDegree) {
  Ciphertext tmp(publicKey);
  batchedPoly.SetDegree(maxDegree, tmp);
  for (int i = 0; i <= maxDegree; i++) {
    vector<ZZ_pX> coefficients(sets.size(), ZZ_pX::zero());
    
    for (size_t j = 0; j < sets.size(); j++) {
      if (deg(sets[j]) >= (int) i) {
        coefficients[j] = to_ZZ_pX(sets[j].rep[i]);
      }
    }
    
    Plaintext coeff(publicKey.GetContext());
    coeff.EmbedInSlots(coefficients, false);
    publicKey.Encrypt(batchedPoly[i], coeff);
  }
}

void ComputePowers(vector<Polynomial<Ciphertext>> &powers, vector<ZZ_pX> &polynomials,
                   int degree, int maxDegree, FHESIPubKey &publicKey) {
  
  int reducedDegree = 0;
  Ciphertext dummy(publicKey);
  
  powers.resize(maxDegree - degree);
  for (int d = degree + 1; d <= maxDegree; d++) {
    vector<ZZ_pX> reduced(polynomials.size());
    
    ZZ_pX power;
    SetCoeff(power, d, 1);
    
    for (size_t i = 0; i < polynomials.size(); i++) {
      rem(reduced[i], power, polynomials[i]);
      reducedDegree = max(reducedDegree, deg(reduced[i]));
    }
    
    Polynomial<Ciphertext> &batchedPoly = powers[d-degree-1];
    batchedPoly.SetDegree(reducedDegree, dummy);
    
    for (int i = 0; i <= reducedDegree; i++) {
      vector<ZZ_pX> batchedCoeffs(polynomials.size());
      for (size_t j = 0; j < polynomials.size(); j++) {
        if (deg(reduced[j]) >= i) {
          batchedCoeffs[j] = to_ZZ_pX(reduced[j].rep[i]);
        }
      }
      
      Plaintext coeff(publicKey.GetContext());
      coeff.EmbedInSlots(batchedCoeffs, false);
      publicKey.Encrypt(batchedPoly[i], coeff);
    }
  }
}

bool BuildIndex(FHEProxy &proxy, size_t &nRecords, int &maxDegree,
                const string &filename, FHESIPubKey &publicKey,
                bool modularReduction = false, bool output = false) {
  size_t numSlots = publicKey.GetContext().GetPlaintextSpace().GetTotalSlots();
  
  ifstream fin;
  fin.open(filename);
  
  if (!fin) {
    return false;
  }
  
  int nTags;
  unsigned int tag, recordsForTag;
  
  map<unsigned int, vector<ZZ_pX>> polynomials;
  map<unsigned int, Polynomial<Ciphertext>> index;
  
  maxDegree = -1;
  
  fin >> nTags >> nRecords;  
  for (int i = 0; i < nTags; i++) {
    fin >> tag >> recordsForTag;
    
    // For convenience, we will require the records to be
    // in sorted order
    int prev = -1;
    
    vector<unsigned int> records(recordsForTag);
    for (size_t j = 0; j < recordsForTag; j++) {
      fin >> records[j];
      
      assert((int) records[j] > prev);
      prev = records[j];
    }
    
    int degree;
    BatchPolynomial(polynomials[tag], degree, records, nRecords, numSlots);
    EncryptPolynomial(index[tag], polynomials[tag], publicKey, nRecords, degree);
    
    maxDegree = max(degree, maxDegree);
    
    if (output) {
      cout << "Built entry for tag " << tag << " (degree "
     << degree << ")" << endl;
    }
  }
  
  proxy.SetIndex(index);
  
  if (modularReduction) {
    map<unsigned int, vector<Polynomial<Ciphertext>>> modReductionTbl;
    
    for (auto it = polynomials.begin(); it != polynomials.end(); it++) {
      unsigned int tag = it->first;
      ComputePowers(modReductionTbl[tag], it->second, index[tag].GetDegree(),
                    maxDegree, publicKey);
    }
    
    proxy.SetModReductionTable(modReductionTbl);
  }
  
  return true;
}

bool ReadQueries(vector<vector<unsigned int>> &queries, const string &filename) {
  ifstream fin;
  fin.open(filename);
  
  if (!fin) {
    return false;
  }
  
  int nQueries;
  fin >> nQueries;
  
  for (int i = 0; i < nQueries; i++) {
    size_t nTags;
    fin >> nTags;
    
    vector<unsigned int> query(nTags);
    for (size_t j = 0; j < nTags; j++) {
      fin >> query[j];
    }
    
    queries.push_back(query);
  }
  
  return true;
}

int main(int argc, char *argv[]) {
  srand48(time(NULL));
  SetSeed(to_ZZ(time(NULL)));

  string inputfile, queryfile, outputfile;
  bool modularReduction = false;
  bool withKeySwitch = false;
  bool output = true;
  if (argc < 3) {
    printf("usage: ./TestProxy datafile queryfile [mod_reduction] [outputfile] [output]\n");
    return 1;
  } else {
    inputfile = argv[1];
    queryfile = argv[2];
    
    if (argc > 3) {
      int selection = atoi(argv[3]);
      modularReduction = (selection != 0);
      withKeySwitch = (selection == 1);
    }
    
    if (argc > 4) {
      outputfile = string(argv[4]);
    }

    if (argc > 5 && atoi(argv[5]) == 0) {
      output = false;
    }
  }
  
  vector<vector<unsigned int>> queries;

  if (!ReadQueries(queries, queryfile)) {
    cout << "Unable to read query file." << endl;
    return 1;
  }

  if (output) {
    cout << "Processing database " << inputfile << endl;
    if (modularReduction) {
      cout << "Running with modular reduction optimization" << endl;
      if (withKeySwitch) {
        cout << "Running with key switching in modulus reduction" << endl;
      }
      cout << endl;
    }
  }
  
  double partStart(clock());
  
  FHEcontext context(5147, 157, to_ZZ(1000051807), 2, 3);
  
  activeContext = &context;
  context.SetUpSIContext();
  
  FHESISecKey secretKey(context);
  FHESISecKey tensoredKey(context);

  if (modularReduction && !withKeySwitch) {
    vector<DoubleCRT> sKeys = secretKey.GetRepresentation();

    vector<DoubleCRT> tKeys;
    tKeys.assign(sKeys.size()*2-1, sKeys[1]);
    tKeys[0] = sKeys[0];

    for (unsigned i = 2; i < tKeys.size(); i++) {
      tKeys[i] *= tKeys[i-1];
    }

    tensoredKey.UpdateRepresentation(tKeys);
  }

  FHESIPubKey publicKey(secretKey);
  
  KeySwitchSI *keySwitchPtr = NULL;
  
  keySwitchPtr = NULL;

  if (modularReduction && withKeySwitch) {
    keySwitchPtr = new KeySwitchSI(secretKey);
  }
  
  if (output) {
    cout << "Setup time: " 
         << (clock()-partStart)/CLOCKS_PER_SEC << endl << endl;
  }
  
  FHEProxy proxy(context, secretKey, publicKey, keySwitchPtr);
  
  size_t totalRecords;
  int maxDegree;
  
  partStart = clock();
  if (!BuildIndex(proxy, totalRecords, maxDegree, inputfile,
      publicKey, modularReduction, output)) {
    cout << "Unable to read database file." << endl;
    return 1;
  }

  if (output) {
    cout << "Index building time: "
         << (clock()-partStart)/CLOCKS_PER_SEC << endl << endl;
  }
  
  for (size_t i = 0; i < queries.size(); i++) {
    if (output) {
      cout << endl;
      cout << "=====================================" << endl;
      cout << "Query " << i+1 << ": ";
      PrintVector(queries[i]);
      cout << " (" << queries[i].size() << ")" << endl;
      cout << "=====================================" << endl << endl;
    }
    
    Polynomial<Ciphertext> result;
    
    double start(clock());
    partStart = clock();
    
    proxy.PerformQuery(result, queries[i], modularReduction);
    
    if (output) {
      cout << "Query time: "
           << (clock()-partStart)/CLOCKS_PER_SEC << endl;
    }
    
    partStart = clock();
    size_t numSlots = context.GetPlaintextSpace().GetTotalSlots();
    size_t recordsPerSlot = (totalRecords + numSlots - 1) / numSlots;
    
    vector<ZZ_pX> polynomials(numSlots);
    for (unsigned int i = 0; i <= result.GetDegree(); i++) {
      vector<ZZ_pX> batchedCoeffs;
    
      Plaintext ptxt;
      if (modularReduction && !withKeySwitch && result[i].size() > 2) {
       tensoredKey.Decrypt(ptxt, result[i]);
      } else {
       secretKey.Decrypt(ptxt, result[i]);
      }
      ptxt.DecodeSlots(batchedCoeffs, false);
      
      for (size_t j = 0; j < batchedCoeffs.size(); j++) {
        if (batchedCoeffs[j] == ZZ_pX::zero()) {
          SetCoeff(polynomials[j], i, 0);
        } else {
          SetCoeff(polynomials[j], i, batchedCoeffs[j].rep[0]);
        }
      }
    }
    
    if (output) {
      cout << "Decryption time: "
           << (clock()-partStart)/CLOCKS_PER_SEC << endl;
      
      cout << endl << "Results: " << endl;
    }
    
    partStart = clock();
    for (size_t i = 0; i < polynomials.size() && recordsPerSlot*i <= totalRecords; i++) {
      vec_ZZ_p values;
      values.SetLength(recordsPerSlot);
      for (size_t j = 0; j < recordsPerSlot; j++) {
        if (recordsPerSlot*i+j > totalRecords) {
          break;
        }
        values[j] = recordsPerSlot*i+j;
      }
      
      vec_ZZ_p evalResults;
      eval(evalResults, polynomials[i], values);
      
      for (size_t j = 0; j < recordsPerSlot; j++) {
        if (evalResults[j] == ZZ_pX::zero() && recordsPerSlot*i+j <= totalRecords) {
          cout << recordsPerSlot*i+j << " ";
        }
      }
    }
    cout << endl;
    
    if (output) {
      cout << endl;
      cout << "Degree of polynomial: " << result.GetDegree() << endl;
      
      cout << "Root finding time: "
           << (clock()-partStart)/CLOCKS_PER_SEC << endl << endl
           << "Total time: " 
           << (clock()-start)/CLOCKS_PER_SEC << endl;
    }

    if (outputfile.length() > 0) {
      ofstream out;
      out.open(outputfile, ios::out | ios::binary);
      
      for (unsigned int i = 0; i <= result.GetDegree(); i++) {
        Export(out, result[i]);
      }

      out.close();

      if (output) {
       cout << endl << "Serialized " << result.GetDegree() + 1 << " polynomials." << endl;
      }
    }
  }
  
  return 0;
}
