Skip to content

Commit 115b517

Browse files
authored
Merge pull request PaddlePaddle#814 from littletomatodonkey/add_tia
add tia aug
2 parents 5d202e4 + 3a18b08 commit 115b517

File tree

3 files changed

+315
-38
lines changed

3 files changed

+315
-38
lines changed

ppocr/data/rec/img_tools.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from ppocr.utils.utility import initial_logger
2020
logger = initial_logger()
2121

22+
from .text_image_aug.augment import tia_distort, tia_stretch, tia_perspective
23+
2224

2325
def get_bounding_box_rect(pos):
2426
left = min(pos[0])
@@ -196,6 +198,9 @@ def make(self, w, h, ang):
196198
self.h = h
197199

198200
self.perspective = True
201+
self.stretch = True
202+
self.distort = True
203+
199204
self.crop = True
200205
self.affine = False
201206
self.reverse = True
@@ -299,41 +304,40 @@ def warp(img, ang):
299304
config.make(w, h, ang)
300305
new_img = img
301306

307+
prob = 0.4
308+
309+
if config.distort:
310+
img_height, img_width = img.shape[0:2]
311+
if random.random() <= prob and img_height >= 20 and img_width >= 20:
312+
new_img = tia_distort(new_img, random.randint(3, 6))
313+
314+
if config.stretch:
315+
img_height, img_width = img.shape[0:2]
316+
if random.random() <= prob and img_height >= 20 and img_width >= 20:
317+
new_img = tia_stretch(new_img, random.randint(3, 6))
318+
302319
if config.perspective:
303-
tp = random.randint(1, 100)
304-
if tp >= 50:
305-
warpR, (r1, c1), ratio, dst = get_warpR(config)
306-
new_w = int(np.max(dst[:, 0])) - int(np.min(dst[:, 0]))
307-
new_img = cv2.warpPerspective(
308-
new_img,
309-
warpR, (int(new_w * ratio), h),
310-
borderMode=config.borderMode)
320+
if random.random() <= prob:
321+
new_img = tia_perspective(new_img)
322+
311323
if config.crop:
312324
img_height, img_width = img.shape[0:2]
313-
tp = random.randint(1, 100)
314-
if tp >= 50 and img_height >= 20 and img_width >= 20:
325+
if random.random() <= prob and img_height >= 20 and img_width >= 20:
315326
new_img = get_crop(new_img)
316-
if config.affine:
317-
warpT = get_warpAffine(config)
318-
new_img = cv2.warpAffine(
319-
new_img, warpT, (w, h), borderMode=config.borderMode)
327+
320328
if config.blur:
321-
tp = random.randint(1, 100)
322-
if tp >= 50:
329+
if random.random() <= prob:
323330
new_img = blur(new_img)
324331
if config.color:
325-
tp = random.randint(1, 100)
326-
if tp >= 50:
332+
if random.random() <= prob:
327333
new_img = cvtColor(new_img)
328334
if config.jitter:
329335
new_img = jitter(new_img)
330336
if config.noise:
331-
tp = random.randint(1, 100)
332-
if tp >= 50:
337+
if random.random() <= prob:
333338
new_img = add_gasuss_noise(new_img)
334339
if config.reverse:
335-
tp = random.randint(1, 100)
336-
if tp >= 50:
340+
if random.random() <= prob:
337341
new_img = 255 - new_img
338342
return new_img
339343

@@ -382,6 +386,7 @@ def process_image(img,
382386
% loss_type
383387
return (norm_img)
384388

389+
385390
def resize_norm_img_srn(img, image_shape):
386391
imgC, imgH, imgW = image_shape
387392

@@ -408,30 +413,39 @@ def resize_norm_img_srn(img, image_shape):
408413

409414
return np.reshape(img_black, (c, row, col)).astype(np.float32)
410415

411-
def srn_other_inputs(image_shape,
412-
num_heads,
413-
max_text_length,
414-
char_num):
416+
417+
def srn_other_inputs(image_shape, num_heads, max_text_length, char_num):
415418

416419
imgC, imgH, imgW = image_shape
417420
feature_dim = int((imgH / 8) * (imgW / 8))
418421

419-
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64')
420-
gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64')
422+
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
423+
(feature_dim, 1)).astype('int64')
424+
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
425+
(max_text_length, 1)).astype('int64')
421426

422-
lbl_weight = np.array([int(char_num-1)] * max_text_length).reshape((-1,1)).astype('int64')
427+
lbl_weight = np.array([int(char_num - 1)] * max_text_length).reshape(
428+
(-1, 1)).astype('int64')
423429

424-
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
425-
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length])
426-
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]) * [-1e9]
430+
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
431+
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
432+
[-1, 1, max_text_length, max_text_length])
433+
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
434+
[1, num_heads, 1, 1]) * [-1e9]
427435

