도미 데이터 준비하기
- 생선 데이터는 캐글 https://www.kaggle.com/aungpyaeap/fish-market 에 공개
- 도미의 길이, 무게 데이터는 http://bit.ly/bream_list 에서 복사
bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0,
31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0,
35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]
bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0,
500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0,
700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]
- 도미 데이터 산점도 표시
- matplotlib(맷플롯립) 패키지 불러오기
- scatter() 함수 사용 : 산점도 그리기
import matplotlib.pyplot as plt # matplotlib의 pyplot 함수를 plt로 줄여서 사용
plt.scatter(bream_length, bream_weight)
plt.xlabel('length') # x축은 길이
plt.ylabel('weight') # y축은 무게
plt.show()
빙어 데이터 준비하기
- 빙어의 길이, 무게 데이터는 http://bit.ly/smelt_list 에서 복사
smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]
- 도미와 빙어 데이터 산점도 표시
plt.scatter(bream_length, bream_weight)
plt.scatter(smelt_length, smelt_weight) # 연속해서 scatter 함수 작성하면 하나의 산점도로 출력
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
K-최근접 이웃 알고리즘(KNeighborsClassifier)
: 주위의 다른 데이터를 보고 다수를 차지하는 것을 정답으로 처리
: fit() 메소드에 전달한 데이터를 모두 저장하고 있다가 새로운 데이터가 등장하면 가장 가까운 데이터 참고
- 도미와 빙어 데이터 하나로 합치기
: 두 리스트를 +로 더하면 하나의 리스트로 합쳐짐
length = bream_length+smelt_length # length = [도미 35개의 길이, 빙어 14개의 길이]
weight = bream_weight+smelt_weight # weight = [도미 35개의 무게, 빙어 14개의 무게]
- length와 weight 리스트를 2차원 리스트로 만들기
: scikit-learn(사이킷런) 패키지를 사용하려면 2차원 리스트 필요
: zip() 함수는 나열된 리스트 각각에서 하나씩 원소를 꺼내 반환
- 2차원 리스트 만드는 법
: for문은 zip() 함수로 리스트1과 리스트2에서 원소를 하나씩 꺼내 a, b에 할당
: [a, b]가 하나의 원소로 구성된 리스트 만들어짐
[[a, b] for a, b in zip(리스트1, 리스트2)]
fish_data = [[l, w] for l, w in zip(length, weight)]
print(fish_data)
# [[25.4, 242.0],
# [26.3, 290.0],
# ...
# [15.0, 19.9]]
- 정답 리스트 생성
: 도미 데이터를 1(찾는 대상)로 표시하고, 빙어 데이터를 0으로 표시
fish_target = [1]*35 + [0]*14
- KNeighborsClassifier 불러오기
: scikit-learn 패키지에서 K-최근접 이웃 알고리즘 구현한 클래스인 KNeighborsClassifier 불러오기
- 전체 패키지(모듈)에서 특정 클래스만 불러오기
from 패키지(모듈) import 클래스
from sklearn.neighbors import KNeighborsClassifier
- 객체 생성
kn = KNeighborsClassifier()
- 알고리즘 훈련
: scikit-learn의 fit() 메소드는 주어진 데이터로 알고리즘 훈련
kn.fit(fish_data, fish_target)
- 훈련 평가
: scikit-learn의 score() 메소드는 모델 평가(0~1 사이)
: 1.0이 나오면 정확도 100%
kn.score(fish_data, fish_target) # 1.0
알고리즘 적용
- 길이가 30이고, 무게가 600인 데이터를 삼각형으로 산점도 표시
- 새로운 데이터 예측
: 길이 30이고, 무게 600인 데이터를 알고리즘에 적용
: 알고리즘이 제대로 되었다면 도미(1) 출력
kn.predict([[30, 600]]) # array([1])
문제점
: K-최근접 이웃 알고리즘은 가까운 데이터(기본값 5)를 참고하여 데이터 구분
: but, n_neighboers 매개변수를 모든 데이터 수로 정하면 어떤 데이터를 넣든 무조건 다수의 데이터로 예측
kn49 = KNeighborsClassifier(n_neighbors=49) # 참고 데이터를 49개로 함
kn49.fit(fish_data, fish_target)
kn49.score(fish_data, fish_target) # 0.7142857142857143
※kn49는 49개의 데이터 중 35개의 도미 데이터만 올바르게 맞히므로 35/49(0.7142857...)의 정확도를 보임