Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions textattack/commands/augment_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,36 @@ def run(self, args):
# Read in CSV file as a list of dictionaries. Use the CSV sniffer to
# try and automatically infer the correct CSV format.
csv_file = open(args.input_csv, "r")

# mark where commas and quotes occur within the text value
def markQuotes(lines):
for row in lines:
row = row.replace('"', '"/')
yield row

dialect = csv.Sniffer().sniff(csv_file.readline(), delimiters=";,")
csv_file.seek(0)
rows = [
row
for row in csv.DictReader(
csv_file, dialect=dialect, skipinitialspace=True
markQuotes(csv_file),
dialect=dialect,
skipinitialspace=True,
)
]

# replace markings with quotations and commas
for row in rows:
for item in row:
i = 0
while i < len(row[item]):
if row[item][i] == "/":
if row[item][i - 1] == '"':
row[item] = row[item][:i] + row[item][i + 1 :]
else:
row[item] = row[item][:i] + '"' + row[item][i + 1 :]
i += 1

# Validate input column.
row_keys = set(rows[0].keys())
if args.input_column not in row_keys:
Expand All @@ -174,20 +196,30 @@ def run(self, args):
augmented_row = row.copy()
augmented_row[args.input_column] = augmentation
output_rows.append(augmented_row)

# Print to file.
with open(args.output_csv, "w") as outfile:
csv_writer = csv.writer(
outfile, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL
outfile, delimiter=",", quotechar="/", quoting=csv.QUOTE_MINIMAL
)
# Write header.
csv_writer.writerow(output_rows[0].keys())
# Write rows.
for row in output_rows:
csv_writer.writerow(row.values())

textattack.shared.logger.info(
f"Wrote {len(output_rows)} augmentations to {args.output_csv} in {time.time() - start_time}s."
)

# Remove extra markings in output file
with open(args.output_csv, "r") as file:
data = file.readlines()
for i in range(len(data)):
data[i] = data[i].replace("/", "")
with open(args.output_csv, "w") as file:
file.writelines(data)

@staticmethod
def register_subcommand(main_parser: ArgumentParser):
parser = main_parser.add_parser(
Expand Down