Commit 9f715802 by 柴鹏飞

新增咨询师->咨询师索引关系

parent 65e1c61d
...@@ -76,7 +76,7 @@ def evaluation(result_detail): ...@@ -76,7 +76,7 @@ def evaluation(result_detail):
for top_n, counselor in enumerate(rd['recall_counselors']): for top_n, counselor in enumerate(rd['recall_counselors']):
if counselor['counselor'] == rd['supplier_id']: if counselor['counselor'] == rd['supplier_id']:
from_info = counselor['from'].split(' ') from_info = counselor['from'].split(' ')
if from_info[0] == 'top_100': if from_info[0] == 'default':
metrics['default_recall_cnt'] += 1 metrics['default_recall_cnt'] += 1
else: else:
from_id = from_info[1] from_id = from_info[1]
...@@ -122,6 +122,8 @@ def evaluation(result_detail): ...@@ -122,6 +122,8 @@ def evaluation(result_detail):
for i, n in enumerate(top_n_list): for i, n in enumerate(top_n_list):
logger.info('top {:<2} 的召回数 {:<4} 召回率 {:.2%}'.format(n, metrics['top_n_recall_cnt'][i], _sd(metrics['top_n_recall_cnt'][i], metrics['all_test_cnt']))) logger.info('top {:<2} 的召回数 {:<4} 召回率 {:.2%}'.format(n, metrics['top_n_recall_cnt'][i], _sd(metrics['top_n_recall_cnt'][i], metrics['all_test_cnt'])))
return metrics
def do_test(args): def do_test(args):
user_profile_dict = load_local_user_profile() user_profile_dict = load_local_user_profile()
...@@ -147,7 +149,7 @@ def do_test(args): ...@@ -147,7 +149,7 @@ def do_test(args):
continue continue
is_merge = args.mode == 0 is_merge = args.mode == 0
size = 10 if args.mode == 0 else 0 size = 100 if args.mode == 0 else 0
recommend_result = recommender.recommend_with_profile(profile, size=size, is_merge=is_merge) recommend_result = recommender.recommend_with_profile(profile, size=size, is_merge=is_merge)
recall_resons = [] recall_resons = []
for rr in recommend_result: for rr in recommend_result:
...@@ -216,6 +218,104 @@ def update_test_data(args): ...@@ -216,6 +218,104 @@ def update_test_data(args):
logger.info('测试数据更新完成') logger.info('测试数据更新完成')
def batch_test(args):
user_profile_dict = load_local_user_profile()
try:
old_users, test_orders = load_test_data()
except Exception as e:
logger.error('测试数据加载出错,请确认测试数据已经下载到本地,或启动命令中增加参数 "--do_update_test_data"', exc_info=True)
return
def _test(case):
logger.info('##' * 20 + ' 测试case ' + '##' * 20)
logger.info(json.dumps(case, ensure_ascii=False))
logger.info('##' * 20 + ' 测试case ' + '##' * 20)
recommender = UserCFRecommender(**case)
result_detail = []
for index, order_info in test_orders.iterrows():
if args.max_test > 0:
if index >= args.max_test:
break
uid = order_info['uid']
profile = user_profile_dict.get(uid)
if profile is None:
continue
is_merge = args.mode == 0
size = 100 if args.mode == 0 else 0
recommend_result = recommender.recommend_with_profile(profile, size=size, is_merge=is_merge)
recall_resons = []
for rr in recommend_result:
if rr['counselor'] == order_info['supplier_id']:
recall_resons.append(rr['from'])
result_detail.append({
'uid': uid,
'supplier_id': order_info['supplier_id'],
'is_old_user': uid in old_users,
'recall_counselors': recommend_result,
'is_recall': len(recall_resons) > 0,
'recall_reason': '|'.join(recall_resons),
'update_time': order_info['update_time'],
})
return result_detail
def _evaluation(result_detail):
metrics = evaluation(result_detail)
metrics['all_recall_ratio'] = _sd(metrics['all_recall_cnt'], metrics['all_test_cnt'])
metrics['old_user_recall_ratio'] = _sd(metrics['old_user_recall_cnt'], metrics['old_user_test_cnt'])
metrics['new_user_recall_ratio'] = _sd(metrics['new_user_recall_cnt'], metrics['new_user_test_cnt'])
metrics['top1'] = _sd(metrics['top_n_recall_cnt'][0], metrics['all_test_cnt'])
metrics['top10'] = _sd(metrics['top_n_recall_cnt'][3], metrics['all_test_cnt'])
metrics['top50'] = _sd(metrics['top_n_recall_cnt'][5], metrics['all_test_cnt'])
return metrics
test_cases = [
{'top_n': 2, 'k': 50, 'u2c': 'combination', 'c2c': None, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'combination', 'c2c': None, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'combination', 'c2c': None, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'combination', 'c2c': None, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'chat', 'c2c': None, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'chat', 'c2c': None, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'chat', 'c2c': None, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'chat', 'c2c': None, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'order', 'c2c': None, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'order', 'c2c': None, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'order', 'c2c': None, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'order', 'c2c': None, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'combination', 'c2c': True, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'combination', 'c2c': True, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'combination', 'c2c': True, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'combination', 'c2c': True, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'chat', 'c2c': True, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'chat', 'c2c': True, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'chat', 'c2c': True, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'chat', 'c2c': True, 'is_use_db': False},
{'top_n': 2, 'k': 50, 'u2c': 'order', 'c2c': True, 'is_use_db': False},
{'top_n': 5, 'k': 20, 'u2c': 'order', 'c2c': True, 'is_use_db': False},
{'top_n': 10, 'k': 10, 'u2c': 'order', 'c2c': True, 'is_use_db': False},
{'top_n': 20, 'k': 5, 'u2c': 'order', 'c2c': True, 'is_use_db': False},
]
all_test_result = []
for case in test_cases:
result_detail = _test(case)
metrics = _evaluation(result_detail)
all_test_result.append((case, metrics))
with open('batch_test_result.tsv', 'w', encoding='utf-8') as f:
for (case, metrics) in all_test_result:
items = [case['k'], case['top_n'], case['u2c'], case['c2c']]
items.extend([metrics['all_recall_ratio'], metrics['old_user_recall_ratio'], metrics['new_user_recall_ratio']])
items.extend([metrics['top1'], metrics['top10'], metrics['top50']])
f.write('\t'.join(map(lambda x: str(x), items)) + '\n')
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -227,6 +327,8 @@ if __name__ == '__main__': ...@@ -227,6 +327,8 @@ if __name__ == '__main__':
parser.add_argument('--save_test_result', default=False, action='store_true', help='保存测试详情结果') parser.add_argument('--save_test_result', default=False, action='store_true', help='保存测试详情结果')
parser.add_argument('--show_result_by_day', default=False, action='store_true', help='测试结果是否按天展示') parser.add_argument('--show_result_by_day', default=False, action='store_true', help='测试结果是否按天展示')
parser.add_argument('--do_batch_test', default=False, action='store_true', help='批量测试')
parser.add_argument('--do_update_test_data', default=False, action='store_true', help='是否更新测试数据') parser.add_argument('--do_update_test_data', default=False, action='store_true', help='是否更新测试数据')
parser.add_argument('--start_date', default='-1', type=str, help='测试订单创建的开始时间,可以是"%Y-%m-%d"格式,也可以是 -3 表示前3天') parser.add_argument('--start_date', default='-1', type=str, help='测试订单创建的开始时间,可以是"%Y-%m-%d"格式,也可以是 -3 表示前3天')
...@@ -238,5 +340,9 @@ if __name__ == '__main__': ...@@ -238,5 +340,9 @@ if __name__ == '__main__':
logger.info('测试数据创建时间 %s', args.start_date) logger.info('测试数据创建时间 %s', args.start_date)
update_test_data(args) update_test_data(args)
logger.info('开始执行测试任务,测试模式为 %s', args.mode) if args.do_batch_test:
do_test(args) logger.info('执行批量测试')
\ No newline at end of file batch_test(args)
else:
logger.info('开始执行测试任务,测试模式为 %s', args.mode)
do_test(args)
\ No newline at end of file
...@@ -91,10 +91,10 @@ if __name__ == '__main__': ...@@ -91,10 +91,10 @@ if __name__ == '__main__':
if args.task == 'make_index': if args.task == 'make_index':
indexers = [ indexers = [
['[用户->咨询师]兜底关系索引', UserCounselorDefaultIndexer()], ['[用户->咨询师]兜底关系索引', UserCounselorDefaultIndexer(logger=logger)],
['基于订单数据的[用户->咨询师]关系索引', UserCounselorOrderIndexer()], ['基于订单数据的[用户->咨询师]关系索引', UserCounselorOrderIndexer(logger=logger)],
['基于询单数据的[用户->咨询师]关系索引', UserCounselorChatIndexer()], ['基于询单数据的[用户->咨询师]关系索引', UserCounselorChatIndexer(logger=logger)],
['基于多种数据组合的[用户->咨询师]关系索引', UserCounselorCombinationIndexer()], ['基于多种数据组合的[用户->咨询师]关系索引', UserCounselorCombinationIndexer(0.8, 0.2, logger=logger)],
] ]
logger.info('') logger.info('')
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import os import os
import json import json
from collections import Counter from collections import Counter
from itertools import combinations
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
...@@ -22,16 +23,34 @@ class Indexer(): ...@@ -22,16 +23,34 @@ class Indexer():
self.logger = logger self.logger = logger
self.local_file_dir = get_data_path() self.local_file_dir = get_data_path()
self.index_file = ''
self.index_data = {}
def make_index(self) -> Dict[str, List]:
raise NotImplementedError
def load_index_data(self):
if not os.path.exists(self.index_file):
self.logger.error('%s 不存在,确认索引是否已经构建完成', self.index_file)
raise '文件不存在'
def index(self, q: str, count: int=0) -> List[Tuple[str, float]]: index = []
with open(self.index_file, 'r', encoding='utf-8') as f:
index = json.load(f)
self.index_data = index
def index(self, q='', count=0) -> List[Tuple[str, float]]:
""" """
返回值类型:[[相似id, score], [相似id, score], ...] 返回值类型:[[相似id, score], [相似id, score], ...]
""" """
raise NotImplementedError if len(self.index_data) == 0:
self.logger.error('未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法')
def make_index(self) -> Dict[str, List]: raise
raise NotImplementedError
if count == 0:
return self.index_data.get(q, [])
else:
return self.index_data.get(q, [])[:count]
class UserCounselorDefaultIndexer(Indexer): class UserCounselorDefaultIndexer(Indexer):
...@@ -40,7 +59,7 @@ class UserCounselorDefaultIndexer(Indexer): ...@@ -40,7 +59,7 @@ class UserCounselorDefaultIndexer(Indexer):
""" """
def __init__(self, logger=None) -> None: def __init__(self, logger=None) -> None:
super().__init__(logger) super().__init__(logger)
self.data_manager = OrderDataManager(logger) self.data_manager = OrderDataManager(self.logger)
self.index_file = os.path.join(self.local_file_dir, 'index_list.txt') self.index_file = os.path.join(self.local_file_dir, 'index_list.txt')
self.count = 100 self.count = 100
self.index_data = [] self.index_data = []
...@@ -88,7 +107,7 @@ class UserCounselorOrderIndexer(Indexer): ...@@ -88,7 +107,7 @@ class UserCounselorOrderIndexer(Indexer):
def __init__(self, logger=None) -> None: def __init__(self, logger=None) -> None:
super().__init__(logger) super().__init__(logger)
self.data_manager = OrderDataManager(logger) self.data_manager = OrderDataManager(self.logger)
self.index_file = os.path.join(self.local_file_dir, 'user_counselor_order_index.json') self.index_file = os.path.join(self.local_file_dir, 'user_counselor_order_index.json')
self.index_data = {} self.index_data = {}
self.now = datetime.now() self.now = datetime.now()
...@@ -100,13 +119,13 @@ class UserCounselorOrderIndexer(Indexer): ...@@ -100,13 +119,13 @@ class UserCounselorOrderIndexer(Indexer):
date = datetime.strptime(dt, '%Y-%m-%d') date = datetime.strptime(dt, '%Y-%m-%d')
if (self.now - date) <= timedelta(days=7): if (self.now - date) <= timedelta(days=7):
w[0] = max(1., w[0], price / 400) w[0] = max(w[0], min(1., price / 400))
elif (self.now - date) <= timedelta(days=30): elif (self.now - date) <= timedelta(days=30):
w[1] = max(1., w[1], price / 400) w[1] = max(w[1], min(1., price / 400))
elif (self.now - date) <= timedelta(days=180): elif (self.now - date) <= timedelta(days=180):
w[2] = max(1., w[2], price / 400) w[2] = max(w[2], min(1., price / 400))
else: else:
w[3] = max(1., w[3], price / 400) w[3] = max(w[3], min(1., price / 400))
value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1 value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1
return value return value
...@@ -142,22 +161,6 @@ class UserCounselorOrderIndexer(Indexer): ...@@ -142,22 +161,6 @@ class UserCounselorOrderIndexer(Indexer):
self.logger.info('基于订单数据的[用户->咨询师]关系索引数据已保存,共有用户 %s', len(index)) self.logger.info('基于订单数据的[用户->咨询师]关系索引数据已保存,共有用户 %s', len(index))
def load_index_data(self):
index = []
with open(self.index_file, 'r', encoding='utf-8') as f:
index = json.load(f)
self.index_data = index
def index(self, q='', count=0) -> List[Tuple[str, float]]:
if len(self.index_data) == 0:
self.logger.error('未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法')
raise
if count == 0:
return self.index_data.get(q, [])
else:
return self.index_data.get(q, [])[:count]
class UserCounselorChatIndexer(Indexer): class UserCounselorChatIndexer(Indexer):
""" """
...@@ -166,7 +169,7 @@ class UserCounselorChatIndexer(Indexer): ...@@ -166,7 +169,7 @@ class UserCounselorChatIndexer(Indexer):
def __init__(self, logger=None) -> None: def __init__(self, logger=None) -> None:
super().__init__(logger) super().__init__(logger)
self.data_manager = ChatDataManager(logger) self.data_manager = ChatDataManager(self.logger)
self.index_file = os.path.join(self.local_file_dir, 'user_counselor_chat_index.json') self.index_file = os.path.join(self.local_file_dir, 'user_counselor_chat_index.json')
self.index_data = {} self.index_data = {}
self.now = datetime.now() self.now = datetime.now()
...@@ -178,13 +181,13 @@ class UserCounselorChatIndexer(Indexer): ...@@ -178,13 +181,13 @@ class UserCounselorChatIndexer(Indexer):
date = datetime.strptime(dt, '%Y-%m-%d') date = datetime.strptime(dt, '%Y-%m-%d')
if (self.now - date) <= timedelta(days=7): if (self.now - date) <= timedelta(days=7):
w[0] = max(1., w[0], (u2d + d2u) / 20) w[0] = max(w[0], min(1., (u2d + d2u) / 20))
elif (self.now - date) <= timedelta(days=30): elif (self.now - date) <= timedelta(days=30):
w[1] = max(1., w[1], (u2d + d2u) / 20) w[1] = max(w[1], min(1., (u2d + d2u) / 20))
elif (self.now - date) <= timedelta(days=180): elif (self.now - date) <= timedelta(days=180):
w[2] = max(1., w[2], (u2d + d2u) / 20) w[2] = max(w[2], min(1., (u2d + d2u) / 20))
else: else:
w[3] = max(1., w[3], (u2d + d2u) / 20) w[3] = max(w[3], min(1., (u2d + d2u) / 20))
value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1 value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1
return value return value
...@@ -221,29 +224,12 @@ class UserCounselorChatIndexer(Indexer): ...@@ -221,29 +224,12 @@ class UserCounselorChatIndexer(Indexer):
self.logger.info('基于询单数据的[用户->咨询师]关系索引数据已保存,共有用户 %s', len(index)) self.logger.info('基于询单数据的[用户->咨询师]关系索引数据已保存,共有用户 %s', len(index))
def load_index_data(self):
index = []
with open(self.index_file, 'r', encoding='utf-8') as f:
index = json.load(f)
self.index_data = index
def index(self, q='', count=0) -> List[Tuple[str, float]]:
if len(self.index_data) == 0:
self.logger.error('未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法')
raise
if count == 0:
return self.index_data.get(q, [])
else:
return self.index_data.get(q, [])[:count]
class UserCounselorCombinationIndexer(Indexer): class UserCounselorCombinationIndexer(Indexer):
""" """
基于多种数据组合的[用户->咨询师]关系索引 基于多种数据组合的[用户->咨询师]关系索引
""" """
def __init__(self, order_w=0.6, chat_w=0.4, logger=None) -> None: def __init__(self, order_w=0.8, chat_w=0.2, logger=None) -> None:
super().__init__(logger) super().__init__(logger)
self.order_w = order_w self.order_w = order_w
self.chat_w = chat_w self.chat_w = chat_w
...@@ -278,21 +264,83 @@ class UserCounselorCombinationIndexer(Indexer): ...@@ -278,21 +264,83 @@ class UserCounselorCombinationIndexer(Indexer):
return index return index
def load_index_data(self):
index = []
with open(self.index_file, 'r', encoding='utf-8') as f:
index = json.load(f)
self.index_data = index
def index(self, q='', count=0) -> List[Tuple[str, float]]:
if len(self.index_data) == 0:
self.logger.error('未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法')
raise
if count == 0: class CounselorCounselorCFIndexer(Indexer):
return self.index_data.get(q, []) """
else: 基于协同过滤的[咨询师->咨询师]关系索引
return self.index_data.get(q, [])[:count] """
def __init__(self, order_w=0.8, chat_w=0.2, logger=None) -> None:
super().__init__(logger)
self.order_w = order_w
self.chat_w = chat_w
self.index_file = os.path.join(self.local_file_dir, 'counselor_counselor_cf_index.json')
self.index_data = {}
def load_pair_by_chat(self):
data_manager = ChatDataManager(self.logger)
df = data_manager.load_raw_data()
counselor_user_dict = {}
for _, row in df.iterrows():
uid, supplier_id = row['uid'], row['doctor_id']
if supplier_id not in counselor_user_dict:
counselor_user_dict[supplier_id] = set()
counselor_user_dict[supplier_id].add(uid)
return counselor_user_dict
def load_pair_by_order(self):
data_manager = OrderDataManager(self.logger)
df = data_manager.load_raw_data()
counselor_user_dict = {}
for _, row in df.iterrows():
uid, supplier_id = row['uid'], row['supplier_id']
if supplier_id not in counselor_user_dict:
counselor_user_dict[supplier_id] = set()
counselor_user_dict[supplier_id].add(uid)
return counselor_user_dict
def make_index(self) -> Dict[str, List]:
self.logger.info('')
self.logger.info('开始构建基于协同过滤的[咨询师->咨询师]关系索引')
counselor_user_dict = self.load_pair_by_order()
self.logger.info('基于订单的[用户->咨询师]数据加载完成')
_counselor_user_dict2 = self.load_pair_by_chat()
self.logger.info('基于询单的[用户->咨询师]数据加载完成')
for key, val in _counselor_user_dict2.items():
if key not in counselor_user_dict:
counselor_user_dict[key] = val
else:
counselor_user_dict[key].update(val)
self.logger.info('数据合并完成,共有咨询师 %s', len(counselor_user_dict))
index = {}
for [_u1, _u2] in combinations(counselor_user_dict.keys(), 2):
u1, u2 = min(_u1, _u2), max(_u1, _u2)
sim = len(counselor_user_dict[u1] & counselor_user_dict[u2]) / (len(counselor_user_dict[u1]) * len(counselor_user_dict[u2]))
if u1 not in index:
index[u1] = []
if u2 not in index:
index[u2] = []
index[u1].append((u2, sim))
index[u2].append((u1, sim))
self.logger.info('开始咨询师相似性排序')
for key, val in index.items():
# 根据相似性得分排序后,取前100个
index[key] = sorted(val, key=lambda x: x[1], reverse=True)[:100]
self.logger.info('基于协同过滤的[咨询师->咨询师]关系索引构建完成,共构建 %s 条数据', len(index))
with open(self.index_file, 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False)
return index
if __name__ == '__main__': if __name__ == '__main__':
...@@ -306,4 +354,7 @@ if __name__ == '__main__': ...@@ -306,4 +354,7 @@ if __name__ == '__main__':
indexer.make_index() indexer.make_index()
indexer = UserCounselorCombinationIndexer() indexer = UserCounselorCombinationIndexer()
indexer.make_index()
indexer = CounselorCounselorCFIndexer()
indexer.make_index() indexer.make_index()
\ No newline at end of file
...@@ -7,8 +7,13 @@ from typing import List, Dict ...@@ -7,8 +7,13 @@ from typing import List, Dict
import faiss import faiss
import numpy as np import numpy as np
from ydl_ai_recommender.src.core.indexer import UserCounselorDefaultIndexer from ydl_ai_recommender.src.core.indexer import (
from ydl_ai_recommender.src.core.indexer import UserCounselorCombinationIndexer UserCounselorChatIndexer,
UserCounselorOrderIndexer,
UserCounselorDefaultIndexer,
UserCounselorCombinationIndexer,
CounselorCounselorCFIndexer,
)
from ydl_ai_recommender.src.core.profile import encode_profile from ydl_ai_recommender.src.core.profile import encode_profile
from ydl_ai_recommender.src.data.mysql_client import MySQLClient from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_data_path from ydl_ai_recommender.src.utils import get_conf_path, get_data_path
...@@ -26,7 +31,15 @@ class Recommender(): ...@@ -26,7 +31,15 @@ class Recommender():
class UserCFRecommender(Recommender): class UserCFRecommender(Recommender):
def __init__(self, top_n=5, k=20, is_use_db=True) -> None: def __init__(self, top_n=5, k=20, is_use_db=True, u2c='combination', c2c=None) -> None:
"""
params:
top_n: 每个召回的用户获取的相关咨询师个数
k: 召回的相似用户数
is_use_db: 是否使用数据库
u2c: [用户->咨询师] 索引方法
c2c: [咨询师->咨询师] 索引方法,None 表示不使用咨询师拓展
"""
super().__init__() super().__init__()
# 召回 top_n 个相似用户 # 召回 top_n 个相似用户
self.top_n = top_n self.top_n = top_n
...@@ -41,9 +54,21 @@ class UserCFRecommender(Recommender): ...@@ -41,9 +54,21 @@ class UserCFRecommender(Recommender):
self.default_indexer = UserCounselorDefaultIndexer(self.logger) self.default_indexer = UserCounselorDefaultIndexer(self.logger)
self.default_indexer.load_index_data() self.default_indexer.load_index_data()
self.indexer = UserCounselorCombinationIndexer(self.logger)
if u2c == 'chat':
self.indexer = UserCounselorChatIndexer(self.logger)
elif u2c == 'order':
self.indexer = UserCounselorOrderIndexer(self.logger)
else:
self.indexer = UserCounselorCombinationIndexer(self.logger)
self.indexer.load_index_data() self.indexer.load_index_data()
if c2c:
self.c2c_indexer = CounselorCounselorCFIndexer(self.logger)
self.c2c_indexer.load_index_data()
else:
self.c2c_indexer = None
self.local_file_dir = get_data_path() self.local_file_dir = get_data_path()
self.load_data() self.load_data()
...@@ -100,7 +125,18 @@ class UserCFRecommender(Recommender): ...@@ -100,7 +125,18 @@ class UserCFRecommender(Recommender):
'score': score / max(0.01, float(simi_score)), 'score': score / max(0.01, float(simi_score)),
'from': 'similar_users {}'.format(similar_user_id), 'from': 'similar_users {}'.format(similar_user_id),
} for (c_id, score) in similar_user_counselor] } for (c_id, score) in similar_user_counselor]
counselors.extend(recommend_data)
supplement_data = []
if self.c2c_indexer:
for ro in recommend_data:
supplement_data.extend([{
'counselor': sc_id,
'score': ro['score'] * score,
'from': '{} supplement {}'.format(ro['from'], ro['counselor']),
} for (sc_id, score) in self.c2c_indexer.index(ro['counselor'], count=int(self.top_n))])
# } for (sc_id, score) in self.c2c_indexer.index(ro['counselor'], count=int(self.top_n / len(recommend_data)))])
counselors.extend(recommend_data + supplement_data)
counselors.sort(key=lambda x: x['score'], reverse=True) counselors.sort(key=lambda x: x['score'], reverse=True)
return counselors return counselors
......
...@@ -14,7 +14,7 @@ from ydl_ai_recommender.src.core.recommender import UserCFRecommender ...@@ -14,7 +14,7 @@ from ydl_ai_recommender.src.core.recommender import UserCFRecommender
logger = create_logger(__name__, 'service.log', is_rotating=True) logger = create_logger(__name__, 'service.log', is_rotating=True)
recommender = UserCFRecommender(top_n=5, k=5) recommender = UserCFRecommender(top_n=5, k=20, c2c=True)
class RecommendHandler(tornado.web.RequestHandler): class RecommendHandler(tornado.web.RequestHandler):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment