-
Notifications
You must be signed in to change notification settings - Fork 18
/
utils.py
189 lines (161 loc) · 6.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# Copyright 2018 Giorgos Kordopatis-Zilos. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import division
from __future__ import print_function
import numpy as np
import pickle as pk
import matplotlib.pylab as plt
from future.utils import lrange
from sklearn.metrics import precision_recall_curve
def load_dataset(dataset):
"""
Function that loads dataset object.
Args:
dataset: dataset name
Returns:
dataset object
"""
return pk.load(open('datasets/{}.pickle'.format(dataset), 'rb'))
def load_feature_files(feature_files):
"""
Function that loads the feature directories.
Args:
feature_files: file that contains the feature directories
Returns:
dictionary that contains the feature directories for each video id
Raise:
file is not in the right format
"""
try:
with open(feature_files, 'r') as f:
d = {l.split('\t')[0]: l.split('\t')[1].strip() for l in f.readlines()}
return d
except:
raise Exception('''--feature_files provided is in wrong format. Each line of the
file have to contain the video id (name of the video file)
and the full path to the corresponding .npy file, separated
by a tab character (\\t). Example:
23254771545e5d278548ba02d25d32add952b2a4 features/23254771545e5d278548ba02d25d32add952b2a4.npy
468410600142c136d707b4cbc3ff0703c112575d features/468410600142c136d707b4cbc3ff0703c112575d.npy
67f1feff7f624cf0b9ac2ebaf49f547a922b4971 features/67f1feff7f624cf0b9ac2ebaf49f547a922b4971.npy
7deff9e47e47c98bb341c4355dfff9a82bfba221 features/7deff9e47e47c98bb341c4355dfff9a82bfba221.npy
...''')
def load_features(video):
"""
Function that loads the video frame vectors.
Args:
video: path to input video
Returns:
the video frame vectors
"""
try:
return np.load(video)
except Exception as e:
if video:
print('Can\'t load feature file {}\n{}'.format(video, e.message))
return np.array([])
def normalize(X, zero_mean=True, l2_norm=True):
"""
Function that apply zero mean and l2-norm to every vector.
Args:
X: input feature vectors
zero_mean: apply zero mean
l2_norm: apply l2-norm
Returns:
the normalized vectors
"""
if zero_mean:
X -= X.mean(axis=1, keepdims=True)
if l2_norm:
X /= np.linalg.norm(X, axis=1, keepdims=True) + 1e-15
return X
def global_vector(video):
"""
Function that calculate the global feature vector from the
frame features vectors. First, all frame features vectors
are normalized, then they are averaged on each dimension to
produce the global vector, and finally the global vector is
normalized again.
Args:
video: path to feature file of a video
Returns:
X: the normalized global feature vector
"""
try:
X = load_features(video)
X = normalize(X)
X = X.mean(axis=0, keepdims=True)
X = normalize(X)
return X
except:
return np.array([])
def plot_pr_curve(pr_curve_dml, pr_curve_base, title):
"""
Function that plots the PR-curve.
Args:
pr_curve: the values of precision for each recall value
title: the title of the plot
"""
plt.figure(figsize=(16, 9))
plt.plot(np.arange(0.0, 1.05, 0.05),
pr_curve_base, color='r', marker='o', linewidth=3, markersize=10)
plt.plot(np.arange(0.0, 1.05, 0.05),
pr_curve_dml, color='b', marker='o', linewidth=3, markersize=10)
plt.grid(True, linestyle='dotted')
plt.xlabel('Recall', color='k', fontsize=27)
plt.ylabel('Precision', color='k', fontsize=27)
plt.yticks(color='k', fontsize=20)
plt.xticks(color='k', fontsize=20)
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title(title, color='k', fontsize=27)
plt.tight_layout()
plt.show()
def evaluate(ground_truth, similarities, positive_labels='ESLMV', all_videos=False):
"""
Function that plots the PR-curve.
Args:
ground_truth: the ground truth labels for each query
similarities: the similarities of each query with the videos in the dataset
positive_labels: labels that are considered positives
all_videos: indicator of whether all videos are considered for the evaluation
or only the videos in the query subset
Returns:
mAP: the mean Average Precision
ps_curve: the values of the PR-curve
"""
pr, mAP = [], 0.0
for query_set, labels in enumerate(ground_truth):
i = 0.0
ri = 0
s = 0.0
y_target, y_score = [], []
for video, sim in similarities[query_set]:
if all_videos or video in labels:
y_score += [sim]
y_target += [0.0]
ri += 1
if video in labels and labels[video] in positive_labels:
i += 1.0
s += i / ri
y_target[-1] = 1.0
mAP += s / np.sum([1.0 for label in labels.values() if label in positive_labels])
precision, recall, thresholds = precision_recall_curve(y_target, y_score)
p = []
for i in lrange(20, 0, -1):
idx = np.where((recall >= i*0.05))[0]
p += [np.max(precision[idx])]
pr += [p + [1.0]]
return mAP / len(ground_truth), np.mean(pr, axis=0)[::-1]