[혼공머신] K-최근접 이웃 알고리즘 - 도미와 빙어 분류 데이터

도미 데이터 준비하기

- 생선 데이터는 캐글 https://www.kaggle.com/aungpyaeap/fish-market 에 공개

- 도미의 길이, 무게 데이터는 http://bit.ly/bream_list 에서 복사

bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 
                31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 
                35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]
bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 
                500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 
                700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]

 

 

- 도미 데이터 산점도 표시

  1. matplotlib(맷플롯립) 패키지 불러오기
  2. scatter() 함수 사용 : 산점도 그리기
import matplotlib.pyplot as plt  # matplotlib의 pyplot 함수를 plt로 줄여서 사용

plt.scatter(bream_length, bream_weight)
plt.xlabel('length')  # x축은 길이
plt.ylabel('weight')  # y축은 무게
plt.show()

도미 데이터 산점도(선형적)

 

 

 

빙어 데이터 준비하기

- 빙어의 길이, 무게 데이터는 http://bit.ly/smelt_list 에서 복사

smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]

 

 

- 도미와 빙어 데이터 산점도 표시

plt.scatter(bream_length, bream_weight)
plt.scatter(smelt_length, smelt_weight)  # 연속해서 scatter 함수 작성하면 하나의 산점도로 출력
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

도미 : 파랑, 빙어 : 주황

 

 

 

K-최근접 이웃 알고리즘(KNeighborsClassifier)

: 주위의 다른 데이터를 보고 다수를 차지하는 것을 정답으로 처리

: fit() 메소드에 전달한 데이터를 모두 저장하고 있다가 새로운 데이터가 등장하면 가장 가까운 데이터 참고

 

 

- 도미와 빙어 데이터 하나로 합치기

: 두 리스트를 +로 더하면 하나의 리스트로 합쳐짐

length = bream_length+smelt_length  # length = [도미 35개의 길이, 빙어 14개의 길이]
weight = bream_weight+smelt_weight  # weight = [도미 35개의 무게, 빙어 14개의 무게]

 

 

- length와 weight 리스트를 2차원 리스트로 만들기

: scikit-learn(사이킷런) 패키지를 사용하려면 2차원 리스트 필요

: zip() 함수는 나열된 리스트 각각에서 하나씩 원소를 꺼내 반환

- 2차원 리스트 만드는 법
: for문은 zip() 함수로 리스트1과 리스트2에서 원소를 하나씩 꺼내 a, b에 할당

: [a, b]가 하나의 원소로 구성된 리스트 만들어짐

[[a, b] for a, b in zip(리스트1, 리스트2)]
fish_data = [[l, w] for l, w in zip(length, weight)]
print(fish_data)

# [[25.4, 242.0],
# [26.3, 290.0],
# ...
# [15.0, 19.9]]

 

 

- 정답 리스트 생성

: 도미 데이터를 1(찾는 대상)로 표시하고, 빙어 데이터를 0으로 표시

fish_target = [1]*35 + [0]*14

 

 

- KNeighborsClassifier 불러오기

: scikit-learn 패키지에서 K-최근접 이웃 알고리즘 구현한 클래스인 KNeighborsClassifier 불러오기

- 전체 패키지(모듈)에서 특정 클래스만 불러오기
from 패키지(모듈) import 클래스
from sklearn.neighbors import KNeighborsClassifier

 

 

- 객체 생성

kn = KNeighborsClassifier()

 

 

- 알고리즘 훈련

: scikit-learn의 fit() 메소드는 주어진 데이터로 알고리즘 훈련

kn.fit(fish_data, fish_target)

 

 

- 훈련 평가

: scikit-learn의 score() 메소드는 모델 평가(0~1 사이)

: 1.0이 나오면 정확도 100%

kn.score(fish_data, fish_target)  # 1.0

 

 

 

알고리즘 적용

- 길이가 30이고, 무게가 600인 데이터를 삼각형으로 산점도 표시

삼각형이 새로운 데이터

 

 

- 새로운 데이터 예측

: 길이 30이고, 무게 600인 데이터를 알고리즘에 적용

: 알고리즘이 제대로 되었다면 도미(1) 출력

kn.predict([[30, 600]])  # array([1])

 

 

 

문제점

: K-최근접 이웃 알고리즘은 가까운 데이터(기본값 5)를 참고하여 데이터 구분

: but, n_neighboers 매개변수를 모든 데이터 수로 정하면 어떤 데이터를 넣든 무조건 다수의 데이터로 예측

kn49 = KNeighborsClassifier(n_neighbors=49)  # 참고 데이터를 49개로 함
kn49.fit(fish_data, fish_target)
kn49.score(fish_data, fish_target)  # 0.7142857142857143

※kn49는 49개의 데이터 중 35개의 도미 데이터만 올바르게 맞히므로 35/49(0.7142857...)의 정확도를 보임