Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
add_dir_predict_for_csv
  • Loading branch information
v5a committed Apr 10, 2023
commit 234fd4b2ca1caa8c2fa1ec2efe69d17d9c0051d6
73 changes: 69 additions & 4 deletions data_analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -27,7 +27,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -49,10 +49,75 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'numpy.ndarray'>\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"label, top, left, bottom, right = 'jacket 0.51',538 ,528,707,616\n",
"data=(label, top, left, bottom, right)\n",
"array = np.asarray(data).reshape(1,5)\n",
"array = np.concatenate((array,array),axis=0)\n",
"print(type(array))\n",
"test=pd.DataFrame(array,columns=['label', 'top', 'left', 'bottom', 'right'])\n",
"test.to_csv('save1.csv', index=False, sep=',')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": []
"source": [
"import os\n",
"image_name = 'sadfhkj.jpg'\n",
"label, top, left, bottom, right = 'jacket 0.51',538 ,528,707,616\n",
"data=(image_name,label, top, left, bottom, right)\n",
"array = np.asarray(data).reshape(1,6)\n",
"array = np.concatenate((array,array),axis=0)\n",
"# image_name = np.asarray(image_name)\n",
"# array_image_name = np.concatenate((image_name,array),axis=1)\n",
"# label, top, left, bottom, right = array\n",
"# array_image_name = image_name + array\n",
"\n",
"test=pd.DataFrame(array,columns=['image_name','label', 'top', 'left', 'bottom', 'right'])\n",
"# test=pd.DataFrame(array)\n",
"if os.path.exists('save1.csv'):\n",
" test.to_csv('save1.csv', index=False, sep=',',mode='a', header=None)\n",
" \n",
"else:\n",
" test.to_csv('save1.csv', index=False, sep=',',mode='a')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"os.path.exists('save1.csv')"
]
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion get_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# 此处的classes_path用于指定需要测量VOC_map的类别
# 一般情况下与训练和预测所用的classes_path一致即可
#--------------------------------------------------------------------------------------#
classes_path = 'model_data/voc_classes.txt'
classes_path = 'model_data/cowboy_classes.txt'
#--------------------------------------------------------------------------------------#
# MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。
# 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
Expand Down
4 changes: 2 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
# 'predict_onnx' 表示利用导出的onnx模型进行预测,相关参数的修改在yolo.py_423行左右处的YOLO_ONNX
#----------------------------------------------------------------------------------------------------------#
mode = "predict"
mode = "dir_predict"
#-------------------------------------------------------------------------#
# crop 指定了是否在单张图片预测后对目标进行截取
# count 指定了是否进行目标的计数
Expand Down Expand Up @@ -158,7 +158,7 @@
if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
image_path = os.path.join(dir_origin_path, img_name)
image = Image.open(image_path)
r_image = yolo.detect_image(image)
r_image = yolo.detect_image(image,image_name=img_name)
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
Expand Down
27 changes: 23 additions & 4 deletions yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import time

import pandas as pd
import numpy as np

import cv2
import numpy as np
import torch
Expand Down Expand Up @@ -125,7 +128,7 @@ def generate(self, onnx=False):
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image, crop = False, count = False):
def detect_image(self, image, crop = False, count = False, image_name = None):
#---------------------------------------------------#
# 计算输入图片的高和宽
#---------------------------------------------------#
Expand Down Expand Up @@ -203,6 +206,8 @@ def detect_image(self, image, crop = False, count = False):
#---------------------------------------------------------#
# 图像绘制
#---------------------------------------------------------#
# data_array = np.array(('',538 ,528,707,616))
# data_array = np.asarray(data_array).reshape(1,5)
for i, c in list(enumerate(top_label)):
predicted_class = self.class_names[int(c)]
box = top_boxes[i]
Expand All @@ -216,12 +221,20 @@ def detect_image(self, image, crop = False, count = False):
right = min(image.size[0], np.floor(right).astype('int32'))

label = '{} {:.2f}'.format(predicted_class, score)
# print(label)

print(label, top, left, bottom, right)
data=([image_name],[label], [top], [left], [bottom], [right])
array = np.asarray(data).reshape(1,6)
if 'data_array' not in dir(): #查看变量有没有定义,没有就加一个定义
data_array = array
else:
data_array = np.concatenate((data_array,array),axis=0)


draw = ImageDraw.Draw(image)
label_size = draw.textsize(label, font)
label = label.encode('utf-8')
# print("sss")
print(label, top, left, bottom, right)


if top - label_size[1] >= 0:
text_origin = np.array([left, top - label_size[1]])
Expand All @@ -233,7 +246,13 @@ def detect_image(self, image, crop = False, count = False):
draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
del draw
# if image_name != None:

test=pd.DataFrame(data_array,columns=['image_name','label', 'top', 'left', 'bottom', 'right'])
if os.path.exists('save1.csv'):
test.to_csv('save1.csv', index=False, sep=',',mode='a', header=None)
else:
test.to_csv('save1.csv', index=False, sep=',',mode='a')
return image

def get_FPS(self, image, test_interval):
Expand Down