Skip to content

Commit

Permalink
Merge pull request #1122 from Laicheng0830/app_format
Browse files Browse the repository at this point in the history
do format
  • Loading branch information
ChengLai authored Feb 22, 2021
2 parents 1f30ceb + 8d9ead0 commit b8c80d9
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
# -*- coding: utf-8 -*-

from tensorlayer.app.human_pose_estimation.common import DataReader, visualize_3D_pose, flip_data
from tensorlayer.app.human_pose_estimation.LCN import CGCNN
from tensorlayer.app import computer_vision
import numpy as np

datareader = DataReader()
train_data, test_data = datareader.read_2d(which='scale', mode='gt', read_confidence=False)
train_labels, test_labels = datareader.read_3d(which='scale', mode='gt')
network = CGCNN(pretrained=True)
network = computer_vision.human_pose_estimation('3D-pose')
test_data = flip_data(test_data)
result = network(test_data, is_train=False)
result = network(test_data)
result = datareader.denormalize3D(np.asarray(result), which='scale')
test_data = datareader.denormalize2D(test_data, which='scale')
test_labels = datareader.denormalize3D(test_labels, which='scale')
Expand Down
135 changes: 53 additions & 82 deletions tensorlayer/app/computer_vision.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
#! /usr/bin/python
# -*- coding: utf-8 -*-

from tensorlayer.app import YOLOv4, get_anchors, decode, filter_boxes
from tensorlayer.app import YOLOv4
from tensorlayer.app import CGCNN
import numpy as np
import tensorflow as tf
from tensorlayer import logging
import cv2
from tensorlayer.app import yolo4_input_processing, yolo4_output_processing, result_to_json


class object_detection(object):
Expand Down Expand Up @@ -42,8 +40,6 @@ def __init__(self, model_name='yolo4-mscoco'):
self.model_name = model_name
if self.model_name == 'yolo4-mscoco':
self.model = YOLOv4(NUM_CLASS=80, pretrained=True)
elif self.model_name == 'lcn':
self.model = CGCNN(pretrained=True)
else:
raise ("The model does not support.")

Expand All @@ -53,8 +49,6 @@ def __call__(self, input_data):
feature_maps = self.model(batch_data, is_train=False)
pred_bbox = yolo4_output_processing(feature_maps)
output = result_to_json(input_data, pred_bbox)
elif self.model_name == 'lcn':
output = self.model(input_data)
else:
raise NotImplementedError

Expand All @@ -70,78 +64,55 @@ def list(self):
logging.info("The model name list: 'yolov4-mscoco', 'lcn'")


def yolo4_input_processing(original_image):
image_data = cv2.resize(original_image, (416, 416))
image_data = image_data / 255.
images_data = []
for i in range(1):
images_data.append(image_data)
images_data = np.asarray(images_data).astype(np.float32)
batch_data = tf.constant(images_data)
return batch_data


def yolo4_output_processing(feature_maps):
STRIDES = [8, 16, 32]
ANCHORS = get_anchors([12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401])
NUM_CLASS = 80
XYSCALE = [1.2, 1.1, 1.05]
iou_threshold = 0.45
score_threshold = 0.25

