Commit 6e946e5a by 柴鹏飞

用户画像embedding构建

parent 9acf0300
# run: conda env create --file environment.yaml # run: conda env create --file environment.yaml
name: yar name: yar
channels: channels:
- pytorch
- defaults - defaults
dependencies: dependencies:
- python==3.8 - python==3.8
- ipykernel - ipykernel
- faiss-cpu
- pip - pip
- pip: - pip:
- -r requirements.txt - -r requirements.txt
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
import json
import logging
import pandas as pd
from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_data_path
class DBDataManager():
"""
导出/更新数据库中的数据到本地文件
"""
def __init__(self) -> None:
self.local_file_dir = get_data_path()
self.logger = logging.getLogger(__name__)
self.client = MySQLClient.create_from_config_file(get_conf_path())
def _load_local_plain_file(self, name):
with open(os.path.join(self.local_file_dir, name), 'r', encoding='utf-8') as f:
return f.readlines()
def _load_local_json_file(self, name):
with open(os.path.join(self.local_file_dir, name), 'r', encoding='utf-8') as f:
return json.load(f)
def _save_plain_data(self, lines, name, mode='w'):
self.logger.info('开始保存 %s 到本地', name)
with open(os.path.join(self.local_file_dir, name), mode, encoding='utf-8') as f:
f.write('\n'.join(lines))
self.logger.info('%s 保存成功,共保存 %s 行内人', name, len(lines))
def _save_json_data(self, data, name):
self.logger.info('开始保存 %s 到本地', name)
with open(os.path.join(self.local_file_dir, name), 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False)
self.logger.info('%s 保存成功,共保存 %s 个对象', name, len(data))
def update_profile(self):
# TODO 用户画像字段
select_fields = ['*']
sql = 'SELECT {} FROM ads.ads_register_user_profiles'.format(', '.join(select_fields))
sql += ' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order)'
_, all_data = self.client.query(sql)
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'all_profile.xlsx'), index=None)
# self._save_json_data(all_data, 'all_profile.json')
def update_order_info(self):
select_fields = ['main_order_id', 'uid', 'supplier_id', 'price', 'standard_order_type']
select_fields.append('DATE_FORMAT(update_time, "%Y-%m-%d") AS update_time')
sql = 'SELECT {} FROM ods.ods_ydl_standard_order'.format(', '.join(select_fields))
_, all_data = self.client.query(sql)
df = pd.DataFrame(all_data)
df.to_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx'), index=None)
# self._save_json_data(all_data, 'all_order_info.json')
if __name__ == '__main__':
manager = DBDataManager()
# manager.update_local_data()
# print(manager.make_index())
\ No newline at end of file
# -*- coding: utf-8 -*-
import logging
import argparse
from ydl_ai_recommender.src.core.db_data_manager import DBDataManager
from ydl_ai_recommender.src.core.profile_manager import ProfileManager
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO
)
parser = argparse.ArgumentParser(description='壹点灵 咨询师推荐 算法召回')
parser.add_argument('--index_last_date', default=None, type=str, help='构建索引最后日期,超过该日期的数据不使用')
# 数据自动备份
# 新建索引数据
# 更新索引
# 性能测试
# 的
parser.add_argument(
'-t', '--task', type=str, required=True,
choices=('load_db_data', 'make_profile_index'), help='执行任务名称'
)
parser.add_argument('--output_dir', default='outputs', type=str, help='模型训练中间结果和训练好的模型保存目录')
parser.add_argument('--max_seq_length', default=128, type=int, help='tokenization 之后序列最大长度。超过会被截断,小于会补齐')
parser.add_argument('--batch_size', default=128, type=int, help='训练时一个 batch 包含多少条数据')
parser.add_argument('--learning_rate', default=1e-3, type=float, help='Adam 优化器的学习率')
parser.add_argument('--add_special_tokens', default=True, type=bool, help='bert encode 时前后是否添加特殊token')
parser.add_argument('--num_train_epochs', default=2, type=int, help='训练时运行的 epoch 数量', )
parser.add_argument('--seed', type=int, default=42, help='随机数种子,保证结果可复现')
parser.add_argument('--do_train', action='store_true', default=False)
parser.add_argument('--do_test', action='store_true', default=False)
parser.add_argument('--do_predict', action='store_true', default=False)
args = parser.parse_args()
if __name__ == '__main__':
if args.task == 'load_db_data':
manager = DBDataManager()
manager.update_order_info()
manager.update_profile()
if args.task == 'make_profile_index':
manager = ProfileManager()
manager.make_embeddings()
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
import json
import logging
from collections import Counter
import pandas as pd
from ydl_ai_recommender.src.utils import get_data_path
class OrderDataManager():
def __init__(self) -> None:
self.local_file_dir = get_data_path()
self.logger = logging.getLogger(__name__)
def load_raw_data(self):
df = pd.read_excel(os.path.join(self.local_file_dir, 'all_order_info.xlsx'))
return df
def make_index(self):
"""
构建索引
用户-咨询师 索引
top50 咨询师列表 用于冷启动
"""
df = self.load_raw_data()
self.logger.info('加载数据完成,共加载 %s 条数据', len(df))
user_order = {}
for index, row in df.iterrows():
uid, supplier_id = row['uid'], row['supplier_id']
if uid not in user_order:
user_order[uid] = {}
if supplier_id not in user_order[uid]:
user_order[uid][supplier_id] = []
user_order[uid][supplier_id].append([row['price'], row['update_time']])
index = {}
for uid, orders in user_order.items():
supplier_values = []
for supplier_id, infos in orders.items():
# 订单越多排序约靠前,相同数量订单,最新订单约晚越靠前
value = len(infos)
latest_time = max([info[1] for info in infos])
supplier_values.append([supplier_id, value, latest_time])
index[uid] = sorted(supplier_values, key=lambda x: (x[1], x[2]), reverse=True)
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), 'w', encoding='utf-8') as f:
json.dump(index, f, ensure_ascii=False)
# 订单最多的咨询师
supplier_cnter = Counter(df['supplier_id'])
top50_supplier = []
for key, _ in supplier_cnter.most_common(50):
top50_supplier.append(str(key))
with open(os.path.join(self.local_file_dir, 'top50_supplier.txt'), 'w', encoding='utf-8') as f:
f.write('\n'.join(top50_supplier))
if __name__ == '__main__':
manager = OrderDataManager()
print(manager.make_index())
\ No newline at end of file
# -*- coding: utf-8 -*-
import json
from typing import Dict, List, Any
import pandas as pd
from .country_code_profile import CountryCodeProfile
from .profile import ChannelIdTypeProfile
class FfromLoginProfile():
"""
登录来源
主要是不同android品牌特征,android和ios特征在其他画像中已有,这里不重复构建
"""
def __init__(self) -> None:
pass
def convert(self, value):
ret = [0, 0, 0, 0, 0]
try:
value = value.lower()
except Exception as e:
return ret
if 'huawei' in value:
ret[0] = 1
elif 'vivo' in value:
ret[1] = 1
elif 'oppo' in value:
ret[2] = 1
elif 'xiaomi' in value:
ret[3] = 1
else:
ret[4] = 1
return ret
class UserPreferenceCateProfile():
def __init__(self) -> None:
cate_list = ['个人成长', '亲子教育', '人际关系', '婚姻家庭', '心理健康', '恋爱情感', '情绪压力', '职场发展']
self.cate_index = {
cate: index for index, cate in enumerate(cate_list)
}
def convert(self, value):
ret = [0.] * 8
if pd.isnull(value):
return ret
if isinstance(value, str):
try:
value = json.loads(value)
except Exception as e:
return ret
for info in value:
ret[self.cate_index[info['cate_name']]] = info['preference_score']
return ret
class NumClassProfile():
def __init__(self, split_points, mode='le') -> None:
"""
mode : le 前开后闭; be: 前闭后开;默认前开后闭
"""
self.split_points = split_points
self.mode = mode
def value_index(self, value):
for i, v in enumerate(self.split_points):
if self.mode == 'be':
if value < v:
return i
else:
if value <= v:
return i
return len(self.split_points)
def convert(self, value):
ret = [0] * (len(self.split_points) + 1)
if pd.isnull(value):
return ret
try:
value = float(value)
index = self.value_index(value)
ret[index] = 1
except:
return ret
return ret
class AidiCstBiasPriceProfile():
def convert(self, value):
ret = [0] * 6
if pd.isnull(value):
return ret
for v in value:
try:
ret[v['level'] - 1] = 1
except Exception:
pass
return ret
class MultiChoiceProfile():
def __init__(self, option_dict: Dict[Any, int]) -> None:
self.option_dict = option_dict
def convert(self, value: List):
ret = [0] * len(self.option_dict)
if pd.isnull(value):
return ret
for v in value:
try:
i = self.option_dict[v]
ret[i] = 1
except Exception as e:
pass
return ret
class CityProfile():
def __init__(self) -> None:
self.default_city_codes = [
'330500', '640200', '130200', '620200', '321000', '530600', '650500', '410200', '511400', '450200',
'610800', '220400', '430400', '320500', '410400', '341100', '420300', '410500', '640100', '440100',
'420500', '650200', '441900', '211200', '210400', '140700', '131000', '440700', '340100', '350200',
'371100', '370900', '130500', '451300', '331100', '320700', '710100', '500000', '610300', '370100',
'610500', '450300', '520100', '140200', '320400', '210500', '440300', '610700', '341800', '210300',
'340200', '120100', '340500', '210200', '222400', '370600', '110100', '441200', '230500', '510100',
'330700', '330600', '370300', '230600', '450100', '340300', '651800', '340800', '430200', '421300',
'220600', '150200', '433100', '440500', '620100', '710200', '130100', '131100', '150600', '430100',
'150100', '130400', '140600', '140300', '410300', '620600', '330300', '321100', '320900', '630100',
'320100', '410900', '510400', '620800', '610600', '220300', '420600', '510700', '130300', '411400',
'310000', '341200', '370500', '710500', '231100', '152900', '371500', '220100', '360700', '150500',
'331000', '360600', '371000', '341600', '130600', '230100', '410800', '370700', '410700', '430800',
'410100', '210800', '330400', '460200', '650100', '310100', '350500', '360400', '320300', '500100',
'360900', '610100', '350100', '350400', '530100', '320600', '130900', '371300', '421200', '210700',
'220200', '130700', '320800', '420100', '110000', '150400', '442000', '469002', '360100', '150800',
'441300', '460100', '610200', '210100', '210900', '371400', '621000', '141000', '330100', '220700',
'371700', '370800', '211400', '330200', '140400', '120000', '231200', '140100', '431100', '320200',
'451000', '370200', '511900', '361100', '610400', '440600', '411100', '231000', '360300'
]
self.city_codes = self.default_city_codes
self.city_code_dict = {
code: index for index, code in enumerate(self.city_codes)
}
def convert(self, value):
ret = [0] * len(self.city_code_dict)
if pd.isnull(value):
return ret
value = str(value)
try:
i = self.city_code_dict[value]
ret[i] = 1
except Exception as e:
pass
return ret
class AidiCstBiasCityProfile(CityProfile):
def __init__(self) -> None:
super().__init__()
def convert(self, value):
ret = [0] * len(self.city_code_dict)
if pd.isnull(value):
return ret
if not value:
return ret
if isinstance(value, str):
try:
value = json.loads(value)
except Exception as e:
pass
for v in value:
try:
ret[self.city_code_dict[v]] = 1
except Exception as e:
pass
return ret
profile_converters = [
['country_code', CountryCodeProfile()],
['channel_id_type', ChannelIdTypeProfile()],
['ffrom_login', FfromLoginProfile()],
['user_preference_cate', UserPreferenceCateProfile()],
['consult_pay_money', NumClassProfile([100, 300, 1000])],
['listen_pay_money', NumClassProfile([0, 50])],
['test_items_pay_money', NumClassProfile([0, 10])],
['course_pay_money', NumClassProfile([0])],
['consult_order_num', NumClassProfile([0, 1])],
['listen_order_num', NumClassProfile([0])],
['test_items_order_num', NumClassProfile([0])],
['course_order_num', NumClassProfile([0])],
['aidi_cst_bias_city', AidiCstBiasCityProfile()],
['aidi_cst_bias_sex', MultiChoiceProfile({i + 1: i for i in range(2)})],
['aidi_cst_bias_price', AidiCstBiasPriceProfile()],
['aidi_cst_bias_server_type', MultiChoiceProfile({i + 1: i for i in range(4)})],
['user_login_city', CityProfile()],
['d30_inquire_order_num', NumClassProfile([0, 1])],
['d30_session_num', NumClassProfile([0, 1])],
]
# -*- coding: utf-8 -*-
class CountryCodeProfile():
def __init__(self) -> None:
pass
def convert(self, value):
try:
value = int(value)
except Exception as e:
return [0, 0, 1]
if value == 86:
return [1, 0, 0]
else:
return [0, 1, 0]
\ No newline at end of file
# -*- coding: utf-8 -*-
class ChannelIdTypeProfile():
def __init__(self) -> None:
pass
def convert(self, value):
try:
value = int(value)
except Exception as e:
return [0, 0, 1]
if value == 1:
return [1, 0, 0]
elif value == 2:
return [0, 1, 0]
else:
return [0, 0, 1]
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
import json
import logging
import pandas as pd
from ydl_ai_recommender.src.utils import get_data_path
from ydl_ai_recommender.src.core.profile import profile_converters
class ProfileManager():
"""
订单用户画像数据管理
"""
def __init__(self) -> None:
self.local_file_dir = get_data_path()
self.profile_file_path = os.path.join(self.local_file_dir, 'all_profile.json')
self.logger = logging.getLogger(__name__)
def _load_profile_data(self):
return pd.read_excel(os.path.join(self.local_file_dir, 'all_profile.xlsx'))
# with open(self.profile_file_path, 'r', encoding='utf-8') as f:
# return json.load(f)
def profile_to_embedding(self, profile):
"""
将用户画像信息转换为向量
"""
embedding = []
for [name, converter] in profile_converters:
embedding.extend(converter.convert(profile[name]))
return embedding
def make_embeddings(self):
user_profiles = self._load_profile_data()
self.logger.info('订单用户画像数据加载完成,共加载 %s 条', len(user_profiles))
user_ids, embeddings = [], []
self.logger.info('开始构建订单用户的用户画像向量')
for _, profile in user_profiles.iterrows():
user_ids.append(str(profile['uid']))
embeddings.append(self.profile_to_embedding(profile))
self.logger.info('用户画像向量构建完成,共构建 %s 用户', len(user_ids))
with open(os.path.join(self.local_file_dir, 'user_embeddings_ids.txt'), 'w', encoding='utf-8') as f:
f.write('\n'.join(user_ids))
with open(os.path.join(self.local_file_dir, 'user_embeddings.json'), 'w', encoding='utf-8') as f:
json.dump(embeddings, f, ensure_ascii=False)
return embeddings
def make_virtual_embedding(self):
user_ids = []
embeddings = []
if __name__ == '__main__':
manager = ProfileManager()
manager.make_embeddings()
# manager.update_local_data()
# print(manager.make_index())
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
import json
import logging
from itertools import combinations
from ydl_ai_recommender.src.utils import get_data_path
class BasicUserSimilarity():
def __init__(self) -> None:
self.local_file_dir = get_data_path()
self.logger = logging.getLogger(__name__)
def compute_similarity(self):
user_counselor_index = {}
with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), encoding='utf-8') as f:
user_counselor_index = json.load(f)
user_like_set = {}
counselor_user_index = {}
for user, counselors in user_counselor_index.items():
user_like_set[user] = len(counselors)
for [counselor, _, _] in counselors:
if counselor not in counselor_user_index:
counselor_user_index[counselor] = []
counselor_user_index[counselor].append(user)
# 两个用户与同一个咨询师有订单,就认为两个用户相似
self.logger.info('开始构建用户相似性关系')
relations = {}
for users in counselor_user_index.values():
for [_u1, _u2] in combinations(users, 2):
u1, u2 = min(_u1, _u2), max(_u1, _u2)
key = '{}_{}'.format(u1, u2)
if key in relations:
continue
relations[key] = 1.0 / (user_like_set[u1] * user_like_set[u2])
self.logger.info('用户相似性关系构建完成,共有 %s 对关系', len(relations))
with open(os.path.join(self.local_file_dir, 'user_similarity.json'), 'w', encoding='utf-8') as f:
json.dump(relations, f, ensure_ascii=False, indent=2)
bs = BasicUserSimilarity()
bs.compute_similarity()
\ No newline at end of file
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging
import configparser
import pymysql import pymysql
class MySQLClient(): class MySQLClient():
def __init__(self, host, port, user, password, logger=None) -> None: def __init__(self, host, port, user, password) -> None:
self.logger = logger self.logger = logging.getLogger(__name__)
self.connection = pymysql.connect( self.connection = pymysql.connect(
host=host, host=host,
port=port, port=port,
...@@ -16,18 +19,35 @@ class MySQLClient(): ...@@ -16,18 +19,35 @@ class MySQLClient():
cursorclass=pymysql.cursors.DictCursor cursorclass=pymysql.cursors.DictCursor
) )
self.cursor = self.connection.cursor() self.cursor = self.connection.cursor()
self._log_info('数据库连接成功')
def _log_info(self, text, *args, **params):
if self.logger: if self.logger:
self.logger.info('数据库连接成功') self.logger.info(text, *args, **params)
def query(self, sql): def query(self, sql):
sql += ' limit 1000'
self._log_info('begin execute sql: %s', sql)
row_count = self.cursor.execute(sql) row_count = self.cursor.execute(sql)
data = self.cursor.fetchall() data = self.cursor.fetchall()
self._log_info('fetch row count: %s', row_count)
return row_count, data return row_count, data
def __del__(self): def __del__(self):
try: try:
self.cursor.close() self.cursor.close()
self.connection.close() self.connection.close()
print('dataset disconnected') self._log_info('dataset disconnected')
except Exception as e: except Exception as e:
print(e) print(e)
\ No newline at end of file
@classmethod
def create_from_config_file(cls, config_file, section='ADB'):
config = configparser.RawConfigParser()
config.read(config_file)
return cls(
config.get(section, 'host'),
config.getint(section, 'port'),
config.get(section, 'user'),
config.get(section, 'password')
)
\ No newline at end of file
# -*- coding: utf-8 -*-
import os
def get_project_path():
current_path = os.path.abspath(__file__)
current_dir = os.path.split(current_path)[0]
return os.path.abspath(os.path.join(current_dir, '../..'))
def get_data_path():
project_path = get_project_path()
return os.path.join(project_path, 'data')
def get_conf_path():
project_path = get_project_path()
return os.path.join(project_path, 'conf/private.conf')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment