Last active
August 29, 2015 13:57
-
-
Save jkdeveyra/9544085 to your computer and use it in GitHub Desktop.
ERC Forest
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
case class ERCForestConf(nbTrees: Int = 8, treeConf: ERCTreeConf = ERCTreeConf()) | |
case class ERCForest(trees: IndexedSeq[ERCTree]) { | |
def assignIndexSeq(pt: Seq[Seq[Float]]): Seq[Int] = { | |
trees.map(_.query(pt)).flatten | |
} | |
def numClusters = trees.map(_.leavesSize).sum | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Extremely Randomized Clustering Tree | |
*/ | |
case class ERCTree(root: Node, leavesSize: Int, conf: ERCTreeConf) { | |
def query(seq: Seq[Seq[Float]]): Seq[Int] = { | |
val arr = new Array[Int](leavesSize) | |
for (s <- seq) { | |
val leaf = root.query(s).asInstanceOf[LeafNode] | |
arr(leaf.id) += 1 | |
} | |
arr | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package xtag.clustering.erc | |
import xtag.clustering.TrainingSet | |
import scala.util.Random | |
class ERCTreeBuilder(set: TrainingSet, conf: ERCTreeConf) { | |
var leavesSize = 0 | |
var id = 0 | |
def nextId() = { | |
id += 1 | |
id - 1 | |
} | |
def tree(set: TrainingSet, score: Double, depth: Int): Node = { | |
if (stopSplitting(set, depth)) { | |
leavesSize += 1 | |
LeafNode(id = nextId(), score = score) | |
} else { | |
var tries = 0 | |
var bestScore = 0.0 | |
var bestAttrib = 0 | |
var bestThreshold = 0f | |
var bestLeft = TrainingSet.empty(set.dimension) | |
var bestRight = TrainingSet.empty(set.dimension) | |
do { | |
tries += 1 | |
// Randomize attribute and threshold | |
val attrib = Random.nextInt(set.dimension) | |
val threshold = (conf.max - conf.min) * math.random.toFloat + conf.min | |
// Compute the score | |
val (left, right) = set.split(attrib, threshold) | |
val score = Score(left, right) | |
if (score >= bestScore) { | |
bestScore = score | |
bestAttrib = attrib | |
bestThreshold = threshold | |
bestLeft = left | |
bestRight = right | |
} | |
} while (bestScore < conf.sMin && tries < conf.tMax) | |
DecisionNode(bestAttrib, bestThreshold, tree(bestLeft, bestScore, depth - 1), tree(bestRight, bestScore, depth - 1)) | |
} | |
} | |
def build(): ERCTree = ERCTree(tree(set, 0, conf.depth), leavesSize, conf) | |
def Score(left: TrainingSet, right: TrainingSet) = { | |
val totalSize = left.size + right.size | |
entropy(left.size, totalSize) + entropy(right.size, totalSize) | |
} | |
def entropy(part: Double, total: Double) = { | |
val per = part / total | |
if (per == 0) | |
0.0 | |
else | |
-per * (math.log(per) / math.log(2)) | |
} | |
def stopSplitting(set: TrainingSet, level: Int) = { | |
set.isEmpty || set.size == 1 || set.sameAll || level == 0 | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package xtag.clustering | |
abstract class Node { | |
def query(seq: Seq[Float]): Node | |
def left: Node | |
def right: Node | |
def print(lvl: Int): String = "\t" * lvl + this.toString + "\n" + left.print(lvl + 1) + "\n" + right.print(lvl + 1) | |
} | |
case class DecisionNode(attrib: Int, threshold: Double, left: Node, right: Node) extends Node { | |
def query(seq: Seq[Float]) = | |
if (seq(attrib) < threshold) | |
left.query(seq) | |
else | |
right.query(seq) | |
override def toString = "x[%d] < %f".format(attrib, threshold) | |
} | |
case class LeafNode(id: Int = 0, score: Double) extends Node { | |
def query(seq: Seq[Float]) = this | |
def left = EmptyNode | |
def right = EmptyNode | |
override def toString = "Leaf(score = %f)".format(score) | |
override def print(lvl: Int) = "\t" * lvl + this.toString | |
} | |
object EmptyNode extends Node with Serializable { | |
def query(seq: Seq[Float]) = this | |
def left = this | |
def right = this | |
override def toString = "" | |
} | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package xtag.clustering | |
import scala.util.Random | |
import collection.mutable.ArrayBuffer | |
import xtag.util.Stopwatch | |
object TrainingSet { | |
def empty(dim: Int) = apply(dim, IndexedSeq.empty) | |
def apply(elem: IndexedSeq[Seq[Float]]): TrainingSet = apply(elem.head.size, elem) | |
def apply(dim: Int, elem: IndexedSeq[Seq[Float]]) = new TrainingSet(dim, elem, 0 until elem.size) | |
} | |
/** | |
* Forms a immutable trainingset from a given seqs of seq of float. | |
* | |
* @param dimension refers to the size of the seqs | |
*/ | |
case class TrainingSet(dimension: Int, elem: IndexedSeq[Seq[Float]], indices: IndexedSeq[Int]) { | |
require(dimension > 0) | |
lazy val max = elem.flatten.max | |
lazy val min = elem.flatten.min | |
lazy val elements = indices map elem.apply | |
def +(seq: Seq[Float]) = { | |
val concat = elem ++ IndexedSeq(seq) | |
val elemSize = elem.size | |
new TrainingSet(dimension, concat, indices ++ (elemSize until (elemSize + seq.size))) | |
} | |
def apply(i: Int) = elem(indices(i)) | |
def size = indices.size | |
def isEmpty = size == 0 | |
def distinct = elements.distinct | |
def sameAll = elements forall (_ == elements.head) | |
def split(attr: Int, threshold: Float): (TrainingSet, TrainingSet) = { | |
val leftIndex = new ArrayBuffer[Int](elem.size) | |
val rightIndex = new ArrayBuffer[Int](elem.size) | |
for (i <- indices) | |
if (elem(i)(attr) < threshold) | |
leftIndex += i | |
else | |
rightIndex += i | |
val left = TrainingSet(dimension, elem, leftIndex.toIndexedSeq) | |
val right = TrainingSet(dimension, elem, rightIndex.toIndexedSeq) | |
(left, right) | |
} | |
def grouped(n: Int): IndexedSeq[TrainingSet] = { | |
val randIndex = Random.shuffle(indices) | |
for (group <- randIndex.grouped(n).toIndexedSeq) yield | |
TrainingSet(dimension, elem, group) | |
} | |
def shuffled(): TrainingSet = new TrainingSet(dimension, elem, Random.shuffle(indices)) | |
def take(n: Int): TrainingSet = new TrainingSet(dimension, elem, indices.take(n)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment