package com.xebialabs.database.anonymizer

import com.xebialabs.database.anonymizer.AnonymizerContants.{DEFAULT_FILE_NAME, DEFAULT_REPORTING_FILE_NAME}
import org.dbunit.database.DatabaseConnection
import org.dbunit.dataset.stream.{IDataSetProducer, StreamingDataSet}
import org.dbunit.dataset.xml.FlatXmlProducer
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.stereotype.Service

import java.sql.SQLException
import scala.xml.InputSource

@Service
class ImportService {

  @transient private lazy val logger = LoggerFactory.getLogger(getClass)

  @Autowired
  var repository: DatabaseRepository = _

  @throws[Exception]
  def start(options: AnonymizerOptions): Unit = {
    logger.info("Data import has been started. Please wait....")
    var connection: DatabaseConnection = null
    try {
      connection = repository.getDatabaseConnection(options.reportingDb)
      if (options.batchSize > 0) repository.enableBatchProcessing(connection, options.batchSize)

      val dataSet = getStreamingDataset(options.reportingDb, options.sourceFileName)
      val databaseName = repository.getDatabaseName(connection.getConnection)
      repository.getDatabaseOperation(databaseName, options.dbRefresh).execute(connection, dataSet)
      resetDatabase(connection, databaseName)

      logger.info("Data import has been successfully finished")
    } finally if (connection != null) try connection.close()
    catch {
      case exception: SQLException =>
        logger.error("Exception occurred while closing the connection ", exception)
    }
  }

  private def getStreamingDataset(isReportingDb: Boolean, fileName: String): StreamingDataSet = {
    val defaultFileName = if (isReportingDb) DEFAULT_REPORTING_FILE_NAME else DEFAULT_FILE_NAME
    val producer = new FlatXmlProducer(new InputSource(if (fileName != null) fileName else defaultFileName), true)
    producer.setValidating(false)
    producer.setColumnSensing(false)
    new StreamingDataSet(producer.asInstanceOf[IDataSetProducer])
  }

  private def resetDatabase(connection: DatabaseConnection, databaseName: String): Unit = {
    DatabaseName.toValue(databaseName) match {
      case DatabaseName.POSTGRES =>
        val seqStmt = connection.getConnection.createStatement
        seqStmt.closeOnCompletion()
        val rs = seqStmt.executeQuery("SELECT c.relname FROM pg_class c WHERE c.relkind = 'S';")
        while ( {
          rs.next
        }) {
          val sequence = rs.getString("relname")
          val table = sequence.replace("_ID_seq", "")
          val updStmt = connection.getConnection.createStatement
          updStmt.closeOnCompletion()
          var query = "SELECT SETVAL('\"%s\"', (SELECT MAX(\"ID\")+1 FROM \"%s\"));"
          query = String.format(query, sequence, table)
          updStmt.executeQuery(query)
        }
      case _ =>
    }
  }
}
