-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathslide_test.py
More file actions
137 lines (123 loc) · 5.76 KB
/
slide_test.py
File metadata and controls
137 lines (123 loc) · 5.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from unet.unet_ddpm import *
import cv2
import torchvision
from get_config import parse_args_and_config
import torch.nn as nn
import os
import numpy as np
from tqdm import tqdm
import random
def sliding_window(image, step_size, window_size):
patches = []
for y in range(0, image.shape[0] - window_size[0] + 1, step_size):
for x in range(0, image.shape[1] - window_size[1] + 1, step_size):
patch = image[y:y + window_size[0], x:x + window_size[1]]
patches.append(toTensor(patch).unsqueeze(0).cuda())
return patches
def reconstruct_image(patches, image_size, step_size, window_size):
reconstructed_image = torch.zeros(image_size).permute(2, 0, 1)
count_map = torch.zeros(image_size).permute(2, 0, 1)
idx = 0
for y in range(0, image_size[0] - window_size[0] + 1, step_size):
for x in range(0, image_size[1] - window_size[1] + 1, step_size):
# patch = patches[idx].squeeze(0)
patch = patches[idx]
reconstructed_image[:, y:y + window_size[0], x:x + window_size[1]] += patch
count_map[:, y:y + window_size[0], x:x + window_size[1]] += 1
idx += 1
count_map[count_map == 0] = 1
reconstructed_image /= count_map # float64
return reconstructed_image
def add_padding(image, window_size, step_size):
h, w = image.shape[:2]
pad_h = (h // step_size + 1) * step_size - h
pad_w = (w // step_size + 1) * step_size - w
padded_image = cv2.copyMakeBorder(image, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT)
return padded_image
# 加载模型
# sigmoid = nn.Sigmoid()
toTensor = torchvision.transforms.ToTensor()
_, config = parse_args_and_config()
modelG = ModelG(config).cuda()
# modelD = torchvision.models.resnet18(pretrained=True)
# modelD.fc = nn.Sequential(nn.Linear(512, 128), nn.LeakyReLU(inplace=True), nn.Linear(128, 32),
# nn.LeakyReLU(inplace=True), nn.Linear(32, 1))
# modelD.cuda()
# modelD.load_state_dict(torch.load("saved_models_stegastamp/modelD42000_20.63.pth", map_location='cpu'), strict=True)
# modelG.load_state_dict(torch.load("checkpoints/modelGsy175000_35.72.pth", map_location='cpu'), strict=True)
# modelG.load_state_dict(torch.load("saved_StegaStamp/modelG_240809-1814_19000_28.17.pth", map_location='cpu'), strict=True)
# modelG.load_state_dict(torch.load("saved_StegaStamp/modelG_240911-1216_3000_28.85.pth", map_location='cpu'), strict=True) # 测试100张时发现没28.17高
modelG.load_state_dict(torch.load("saved_combine/modelG_241009-1112_10000_29.88_28.86_33.74.pth", map_location='cpu'), strict=True) # 更佳
# modelG.load_state_dict(torch.load("saved_combine/modelG_241010-1830_31000_30.75_28.83_36.59.pth", map_location='cpu'), strict=True) # 有些BER不高
modelG.eval()
# modelD.eval()
method = 'StegaStamp'
mode = 'val'
domain = 'rgb'
amount = 100
# 处理文件夹中的图片
input_folder = f'/media/dongli911/Documents/Datasets/watermark_dataset/{method}/wm_{mode}'
output_folder = f'/media/dongli911/Documents/Datasets/watermark_dataset/{method}/rm_{mode}'
# 创建输出文件夹
if not os.path.exists(output_folder):
os.makedirs(output_folder)
else:
for filename in os.listdir(output_folder):
file_path = os.path.join(output_folder, filename)
os.remove(file_path)
# 定义滑动窗口参数
batch_size = 16
step_size = 32
window_size = (256, 256)
# 获取文件夹中的所有图片文件
image_files = [f for f in os.listdir(input_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
# image_files = image_files[:int(len(image_files) * 0.8)]
# image_files = image_files[-int(len(image_files) * 0.2):]
image_files = random.sample(image_files, min(amount, len(image_files)))
# 取最后100张图片
for filename in tqdm(image_files):
img_path = os.path.join(input_folder, filename)
img = cv2.imread(img_path)
# assert img.shape == (256, 256, 3)
if domain == 'rgb':
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2YUV)
# 添加填充
padded_img = add_padding(img, window_size, step_size)
# 生成滑动窗口子图
patches = sliding_window(padded_img, step_size, window_size)
# 处理每个子图
# processed_patches = []
# for patch in patches:
# with torch.no_grad():
# patch_reconstructed = torch.clamp(modelG(patch), max=1, min=0)
# processed_patches.append(patch_reconstructed.cpu().numpy())
# 处理每个子图
processed_patches = []
for i in range(0, len(patches), batch_size):
try:
batch_patches = torch.cat(patches[i:i + batch_size], dim=0)
except IndexError:
# 处理越界情况
batch_patches = torch.cat(patches[i:], dim=0)
with torch.no_grad():
batch_reconstructed = torch.clamp(modelG(batch_patches), max=1, min=0)
# print(batch_reconstructed.size())
# 将批次结果拆分回子图
for j in range(len(batch_patches)):
patch_reconstructed = batch_reconstructed[j].cpu().numpy()
# print(patch_reconstructed.shape)
processed_patches.append(patch_reconstructed)
# 重建原始图像
reconstructed_image = reconstruct_image(processed_patches, padded_img.shape, step_size, window_size)
# 去除填充
reconstructed_image = reconstructed_image[:, :img.shape[0], :img.shape[1]]
reconstructed_image_np = (reconstructed_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
# 保存处理后的图片
output_path = os.path.join(output_folder, filename)
if domain == 'rgb':
reconstructed_image_np = cv2.cvtColor(reconstructed_image_np, cv2.COLOR_RGB2BGR)
else:
reconstructed_image_np = cv2.cvtColor(reconstructed_image_np, cv2.COLOR_YUV2BGR)
cv2.imwrite(output_path, reconstructed_image_np)