python 으로 화훼 식별 시스템 을 구축 하 다.
1.데이터 세트 가 져 오기
사용 절 차 는 다음 과 같다.
*(1)dataset 폴 더 아래 새 폴 더 만 들 기"flowerdata"
*(2)링크 를 클릭 하여 꽃 분류 데이터 세트 다운 로드 를 다운로드 합 니 다.tensorflow.org/example\im…
*(3)압축 해제 데이터 flowerdata 폴 더 아래
*(4)"split"실행data.py"스 크 립 트 는 자동 으로 데이터 세트 를 훈련 집합 train 과 검증 집합 val 로 나 눕 니 다.
split_data.py
import os
from shutil import copy, rmtree
import random
def mk_file(file_path: str):
if os.path.exists(file_path):
# ,
rmtree(file_path)
os.makedirs(file_path)
def main():
#
random.seed(0)
# 10%
split_rate = 0.1
# flower_photos
cwd = os.getcwd()
data_root = os.path.join(cwd, "flower_data")
origin_flower_path = os.path.join(data_root, "flower_photos")
assert os.path.exists(origin_flower_path)
flower_class = [cla for cla in os.listdir(origin_flower_path)
if os.path.isdir(os.path.join(origin_flower_path, cla))]
#
train_root = os.path.join(data_root, "train")
mk_file(train_root)
for cla in flower_class:
#
mk_file(os.path.join(train_root, cla))
#
val_root = os.path.join(data_root, "val")
mk_file(val_root)
for cla in flower_class:
#
mk_file(os.path.join(val_root, cla))
for cla in flower_class:
cla_path = os.path.join(origin_flower_path, cla)
images = os.listdir(cla_path)
num = len(images)
#
eval_index = random.sample(images, k=int(num*split_rate))
for index, image in enumerate(images):
if image in eval_index:
#
image_path = os.path.join(cla_path, image)
new_path = os.path.join(val_root, cla)
copy(image_path, new_path)
else:
#
image_path = os.path.join(cla_path, image)
new_path = os.path.join(train_root, cla)
copy(image_path, new_path)
print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
print()
print("processing done!")
if __name__ == '__main__':
main()
2.신경 망 모형model.py
import torch.nn as nn
import torch
class AlexNet(nn.Module):
def __init__(self, num_classes=1000, init_weights=False):
super(AlexNet, self).__init__()
# nn.Sequential() ,
self.features = nn.Sequential( #
nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55]
nn.ReLU(inplace=True), # ,
nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27]
nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13]
nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6]
)
self.classifier = nn.Sequential( #
nn.Dropout(p=0.5), # Dropout , 0.5
nn.Linear(128 * 6 * 6, 2048),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(2048, 2048),
nn.ReLU(inplace=True),
nn.Linear(2048, num_classes),
)
if init_weights:
self._initialize_weights()
#
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, start_dim=1) #
x = self.classifier(x)
return x
# , pytorch
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d): #
nn.init.kaiming_normal_(m.weight, mode='fan_out', # ( )kaiming_normal_
nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0) # 0
elif isinstance(m, nn.Linear): #
nn.init.normal_(m.weight, 0, 0.01) #
nn.init.constant_(m.bias, 0) # 0
3.신경 망 훈련train.py
#
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time
# GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open(os.path.join("train.log"), "a") as log:
log.write(str(device)+"
")
#
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224), # , 224×224
transforms.RandomHorizontalFlip(p=0.5), # , 0.5, ,
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
"val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
# 、
#
#train_set = torchvision.datasets.CIFAR10(root='./data', #
# train=True, #
# download=True, # True, , False
# transform=transform) #
#
#train_loader = torch.utils.data.DataLoader(train_set, #
# batch_size=50, #
# shuffle=False, #
# num_workers=0) # num_workers windows 0
#
data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
image_path = data_root + "/jqsj/data_set/flower_data/" # flower data_set path
#
train_dataset = datasets.ImageFolder(root=image_path + "/train",
transform=data_transform["train"])
train_num = len(train_dataset)
# batch_size
train_loader = torch.utils.data.DataLoader(train_dataset, #
batch_size=32, #
shuffle=True, #
num_workers=0) # , windows 0
# 、
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
transform=data_transform["val"])
val_num = len(validate_dataset)
#
validate_loader = torch.utils.data.DataLoader(validate_dataset, #
batch_size=32,
shuffle=True,
num_workers=0)
# :
# , : {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
# flower_list key val
cla_dict = dict((val, key) for key, val in flower_list.items())
# cla_dict json
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
#
net = AlexNet(num_classes=5, init_weights=True) # ( 5, )
net.to(device) # (GPU/CPU)
loss_function = nn.CrossEntropyLoss() #
optimizer = optim.Adam(net.parameters(), lr=0.0002) # ( , )
save_path = './AlexNet.pth'
best_acc = 0.0
for epoch in range(150):
########################################## train ###############################################
net.train() # Dropout
running_loss = 0.0 # epoch running_loss
time_start = time.perf_counter() # epoch
for step, data in enumerate(train_loader, start=0): # ,step 0
images, labels = data #
optimizer.zero_grad() #
outputs = net(images.to(device)) #
loss = loss_function(outputs, labels.to(device)) #
loss.backward() #
optimizer.step() #
running_loss += loss.item()
# ( )
rate = (step + 1) / len(train_loader) # = step / epoch step
a = "*" * int(rate * 50)
b = "." * int((1 - rate) * 50)
with open(os.path.join("train.log"), "a") as log:
log.write(str("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss))+"
")
print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
print()
with open(os.path.join("train.log"), "a") as log:
log.write(str('%f s' % (time.perf_counter()-time_start))+"
")
print('%f s' % (time.perf_counter()-time_start))
########################################### validate ###########################################
net.eval() # Dropout
acc = 0.0
with torch.no_grad():
for val_data in validate_loader:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1] # output ( )
acc += (predict_y == val_labels.to(device)).sum().item()
val_accurate = acc / val_num
#
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
with open(os.path.join("train.log"), "a") as log:
log.write(str('[epoch %d] train_loss: %.3f test_accuracy: %.3f
' %
(epoch + 1, running_loss / step, val_accurate))+"
")
print('[epoch %d] train_loss: %.3f test_accuracy: %.3f
' %
(epoch + 1, running_loss / step, val_accurate))
with open(os.path.join("train.log"), "a") as log:
log.write(str('Finished Training')+"
")
print('Finished Training')
훈련 결과 정확도 94%훈련 일 지 는 다음 과 같다.
4.모델 에 대한 예측
predict.py
import torch
이 어 그 중의 한 화훼 사진 을 식별 한 결과 다음 과 같다.하나의 식별 결과(daisy 데이지)와 정확도 1.0 은 100%(범 위 는 0~1 이 므 로 1 대응 100%)만 볼 수 있다.
이 신경 망 을 편리 하 게 사용 하기 위해 서,이어서 우 리 는 그것 을 시각 화 된 인터페이스 조작 으로 개발 했다.
2.화훼 식별 시스템 구축(flask)
1.페이지 구축:
2.신경 망 모형 호출
main.py
# coding:utf-8
from flask import Flask, render_template, request, redirect, url_for, make_response, jsonify
from werkzeug.utils import secure_filename
import os
import time
###################
#
import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
# read class_indict
try:
json_file = open('./class_indices.json', 'r')
class_indict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
# create model
model = AlexNet(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"
#, map_location='cpu'
model.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
# Dropout
model.eval()
###################
from datetime import timedelta
#
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'PNG', 'bmp'])
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
app = Flask(__name__)
#
app.send_file_max_age_default = timedelta(seconds=1)
#
def tran(img_path):
#
data_transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# load image
img = Image.open("pgy2.jpg")
#plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
return img
@app.route('/upload', methods=['POST', 'GET']) #
def upload():
path=""
if request.method == 'POST':
f = request.files['file']
if not (f and allowed_file(f.filename)):
return jsonify({"error": 1001, "msg": " , png、PNG、jpg、JPG、bmp"})
basepath = os.path.dirname(__file__) #
path = secure_filename(f.filename)
upload_path = os.path.join(basepath, 'static/images', secure_filename(f.filename)) # : ,
# upload_path = os.path.join(basepath, 'static/images','test.jpg') # : ,
print(path)
img = tran('static/images'+path)
##########################
#
with torch.no_grad():
# predict class
output = torch.squeeze(model(img)) # , batch
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
res = class_indict[str(predict_cla)]
pred = predict[predict_cla].item()
#print(class_indict[str(predict_cla)], predict[predict_cla].item())
res_chinese = ""
if res=="daisy":
res_chinese=" "
if res=="dandelion":
res_chinese=" "
if res=="roses":
res_chinese=" "
if res=="sunflower":
res_chinese=" "
if res=="tulips":
res_chinese=" "
#print('result:', class_indict[str(predict_class)], 'accuracy:', prediction[predict_class])
##########################
f.save(upload_path)
pred = pred*100
return render_template('upload_ok.html', path=path, res_chinese=res_chinese,pred = pred, val1=time.time())
return render_template('upload.html')
if __name__ == '__main__':
# app.debug = True
app.run(host='127.0.0.1', port=80,debug = True)
3.시스템 식별 결과
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title> - v1.0</title>
<link rel="stylesheet" type="text/css" href="../static/css/bootstrap.min.css" rel="external nofollow" >
<link rel="stylesheet" type="text/css" href="../static/css/fileinput.css" rel="external nofollow" >
<script src="../static/js/jquery-2.1.4.min.js"></script>
<script src="../static/js/bootstrap.min.js"></script>
<script src="../static/js/fileinput.js"></script>
<script src="../static/js/locales/zh.js"></script>
</head>
<body>
<h1 align="center"> - v1.0</h1>
<div align="center">
<form action="" enctype='multipart/form-data' method='POST'>
<input type="file" name="file" class="file" data-show-preview="false" style="margin-top:20px;"/>
<br>
<input type="submit" value=" " class="button-new btn btn-primary" style="margin-top:15px;"/>
</form>
<p style="size:15px;color:blue;"> :{{res_chinese}}</p>
</br>
<p style="size:15px;color:red;"> :{{pred}}%</p>
<img src="{{ './static/images/'+path }}" width="400" height="400" alt=""/>
</div>
</body>
</html>
4.시스템 시작:
python main.py
이 어 브 라 우 저 에서 브 라 우 저 에 접근 합 니 다.
http://127.0.0.1/upload
다음 화면 이 나타 납 니 다:마지막 으로 식별 과정의 움 직 이 는 그림.
3.총화
ok,이 화훼 시스템 은 이미 구축 되 었 습 니 다.아주 간단 하지 않 습 니까?저도 이 기계 시각 을 고 친 틈 을 타 이런 시스템 을 만 들 었 습 니 다.예전 의 지식 을 돌 이 켜 보 겠 습 니 다.하하 하.
이상 은 python 으로 화훼 식별 시스템 을 구축 하 는 상세 한 내용 입 니 다.python 화훼 식별 시스템 에 관 한 자 료 는 우리 의 다른 관련 글 을 주목 하 십시오!
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
로마 숫자를 정수로 또는 그 반대로 변환그 중 하나는 로마 숫자를 정수로 변환하는 함수를 만드는 것이었고 두 번째는 그 반대를 수행하는 함수를 만드는 것이었습니다. 문자만 포함합니다'I', 'V', 'X', 'L', 'C', 'D', 'M' ; 문자열이 ...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.