Skip to content
This repository has been archived by the owner on Apr 10, 2019. It is now read-only.

WIP: New goal mix dataset #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/main/kotlin/net/gosecure/spotbugs/MixDatasetMojo.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package net.gosecure.spotbugs

import net.gosecure.spotbugs.datasource.ml.MLUtils
import org.apache.maven.plugin.AbstractMojo
import org.apache.maven.plugins.annotations.Mojo
import org.apache.maven.plugins.annotations.Parameter
import org.apache.maven.project.MavenProject
import weka.classifiers.Classifier
import weka.classifiers.bayes.NaiveBayes

@Mojo(name="train-predict")
class MixDatasetMojo : AbstractMojo() {

@Parameter(readonly = true, defaultValue = "\${project}")

private lateinit var project: MavenProject

private val FILE_INPUT = "aggregate-results_classified_sample.csv"
private val FILE_OUTPUT = "aggregate-results_classified_sample.arff"
private val FILE_RESULTS = "aggregate-results_classified_sample_labeled.csv"
private val MODEL_SAVED = "test-saved-model.model"

override fun execute() {
log.info("Training and predicting...")

//Instantiate configuration
val cfg = MLUtils().initConfig()

val dataUnfiltered = MLUtils().getInstances(project, FILE_INPUT, FILE_OUTPUT)
val dataFiltered = MLUtils().filterMeta(dataUnfiltered)
dataFiltered.setClassIndex(dataFiltered.numAttributes() - 1)

val dataSplit = MLUtils().splitDataset(dataFiltered)
val dataTrain = dataSplit[0]
var dataPredict = dataSplit[1]

// Use a set of classifiers
val models = arrayOf<Classifier>(
NaiveBayes())

// Run for each model
for (j in models.indices) {
System.out.println("\n" + models[j].javaClass.simpleName)

//10 fold-cross validation, print stats data in html
MLUtils().trainStats(project, cfg, models[j], dataTrain)

//Train on full data : build the classifier
val model : Classifier = MLUtils().trainFullData(models[j], dataTrain)
//if needed
MLUtils().saveModel(project, model, MODEL_SAVED)

//Predict
dataPredict = MLUtils().createClassAttribute(dataPredict)
dataPredict.setClassIndex(dataPredict.numAttributes() - 1)

MLUtils().makePredictions(project, cfg, dataUnfiltered, dataPredict, model, FILE_INPUT, FILE_RESULTS)
}
}
}
3 changes: 2 additions & 1 deletion src/main/kotlin/net/gosecure/spotbugs/PredictMojo.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class PredictMojo : AbstractMojo() {
private val FILE_OUTPUT = "aggregate-results_classified_sample.arff"
private val FILE_INPUT_TEST = "aggregate-results_classified_sample_unlabeled.csv"
private val FILE_OUTPUT_TEST = "aggregate-results_classified_sample_unlabeled.arff"
private val FILE_RESULTS = "aggregate-results_classified_sample_labeled.csv"
private val MODEL_SAVED = "test-saved-model.model"

override fun execute() {
Expand All @@ -35,6 +36,6 @@ class PredictMojo : AbstractMojo() {

dataUnlabeled.setClassIndex(dataUnlabeled.numAttributes() - 1)

MLUtils().makePredictions(project, cfg, dataUnfiltered, dataUnlabeled, model)
MLUtils().makePredictions(project, cfg, dataUnfiltered, dataUnlabeled, model, FILE_INPUT_TEST, FILE_RESULTS)
}
}
16 changes: 14 additions & 2 deletions src/main/kotlin/net/gosecure/spotbugs/TrainMojo.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ import org.apache.maven.project.MavenProject
import weka.classifiers.*
import net.gosecure.spotbugs.datasource.ml.MLUtils
import weka.classifiers.bayes.NaiveBayes
import weka.classifiers.functions.MultilayerPerceptron
import weka.classifiers.functions.SMO
import weka.classifiers.trees.J48
import weka.classifiers.trees.RandomForest
import weka.classifiers.trees.RandomTree

@Mojo(name="train")
class TrainMojo: AbstractMojo() {
Expand All @@ -33,17 +38,24 @@ class TrainMojo: AbstractMojo() {
val models = arrayOf<Classifier>(
NaiveBayes())

/*val options = arrayOfNulls<String>(2)
options[0] = "-K"
options[1] = "4"
(models[0] as RandomForest).setOptions(options)*/

// Run for each model
for (j in models.indices) {
System.out.println("\n" + models[j].javaClass.simpleName)
//log.info((models[j] as RandomForest).numIterations.toString())
//log.info((models[j] as RandomForest).numFeatures.toString())

//10 fold-cross validation, print stats data in html
MLUtils().trainStats(project, cfg, models[j], dataFiltered)

//Train on full data : build the classifier
MLUtils().trainFullData(models[j], dataFiltered)
val model : Classifier = MLUtils().trainFullData(models[j], dataFiltered)

MLUtils().saveModel(project, models[j], MODEL_SAVED)
MLUtils().saveModel(project, model, MODEL_SAVED)
}
}

Expand Down
158 changes: 152 additions & 6 deletions src/main/kotlin/net/gosecure/spotbugs/datasource/ml/MLUtils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package net.gosecure.spotbugs.datasource.ml
import freemarker.template.Configuration
import freemarker.template.TemplateException
import freemarker.template.TemplateExceptionHandler
import org.apache.commons.codec.digest.DigestUtils
import org.apache.maven.project.MavenProject
import org.w3c.dom.Element
import org.xml.sax.SAXException
import weka.classifiers.Classifier
import weka.classifiers.Evaluation
import weka.core.Instances
Expand All @@ -14,6 +17,13 @@ import weka.filters.unsupervised.attribute.Add
import weka.filters.unsupervised.attribute.Remove
import java.io.*
import java.util.*
import javax.xml.parsers.DocumentBuilderFactory
import javax.xml.parsers.ParserConfigurationException
import javax.xml.transform.OutputKeys
import javax.xml.transform.TransformerException
import javax.xml.transform.TransformerFactory
import javax.xml.transform.dom.DOMSource
import javax.xml.transform.stream.StreamResult

class MLUtils {

Expand All @@ -29,6 +39,21 @@ class MLUtils {
return File(completeFileName)
}

fun readFile(file:File):List<String> {
val result = ArrayList<String>()
val fr = FileReader(file)
val br = BufferedReader(fr)
var line = br.readLine()
while (line != null)
{
result.add(line)
line = br.readLine()
}
br.close()
fr.close()
return result
}


@Throws(IOException::class)
fun csvToArff(csv: File, arff: File) {
Expand Down Expand Up @@ -60,7 +85,7 @@ class MLUtils {

val options = arrayOfNulls<String>(2)
options[0] = "-R"
options[1] = "1,2,3,4,5,15"
options[1] = "1,2,3,4,5,7,15"
val remove = Remove()
remove.setOptions(options)
remove.setInputFormat(data)
Expand Down Expand Up @@ -99,6 +124,23 @@ class MLUtils {
weka.core.SerializationHelper.write(name, model)
}

@Throws(Exception::class)
fun splitDataset(data: Instances) : Array<Instances> {
val dataTrain = data
val dataPredict = data

for (i in data.numInstances() - 1..0) {
if (data.instance(i).stringValue(data.numAttributes() - 1).equals("GOOD") ||
data.instance(i).stringValue(data.numAttributes() - 1).equals("BAD"))
dataPredict.delete(i)
else dataTrain.delete(i)
}
val dataSplit = arrayOf<Instances>()
dataSplit[0] = dataTrain
dataSplit[1] = dataPredict

return dataSplit
}

@Throws(Exception::class)
fun trainStats(project: MavenProject, cfg: Configuration, model: Classifier, data: Instances) {
Expand Down Expand Up @@ -129,12 +171,15 @@ class MLUtils {

// After training, make predictions on instances, and print the prediction and real values
@Throws(Exception::class)
fun makePredictions(project: MavenProject, cfg: Configuration, unfiltered: Instances, unlabeled: Instances, model: Classifier) {
fun makePredictions(project: MavenProject, cfg: Configuration, unfiltered: Instances, unlabeled: Instances, model: Classifier, fileInput:String, fileOutput:String) {

val labeled = Instances(unlabeled)

val issues = ArrayList<Issue>()
var number = 0

val predictions = arrayOfNulls<String>(unlabeled.numInstances())
val issuesToRemove = ArrayList<Issue>()

for (i in 0 until (unlabeled.numInstances() - 1)) {

val newInst = unlabeled.instance(i)
Expand All @@ -144,19 +189,120 @@ class MLUtils {

val predString = labeled.classAttribute().value(predNb.toInt())
val pred = model.distributionForInstance(labeled.get(i))
predictions[i] = predString

//Instances classified with a probability < 90%
if (Math.max(pred[0], pred[1]) < 0.9) {
if ((Math.max(pred[0], pred[1]) < 0.9) || predString.equals("GOOD")) {
val sourceFile = unfiltered.instance(i).stringValue(0) //Source File Attribute 1
val line = Integer.toString(unfiltered.instance(i).value(1).toInt()) //Line Attribute 2
val bugType = unfiltered.instance(i).stringValue(5) //BugType Attribute 5
issues.add(Issue(sourceFile, line, bugType))
number++
if (Math.max(pred[0], pred[1]) < 0.9){
issues.add(Issue(sourceFile, line, bugType))
number++
}
if(predString.equals("GOOD")){
issuesToRemove.add(Issue(sourceFile, line, bugType))
}
}
}
val fileResults = project.build.directory + "/spotbugs-ml/" + fileOutput
resultsToCsv(project, fileInput, fileResults, predictions)
parserXml(project, issuesToRemove)
outputHtmlPredict(project, cfg, issues, number)
}

fun resultsToCsv(project: MavenProject, fileInput:String, fileResults:String, predictions: Array<String?>) {
val file = getResource(project, fileInput)
val lines = readFile(file)
val data = ArrayList<Array<String>>(lines.size)
val sep = ','.toString()
for (line in lines) {
val oneData = line.split((sep).toRegex()).dropLastWhile({ it.isEmpty() }).toTypedArray()
data.add(oneData)
}
data.removeAt(0)
val file2 = File(fileResults)
val fw = FileWriter(file2)
var i = 0
for (oneData in data) {
fw.write(oneData[0] + "," + oneData[1] + "," + oneData[2] + "," +
oneData[3] + "," + oneData[4] + "," + oneData[5] + "," +
oneData[6] + "," + oneData[7] + "," + oneData[8] + "," +
oneData[9] + "," + oneData[10] + "," + oneData[11] + "," +
oneData[12] + "," + predictions[i] + System.getProperty("line.separator"))
i++
}
fw.flush()
fw.close()
}

fun sha1(str: String): String? {
var s: String? = null
try {
val data = str.toByteArray(charset("UTF-8"))

s = DigestUtils.sha1Hex(data)
} catch (ex: Exception) {
ex.printStackTrace()
}

return s
}

@Throws(ParserConfigurationException::class, IOException::class, SAXException::class, TransformerException::class)
fun parserXml(project: MavenProject, issues: List<Issue>) {
val filePath = project.build.directory + "/sonar/" +"findbugs-result.xml"
val xmlFile = File(filePath)
val dbFactory = DocumentBuilderFactory.newInstance()
val dBuilder = dbFactory.newDocumentBuilder()
val doc = dBuilder.parse(xmlFile)
doc.documentElement.normalize()

val nodes = doc.getElementsByTagName("BugCollection")
val element = nodes.item(0) as Element
val bugs = element.getElementsByTagName("BugInstance")

//create first hashmap
val hmap1 = HashMap<String?, Element?>()
for (i in 0..bugs.length - 1) {
val emp = bugs.item(i) as Element
val newNode = emp.firstChild
var goodNode = newNode.nextSibling
while (!goodNode.nodeName.equals("SourceLine")){
goodNode = goodNode.nextSibling
}
val sourceNode = goodNode as Element
val sourceFile = "src/" + sourceNode.getAttribute("sourcepath")
val line = sourceNode.getAttribute("start")
val bugType = emp.getAttribute("type")
hmap1.put(sha1(sourceFile + line + bugType), emp)
}

//create second hashmap
val hmap2 = HashMap<String?, Issue?>()
for (i in issues.indices) {
hmap2.put(sha1(issues.get(i).sourceFile + issues.get(i).line + issues.get(i).bugType), issues.get(i))
}

//iterate through hmap2 and suppress elem in hmap1
val it = hmap2.entries.iterator()
while (it.hasNext()) {
val pair = it.next() as Map.Entry<*, *>
val elem = hmap1.get(pair.key)
elem?.parentNode?.removeChild(elem)
}

doc.documentElement.normalize()
val transformerFactory = TransformerFactory.newInstance()
val transformer = transformerFactory.newTransformer()
val source = DOMSource(doc)
val result = StreamResult(File(project.build.directory + "/sonar/" +"findbugs-result-updated.xml"))
transformer.setOutputProperty(OutputKeys.INDENT, "yes")
transformer.transform(source, result)
println("XML file updated successfully")

}

@Throws(IOException::class)
fun initConfig(): Configuration {
val cfg = Configuration()
Expand Down