Commit a7c2622d by 柴鹏飞

优化代码写法,测试指标调整

parent ce4a0f69
...@@ -7,6 +7,6 @@ set -x ...@@ -7,6 +7,6 @@ set -x
for port in `echo '6001,6002' | tr ',' ' '` for port in `echo '6001,6002' | tr ',' ' '`
do do
echo $port echo $port
nohup python ../src/service/service.py --port=$port > ../log/console_service_$port.log 2>&1 & nohup python ../src/service/recommend_service.py --port=$port > ../log/console_service_$port.log 2>&1 &
echo $! >> .start_service.pid echo $! >> .start_service.pid
done done
\ No newline at end of file
...@@ -45,6 +45,8 @@ def load_test_data(): ...@@ -45,6 +45,8 @@ def load_test_data():
def evaluation(result_detail): def evaluation(result_detail):
top_n_list = [1, 3, 5, 10]
metrics = { metrics = {
'all_test_cnt': 0, 'all_test_cnt': 0,
'all_recall_cnt': 0, 'all_recall_cnt': 0,
...@@ -54,6 +56,8 @@ def evaluation(result_detail): ...@@ -54,6 +56,8 @@ def evaluation(result_detail):
'new_user_recall_cnt': 0, 'new_user_recall_cnt': 0,
'same_user_recall_cnt': 0, 'same_user_recall_cnt': 0,
'similar_user_recall_cnt': 0, 'similar_user_recall_cnt': 0,
'default_recall_cnt': 0,
'top_n_recall_cnt': [0] * len(top_n_list),
} }
for rd in result_detail: for rd in result_detail:
...@@ -62,12 +66,22 @@ def evaluation(result_detail): ...@@ -62,12 +66,22 @@ def evaluation(result_detail):
metrics['all_recall_cnt'] += 1 metrics['all_recall_cnt'] += 1
is_same_user, is_similar_user = False, False is_same_user, is_similar_user = False, False
for counselor in rd['recall_counselors']: for top_n, counselor in enumerate(rd['recall_counselors']):
from_id = counselor['from'].split(' ')[1] if counselor['counselor'] == rd['supplier_id']:
from_info = counselor['from'].split(' ')
if from_info[0] == 'top_50':
metrics['default_recall_cnt'] += 1
else:
from_id = from_info[1]
if from_id == rd['uid']: if from_id == rd['uid']:
is_same_user = True is_same_user = True
if from_id != rd['uid']: if from_id != rd['uid']:
is_similar_user = True is_similar_user = True
for i, n in enumerate(top_n_list):
if n > top_n:
metrics['top_n_recall_cnt'][i] += 1
if is_same_user: if is_same_user:
metrics['same_user_recall_cnt'] += 1 metrics['same_user_recall_cnt'] += 1
...@@ -97,8 +111,14 @@ def evaluation(result_detail): ...@@ -97,8 +111,14 @@ def evaluation(result_detail):
logger.info('--' * 45) logger.info('--' * 45)
logger.info('') logger.info('')
logger.info('用户自己召回数 {} 占总召回比例 {:.2%}'.format(metrics['same_user_recall_cnt'], metrics['same_user_recall_cnt'] / metrics['all_recall_cnt'])) logger.info('用户自己召回数 {:<4} 占总召回比例 {:.2%}'.format(metrics['same_user_recall_cnt'], metrics['same_user_recall_cnt'] / metrics['all_recall_cnt']))
logger.info('相似用户召回数 {} 占总召回比例 {:.2%}'.format(metrics['similar_user_recall_cnt'], metrics['similar_user_recall_cnt'] / metrics['all_recall_cnt'])) logger.info('相似用户召回数 {:<4} 占总召回比例 {:.2%}'.format(metrics['similar_user_recall_cnt'], metrics['similar_user_recall_cnt'] / metrics['all_recall_cnt']))
logger.info('兜底用户召回数 {:<4} 占总召回比例 {:.2%}'.format(metrics['default_recall_cnt'], metrics['default_recall_cnt'] / metrics['all_recall_cnt']))
logger.info('--' * 45)
logger.info('')
for i, n in enumerate(top_n_list):
logger.info('top {:<2} 的召回数 {:<4} 召回率 {:.2%}'.format(n, metrics['top_n_recall_cnt'][i], metrics['top_n_recall_cnt'][i] / metrics['all_test_cnt']))
def do_test(args): def do_test(args):
...@@ -113,6 +133,7 @@ def do_test(args): ...@@ -113,6 +133,7 @@ def do_test(args):
recommender = UserCFRecommender(top_n=args.top_n, k=args.k, is_use_db=False) recommender = UserCFRecommender(top_n=args.top_n, k=args.k, is_use_db=False)
result_detail = [] result_detail = []
logger.info('开始测试')
for index, order_info in test_orders.iterrows(): for index, order_info in test_orders.iterrows():
if args.max_test > 0: if args.max_test > 0:
if index >= args.max_test: if index >= args.max_test:
...@@ -124,7 +145,8 @@ def do_test(args): ...@@ -124,7 +145,8 @@ def do_test(args):
continue continue
is_merge = args.mode == 0 is_merge = args.mode == 0
recommend_result = recommender.recommend_with_profile(profile, is_merge=is_merge) size = 10 if args.mode == 0 else 0
recommend_result = recommender.recommend_with_profile(profile, size=size, is_merge=is_merge)
recall_resons = [] recall_resons = []
for rr in recommend_result: for rr in recommend_result:
if rr['counselor'] == order_info['supplier_id']: if rr['counselor'] == order_info['supplier_id']:
...@@ -138,6 +160,7 @@ def do_test(args): ...@@ -138,6 +160,7 @@ def do_test(args):
'is_recall': len(recall_resons) > 0, 'is_recall': len(recall_resons) > 0,
'recall_reason': '|'.join(recall_resons), 'recall_reason': '|'.join(recall_resons),
}) })
logger.info('测试结束')
# 测试结果统计 # 测试结果统计
evaluation(result_detail) evaluation(result_detail)
...@@ -158,6 +181,7 @@ def update_test_data(args): ...@@ -158,6 +181,7 @@ def update_test_data(args):
logger.error('args.start_date 参数格式错误,%s', args.start_date) logger.error('args.start_date 参数格式错误,%s', args.start_date)
raise raise
logger.info('开始更新测试数据,测试数据获取条件为 create_time >= %s ', start_date)
conditions = ['create_time >= "{}"'.format(start_date)] conditions = ['create_time >= "{}"'.format(start_date)]
client = MySQLClient.create_from_config_file(get_conf_path()) client = MySQLClient.create_from_config_file(get_conf_path())
...@@ -175,8 +199,9 @@ def update_test_data(args): ...@@ -175,8 +199,9 @@ def update_test_data(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--k', default=5, type=int, help='召回相似用户的数量')
parser.add_argument('--mode', default=0, type=int, help='模式:0-推荐的咨询师列表去重(默认,与实际线上一样);1-推荐的咨询师列表没有去重') parser.add_argument('--mode', default=0, type=int, help='模式:0-推荐的咨询师列表去重(默认,与实际线上一样);1-推荐的咨询师列表没有去重')
parser.add_argument('--k', default=10, type=int, help='召回相似用户的数量')
parser.add_argument('--top_n', default=5, type=int, help='每个相似用户召回的咨询师数量') parser.add_argument('--top_n', default=5, type=int, help='每个相似用户召回的咨询师数量')
parser.add_argument('--max_test', default=0, type=int, help='最多测试数据量') parser.add_argument('--max_test', default=0, type=int, help='最多测试数据量')
......
...@@ -24,7 +24,7 @@ class Recommender(): ...@@ -24,7 +24,7 @@ class Recommender():
class UserCFRecommender(Recommender): class UserCFRecommender(Recommender):
def __init__(self, top_n=5, k=5, is_use_db=True) -> None: def __init__(self, top_n=5, k=10, is_use_db=True) -> None:
super().__init__() super().__init__()
# 召回 top_n 个相似用户 # 召回 top_n 个相似用户
self.top_n = top_n self.top_n = top_n
...@@ -91,8 +91,8 @@ class UserCFRecommender(Recommender): ...@@ -91,8 +91,8 @@ class UserCFRecommender(Recommender):
return self.manager.profile_to_embedding(user_profile) return self.manager.profile_to_embedding(user_profile)
def _recommend_top(self): def _recommend_top(self, size=50):
return self.default_counselor return self.default_counselor[:size]
def _recommend(self, user_embedding): def _recommend(self, user_embedding):
...@@ -115,12 +115,12 @@ class UserCFRecommender(Recommender): ...@@ -115,12 +115,12 @@ class UserCFRecommender(Recommender):
return counselors return counselors
def recommend_with_profile(self, user_profile, count=0, is_merge=True): def recommend_with_profile(self, user_profile, size=0, is_merge=True):
user_embedding = self.user_token(user_profile) user_embedding = self.user_token(user_profile)
counselors = self._recommend(user_embedding) counselors = self._recommend(user_embedding)
# count == 0 时,不追加默认推荐咨询师 # size == 0 时,不追加默认推荐咨询师
if count > 0: if size > 0:
counselors.extend(self._recommend_top()) counselors.extend(self._recommend_top())
if is_merge: if is_merge:
...@@ -132,21 +132,21 @@ class UserCFRecommender(Recommender): ...@@ -132,21 +132,21 @@ class UserCFRecommender(Recommender):
merged_counselors.append(counselor) merged_counselors.append(counselor)
counselors = merged_counselors counselors = merged_counselors
if count > 0: if size > 0:
counselors = counselors[:count] counselors = counselors[:size]
return counselors return counselors
def recommend(self, user_id, count=0, is_merge=True): def recommend(self, user_id, size=0, is_merge=True):
""" """
根据用户画像,推荐咨询师 根据用户画像,推荐咨询师
若获取不到用户画像,推荐默认咨询师(订单最多的) 若获取不到用户画像,推荐默认咨询师(订单最多的)
""" """
user_profile = self.get_user_profile(user_id) user_profile = self.get_user_profile(user_id)
if not user_profile: if not user_profile:
return self._recommend_top() return self._recommend_top(size)
return self.recommend_with_profile(user_profile, count, is_merge) return self.recommend_with_profile(user_profile, size, is_merge)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -26,14 +26,14 @@ class RecommendHandler(tornado.web.RequestHandler): ...@@ -26,14 +26,14 @@ class RecommendHandler(tornado.web.RequestHandler):
if uid is None: if uid is None:
logger.warn('请求参数不正确,无uid') logger.warn('请求参数不正确,无uid')
row = self.get_argument('row', 10) size = self.get_argument('size', 10)
try: try:
row = int(row) size = int(size)
except Exception as e: except Exception as e:
logger.warn('row=%s 不是数字', row) logger.warn('size=%s 不是数字', size)
row = 10 size = 10
ret = yield self.run(uid, row) ret = yield self.run(uid, size)
self.write(ret) self.write(ret)
...@@ -41,24 +41,24 @@ class RecommendHandler(tornado.web.RequestHandler): ...@@ -41,24 +41,24 @@ class RecommendHandler(tornado.web.RequestHandler):
def post(self): def post(self):
param = json.loads(self.request.body.decode('utf-8')) param = json.loads(self.request.body.decode('utf-8'))
uid = param.get('uid', None) uid = param.get('uid', None)
row = param.get('row', 10) size = param.get('size', 10)
if uid is None: if uid is None:
logger.warn('请求参数不正确,无uid') logger.warn('请求参数不正确,无uid')
ret = yield self.run(uid, row) ret = yield self.run(uid, size)
self.write(ret) self.write(ret)
@run_on_executor @run_on_executor
def run(self, uid, row=10): def run(self, uid, size=10):
logger.info('request@@uid=%s@@row=%s', uid, row) logger.info('request@@uid=%s@@size=%s', uid, size)
try: try:
recommend_result = recommender.recommend(uid, count=row, is_merge=True) recommend_result = recommender.recommend(uid, size=size, is_merge=True)
ret = { ret = {
'status': 'success', 'status': 'success',
'code': 0, 'code': 0,
'data': recommend_result, 'data': recommend_result,
'row': len(recommend_result), 'total_count': len(recommend_result),
} }
except Exception as e: except Exception as e:
logger.error('执行推荐函数报错', exc_info=True) logger.error('执行推荐函数报错', exc_info=True)
...@@ -66,7 +66,7 @@ class RecommendHandler(tornado.web.RequestHandler): ...@@ -66,7 +66,7 @@ class RecommendHandler(tornado.web.RequestHandler):
'status': 'error', 'status': 'error',
'code': 1, 'code': 1,
'data': [], 'data': [],
'row': 0, 'total_count': 0,
} }
ret_str = json.dumps(ret, ensure_ascii=False) ret_str = json.dumps(ret, ensure_ascii=False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment