Commit 6697d73b by 柴鹏飞

创建接口服务

parent 5b07caca
...@@ -105,7 +105,7 @@ def do_test(args): ...@@ -105,7 +105,7 @@ def do_test(args):
user_profile_dict = load_local_user_profile() user_profile_dict = load_local_user_profile()
old_users, test_orders = load_test_data() old_users, test_orders = load_test_data()
recommender = UserCFRecommender(top_n=args.top_n, k=args.k) recommender = UserCFRecommender(top_n=args.top_n, k=args.k, is_use_db=False)
result_detail = [] result_detail = []
for index, order_info in test_orders.iterrows(): for index, order_info in test_orders.iterrows():
...@@ -141,120 +141,6 @@ def do_test(args): ...@@ -141,120 +141,6 @@ def do_test(args):
json.dump(result_detail, f, ensure_ascii=False, indent=2) json.dump(result_detail, f, ensure_ascii=False, indent=2)
def test(args):
# 构建用户画像字典,不用每次都从数据库中获取
profile_manager = ProfileManager()
df = profile_manager.load_test_profile_data()
user_profile_dict = {}
for _, row in df.iterrows():
user_profile_dict[row['uid']] = row
manager = OrderDataManager()
# 加载训练订单数据,为后面判断用户是否为新用户
train_orders = manager.load_raw_data()
old_users = set(train_orders['uid'])
logger.info('订单用户数 %s ', len(old_users))
# 加载测试数据
test_orders = manager.load_test_order_data()
logger.info('加载测试数据成功,共加载 %s 条', len(test_orders))
recommender = UserCFRecommender(top_n=args.top_n, k=args.k)
result_detail = []
for index, order_info in test_orders.iterrows():
if args.max_test > 0:
if index >= args.max_test:
break
uid = order_info['uid']
profile = user_profile_dict.get(uid)
if profile is None:
continue
recommend_result = recommender.recommend_with_profile(profile, is_merge=False)
recall_resons = []
for rr in recommend_result:
if rr['counselor'] == order_info['supplier_id']:
recall_resons.append(rr['from'])
result_detail.append({
'uid': uid,
'supplier_id': order_info['supplier_id'],
'is_old_user': uid in old_users,
'recall_counselors': recommend_result,
'is_recall': len(recall_resons) > 0,
'recall_reason': '|'.join(recall_resons),
})
# 结果报告
metrics = {
'all_test_cnt': 0,
'all_recall_cnt': 0,
'old_user_test_cnt': 0,
'old_user_recall_cnt': 0,
'new_user_test_cnt': 0,
'new_user_recall_cnt': 0,
'same_user_recall_cnt': 0,
'similar_user_recall_cnt': 0,
}
for rd in result_detail:
metrics['all_test_cnt'] += 1
if rd['is_recall']:
metrics['all_recall_cnt'] += 1
is_same_user, is_similar_user = False, False
for counselor in rd['recall_counselors']:
from_id = counselor['from'].split(' ')[1]
if from_id == rd['uid']:
is_same_user = True
if from_id != rd['uid']:
is_similar_user = True
if is_same_user:
metrics['same_user_recall_cnt'] += 1
if is_similar_user:
metrics['similar_user_recall_cnt'] += 1
if rd['is_old_user']:
metrics['old_user_test_cnt'] += 1
if rd['is_recall']:
metrics['old_user_recall_cnt'] += 1
else:
metrics['new_user_test_cnt'] += 1
if rd['is_recall']:
metrics['new_user_recall_cnt'] += 1
logger.info('==' * 20 + ' 测试结果 ' + '==' * 20)
logger.info('')
logger.info('相关参数配置: 相似用户数(k) %s ;每个相似用户召回咨询师数(top_n) %s', args.k, args.top_n)
logger.info('--' * 45)
logger.info('')
logger.info('{:<13}{:<7}{:<7}{:<7}'.format('', '样本数', '召回数', '召回率'))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('整体\u3000', metrics['all_test_cnt'], metrics['all_recall_cnt'], metrics['all_recall_cnt'] / metrics['all_test_cnt']))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('老用户', metrics['old_user_test_cnt'], metrics['old_user_recall_cnt'], metrics['old_user_recall_cnt'] / metrics['old_user_test_cnt']))
logger.info('{:<10}{:<10}{:<10}{:<10.2%}'.format('新用户', metrics['new_user_test_cnt'], metrics['new_user_recall_cnt'], metrics['new_user_recall_cnt'] / metrics['new_user_test_cnt']))
logger.info('--' * 45)
logger.info('')
logger.info('用户自己召回数 {} 占总召回比例 {:.2%}'.format(metrics['same_user_recall_cnt'], metrics['same_user_recall_cnt'] / metrics['all_recall_cnt']))
logger.info('相似用户召回数 {} 占总召回比例 {:.2%}'.format(metrics['similar_user_recall_cnt'], metrics['similar_user_recall_cnt'] / metrics['all_recall_cnt']))
# 召回位置指标
with open('result_detail.json', 'w', encoding='utf-8') as f:
json.dump(result_detail, f, ensure_ascii=False, indent=2)
def update_test_data(args): def update_test_data(args):
""" 更新测试数据,没有就新建 """ """ 更新测试数据,没有就新建 """
start_date = '' start_date = ''
...@@ -300,5 +186,5 @@ if __name__ == '__main__': ...@@ -300,5 +186,5 @@ if __name__ == '__main__':
logger.info('测试数据创建时间 %s', args.start_date) logger.info('测试数据创建时间 %s', args.start_date)
update_test_data(args) update_test_data(args)
logger.info('开始执行测试任务') logger.info('开始执行测试任务,测试模式为 %s', args.mode)
do_test(args) do_test(args)
\ No newline at end of file
...@@ -10,6 +10,7 @@ import numpy as np ...@@ -10,6 +10,7 @@ import numpy as np
from ydl_ai_recommender.src.core.profile_manager import ProfileManager from ydl_ai_recommender.src.core.profile_manager import ProfileManager
from ydl_ai_recommender.src.data.mysql_client import MySQLClient from ydl_ai_recommender.src.data.mysql_client import MySQLClient
from ydl_ai_recommender.src.utils import get_conf_path, get_data_path from ydl_ai_recommender.src.utils import get_conf_path, get_data_path
from ydl_ai_recommender.src.utils.log import create_logger
class Recommender(): class Recommender():
...@@ -23,14 +24,19 @@ class Recommender(): ...@@ -23,14 +24,19 @@ class Recommender():
class UserCFRecommender(Recommender): class UserCFRecommender(Recommender):
def __init__(self, top_n=5, k=5, is_lazy=True) -> None: def __init__(self, top_n=5, k=5, is_use_db=True) -> None:
super().__init__() super().__init__()
# 召回 top_n 个相似用户 # 召回 top_n 个相似用户
self.top_n = top_n self.top_n = top_n
# 每个召回的用户取 k 个相关咨询师 # 每个召回的用户取 k 个相关咨询师
self.k = k self.k = k
if is_lazy is False:
self.logger = create_logger(__name__, 'recommender.log')
if is_use_db:
self.client = MySQLClient.create_from_config_file(get_conf_path()) self.client = MySQLClient.create_from_config_file(get_conf_path())
else:
self.logger.warn('未连接数据库')
self.manager = ProfileManager() self.manager = ProfileManager()
self.local_file_dir = get_data_path() self.local_file_dir = get_data_path()
self.load_data() self.load_data()
...@@ -59,7 +65,7 @@ class UserCFRecommender(Recommender): ...@@ -59,7 +65,7 @@ class UserCFRecommender(Recommender):
self.order_user_counselor_index = order_user_counselor_index self.order_user_counselor_index = order_user_counselor_index
self.default_counselor = [{ self.default_counselor = [{
'counselor': str(user), 'counselor': str(user),
'score': index + 1, 'score': 1 - 0.01 * index,
'from': 'top_50', 'from': 'top_50',
} for index, user in enumerate(default_counselor)] } for index, user in enumerate(default_counselor)]
...@@ -76,7 +82,7 @@ class UserCFRecommender(Recommender): ...@@ -76,7 +82,7 @@ class UserCFRecommender(Recommender):
return [] return []
return all_data[0] return all_data[0]
except Exception as e: except Exception as e:
self.logging.exception("Exception occurred") self.logger.error('获取用户画像数据失败', exc_info=True)
return [] return []
...@@ -100,12 +106,12 @@ class UserCFRecommender(Recommender): ...@@ -100,12 +106,12 @@ class UserCFRecommender(Recommender):
recommend_data = [{ recommend_data = [{
'counselor': str(user[0]), 'counselor': str(user[0]),
'score': float(score) * (index + 1), 'score': 1 / max(0.01, float(score) * (index + 1)),
'from': 'similar_users {}'.format(similar_user_id), 'from': 'similar_users {}'.format(similar_user_id),
} for index, user in enumerate(similar_user_counselor[:self.top_n])] } for index, user in enumerate(similar_user_counselor[:self.top_n])]
counselors.extend(recommend_data) counselors.extend(recommend_data)
counselors.sort(key=lambda x: x['score']) counselors.sort(key=lambda x: x['score'], reverse=True)
return counselors return counselors
...@@ -113,6 +119,7 @@ class UserCFRecommender(Recommender): ...@@ -113,6 +119,7 @@ class UserCFRecommender(Recommender):
user_embedding = self.user_token(user_profile) user_embedding = self.user_token(user_profile)
counselors = self._recommend(user_embedding) counselors = self._recommend(user_embedding)
# count == 0 时,不追加默认推荐咨询师
if count > 0: if count > 0:
counselors.extend(self._recommend_top()) counselors.extend(self._recommend_top())
......
...@@ -20,43 +20,27 @@ class MySQLClient(): ...@@ -20,43 +20,27 @@ class MySQLClient():
cursorclass=pymysql.cursors.DictCursor cursorclass=pymysql.cursors.DictCursor
) )
self.cursor = self.connection.cursor() self.cursor = self.connection.cursor()
self._log_info('数据库连接成功') self.logger.info('数据库连接成功')
def _connect(self):
self.connection = pymysql.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
self.cursor = self.connection.cursor()
self._log_info('数据库连接成功')
def _log_info(self, text, *args, **params):
if self.logger:
self.logger.debug(text, *args, **params)
def query(self, sql): def query(self, sql):
if self.cursor is None:
self._connect()
# sql += ' limit 1000' # sql += ' limit 1000'
self._log_info('begin execute sql: %s', sql) self.logger.debug('begin execute sql: %s', sql)
row_count = self.cursor.execute(sql) row_count = self.cursor.execute(sql)
data = self.cursor.fetchall() data = self.cursor.fetchall()
self._log_info('fetch row count: %s', row_count) self.logger.debug('fetch row count: %s', row_count)
return row_count, data return row_count, data
def __del__(self): def __del__(self):
try: try:
self.cursor.close() self.cursor.close()
self.connection.close() self.connection.close()
self._log_info('dataset disconnected') self.logger.info('dataset disconnected')
except Exception as e: except Exception as e:
print(e) print(e)
@classmethod @classmethod
def create_from_config_file(cls, config_file, section='ADB'): def create_from_config_file(cls, config_file, section='ADB'):
config = configparser.RawConfigParser() config = configparser.RawConfigParser()
......
# -*- coding: utf-8 -*-
import json
from concurrent.futures import ThreadPoolExecutor
import tornado.web
import tornado.ioloop
import tornado.options
import tornado.httpserver
from tornado.concurrent import run_on_executor
from ydl_ai_recommender.src.utils.log import create_logger
from ydl_ai_recommender.src.core.recommender import UserCFRecommender
logger = create_logger(__name__, 'service.log', is_rotating=True)
recommender = UserCFRecommender(top_n=5, k=5)
class RecommendHandler(tornado.web.RequestHandler):
executor = ThreadPoolExecutor(2)
@tornado.gen.coroutine
def get(self):
uid = self.get_argument('uid', None)
if uid is None:
logger.warn('请求参数不正确,无uid')
row = self.get_argument('row', 10)
try:
row = int(row)
except Exception as e:
logger.warn('row=%s 不是数字', row)
row = 10
ret = yield self.run(uid, row)
self.write(ret)
@tornado.gen.coroutine
def post(self):
param = json.loads(self.request.body.decode('utf-8'))
uid = param.get('uid', None)
row = param.get('row', 10)
if uid is None:
logger.warn('请求参数不正确,无uid')
ret = yield self.run(uid, row)
self.write(ret)
@run_on_executor
def run(self, uid, row=10):
logger.info('request@@uid=%s@@row=%s', uid, row)
try:
recommend_result = recommender.recommend(uid, count=row, is_merge=True)
ret = {
'status': 'success',
'code': 0,
'data': recommend_result,
'row': len(recommend_result),
}
except Exception as e:
logger.error('执行推荐函数报错', exc_info=True)
ret = {
'status': 'error',
'code': 1,
'data': [],
'row': 0,
}
ret_str = json.dumps(ret, ensure_ascii=False)
logger.info('response@@uid=%s@@ret=%s', uid, ret_str)
return ret_str
if __name__ == '__main__':
tornado.options.define('port', default=8868, type=int, help='服务启动的端口号')
tornado.options.parse_command_line()
app = tornado.web.Application(handlers=[(r'/ai_counselor_recommend', RecommendHandler)], autoreload=False, debug=False)
http_server = tornado.httpserver.HTTPServer(app)
http_server.listen(tornado.options.options.port)
tornado.ioloop.IOLoop.instance().start()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment