#include "FHEProxy.h"

void FHEProxy::PerformQuery(Polynomial<Ciphertext> &result, 
                            const vector<unsigned int> &queries,
                            bool modularReduction) {
  if (queries.empty()) {
    return;
  }
  
  // Since the number of tags can be large, we limit the number of
  // times we need to query the map
  vector<Polynomial<Ciphertext> *> queryPoly(queries.size());
  queryPoly[0] = &invIndex[queries[0]];
  
  // Find the polynomial of lowest and highest degree
  size_t maxIndex = 0;
  size_t minIndex = 0;
  size_t maxDegree = queryPoly[0]->GetDegree();
  size_t minDegree = maxDegree;
  
  for (size_t i = 1; i < queries.size(); i++) {
    queryPoly[i] = &invIndex[queries[i]];
    size_t degree = queryPoly[i]->GetDegree();
    
    if (maxDegree < degree) {
      maxDegree = degree;
      maxIndex = i;
    }
    
    if (minDegree > degree) {
      minDegree = degree;
      minIndex = i;
    }
  }
  
  if (minIndex == maxIndex) { // will both be 0
    maxIndex++;
  }
  
  // Linear combination of polynomials for S_2, ..., S_t
  result = invIndex[queries[maxIndex]];
  
  Plaintext packedCoeffs(context);
  for (size_t i = 0; i < queries.size(); i++) {
    if (i == minIndex || i == maxIndex) {
      continue;
    }
    
    Polynomial<Ciphertext> matchingRecords = *queryPoly[i];
    
    vector<ZZ_pX> coeffs(context.GetPlaintextSpace().GetTotalSlots());
    for (size_t j = 0; j < coeffs.size(); j++) {
      ZZ_p randVal;
      
      random(randVal);
      coeffs[j] = to_ZZ_pX(randVal);
    }
    packedCoeffs.EmbedInSlots(coeffs, false);

    matchingRecords *= packedCoeffs.message;
    result += matchingRecords;
  }
  
  Polynomial<ZZ_pX> randPoly1;
  Polynomial<ZZ_pX> randPoly2;
  
  // Apply modular reduction if necessary
  if (modularReduction && minDegree < maxDegree) {
    Polynomial<Ciphertext> reduced;
    ApplyModularReduction(reduced, result, queries[minIndex]);
    
    result = reduced;
    
    GenerateRandPoly(randPoly1, minDegree);
    GenerateRandPoly(randPoly2, minDegree-1);
  } else {
    GenerateRandPoly(randPoly1, minDegree-1);
    GenerateRandPoly(randPoly2, maxDegree-1);
  }
  
  result *= randPoly1;
  
  Polynomial<Ciphertext> lastTerm = *queryPoly[minIndex];
  lastTerm *= randPoly2;

  result += lastTerm;
}

void FHEProxy::ApplyModularReduction(Polynomial<Ciphertext> &result,
                                     const Polynomial<Ciphertext> &poly,
                                     unsigned int tag) {
  const vector<Polynomial<Ciphertext>> &powers = modReductionTable[tag];
  int modDegree = invIndex[tag].GetDegree();
  
  Ciphertext dummy(publicKey);
  
  result.SetDegree(modDegree, dummy);
  for (int i = 0; i <= modDegree; i++) {
    result[i] = poly[i];
  }
  
  for (size_t i = modDegree + 1; i <= poly.GetDegree(); i++) {
    Polynomial<Ciphertext> reduced = powers[i-modDegree-1];
    reduced *= poly[i];
    
    for (size_t j = 0; j <= reduced.GetDegree(); j++) {
      keySwitch->ApplyKeySwitch(reduced[j]);
    }
    
    result += reduced;
  }
}

void FHEProxy::SetIndex(const map<unsigned int, Polynomial<Ciphertext>> &newIndex) {
  invIndex = newIndex;
}

void FHEProxy::SetModReductionTable(const map<unsigned int, vector<Polynomial<Ciphertext>>> &newTable) {
  modReductionTable = newTable;
}

void FHEProxy::GenerateRandPoly(Polynomial<ZZ_pX> &randPoly, unsigned int degree) const {
  randPoly.SetDegree(degree);
  
  Plaintext packedCoeffs(context);
  for (size_t i = 0; i <= degree; i++) {
    vector<ZZ_pX> coeffs(context.GetPlaintextSpace().GetTotalSlots());
    for (size_t j = 0; j < coeffs.size(); j++) {
      ZZ_p randVal;
      random(randVal);
    
      coeffs[j] = to_ZZ_pX(randVal);
    }
    packedCoeffs.EmbedInSlots(coeffs, false);
    
    randPoly[i] = packedCoeffs.message;
  }
}
