package com.xebialabs.database.anonymizer

import org.dbunit.dataset.{IDataSet, ITable, ITableIterator, ITableMetaData, ReplacementDataSet}
import org.slf4j.{Logger, LoggerFactory}

import java.util
import scala.jdk.CollectionConverters._

class XldReplacementDataSet(val configuration: AnonymizerConfiguration, val dataSet: IDataSet) extends ReplacementDataSet(dataSet) {

  private val logger: Logger = LoggerFactory.getLogger(classOf[XldReplacementDataSet])

  private val replacedItems: JMap[AnyRef, AnyRef] = new util.HashMap[AnyRef, AnyRef]
  private val replacedSubstrings: JMap[AnyRef, AnyRef] = new util.HashMap[AnyRef, AnyRef]

  private val tablesToAnonymize = configuration.tablesToAnonymize
  private val fieldsToIgnore = configuration.encryptedFieldsToIgnore

  override def getTable(tableName: String): ITable = {
    logger.debug("getTable(tableName={}) - start", tableName)
    createXldReplacementTable(dataSet.getTable(tableName))
  }

  override def addReplacementSubstring(originalSubstring: String, replacementSubstring: String): Unit = {
    logger.debug("addReplacementSubstring(originalSubstring={}, replacementSubstring={}) - start", originalSubstring, replacementSubstring)
    if (originalSubstring != null && replacementSubstring != null) replacedSubstrings.put(originalSubstring, replacementSubstring)
  }

  override def addReplacementObject(originalItem: AnyRef, replacementItem: AnyRef): Unit = {
    logger.debug("addReplacementObject(originalObject={}, replacementObject={}) - start", originalItem, replacementItem)
    replacedItems.put(originalItem, replacementItem)
  }

  private def createXldReplacementTable(table: ITable): XldReplacementTable = {
    logger.debug("createXldReplacementTable(table={}) - start", table)
    val replacementTable:XldReplacementTable = new XldReplacementTable(table, replacedItems,
      replacedSubstrings, null, null, tablesToAnonymize.asScala.toList, fieldsToIgnore.asScala.toList)
    replacementTable.setStrictReplacement(true)
    replacementTable
  }

  override protected def createIterator(reversed: Boolean): ITableIterator = {
    logger.debug("createIterator(reversed={}) - start", reversed)
    new XldReplacementIterator(if (reversed) dataSet.reverseIterator else dataSet.iterator)
  }

  private class XldReplacementIterator(val iterator: ITableIterator) extends ITableIterator {
    override def next: Boolean = iterator.next

    override def getTableMetaData: ITableMetaData = {
      logger.debug("getTableMetaData() - start")
      iterator.getTableMetaData
    }

    override def getTable: ITable = {
      logger.debug("getTable() - start")
      createXldReplacementTable(iterator.getTable)
    }
  }
}
