|
| 1 | +import xml.etree.ElementTree as ET |
| 2 | +from os import getcwd |
| 3 | +import os |
| 4 | + |
| 5 | +save_folder = 'model_data\\' |
| 6 | +dataset_train = 'CSGO_images\\' |
| 7 | +dataset_file = save_folder+'NO_CLASSES.txt' |
| 8 | +classes_file = dataset_file[:-4]+'_classes.txt' |
| 9 | + |
| 10 | + |
| 11 | +CLS = os.listdir(dataset_train) |
| 12 | +classes =[dataset_train+CLASS for CLASS in CLS] |
| 13 | +wd = getcwd() |
| 14 | + |
| 15 | +CLASSES = [] |
| 16 | +def GetClassesNames(fullname): |
| 17 | + global CLASSES |
| 18 | + in_file = open(fullname) |
| 19 | + tree=ET.parse(in_file) |
| 20 | + root = tree.getroot() |
| 21 | + for i, obj in enumerate(root.iter('object')): |
| 22 | + name = obj.find('name').text |
| 23 | + if name not in CLASSES: |
| 24 | + CLASSES.append(name) |
| 25 | + |
| 26 | +def test(fullname): |
| 27 | + bb = "" |
| 28 | + in_file = open(fullname) |
| 29 | + tree=ET.parse(in_file) |
| 30 | + root = tree.getroot() |
| 31 | + filename = root.find('filename').text |
| 32 | + for i, obj in enumerate(root.iter('object')): |
| 33 | + difficult = obj.find('difficult').text |
| 34 | + cls = obj.find('name').text |
| 35 | + if cls not in CLS or int(difficult)==1: |
| 36 | + continue |
| 37 | + cls_id = CLS.index(cls) |
| 38 | + xmlbox = obj.find('bndbox') |
| 39 | + b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), int(xmlbox.find('xmax').text), int(xmlbox.find('ymax').text)) |
| 40 | + bb += (" " + ",".join([str(a) for a in b]) + ',' + str(cls_id)) |
| 41 | + |
| 42 | + if bb != "": |
| 43 | + list_file = open(dataset_file, 'a') |
| 44 | + file_string = str(fullname)[:-4]+filename[-4:]+bb+'\n' |
| 45 | + list_file.write(file_string) |
| 46 | + list_file.close() |
| 47 | + |
| 48 | + |
| 49 | + |
| 50 | +for CLASS in classes: |
| 51 | + if not CLASS.endswith('.xml'): |
| 52 | + continue |
| 53 | + fullname = os.getcwd()+'\\'+CLASS#+'\\'+filename |
| 54 | + print(fullname) |
| 55 | + GetClassesNames(fullname) |
| 56 | + |
| 57 | +CLASSES.sort() |
| 58 | +CLS = CLASSES.copy() |
| 59 | +print(CLS) |
| 60 | + |
| 61 | +for CLASS in classes: |
| 62 | + if not CLASS.endswith('.xml'): |
| 63 | + continue |
| 64 | + fullname = os.getcwd()+'\\'+CLASS#+'\\'+filename |
| 65 | + test(fullname) |
| 66 | + |
| 67 | +for CLASS in CLS: |
| 68 | + list_file = open(classes_file, 'a') |
| 69 | + file_string = str(CLASS)+"\n" |
| 70 | + list_file.write(file_string) |
| 71 | + list_file.close() |
0 commit comments