428-
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape([-1, 1, max_text_length, max_text_length])
429-
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]) * [-1e9]
436+
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
437+
[-1, 1, max_text_length, max_text_length])
438+
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
439+
[1, num_heads, 1, 1]) * [-1e9]
430440

431441
encoder_word_pos = encoder_word_pos[np.newaxis, :]
432442
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
433443

434-
return [lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2]
444+
return [
445+
lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
446+
gsrm_slf_attn_bias2
447+
]
448+
435449

436450
def process_image_srn(img,
437451
image_shape,
@@ -453,14 +467,16 @@ def process_image_srn(img,
453467
return None
454468
else:
455469
if loss_type == "srn":
456-
text_padded = [int(char_num-1)] * max_text_length
470+
text_padded = [int(char_num - 1)] * max_text_length
457471
for i in range(len(text)):
458472
text_padded[i] = text[i]
459473
lbl_weight[i] = [1.0]
460474
text_padded = np.array(text_padded)
461475
text = text_padded.reshape(-1, 1)
462-
return (norm_img, text,encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2,lbl_weight)
476+
return (norm_img, text, encoder_word_pos, gsrm_word_pos,
477+
gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight)
463478
else:
464479
assert False, "Unsupport loss_type %s in process_image"\
465480
% loss_type
466-
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2)
481+
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
482+
gsrm_slf_attn_bias2)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# -*- coding:utf-8 -*-
2+
# Author: RubanSeven
3+
# Reference: https://github.com/RubanSeven/Text-Image-Augmentation-python
4+
5+
# import cv2
6+
import numpy as np
7+
from .warp_mls import WarpMLS
8+
9+
10+
def tia_distort(src, segment=4):
11+
img_h, img_w = src.shape[:2]
12+
13+
cut = img_w // segment
14+
thresh = cut // 3
15+
16+
src_pts = list()
17+
dst_pts = list()
18+
19+
src_pts.append([0, 0])
20+
src_pts.append([img_w, 0])
21+
src_pts.append([img_w, img_h])
22+
src_pts.append([0, img_h])
23+
24+
dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
25+
dst_pts.append(
26+
[img_w - np.random.randint(thresh), np.random.randint(thresh)])
27+
dst_pts.append(
28+
[img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
29+
dst_pts.append(
30+
[np.random.randint(thresh), img_h - np.random.randint(thresh)])
31+
32+
half_thresh = thresh * 0.5
33+
34+
for cut_idx in np.arange(1, segment, 1):
35+
src_pts.append([cut * cut_idx, 0])
36+
src_pts.append([cut * cut_idx, img_h])
37+
dst_pts.append([
38+
cut * cut_idx + np.random.randint(thresh) - half_thresh,
39+
np.random.randint(thresh) - half_thresh
40+
])
41+
dst_pts.append([
42+
cut * cut_idx + np.random.randint(thresh) - half_thresh,
43+
img_h + np.random.randint(thresh) - half_thresh
44+
])
45+
46+
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
47+
dst = trans.generate()
48+
49+
return dst
50+
51+
52+
def tia_stretch(src, segment=4):
53+
img_h, img_w = src.shape[:2]
54+
55+
cut = img_w // segment
56+
thresh = cut * 4 // 5
57+
58+
src_pts = list()
59+
dst_pts = list()
60+
61+
src_pts.append([0, 0])
62+
src_pts.append([img_w, 0])
63+
src_pts.append([img_w, img_h])
64+
src_pts.append([0, img_h])
65+
66+
dst_pts.append([0, 0])
67+
dst_pts.append([img_w, 0])
68+
dst_pts.append([img_w, img_h])
69+
dst_pts.append([0, img_h])
70+
71+
half_thresh = thresh * 0.5
72+
73+
for cut_idx in np.arange(1, segment, 1):
74+
move = np.random.randint(thresh) - half_thresh
75+
src_pts.append([cut * cut_idx, 0])
76+
src_pts.append([cut * cut_idx, img_h])
77+
dst_pts.append([cut * cut_idx + move, 0])
78+
dst_pts.append([cut * cut_idx + move, img_h])
79+
80+
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
81+
dst = trans.generate()
82+
83+
return dst
84+
85+
86+
def tia_perspective(src):
87+
img_h, img_w = src.shape[:2]
88+
89+
thresh = img_h // 2
90+
91+
src_pts = list()
92+
dst_pts = list()
93+
94+
src_pts.append([0, 0])
95+
src_pts.append([img_w, 0])
96+
src_pts.append([img_w, img_h])
97+
src_pts.append([0, img_h])
98+
99+
dst_pts.append([0, np.random.randint(thresh)])
100+
dst_pts.append([img_w, np.random.randint(thresh)])
101+
dst_pts.append([img_w, img_h - np.random.randint(thresh)])
102+
dst_pts.append([0, img_h - np.random.randint(thresh)])
103+
104+
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
105+
dst = trans.generate()
106+
107+
return dst

0 commit comments

Comments
 (0)