결정 트 리 ID3 알고리즘
package graph;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;
/**
* ID3
* http://www.blog.edu.cn/user2/huangbo929/archives/2006/1533249.shtml
*
* @author Leon.Chen
*
*/
public class DTree {
/**
*
*/
TreeNode root;
/**
*
*/
private boolean[] visable;
/**
*
*/
private static final int NO_FOUND = -1;
/**
*
*/
private Object[] trainingArray;
/**
*
*/
private int nodeIndex;
/**
* @param args
*/
@SuppressWarnings("boxing")
public static void main(String[] args) {
Object[] array = new Object[] {
new String[] { " ", " ", " ", " ", " ", " " },
new String[] { " ", " ", " ", " ", " ", " " },
new String[] { " ", " ", " ", " ", " ", " " },
new String[] { " ", " ", " ", " ", " ", " " } };
DTree tree = new DTree();
tree.create(array, 5);
System.out.println("===============END PRINT TREE===============");
String[] printData = new String[] { " ", " ", " ", " ", " " };
System.out.println("===============DECISION RESULT===============");
tree.compare(printData, tree.root);
}
/**
*
*
* @param printData
* @param node
*/
public void compare(String[] printData, TreeNode node) {
int index = getNodeIndex(node.nodeName);
if (index == NO_FOUND) {
System.out.println(node.nodeName);
System.out.println((node.percent * 100) + "%");
}
TreeNode[] childs = node.childNodes;
for (int i = 0; i < childs.length; i++) {
if (childs[i] != null) {
if (childs[i].parentArrtibute.equals(printData[index])) {
compare(printData, childs[i]);
}
}
}
}
/**
*
*
* @param array
* @param index
*/
public void create(Object[] array, int index) {
this.trainingArray = array;
init(array, index);
createDTree(array);
printDTree(root);
}
/**
*
*
* @param array
* @return Object[]
*/
@SuppressWarnings("boxing")
public Object[] getMaxGain(Object[] array) {
Object[] result = new Object[2];
double gain = 0;
int index = -1;
for (int i = 0; i < visable.length; i++) {
if (!visable[i]) {
double value = gain(array, i);
if (gain < value) {
gain = value;
index = i;
}
}
}
result[0] = gain;
result[1] = index;
if (index != -1) {
visable[index] = true;
}
return result;
}
/**
*
*
* @param array
*/
public void createDTree(Object[] array) {
Object[] maxgain = getMaxGain(array);
if (root == null) {
root = new TreeNode();
root.parent = null;
root.parentArrtibute = null;
root.arrtibutes = getArrtibutes(((Integer) maxgain[1]).intValue());
root.nodeName = getNodeName(((Integer) maxgain[1]).intValue());
root.childNodes = new TreeNode[root.arrtibutes.length];
insertTree(array, root);
}
}
/**
*
*
* @param array
* @param parentNode
*/
public void insertTree(Object[] array, TreeNode parentNode) {
String[] arrtibutes = parentNode.arrtibutes;
for (int i = 0; i < arrtibutes.length; i++) {
Object[] pickArray = pickUpAndCreateArray(array, arrtibutes[i],
getNodeIndex(parentNode.nodeName));
Object[] info = getMaxGain(pickArray);
double gain = ((Double) info[0]).doubleValue();
if (gain != 0) {
int index = ((Integer) info[1]).intValue();
TreeNode currentNode = new TreeNode();
currentNode.parent = parentNode;
currentNode.parentArrtibute = arrtibutes[i];
currentNode.arrtibutes = getArrtibutes(index);
currentNode.nodeName = getNodeName(index);
currentNode.childNodes = new TreeNode[currentNode.arrtibutes.length];
parentNode.childNodes[i] = currentNode;
insertTree(pickArray, currentNode);
} else {
TreeNode leafNode = new TreeNode();
leafNode.parent = parentNode;
leafNode.parentArrtibute = arrtibutes[i];
leafNode.arrtibutes = new String[0];
leafNode.nodeName = getLeafNodeName(pickArray);
leafNode.childNodes = new TreeNode[0];
parentNode.childNodes[i] = leafNode;
double percent = 0;
String[] arrs = getArrtibutes(this.nodeIndex);
for (int j = 0; j < arrs.length; j++) {
if (leafNode.nodeName.equals(arrs[j])) {
Object[] subo = pickUpAndCreateArray(pickArray,
arrs[j], this.nodeIndex);
Object[] o = pickUpAndCreateArray(this.trainingArray,
arrs[j], this.nodeIndex);
double subCount = subo.length;
percent = subCount / o.length;
}
}
leafNode.percent = percent;
}
}
}
/**
*
*
* @param node
*/
public void printDTree(TreeNode node) {
System.out.println(node.nodeName);
TreeNode[] childs = node.childNodes;
for (int i = 0; i < childs.length; i++) {
if (childs[i] != null) {
System.out.println(childs[i].parentArrtibute);
printDTree(childs[i]);
}
}
}
/**
*
*
* @param dataArray
* @param index
*/
public void init(Object[] dataArray, int index) {
this.nodeIndex = index;
//
visable = new boolean[((String[]) dataArray[0]).length];
for (int i = 0; i < visable.length; i++) {
if (i == index) {
visable[i] = true;
} else {
visable[i] = false;
}
}
}
/**
*
*
* @param array
* @param arrtibute
* @param index
* @return Object[]
*/
public Object[] pickUpAndCreateArray(Object[] array, String arrtibute,
int index) {
List<String[]> list = new ArrayList<String[]>();
for (int i = 0; i < array.length; i++) {
String[] strs = (String[]) array[i];
if (strs[index].equals(arrtibute)) {
list.add(strs);
}
}
return list.toArray();
}
/**
* Entropy(S)
*
* @param array
* @param index
* @return double
*/
public double gain(Object[] array, int index) {
String[] playBalls = getArrtibutes(this.nodeIndex);
int[] counts = new int[playBalls.length];
for (int i = 0; i < counts.length; i++) {
counts[i] = 0;
}
for (int i = 0; i < array.length; i++) {
String[] strs = (String[]) array[i];
for (int j = 0; j < playBalls.length; j++) {
if (strs[this.nodeIndex].equals(playBalls[j])) {
counts[j]++;
}
}
}
/**
* Entropy(S) = S -p(I) log2 p(I)
*/
double entropyS = 0;
for (int i = 0; i < counts.length; i++) {
entropyS += DTreeUtil.sigma(counts[i], array.length);
}
String[] arrtibutes = getArrtibutes(index);
/**
* total ((|Sv| / |S|) * Entropy(Sv))
*/
double sv_total = 0;
for (int i = 0; i < arrtibutes.length; i++) {
sv_total += entropySv(array, index, arrtibutes[i], array.length);
}
return entropyS - sv_total;
}
/**
* ((|Sv| / |S|) * Entropy(Sv))
*
* @param array
* @param index
* @param arrtibute
* @param allTotal
* @return double
*/
public double entropySv(Object[] array, int index, String arrtibute,
int allTotal) {
String[] playBalls = getArrtibutes(this.nodeIndex);
int[] counts = new int[playBalls.length];
for (int i = 0; i < counts.length; i++) {
counts[i] = 0;
}
for (int i = 0; i < array.length; i++) {
String[] strs = (String[]) array[i];
if (strs[index].equals(arrtibute)) {
for (int k = 0; k < playBalls.length; k++) {
if (strs[this.nodeIndex].equals(playBalls[k])) {
counts[k]++;
}
}
}
}
int total = 0;
double entropySv = 0;
for (int i = 0; i < counts.length; i++) {
total += counts[i];
}
for (int i = 0; i < counts.length; i++) {
entropySv += DTreeUtil.sigma(counts[i], total);
}
return DTreeUtil.getPi(total, allTotal) * entropySv;
}
/**
*
*
* @param index
* @return String[]
*/
@SuppressWarnings("unchecked")
public String[] getArrtibutes(int index) {
TreeSet<String> set = new TreeSet<String>(new SequenceComparator());
for (int i = 0; i < trainingArray.length; i++) {
String[] strs = (String[]) trainingArray[i];
set.add(strs[index]);
}
String[] result = new String[set.size()];
return set.toArray(result);
}
/**
*
*
* @param index
* @return String
*/
public String getNodeName(int index) {
String[] strs = new String[] { " ", " ", " ", " ", " ", " " };
for (int i = 0; i < strs.length; i++) {
if (i == index) {
return strs[i];
}
}
return null;
}
/**
*
*
* @param array
* @return String
*/
public String getLeafNodeName(Object[] array) {
if (array != null && array.length > 0) {
String[] strs = (String[]) array[0];
return strs[nodeIndex];
}
return null;
}
/**
*
*
* @param name
* @return int
*/
public int getNodeIndex(String name) {
String[] strs = new String[] { " ", " ", " ", " ", " ", " " };
for (int i = 0; i < strs.length; i++) {
if (name.equals(strs[i])) {
return i;
}
}
return NO_FOUND;
}
}
package graph;
/**
* @author Leon.Chen
*/
public class TreeNode {
/**
*
*/
TreeNode parent;
/**
*
*/
String parentArrtibute;
/**
*
*/
String nodeName;
/**
*
*/
String[] arrtibutes;
/**
*
*/
TreeNode[] childNodes;
/**
*
*/
double percent;
}
package graph;
/**
* @author Leon.Chen
*/
public class DTreeUtil {
/**
* Info(T)=(i=1...k)pi*log(2)pi
*
* @param x
* @param total
* @return double
*/
public static double sigma(int x, int total) {
if (x == 0) {
return 0;
}
double x_pi = getPi(x, total);
return -(x_pi * logYBase2(x_pi));
}
/**
* log2y
*
* @param y
* @return double
*/
public static double logYBase2(double y) {
return Math.log(y) / Math.log(2);
}
/**
* pi (= / )
*
* @param x
* @param total
* @return double
*/
public static double getPi(int x, int total) {
return x * Double.parseDouble("1.0") / total;
}
}
package graph;
import java.util.Comparator;
/**
* @author Leon.Chen
*
*/
@SuppressWarnings("unchecked")
public class SequenceComparator implements Comparator {
public int compare(Object o1, Object o2) throws ClassCastException {
String str1 = (String) o1;
String str2 = (String) o2;
return str1.compareTo(str2);
}
}
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
【Codility Lesson3】FrogJmpA small frog wants to get to the other side of the road. The frog is currently located at position X and wants to get to...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.