Commit 9f715802 by 柴鹏飞

新增咨询师->咨询师索引关系

parent 65e1c61d
...@@ -76,7 +76,7 @@ def evaluation(result_detail): ...@@ -76,7 +76,7 @@ def evaluation(result_detail):
for top_n, counselor in enumerate(rd['recall_counselors']): for top_n, counselor in enumerate(rd['recall_counselors']):
if counselor['counselor'] == rd['supplier_id']: if counselor['counselor'] == rd['supplier_id']:
from_info = counselor['from'].split(' ') from_info = counselor['from'].split(' ')
if from_info[0] == 'top_100': if from_info[0] == 'default':
metrics['default_recall_cnt'] += 1 metrics['default_recall_cnt'] += 1
else: else:
from_id = from_info[1] from_id = from_info[1]
...@@ -122,6 +122,8 @@ def evaluation(result_detail): ...@@ -122,6 +122,8 @@ def evaluation(result_detail):
for i, n in enumerate(top_n_list): for i, n in enumerate(top_n_list):
logger.info('top {:<2} 的召回数 {:<4} 召回率 {:.2%}'.format(n, metrics['top_n_recall_cnt'][i], _sd(metrics['top_n_recall_cnt'][i], metrics['all_test_cnt']))) logger.info('top {:<2} 的召回数 {:<4} 召回率 {:.2%}'.format(n, metrics['top_n_recall_cnt'][i], _sd(metrics['top_n_recall_cnt'][i], metrics['all_test_cnt'])))
return metrics
def do_test(args): def do_test(args):
user_profile_dict = load_local_user_profile() user_profile_dict = load_local_user_profile()
...@@ -147,7 +149,7 @@ def do_test(args): ...@@ -147,7 +149,7 @@ def do_test(args):
continue continue
is_merge = args.mode == 0 is_merge = args.mode == 0
size = 10 if args.mode == 0 else 0 size = 100 if args.mode == 0 else 0
recommend_result = recommender.recommend_with_profile(profile, size=size, is_merge=is_merge) recommend_result = recommender.recommend_with_profile(profile, size=size, is_merge=is_merge)
recall_resons = [] recall_resons = []
for rr in recommend_result: for rr in recommend_result:
...@@ -216,6 +218,104 @@ def update_test_data(args): ...@@ -216,6 +218,104 @@ def update_test_data(args):
logger.info('测试数据更新完成') logger.info('测试数据更新完成')
def batch_test(args):
user_profile_dict = load_local_user_profile()
try:
old_users, test_orders = load_test_data()
except Exception as e:
logger.error('测试数据加载出错,请确认测试数据已经下载到本地,或启动命令中增加参数 "--do_update_test_data"', exc_info=True)
return
def _test(case):
logger.info('##' * 20 + ' 测试case ' + '##' * 20)
logger.info(json.dumps(case, ensure_ascii=False))
logger.info('##' * 20 + ' 测试case ' + '##' * 20)
recommender = UserCFRecommender(**case)
result_detail = []
for index, order_info in test_orders.iterrows():
if args.max_test > 0:
if index >= args.max_test:
break
uid = order_info['uid']
profile = user_profile_dict.get(uid)
if profile is None:
continue
is_merge = args.mode == 0
size = 100 if args.mode == 0 else 0
recommend_result = recommender.recommend_with_profile(profile, size=size, is_merge=is_merge)
recall_resons = []
for rr in recommend_result:
if rr['counselor'] == order_info['supplier_id']:
recall_resons.append(rr['from'])
result_detail.append({
'uid': uid,
'supplier_id': order_info['supplier_id'],
'is_old_user': uid in old_users,
'recall_counselors': recommend_result,
'is_recall': len(recall_resons) > 0,
'recall_reason': '|'.join(recall_resons),
'update_time': order_info['update_time'],
})
return result_detail
def _evaluation(result_detail):
metrics = evaluation(result_detail)
metrics['all_recall_ratio'] = _sd(metrics['all_recall_cnt'], metrics['all_test_cnt'])
metrics['old_user_recall_ratio'] = _sd(metrics['old_user_recall_cnt'], metrics['old_user_test_cnt'])
metrics['new_user_recall_ratio'] = _sd(metrics['new_user_recall_cnt'], metrics['new_user_test_cnt'])
metrics['top1'] = _sd(metrics['top_n_recall_cnt'][0], metrics['all_test_cnt'])
metrics['top10'] = _sd(metrics['top_n_recall_cnt'][3], metrics['all_test_cnt'])
metrics['top50'] = _sd(metrics['top_n_recall_cnt'][5], metrics['all_test_cnt'])
return metrics
test_cases = [
{'top_n': 2, 'k': 50, 'u2c': 'combination', 'c2c': None, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'combination', 'c2c': None, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'combination', 'c2c': None, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'combination', 'c2c': None, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'chat', 'c2c': None, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'chat', 'c2c': None, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'chat', 'c2c': None, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'chat', 'c2c': None, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'order', 'c2c': None, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'order', 'c2c': None, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'order', 'c2c': None, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'order', 'c2c': None, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'combination', 'c2c': True, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'combination', 'c2c': True, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'combination', 'c2c': True, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'combination', 'c2c': True, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'chat', 'c2c': True, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'chat', 'c2c': True, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'chat', 'c2c': True, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'chat', 'c2c': True, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'order', 'c2c': True, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'order', 'c2c': True, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'order', 'c2c': True, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'order', 'c2c': True, 'is_use_db': False},
]
all_test_result = []
for case in test_cases:
result_detail = _test(case)
metrics = _evaluation(result_detail)
all_test_result.append((case, metrics))
with open('batch_test_result.tsv', 'w', encoding='utf-8') as f:
for (case, metrics) in all_test_result:
items = [case['k'], case['top_n'], case['u2c'], case['c2c']]
items.extend([metrics['all_recall_ratio'], metrics['old_user_recall_ratio'], metrics['new_user_recall_ratio']])
items.extend([metrics['top1'], metrics['top10'], metrics['top50']])
f.write('\t'.join(map(lambda x: str(x), items)) + '\n')
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -227,6 +327,8 @@ if __name__ == '__main__': ...@@ -227,6 +327,8 @@ if __name__ == '__main__':
parser.add_argument('--save_test_result', default=False, action='store_true', help='保存测试详情结果') parser.add_argument('--save_test_result', default=False, action='store_true', help='保存测试详情结果')
parser.add_argument('--show_result_by_day', default=False, action='store_true', help='测试结果是否按天展示') parser.add_argument('--show_result_by_day', default=False, action='store_true', help='测试结果是否按天展示')
parser.add_argument('--do_batch_test', default=False, action='store_true', help='批量测试')
parser.add_argument('--do_update_test_data', default=False, action='store_true', help='是否更新测试数据') parser.add_argument('--do_update_test_data', default=False, action='store_true', help='是否更新测试数据')
parser.add_argument('--start_date', default='-1', type=str, help='测试订单创建的开始时间,可以是"%Y-%m-%d"格式,也可以是 -3 表示前3天') parser.add_argument('--start_date', default='-1', type=str, help='测试订单创建的开始时间,可以是"%Y-%m-%d"格式,也可以是 -3 表示前3天')
...@@ -238,5 +340,9 @@ if __name__ == '__main__': ...@@ -238,5 +340,9 @@ if __name__ == '__main__':
logger.info('测试数据创建时间 %s', args.start_date) logger.info('测试数据创建时间 %s', args.start_date)
update_test_data(args) update_test_data(args)
if args.do_batch_test:
logger.info('执行批量测试')
batch_test(args)
else:
logger.info('开始执行测试任务,测试模式为 %s', args.mode) logger.info('开始执行测试任务,测试模式为 %s', args.mode)
do_test(args) do_test(args)
\ No newline at end of file
...@@ -91,10 +91,10 @@ if __name__ == '__main__': ...@@ -91,10 +91,10 @@ if __name__ == '__main__':
if args.task == 'make_index': if args.task == 'make_index':
indexers = [ indexers = [
['[用户->咨询师]兜底关系索引', UserCounselorDefaultIndexer()], ['[用户->咨询师]兜底关系索引', UserCounselorDefaultIndexer(logger=logger)],
['基于订单数据的[用户->咨询师]关系索引', UserCounselorOrderIndexer()], ['基于订单数据的[用户->咨询师]关系索引', UserCounselorOrderIndexer(logger=logger)],
['基于询单数据的[用户->咨询师]关系索引', UserCounselorChatIndexer()], ['基于询单数据的[用户->咨询师]关系索引', UserCounselorChatIndexer(logger=logger)],
['基于多种数据组合的[用户->咨询师]关系索引', UserCounselorCombinationIndexer()], ['基于多种数据组合的[用户->咨询师]关系索引', UserCounselorCombinationIndexer(0.8, 0.2, logger=logger)],
] ]
logger.info('') logger.info('')
......
...@@ -7,8 +7,13 @@ from typing import List, Dict ...@@ -7,8 +7,13 @@ from typing import List, Dict
import faiss import faiss
import numpy as np import numpy as np
from ydl_ai_recommender.src.core.indexer import UserCounselorDefaultIndexer from ydl_ai_recommender.src.core.indexer import (
from ydl_ai_recommender.src.core.indexer import UserCounselorCombinationIndexer UserCounselorChatIndexer,
UserCounselorOrderIndexer,
UserCounselorDefaultIndexer,
UserCounselorCombinationIndexer,
CounselorCounselorCFIndexer,
)
from ydl_ai_recommender.src.core.profile import encode_profile from ydl_ai_recommender.src.core.profile import encode_profile
from ydl_ai_recommender.src.data.mysql_client import MySQLClient from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_data_path from ydl_ai_recommender.src.utils import get_conf_path, get_data_path
...@@ -26,7 +31,15 @@ class Recommender(): ...@@ -26,7 +31,15 @@ class Recommender():
class UserCFRecommender(Recommender): class UserCFRecommender(Recommender):
def __init__(self, top_n=5, k=20, is_use_db=True) -> None: def __init__(self, top_n=5, k=20, is_use_db=True, u2c='combination', c2c=None) -> None:
"""
params:
top_n: 每个召回的用户获取的相关咨询师个数
k: 召回的相似用户数
is_use_db: 是否使用数据库
u2c: [用户->咨询师] 索引方法
c2c: [咨询师->咨询师] 索引方法,None 表示不使用咨询师拓展
"""
super().__init__() super().__init__()
# 召回 top_n 个相似用户 # 召回 top_n 个相似用户
self.top_n = top_n self.top_n = top_n
...@@ -41,9 +54,21 @@ class UserCFRecommender(Recommender): ...@@ -41,9 +54,21 @@ class UserCFRecommender(Recommender):
self.default_indexer = UserCounselorDefaultIndexer(self.logger) self.default_indexer = UserCounselorDefaultIndexer(self.logger)
self.default_indexer.load_index_data() self.default_indexer.load_index_data()
if u2c == 'chat':
self.indexer = UserCounselorChatIndexer(self.logger)
elif u2c == 'order':
self.indexer = UserCounselorOrderIndexer(self.logger)
else:
self.indexer = UserCounselorCombinationIndexer(self.logger) self.indexer = UserCounselorCombinationIndexer(self.logger)
self.indexer.load_index_data() self.indexer.load_index_data()
if c2c:
self.c2c_indexer = CounselorCounselorCFIndexer(self.logger)
self.c2c_indexer.load_index_data()
else:
self.c2c_indexer = None
self.local_file_dir = get_data_path() self.local_file_dir = get_data_path()
self.load_data() self.load_data()
...@@ -100,7 +125,18 @@ class UserCFRecommender(Recommender): ...@@ -100,7 +125,18 @@ class UserCFRecommender(Recommender):
'score': score / max(0.01, float(simi_score)), 'score': score / max(0.01, float(simi_score)),
'from': 'similar_users {}'.format(similar_user_id), 'from': 'similar_users {}'.format(similar_user_id),
} for (c_id, score) in similar_user_counselor] } for (c_id, score) in similar_user_counselor]
counselors.extend(recommend_data)
supplement_data = []
if self.c2c_indexer:
for ro in recommend_data:
supplement_data.extend([{
'counselor': sc_id,
'score': ro['score'] * score,
'from': '{} supplement {}'.format(ro['from'], ro['counselor']),
} for (sc_id, score) in self.c2c_indexer.index(ro['counselor'], count=int(self.top_n))])
# } for (sc_id, score) in self.c2c_indexer.index(ro['counselor'], count=int(self.top_n / len(recommend_data)))])
counselors.extend(recommend_data + supplement_data)
counselors.sort(key=lambda x: x['score'], reverse=True) counselors.sort(key=lambda x: x['score'], reverse=True)
return counselors return counselors
......
...@@ -14,7 +14,7 @@ from ydl_ai_recommender.src.core.recommender import UserCFRecommender ...@@ -14,7 +14,7 @@ from ydl_ai_recommender.src.core.recommender import UserCFRecommender
logger = create_logger(__name__, 'service.log', is_rotating=True) logger = create_logger(__name__, 'service.log', is_rotating=True)
recommender = UserCFRecommender(top_n=5, k=5) recommender = UserCFRecommender(top_n=5, k=20, c2c=True)
class RecommendHandler(tornado.web.RequestHandler): class RecommendHandler(tornado.web.RequestHandler):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment