python mean-shift 집합 알고리즘 구현
1.새 MeanShift.py 파일
import numpy as np
#
STOP_THRESHOLD = 1e-4
CLUSTER_THRESHOLD = 1e-1
#
def distance(a, b):
return np.linalg.norm(np.array(a) - np.array(b))
#
def gaussian_kernel(distance, bandwidth):
return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((distance / bandwidth)) ** 2)
# mean_shift
class mean_shift(object):
def __init__(self, kernel=gaussian_kernel):
self.kernel = kernel
def fit(self, points, kernel_bandwidth):
shift_points = np.array(points)
shifting = [True] * points.shape[0]
while True:
max_dist = 0
for i in range(0, len(shift_points)):
if not shifting[i]:
continue
p_shift_init = shift_points[i].copy()
shift_points[i] = self._shift_point(shift_points[i], points, kernel_bandwidth)
dist = distance(shift_points[i], p_shift_init)
max_dist = max(max_dist, dist)
shifting[i] = dist > STOP_THRESHOLD
if(max_dist < STOP_THRESHOLD):
break
cluster_ids = self._cluster_points(shift_points.tolist())
return shift_points, cluster_ids
def _shift_point(self, point, points, kernel_bandwidth):
shift_x = 0.0
shift_y = 0.0
scale = 0.0
for p in points:
dist = distance(point, p)
weight = self.kernel(dist, kernel_bandwidth)
shift_x += p[0] * weight
shift_y += p[1] * weight
scale += weight
shift_x = shift_x / scale
shift_y = shift_y / scale
return [shift_x, shift_y]
def _cluster_points(self, points):
cluster_ids = []
cluster_idx = 0
cluster_centers = []
for i, point in enumerate(points):
if(len(cluster_ids) == 0):
cluster_ids.append(cluster_idx)
cluster_centers.append(point)
cluster_idx += 1
else:
for center in cluster_centers:
dist = distance(point, center)
if(dist < CLUSTER_THRESHOLD):
cluster_ids.append(cluster_centers.index(center))
if(len(cluster_ids) < i + 1):
cluster_ids.append(cluster_idx)
cluster_centers.append(point)
cluster_idx += 1
return cluster_ids
2.상기 py 파일 호출
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 09 11:02:08 2018
@author: muli
"""
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
import random
import numpy as np
import MeanShift
def colors(n):
ret = []
for i in range(n):
ret.append((random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)))
return ret
def main():
centers = [[-1, -1], [-1, 1], [1, -1], [1, 1]]
X, _ = make_blobs(n_samples=300, centers=centers, cluster_std=0.4)
mean_shifter = MeanShift.mean_shift()
_, mean_shift_result = mean_shifter.fit(X, kernel_bandwidth=0.5)
np.set_printoptions(precision=3)
print('input: {}'.format(X))
print('assined clusters: {}'.format(mean_shift_result))
color = colors(np.unique(mean_shift_result).size)
for i in range(len(mean_shift_result)):
plt.scatter(X[i, 0], X[i, 1], color = color[mean_shift_result[i]])
plt.show()
if __name__ == '__main__':
main()
결 과 는 그림 과 같다.참조
이상 이 바로 본 고의 모든 내용 입 니 다.여러분 의 학습 에 도움 이 되 고 저 희 를 많이 응원 해 주 셨 으 면 좋 겠 습 니 다.
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
로마 숫자를 정수로 또는 그 반대로 변환그 중 하나는 로마 숫자를 정수로 변환하는 함수를 만드는 것이었고 두 번째는 그 반대를 수행하는 함수를 만드는 것이었습니다. 문자만 포함합니다'I', 'V', 'X', 'L', 'C', 'D', 'M' ; 문자열이 ...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.