bbox_tensors = []
prob_tensors = []
score_thres = 0.2
for i, fm in enumerate(feature_maps):
if i == 0:
output_tensors = decode(fm, 416 // 8, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
elif i == 1:
output_tensors = decode(fm, 416 // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
class human_pose_estimation(object):
"""Model encapsulation.
Parameters
----------
model_name : str
Choose the model to inference.
Methods
---------
__init__()
Initializing the model.
__call__()
(1)Formatted input and output. (2)Inference model.
list()
Abstract method. Return available a list of model_name.
Examples
---------
LCN to estimate 3D human poses from 2D poses, see `tutorial_human_3dpose_estimation_LCN.py
<https://github.com/tensorlayer/tensorlayer/blob/master/example/app_tutorials/tutorial_human_3dpose_estimation_LCN.py>`__
With TensorLayer
>>> # get the whole model
>>> net = tl.app.computer_vision.human_pose_estimation('3D-pose')
>>> # use for inferencing
>>> output = net(img)
"""

def __init__(self, model_name='3D-pose'):
self.model_name = model_name
if self.model_name == '3D-pose':
self.model = CGCNN(pretrained=True)
else:
output_tensors = decode(fm, 416 // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
bbox_tensors.append(output_tensors[0])
prob_tensors.append(output_tensors[1])
pred_bbox = tf.concat(bbox_tensors, axis=1)
pred_prob = tf.concat(prob_tensors, axis=1)
boxes, pred_conf = filter_boxes(
pred_bbox, pred_prob, score_threshold=score_thres, input_shape=tf.constant([416, 416])
)
pred = {'concat': tf.concat([boxes, pred_conf], axis=-1)}

for key, value in pred.items():
boxes = value[:, :, 0:4]
pred_conf = value[:, :, 4:]

boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
scores=tf.reshape(pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
max_output_size_per_class=50, max_total_size=50, iou_threshold=iou_threshold, score_threshold=score_threshold
)
output = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]
return output


def result_to_json(image, pred_bbox):
image_h, image_w, _ = image.shape
out_boxes, out_scores, out_classes, num_boxes = pred_bbox
class_names = {}
json_result = []
with open('model/coco.names', 'r') as data:
for ID, name in enumerate(data):
class_names[ID] = name.strip('\n')
nums_class = len(class_names)

for i in range(num_boxes[0]):
if int(out_classes[0][i]) < 0 or int(out_classes[0][i]) > nums_class: continue
coor = out_boxes[0][i]
coor[0] = int(coor[0] * image_h)
coor[2] = int(coor[2] * image_h)
coor[1] = int(coor[1] * image_w)
coor[3] = int(coor[3] * image_w)

score = float(out_scores[0][i])
class_ind = int(out_classes[0][i])
bbox = np.array([coor[1], coor[0], coor[3], coor[2]]).tolist() # [x1,y1,x2,y2]
json_result.append({'image': None, 'category_id': class_ind, 'bbox': bbox, 'score': score})

return json_result
raise ("The model does not support.")

def __call__(self, input_data):
if self.model_name == '3D-pose':
output = self.model(input_data, is_train=False)
else:
raise NotImplementedError

return output

def __repr__(self):
s = ('(model_name={model_name}, model_structure={model}')
s += ')'
return s.format(classname=self.__class__.__name__, **self.__dict__)

@property
def list(self):
logging.info("The model name list: '3D-pose'")
99 changes: 75 additions & 24 deletions tensorlayer/app/computer_vision_object_detection/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,27 +147,78 @@ def decode_train(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYS
return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)


# def weights_sorted():
# # download weights
# maybe_download_and_extract(
# 'yolov4.npz',
# model_path,
# 'https://github.com/',
# ) # ls -al
# weights = []
# track_weights = []
# weights_dict = {}
#
# npz = np.load(model_path, allow_pickle=True)
# # get weight list
# for val in sorted(npz.items()):
# logging.info(" Loading weights %s in %s" % (str(val[1].shape), val[0]))
# try:
# weights.append(int(val[0].split('/')[0].split('-')[-1]))
# track_weights.append(val[0])
# except:
# pass
# zip_weights = zip(weights, track_weights)
# zip_weights = sorted(zip_weights)
# for value, key in zip_weights:
# print(key)
def yolo4_input_processing(original_image):
image_data = cv2.resize(original_image, (416, 416))
image_data = image_data / 255.
images_data = []
for i in range(1):
images_data.append(image_data)
images_data = np.asarray(images_data).astype(np.float32)
batch_data = tf.constant(images_data)
return batch_data


def yolo4_output_processing(feature_maps):
STRIDES = [8, 16, 32]
ANCHORS = get_anchors([12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401])
NUM_CLASS = 80
XYSCALE = [1.2, 1.1, 1.05]
iou_threshold = 0.45
score_threshold = 0.25

bbox_tensors = []
prob_tensors = []
score_thres = 0.2
for i, fm in enumerate(feature_maps):
if i == 0:
output_tensors = decode(fm, 416 // 8, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
elif i == 1:
output_tensors = decode(fm, 416 // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
else:
output_tensors = decode(fm, 416 // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
bbox_tensors.append(output_tensors[0])
prob_tensors.append(output_tensors[1])
pred_bbox = tf.concat(bbox_tensors, axis=1)
pred_prob = tf.concat(prob_tensors, axis=1)
boxes, pred_conf = filter_boxes(
pred_bbox, pred_prob, score_threshold=score_thres, input_shape=tf.constant([416, 416])
)
pred = {'concat': tf.concat([boxes, pred_conf], axis=-1)}

for key, value in pred.items():
boxes = value[:, :, 0:4]
pred_conf = value[:, :, 4:]

boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
scores=tf.reshape(pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
max_output_size_per_class=50, max_total_size=50, iou_threshold=iou_threshold, score_threshold=score_threshold
)
output = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]
return output


def result_to_json(image, pred_bbox):
image_h, image_w, _ = image.shape
out_boxes, out_scores, out_classes, num_boxes = pred_bbox
class_names = {}
json_result = []
with open('model/coco.names', 'r') as data:
for ID, name in enumerate(data):
class_names[ID] = name.strip('\n')
nums_class = len(class_names)

for i in range(num_boxes[0]):
if int(out_classes[0][i]) < 0 or int(out_classes[0][i]) > nums_class: continue
coor = out_boxes[0][i]
coor[0] = int(coor[0] * image_h)
coor[2] = int(coor[2] * image_h)
coor[1] = int(coor[1] * image_w)
coor[3] = int(coor[3] * image_w)

score = float(out_scores[0][i])
class_ind = int(out_classes[0][i])
bbox = np.array([coor[1], coor[0], coor[3], coor[2]]).tolist() # [x1,y1,x2,y2]
json_result.append({'image': None, 'category_id': class_ind, 'bbox': bbox, 'score': score})

return json_result
4 changes: 2 additions & 2 deletions tensorlayer/app/computer_vision_object_detection/yolov4.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def YOLOv4(NUM_CLASS, pretrained=False):
network = Model(input_layer, [conv_sbbox, conv_mbbox, conv_lbbox])

if pretrained:
restore_params(network, model_path='model/model.npz')
restore_params(network, model_path='model/yolov4_model.npz')

return network

Expand All @@ -236,7 +236,7 @@ def restore_params(network, model_path='models.npz'):
print("Download the model file, placed in the /model ")
print("Weights download: ", weights_url['link'], "password:", weights_url['password'])

txt_path = 'model/yolov4_config.txt'
txt_path = 'model/yolov4_weights_config.txt'
f = open(txt_path, "r")
line = f.readlines()
for i in range(len(line)):
Expand Down
4 changes: 2 additions & 2 deletions tensorlayer/app/human_pose_estimation/LCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def restore_params(network, model_path='model.npz'):
print("Download the model file, placed in the /model ")
print("Weights download: ", weights_url['link'], "password:", weights_url['password'])

txt_path = 'model/pose_config.txt'
txt_path = 'model/pose_weights_config.txt'
f = open(txt_path, "r")
line = f.readlines()
for i in range(len(line)):
Expand Down Expand Up @@ -325,7 +325,7 @@ def CGCNN(pretrained=True):
"""
if pretrained:
network = cgcnn_inference()
restore_params(network, model_path='model/model.npz')
restore_params(network, model_path='model/lcn_model.npz')
else:
network = cgcnn_train()
return network

0 comments on commit b8c80d9

Please sign in to comment.