Commit 523f5373 by 柴鹏飞

refactor code

parent 7401bbe5
...@@ -190,11 +190,11 @@ def update_test_data(args): ...@@ -190,11 +190,11 @@ def update_test_data(args):
# 订单数据 # 订单数据
manager = OrderDataManager(client) manager = OrderDataManager(client)
manager.update_test_order_data(conditions=conditions) manager.update_test_data(conditions=conditions)
# 用户画像数据 # 用户画像数据
manager = ProfileManager(client) manager = ProfileManager(client)
manager.update_test_profile(conditions=conditions) manager.update_test_data(conditions=conditions)
logger.info('测试数据更新完成') logger.info('测试数据更新完成')
......
...@@ -4,9 +4,17 @@ import os ...@@ -4,9 +4,17 @@ import os
import argparse import argparse
from datetime import datetime from datetime import datetime
from ydl_ai_recommender.src.core.manager import OrderDataManager from ydl_ai_recommender.src.core.manager import (
from ydl_ai_recommender.src.core.manager import ChatDataManager OrderDataManager,
from ydl_ai_recommender.src.core.manager import ProfileManager ChatDataManager,
ProfileManager,
)
from ydl_ai_recommender.src.core.indexer import (
UserCounselorDefaultIndexer,
UserCounselorOrderIndexer,
UserCounselorChatIndexer,
UserCounselorCombinationIndexer,
)
from ydl_ai_recommender.src.data.mysql_client import MySQLClient from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_project_path from ydl_ai_recommender.src.utils import get_conf_path, get_project_path
from ydl_ai_recommender.src.utils.log import create_logger from ydl_ai_recommender.src.utils.log import create_logger
...@@ -18,79 +26,92 @@ logger = create_logger(__name__, 'update.log') ...@@ -18,79 +26,92 @@ logger = create_logger(__name__, 'update.log')
parser = argparse.ArgumentParser(description='壹点灵 咨询师推荐 算法召回 离线更新数据模型') parser = argparse.ArgumentParser(description='壹点灵 咨询师推荐 算法召回 离线更新数据模型')
parser.add_argument( parser.add_argument(
'-t', '--task', type=str, required=True, '-t', '--task', type=str, required=True,
choices=('load_db_data', 'make_embedding', 'make_virtual_embedding'), help='执行任务名称' choices=('load_db_data', 'make_embedding', 'make_index'), help='执行任务名称'
) )
parser.add_argument('--index_last_date', default=None, type=str, help='构建索引最后日期,超过该日期的数据不使用') parser.add_argument('--index_last_date', default=None, type=str, help='构建索引最后日期,超过该日期的数据不使用')
args = parser.parse_args() args = parser.parse_args()
if __name__ == '__main__': def initialize_dir():
# 创建数据目录
now = datetime.now()
really_data_dir = os.path.join(get_project_path(), 'data_{}'.format(now.strftime('%Y%m%d_%H%M%S')))
default_data_dir = os.path.join(get_project_path(), 'data')
if args.task == 'load_db_data': # 判断data目录是否存在
logger.info('') if os.path.exists(default_data_dir):
if os.path.islink(default_data_dir):
os.unlink(default_data_dir)
else:
logger.error('%s 目录已经存在!请备份后删除该目录再重新执行本操作', default_data_dir)
# 创建数据目录 os.mkdir(really_data_dir)
now = datetime.now() logger.info('创建数据保存目录成功')
really_data_dir = os.path.join(get_project_path(), 'data_{}'.format(now.strftime('%Y%m%d_%H%M%S')))
default_data_dir = os.path.join(get_project_path(), 'data') # 创建软连接
os.symlink(really_data_dir, default_data_dir)
logger.info('创建软连接成功 %s -> %s', really_data_dir, default_data_dir)
# 判断data目录是否存在 # TODO 历史数据删除
if os.path.exists(default_data_dir):
if os.path.islink(default_data_dir):
os.unlink(default_data_dir)
else:
logger.error('%s 目录已经存在!请备份后删除该目录再重新执行本操作', default_data_dir)
os.mkdir(really_data_dir)
logger.info('创建数据保存目录成功')
# 创建软连接
os.symlink(really_data_dir, default_data_dir)
logger.info('创建软连接成功 %s -> %s', really_data_dir, default_data_dir)
# TODO 历史数据删除 if __name__ == '__main__':
logger.info('')
if args.task == 'load_db_data':
initialize_dir()
logger.info('开始从数据库中更新数据') logger.info('开始从数据库中更新数据')
client = MySQLClient.create_from_config_file(get_conf_path()) client = MySQLClient.create_from_config_file(get_conf_path())
logger.info('开始从数据库中更新画像数据') managers = [
profile_manager = ProfileManager(client) ['画像数据', ProfileManager(client)],
profile_manager.update_profile() ['订单数据', OrderDataManager(client)],
['询单数据', ChatDataManager(client)],
logger.info('开始从数据库中更新订单数据') ]
order_data_manager = OrderDataManager(client)
order_data_manager.update_order_data()
logger.info('开始从数据库中更新询单数据') for [name, manager] in managers:
chat_data_manager = ChatDataManager(client) logger.info('开始更新 %s', name)
chat_data_manager.update_data() manager.update_data()
logger.info('%s 更新完成', name)
logger.info('')
logger.info('所有数据更新数据完成') logger.info('所有数据更新数据完成')
if args.task == 'make_embedding': if args.task == 'make_embedding':
logger.info('') logger.info('')
logger.info('--' * 50)
logger.info('开始构建用户特征 embedding') logger.info('开始构建用户特征 embedding')
manager = ProfileManager() manager = ProfileManager()
manager.make_embeddings() manager.make_embeddings()
logger.info('用户特征 embedding 构建完成') logger.info('用户特征 embedding 构建完成')
logger.info('开始构建订单相关索引') if args.task == 'make_index':
manager = OrderDataManager() indexers = [
manager.make_index() ['[用户->咨询师]兜底关系索引', UserCounselorDefaultIndexer()],
logger.info('订单相关索引 构建完成') ['基于订单数据的[用户->咨询师]关系索引', UserCounselorOrderIndexer()],
['基于询单数据的[用户->咨询师]关系索引', UserCounselorChatIndexer()],
['基于多种数据组合的[用户->咨询师]关系索引', UserCounselorCombinationIndexer()],
]
logger.info('开始构建询单相关索引') logger.info('')
chat_data_manager = ChatDataManager() logger.info('--' * 50)
chat_data_manager.make_index()
logger.info('询单相关索引 构建完成')
for [name, indexer] in indexers:
logger.info('开始构建 %s', name)
indexer.make_index()
logger.info('%s 构建完成', name)
logger.info('')
if args.task == 'make_virtual_embedding': logger.info('所有索引更新数据完成')
logger.info('')
logger.info('开始构建用户特征虚拟embedding')
manager = ProfileManager() # if args.task == 'make_virtual_embedding':
manager.make_virtual_embedding() # logger.info('')
logger.info('用户特征虚拟 embedding 构建完成') # logger.info('开始构建用户特征虚拟embedding')
\ No newline at end of file # manager = ProfileManager()
# manager.make_virtual_embedding()
# logger.info('用户特征虚拟 embedding 构建完成')
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
import json
from collections import Counter
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
from ydl_ai_recommender.src.utils import get_data_path
from ydl_ai_recommender.src.utils.log import create_logger
from ydl_ai_recommender.src.core.manager import OrderDataManager, ChatDataManager
class Indexer():
"""
索引构建、管理类
"""
def __init__(self, logger=None) -> None:
if logger is None:
self.logger = create_logger(__name__)
else:
self.logger = logger
self.local_file_dir = get_data_path()
def index(self, q: str, count: int=0) -> List[Tuple[str, float]]:
"""
返回值类型:[[相似id, score], [相似id, score], ...]
"""
raise NotImplementedError
def make_index(self) -> Dict[str, List]:
raise NotImplementedError
class UserCounselorDefaultIndexer(Indexer):
"""
[用户->咨询师]兜底关系索引
"""
def __init__(self, logger=None) -> None:
super().__init__(logger)
self.data_manager = OrderDataManager(logger)
self.index_file = os.path.join(self.local_file_dir, 'index_list.txt')
self.count = 100
self.index_data = []
def load_index_data(self):
index = []
with open(self.index_file, 'r', encoding='utf-8') as f:
index = [(line.strip(), 0.01 - 0.0001 * index) for index, line in enumerate(f)]
self.index_data = index
def index(self, q='', count=0) -> List[Tuple[str, float]]:
if len(self.index_data) == 0:
self.logger.error('未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法')
raise
if count == 0:
return self.index_data
else:
return self.index_data[:count]
def make_index(self) -> Dict[str, List]:
self.logger.info('')
self.logger.info('开始构建[用户->咨询师]兜底关系索引')
df = self.data_manager.load_raw_data()
self.logger.info('构建索引数据加载完成,共加载 %s 条数据', len(df))
supplier_cnter = Counter(df['supplier_id'])
index_list = []
for key, _ in supplier_cnter.most_common(self.count):
index_list.append(str(key))
self.logger.info('[用户->咨询师]兜底关系索引构建完成')
with open(self.index_file, 'w', encoding='utf-8') as f:
f.write('\n'.join(index_list))
self.logger.info('[用户->咨询师]兜底关系索引构建完成已保存')
return index_list
class UserCounselorOrderIndexer(Indexer):
"""
基于订单数据的[用户->咨询师]关系索引
"""
def __init__(self, logger=None) -> None:
super().__init__(logger)
self.data_manager = OrderDataManager(logger)
self.index_file = os.path.join(self.local_file_dir, 'user_counselor_order_index.json')
self.index_data = {}
self.now = datetime.now()
def _compute_score(self, infos: List) -> float:
w = [0, 0, 0, 0]
for [_price, dt] in infos:
price = float(_price)
date = datetime.strptime(dt, '%Y-%m-%d')
if (self.now - date) <= timedelta(days=7):
w[0] = max(1., w[0], price / 400)
elif (self.now - date) <= timedelta(days=30):
w[1] = max(1., w[1], price / 400)
elif (self.now - date) <= timedelta(days=180):
w[2] = max(1., w[2], price / 400)
else:
w[3] = max(1., w[3], price / 400)
value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1
return value
def make_index(self) -> Dict[str, List]:
self.logger.info('')
self.logger.info('开始构建基于订单数据的[用户->咨询师]关系索引')
df = self.data_manager.load_raw_data()
self.logger.info('构建索引数据加载完成,共加载 %s 条数据', len(df))
user_order = {}
for index, row in df.iterrows():
uid, supplier_id = row['uid'], row['supplier_id']
if uid not in user_order:
user_order[uid] = {}
if supplier_id not in user_order[uid]:
user_order[uid][supplier_id] = []
user_order[uid][supplier_id].append([row['price'], row['update_time']])
index = {}
for uid, orders in user_order.items():
supplier_values = []
for supplier_id, infos in orders.items():
value = self._compute_score(infos)
supplier_values.append((supplier_id, value))
index[uid] = sorted(supplier_values, key=lambda x: x[1], reverse=True)
self.logger.info('基于订单数据的[用户->咨询师]关系索引构建完成,共构建 %s 条数据', len(index))
with open(self.index_file, 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False)
self.logger.info('基于订单数据的[用户->咨询师]关系索引数据已保存,共有用户 %s', len(index))
def load_index_data(self):
index = []
with open(self.index_file, 'r', encoding='utf-8') as f:
index = json.load(f)
self.index_data = index
def index(self, q='', count=0) -> List[Tuple[str, float]]:
if len(self.index_data) == 0:
self.logger.error('未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法')
raise
if count == 0:
return self.index_data.get(q, [])
else:
return self.index_data.get(q, [])[:count]
class UserCounselorChatIndexer(Indexer):
"""
基于询单数据的[用户->咨询师]关系索引
"""
def __init__(self, logger=None) -> None:
super().__init__(logger)
self.data_manager = ChatDataManager(logger)
self.index_file = os.path.join(self.local_file_dir, 'user_counselor_chat_index.json')
self.index_data = {}
self.now = datetime.now()
def _compute_score(self, infos: List) -> float:
w = [0, 0, 0, 0]
for [dt, _u2d, _d2u] in infos:
u2d, d2u = int(_u2d), int(_d2u)
date = datetime.strptime(dt, '%Y-%m-%d')
if (self.now - date) <= timedelta(days=7):
w[0] = max(1., w[0], (u2d + d2u) / 20)
elif (self.now - date) <= timedelta(days=30):
w[1] = max(1., w[1], (u2d + d2u) / 20)
elif (self.now - date) <= timedelta(days=180):
w[2] = max(1., w[2], (u2d + d2u) / 20)
else:
w[3] = max(1., w[3], (u2d + d2u) / 20)
value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1
return value
def make_index(self) -> Dict[str, List]:
self.logger.info('')
self.logger.info('开始构建基于询单数据的[用户->咨询师]关系索引')
df = self.data_manager.load_raw_data()
self.logger.info('构建索引数据加载完成,共加载 %s 条数据', len(df))
user_chat = {}
for index, row in df.iterrows():
uid, supplier_id = row['uid'], row['doctor_id']
if uid not in user_chat:
user_chat[uid] = {}
if supplier_id not in user_chat[uid]:
user_chat[uid][supplier_id] = []
user_chat[uid][supplier_id].append([row['dt'], row['user_to_doctor'], row['doctor_to_user']])
index = {}
for uid, chats in user_chat.items():
supplier_values = []
for supplier_id, infos in chats.items():
value = self._compute_score(infos)
supplier_values.append([supplier_id, value])
index[uid] = sorted(supplier_values, key=lambda x: x[1], reverse=True)
self.logger.info('基于询单数据的[用户->咨询师]关系索引构建完成,共构建 %s 条数据', len(index))
with open(self.index_file, 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False)
self.logger.info('基于询单数据的[用户->咨询师]关系索引数据已保存,共有用户 %s', len(index))
def load_index_data(self):
index = []
with open(self.index_file, 'r', encoding='utf-8') as f:
index = json.load(f)
self.index_data = index
def index(self, q='', count=0) -> List[Tuple[str, float]]:
if len(self.index_data) == 0:
self.logger.error('未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法')
raise
if count == 0:
return self.index_data.get(q, [])
else:
return self.index_data.get(q, [])[:count]
class UserCounselorCombinationIndexer(Indexer):
"""
基于多种数据组合的[用户->咨询师]关系索引
"""
def __init__(self, order_w=0.6, chat_w=0.4, logger=None) -> None:
super().__init__(logger)
self.order_w = order_w
self.chat_w = chat_w
self.index_file = os.path.join(self.local_file_dir, 'user_counselor_combination_index.json')
self.index_data = {}
def make_index(self) -> Dict[str, List]:
self.logger.info('')
self.logger.info('开始构建基于多种数据组合的[用户->咨询师]关系索引')
order_indexer = UserCounselorOrderIndexer(self.logger)
chat_indexer = UserCounselorChatIndexer(self.logger)
order_indexer.load_index_data()
chat_indexer.load_index_data()
self.logger.info('构建索引数据加载完成')
index = {}
for uid, counselors in order_indexer.index_data.items():
chat_index = {
c_id: value for c_id, value in chat_indexer.index_data.get(uid, [])
}
new_counselors = []
for (c_id, value) in counselors:
merge_value = value * self.order_w + chat_index.get(c_id, 0) * self.chat_w
new_counselors.append((c_id, merge_value))
index[uid] = sorted(new_counselors, key=lambda x: x[1], reverse=True)
self.logger.info('基于多种数据组合的[用户->咨询师]关系索引构建完成,共构建 %s 条数据', len(index))
with open(self.index_file, 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False)
return index
def load_index_data(self):
index = []
with open(self.index_file, 'r', encoding='utf-8') as f:
index = json.load(f)
self.index_data = index
def index(self, q='', count=0) -> List[Tuple[str, float]]:
if len(self.index_data) == 0:
self.logger.error('未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法')
raise
if count == 0:
return self.index_data.get(q, [])
else:
return self.index_data.get(q, [])[:count]
if __name__ == '__main__':
indexer = UserCounselorDefaultIndexer()
indexer.make_index()
indexer = UserCounselorOrderIndexer()
indexer.make_index()
indexer = UserCounselorChatIndexer()
indexer.make_index()
indexer = UserCounselorCombinationIndexer()
indexer.make_index()
\ No newline at end of file
...@@ -5,6 +5,4 @@ from .database_manager import DatabaseDataManager ...@@ -5,6 +5,4 @@ from .database_manager import DatabaseDataManager
from .profile_manager import ProfileManager from .profile_manager import ProfileManager
from .chat_data_manager import ChatDataManager from .chat_data_manager import ChatDataManager
from .order_data_manager import OrderDataManager from .order_data_manager import OrderDataManager
\ No newline at end of file
from .user_counselor_index_manager import UserCounselorIndexManager
\ No newline at end of file
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os
import json
from datetime import datetime, timedelta
import pandas as pd import pandas as pd
from ydl_ai_recommender.src.utils.log import create_logger from ydl_ai_recommender.src.utils.log import create_logger
...@@ -13,8 +9,6 @@ from ydl_ai_recommender.src.core.manager import DatabaseDataManager ...@@ -13,8 +9,6 @@ from ydl_ai_recommender.src.core.manager import DatabaseDataManager
class ChatDataManager(DatabaseDataManager): class ChatDataManager(DatabaseDataManager):
def __init__(self, client=None) -> None: def __init__(self, client=None) -> None:
super().__init__(client, create_logger(__name__, 'chat_data_manager.log')) super().__init__(client, create_logger(__name__, 'chat_data_manager.log'))
self.now = datetime.now()
def _make_query_sql(self, conditions=None): def _make_query_sql(self, conditions=None):
...@@ -30,7 +24,6 @@ class ChatDataManager(DatabaseDataManager): ...@@ -30,7 +24,6 @@ class ChatDataManager(DatabaseDataManager):
sql += condition_sql sql += condition_sql
return sql return sql
def update_data(self): def update_data(self):
""" 从数据库中拉取最新订单数据并保存 """ """ 从数据库中拉取最新订单数据并保存 """
sql = self._make_query_sql() sql = self._make_query_sql()
...@@ -40,7 +33,6 @@ class ChatDataManager(DatabaseDataManager): ...@@ -40,7 +33,6 @@ class ChatDataManager(DatabaseDataManager):
self.save_csv_data(df, 'all_chat_info.csv') self.save_csv_data(df, 'all_chat_info.csv')
return df return df
def update_test_data(self, conditions): def update_test_data(self, conditions):
""" 从数据库中拉取指定条件订单用于测试 """ """ 从数据库中拉取指定条件订单用于测试 """
...@@ -50,69 +42,13 @@ class ChatDataManager(DatabaseDataManager): ...@@ -50,69 +42,13 @@ class ChatDataManager(DatabaseDataManager):
df = pd.DataFrame(all_data) df = pd.DataFrame(all_data)
self.save_csv_data(df, 'test_chat_info.csv') self.save_csv_data(df, 'test_chat_info.csv')
def load_raw_data(self): def load_raw_data(self):
return self.load_csv_data('all_chat_info.csv') return self.load_csv_data('all_chat_info.csv')
def load_test_data(self): def load_test_data(self):
return self.load_csv_data('test_chat_info.csv') return self.load_csv_data('test_chat_info.csv')
def make_index(self):
"""
构建索引
用户-咨询师 索引
"""
self.logger.info('')
self.logger.info('开始构建 用户-咨询师 索引')
df = self.load_raw_data()
self.logger.info('本地用户咨询师对话数据加载完成,共加载 %s 条数据', len(df))
user_chat = {}
for index, row in df.iterrows():
uid, supplier_id = row['uid'], row['doctor_id']
if uid not in user_chat:
user_chat[uid] = {}
if supplier_id not in user_chat[uid]:
user_chat[uid][supplier_id] = []
user_chat[uid][supplier_id].append([row['dt'], row['user_to_doctor'], row['doctor_to_user']])
def compute_score(infos):
w = [0, 0, 0, 0]
for [dt, _u2d, _d2u] in infos:
u2d, d2u = int(_u2d), int(_d2u)
date = datetime.strptime(dt, '%Y-%m-%d')
if (self.now - date) <= timedelta(days=7):
w[0] = max(1., w[0], (u2d + d2u) / 20)
elif (self.now - date) <= timedelta(days=30):
w[1] = max(1., w[1], (u2d + d2u) / 20)
elif (self.now - date) <= timedelta(days=180):
w[2] = max(1., w[2], (u2d + d2u) / 20)
else:
w[3] = max(1., w[3], (u2d + d2u) / 20)
value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1
return value
index = {}
for uid, chats in user_chat.items():
supplier_values = []
for supplier_id, infos in chats.items():
# 日期越近权重越大
value = compute_score(infos)
supplier_values.append([supplier_id, value])
index[uid] = sorted(supplier_values, key=lambda x: x[1], reverse=True)
self.logger.info('用户-咨询师 询单索引构建完成,共构建 %s 条数据', len(index))
with open(os.path.join(self.local_file_dir, 'user_doctor_chat_index.json'), 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False)
if __name__ == '__main__': if __name__ == '__main__':
from ydl_ai_recommender.src.data.mysql_client import MySQLClient from ydl_ai_recommender.src.data.mysql_client import MySQLClient
......
...@@ -8,20 +8,16 @@ import pandas as pd ...@@ -8,20 +8,16 @@ import pandas as pd
from ydl_ai_recommender.src.utils import get_data_path from ydl_ai_recommender.src.utils import get_data_path
from ydl_ai_recommender.src.utils.log import create_logger from ydl_ai_recommender.src.utils.log import create_logger
from .manager import Manager
# class Manager():
# def __init__(self, logger=None) -> None:
# if logger is None:
# self.logger = create_logger(__name__)
# else:
# self.logger = logger
class Manager(): # self.local_file_dir = get_data_path()
def __init__(self, logger=None) -> None:
if logger is None:
self.logger = create_logger(__name__)
else:
self.logger = logger
self.local_file_dir = get_data_path()
def make_index(self):
raise NotImplemented
class DatabaseDataManager(Manager): class DatabaseDataManager(Manager):
...@@ -29,7 +25,6 @@ class DatabaseDataManager(Manager): ...@@ -29,7 +25,6 @@ class DatabaseDataManager(Manager):
super().__init__(logger) super().__init__(logger)
self.client = client self.client = client
def fetch_data_from_db(self, sql: str) -> List: def fetch_data_from_db(self, sql: str) -> List:
if self.client is None: if self.client is None:
self.logger.error('未连接数据库') self.logger.error('未连接数据库')
...@@ -37,34 +32,28 @@ class DatabaseDataManager(Manager): ...@@ -37,34 +32,28 @@ class DatabaseDataManager(Manager):
return self.client.query(sql) return self.client.query(sql)
def load_xlsx_data(self, filename): def load_xlsx_data(self, filename):
return pd.read_excel(os.path.join(self.local_file_dir, filename), dtype=str) return pd.read_excel(os.path.join(self.local_file_dir, filename), dtype=str)
def save_xlsx_data(self, df, filename): def save_xlsx_data(self, df, filename):
df.to_excel(os.path.join(self.local_file_dir, filename), index=None) df.to_excel(os.path.join(self.local_file_dir, filename), index=None)
def load_csv_data(self, filename): def load_csv_data(self, filename):
return pd.read_csv(os.path.join(self.local_file_dir, filename), dtype=str) return pd.read_csv(os.path.join(self.local_file_dir, filename), dtype=str)
def save_csv_data(self, df, filename): def save_csv_data(self, df, filename):
df.to_csv(os.path.join(self.local_file_dir, filename), encoding='utf-8', index=False) df.to_csv(os.path.join(self.local_file_dir, filename), encoding='utf-8', index=False)
def load_json_data(self, filename): def load_json_data(self, filename):
with open(os.path.join(self.local_file_dir, filename), 'r', encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, filename), 'r', encoding='utf-8') as f:
return json.load(f) return json.load(f)
def save_json_data(self, data, filename): def save_json_data(self, data, filename):
with open(os.path.join(self.local_file_dir, filename), 'r', encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, filename), 'r', encoding='utf-8') as f:
return json.dump(data, f, ensure_ascii=False) return json.dump(data, f, ensure_ascii=False)
def update_data(self):
raise NotImplementedError
def update_data(self, sql, filename): def update_test_data(self, conditions):
_, all_data = self.fetch_data_from_db(sql) raise NotImplementedError
df = pd.DataFrame(all_data)
self.save_xlsx_data(df, filename)
...@@ -15,4 +15,4 @@ class Manager(): ...@@ -15,4 +15,4 @@ class Manager():
def make_index(self): def make_index(self):
raise NotImplemented raise NotImplementedError
\ No newline at end of file \ No newline at end of file
...@@ -16,7 +16,6 @@ class OrderDataManager(DatabaseDataManager): ...@@ -16,7 +16,6 @@ class OrderDataManager(DatabaseDataManager):
super().__init__(client, create_logger(__name__, 'order_data_manager.log')) super().__init__(client, create_logger(__name__, 'order_data_manager.log'))
self.now = datetime.now() self.now = datetime.now()
def _make_query_sql(self, conditions=None): def _make_query_sql(self, conditions=None):
condition_sql = '' condition_sql = ''
if conditions: if conditions:
...@@ -28,8 +27,7 @@ class OrderDataManager(DatabaseDataManager): ...@@ -28,8 +27,7 @@ class OrderDataManager(DatabaseDataManager):
sql += condition_sql sql += condition_sql
return sql return sql
def update_data(self):
def update_order_data(self):
""" 从数据库中拉取最新订单数据并保存 """ """ 从数据库中拉取最新订单数据并保存 """
sql = self._make_query_sql() sql = self._make_query_sql()
_, all_data = self.fetch_data_from_db(sql) _, all_data = self.fetch_data_from_db(sql)
...@@ -38,7 +36,7 @@ class OrderDataManager(DatabaseDataManager): ...@@ -38,7 +36,7 @@ class OrderDataManager(DatabaseDataManager):
self.save_xlsx_data(df, 'all_order_info.xlsx') self.save_xlsx_data(df, 'all_order_info.xlsx')
def update_test_order_data(self, conditions): def update_test_data(self, conditions):
""" 从数据库中拉取指定条件订单用于测试 """ """ 从数据库中拉取指定条件订单用于测试 """
sql = self._make_query_sql(conditions) sql = self._make_query_sql(conditions)
...@@ -56,82 +54,6 @@ class OrderDataManager(DatabaseDataManager): ...@@ -56,82 +54,6 @@ class OrderDataManager(DatabaseDataManager):
return self.load_xlsx_data('test_order_info.xlsx') return self.load_xlsx_data('test_order_info.xlsx')
def make_index(self):
"""
构建索引
用户-咨询师 索引
top100 咨询师列表 用于冷启动
"""
self.logger.info('')
self.logger.info('开始构建 用户-咨询师 索引')
df = self.load_raw_data()
self.logger.info('本地订单加载数据完成,共加载 %s 条数据', len(df))
user_order = {}
for index, row in df.iterrows():
uid, supplier_id = row['uid'], row['supplier_id']
if uid not in user_order:
user_order[uid] = {}
if supplier_id not in user_order[uid]:
user_order[uid][supplier_id] = []
user_order[uid][supplier_id].append([row['price'], row['update_time']])
def compute_score(infos):
w = [0, 0, 0, 0]
for [_price, dt] in infos:
price = float(_price)
date = datetime.strptime(dt, '%Y-%m-%d')
if (self.now - date) <= timedelta(days=7):
w[0] = max(1., w[0], price / 400)
elif (self.now - date) <= timedelta(days=30):
w[1] = max(1., w[1], price / 400)
elif (self.now - date) <= timedelta(days=180):
w[2] = max(1., w[2], price / 400)
else:
w[3] = max(1., w[3], price / 400)
value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1
return value
index = {}
for uid, orders in user_order.items():
supplier_values = []
for supplier_id, infos in orders.items():
# 订单越多排序约靠前,相同数量订单,最新订单约晚越靠前
value = compute_score(infos)
supplier_values.append([supplier_id, value])
# value = len(infos)
# latest_time = max([info[1] for info in infos])
# supplier_values.append([supplier_id, value, latest_time])
# index[uid] = sorted(supplier_values, key=lambda x: (x[2], x[1]), reverse=True)
index[uid] = sorted(supplier_values, key=lambda x: x[1], reverse=True)
self.logger.info('用户-咨询师 索引构建完成,共构建 %s 条数据', len(index))
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False)
self.logger.info('用户-咨询师 索引数据已保存,共有用户 %s', len(index))
# 订单最多的咨询师
supplier_cnter = Counter(df['supplier_id'])
top100_supplier = []
for key, _ in supplier_cnter.most_common(100):
top100_supplier.append(str(key))
self.logger.info('top100 订单量咨询师统计完成')
with open(os.path.join(self.local_file_dir, 'top100_supplier.txt'), 'w', encoding='utf-8') as f:
f.write('\n'.join(top100_supplier))
self.logger.info('top100 订单量咨询师列表已保存')
if __name__ == '__main__': if __name__ == '__main__':
manager = OrderDataManager() manager = OrderDataManager()
print(manager.make_index()) print(manager.make_index())
\ No newline at end of file
...@@ -6,7 +6,7 @@ from typing import List ...@@ -6,7 +6,7 @@ from typing import List
import pandas as pd import pandas as pd
from ydl_ai_recommender.src.core.profile import profile_converters from ydl_ai_recommender.src.core.profile import encode_profile
from ydl_ai_recommender.src.core.manager import DatabaseDataManager from ydl_ai_recommender.src.core.manager import DatabaseDataManager
from ydl_ai_recommender.src.utils.log import create_logger from ydl_ai_recommender.src.utils.log import create_logger
...@@ -29,8 +29,7 @@ class ProfileManager(DatabaseDataManager): ...@@ -29,8 +29,7 @@ class ProfileManager(DatabaseDataManager):
sql += ' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order{})'.format(condition_sql) sql += ' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order{})'.format(condition_sql)
return sql return sql
def update_data(self):
def update_profile(self):
""" 从数据库中拉取最新画像特征并保存 """ """ 从数据库中拉取最新画像特征并保存 """
sql = self._make_query_sql() sql = self._make_query_sql()
...@@ -39,8 +38,7 @@ class ProfileManager(DatabaseDataManager): ...@@ -39,8 +38,7 @@ class ProfileManager(DatabaseDataManager):
df = pd.DataFrame(all_data) df = pd.DataFrame(all_data)
self.save_xlsx_data(df, 'all_profile.xlsx') self.save_xlsx_data(df, 'all_profile.xlsx')
def update_test_data(self, conditions):
def update_test_profile(self, conditions):
""" 从数据库中拉取指定条件画像信息用于测试 """ """ 从数据库中拉取指定条件画像信息用于测试 """
sql = self._make_query_sql(conditions) sql = self._make_query_sql(conditions)
...@@ -48,39 +46,13 @@ class ProfileManager(DatabaseDataManager): ...@@ -48,39 +46,13 @@ class ProfileManager(DatabaseDataManager):
df = pd.DataFrame(all_data) df = pd.DataFrame(all_data)
self.save_xlsx_data(df, 'test_profile.xlsx') self.save_xlsx_data(df, 'test_profile.xlsx')
def _load_profile_data(self): def _load_profile_data(self):
return self.load_xlsx_data('all_profile.xlsx') return self.load_xlsx_data('all_profile.xlsx')
def load_test_profile_data(self): def load_test_profile_data(self):
return self.load_xlsx_data('test_profile.xlsx') return self.load_xlsx_data('test_profile.xlsx')
def profile_to_embedding(self, profile):
"""
将用户画像信息转换为向量
"""
embedding = []
for [name, converter] in profile_converters:
embedding.extend(converter.convert(profile[name]))
return embedding
def embedding_to_profile(self, embedding):
"""
向量转换为用户画像
"""
ret = {}
si = 0
for [name, converter] in profile_converters:
ei = si + converter.dim
ret[name] = converter.inconvert(embedding[si: ei])
si = ei
return ret
def make_embeddings(self): def make_embeddings(self):
user_profiles = self._load_profile_data() user_profiles = self._load_profile_data()
self.logger.info('订单用户画像数据加载完成,共加载 %s 条', len(user_profiles)) self.logger.info('订单用户画像数据加载完成,共加载 %s 条', len(user_profiles))
...@@ -88,7 +60,7 @@ class ProfileManager(DatabaseDataManager): ...@@ -88,7 +60,7 @@ class ProfileManager(DatabaseDataManager):
self.logger.info('开始构建订单用户的用户画像向量') self.logger.info('开始构建订单用户的用户画像向量')
for _, profile in user_profiles.iterrows(): for _, profile in user_profiles.iterrows():
user_ids.append(str(profile['uid'])) user_ids.append(str(profile['uid']))
embeddings.append(self.profile_to_embedding(profile)) embeddings.append(encode_profile(profile))
self.logger.info('用户画像向量构建完成,共构建 %s 用户', len(user_ids)) self.logger.info('用户画像向量构建完成,共构建 %s 用户', len(user_ids))
...@@ -100,7 +72,6 @@ class ProfileManager(DatabaseDataManager): ...@@ -100,7 +72,6 @@ class ProfileManager(DatabaseDataManager):
return embeddings return embeddings
def make_virtual_embedding(self): def make_virtual_embedding(self):
user_ids = [] user_ids = []
embeddings = [] embeddings = []
......
# -*- coding: utf-8 -*-
import os
import json
from ydl_ai_recommender.src.utils.log import create_logger
from ydl_ai_recommender.src.core.manager import Manager
# from ydl_ai_recommender.src.core.manager import OrderDataManager, ChatDataManager
class UserCounselorIndexManager(Manager):
def __init__(self) -> None:
super().__init__(create_logger(__name__, 'order_data_manager.log'))
# self.order_data_manager = OrderDataManager()
# self.chat_data_manager = ChatDataManager()
def make_index(self):
self.logger.info('开始构建用户-咨询师 合并索引')
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), encoding='utf-8') as f:
user_doctor_index = json.load(f)
with open(os.path.join(self.local_file_dir, 'user_doctor_chat_index.json'), encoding='utf-8') as f:
user_doctor_chat_index = json.load(f)
merged_index = {}
for uid, counselors in user_doctor_index.items():
chat_index = {
c_id: value for c_id, value in user_doctor_chat_index.get(uid, [])
}
new_counselors = []
for [c_id, value] in counselors:
if c_id in chat_index:
merge_value = value * 0.6 + chat_index[c_id] * 0.4
else:
merge_value = value * 0.6
new_counselors.append([c_id, merge_value])
merged_index[uid] = sorted(new_counselors, key=lambda x: x[1], reverse=True)
self.logger.info('用户-咨询师 合并索引构建完成,共构建 %s 条数据', len(merged_index))
with open(os.path.join(self.local_file_dir, 'merged_user_doctor_index.json'), 'w', encoding='utf-8') as f:
json.dump(merged_index, f, ensure_ascii=False)
if __name__ == '__main__':
manager = UserCounselorIndexManager()
manager.make_index()
\ No newline at end of file
...@@ -5,20 +5,17 @@ from typing import Dict, List, Any, Union ...@@ -5,20 +5,17 @@ from typing import Dict, List, Any, Union
import pandas as pd import pandas as pd
# from .country_code_profile import CountryCodeProfile
# from .profile import ChannelIdTypeProfile
class BaseProfile(): class BaseProfile():
def __init__(self) -> None: def __init__(self) -> None:
self.dim :int = 0 self.dim: int = 0
def convert(self, value): def convert(self, value):
raise NotImplemented raise NotImplementedError
def inconvert(self, embedding: List[Union[int, float]]) -> str: def inconvert(self, embedding: List[Union[int, float]]) -> str:
raise NotImplemented raise NotImplementedError
class CountryCodeProfile(BaseProfile): class CountryCodeProfile(BaseProfile):
...@@ -29,7 +26,7 @@ class CountryCodeProfile(BaseProfile): ...@@ -29,7 +26,7 @@ class CountryCodeProfile(BaseProfile):
def convert(self, value): def convert(self, value):
try: try:
value = int(value) value = int(value)
except Exception as e: except Exception:
return [0, 0, 1] return [0, 0, 1]
if value == 86: if value == 86:
return [1, 0, 0] return [1, 0, 0]
...@@ -53,9 +50,9 @@ class ChannelIdTypeProfile(BaseProfile): ...@@ -53,9 +50,9 @@ class ChannelIdTypeProfile(BaseProfile):
def convert(self, value): def convert(self, value):
try: try:
value = int(value) value = int(value)
except Exception as e: except Exception:
return [0, 0, 1] return [0, 0, 1]
if value == 1: if value == 1:
return [1, 0, 0] return [1, 0, 0]
elif value == 2: elif value == 2:
...@@ -86,7 +83,7 @@ class FfromLoginProfile(BaseProfile): ...@@ -86,7 +83,7 @@ class FfromLoginProfile(BaseProfile):
ret = [0, 0, 0, 0, 0] ret = [0, 0, 0, 0, 0]
try: try:
value = value.lower() value = value.lower()
except Exception as e: except Exception:
return ret return ret
for i, v in enumerate(self.brand_list): for i, v in enumerate(self.brand_list):
...@@ -119,11 +116,11 @@ class UserPreferenceCateProfile(BaseProfile): ...@@ -119,11 +116,11 @@ class UserPreferenceCateProfile(BaseProfile):
ret = [0.] * 8 ret = [0.] * 8
if pd.isnull(value): if pd.isnull(value):
return ret return ret
if isinstance(value, str): if isinstance(value, str):
try: try:
value = json.loads(value) value = json.loads(value)
except Exception as e: except Exception:
return ret return ret
for info in value: for info in value:
...@@ -167,11 +164,10 @@ class NumClassProfile(BaseProfile): ...@@ -167,11 +164,10 @@ class NumClassProfile(BaseProfile):
value = float(value) value = float(value)
index = self.value_index(value) index = self.value_index(value)
ret[index] = 1 ret[index] = 1
except: except Exception:
return ret return ret
return ret return ret
def inconvert(self, embedding): def inconvert(self, embedding):
ret = '' ret = ''
# 确保embedding中有包含1的值 # 确保embedding中有包含1的值
...@@ -245,7 +241,6 @@ class MultiChoiceProfile(BaseProfile): ...@@ -245,7 +241,6 @@ class MultiChoiceProfile(BaseProfile):
self.option_dict = option_dict self.option_dict = option_dict
self.re_option_dict = {v: k for k, v in self.option_dict.items()} self.re_option_dict = {v: k for k, v in self.option_dict.items()}
def convert(self, value: List): def convert(self, value: List):
ret = [0] * len(self.option_dict) ret = [0] * len(self.option_dict)
if pd.isnull(value): if pd.isnull(value):
...@@ -261,7 +256,6 @@ class MultiChoiceProfile(BaseProfile): ...@@ -261,7 +256,6 @@ class MultiChoiceProfile(BaseProfile):
pass pass
return ret return ret
def inconvert(self, embedding): def inconvert(self, embedding):
ret = [] ret = []
...@@ -284,7 +278,6 @@ class CityProfile(BaseProfile): ...@@ -284,7 +278,6 @@ class CityProfile(BaseProfile):
self.level = level self.level = level
self.dim = self.level * 10 self.dim = self.level * 10
def convert(self, value): def convert(self, value):
ret = [0] * self.dim ret = [0] * self.dim
...@@ -297,11 +290,10 @@ class CityProfile(BaseProfile): ...@@ -297,11 +290,10 @@ class CityProfile(BaseProfile):
n = int(_n) n = int(_n)
ret[i * 10 + n] = 1 ret[i * 10 + n] = 1
except Exception as e: except Exception:
pass pass
return ret return ret
def inconvert(self, embedding): def inconvert(self, embedding):
# 邮编固定都是6 # 邮编固定都是6
ret = [0] * 6 ret = [0] * 6
...@@ -318,7 +310,6 @@ class AidiCstBiasCityProfile(CityProfile): ...@@ -318,7 +310,6 @@ class AidiCstBiasCityProfile(CityProfile):
def __init__(self, level=2) -> None: def __init__(self, level=2) -> None:
super().__init__(level=level) super().__init__(level=level)
def convert(self, value_object): def convert(self, value_object):
ret = [0] * self.dim ret = [0] * self.dim
...@@ -331,7 +322,7 @@ class AidiCstBiasCityProfile(CityProfile): ...@@ -331,7 +322,7 @@ class AidiCstBiasCityProfile(CityProfile):
if isinstance(value_object, str): if isinstance(value_object, str):
try: try:
value_object = json.loads(value_object) value_object = json.loads(value_object)
except Exception as e: except Exception:
pass pass
if isinstance(value_object, dict): if isinstance(value_object, dict):
...@@ -340,7 +331,7 @@ class AidiCstBiasCityProfile(CityProfile): ...@@ -340,7 +331,7 @@ class AidiCstBiasCityProfile(CityProfile):
for i, _n in enumerate(value[:self.level]): for i, _n in enumerate(value[:self.level]):
n = int(_n) n = int(_n)
ret[i * 10 + n] = 1 ret[i * 10 + n] = 1
except Exception as e: except Exception:
pass pass
return ret return ret
...@@ -367,3 +358,25 @@ profile_converters = [ ...@@ -367,3 +358,25 @@ profile_converters = [
['d30_session_num', NumClassProfile([0, 1])], ['d30_session_num', NumClassProfile([0, 1])],
] ]
def encode_profile(profile):
"""
将用户画像信息转换为向量
"""
embedding = []
for [name, converter] in profile_converters:
embedding.extend(converter.convert(profile[name]))
return embedding
def decode_profile(embedding):
"""
向量转换为用户画像
"""
ret = {}
si = 0
for [name, converter] in profile_converters:
ei = si + converter.dim
ret[name] = converter.inconvert(embedding[si: ei])
si = ei
return ret
# -*- coding: utf-8 -*-
class CountryCodeProfile():
def __init__(self) -> None:
pass
def convert(self, value):
try:
value = int(value)
except Exception as e:
return [0, 0, 1]
if value == 86:
return [1, 0, 0]
else:
return [0, 1, 0]
\ No newline at end of file
# -*- coding: utf-8 -*-
class ChannelIdTypeProfile():
def __init__(self) -> None:
pass
def convert(self, value):
try:
value = int(value)
except Exception as e:
return [0, 0, 1]
if value == 1:
return [1, 0, 0]
elif value == 2:
return [0, 1, 0]
else:
return [0, 0, 1]
\ No newline at end of file
...@@ -7,7 +7,9 @@ from typing import List, Dict ...@@ -7,7 +7,9 @@ from typing import List, Dict
import faiss import faiss
import numpy as np import numpy as np
from ydl_ai_recommender.src.core.manager import ProfileManager from ydl_ai_recommender.src.core.indexer import UserCounselorDefaultIndexer
from ydl_ai_recommender.src.core.indexer import UserCounselorCombinationIndexer
from ydl_ai_recommender.src.core.profile import encode_profile
from ydl_ai_recommender.src.data.mysql_client import MySQLClient from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_data_path from ydl_ai_recommender.src.utils import get_conf_path, get_data_path
from ydl_ai_recommender.src.utils.log import create_logger from ydl_ai_recommender.src.utils.log import create_logger
...@@ -19,7 +21,7 @@ class Recommender(): ...@@ -19,7 +21,7 @@ class Recommender():
pass pass
def recommend(self, user) -> List: def recommend(self, user) -> List:
raise NotImplemented raise NotImplementedError
class UserCFRecommender(Recommender): class UserCFRecommender(Recommender):
...@@ -37,7 +39,11 @@ class UserCFRecommender(Recommender): ...@@ -37,7 +39,11 @@ class UserCFRecommender(Recommender):
else: else:
self.logger.warn('未连接数据库') self.logger.warn('未连接数据库')
self.manager = ProfileManager() self.default_indexer = UserCounselorDefaultIndexer(self.logger)
self.default_indexer.load_index_data()
self.indexer = UserCounselorCombinationIndexer(self.logger)
self.indexer.load_index_data()
self.local_file_dir = get_data_path() self.local_file_dir = get_data_path()
self.load_data() self.load_data()
...@@ -45,8 +51,6 @@ class UserCFRecommender(Recommender): ...@@ -45,8 +51,6 @@ class UserCFRecommender(Recommender):
def load_data(self): def load_data(self):
order_user_embedding = [] order_user_embedding = []
order_user_ids = [] order_user_ids = []
order_user_counselor_index = {}
default_counselor = []
with open(os.path.join(self.local_file_dir, 'user_embeddings_ids.txt'), 'r', encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, 'user_embeddings_ids.txt'), 'r', encoding='utf-8') as f:
order_user_ids = [line.strip() for line in f] order_user_ids = [line.strip() for line in f]
...@@ -54,21 +58,8 @@ class UserCFRecommender(Recommender): ...@@ -54,21 +58,8 @@ class UserCFRecommender(Recommender):
with open(os.path.join(self.local_file_dir, 'user_embeddings.json'), 'r', encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, 'user_embeddings.json'), 'r', encoding='utf-8') as f:
order_user_embedding = json.load(f) order_user_embedding = json.load(f)
with open(os.path.join(self.local_file_dir, 'merged_user_doctor_index.json'), encoding='utf-8') as f:
# with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), encoding='utf-8') as f:
order_user_counselor_index = json.load(f)
with open(os.path.join(self.local_file_dir, 'top100_supplier.txt'), 'r', encoding='utf-8') as f:
default_counselor = [line.strip() for line in f]
self.order_user_embedding = order_user_embedding self.order_user_embedding = order_user_embedding
self.order_user_ids = order_user_ids self.order_user_ids = order_user_ids
self.order_user_counselor_index = order_user_counselor_index
self.default_counselor = [{
'counselor': str(user),
'score': 1 - 0.01 * index,
'from': 'top_100',
} for index, user in enumerate(default_counselor)]
self.index = faiss.IndexFlatL2(len(self.order_user_embedding[0])) self.index = faiss.IndexFlatL2(len(self.order_user_embedding[0]))
self.index.add(np.array(self.order_user_embedding)) self.index.add(np.array(self.order_user_embedding))
...@@ -88,28 +79,27 @@ class UserCFRecommender(Recommender): ...@@ -88,28 +79,27 @@ class UserCFRecommender(Recommender):
return [] return []
def user_token(self, user_profile):
return self.manager.profile_to_embedding(user_profile)
def _recommend_top(self, size=100): def _recommend_top(self, size=100):
return self.default_counselor[:size] return [{
'counselor': str(c_id),
'score': score,
'from': 'default',
} for [c_id, score] in self.default_indexer.index(size)]
def _recommend(self, user_embedding): def _recommend(self, user_embedding):
D, I = self.index.search(np.array([user_embedding]), self.k) D, I = self.index.search(np.array([user_embedding]), self.k)
counselors = [] counselors = []
for idx, score in zip(I[0], D[0]): for idx, simi_score in zip(I[0], D[0]):
# 相似用户uid # 相似用户uid
similar_user_id = self.order_user_ids[idx] similar_user_id = self.order_user_ids[idx]
similar_user_counselor = self.order_user_counselor_index.get(similar_user_id, []) similar_user_counselor = self.indexer.index(q=similar_user_id, count=self.top_n)
recommend_data = [{ recommend_data = [{
'counselor': str(user[0]), 'counselor': c_id,
'score': 1 / max(0.01, float(score) * (index + 1)), 'score': score / max(0.01, float(simi_score)),
'from': 'similar_users {}'.format(similar_user_id), 'from': 'similar_users {}'.format(similar_user_id),
} for index, user in enumerate(similar_user_counselor[:self.top_n])] } for (c_id, score) in similar_user_counselor]
counselors.extend(recommend_data) counselors.extend(recommend_data)
counselors.sort(key=lambda x: x['score'], reverse=True) counselors.sort(key=lambda x: x['score'], reverse=True)
...@@ -117,7 +107,7 @@ class UserCFRecommender(Recommender): ...@@ -117,7 +107,7 @@ class UserCFRecommender(Recommender):
def recommend_with_profile(self, user_profile, size=0, is_merge=True): def recommend_with_profile(self, user_profile, size=0, is_merge=True):
user_embedding = self.user_token(user_profile) user_embedding = encode_profile(user_profile)
counselors = self._recommend(user_embedding) counselors = self._recommend(user_embedding)
# size == 0 时,不追加默认推荐咨询师 # size == 0 时,不追加默认推荐咨询师
......
...@@ -40,7 +40,8 @@ class MySQLClient(): ...@@ -40,7 +40,8 @@ class MySQLClient():
try: try:
self.cursor.close() self.cursor.close()
self.connection.close() self.connection.close()
self.logger.info('dataset disconnected') # 容易触发 NameError: name 'open' is not defined
# self.logger.info('dataset disconnected')
except Exception as e: except Exception as e:
self.logger.error('销毁 MySQLClient 失败', exc_info=True) self.logger.error('销毁 MySQLClient 失败', exc_info=True)
print(e) print(e)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment