Commit 523f5373 by 柴鹏飞

refactor code

parent 7401bbe5
......@@ -190,11 +190,11 @@ def update_test_data(args):
# 订单数据
manager = OrderDataManager(client)
manager.update_test_order_data(conditions=conditions)
manager.update_test_data(conditions=conditions)
# 用户画像数据
manager = ProfileManager(client)
manager.update_test_profile(conditions=conditions)
manager.update_test_data(conditions=conditions)
logger.info('测试数据更新完成')
......
......@@ -4,9 +4,17 @@ import os
import argparse
from datetime import datetime
from ydl_ai_recommender.src.core.manager import OrderDataManager
from ydl_ai_recommender.src.core.manager import ChatDataManager
from ydl_ai_recommender.src.core.manager import ProfileManager
from ydl_ai_recommender.src.core.manager import (
OrderDataManager,
ChatDataManager,
ProfileManager,
)
from ydl_ai_recommender.src.core.indexer import (
UserCounselorDefaultIndexer,
UserCounselorOrderIndexer,
UserCounselorChatIndexer,
UserCounselorCombinationIndexer,
)
from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_project_path
from ydl_ai_recommender.src.utils.log import create_logger
......@@ -18,18 +26,14 @@ logger = create_logger(__name__, 'update.log')
parser = argparse.ArgumentParser(description='壹点灵 咨询师推荐 算法召回 离线更新数据模型')
parser.add_argument(
'-t', '--task', type=str, required=True,
choices=('load_db_data', 'make_embedding', 'make_virtual_embedding'), help='执行任务名称'
choices=('load_db_data', 'make_embedding', 'make_index'), help='执行任务名称'
)
parser.add_argument('--index_last_date', default=None, type=str, help='构建索引最后日期,超过该日期的数据不使用')
args = parser.parse_args()
if __name__ == '__main__':
if args.task == 'load_db_data':
logger.info('')
def initialize_dir():
# 创建数据目录
now = datetime.now()
really_data_dir = os.path.join(get_project_path(), 'data_{}'.format(now.strftime('%Y%m%d_%H%M%S')))
......@@ -51,46 +55,63 @@ if __name__ == '__main__':
# TODO 历史数据删除
logger.info('开始从数据库中更新数据')
client = MySQLClient.create_from_config_file(get_conf_path())
logger.info('开始从数据库中更新画像数据')
profile_manager = ProfileManager(client)
profile_manager.update_profile()
if __name__ == '__main__':
logger.info('开始从数据库中更新订单数据')
order_data_manager = OrderDataManager(client)
order_data_manager.update_order_data()
logger.info('')
if args.task == 'load_db_data':
logger.info('开始从数据库中更新询单数据')
chat_data_manager = ChatDataManager(client)
chat_data_manager.update_data()
initialize_dir()
logger.info('开始从数据库中更新数据')
client = MySQLClient.create_from_config_file(get_conf_path())
logger.info('所有数据更新数据完成')
managers = [
['画像数据', ProfileManager(client)],
['订单数据', OrderDataManager(client)],
['询单数据', ChatDataManager(client)],
]
for [name, manager] in managers:
logger.info('开始更新 %s', name)
manager.update_data()
logger.info('%s 更新完成', name)
logger.info('')
logger.info('所有数据更新数据完成')
if args.task == 'make_embedding':
logger.info('')
logger.info('--' * 50)
logger.info('开始构建用户特征 embedding')
manager = ProfileManager()
manager.make_embeddings()
logger.info('用户特征 embedding 构建完成')
logger.info('开始构建订单相关索引')
manager = OrderDataManager()
manager.make_index()
logger.info('订单相关索引 构建完成')
logger.info('开始构建询单相关索引')
chat_data_manager = ChatDataManager()
chat_data_manager.make_index()
logger.info('询单相关索引 构建完成')
if args.task == 'make_index':
indexers = [
['[用户->咨询师]兜底关系索引', UserCounselorDefaultIndexer()],
['基于订单数据的[用户->咨询师]关系索引', UserCounselorOrderIndexer()],
['基于询单数据的[用户->咨询师]关系索引', UserCounselorChatIndexer()],
['基于多种数据组合的[用户->咨询师]关系索引', UserCounselorCombinationIndexer()],
]
logger.info('')
logger.info('--' * 50)
if args.task == 'make_virtual_embedding':
for [name, indexer] in indexers:
logger.info('开始构建 %s', name)
indexer.make_index()
logger.info('%s 构建完成', name)
logger.info('')
logger.info('开始构建用户特征虚拟embedding')
manager = ProfileManager()
manager.make_virtual_embedding()
logger.info('用户特征虚拟 embedding 构建完成')
\ No newline at end of file
logger.info('所有索引更新数据完成')
# if args.task == 'make_virtual_embedding':
# logger.info('')
# logger.info('开始构建用户特征虚拟embedding')
# manager = ProfileManager()
# manager.make_virtual_embedding()
# logger.info('用户特征虚拟 embedding 构建完成')
\ No newline at end of file
......@@ -6,5 +6,3 @@ from .database_manager import DatabaseDataManager
from .profile_manager import ProfileManager
from .chat_data_manager import ChatDataManager
from .order_data_manager import OrderDataManager
\ No newline at end of file
from .user_counselor_index_manager import UserCounselorIndexManager
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
import json
from datetime import datetime, timedelta
import pandas as pd
from ydl_ai_recommender.src.utils.log import create_logger
......@@ -13,8 +9,6 @@ from ydl_ai_recommender.src.core.manager import DatabaseDataManager
class ChatDataManager(DatabaseDataManager):
def __init__(self, client=None) -> None:
super().__init__(client, create_logger(__name__, 'chat_data_manager.log'))
self.now = datetime.now()
def _make_query_sql(self, conditions=None):
......@@ -30,7 +24,6 @@ class ChatDataManager(DatabaseDataManager):
sql += condition_sql
return sql
def update_data(self):
""" 从数据库中拉取最新订单数据并保存 """
sql = self._make_query_sql()
......@@ -40,7 +33,6 @@ class ChatDataManager(DatabaseDataManager):
self.save_csv_data(df, 'all_chat_info.csv')
return df
def update_test_data(self, conditions):
""" 从数据库中拉取指定条件订单用于测试 """
......@@ -50,69 +42,13 @@ class ChatDataManager(DatabaseDataManager):
df = pd.DataFrame(all_data)
self.save_csv_data(df, 'test_chat_info.csv')
def load_raw_data(self):
return self.load_csv_data('all_chat_info.csv')
def load_test_data(self):
return self.load_csv_data('test_chat_info.csv')
def make_index(self):
"""
构建索引
用户-咨询师 索引
"""
self.logger.info('')
self.logger.info('开始构建 用户-咨询师 索引')
df = self.load_raw_data()
self.logger.info('本地用户咨询师对话数据加载完成,共加载 %s 条数据', len(df))
user_chat = {}
for index, row in df.iterrows():
uid, supplier_id = row['uid'], row['doctor_id']
if uid not in user_chat:
user_chat[uid] = {}
if supplier_id not in user_chat[uid]:
user_chat[uid][supplier_id] = []
user_chat[uid][supplier_id].append([row['dt'], row['user_to_doctor'], row['doctor_to_user']])
def compute_score(infos):
w = [0, 0, 0, 0]
for [dt, _u2d, _d2u] in infos:
u2d, d2u = int(_u2d), int(_d2u)
date = datetime.strptime(dt, '%Y-%m-%d')
if (self.now - date) <= timedelta(days=7):
w[0] = max(1., w[0], (u2d + d2u) / 20)
elif (self.now - date) <= timedelta(days=30):
w[1] = max(1., w[1], (u2d + d2u) / 20)
elif (self.now - date) <= timedelta(days=180):
w[2] = max(1., w[2], (u2d + d2u) / 20)
else:
w[3] = max(1., w[3], (u2d + d2u) / 20)
value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1
return value
index = {}
for uid, chats in user_chat.items():
supplier_values = []
for supplier_id, infos in chats.items():
# 日期越近权重越大
value = compute_score(infos)
supplier_values.append([supplier_id, value])
index[uid] = sorted(supplier_values, key=lambda x: x[1], reverse=True)
self.logger.info('用户-咨询师 询单索引构建完成,共构建 %s 条数据', len(index))
with open(os.path.join(self.local_file_dir, 'user_doctor_chat_index.json'), 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False)
if __name__ == '__main__':
from ydl_ai_recommender.src.data.mysql_client import MySQLClient
......
......@@ -8,20 +8,16 @@ import pandas as pd
from ydl_ai_recommender.src.utils import get_data_path
from ydl_ai_recommender.src.utils.log import create_logger
from .manager import Manager
# class Manager():
# def __init__(self, logger=None) -> None:
# if logger is None:
# self.logger = create_logger(__name__)
# else:
# self.logger = logger
class Manager():
def __init__(self, logger=None) -> None:
if logger is None:
self.logger = create_logger(__name__)
else:
self.logger = logger
self.local_file_dir = get_data_path()
def make_index(self):
raise NotImplemented
# self.local_file_dir = get_data_path()
class DatabaseDataManager(Manager):
......@@ -29,7 +25,6 @@ class DatabaseDataManager(Manager):
super().__init__(logger)
self.client = client
def fetch_data_from_db(self, sql: str) -> List:
if self.client is None:
self.logger.error('未连接数据库')
......@@ -37,34 +32,28 @@ class DatabaseDataManager(Manager):
return self.client.query(sql)
def load_xlsx_data(self, filename):
return pd.read_excel(os.path.join(self.local_file_dir, filename), dtype=str)
def save_xlsx_data(self, df, filename):
df.to_excel(os.path.join(self.local_file_dir, filename), index=None)
def load_csv_data(self, filename):
return pd.read_csv(os.path.join(self.local_file_dir, filename), dtype=str)
def save_csv_data(self, df, filename):
df.to_csv(os.path.join(self.local_file_dir, filename), encoding='utf-8', index=False)
def load_json_data(self, filename):
with open(os.path.join(self.local_file_dir, filename), 'r', encoding='utf-8') as f:
return json.load(f)
def save_json_data(self, data, filename):
with open(os.path.join(self.local_file_dir, filename), 'r', encoding='utf-8') as f:
return json.dump(data, f, ensure_ascii=False)
def update_data(self):
raise NotImplementedError
def update_data(self, sql, filename):
_, all_data = self.fetch_data_from_db(sql)
df = pd.DataFrame(all_data)
self.save_xlsx_data(df, filename)
def update_test_data(self, conditions):
raise NotImplementedError
......@@ -15,4 +15,4 @@ class Manager():
def make_index(self):
raise NotImplemented
\ No newline at end of file
raise NotImplementedError
\ No newline at end of file
......@@ -16,7 +16,6 @@ class OrderDataManager(DatabaseDataManager):
super().__init__(client, create_logger(__name__, 'order_data_manager.log'))
self.now = datetime.now()
def _make_query_sql(self, conditions=None):
condition_sql = ''
if conditions:
......@@ -28,8 +27,7 @@ class OrderDataManager(DatabaseDataManager):
sql += condition_sql
return sql
def update_order_data(self):
def update_data(self):
""" 从数据库中拉取最新订单数据并保存 """
sql = self._make_query_sql()
_, all_data = self.fetch_data_from_db(sql)
......@@ -38,7 +36,7 @@ class OrderDataManager(DatabaseDataManager):
self.save_xlsx_data(df, 'all_order_info.xlsx')
def update_test_order_data(self, conditions):
def update_test_data(self, conditions):
""" 从数据库中拉取指定条件订单用于测试 """
sql = self._make_query_sql(conditions)
......@@ -56,82 +54,6 @@ class OrderDataManager(DatabaseDataManager):
return self.load_xlsx_data('test_order_info.xlsx')
def make_index(self):
"""
构建索引
用户-咨询师 索引
top100 咨询师列表 用于冷启动
"""
self.logger.info('')
self.logger.info('开始构建 用户-咨询师 索引')
df = self.load_raw_data()
self.logger.info('本地订单加载数据完成,共加载 %s 条数据', len(df))
user_order = {}
for index, row in df.iterrows():
uid, supplier_id = row['uid'], row['supplier_id']
if uid not in user_order:
user_order[uid] = {}
if supplier_id not in user_order[uid]:
user_order[uid][supplier_id] = []
user_order[uid][supplier_id].append([row['price'], row['update_time']])
def compute_score(infos):
w = [0, 0, 0, 0]
for [_price, dt] in infos:
price = float(_price)
date = datetime.strptime(dt, '%Y-%m-%d')
if (self.now - date) <= timedelta(days=7):
w[0] = max(1., w[0], price / 400)
elif (self.now - date) <= timedelta(days=30):
w[1] = max(1., w[1], price / 400)
elif (self.now - date) <= timedelta(days=180):
w[2] = max(1., w[2], price / 400)
else:
w[3] = max(1., w[3], price / 400)
value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1
return value
index = {}
for uid, orders in user_order.items():
supplier_values = []
for supplier_id, infos in orders.items():
# 订单越多排序约靠前,相同数量订单,最新订单约晚越靠前
value = compute_score(infos)
supplier_values.append([supplier_id, value])
# value = len(infos)
# latest_time = max([info[1] for info in infos])
# supplier_values.append([supplier_id, value, latest_time])
# index[uid] = sorted(supplier_values, key=lambda x: (x[2], x[1]), reverse=True)
index[uid] = sorted(supplier_values, key=lambda x: x[1], reverse=True)
self.logger.info('用户-咨询师 索引构建完成,共构建 %s 条数据', len(index))
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False)
self.logger.info('用户-咨询师 索引数据已保存,共有用户 %s', len(index))
# 订单最多的咨询师
supplier_cnter = Counter(df['supplier_id'])
top100_supplier = []
for key, _ in supplier_cnter.most_common(100):
top100_supplier.append(str(key))
self.logger.info('top100 订单量咨询师统计完成')
with open(os.path.join(self.local_file_dir, 'top100_supplier.txt'), 'w', encoding='utf-8') as f:
f.write('\n'.join(top100_supplier))
self.logger.info('top100 订单量咨询师列表已保存')
if __name__ == '__main__':
manager = OrderDataManager()
print(manager.make_index())
\ No newline at end of file
......@@ -6,7 +6,7 @@ from typing import List
import pandas as pd
from ydl_ai_recommender.src.core.profile import profile_converters
from ydl_ai_recommender.src.core.profile import encode_profile
from ydl_ai_recommender.src.core.manager import DatabaseDataManager
from ydl_ai_recommender.src.utils.log import create_logger
......@@ -29,8 +29,7 @@ class ProfileManager(DatabaseDataManager):
sql += ' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order{})'.format(condition_sql)
return sql
def update_profile(self):
def update_data(self):
""" 从数据库中拉取最新画像特征并保存 """
sql = self._make_query_sql()
......@@ -39,8 +38,7 @@ class ProfileManager(DatabaseDataManager):
df = pd.DataFrame(all_data)
self.save_xlsx_data(df, 'all_profile.xlsx')
def update_test_profile(self, conditions):
def update_test_data(self, conditions):
""" 从数据库中拉取指定条件画像信息用于测试 """
sql = self._make_query_sql(conditions)
......@@ -49,38 +47,12 @@ class ProfileManager(DatabaseDataManager):
df = pd.DataFrame(all_data)
self.save_xlsx_data(df, 'test_profile.xlsx')
def _load_profile_data(self):
return self.load_xlsx_data('all_profile.xlsx')
def load_test_profile_data(self):
return self.load_xlsx_data('test_profile.xlsx')
def profile_to_embedding(self, profile):
"""
将用户画像信息转换为向量
"""
embedding = []
for [name, converter] in profile_converters:
embedding.extend(converter.convert(profile[name]))
return embedding
def embedding_to_profile(self, embedding):
"""
向量转换为用户画像
"""
ret = {}
si = 0
for [name, converter] in profile_converters:
ei = si + converter.dim
ret[name] = converter.inconvert(embedding[si: ei])
si = ei
return ret
def make_embeddings(self):
user_profiles = self._load_profile_data()
self.logger.info('订单用户画像数据加载完成,共加载 %s 条', len(user_profiles))
......@@ -88,7 +60,7 @@ class ProfileManager(DatabaseDataManager):
self.logger.info('开始构建订单用户的用户画像向量')
for _, profile in user_profiles.iterrows():
user_ids.append(str(profile['uid']))
embeddings.append(self.profile_to_embedding(profile))
embeddings.append(encode_profile(profile))
self.logger.info('用户画像向量构建完成,共构建 %s 用户', len(user_ids))
......@@ -100,7 +72,6 @@ class ProfileManager(DatabaseDataManager):
return embeddings
def make_virtual_embedding(self):
user_ids = []
embeddings = []
......
# -*- coding: utf-8 -*-
import os
import json
from ydl_ai_recommender.src.utils.log import create_logger
from ydl_ai_recommender.src.core.manager import Manager
# from ydl_ai_recommender.src.core.manager import OrderDataManager, ChatDataManager
class UserCounselorIndexManager(Manager):
def __init__(self) -> None:
super().__init__(create_logger(__name__, 'order_data_manager.log'))
# self.order_data_manager = OrderDataManager()
# self.chat_data_manager = ChatDataManager()
def make_index(self):
self.logger.info('开始构建用户-咨询师 合并索引')
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), encoding='utf-8') as f:
user_doctor_index = json.load(f)
with open(os.path.join(self.local_file_dir, 'user_doctor_chat_index.json'), encoding='utf-8') as f:
user_doctor_chat_index = json.load(f)
merged_index = {}
for uid, counselors in user_doctor_index.items():
chat_index = {
c_id: value for c_id, value in user_doctor_chat_index.get(uid, [])
}
new_counselors = []
for [c_id, value] in counselors:
if c_id in chat_index:
merge_value = value * 0.6 + chat_index[c_id] * 0.4
else:
merge_value = value * 0.6
new_counselors.append([c_id, merge_value])
merged_index[uid] = sorted(new_counselors, key=lambda x: x[1], reverse=True)
self.logger.info('用户-咨询师 合并索引构建完成,共构建 %s 条数据', len(merged_index))
with open(os.path.join(self.local_file_dir, 'merged_user_doctor_index.json'), 'w', encoding='utf-8') as f:
json.dump(merged_index, f, ensure_ascii=False)
if __name__ == '__main__':
manager = UserCounselorIndexManager()
manager.make_index()
\ No newline at end of file
......@@ -5,20 +5,17 @@ from typing import Dict, List, Any, Union
import pandas as pd
# from .country_code_profile import CountryCodeProfile
# from .profile import ChannelIdTypeProfile
class BaseProfile():
def __init__(self) -> None:
self.dim :int = 0
self.dim: int = 0
def convert(self, value):
raise NotImplemented
raise NotImplementedError
def inconvert(self, embedding: List[Union[int, float]]) -> str:
raise NotImplemented
raise NotImplementedError
class CountryCodeProfile(BaseProfile):
......@@ -29,7 +26,7 @@ class CountryCodeProfile(BaseProfile):
def convert(self, value):
try:
value = int(value)
except Exception as e:
except Exception:
return [0, 0, 1]
if value == 86:
return [1, 0, 0]
......@@ -53,7 +50,7 @@ class ChannelIdTypeProfile(BaseProfile):
def convert(self, value):
try:
value = int(value)
except Exception as e:
except Exception:
return [0, 0, 1]
if value == 1:
......@@ -86,7 +83,7 @@ class FfromLoginProfile(BaseProfile):
ret = [0, 0, 0, 0, 0]
try:
value = value.lower()
except Exception as e:
except Exception:
return ret
for i, v in enumerate(self.brand_list):
......@@ -123,7 +120,7 @@ class UserPreferenceCateProfile(BaseProfile):
if isinstance(value, str):
try:
value = json.loads(value)
except Exception as e:
except Exception:
return ret
for info in value:
......@@ -167,11 +164,10 @@ class NumClassProfile(BaseProfile):
value = float(value)
index = self.value_index(value)
ret[index] = 1
except:
except Exception:
return ret
return ret
def inconvert(self, embedding):
ret = ''
# 确保embedding中有包含1的值
......@@ -245,7 +241,6 @@ class MultiChoiceProfile(BaseProfile):
self.option_dict = option_dict
self.re_option_dict = {v: k for k, v in self.option_dict.items()}
def convert(self, value: List):
ret = [0] * len(self.option_dict)
if pd.isnull(value):
......@@ -261,7 +256,6 @@ class MultiChoiceProfile(BaseProfile):
pass
return ret
def inconvert(self, embedding):
ret = []
......@@ -284,7 +278,6 @@ class CityProfile(BaseProfile):
self.level = level
self.dim = self.level * 10
def convert(self, value):
ret = [0] * self.dim
......@@ -297,11 +290,10 @@ class CityProfile(BaseProfile):
n = int(_n)
ret[i * 10 + n] = 1
except Exception as e:
except Exception:
pass
return ret
def inconvert(self, embedding):
# 邮编固定都是6
ret = [0] * 6
......@@ -318,7 +310,6 @@ class AidiCstBiasCityProfile(CityProfile):
def __init__(self, level=2) -> None:
super().__init__(level=level)
def convert(self, value_object):
ret = [0] * self.dim
......@@ -331,7 +322,7 @@ class AidiCstBiasCityProfile(CityProfile):
if isinstance(value_object, str):
try:
value_object = json.loads(value_object)
except Exception as e:
except Exception:
pass
if isinstance(value_object, dict):
......@@ -340,7 +331,7 @@ class AidiCstBiasCityProfile(CityProfile):
for i, _n in enumerate(value[:self.level]):
n = int(_n)
ret[i * 10 + n] = 1
except Exception as e:
except Exception:
pass
return ret
......@@ -367,3 +358,25 @@ profile_converters = [
['d30_session_num', NumClassProfile([0, 1])],
]
def encode_profile(profile):
"""
将用户画像信息转换为向量
"""
embedding = []
for [name, converter] in profile_converters:
embedding.extend(converter.convert(profile[name]))
return embedding
def decode_profile(embedding):
"""
向量转换为用户画像
"""
ret = {}
si = 0
for [name, converter] in profile_converters:
ei = si + converter.dim
ret[name] = converter.inconvert(embedding[si: ei])
si = ei
return ret
# -*- coding: utf-8 -*-
class CountryCodeProfile():
def __init__(self) -> None:
pass
def convert(self, value):
try:
value = int(value)
except Exception as e:
return [0, 0, 1]
if value == 86:
return [1, 0, 0]
else:
return [0, 1, 0]
\ No newline at end of file
# -*- coding: utf-8 -*-
class ChannelIdTypeProfile():
def __init__(self) -> None:
pass
def convert(self, value):
try:
value = int(value)
except Exception as e:
return [0, 0, 1]
if value == 1:
return [1, 0, 0]
elif value == 2:
return [0, 1, 0]
else:
return [0, 0, 1]
\ No newline at end of file
......@@ -7,7 +7,9 @@ from typing import List, Dict
import faiss
import numpy as np
from ydl_ai_recommender.src.core.manager import ProfileManager
from ydl_ai_recommender.src.core.indexer import UserCounselorDefaultIndexer
from ydl_ai_recommender.src.core.indexer import UserCounselorCombinationIndexer
from ydl_ai_recommender.src.core.profile import encode_profile
from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_data_path
from ydl_ai_recommender.src.utils.log import create_logger
......@@ -19,7 +21,7 @@ class Recommender():
pass
def recommend(self, user) -> List:
raise NotImplemented
raise NotImplementedError
class UserCFRecommender(Recommender):
......@@ -37,7 +39,11 @@ class UserCFRecommender(Recommender):
else:
self.logger.warn('未连接数据库')
self.manager = ProfileManager()
self.default_indexer = UserCounselorDefaultIndexer(self.logger)
self.default_indexer.load_index_data()
self.indexer = UserCounselorCombinationIndexer(self.logger)
self.indexer.load_index_data()
self.local_file_dir = get_data_path()
self.load_data()
......@@ -45,8 +51,6 @@ class UserCFRecommender(Recommender):
def load_data(self):
order_user_embedding = []
order_user_ids = []
order_user_counselor_index = {}
default_counselor = []
with open(os.path.join(self.local_file_dir, 'user_embeddings_ids.txt'), 'r', encoding='utf-8') as f:
order_user_ids = [line.strip() for line in f]
......@@ -54,21 +58,8 @@ class UserCFRecommender(Recommender):
with open(os.path.join(self.local_file_dir, 'user_embeddings.json'), 'r', encoding='utf-8') as f:
order_user_embedding = json.load(f)
with open(os.path.join(self.local_file_dir, 'merged_user_doctor_index.json'), encoding='utf-8') as f:
# with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), encoding='utf-8') as f:
order_user_counselor_index = json.load(f)
with open(os.path.join(self.local_file_dir, 'top100_supplier.txt'), 'r', encoding='utf-8') as f:
default_counselor = [line.strip() for line in f]
self.order_user_embedding = order_user_embedding
self.order_user_ids = order_user_ids
self.order_user_counselor_index = order_user_counselor_index
self.default_counselor = [{
'counselor': str(user),
'score': 1 - 0.01 * index,
'from': 'top_100',
} for index, user in enumerate(default_counselor)]
self.index = faiss.IndexFlatL2(len(self.order_user_embedding[0]))
self.index.add(np.array(self.order_user_embedding))
......@@ -88,28 +79,27 @@ class UserCFRecommender(Recommender):
return []
def user_token(self, user_profile):
return self.manager.profile_to_embedding(user_profile)
def _recommend_top(self, size=100):
return self.default_counselor[:size]
return [{
'counselor': str(c_id),
'score': score,
'from': 'default',
} for [c_id, score] in self.default_indexer.index(size)]
def _recommend(self, user_embedding):
D, I = self.index.search(np.array([user_embedding]), self.k)
counselors = []
for idx, score in zip(I[0], D[0]):
for idx, simi_score in zip(I[0], D[0]):
# 相似用户uid
similar_user_id = self.order_user_ids[idx]
similar_user_counselor = self.order_user_counselor_index.get(similar_user_id, [])
similar_user_counselor = self.indexer.index(q=similar_user_id, count=self.top_n)
recommend_data = [{
'counselor': str(user[0]),
'score': 1 / max(0.01, float(score) * (index + 1)),
'counselor': c_id,
'score': score / max(0.01, float(simi_score)),
'from': 'similar_users {}'.format(similar_user_id),
} for index, user in enumerate(similar_user_counselor[:self.top_n])]
} for (c_id, score) in similar_user_counselor]
counselors.extend(recommend_data)
counselors.sort(key=lambda x: x['score'], reverse=True)
......@@ -117,7 +107,7 @@ class UserCFRecommender(Recommender):
def recommend_with_profile(self, user_profile, size=0, is_merge=True):
user_embedding = self.user_token(user_profile)
user_embedding = encode_profile(user_profile)
counselors = self._recommend(user_embedding)
# size == 0 时,不追加默认推荐咨询师
......
......@@ -40,7 +40,8 @@ class MySQLClient():
try:
self.cursor.close()
self.connection.close()
self.logger.info('dataset disconnected')
# 容易触发 NameError: name 'open' is not defined
# self.logger.info('dataset disconnected')
except Exception as e:
self.logger.error('销毁 MySQLClient 失败', exc_info=True)
print(e)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment