Commit 7401bbe5 by 柴鹏飞

用户-咨询师相关性加入了询单信息特征

parent 8bf0bb96
...@@ -160,3 +160,5 @@ cython_debug/ ...@@ -160,3 +160,5 @@ cython_debug/
#.idea/ #.idea/
!.gitkeep !.gitkeep
/log
/data*
...@@ -6,8 +6,8 @@ import argparse ...@@ -6,8 +6,8 @@ import argparse
from datetime import datetime, timedelta from datetime import datetime, timedelta
from ydl_ai_recommender.src.core.order_data_manager import OrderDataManager from ydl_ai_recommender.src.core.manager import OrderDataManager
from ydl_ai_recommender.src.core.profile_manager import ProfileManager from ydl_ai_recommender.src.core.manager import ProfileManager
from ydl_ai_recommender.src.core.recommender import UserCFRecommender from ydl_ai_recommender.src.core.recommender import UserCFRecommender
from ydl_ai_recommender.src.data.mysql_client import MySQLClient from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path from ydl_ai_recommender.src.utils import get_conf_path
...@@ -165,6 +165,8 @@ def do_test(args): ...@@ -165,6 +165,8 @@ def do_test(args):
# 测试结果统计 # 测试结果统计
evaluation(result_detail) evaluation(result_detail)
if args.save_test_result:
# 保存测试结果详情数据
with open('result_detail.json', 'w', encoding='utf-8') as f: with open('result_detail.json', 'w', encoding='utf-8') as f:
json.dump(result_detail, f, ensure_ascii=False, indent=2) json.dump(result_detail, f, ensure_ascii=False, indent=2)
...@@ -204,6 +206,7 @@ if __name__ == '__main__': ...@@ -204,6 +206,7 @@ if __name__ == '__main__':
parser.add_argument('--k', default=20, type=int, help='召回相似用户的数量') parser.add_argument('--k', default=20, type=int, help='召回相似用户的数量')
parser.add_argument('--top_n', default=5, type=int, help='每个相似用户召回的咨询师数量') parser.add_argument('--top_n', default=5, type=int, help='每个相似用户召回的咨询师数量')
parser.add_argument('--max_test', default=0, type=int, help='最多测试数据量') parser.add_argument('--max_test', default=0, type=int, help='最多测试数据量')
parser.add_argument('--save_test_result', default=False, action='store_true', help='保存测试详情结果')
parser.add_argument('--do_update_test_data', default=False, action='store_true', help='是否更新测试数据') parser.add_argument('--do_update_test_data', default=False, action='store_true', help='是否更新测试数据')
parser.add_argument('--start_date', default='-1', type=str, help='测试订单创建的开始时间,可以是"%Y-%m-%d"格式,也可以是 -3 表示前3天') parser.add_argument('--start_date', default='-1', type=str, help='测试订单创建的开始时间,可以是"%Y-%m-%d"格式,也可以是 -3 表示前3天')
......
...@@ -4,8 +4,9 @@ import os ...@@ -4,8 +4,9 @@ import os
import argparse import argparse
from datetime import datetime from datetime import datetime
from ydl_ai_recommender.src.core.order_data_manager import OrderDataManager from ydl_ai_recommender.src.core.manager import OrderDataManager
from ydl_ai_recommender.src.core.profile_manager import ProfileManager from ydl_ai_recommender.src.core.manager import ChatDataManager
from ydl_ai_recommender.src.core.manager import ProfileManager
from ydl_ai_recommender.src.data.mysql_client import MySQLClient from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_project_path from ydl_ai_recommender.src.utils import get_conf_path, get_project_path
from ydl_ai_recommender.src.utils.log import create_logger from ydl_ai_recommender.src.utils.log import create_logger
...@@ -60,7 +61,13 @@ if __name__ == '__main__': ...@@ -60,7 +61,13 @@ if __name__ == '__main__':
logger.info('开始从数据库中更新订单数据') logger.info('开始从数据库中更新订单数据')
order_data_manager = OrderDataManager(client) order_data_manager = OrderDataManager(client)
order_data_manager.update_order_data() order_data_manager.update_order_data()
logger.info('更新数据完成')
logger.info('开始从数据库中更新询单数据')
chat_data_manager = ChatDataManager(client)
chat_data_manager.update_data()
logger.info('所有数据更新数据完成')
if args.task == 'make_embedding': if args.task == 'make_embedding':
...@@ -75,6 +82,11 @@ if __name__ == '__main__': ...@@ -75,6 +82,11 @@ if __name__ == '__main__':
manager.make_index() manager.make_index()
logger.info('订单相关索引 构建完成') logger.info('订单相关索引 构建完成')
logger.info('开始构建询单相关索引')
chat_data_manager = ChatDataManager()
chat_data_manager.make_index()
logger.info('询单相关索引 构建完成')
if args.task == 'make_virtual_embedding': if args.task == 'make_virtual_embedding':
logger.info('') logger.info('')
......
# -*- coding: utf-8 -*-
from .manager import Manager
from .database_manager import DatabaseDataManager
from .profile_manager import ProfileManager
from .chat_data_manager import ChatDataManager
from .order_data_manager import OrderDataManager
from .user_counselor_index_manager import UserCounselorIndexManager
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
import json
from datetime import datetime, timedelta
import pandas as pd
from ydl_ai_recommender.src.utils.log import create_logger
from ydl_ai_recommender.src.core.manager import DatabaseDataManager
class ChatDataManager(DatabaseDataManager):
def __init__(self, client=None) -> None:
super().__init__(client, create_logger(__name__, 'chat_data_manager.log'))
self.now = datetime.now()
def _make_query_sql(self, conditions=None):
if conditions:
conditions = ['doctor_id IS NOT NULL'] + conditions
else:
conditions = ['doctor_id IS NOT NULL']
condition_sql = ' WHERE ' + ' AND '.join(conditions)
select_fields = ['dt', 'uid', 'doctor_id', 'user_to_doctor', 'doctor_to_user']
sql = 'SELECT {} FROM dws.dws_user_chat_assistant_doctor_day'.format(', '.join(select_fields))
sql += condition_sql
return sql
def update_data(self):
""" 从数据库中拉取最新订单数据并保存 """
sql = self._make_query_sql()
_, all_data = self.fetch_data_from_db(sql)
df = pd.DataFrame(all_data)
self.save_csv_data(df, 'all_chat_info.csv')
return df
def update_test_data(self, conditions):
""" 从数据库中拉取指定条件订单用于测试 """
sql = self._make_query_sql(conditions)
_, all_data = self.fetch_data_from_db(sql)
df = pd.DataFrame(all_data)
self.save_csv_data(df, 'test_chat_info.csv')
def load_raw_data(self):
return self.load_csv_data('all_chat_info.csv')
def load_test_data(self):
return self.load_csv_data('test_chat_info.csv')
def make_index(self):
"""
构建索引
用户-咨询师 索引
"""
self.logger.info('')
self.logger.info('开始构建 用户-咨询师 索引')
df = self.load_raw_data()
self.logger.info('本地用户咨询师对话数据加载完成,共加载 %s 条数据', len(df))
user_chat = {}
for index, row in df.iterrows():
uid, supplier_id = row['uid'], row['doctor_id']
if uid not in user_chat:
user_chat[uid] = {}
if supplier_id not in user_chat[uid]:
user_chat[uid][supplier_id] = []
user_chat[uid][supplier_id].append([row['dt'], row['user_to_doctor'], row['doctor_to_user']])
def compute_score(infos):
w = [0, 0, 0, 0]
for [dt, _u2d, _d2u] in infos:
u2d, d2u = int(_u2d), int(_d2u)
date = datetime.strptime(dt, '%Y-%m-%d')
if (self.now - date) <= timedelta(days=7):
w[0] = max(1., w[0], (u2d + d2u) / 20)
elif (self.now - date) <= timedelta(days=30):
w[1] = max(1., w[1], (u2d + d2u) / 20)
elif (self.now - date) <= timedelta(days=180):
w[2] = max(1., w[2], (u2d + d2u) / 20)
else:
w[3] = max(1., w[3], (u2d + d2u) / 20)
value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1
return value
index = {}
for uid, chats in user_chat.items():
supplier_values = []
for supplier_id, infos in chats.items():
# 日期越近权重越大
value = compute_score(infos)
supplier_values.append([supplier_id, value])
index[uid] = sorted(supplier_values, key=lambda x: x[1], reverse=True)
self.logger.info('用户-咨询师 询单索引构建完成,共构建 %s 条数据', len(index))
with open(os.path.join(self.local_file_dir, 'user_doctor_chat_index.json'), 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False)
if __name__ == '__main__':
from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path
client = MySQLClient.create_from_config_file(get_conf_path())
manager = ChatDataManager(client)
# manager.update_data()
manager.make_index()
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
import json
from typing import List
import pandas as pd
from ydl_ai_recommender.src.utils import get_data_path
from ydl_ai_recommender.src.utils.log import create_logger
class Manager():
def __init__(self, logger=None) -> None:
if logger is None:
self.logger = create_logger(__name__)
else:
self.logger = logger
self.local_file_dir = get_data_path()
def make_index(self):
raise NotImplemented
class DatabaseDataManager(Manager):
def __init__(self, client=None, logger=None) -> None:
super().__init__(logger)
self.client = client
def fetch_data_from_db(self, sql: str) -> List:
if self.client is None:
self.logger.error('未连接数据库')
raise
return self.client.query(sql)
def load_xlsx_data(self, filename):
return pd.read_excel(os.path.join(self.local_file_dir, filename), dtype=str)
def save_xlsx_data(self, df, filename):
df.to_excel(os.path.join(self.local_file_dir, filename), index=None)
def load_csv_data(self, filename):
return pd.read_csv(os.path.join(self.local_file_dir, filename), dtype=str)
def save_csv_data(self, df, filename):
df.to_csv(os.path.join(self.local_file_dir, filename), encoding='utf-8', index=False)
def load_json_data(self, filename):
with open(os.path.join(self.local_file_dir, filename), 'r', encoding='utf-8') as f:
return json.load(f)
def save_json_data(self, data, filename):
with open(os.path.join(self.local_file_dir, filename), 'r', encoding='utf-8') as f:
return json.dump(data, f, ensure_ascii=False)
def update_data(self, sql, filename):
_, all_data = self.fetch_data_from_db(sql)
df = pd.DataFrame(all_data)
self.save_xlsx_data(df, filename)
# -*- coding: utf-8 -*-
from ydl_ai_recommender.src.utils import get_data_path
from ydl_ai_recommender.src.utils.log import create_logger
class Manager():
def __init__(self, logger=None) -> None:
if logger is None:
self.logger = create_logger(__name__)
else:
self.logger = logger
self.local_file_dir = get_data_path()
def make_index(self):
raise NotImplemented
\ No newline at end of file
...@@ -3,25 +3,21 @@ ...@@ -3,25 +3,21 @@
import os import os
import json import json
from collections import Counter from collections import Counter
from datetime import datetime, timedelta
import pandas as pd import pandas as pd
from ydl_ai_recommender.src.utils import get_data_path
from ydl_ai_recommender.src.utils.log import create_logger from ydl_ai_recommender.src.utils.log import create_logger
from ydl_ai_recommender.src.core.manager import DatabaseDataManager
class OrderDataManager(): class OrderDataManager(DatabaseDataManager):
def __init__(self, client=None) -> None: def __init__(self, client=None) -> None:
self.local_file_dir = get_data_path() super().__init__(client, create_logger(__name__, 'order_data_manager.log'))
self.client = client self.now = datetime.now()
self.logger = create_logger(__name__, 'order_data_manager.log')
def _fetch_data_from_db(self, conditions=None): def _make_query_sql(self, conditions=None):
if self.client is None:
self.logger.error('未连接数据库')
raise
condition_sql = '' condition_sql = ''
if conditions: if conditions:
condition_sql = ' WHERE ' + ' AND '.join(conditions) condition_sql = ' WHERE ' + ' AND '.join(conditions)
...@@ -30,33 +26,34 @@ class OrderDataManager(): ...@@ -30,33 +26,34 @@ class OrderDataManager():
select_fields.append('DATE_FORMAT(update_time, "%Y-%m-%d") AS update_time') select_fields.append('DATE_FORMAT(update_time, "%Y-%m-%d") AS update_time')
sql = 'SELECT {} FROM ods.ods_ydl_standard_order'.format(', '.join(select_fields)) sql = 'SELECT {} FROM ods.ods_ydl_standard_order'.format(', '.join(select_fields))
sql += condition_sql sql += condition_sql
return sql
_, all_data = self.client.query(sql)
return all_data
def update_order_data(self): def update_order_data(self):
""" 从数据库中拉取最新订单数据并保存 """ """ 从数据库中拉取最新订单数据并保存 """
all_data = self._fetch_data_from_db() sql = self._make_query_sql()
_, all_data = self.fetch_data_from_db(sql)
df = pd.DataFrame(all_data) df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx'), index=None) self.save_xlsx_data(df, 'all_order_info.xlsx')
def update_test_order_data(self, conditions): def update_test_order_data(self, conditions):
""" 从数据库中拉取指定条件订单用于测试 """ """ 从数据库中拉取指定条件订单用于测试 """
all_data = self._fetch_data_from_db(conditions)
sql = self._make_query_sql(conditions)
_, all_data = self.fetch_data_from_db(sql)
df = pd.DataFrame(all_data) df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'test_order_info.xlsx'), index=None) self.save_xlsx_data(df, 'test_order_info.xlsx')
def load_raw_data(self): def load_raw_data(self):
df = pd.read_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx'), dtype=str) return self.load_xlsx_data('all_order_info.xlsx')
return df
def load_test_order_data(self): def load_test_order_data(self):
df = pd.read_excel(os.path.join(self.local_file_dir, 'test_order_info.xlsx'), dtype=str) return self.load_xlsx_data('test_order_info.xlsx')
return df
def make_index(self): def make_index(self):
...@@ -81,16 +78,41 @@ class OrderDataManager(): ...@@ -81,16 +78,41 @@ class OrderDataManager():
user_order[uid][supplier_id] = [] user_order[uid][supplier_id] = []
user_order[uid][supplier_id].append([row['price'], row['update_time']]) user_order[uid][supplier_id].append([row['price'], row['update_time']])
def compute_score(infos):
w = [0, 0, 0, 0]
for [_price, dt] in infos:
price = float(_price)
date = datetime.strptime(dt, '%Y-%m-%d')
if (self.now - date) <= timedelta(days=7):
w[0] = max(1., w[0], price / 400)
elif (self.now - date) <= timedelta(days=30):
w[1] = max(1., w[1], price / 400)
elif (self.now - date) <= timedelta(days=180):
w[2] = max(1., w[2], price / 400)
else:
w[3] = max(1., w[3], price / 400)
value = w[0] * 0.5 + w[1] * 0.25 + w[2] * 0.15 + w[3] * 0.1
return value
index = {} index = {}
for uid, orders in user_order.items(): for uid, orders in user_order.items():
supplier_values = [] supplier_values = []
for supplier_id, infos in orders.items(): for supplier_id, infos in orders.items():
# 订单越多排序约靠前,相同数量订单,最新订单约晚越靠前 # 订单越多排序约靠前,相同数量订单,最新订单约晚越靠前
value = len(infos) value = compute_score(infos)
latest_time = max([info[1] for info in infos]) supplier_values.append([supplier_id, value])
supplier_values.append([supplier_id, value, latest_time])
# value = len(infos)
# latest_time = max([info[1] for info in infos])
# supplier_values.append([supplier_id, value, latest_time])
# index[uid] = sorted(supplier_values, key=lambda x: (x[2], x[1]), reverse=True)
index[uid] = sorted(supplier_values, key=lambda x: (x[2], x[1]), reverse=True) index[uid] = sorted(supplier_values, key=lambda x: x[1], reverse=True)
self.logger.info('用户-咨询师 索引构建完成,共构建 %s 条数据', len(index)) self.logger.info('用户-咨询师 索引构建完成,共构建 %s 条数据', len(index))
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), 'w', encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), 'w', encoding='utf-8') as f:
......
...@@ -2,62 +2,60 @@ ...@@ -2,62 +2,60 @@
import os import os
import json import json
import logging from typing import List
import pandas as pd import pandas as pd
from ydl_ai_recommender.src.utils import get_data_path
from ydl_ai_recommender.src.core.profile import profile_converters from ydl_ai_recommender.src.core.profile import profile_converters
from ydl_ai_recommender.src.core.manager import DatabaseDataManager
from ydl_ai_recommender.src.utils.log import create_logger from ydl_ai_recommender.src.utils.log import create_logger
class ProfileManager():
class ProfileManager(DatabaseDataManager):
""" """
订单用户画像数据管理 订单用户画像数据管理
""" """
def __init__(self, client=None) -> None: def __init__(self, client=None) -> None:
self.local_file_dir = get_data_path() super().__init__(client, create_logger(__name__, 'profile_manager.log'))
self.profile_file_path = os.path.join(self.local_file_dir, 'all_profile.json')
self.client = client
self.logger = create_logger(__name__, 'profile_manager.log')
def _fetch_data_from_db(self, conditions=None):
if self.client is None:
self.logger.error('未连接数据库')
raise
def _make_query_sql(self, conditions=None):
condition_sql = '' condition_sql = ''
if conditions: if conditions:
condition_sql = ' WHERE ' + ' AND '.join(conditions) condition_sql = ' WHERE ' + ' AND '.join(conditions)
sql = 'SELECT * FROM ads.ads_register_user_profiles' sql = 'SELECT * FROM ads.ads_register_user_profiles'
sql += ' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order{})'.format(condition_sql) sql += ' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order{})'.format(condition_sql)
return sql
_, all_data = self.client.query(sql)
return all_data
def update_profile(self): def update_profile(self):
""" 从数据库中拉取最新画像特征并保存 """ """ 从数据库中拉取最新画像特征并保存 """
all_data = self._fetch_data_from_db()
sql = self._make_query_sql()
_, all_data = self.fetch_data_from_db(sql)
df = pd.DataFrame(all_data) df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'all_profile.xlsx'), index=None) self.save_xlsx_data(df, 'all_profile.xlsx')
def update_test_profile(self, conditions): def update_test_profile(self, conditions):
""" 从数据库中拉取指定条件画像信息用于测试 """ """ 从数据库中拉取指定条件画像信息用于测试 """
all_data = self._fetch_data_from_db(conditions)
sql = self._make_query_sql(conditions)
_, all_data = self.fetch_data_from_db(sql)
df = pd.DataFrame(all_data) df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'test_profile.xlsx'), index=None) self.save_xlsx_data(df, 'test_profile.xlsx')
def _load_profile_data(self): def _load_profile_data(self):
return pd.read_excel(os.path.join(self.local_file_dir, 'all_profile.xlsx'), dtype=str) return self.load_xlsx_data('all_profile.xlsx')
def load_test_profile_data(self): def load_test_profile_data(self):
return pd.read_excel(os.path.join(self.local_file_dir, 'test_profile.xlsx'), dtype=str) return self.load_xlsx_data('test_profile.xlsx')
def profile_to_embedding(self, profile): def profile_to_embedding(self, profile):
...@@ -69,6 +67,7 @@ class ProfileManager(): ...@@ -69,6 +67,7 @@ class ProfileManager():
embedding.extend(converter.convert(profile[name])) embedding.extend(converter.convert(profile[name]))
return embedding return embedding
def embedding_to_profile(self, embedding): def embedding_to_profile(self, embedding):
""" """
向量转换为用户画像 向量转换为用户画像
......
# -*- coding: utf-8 -*-
import os
import json
from ydl_ai_recommender.src.utils.log import create_logger
from ydl_ai_recommender.src.core.manager import Manager
# from ydl_ai_recommender.src.core.manager import OrderDataManager, ChatDataManager
class UserCounselorIndexManager(Manager):
def __init__(self) -> None:
super().__init__(create_logger(__name__, 'order_data_manager.log'))
# self.order_data_manager = OrderDataManager()
# self.chat_data_manager = ChatDataManager()
def make_index(self):
self.logger.info('开始构建用户-咨询师 合并索引')
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), encoding='utf-8') as f:
user_doctor_index = json.load(f)
with open(os.path.join(self.local_file_dir, 'user_doctor_chat_index.json'), encoding='utf-8') as f:
user_doctor_chat_index = json.load(f)
merged_index = {}
for uid, counselors in user_doctor_index.items():
chat_index = {
c_id: value for c_id, value in user_doctor_chat_index.get(uid, [])
}
new_counselors = []
for [c_id, value] in counselors:
if c_id in chat_index:
merge_value = value * 0.6 + chat_index[c_id] * 0.4
else:
merge_value = value * 0.6
new_counselors.append([c_id, merge_value])
merged_index[uid] = sorted(new_counselors, key=lambda x: x[1], reverse=True)
self.logger.info('用户-咨询师 合并索引构建完成,共构建 %s 条数据', len(merged_index))
with open(os.path.join(self.local_file_dir, 'merged_user_doctor_index.json'), 'w', encoding='utf-8') as f:
json.dump(merged_index, f, ensure_ascii=False)
if __name__ == '__main__':
manager = UserCounselorIndexManager()
manager.make_index()
\ No newline at end of file
...@@ -7,7 +7,7 @@ from typing import List, Dict ...@@ -7,7 +7,7 @@ from typing import List, Dict
import faiss import faiss
import numpy as np import numpy as np
from ydl_ai_recommender.src.core.profile_manager import ProfileManager from ydl_ai_recommender.src.core.manager import ProfileManager
from ydl_ai_recommender.src.data.mysql_client import MySQLClient from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_data_path from ydl_ai_recommender.src.utils import get_conf_path, get_data_path
from ydl_ai_recommender.src.utils.log import create_logger from ydl_ai_recommender.src.utils.log import create_logger
...@@ -54,7 +54,8 @@ class UserCFRecommender(Recommender): ...@@ -54,7 +54,8 @@ class UserCFRecommender(Recommender):
with open(os.path.join(self.local_file_dir, 'user_embeddings.json'), 'r', encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, 'user_embeddings.json'), 'r', encoding='utf-8') as f:
order_user_embedding = json.load(f) order_user_embedding = json.load(f)
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, 'merged_user_doctor_index.json'), encoding='utf-8') as f:
# with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), encoding='utf-8') as f:
order_user_counselor_index = json.load(f) order_user_counselor_index = json.load(f)
with open(os.path.join(self.local_file_dir, 'top100_supplier.txt'), 'r', encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, 'top100_supplier.txt'), 'r', encoding='utf-8') as f:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment