Commit 0157477e by 柴鹏飞

测试脚本优化

parent 523f5373
...@@ -44,9 +44,16 @@ def load_test_data(): ...@@ -44,9 +44,16 @@ def load_test_data():
return old_users, test_orders return old_users, test_orders
def _sd(n1, n2):
if n2 == 0:
return 0
else:
return n1 / n2
def evaluation(result_detail): def evaluation(result_detail):
top_n_list = [1, 3, 5, 10, 20, 50, 100] top_n_list = [1, 3, 5, 10, 20, 50]
metrics = { metrics = {
'all_test_cnt': 0, 'all_test_cnt': 0,
'all_recall_cnt': 0, 'all_recall_cnt': 0,
...@@ -97,28 +104,23 @@ def evaluation(result_detail): ...@@ -97,28 +104,23 @@ def evaluation(result_detail):
if rd['is_recall']: if rd['is_recall']:
metrics['new_user_recall_cnt'] += 1 metrics['new_user_recall_cnt'] += 1
logger.info('==' * 20 + ' 测试结果 ' + '==' * 20)
logger.info('')
logger.info('相关参数配置: 相似用户数(k) %s ;每个相似用户召回咨询师数(top_n) %s', args.k, args.top_n)
logger.info('--' * 45) logger.info('--' * 45)
logger.info('') logger.info('')
logger.info('{:<13}{:<7}{:<7}{:<7}'.format('', '样本数', '召回数', '召回率')) logger.info('{:<13}{:<7}{:<7}{:<7}'.format('', '样本数', '召回数', '召回率'))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('整体\u3000', metrics['all_test_cnt'], metrics['all_recall_cnt'], metrics['all_recall_cnt'] / metrics['all_test_cnt'])) logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('整体\u3000', metrics['all_test_cnt'], metrics['all_recall_cnt'], _sd(metrics['all_recall_cnt'], metrics['all_test_cnt'])))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('老用户', metrics['old_user_test_cnt'], metrics['old_user_recall_cnt'], metrics['old_user_recall_cnt'] / metrics['old_user_test_cnt'])) logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('老用户', metrics['old_user_test_cnt'], metrics['old_user_recall_cnt'], _sd(metrics['old_user_recall_cnt'], metrics['old_user_test_cnt'])))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('新用户', metrics['new_user_test_cnt'], metrics['new_user_recall_cnt'], metrics['new_user_recall_cnt'] / metrics['new_user_test_cnt'])) logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('新用户', metrics['new_user_test_cnt'], metrics['new_user_recall_cnt'], _sd(metrics['new_user_recall_cnt'], metrics['new_user_test_cnt'])))
logger.info('--' * 45) logger.info('--' * 45)
logger.info('') logger.info('')
logger.info('用户自己召回数 {:<4} 占总召回比例 {:.2%}'.format(metrics['same_user_recall_cnt'], metrics['same_user_recall_cnt'] / metrics['all_recall_cnt'])) logger.info('用户自己召回数 {:<4} 占总召回比例 {:.2%}'.format(metrics['same_user_recall_cnt'], _sd(metrics['same_user_recall_cnt'], metrics['all_recall_cnt'])))
logger.info('相似用户召回数 {:<4} 占总召回比例 {:.2%}'.format(metrics['similar_user_recall_cnt'], metrics['similar_user_recall_cnt'] / metrics['all_recall_cnt'])) logger.info('相似用户召回数 {:<4} 占总召回比例 {:.2%}'.format(metrics['similar_user_recall_cnt'], _sd(metrics['similar_user_recall_cnt'], metrics['all_recall_cnt'])))
logger.info('兜底用户召回数 {:<4} 占总召回比例 {:.2%}'.format(metrics['default_recall_cnt'], metrics['default_recall_cnt'] / metrics['all_recall_cnt'])) logger.info('兜底用户召回数 {:<4} 占总召回比例 {:.2%}'.format(metrics['default_recall_cnt'], _sd(metrics['default_recall_cnt'], metrics['all_recall_cnt'])))
logger.info('--' * 45) logger.info('--' * 45)
logger.info('') logger.info('')
for i, n in enumerate(top_n_list): for i, n in enumerate(top_n_list):
logger.info('top {:<2} 的召回数 {:<4} 召回率 {:.2%}'.format(n, metrics['top_n_recall_cnt'][i], metrics['top_n_recall_cnt'][i] / metrics['all_test_cnt'])) logger.info('top {:<2} 的召回数 {:<4} 召回率 {:.2%}'.format(n, metrics['top_n_recall_cnt'][i], _sd(metrics['top_n_recall_cnt'][i], metrics['all_test_cnt'])))
def do_test(args): def do_test(args):
...@@ -159,11 +161,24 @@ def do_test(args): ...@@ -159,11 +161,24 @@ def do_test(args):
'recall_counselors': recommend_result, 'recall_counselors': recommend_result,
'is_recall': len(recall_resons) > 0, 'is_recall': len(recall_resons) > 0,
'recall_reason': '|'.join(recall_resons), 'recall_reason': '|'.join(recall_resons),
'update_time': order_info['update_time'],
}) })
logger.info('测试结束') logger.info('测试数据推荐完成')
# 测试结果统计 # 测试结果统计
evaluation(result_detail) logger.info('==' * 20 + ' 测试结果 ' + '==' * 20)
logger.info('')
logger.info('相关参数配置: 相似用户数(k) %s ;每个相似用户召回咨询师数(top_n) %s', args.k, args.top_n)
if args.show_result_by_day:
days = set(map(lambda x: x['update_time'], result_detail))
for d in sorted(days):
logger.info('')
logger.info('**' * 15 + ' 订单日期 ' + d + '**' * 15)
evaluation(filter(lambda x: x['update_time'] == d, result_detail))
else:
evaluation(result_detail)
if args.save_test_result: if args.save_test_result:
# 保存测试结果详情数据 # 保存测试结果详情数据
...@@ -179,6 +194,9 @@ def update_test_data(args): ...@@ -179,6 +194,9 @@ def update_test_data(args):
elif re.match(r'-\d+', args.start_date): elif re.match(r'-\d+', args.start_date):
now = datetime.now() now = datetime.now()
start_date = (now - timedelta(days=int(args.start_date[1:]))).strftime('%Y-%m-%d') start_date = (now - timedelta(days=int(args.start_date[1:]))).strftime('%Y-%m-%d')
elif args.start_date == '0':
now = datetime.now()
start_date = now.strftime('%Y-%m-%d')
else: else:
logger.error('args.start_date 参数格式错误,%s', args.start_date) logger.error('args.start_date 参数格式错误,%s', args.start_date)
raise raise
...@@ -207,6 +225,7 @@ if __name__ == '__main__': ...@@ -207,6 +225,7 @@ if __name__ == '__main__':
parser.add_argument('--top_n', default=5, type=int, help='每个相似用户召回的咨询师数量') parser.add_argument('--top_n', default=5, type=int, help='每个相似用户召回的咨询师数量')
parser.add_argument('--max_test', default=0, type=int, help='最多测试数据量') parser.add_argument('--max_test', default=0, type=int, help='最多测试数据量')
parser.add_argument('--save_test_result', default=False, action='store_true', help='保存测试详情结果') parser.add_argument('--save_test_result', default=False, action='store_true', help='保存测试详情结果')
parser.add_argument('--show_result_by_day', default=False, action='store_true', help='测试结果是否按天展示')
parser.add_argument('--do_update_test_data', default=False, action='store_true', help='是否更新测试数据') parser.add_argument('--do_update_test_data', default=False, action='store_true', help='是否更新测试数据')
parser.add_argument('--start_date', default='-1', type=str, help='测试订单创建的开始时间,可以是"%Y-%m-%d"格式,也可以是 -3 表示前3天') parser.add_argument('--start_date', default='-1', type=str, help='测试订单创建的开始时间,可以是"%Y-%m-%d"格式,也可以是 -3 表示前3天')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment