[~] Refactor
This commit is contained in:
parent
af61c39377
commit
d10c5664f5
@ -489,8 +489,8 @@ def kernel_7(
|
|||||||
feed = Variable(torch.from_numpy(img_test_pad)).cuda()
|
feed = Variable(torch.from_numpy(img_test_pad)).cuda()
|
||||||
output1, output2 = model(feed)
|
output1, output2 = model(feed)
|
||||||
|
|
||||||
print(output1.size())
|
#print(output1.size())
|
||||||
print(output2.size())
|
#print(output2.size())
|
||||||
|
|
||||||
heatmap = nn.UpsamplingBilinear2d((img_raw.shape[0], img_raw.shape[1])).cuda()(output2)
|
heatmap = nn.UpsamplingBilinear2d((img_raw.shape[0], img_raw.shape[1])).cuda()(output2)
|
||||||
|
|
||||||
@ -752,7 +752,14 @@ def kernel_7(
|
|||||||
model_pose = torch.nn.DataParallel(model_pose, device_ids=range(torch.cuda.device_count()))
|
model_pose = torch.nn.DataParallel(model_pose, device_ids=range(torch.cuda.device_count()))
|
||||||
cudnn.benchmark = True
|
cudnn.benchmark = True
|
||||||
|
|
||||||
def estimate_pose(img_ori, name=None):
|
def estimate_pose(
|
||||||
|
img_ori,
|
||||||
|
name=None,
|
||||||
|
scale_param=None,
|
||||||
|
):
|
||||||
|
if scale_param is None:
|
||||||
|
scale_param = [0.5, 1.0, 1.5, 2.0]
|
||||||
|
|
||||||
if name is None:
|
if name is None:
|
||||||
name = tempfile.mktemp(
|
name = tempfile.mktemp(
|
||||||
dir='/kaggle/working',
|
dir='/kaggle/working',
|
||||||
@ -763,7 +770,6 @@ def kernel_7(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# People might be at different scales in the image, perform inference at multiple scales to boost results
|
# People might be at different scales in the image, perform inference at multiple scales to boost results
|
||||||
scale_param = [0.5, 1.0, 1.5, 2.0]
|
|
||||||
|
|
||||||
# Predict Heatmaps for approximate joint position
|
# Predict Heatmaps for approximate joint position
|
||||||
# Use Part Affinity Fields (PAF's) as guidance to link joints to form skeleton
|
# Use Part Affinity Fields (PAF's) as guidance to link joints to form skeleton
|
||||||
@ -811,3 +817,20 @@ def kernel_8(
|
|||||||
arch_image = o
|
arch_image = o
|
||||||
img_ori = o_7['cv2'].imread(arch_image)
|
img_ori = o_7['cv2'].imread(arch_image)
|
||||||
o_7['estimate_pose'](img_ori)
|
o_7['estimate_pose'](img_ori)
|
||||||
|
|
||||||
|
def kernel_9_benchmark(
|
||||||
|
o_7,
|
||||||
|
):
|
||||||
|
t1 = o_7['cv2'].imread('../input/indonesian-traditional-dance/tgagrakanyar/tga_0000.jpg'
|
||||||
|
t5 = 10
|
||||||
|
t2 = datetime.datetime.now()
|
||||||
|
for k in range(t5):
|
||||||
|
o_7['estimate_pose'](
|
||||||
|
img_ori=t1,
|
||||||
|
scale_param=[1.0],
|
||||||
|
)
|
||||||
|
t3 = datetime.datetime.now()
|
||||||
|
t4 = (t3 - t2).totalseconds() / t5
|
||||||
|
pprint.pprint(
|
||||||
|
['kernel_9_benchmark', dict(t4=t4, t5=t5)]
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user