Commit cfdc63bd by 柴鹏飞

基于 u2u 的推荐

parent 6e946e5a
...@@ -4,7 +4,7 @@ channels: ...@@ -4,7 +4,7 @@ channels:
- pytorch - pytorch
- defaults - defaults
dependencies: dependencies:
- python==3.8 - python==3.9
- ipykernel - ipykernel
- faiss-cpu - faiss-cpu
- pip - pip
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import os import os
import json import json
import logging import logging
from datetime import datetime, timedelta
import pandas as pd import pandas as pd
...@@ -35,7 +36,7 @@ class DBDataManager(): ...@@ -35,7 +36,7 @@ class DBDataManager():
self.logger.info('开始保存 %s 到本地', name) self.logger.info('开始保存 %s 到本地', name)
with open(os.path.join(self.local_file_dir, name), mode, encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, name), mode, encoding='utf-8') as f:
f.write('\n'.join(lines)) f.write('\n'.join(lines))
self.logger.info('%s 保存成功,共保存 %s 行内人', name, len(lines)) self.logger.info('%s 保存成功,共保存 %s 行数据', name, len(lines))
def _save_json_data(self, data, name): def _save_json_data(self, data, name):
...@@ -63,10 +64,42 @@ class DBDataManager(): ...@@ -63,10 +64,42 @@ class DBDataManager():
_, all_data = self.client.query(sql) _, all_data = self.client.query(sql)
df = pd.DataFrame(all_data) df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx'), index=None) df.to_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx'), index=None)
# self._save_json_data(all_data, 'all_order_info.json')
def _load_order_data(self, conditions=None):
select_fields = ['main_order_id', 'uid', 'supplier_id', 'price', 'standard_order_type']
select_fields.append('DATE_FORMAT(update_time, "%Y-%m-%d") AS update_time')
sql = 'SELECT {} FROM ods.ods_ydl_standard_order'.format(', '.join(select_fields))
if conditions:
sql += ' WHERE {}'.format('AND '.join(conditions))
self.logger.info('开始执行sql %s', sql)
cnt, data = self.client.query(sql)
self.logger.info('sql执行成功,共获取 %s 条数据', cnt)
return data
def load_test_data(self, days=5):
now = datetime.now()
start_time = now - timedelta(days=days)
conditions = [
'create_time >= "{}"'.format(start_time.strftime('%Y-%m-%d')),
]
order_data = self._load_order_data(conditions=conditions)
df = pd.DataFrame(order_data)
df.to_excel(os.path.join(self.local_file_dir, 'test_order_info.xlsx'), index=None)
select_fields = ['*']
sql = 'SELECT {} FROM ads.ads_register_user_profiles'.format(', '.join(select_fields))
sql += ' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order WHERE create_time >= "{}")'.format(start_time.strftime('%Y-%m-%d'))
_, all_data = self.client.query(sql)
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'test_profile.xlsx'), index=None)
if __name__ == '__main__': if __name__ == '__main__':
manager = DBDataManager() manager = DBDataManager()
manager.load_test_data()
# manager.update_local_data() # manager.update_local_data()
# print(manager.make_index()) # print(manager.make_index())
\ No newline at end of file
...@@ -29,10 +29,12 @@ parser.add_argument('--index_last_date', default=None, type=str, help='构建索 ...@@ -29,10 +29,12 @@ parser.add_argument('--index_last_date', default=None, type=str, help='构建索
parser.add_argument( parser.add_argument(
'-t', '--task', type=str, required=True, '-t', '--task', type=str, required=True,
choices=('load_db_data', 'make_profile_index'), help='执行任务名称' choices=('load_db_data', 'make_profile_index', 'do_test'), help='执行任务名称'
) )
parser.add_argument('--output_dir', default='outputs', type=str, help='模型训练中间结果和训练好的模型保存目录') parser.add_argument('--test_start_date', default='-3', type=str, help='测试任务 - 开始日期')
parser.add_argument('--max_seq_length', default=128, type=int, help='tokenization 之后序列最大长度。超过会被截断,小于会补齐') parser.add_argument('--test_end_date', default='0', type=str, help='测试任务 - 结束日期')
parser.add_argument('--batch_size', default=128, type=int, help='训练时一个 batch 包含多少条数据') parser.add_argument('--batch_size', default=128, type=int, help='训练时一个 batch 包含多少条数据')
parser.add_argument('--learning_rate', default=1e-3, type=float, help='Adam 优化器的学习率') parser.add_argument('--learning_rate', default=1e-3, type=float, help='Adam 优化器的学习率')
parser.add_argument('--add_special_tokens', default=True, type=bool, help='bert encode 时前后是否添加特殊token') parser.add_argument('--add_special_tokens', default=True, type=bool, help='bert encode 时前后是否添加特殊token')
...@@ -46,6 +48,7 @@ args = parser.parse_args() ...@@ -46,6 +48,7 @@ args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
if args.task == 'load_db_data': if args.task == 'load_db_data':
# 从数据库中导出信息
manager = DBDataManager() manager = DBDataManager()
manager.update_order_info() manager.update_order_info()
manager.update_profile() manager.update_profile()
...@@ -53,3 +56,7 @@ if __name__ == '__main__': ...@@ -53,3 +56,7 @@ if __name__ == '__main__':
if args.task == 'make_profile_index': if args.task == 'make_profile_index':
manager = ProfileManager() manager = ProfileManager()
manager.make_embeddings() manager.make_embeddings()
manager.make_virtual_embedding()
if args.task == 'make_similarity':
pass
\ No newline at end of file
...@@ -11,13 +11,51 @@ from ydl_ai_recommender.src.utils import get_data_path ...@@ -11,13 +11,51 @@ from ydl_ai_recommender.src.utils import get_data_path
class OrderDataManager(): class OrderDataManager():
def __init__(self) -> None: def __init__(self, client=None) -> None:
self.local_file_dir = get_data_path() self.local_file_dir = get_data_path()
self.client = client
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
def _fetch_data_from_db(self, conditions=None):
if self.client is None:
self.logger.error('未连接数据库')
raise
condition_sql = ''
if conditions:
condition_sql = ' WHERE ' + ' AND '.join(conditions)
select_fields = ['main_order_id', 'uid', 'supplier_id', 'price', 'standard_order_type']
select_fields.append('DATE_FORMAT(update_time, "%Y-%m-%d") AS update_time')
sql = 'SELECT {} FROM ods.ods_ydl_standard_order'.format(', '.join(select_fields))
sql += condition_sql
_, all_data = self.client.query(sql)
return all_data
def update_order_data(self):
""" 从数据库中拉取最新订单数据并保存 """
all_data = self._fetch_data_from_db()
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx'), index=None)
def update_test_order_data(self, conditions):
""" 从数据库中拉取指定条件订单用于测试 """
all_data = self._fetch_data_from_db(conditions)
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'test_order_info.xlsx'), index=None)
def load_raw_data(self): def load_raw_data(self):
df = pd.read_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx')) df = pd.read_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx'), dtype=str)
return df
def load_test_order_data(self):
df = pd.read_excel(os.path.join(self.local_file_dir, 'test_order_info.xlsx'), dtype=str)
return df return df
...@@ -49,7 +87,7 @@ class OrderDataManager(): ...@@ -49,7 +87,7 @@ class OrderDataManager():
latest_time = max([info[1] for info in infos]) latest_time = max([info[1] for info in infos])
supplier_values.append([supplier_id, value, latest_time]) supplier_values.append([supplier_id, value, latest_time])
index[uid] = sorted(supplier_values, key=lambda x: (x[1], x[2]), reverse=True) index[uid] = sorted(supplier_values, key=lambda x: (x[2], x[1]), reverse=True)
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), 'w', encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False) json.dump(index, f, ensure_ascii=False)
......
...@@ -15,16 +15,49 @@ class ProfileManager(): ...@@ -15,16 +15,49 @@ class ProfileManager():
订单用户画像数据管理 订单用户画像数据管理
""" """
def __init__(self) -> None: def __init__(self, client=None) -> None:
self.local_file_dir = get_data_path() self.local_file_dir = get_data_path()
self.profile_file_path = os.path.join(self.local_file_dir, 'all_profile.json') self.profile_file_path = os.path.join(self.local_file_dir, 'all_profile.json')
self.client = client
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
def _fetch_data_from_db(self, conditions=None):
if self.client is None:
self.logger.error('未连接数据库')
raise
condition_sql = ''
if conditions:
condition_sql = ' WHERE ' + ' AND '.join(conditions)
sql = 'SELECT * FROM ads.ads_register_user_profiles'
sql += ' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order{})'.format(condition_sql)
_, all_data = self.client.query(sql)
return all_data
def update_profile(self):
""" 从数据库中拉取最新画像特征并保存 """
all_data = self._fetch_data_from_db()
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'all_profile.xlsx'), index=None)
def update_test_profile(self, conditions):
""" 从数据库中拉取指定条件画像信息用于测试 """
all_data = self._fetch_data_from_db(conditions)
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'test_profile.xlsx'), index=None)
def _load_profile_data(self): def _load_profile_data(self):
return pd.read_excel(os.path.join(self.local_file_dir, 'all_profile.xlsx')) return pd.read_excel(os.path.join(self.local_file_dir, 'all_profile.xlsx'), dtype=str)
# with open(self.profile_file_path, 'r', encoding='utf-8') as f:
# return json.load(f)
def load_test_profile_data(self):
return pd.read_excel(os.path.join(self.local_file_dir, 'test_profile.xlsx'), dtype=str)
def profile_to_embedding(self, profile): def profile_to_embedding(self, profile):
...@@ -36,6 +69,18 @@ class ProfileManager(): ...@@ -36,6 +69,18 @@ class ProfileManager():
embedding.extend(converter.convert(profile[name])) embedding.extend(converter.convert(profile[name]))
return embedding return embedding
def embedding_to_profile(self, embedding):
"""
向量转换为用户画像
"""
ret = {}
si = 0
for [name, converter] in profile_converters:
ei = si + converter.dim
ret[name] = converter.inconvert(embedding[si: ei])
si = ei
return ret
def make_embeddings(self): def make_embeddings(self):
user_profiles = self._load_profile_data() user_profiles = self._load_profile_data()
...@@ -61,10 +106,33 @@ class ProfileManager(): ...@@ -61,10 +106,33 @@ class ProfileManager():
user_ids = [] user_ids = []
embeddings = [] embeddings = []
with open(os.path.join(self.local_file_dir, 'user_embeddings_ids.txt'), 'r', encoding='utf-8') as f:
user_ids = [line.strip() for line in f]
with open(os.path.join(self.local_file_dir, 'user_embeddings.json'), 'r', encoding='utf-8') as f:
embeddings = json.load(f)
v_embedding_set = {}
for user_id, embedding in zip(user_ids, embeddings):
key = '_'.join(map(str, embedding))
if key not in v_embedding_set:
v_embedding_set[key] = {
'embedding': embedding,
'user_ids': [],
}
v_embedding_set[key]['user_ids'].append(str(user_id))
v_embedding_list = []
with open(os.path.join(self.local_file_dir, 'virtual_user_embeddings_ids.txt'), 'w', encoding='utf-8') as f:
for info in v_embedding_set.values():
f.write(','.join(info['user_ids']) + '\n')
v_embedding_list.append(info['embedding'])
with open(os.path.join(self.local_file_dir, 'virtual_user_embeddings.json'), 'w', encoding='utf-8') as f:
json.dump(v_embedding_list, f, ensure_ascii=False)
if __name__ == '__main__': if __name__ == '__main__':
manager = ProfileManager() manager = ProfileManager()
manager.make_embeddings() # manager.make_embeddings()
# manager.update_local_data() manager.make_virtual_embedding()
# print(manager.make_index()) \ No newline at end of file
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
import json
from typing import List, Dict
import faiss
import numpy as np
from ydl_ai_recommender.src.core.profile_manager import ProfileManager
from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_data_path
class Recommender():
def __init__(self) -> None:
pass
def recommend(self, user) -> List:
raise NotImplemented
class UserCFRecommender(Recommender):
def __init__(self, top_n=5, k=5, is_lazy=True) -> None:
super().__init__()
# 召回 top_n 个相似用户
self.top_n = top_n
# 每个召回的用户取 k 个相关咨询师
self.k = k
if is_lazy is False:
self.client = MySQLClient.create_from_config_file(get_conf_path())
self.manager = ProfileManager()
self.local_file_dir = get_data_path()
self.load_data()
def load_data(self):
order_user_embedding = []
order_user_ids = []
order_user_counselor_index = {}
default_counselor = []
with open(os.path.join(self.local_file_dir, 'user_embeddings_ids.txt'), 'r', encoding='utf-8') as f:
order_user_ids = [line.strip() for line in f]
with open(os.path.join(self.local_file_dir, 'user_embeddings.json'), 'r', encoding='utf-8') as f:
order_user_embedding = json.load(f)
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), encoding='utf-8') as f:
order_user_counselor_index = json.load(f)
with open(os.path.join(self.local_file_dir, 'top50_supplier.txt'), 'r', encoding='utf-8') as f:
default_counselor = [line.strip() for line in f]
self.order_user_embedding = order_user_embedding
self.order_user_ids = order_user_ids
self.order_user_counselor_index = order_user_counselor_index
self.default_counselor = default_counselor
self.index = faiss.IndexFlatL2(len(self.order_user_embedding[0]))
self.index.add(np.array(self.order_user_embedding))
def get_user_profile(self, user_id):
sql = 'SELECT * FROM ads.ads_register_user_profiles'
sql += ' WHERE uid={}'.format(user_id)
_, all_data = self.client.query(sql)
if len(all_data) == 0:
return []
return all_data[0]
def user_token(self, user_profile):
return self.manager.profile_to_embedding(user_profile)
def _recommend(self, user_embedding):
D, I = self.index.search(np.array([user_embedding]), self.k)
counselors = []
for idx, score in zip(I[0], D[0]):
# 相似用户uid
similar_user_id = self.order_user_ids[idx]
similar_user_counselor = self.order_user_counselor_index.get(similar_user_id, [])
recommend_data = [{
'counselor': str(user[0]),
'score': float(score),
'from': 'similar_users {}'.format(similar_user_id),
} for user in similar_user_counselor[:self.top_n]]
counselors.extend(recommend_data)
return counselors
def recommend_with_profile(self, user_profile):
user_embedding = self.user_token(user_profile)
counselors = self._recommend(user_embedding)
return counselors
def recommend(self, user_id):
"""
根据用户画像,推荐咨询师
若获取不到用户画像,推荐默认咨询师(订单最多的)
"""
user_profile = self.get_user_profile(user_id)
if not user_profile:
return []
return self.recommend_with_profile(user_profile)
if __name__ == '__main__':
recommender = UserCFRecommender()
print(recommender.recommend('10957910'))
\ No newline at end of file
# -*- coding: utf-8 -*-
import re
import json
import logging
import argparse
from datetime import datetime, timedelta
from ydl_ai_recommender.src.core.order_data_manager import OrderDataManager
from ydl_ai_recommender.src.core.profile_manager import ProfileManager
from ydl_ai_recommender.src.core.recommender import UserCFRecommender
from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO
)
logger = logging.getLogger(__name__)
def main(args):
# 构建用户画像字典,不用每次都从数据库中获取
profile_manager = ProfileManager()
df = profile_manager.load_test_profile_data()
user_profile_dict = {}
for _, row in df.iterrows():
user_profile_dict[row['uid']] = row
manager = OrderDataManager()
# 加载训练订单数据,为后面判断用户是否为新用户
train_orders = manager.load_raw_data()
old_users = set(train_orders['uid'])
logger.info('订单用户数 %s ', len(old_users))
# 加载测试数据
test_orders = manager.load_test_order_data()
logger.info('加载测试数据成功,共加载 %s 条', len(test_orders))
recommender = UserCFRecommender(top_n=args.top_n, k=args.k)
result_detail = []
for index, order_info in test_orders.iterrows():
if args.max_test > 0:
if index >= args.max_test:
break
uid = order_info['uid']
profile = user_profile_dict.get(uid)
if profile is None:
continue
recommend_result = recommender.recommend_with_profile(profile)
recall_resons = []
for rr in recommend_result:
if rr['counselor'] == order_info['supplier_id']:
recall_resons.append(rr['from'])
result_detail.append({
'uid': uid,
'supplier_id': order_info['supplier_id'],
'is_old_user': uid in old_users,
'recall_counselors': recommend_result,
'is_recall': len(recall_resons) > 0,
'recall_reason': '|'.join(recall_resons),
})
# 结果报告
metrics = {
'all_test_cnt': 0,
'all_recall_cnt': 0,
'old_user_test_cnt': 0,
'old_user_recall_cnt': 0,
'new_user_test_cnt': 0,
'new_user_recall_cnt': 0,
'same_user_recall_cnt': 0,
'similar_user_recall_cnt': 0,
}
for rd in result_detail:
metrics['all_test_cnt'] += 1
if rd['is_recall']:
metrics['all_recall_cnt'] += 1
is_same_user, is_similar_user = False, False
for counselor in rd['recall_counselors']:
from_id = counselor['from'].split(' ')[1]
if from_id == rd['uid']:
is_same_user = True
if from_id != rd['uid']:
is_similar_user = True
if is_same_user:
metrics['same_user_recall_cnt'] += 1
if is_similar_user:
metrics['similar_user_recall_cnt'] += 1
if rd['is_old_user']:
metrics['old_user_test_cnt'] += 1
if rd['is_recall']:
metrics['old_user_recall_cnt'] += 1
else:
metrics['new_user_test_cnt'] += 1
if rd['is_recall']:
metrics['new_user_recall_cnt'] += 1
logger.info('==' * 20 + ' 测试结果 ' + '==' * 20)
logger.info('--' * 45)
logger.info('{:<10}{:<10}{:<10}{:<10}'.format('', '样本数', '召回数', '召回率'))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('整体 ', metrics['all_test_cnt'], metrics['all_recall_cnt'], metrics['all_recall_cnt'] / metrics['all_test_cnt']))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('老用户', metrics['old_user_test_cnt'], metrics['old_user_recall_cnt'], metrics['old_user_recall_cnt'] / metrics['old_user_test_cnt']))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('新用户', metrics['new_user_test_cnt'], metrics['new_user_recall_cnt'], metrics['new_user_recall_cnt'] / metrics['new_user_test_cnt']))
logger.info('')
logger.info('--' * 45)
logger.info('用户自己召回数 {} 占总召回比例 {:.2%}'.format(metrics['same_user_recall_cnt'], metrics['same_user_recall_cnt'] / metrics['all_recall_cnt']))
logger.info('相似用户召回数 {} 占总召回比例 {:.2%}'.format(metrics['similar_user_recall_cnt'], metrics['similar_user_recall_cnt'] / metrics['all_recall_cnt']))
with open('result_detail.json', 'w', encoding='utf-8') as f:
json.dump(result_detail, f, ensure_ascii=False, indent=2)
def update_test_data(args):
start_date = ''
if re.match(r'20\d\d-[01]\d-\d\d', args.start_date):
start_date = args.start_date
elif re.match(r'-\d+', args.start_date):
now = datetime.now()
start_date = (now - timedelta(days=int(args.start_date[1:]))).strftime('%Y-%m-%d')
else:
logger.error('args.start_date 参数格式错误,%s', args.start_date)
raise
conditions = ['create_time >= "{}"'.format(start_date)]
client = MySQLClient.create_from_config_file(get_conf_path())
# 订单数据
manager = OrderDataManager(client)
manager.update_test_order_data(conditions=conditions)
# 用户画像数据
manager = ProfileManager(client)
manager.update_test_profile(conditions=conditions)
logger.info('测试数据更新完成')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--k', default=5, type=int, help='召回相似用户的数量')
parser.add_argument('--top_n', default=5, type=int, help='每个相似用户召回的咨询师数量')
parser.add_argument('--max_test', default=0, type=int, help='最多测试数据量')
parser.add_argument('--do_update_test_data', default=False, action='store_true', help='是否更新测试数据')
parser.add_argument('--start_date', default='-1', type=str, help='测试订单创建的开始时间,可以是"%Y-%m-%d"格式,也可以是 -3 表示前3天')
args = parser.parse_args()
if args.do_update_test_data:
logger.info('更新测试数据')
update_test_data(args)
main(args)
\ No newline at end of file
...@@ -8,6 +8,13 @@ from itertools import combinations ...@@ -8,6 +8,13 @@ from itertools import combinations
from ydl_ai_recommender.src.utils import get_data_path from ydl_ai_recommender.src.utils import get_data_path
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO
)
class BasicUserSimilarity(): class BasicUserSimilarity():
def __init__(self) -> None: def __init__(self) -> None:
...@@ -24,7 +31,7 @@ class BasicUserSimilarity(): ...@@ -24,7 +31,7 @@ class BasicUserSimilarity():
counselor_user_index = {} counselor_user_index = {}
for user, counselors in user_counselor_index.items(): for user, counselors in user_counselor_index.items():
user_like_set[user] = len(counselors) user_like_set[user] = set([c[0] for c in counselors])
for [counselor, _, _] in counselors: for [counselor, _, _] in counselors:
if counselor not in counselor_user_index: if counselor not in counselor_user_index:
...@@ -32,20 +39,38 @@ class BasicUserSimilarity(): ...@@ -32,20 +39,38 @@ class BasicUserSimilarity():
counselor_user_index[counselor].append(user) counselor_user_index[counselor].append(user)
# 两个用户与同一个咨询师有订单,就认为两个用户相似 # 两个用户与同一个咨询师有订单,就认为两个用户相似
self.logger.info('开始构建用户相似性关系') self.logger.info('开始构建用户之间相似性关系')
relations = {} relations = {}
user_index = {}
for users in counselor_user_index.values(): for users in counselor_user_index.values():
for [_u1, _u2] in combinations(users, 2): for [_u1, _u2] in combinations(users, 2):
u1, u2 = min(_u1, _u2), max(_u1, _u2) u1, u2 = min(_u1, _u2), max(_u1, _u2)
key = '{}_{}'.format(u1, u2) key = '{}_{}'.format(u1, u2)
if key in relations: if key in relations:
continue continue
relations[key] = 1.0 / (user_like_set[u1] * user_like_set[u2]) sim = len(user_like_set[u1] & user_like_set[u2]) / (len(user_like_set[u1]) * len(user_like_set[u2]))
relations[key] = sim
if u1 not in user_index:
user_index[u1] = {}
if u2 not in user_index:
user_index[u2] = {}
if u2 not in user_index[u1]:
user_index[u1][u2] = sim
if u1 not in user_index[u2]:
user_index[u2][u1] = sim
user_counselor_index = {}
self.logger.info('用户相似性关系构建完成,共有 %s 对关系', len(relations)) self.logger.info('用户相似性关系构建完成,共有 %s 对关系', len(relations))
with open(os.path.join(self.local_file_dir, 'user_similarity.json'), 'w', encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, 'user_similarity.json'), 'w', encoding='utf-8') as f:
json.dump(relations, f, ensure_ascii=False, indent=2) json.dump(relations, f, ensure_ascii=False, indent=2)
def recall(self, user, N=10, top_k=10):
pass
bs = BasicUserSimilarity() bs = BasicUserSimilarity()
bs.compute_similarity() bs.compute_similarity()
\ No newline at end of file
...@@ -21,11 +21,26 @@ class MySQLClient(): ...@@ -21,11 +21,26 @@ class MySQLClient():
self.cursor = self.connection.cursor() self.cursor = self.connection.cursor()
self._log_info('数据库连接成功') self._log_info('数据库连接成功')
def _connect(self):
self.connection = pymysql.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
self.cursor = self.connection.cursor()
self._log_info('数据库连接成功')
def _log_info(self, text, *args, **params): def _log_info(self, text, *args, **params):
if self.logger: if self.logger:
self.logger.info(text, *args, **params) self.logger.info(text, *args, **params)
def query(self, sql): def query(self, sql):
if self.cursor is None:
self._connect()
sql += ' limit 1000' sql += ' limit 1000'
self._log_info('begin execute sql: %s', sql) self._log_info('begin execute sql: %s', sql)
row_count = self.cursor.execute(sql) row_count = self.cursor.execute(sql)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment