Commit 9f715802 by 柴鹏飞

新增咨询师->咨询师索引关系

parent 65e1c61d
......@@ -76,7 +76,7 @@ def evaluation(result_detail):
for top_n, counselor in enumerate(rd['recall_counselors']):
if counselor['counselor'] == rd['supplier_id']:
from_info = counselor['from'].split(' ')
if from_info[0] == 'top_100':
if from_info[0] == 'default':
metrics['default_recall_cnt'] += 1
else:
from_id = from_info[1]
......@@ -122,6 +122,8 @@ def evaluation(result_detail):
for i, n in enumerate(top_n_list):
logger.info('top {:<2} 的召回数 {:<4} 召回率 {:.2%}'.format(n, metrics['top_n_recall_cnt'][i], _sd(metrics['top_n_recall_cnt'][i], metrics['all_test_cnt'])))
return metrics
def do_test(args):
user_profile_dict = load_local_user_profile()
......@@ -147,7 +149,7 @@ def do_test(args):
continue
is_merge = args.mode == 0
size = 10 if args.mode == 0 else 0
size = 100 if args.mode == 0 else 0
recommend_result = recommender.recommend_with_profile(profile, size=size, is_merge=is_merge)
recall_resons = []
for rr in recommend_result:
......@@ -216,6 +218,104 @@ def update_test_data(args):
logger.info('测试数据更新完成')
def batch_test(args):
user_profile_dict = load_local_user_profile()
try:
old_users, test_orders = load_test_data()
except Exception as e:
logger.error('测试数据加载出错,请确认测试数据已经下载到本地,或启动命令中增加参数 "--do_update_test_data"', exc_info=True)
return
def _test(case):
logger.info('##' * 20 + ' 测试case ' + '##' * 20)
logger.info(json.dumps(case, ensure_ascii=False))
logger.info('##' * 20 + ' 测试case ' + '##' * 20)
recommender = UserCFRecommender(**case)
result_detail = []
for index, order_info in test_orders.iterrows():
if args.max_test > 0:
if index >= args.max_test:
break
uid = order_info['uid']
profile = user_profile_dict.get(uid)
if profile is None:
continue
is_merge = args.mode == 0
size = 100 if args.mode == 0 else 0
recommend_result = recommender.recommend_with_profile(profile, size=size, is_merge=is_merge)
recall_resons = []
for rr in recommend_result:
if rr['counselor'] == order_info['supplier_id']:
recall_resons.append(rr['from'])
result_detail.append({
'uid': uid,
'supplier_id': order_info['supplier_id'],
'is_old_user': uid in old_users,
'recall_counselors': recommend_result,
'is_recall': len(recall_resons) > 0,
'recall_reason': '|'.join(recall_resons),
'update_time': order_info['update_time'],
})
return result_detail
def _evaluation(result_detail):
metrics = evaluation(result_detail)
metrics['all_recall_ratio'] = _sd(metrics['all_recall_cnt'], metrics['all_test_cnt'])
metrics['old_user_recall_ratio'] = _sd(metrics['old_user_recall_cnt'], metrics['old_user_test_cnt'])
metrics['new_user_recall_ratio'] = _sd(metrics['new_user_recall_cnt'], metrics['new_user_test_cnt'])
metrics['top1'] = _sd(metrics['top_n_recall_cnt'][0], metrics['all_test_cnt'])
metrics['top10'] = _sd(metrics['top_n_recall_cnt'][3], metrics['all_test_cnt'])
metrics['top50'] = _sd(metrics['top_n_recall_cnt'][5], metrics['all_test_cnt'])
return metrics
test_cases = [
{'top_n': 2, 'k': 50, 'u2c': 'combination', 'c2c': None, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'combination', 'c2c': None, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'combination', 'c2c': None, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'combination', 'c2c': None, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'chat', 'c2c': None, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'chat', 'c2c': None, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'chat', 'c2c': None, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'chat', 'c2c': None, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'order', 'c2c': None, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'order', 'c2c': None, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'order', 'c2c': None, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'order', 'c2c': None, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'combination', 'c2c': True, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'combination', 'c2c': True, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'combination', 'c2c': True, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'combination', 'c2c': True, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'chat', 'c2c': True, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'chat', 'c2c': True, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'chat', 'c2c': True, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'chat', 'c2c': True, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'order', 'c2c': True, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'order', 'c2c': True, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'order', 'c2c': True, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'order', 'c2c': True, 'is_use_db': False},
]
all_test_result = []
for case in test_cases:
result_detail = _test(case)
metrics = _evaluation(result_detail)
all_test_result.append((case, metrics))
with open('batch_test_result.tsv', 'w', encoding='utf-8') as f:
for (case, metrics) in all_test_result:
items = [case['k'], case['top_n'], case['u2c'], case['c2c']]
items.extend([metrics['all_recall_ratio'], metrics['old_user_recall_ratio'], metrics['new_user_recall_ratio']])
items.extend([metrics['top1'], metrics['top10'], metrics['top50']])
f.write('\t'.join(map(lambda x: str(x), items)) + '\n')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
......@@ -227,6 +327,8 @@ if __name__ == '__main__':
parser.add_argument('--save_test_result', default=False, action='store_true', help='保存测试详情结果')
parser.add_argument('--show_result_by_day', default=False, action='store_true', help='测试结果是否按天展示')
parser.add_argument('--do_batch_test', default=False, action='store_true', help='批量测试')
parser.add_argument('--do_update_test_data', default=False, action='store_true', help='是否更新测试数据')
parser.add_argument('--start_date', default='-1', type=str, help='测试订单创建的开始时间,可以是"%Y-%m-%d"格式,也可以是 -3 表示前3天')
......@@ -238,5 +340,9 @@ if __name__ == '__main__':
logger.info('测试数据创建时间 %s', args.start_date)
update_test_data(args)
if args.do_batch_test:
logger.info('执行批量测试')
batch_test(args)
else:
logger.info('开始执行测试任务,测试模式为 %s', args.mode)
do_test(args)
\ No newline at end of file
......@@ -91,10 +91,10 @@ if __name__ == '__main__':
if args.task == 'make_index':
indexers = [
['[用户->咨询师]兜底关系索引', UserCounselorDefaultIndexer()],
['基于订单数据的[用户->咨询师]关系索引', UserCounselorOrderIndexer()],
['基于询单数据的[用户->咨询师]关系索引', UserCounselorChatIndexer()],
['基于多种数据组合的[用户->咨询师]关系索引', UserCounselorCombinationIndexer()],
['[用户->咨询师]兜底关系索引', UserCounselorDefaultIndexer(logger=logger)],
['基于订单数据的[用户->咨询师]关系索引', UserCounselorOrderIndexer(logger=logger)],
['基于询单数据的[用户->咨询师]关系索引', UserCounselorChatIndexer(logger=logger)],
['基于多种数据组合的[用户->咨询师]关系索引', UserCounselorCombinationIndexer(0.8, 0.2, logger=logger)],
]
logger.info('')
......
......@@ -3,6 +3,7 @@
import os
import json
from collections import Counter
from itertools import combinations
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
......@@ -22,16 +23,34 @@ class Indexer():
self.logger = logger
self.local_file_dir = get_data_path()
self.index_file = ''
self.index_data = {}
def make_index(self) -> Dict[str, List]:
raise NotImplementedError
def load_index_data(self):
if not os.path.exists(self.index_file):
self.logger.error('%s 不存在,确认索引是否已经构建完成', self.index_file)
raise '文件不存在'
index = []
with open(self.index_file, 'r', encoding='utf-8') as f:
index = json.load(f)
self.index_data = index
def index(self, q: str, count: int=0) -> List[Tuple[str, float]]:
def index(self, q='', count=0) -> List[Tuple[str, float]]:
"""
返回值类型:[[相似id, score], [相似id, score], ...]
"""
raise NotImplementedError
if len(self.index_data) == 0:
self.logger.error('未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法')
raise
def make_index(self) -> Dict[str, List]:
raise NotImplementedError
if count == 0:
return self.index_data.get(q, [])
else:
return self.index_data.get(q, [])[:count]
class UserCounselorDefaultIndexer(Indexer):
......@@ -40,7 +59,7 @@ class UserCounselorDefaultIndexer(Indexer):
"""
def __init__(self, logger=None) -> None:
super().__init__(logger)
self.data_manager = OrderDataManager(logger)
self.data_manager = OrderDataManager(self.logger)
self.index_file = os.path.join(self.local_file_dir, 'index_list.txt')
self.count = 100
self.index_data = []
......@@ -88,7 +107,7 @@ class UserCounselorOrderIndexer(Indexer):
def __init__(self, logger=None) -> None:
super().__init__(logger)
self.data_manager = OrderDataManager(logger)
self.data_manager = OrderDataManager(self.logger)
self.index_file = os.path.join(self.local_file_dir, 'user_counselor_order_index.json')
self.index_data = {}
self.now = datetime.now()
......@@ -100,13 +119,13 @@ class UserCounselorOrderIndexer(Indexer):
date = datetime.strptime(dt, '%Y-%m-%d')
if (self.now - date) <= timedelta(days=7):
w[0] = max(1., w[0], price / 400)
w[0] = max(w[0], min(1., price / 400))
elif (self.now - date) <= timedelta(days=30):
w[1] = max(1., w[1], price / 400)
w[1] = max(w[1], min(1., price / 400))
elif (self.now - date) <= timedelta(days=180):
w[2] = max(1., w[2], price / 400)
w[2] = max(w[2], min(1., price / 400))
else:
w[3] = max(1., w[3], price / 400)
w[3] = max(w[3], min(1., price / 400))
value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1
return value
......@@ -142,22 +161,6 @@ class UserCounselorOrderIndexer(Indexer):
self.logger.info('基于订单数据的[用户->咨询师]关系索引数据已保存,共有用户 %s', len(index))
def load_index_data(self):
index = []
with open(self.index_file, 'r', encoding='utf-8') as f:
index = json.load(f)
self.index_data = index
def index(self, q='', count=0) -> List[Tuple[str, float]]:
if len(self.index_data) == 0:
self.logger.error('未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法')
raise
if count == 0:
return self.index_data.get(q, [])
else:
return self.index_data.get(q, [])[:count]
class UserCounselorChatIndexer(Indexer):
"""
......@@ -166,7 +169,7 @@ class UserCounselorChatIndexer(Indexer):
def __init__(self, logger=None) -> None:
super().__init__(logger)
self.data_manager = ChatDataManager(logger)
self.data_manager = ChatDataManager(self.logger)
self.index_file = os.path.join(self.local_file_dir, 'user_counselor_chat_index.json')
self.index_data = {}
self.now = datetime.now()
......@@ -178,13 +181,13 @@ class UserCounselorChatIndexer(Indexer):
date = datetime.strptime(dt, '%Y-%m-%d')
if (self.now - date) <= timedelta(days=7):
w[0] = max(1., w[0], (u2d + d2u) / 20)
w[0] = max(w[0], min(1., (u2d + d2u) / 20))
elif (self.now - date) <= timedelta(days=30):
w[1] = max(1., w[1], (u2d + d2u) / 20)
w[1] = max(w[1], min(1., (u2d + d2u) / 20))
elif (self.now - date) <= timedelta(days=180):
w[2] = max(1., w[2], (u2d + d2u) / 20)
w[2] = max(w[2], min(1., (u2d + d2u) / 20))
else:
w[3] = max(1., w[3], (u2d + d2u) / 20)
w[3] = max(w[3], min(1., (u2d + d2u) / 20))
value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1
return value
......@@ -221,29 +224,12 @@ class UserCounselorChatIndexer(Indexer):
self.logger.info('基于询单数据的[用户->咨询师]关系索引数据已保存,共有用户 %s', len(index))
def load_index_data(self):
index = []
with open(self.index_file, 'r', encoding='utf-8') as f:
index = json.load(f)
self.index_data = index
def index(self, q='', count=0) -> List[Tuple[str, float]]:
if len(self.index_data) == 0:
self.logger.error('未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法')
raise
if count == 0:
return self.index_data.get(q, [])
else:
return self.index_data.get(q, [])[:count]
class UserCounselorCombinationIndexer(Indexer):
"""
基于多种数据组合的[用户->咨询师]关系索引
"""
def __init__(self, order_w=0.6, chat_w=0.4, logger=None) -> None:
def __init__(self, order_w=0.8, chat_w=0.2, logger=None) -> None:
super().__init__(logger)
self.order_w = order_w
self.chat_w = chat_w
......@@ -278,21 +264,83 @@ class UserCounselorCombinationIndexer(Indexer):
return index
def load_index_data(self):
index = []
with open(self.index_file, 'r', encoding='utf-8') as f:
index = json.load(f)
self.index_data = index
def index(self, q='', count=0) -> List[Tuple[str, float]]:
if len(self.index_data) == 0:
self.logger.error('未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法')
raise
class CounselorCounselorCFIndexer(Indexer):
"""
基于协同过滤的[咨询师->咨询师]关系索引
"""
def __init__(self, order_w=0.8, chat_w=0.2, logger=None) -> None:
super().__init__(logger)
self.order_w = order_w
self.chat_w = chat_w
self.index_file = os.path.join(self.local_file_dir, 'counselor_counselor_cf_index.json')
self.index_data = {}
if count == 0:
return self.index_data.get(q, [])
def load_pair_by_chat(self):
data_manager = ChatDataManager(self.logger)
df = data_manager.load_raw_data()
counselor_user_dict = {}
for _, row in df.iterrows():
uid, supplier_id = row['uid'], row['doctor_id']
if supplier_id not in counselor_user_dict:
counselor_user_dict[supplier_id] = set()
counselor_user_dict[supplier_id].add(uid)
return counselor_user_dict
def load_pair_by_order(self):
data_manager = OrderDataManager(self.logger)
df = data_manager.load_raw_data()
counselor_user_dict = {}
for _, row in df.iterrows():
uid, supplier_id = row['uid'], row['supplier_id']
if supplier_id not in counselor_user_dict:
counselor_user_dict[supplier_id] = set()
counselor_user_dict[supplier_id].add(uid)
return counselor_user_dict
def make_index(self) -> Dict[str, List]:
self.logger.info('')
self.logger.info('开始构建基于协同过滤的[咨询师->咨询师]关系索引')
counselor_user_dict = self.load_pair_by_order()
self.logger.info('基于订单的[用户->咨询师]数据加载完成')
_counselor_user_dict2 = self.load_pair_by_chat()
self.logger.info('基于询单的[用户->咨询师]数据加载完成')
for key, val in _counselor_user_dict2.items():
if key not in counselor_user_dict:
counselor_user_dict[key] = val
else:
return self.index_data.get(q, [])[:count]
counselor_user_dict[key].update(val)
self.logger.info('数据合并完成,共有咨询师 %s', len(counselor_user_dict))
index = {}
for [_u1, _u2] in combinations(counselor_user_dict.keys(), 2):
u1, u2 = min(_u1, _u2), max(_u1, _u2)
sim = len(counselor_user_dict[u1] & counselor_user_dict[u2]) / (len(counselor_user_dict[u1]) * len(counselor_user_dict[u2]))
if u1 not in index:
index[u1] = []
if u2 not in index:
index[u2] = []
index[u1].append((u2, sim))
index[u2].append((u1, sim))
self.logger.info('开始咨询师相似性排序')
for key, val in index.items():
# 根据相似性得分排序后,取前100个
index[key] = sorted(val, key=lambda x: x[1], reverse=True)[:100]
self.logger.info('基于协同过滤的[咨询师->咨询师]关系索引构建完成,共构建 %s 条数据', len(index))
with open(self.index_file, 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False)
return index
if __name__ == '__main__':
......@@ -307,3 +355,6 @@ if __name__ == '__main__':
indexer = UserCounselorCombinationIndexer()
indexer.make_index()
indexer = CounselorCounselorCFIndexer()
indexer.make_index()
\ No newline at end of file
......@@ -7,8 +7,13 @@ from typing import List, Dict
import faiss
import numpy as np
from ydl_ai_recommender.src.core.indexer import UserCounselorDefaultIndexer
from ydl_ai_recommender.src.core.indexer import UserCounselorCombinationIndexer
from ydl_ai_recommender.src.core.indexer import (
UserCounselorChatIndexer,
UserCounselorOrderIndexer,
UserCounselorDefaultIndexer,
UserCounselorCombinationIndexer,
CounselorCounselorCFIndexer,
)
from ydl_ai_recommender.src.core.profile import encode_profile
from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_data_path
......@@ -26,7 +31,15 @@ class Recommender():
class UserCFRecommender(Recommender):
def __init__(self, top_n=5, k=20, is_use_db=True) -> None:
def __init__(self, top_n=5, k=20, is_use_db=True, u2c='combination', c2c=None) -> None:
"""
params:
top_n: 每个召回的用户获取的相关咨询师个数
k: 召回的相似用户数
is_use_db: 是否使用数据库
u2c: [用户->咨询师] 索引方法
c2c: [咨询师->咨询师] 索引方法,None 表示不使用咨询师拓展
"""
super().__init__()
# 召回 top_n 个相似用户
self.top_n = top_n
......@@ -41,9 +54,21 @@ class UserCFRecommender(Recommender):
self.default_indexer = UserCounselorDefaultIndexer(self.logger)
self.default_indexer.load_index_data()
if u2c == 'chat':
self.indexer = UserCounselorChatIndexer(self.logger)
elif u2c == 'order':
self.indexer = UserCounselorOrderIndexer(self.logger)
else:
self.indexer = UserCounselorCombinationIndexer(self.logger)
self.indexer.load_index_data()
if c2c:
self.c2c_indexer = CounselorCounselorCFIndexer(self.logger)
self.c2c_indexer.load_index_data()
else:
self.c2c_indexer = None
self.local_file_dir = get_data_path()
self.load_data()
......@@ -100,7 +125,18 @@ class UserCFRecommender(Recommender):
'score': score / max(0.01, float(simi_score)),
'from': 'similar_users {}'.format(similar_user_id),
} for (c_id, score) in similar_user_counselor]
counselors.extend(recommend_data)
supplement_data = []
if self.c2c_indexer:
for ro in recommend_data:
supplement_data.extend([{
'counselor': sc_id,
'score': ro['score'] * score,
'from': '{} supplement {}'.format(ro['from'], ro['counselor']),
} for (sc_id, score) in self.c2c_indexer.index(ro['counselor'], count=int(self.top_n))])
# } for (sc_id, score) in self.c2c_indexer.index(ro['counselor'], count=int(self.top_n / len(recommend_data)))])
counselors.extend(recommend_data + supplement_data)
counselors.sort(key=lambda x: x['score'], reverse=True)
return counselors
......
......@@ -14,7 +14,7 @@ from ydl_ai_recommender.src.core.recommender import UserCFRecommender
logger = create_logger(__name__, 'service.log', is_rotating=True)
recommender = UserCFRecommender(top_n=5, k=5)
recommender = UserCFRecommender(top_n=5, k=20, c2c=True)
class RecommendHandler(tornado.web.RequestHandler):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment