Commit cfdc63bd by 柴鹏飞

基于 u2u 的推荐

parent 6e946e5a
...@@ -4,9 +4,9 @@ channels: ...@@ -4,9 +4,9 @@ channels:
- pytorch - pytorch
- defaults - defaults
dependencies: dependencies:
- python==3.8 - python==3.9
- ipykernel - ipykernel
- faiss-cpu - faiss-cpu
- pip - pip
- pip: - pip:
- -r requirements.txt - -r requirements.txt
\ No newline at end of file
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import os import os
import json import json
import logging import logging
from datetime import datetime, timedelta
import pandas as pd import pandas as pd
...@@ -35,7 +36,7 @@ class DBDataManager(): ...@@ -35,7 +36,7 @@ class DBDataManager():
self.logger.info('开始保存 %s 到本地', name) self.logger.info('开始保存 %s 到本地', name)
with open(os.path.join(self.local_file_dir, name), mode, encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, name), mode, encoding='utf-8') as f:
f.write('\n'.join(lines)) f.write('\n'.join(lines))
self.logger.info('%s 保存成功,共保存 %s 行内人', name, len(lines)) self.logger.info('%s 保存成功,共保存 %s 行数据', name, len(lines))
def _save_json_data(self, data, name): def _save_json_data(self, data, name):
...@@ -63,10 +64,42 @@ class DBDataManager(): ...@@ -63,10 +64,42 @@ class DBDataManager():
_, all_data = self.client.query(sql) _, all_data = self.client.query(sql)
df = pd.DataFrame(all_data) df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx'), index=None) df.to_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx'), index=None)
# self._save_json_data(all_data, 'all_order_info.json')
def _load_order_data(self, conditions=None):
select_fields = ['main_order_id', 'uid', 'supplier_id', 'price', 'standard_order_type']
select_fields.append('DATE_FORMAT(update_time, "%Y-%m-%d") AS update_time')
sql = 'SELECT {} FROM ods.ods_ydl_standard_order'.format(', '.join(select_fields))
if conditions:
sql += ' WHERE {}'.format('AND '.join(conditions))
self.logger.info('开始执行sql %s', sql)
cnt, data = self.client.query(sql)
self.logger.info('sql执行成功,共获取 %s 条数据', cnt)
return data
def load_test_data(self, days=5):
now = datetime.now()
start_time = now - timedelta(days=days)
conditions = [
'create_time >= "{}"'.format(start_time.strftime('%Y-%m-%d')),
]
order_data = self._load_order_data(conditions=conditions)
df = pd.DataFrame(order_data)
df.to_excel(os.path.join(self.local_file_dir, 'test_order_info.xlsx'), index=None)
select_fields = ['*']
sql = 'SELECT {} FROM ads.ads_register_user_profiles'.format(', '.join(select_fields))
sql += ' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order WHERE create_time >= "{}")'.format(start_time.strftime('%Y-%m-%d'))
_, all_data = self.client.query(sql)
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'test_profile.xlsx'), index=None)
if __name__ == '__main__': if __name__ == '__main__':
manager = DBDataManager() manager = DBDataManager()
manager.load_test_data()
# manager.update_local_data() # manager.update_local_data()
# print(manager.make_index()) # print(manager.make_index())
\ No newline at end of file
...@@ -29,10 +29,12 @@ parser.add_argument('--index_last_date', default=None, type=str, help='构建索 ...@@ -29,10 +29,12 @@ parser.add_argument('--index_last_date', default=None, type=str, help='构建索
parser.add_argument( parser.add_argument(
'-t', '--task', type=str, required=True, '-t', '--task', type=str, required=True,
choices=('load_db_data', 'make_profile_index'), help='执行任务名称' choices=('load_db_data', 'make_profile_index', 'do_test'), help='执行任务名称'
) )
parser.add_argument('--output_dir', default='outputs', type=str, help='模型训练中间结果和训练好的模型保存目录') parser.add_argument('--test_start_date', default='-3', type=str, help='测试任务 - 开始日期')
parser.add_argument('--max_seq_length', default=128, type=int, help='tokenization 之后序列最大长度。超过会被截断,小于会补齐') parser.add_argument('--test_end_date', default='0', type=str, help='测试任务 - 结束日期')
parser.add_argument('--batch_size', default=128, type=int, help='训练时一个 batch 包含多少条数据') parser.add_argument('--batch_size', default=128, type=int, help='训练时一个 batch 包含多少条数据')
parser.add_argument('--learning_rate', default=1e-3, type=float, help='Adam 优化器的学习率') parser.add_argument('--learning_rate', default=1e-3, type=float, help='Adam 优化器的学习率')
parser.add_argument('--add_special_tokens', default=True, type=bool, help='bert encode 时前后是否添加特殊token') parser.add_argument('--add_special_tokens', default=True, type=bool, help='bert encode 时前后是否添加特殊token')
...@@ -46,10 +48,15 @@ args = parser.parse_args() ...@@ -46,10 +48,15 @@ args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
if args.task == 'load_db_data': if args.task == 'load_db_data':
# 从数据库中导出信息
manager = DBDataManager() manager = DBDataManager()
manager.update_order_info() manager.update_order_info()
manager.update_profile() manager.update_profile()
if args.task == 'make_profile_index': if args.task == 'make_profile_index':
manager = ProfileManager() manager = ProfileManager()
manager.make_embeddings() manager.make_embeddings()
\ No newline at end of file manager.make_virtual_embedding()
if args.task == 'make_similarity':
pass
\ No newline at end of file
...@@ -11,13 +11,51 @@ from ydl_ai_recommender.src.utils import get_data_path ...@@ -11,13 +11,51 @@ from ydl_ai_recommender.src.utils import get_data_path
class OrderDataManager(): class OrderDataManager():
def __init__(self) -> None: def __init__(self, client=None) -> None:
self.local_file_dir = get_data_path() self.local_file_dir = get_data_path()
self.client = client
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
def _fetch_data_from_db(self, conditions=None):
if self.client is None:
self.logger.error('未连接数据库')
raise
condition_sql = ''
if conditions:
condition_sql = ' WHERE ' + ' AND '.join(conditions)
select_fields = ['main_order_id', 'uid', 'supplier_id', 'price', 'standard_order_type']
select_fields.append('DATE_FORMAT(update_time, "%Y-%m-%d") AS update_time')
sql = 'SELECT {} FROM ods.ods_ydl_standard_order'.format(', '.join(select_fields))
sql += condition_sql
_, all_data = self.client.query(sql)
return all_data
def update_order_data(self):
""" 从数据库中拉取最新订单数据并保存 """
all_data = self._fetch_data_from_db()
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx'), index=None)
def update_test_order_data(self, conditions):
""" 从数据库中拉取指定条件订单用于测试 """
all_data = self._fetch_data_from_db(conditions)
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'test_order_info.xlsx'), index=None)
def load_raw_data(self): def load_raw_data(self):
df = pd.read_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx')) df = pd.read_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx'), dtype=str)
return df
def load_test_order_data(self):
df = pd.read_excel(os.path.join(self.local_file_dir, 'test_order_info.xlsx'), dtype=str)
return df return df
...@@ -49,7 +87,7 @@ class OrderDataManager(): ...@@ -49,7 +87,7 @@ class OrderDataManager():
latest_time = max([info[1] for info in infos]) latest_time = max([info[1] for info in infos])
supplier_values.append([supplier_id, value, latest_time]) supplier_values.append([supplier_id, value, latest_time])
index[uid] = sorted(supplier_values, key=lambda x: (x[1], x[2]), reverse=True) index[uid] = sorted(supplier_values, key=lambda x: (x[2], x[1]), reverse=True)
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), 'w', encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False) json.dump(index, f, ensure_ascii=False)
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json import json
from typing import Dict, List, Any from typing import Dict, List, Any, Union
import pandas as pd import pandas as pd
from .country_code_profile import CountryCodeProfile # from .country_code_profile import CountryCodeProfile
from .profile import ChannelIdTypeProfile # from .profile import ChannelIdTypeProfile
class FfromLoginProfile(): class BaseProfile():
def __init__(self) -> None:
self.dim :int = 0
def convert(self, value):
raise NotImplemented
def inconvert(self, embedding: List[Union[int, float]]) -> str:
raise NotImplemented
class CountryCodeProfile(BaseProfile):
def __init__(self) -> None:
super().__init__()
self.dim = 3
def convert(self, value):
try:
value = int(value)
except Exception as e:
return [0, 0, 1]
if value == 86:
return [1, 0, 0]
else:
return [0, 1, 0]
def inconvert(self, embedding):
if embedding[0] == 1:
return 'china'
elif embedding[1] == 2:
return 'abroad'
else:
return 'unknown'
class ChannelIdTypeProfile(BaseProfile):
def __init__(self) -> None:
super().__init__()
self.dim = 3
def convert(self, value):
try:
value = int(value)
except Exception as e:
return [0, 0, 1]
if value == 1:
return [1, 0, 0]
elif value == 2:
return [0, 1, 0]
else:
return [0, 0, 1]
def inconvert(self, embedding):
if embedding[0] == 1:
return 'ios'
elif embedding[1] == 2:
return 'android'
else:
return 'other'
class FfromLoginProfile(BaseProfile):
""" """
登录来源 登录来源
主要是不同android品牌特征,android和ios特征在其他画像中已有,这里不重复构建 主要是不同android品牌特征,android和ios特征在其他画像中已有,这里不重复构建
""" """
def __init__(self) -> None: def __init__(self) -> None:
pass super().__init__()
self.brand_list = ['huawei', 'vivo', 'oppo', 'xiaomi']
self.dim = len(self.brand_list) + 1
def convert(self, value): def convert(self, value):
ret = [0, 0, 0, 0, 0] ret = [0, 0, 0, 0, 0]
...@@ -24,25 +89,30 @@ class FfromLoginProfile(): ...@@ -24,25 +89,30 @@ class FfromLoginProfile():
except Exception as e: except Exception as e:
return ret return ret
if 'huawei' in value: for i, v in enumerate(self.brand_list):
ret[0] = 1 if v in value:
elif 'vivo' in value: ret[i] = 1
ret[1] = 1 break
elif 'oppo' in value:
ret[2] = 1
elif 'xiaomi' in value:
ret[3] = 1
else: else:
ret[4] = 1 ret[self.dim - 1] = 1
return ret return ret
def inconvert(self, embedding):
for emb, name in zip(embedding, self.brand_list):
if emb == 1:
return name
return 'other'
class UserPreferenceCateProfile():
class UserPreferenceCateProfile(BaseProfile):
def __init__(self) -> None: def __init__(self) -> None:
cate_list = ['个人成长', '亲子教育', '人际关系', '婚姻家庭', '心理健康', '恋爱情感', '情绪压力', '职场发展'] super().__init__()
self.dim = 8
self.cate_list = ['个人成长', '亲子教育', '人际关系', '婚姻家庭', '心理健康', '恋爱情感', '情绪压力', '职场发展']
self.cate_index = { self.cate_index = {
cate: index for index, cate in enumerate(cate_list) cate: index for index, cate in enumerate(self.cate_list)
} }
def convert(self, value): def convert(self, value):
...@@ -61,13 +131,21 @@ class UserPreferenceCateProfile(): ...@@ -61,13 +131,21 @@ class UserPreferenceCateProfile():
return ret return ret
def inconvert(self, embedding):
ret = {}
for emb, name in zip(embedding, self.cate_list):
ret[name] = emb
return json.dumps(ret, ensure_ascii=False)
class NumClassProfile(): class NumClassProfile(BaseProfile):
def __init__(self, split_points, mode='le') -> None: def __init__(self, split_points, mode='le') -> None:
""" """
mode : le 前开后闭; be: 前闭后开;默认前开后闭 mode : le 前开后闭; be: 前闭后开;默认前开后闭
""" """
super().__init__()
self.dim = len(split_points) + 1
self.split_points = split_points self.split_points = split_points
self.mode = mode self.mode = mode
...@@ -94,31 +172,87 @@ class NumClassProfile(): ...@@ -94,31 +172,87 @@ class NumClassProfile():
return ret return ret
class AidiCstBiasPriceProfile(): def inconvert(self, embedding):
ret = ''
# 确保embedding中有包含1的值
if embedding.count(1) == 0:
return ''
index = embedding.index(1)
if index == 0:
if self.mode == 'be':
ret = '< {}'.format(self.split_points[0])
else:
ret = '<= {}'.format(self.split_points[0])
elif index == (self.dim - 1):
if self.mode == 'be':
ret = '>= {}'.format(self.split_points[-1])
else:
ret = '> {}'.format(self.split_points[-1])
else:
if self.mode == 'be':
ret = '{} < {}'.format(self.split_points[index-1], self.split_points[index])
else:
ret = '{} <= {}'.format(self.split_points[index-1], self.split_points[index])
return ret
class AidiCstBiasPriceProfile(BaseProfile):
""" 用户偏好价格 """
def __init__(self) -> None:
super().__init__()
self.dim = 6
self.price_groups = [
'[0, 50)',
'[50, 100)',
'[100, 200)',
'[200, 500)',
'[500, 1000)',
'[1000, )',
]
def convert(self, value): def convert(self, value):
ret = [0] * 6 ret = [0] * 6
if pd.isnull(value): if pd.isnull(value):
return ret return ret
for v in value: json_object = json.loads(value)
for v in json_object:
try: try:
ret[v['level'] - 1] = 1 ret[v['level'] - 1] = 1
except Exception: except Exception:
pass pass
return ret return ret
def inconvert(self, embedding):
ret = ''
# 确保embedding中有包含1的值
if embedding.count(1) == 0:
return ''
index = embedding.index(1)
ret = self.price_groups[index]
return ret
class MultiChoiceProfile(): class MultiChoiceProfile(BaseProfile):
def __init__(self, option_dict: Dict[Any, int]) -> None: def __init__(self, option_dict: Dict[Any, int]) -> None:
super().__init__()
self.dim = 6
self.option_dict = option_dict self.option_dict = option_dict
self.re_option_dict = {v: k for k, v in self.option_dict.items()}
def convert(self, value: List): def convert(self, value: List):
ret = [0] * len(self.option_dict) ret = [0] * len(self.option_dict)
if pd.isnull(value): if pd.isnull(value):
return ret return ret
value = json.loads(value)
for v in value: for v in value:
try: try:
i = self.option_dict[v] i = self.option_dict[v]
...@@ -128,73 +262,86 @@ class MultiChoiceProfile(): ...@@ -128,73 +262,86 @@ class MultiChoiceProfile():
return ret return ret
class CityProfile(): def inconvert(self, embedding):
ret = []
for index, emb in enumerate(embedding):
if emb == 1:
ret.append(self.re_option_dict.get(index, ''))
return '[{}]'.format(', '.join(map(str, ret)))
class CityProfile(BaseProfile):
""" 基于邮编的城市编码 """
def __init__(self, level=2) -> None:
"""
level: 级别,2-省/直辖市 ; 3-区; 4-区;6-投递区
"""
super().__init__()
self.level = level
self.dim = self.level * 10
def __init__(self) -> None:
self.default_city_codes = [
'330500', '640200', '130200', '620200', '321000', '530600', '650500', '410200', '511400', '450200',
'610800', '220400', '430400', '320500', '410400', '341100', '420300', '410500', '640100', '440100',
'420500', '650200', '441900', '211200', '210400', '140700', '131000', '440700', '340100', '350200',
'371100', '370900', '130500', '451300', '331100', '320700', '710100', '500000', '610300', '370100',
'610500', '450300', '520100', '140200', '320400', '210500', '440300', '610700', '341800', '210300',
'340200', '120100', '340500', '210200', '222400', '370600', '110100', '441200', '230500', '510100',
'330700', '330600', '370300', '230600', '450100', '340300', '651800', '340800', '430200', '421300',
'220600', '150200', '433100', '440500', '620100', '710200', '130100', '131100', '150600', '430100',
'150100', '130400', '140600', '140300', '410300', '620600', '330300', '321100', '320900', '630100',
'320100', '410900', '510400', '620800', '610600', '220300', '420600', '510700', '130300', '411400',
'310000', '341200', '370500', '710500', '231100', '152900', '371500', '220100', '360700', '150500',
'331000', '360600', '371000', '341600', '130600', '230100', '410800', '370700', '410700', '430800',
'410100', '210800', '330400', '460200', '650100', '310100', '350500', '360400', '320300', '500100',
'360900', '610100', '350100', '350400', '530100', '320600', '130900', '371300', '421200', '210700',
'220200', '130700', '320800', '420100', '110000', '150400', '442000', '469002', '360100', '150800',
'441300', '460100', '610200', '210100', '210900', '371400', '621000', '141000', '330100', '220700',
'371700', '370800', '211400', '330200', '140400', '120000', '231200', '140100', '431100', '320200',
'451000', '370200', '511900', '361100', '610400', '440600', '411100', '231000', '360300'
]
self.city_codes = self.default_city_codes
self.city_code_dict = {
code: index for index, code in enumerate(self.city_codes)
}
def convert(self, value): def convert(self, value):
ret = [0] * len(self.city_code_dict) ret = [0] * self.dim
if pd.isnull(value): if pd.isnull(value):
return ret return ret
value = str(value) value = str(value)
try: try:
i = self.city_code_dict[value] for i, _n in enumerate(value[:self.level]):
ret[i] = 1 n = int(_n)
ret[i * 10 + n] = 1
except Exception as e: except Exception as e:
pass pass
return ret return ret
def inconvert(self, embedding):
# 邮编固定都是6
ret = [0] * 6
for index, _emb in enumerate(embedding):
emb = int(_emb)
if emb == 1:
ret[int(index / 10)] = index % 10
return ''.join(map(str, ret))
class AidiCstBiasCityProfile(CityProfile): class AidiCstBiasCityProfile(CityProfile):
""" 支持多个城市编码 """
def __init__(self) -> None: def __init__(self, level=2) -> None:
super().__init__() super().__init__(level=level)
def convert(self, value):
ret = [0] * len(self.city_code_dict)
if pd.isnull(value): def convert(self, value_object):
ret = [0] * self.dim
if pd.isnull(value_object):
return ret return ret
if not value: if not value_object:
return ret return ret
if isinstance(value, str): if isinstance(value_object, str):
try: try:
value = json.loads(value) value_object = json.loads(value_object)
except Exception as e:
pass
for v in value:
try:
ret[self.city_code_dict[v]] = 1
except Exception as e: except Exception as e:
pass pass
if isinstance(value_object, dict):
for value in value_object.get('in', []):
try:
for i, _n in enumerate(value[:self.level]):
n = int(_n)
ret[i * 10 + n] = 1
except Exception as e:
pass
return ret return ret
......
...@@ -15,16 +15,49 @@ class ProfileManager(): ...@@ -15,16 +15,49 @@ class ProfileManager():
订单用户画像数据管理 订单用户画像数据管理
""" """
def __init__(self) -> None: def __init__(self, client=None) -> None:
self.local_file_dir = get_data_path() self.local_file_dir = get_data_path()
self.profile_file_path = os.path.join(self.local_file_dir, 'all_profile.json') self.profile_file_path = os.path.join(self.local_file_dir, 'all_profile.json')
self.client = client
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
def _fetch_data_from_db(self, conditions=None):
if self.client is None:
self.logger.error('未连接数据库')
raise
condition_sql = ''
if conditions:
condition_sql = ' WHERE ' + ' AND '.join(conditions)
sql = 'SELECT * FROM ads.ads_register_user_profiles'
sql += ' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order{})'.format(condition_sql)
_, all_data = self.client.query(sql)
return all_data
def update_profile(self):
""" 从数据库中拉取最新画像特征并保存 """
all_data = self._fetch_data_from_db()
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'all_profile.xlsx'), index=None)
def update_test_profile(self, conditions):
""" 从数据库中拉取指定条件画像信息用于测试 """
all_data = self._fetch_data_from_db(conditions)
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'test_profile.xlsx'), index=None)
def _load_profile_data(self): def _load_profile_data(self):
return pd.read_excel(os.path.join(self.local_file_dir, 'all_profile.xlsx')) return pd.read_excel(os.path.join(self.local_file_dir, 'all_profile.xlsx'), dtype=str)
# with open(self.profile_file_path, 'r', encoding='utf-8') as f:
# return json.load(f)
def load_test_profile_data(self):
return pd.read_excel(os.path.join(self.local_file_dir, 'test_profile.xlsx'), dtype=str)
def profile_to_embedding(self, profile): def profile_to_embedding(self, profile):
...@@ -36,6 +69,18 @@ class ProfileManager(): ...@@ -36,6 +69,18 @@ class ProfileManager():
embedding.extend(converter.convert(profile[name])) embedding.extend(converter.convert(profile[name]))
return embedding return embedding
def embedding_to_profile(self, embedding):
"""
向量转换为用户画像
"""
ret = {}
si = 0
for [name, converter] in profile_converters:
ei = si + converter.dim
ret[name] = converter.inconvert(embedding[si: ei])
si = ei
return ret
def make_embeddings(self): def make_embeddings(self):
user_profiles = self._load_profile_data() user_profiles = self._load_profile_data()
...@@ -61,10 +106,33 @@ class ProfileManager(): ...@@ -61,10 +106,33 @@ class ProfileManager():
user_ids = [] user_ids = []
embeddings = [] embeddings = []
with open(os.path.join(self.local_file_dir, 'user_embeddings_ids.txt'), 'r', encoding='utf-8') as f:
user_ids = [line.strip() for line in f]
with open(os.path.join(self.local_file_dir, 'user_embeddings.json'), 'r', encoding='utf-8') as f:
embeddings = json.load(f)
v_embedding_set = {}
for user_id, embedding in zip(user_ids, embeddings):
key = '_'.join(map(str, embedding))
if key not in v_embedding_set:
v_embedding_set[key] = {
'embedding': embedding,
'user_ids': [],
}
v_embedding_set[key]['user_ids'].append(str(user_id))
v_embedding_list = []
with open(os.path.join(self.local_file_dir, 'virtual_user_embeddings_ids.txt'), 'w', encoding='utf-8') as f:
for info in v_embedding_set.values():
f.write(','.join(info['user_ids']) + '\n')
v_embedding_list.append(info['embedding'])
with open(os.path.join(self.local_file_dir, 'virtual_user_embeddings.json'), 'w', encoding='utf-8') as f:
json.dump(v_embedding_list, f, ensure_ascii=False)
if __name__ == '__main__': if __name__ == '__main__':
manager = ProfileManager() manager = ProfileManager()
manager.make_embeddings() # manager.make_embeddings()
# manager.update_local_data() manager.make_virtual_embedding()
# print(manager.make_index()) \ No newline at end of file
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
import json
from typing import List, Dict
import faiss
import numpy as np
from ydl_ai_recommender.src.core.profile_manager import ProfileManager
from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_data_path
class Recommender():
def __init__(self) -> None:
pass
def recommend(self, user) -> List:
raise NotImplemented
class UserCFRecommender(Recommender):
def __init__(self, top_n=5, k=5, is_lazy=True) -> None:
super().__init__()
# 召回 top_n 个相似用户
self.top_n = top_n
# 每个召回的用户取 k 个相关咨询师
self.k = k
if is_lazy is False:
self.client = MySQLClient.create_from_config_file(get_conf_path())
self.manager = ProfileManager()
self.local_file_dir = get_data_path()
self.load_data()
def load_data(self):
order_user_embedding = []
order_user_ids = []
order_user_counselor_index = {}
default_counselor = []
with open(os.path.join(self.local_file_dir, 'user_embeddings_ids.txt'), 'r', encoding='utf-8') as f:
order_user_ids = [line.strip() for line in f]
with open(os.path.join(self.local_file_dir, 'user_embeddings.json'), 'r', encoding='utf-8') as f:
order_user_embedding = json.load(f)
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), encoding='utf-8') as f:
order_user_counselor_index = json.load(f)
with open(os.path.join(self.local_file_dir, 'top50_supplier.txt'), 'r', encoding='utf-8') as f:
default_counselor = [line.strip() for line in f]
self.order_user_embedding = order_user_embedding
self.order_user_ids = order_user_ids
self.order_user_counselor_index = order_user_counselor_index
self.default_counselor = default_counselor
self.index = faiss.IndexFlatL2(len(self.order_user_embedding[0]))
self.index.add(np.array(self.order_user_embedding))
def get_user_profile(self, user_id):
sql = 'SELECT * FROM ads.ads_register_user_profiles'
sql += ' WHERE uid={}'.format(user_id)
_, all_data = self.client.query(sql)
if len(all_data) == 0:
return []
return all_data[0]
def user_token(self, user_profile):
return self.manager.profile_to_embedding(user_profile)
def _recommend(self, user_embedding):
D, I = self.index.search(np.array([user_embedding]), self.k)
counselors = []
for idx, score in zip(I[0], D[0]):
# 相似用户uid
similar_user_id = self.order_user_ids[idx]
similar_user_counselor = self.order_user_counselor_index.get(similar_user_id, [])
recommend_data = [{
'counselor': str(user[0]),
'score': float(score),
'from': 'similar_users {}'.format(similar_user_id),
} for user in similar_user_counselor[:self.top_n]]
counselors.extend(recommend_data)
return counselors
def recommend_with_profile(self, user_profile):
user_embedding = self.user_token(user_profile)
counselors = self._recommend(user_embedding)
return counselors
def recommend(self, user_id):
"""
根据用户画像,推荐咨询师
若获取不到用户画像,推荐默认咨询师(订单最多的)
"""
user_profile = self.get_user_profile(user_id)
if not user_profile:
return []
return self.recommend_with_profile(user_profile)
if __name__ == '__main__':
recommender = UserCFRecommender()
print(recommender.recommend('10957910'))
\ No newline at end of file
# -*- coding: utf-8 -*-
import re
import json
import logging
import argparse
from datetime import datetime, timedelta
from ydl_ai_recommender.src.core.order_data_manager import OrderDataManager
from ydl_ai_recommender.src.core.profile_manager import ProfileManager
from ydl_ai_recommender.src.core.recommender import UserCFRecommender
from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO
)
logger = logging.getLogger(__name__)
def main(args):
# 构建用户画像字典,不用每次都从数据库中获取
profile_manager = ProfileManager()
df = profile_manager.load_test_profile_data()
user_profile_dict = {}
for _, row in df.iterrows():
user_profile_dict[row['uid']] = row
manager = OrderDataManager()
# 加载训练订单数据,为后面判断用户是否为新用户
train_orders = manager.load_raw_data()
old_users = set(train_orders['uid'])
logger.info('订单用户数 %s ', len(old_users))
# 加载测试数据
test_orders = manager.load_test_order_data()
logger.info('加载测试数据成功,共加载 %s 条', len(test_orders))
recommender = UserCFRecommender(top_n=args.top_n, k=args.k)
result_detail = []
for index, order_info in test_orders.iterrows():
if args.max_test > 0:
if index >= args.max_test:
break
uid = order_info['uid']
profile = user_profile_dict.get(uid)
if profile is None:
continue
recommend_result = recommender.recommend_with_profile(profile)
recall_resons = []
for rr in recommend_result:
if rr['counselor'] == order_info['supplier_id']:
recall_resons.append(rr['from'])
result_detail.append({
'uid': uid,
'supplier_id': order_info['supplier_id'],
'is_old_user': uid in old_users,
'recall_counselors': recommend_result,
'is_recall': len(recall_resons) > 0,
'recall_reason': '|'.join(recall_resons),
})
# 结果报告
metrics = {
'all_test_cnt': 0,
'all_recall_cnt': 0,
'old_user_test_cnt': 0,
'old_user_recall_cnt': 0,
'new_user_test_cnt': 0,
'new_user_recall_cnt': 0,
'same_user_recall_cnt': 0,
'similar_user_recall_cnt': 0,
}
for rd in result_detail:
metrics['all_test_cnt'] += 1
if rd['is_recall']:
metrics['all_recall_cnt'] += 1
is_same_user, is_similar_user = False, False
for counselor in rd['recall_counselors']:
from_id = counselor['from'].split(' ')[1]
if from_id == rd['uid']:
is_same_user = True
if from_id != rd['uid']:
is_similar_user = True
if is_same_user:
metrics['same_user_recall_cnt'] += 1
if is_similar_user:
metrics['similar_user_recall_cnt'] += 1
if rd['is_old_user']:
metrics['old_user_test_cnt'] += 1
if rd['is_recall']:
metrics['old_user_recall_cnt'] += 1
else:
metrics['new_user_test_cnt'] += 1
if rd['is_recall']:
metrics['new_user_recall_cnt'] += 1
logger.info('==' * 20 + ' 测试结果 ' + '==' * 20)
logger.info('--' * 45)
logger.info('{:<10}{:<10}{:<10}{:<10}'.format('', '样本数', '召回数', '召回率'))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('整体 ', metrics['all_test_cnt'], metrics['all_recall_cnt'], metrics['all_recall_cnt'] / metrics['all_test_cnt']))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('老用户', metrics['old_user_test_cnt'], metrics['old_user_recall_cnt'], metrics['old_user_recall_cnt'] / metrics['old_user_test_cnt']))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('新用户', metrics['new_user_test_cnt'], metrics['new_user_recall_cnt'], metrics['new_user_recall_cnt'] / metrics['new_user_test_cnt']))
logger.info('')
logger.info('--' * 45)
logger.info('用户自己召回数 {} 占总召回比例 {:.2%}'.format(metrics['same_user_recall_cnt'], metrics['same_user_recall_cnt'] / metrics['all_recall_cnt']))
logger.info('相似用户召回数 {} 占总召回比例 {:.2%}'.format(metrics['similar_user_recall_cnt'], metrics['similar_user_recall_cnt'] / metrics['all_recall_cnt']))
with open('result_detail.json', 'w', encoding='utf-8') as f:
json.dump(result_detail, f, ensure_ascii=False, indent=2)
def update_test_data(args):
start_date = ''
if re.match(r'20\d\d-[01]\d-\d\d', args.start_date):
start_date = args.start_date
elif re.match(r'-\d+', args.start_date):
now = datetime.now()
start_date = (now - timedelta(days=int(args.start_date[1:]))).strftime('%Y-%m-%d')
else:
logger.error('args.start_date 参数格式错误,%s', args.start_date)
raise
conditions = ['create_time >= "{}"'.format(start_date)]
client = MySQLClient.create_from_config_file(get_conf_path())
# 订单数据
manager = OrderDataManager(client)
manager.update_test_order_data(conditions=conditions)
# 用户画像数据
manager = ProfileManager(client)
manager.update_test_profile(conditions=conditions)
logger.info('测试数据更新完成')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--k', default=5, type=int, help='召回相似用户的数量')
parser.add_argument('--top_n', default=5, type=int, help='每个相似用户召回的咨询师数量')
parser.add_argument('--max_test', default=0, type=int, help='最多测试数据量')
parser.add_argument('--do_update_test_data', default=False, action='store_true', help='是否更新测试数据')
parser.add_argument('--start_date', default='-1', type=str, help='测试订单创建的开始时间,可以是"%Y-%m-%d"格式,也可以是 -3 表示前3天')
args = parser.parse_args()
if args.do_update_test_data:
logger.info('更新测试数据')
update_test_data(args)
main(args)
\ No newline at end of file
...@@ -8,6 +8,13 @@ from itertools import combinations ...@@ -8,6 +8,13 @@ from itertools import combinations
from ydl_ai_recommender.src.utils import get_data_path from ydl_ai_recommender.src.utils import get_data_path
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO
)
class BasicUserSimilarity(): class BasicUserSimilarity():
def __init__(self) -> None: def __init__(self) -> None:
...@@ -24,7 +31,7 @@ class BasicUserSimilarity(): ...@@ -24,7 +31,7 @@ class BasicUserSimilarity():
counselor_user_index = {} counselor_user_index = {}
for user, counselors in user_counselor_index.items(): for user, counselors in user_counselor_index.items():
user_like_set[user] = len(counselors) user_like_set[user] = set([c[0] for c in counselors])
for [counselor, _, _] in counselors: for [counselor, _, _] in counselors:
if counselor not in counselor_user_index: if counselor not in counselor_user_index:
...@@ -32,19 +39,37 @@ class BasicUserSimilarity(): ...@@ -32,19 +39,37 @@ class BasicUserSimilarity():
counselor_user_index[counselor].append(user) counselor_user_index[counselor].append(user)
# 两个用户与同一个咨询师有订单,就认为两个用户相似 # 两个用户与同一个咨询师有订单,就认为两个用户相似
self.logger.info('开始构建用户相似性关系') self.logger.info('开始构建用户之间相似性关系')
relations = {} relations = {}
user_index = {}
for users in counselor_user_index.values(): for users in counselor_user_index.values():
for [_u1, _u2] in combinations(users, 2): for [_u1, _u2] in combinations(users, 2):
u1, u2 = min(_u1, _u2), max(_u1, _u2) u1, u2 = min(_u1, _u2), max(_u1, _u2)
key = '{}_{}'.format(u1, u2) key = '{}_{}'.format(u1, u2)
if key in relations: if key in relations:
continue continue
relations[key] = 1.0 / (user_like_set[u1] * user_like_set[u2]) sim = len(user_like_set[u1] & user_like_set[u2]) / (len(user_like_set[u1]) * len(user_like_set[u2]))
relations[key] = sim
if u1 not in user_index:
user_index[u1] = {}
if u2 not in user_index:
user_index[u2] = {}
if u2 not in user_index[u1]:
user_index[u1][u2] = sim
if u1 not in user_index[u2]:
user_index[u2][u1] = sim
user_counselor_index = {}
self.logger.info('用户相似性关系构建完成,共有 %s 对关系', len(relations)) self.logger.info('用户相似性关系构建完成,共有 %s 对关系', len(relations))
with open(os.path.join(self.local_file_dir, 'user_similarity.json'), 'w', encoding='utf-8') as f: with open(os.path.join(self.local_file_dir, 'user_similarity.json'), 'w', encoding='utf-8') as f:
json.dump(relations, f, ensure_ascii=False, indent=2) json.dump(relations, f, ensure_ascii=False, indent=2)
def recall(self, user, N=10, top_k=10):
pass
bs = BasicUserSimilarity() bs = BasicUserSimilarity()
......
...@@ -21,11 +21,26 @@ class MySQLClient(): ...@@ -21,11 +21,26 @@ class MySQLClient():
self.cursor = self.connection.cursor() self.cursor = self.connection.cursor()
self._log_info('数据库连接成功') self._log_info('数据库连接成功')
def _connect(self):
self.connection = pymysql.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
self.cursor = self.connection.cursor()
self._log_info('数据库连接成功')
def _log_info(self, text, *args, **params): def _log_info(self, text, *args, **params):
if self.logger: if self.logger:
self.logger.info(text, *args, **params) self.logger.info(text, *args, **params)
def query(self, sql): def query(self, sql):
if self.cursor is None:
self._connect()
sql += ' limit 1000' sql += ' limit 1000'
self._log_info('begin execute sql: %s', sql) self._log_info('begin execute sql: %s', sql)
row_count = self.cursor.execute(sql) row_count = self.cursor.execute(sql)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment