-
Notifications
You must be signed in to change notification settings - Fork 92
/
display_results.py
94 lines (79 loc) · 3.22 KB
/
display_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
import os
import re
import imageio
import matplotlib.pyplot as plt
import moviepy.editor as mvp
import numpy as np
import pydiffvg
import torch
from IPython.display import Image as Image_colab
from IPython.display import display, SVG
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument("--target_file", type=str,
help="target image file, located in <target_images>")
parser.add_argument("--num_strokes", type=int)
args = parser.parse_args()
def read_svg(path_svg, multiply=False):
device = torch.device("cuda" if (
torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu")
canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(
path_svg)
if multiply:
canvas_width *= 2
canvas_height *= 2
for path in shapes:
path.points *= 2
path.stroke_width *= 2
_render = pydiffvg.RenderFunction.apply
scene_args = pydiffvg.RenderFunction.serialize_scene(
canvas_width, canvas_height, shapes, shape_groups)
img = _render(canvas_width, # width
canvas_height, # height
2, # num_samples_x
2, # num_samples_y
0, # seed
None,
*scene_args)
img = img[:, :, 3:4] * img[:, :, :3] + \
torch.ones(img.shape[0], img.shape[1], 3,
device=device) * (1 - img[:, :, 3:4])
img = img[:, :, :3]
return img
abs_path = os.path.abspath(os.getcwd())
result_path = f"{abs_path}/output_sketches/{os.path.splitext(args.target_file)[0]}"
svg_files = os.listdir(result_path)
svg_files = [f for f in svg_files if "best.svg" in f and f"{args.num_strokes}strokes" in f]
svg_output_path = f"{result_path}/{svg_files[0]}"
target_path = f"{svg_output_path[:-9]}/input.png"
sketch_res = read_svg(svg_output_path, multiply=True).cpu().numpy()
sketch_res = Image.fromarray((sketch_res * 255).astype('uint8'), 'RGB')
input_im = Image.open(target_path).resize((224,224))
display(input_im)
display(SVG(svg_output_path))
p = re.compile("_best")
best_sketch_dir = ""
for m in p.finditer(svg_files[0]):
best_sketch_dir += svg_files[0][0: m.start()]
sketches = []
cur_path = f"{result_path}/{best_sketch_dir}"
sketch_res.save(f"{cur_path}/final_sketch.png")
print(f"You can download the result sketch from {cur_path}/final_sketch.png")
if not os.path.exists(f"{cur_path}/svg_to_png"):
os.mkdir(f"{cur_path}/svg_to_png")
if os.path.exists(f"{cur_path}/config.npy"):
config = np.load(f"{cur_path}/config.npy", allow_pickle=True)[()]
inter = config["save_interval"]
loss_eval = np.array(config['loss_eval'])
inds = np.argsort(loss_eval)
intervals = list(range(0, (inds[0] + 1) * inter, inter))
for i_ in intervals:
path_svg = f"{cur_path}/svg_logs/svg_iter{i_}.svg"
sketch = read_svg(path_svg, multiply=True).cpu().numpy()
sketch = Image.fromarray((sketch * 255).astype('uint8'), 'RGB')
# print("{0}/iter_{1:04}.png".format(cur_path, int(i_)))
sketch.save("{0}/{1}/iter_{2:04}.png".format(cur_path, "svg_to_png", int(i_)))
sketches.append(sketch)
imageio.mimsave(f"{cur_path}/sketch.gif", sketches)
print(cur_path)