Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
Y
ydl_ai_recommender
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
闫发泽
ydl_ai_recommender
Commits
523f5373
Commit
523f5373
authored
Dec 19, 2022
by
柴鹏飞
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
refactor code
parent
7401bbe5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
153 additions
and
403 deletions
+153
-403
test.py
bin/test.py
+2
-2
update.py
bin/update.py
+72
-51
indexer.py
src/core/indexer.py
+0
-0
__init__.py
src/core/manager/__init__.py
+2
-4
chat_data_manager.py
src/core/manager/chat_data_manager.py
+0
-64
database_manager.py
src/core/manager/database_manager.py
+12
-23
manager.py
src/core/manager/manager.py
+2
-2
order_data_manager.py
src/core/manager/order_data_manager.py
+2
-81
profile_manager.py
src/core/manager/profile_manager.py
+4
-33
user_counselor_index_manager.py
src/core/manager/user_counselor_index_manager.py
+0
-53
profile.py
src/core/profile.py
+35
-22
country_code_profile.py
src/core/profile/country_code_profile.py
+0
-17
profile.py
src/core/profile/profile.py
+0
-20
recommender.py
src/core/recommender.py
+20
-30
mysql_client.py
src/data/mysql_client.py
+2
-1
No files found.
bin/test.py
View file @
523f5373
...
@@ -190,11 +190,11 @@ def update_test_data(args):
...
@@ -190,11 +190,11 @@ def update_test_data(args):
# 订单数据
# 订单数据
manager
=
OrderDataManager
(
client
)
manager
=
OrderDataManager
(
client
)
manager
.
update_test_
order_
data
(
conditions
=
conditions
)
manager
.
update_test_data
(
conditions
=
conditions
)
# 用户画像数据
# 用户画像数据
manager
=
ProfileManager
(
client
)
manager
=
ProfileManager
(
client
)
manager
.
update_test_
profile
(
conditions
=
conditions
)
manager
.
update_test_
data
(
conditions
=
conditions
)
logger
.
info
(
'测试数据更新完成'
)
logger
.
info
(
'测试数据更新完成'
)
...
...
bin/update.py
View file @
523f5373
...
@@ -4,9 +4,17 @@ import os
...
@@ -4,9 +4,17 @@ import os
import
argparse
import
argparse
from
datetime
import
datetime
from
datetime
import
datetime
from
ydl_ai_recommender.src.core.manager
import
OrderDataManager
from
ydl_ai_recommender.src.core.manager
import
(
from
ydl_ai_recommender.src.core.manager
import
ChatDataManager
OrderDataManager
,
from
ydl_ai_recommender.src.core.manager
import
ProfileManager
ChatDataManager
,
ProfileManager
,
)
from
ydl_ai_recommender.src.core.indexer
import
(
UserCounselorDefaultIndexer
,
UserCounselorOrderIndexer
,
UserCounselorChatIndexer
,
UserCounselorCombinationIndexer
,
)
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
from
ydl_ai_recommender.src.utils
import
get_conf_path
,
get_project_path
from
ydl_ai_recommender.src.utils
import
get_conf_path
,
get_project_path
from
ydl_ai_recommender.src.utils.log
import
create_logger
from
ydl_ai_recommender.src.utils.log
import
create_logger
...
@@ -18,79 +26,92 @@ logger = create_logger(__name__, 'update.log')
...
@@ -18,79 +26,92 @@ logger = create_logger(__name__, 'update.log')
parser
=
argparse
.
ArgumentParser
(
description
=
'壹点灵 咨询师推荐 算法召回 离线更新数据模型'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'壹点灵 咨询师推荐 算法召回 离线更新数据模型'
)
parser
.
add_argument
(
parser
.
add_argument
(
'-t'
,
'--task'
,
type
=
str
,
required
=
True
,
'-t'
,
'--task'
,
type
=
str
,
required
=
True
,
choices
=
(
'load_db_data'
,
'make_embedding'
,
'make_
virtual_embedding
'
),
help
=
'执行任务名称'
choices
=
(
'load_db_data'
,
'make_embedding'
,
'make_
index
'
),
help
=
'执行任务名称'
)
)
parser
.
add_argument
(
'--index_last_date'
,
default
=
None
,
type
=
str
,
help
=
'构建索引最后日期,超过该日期的数据不使用'
)
parser
.
add_argument
(
'--index_last_date'
,
default
=
None
,
type
=
str
,
help
=
'构建索引最后日期,超过该日期的数据不使用'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
__name__
==
'__main__'
:
def
initialize_dir
():
# 创建数据目录
now
=
datetime
.
now
()
really_data_dir
=
os
.
path
.
join
(
get_project_path
(),
'data_{}'
.
format
(
now
.
strftime
(
'
%
Y
%
m
%
d_
%
H
%
M
%
S'
)))
default_data_dir
=
os
.
path
.
join
(
get_project_path
(),
'data'
)
if
args
.
task
==
'load_db_data'
:
# 判断data目录是否存在
logger
.
info
(
''
)
if
os
.
path
.
exists
(
default_data_dir
):
if
os
.
path
.
islink
(
default_data_dir
):
os
.
unlink
(
default_data_dir
)
else
:
logger
.
error
(
'
%
s 目录已经存在!请备份后删除该目录再重新执行本操作'
,
default_data_dir
)
# 创建数据目录
os
.
mkdir
(
really_data_dir
)
now
=
datetime
.
now
()
logger
.
info
(
'创建数据保存目录成功'
)
really_data_dir
=
os
.
path
.
join
(
get_project_path
(),
'data_{}'
.
format
(
now
.
strftime
(
'
%
Y
%
m
%
d_
%
H
%
M
%
S'
)))
default_data_dir
=
os
.
path
.
join
(
get_project_path
(),
'data'
)
# 创建软连接
os
.
symlink
(
really_data_dir
,
default_data_dir
)
logger
.
info
(
'创建软连接成功
%
s ->
%
s'
,
really_data_dir
,
default_data_dir
)
# 判断data目录是否存在
# TODO 历史数据删除
if
os
.
path
.
exists
(
default_data_dir
):
if
os
.
path
.
islink
(
default_data_dir
):
os
.
unlink
(
default_data_dir
)
else
:
logger
.
error
(
'
%
s 目录已经存在!请备份后删除该目录再重新执行本操作'
,
default_data_dir
)
os
.
mkdir
(
really_data_dir
)
logger
.
info
(
'创建数据保存目录成功'
)
# 创建软连接
os
.
symlink
(
really_data_dir
,
default_data_dir
)
logger
.
info
(
'创建软连接成功
%
s ->
%
s'
,
really_data_dir
,
default_data_dir
)
# TODO 历史数据删除
if
__name__
==
'__main__'
:
logger
.
info
(
''
)
if
args
.
task
==
'load_db_data'
:
initialize_dir
()
logger
.
info
(
'开始从数据库中更新数据'
)
logger
.
info
(
'开始从数据库中更新数据'
)
client
=
MySQLClient
.
create_from_config_file
(
get_conf_path
())
client
=
MySQLClient
.
create_from_config_file
(
get_conf_path
())
logger
.
info
(
'开始从数据库中更新画像数据'
)
managers
=
[
profile_manager
=
ProfileManager
(
client
)
[
'画像数据'
,
ProfileManager
(
client
)],
profile_manager
.
update_profile
()
[
'订单数据'
,
OrderDataManager
(
client
)],
[
'询单数据'
,
ChatDataManager
(
client
)],
logger
.
info
(
'开始从数据库中更新订单数据'
)
]
order_data_manager
=
OrderDataManager
(
client
)
order_data_manager
.
update_order_data
()
logger
.
info
(
'开始从数据库中更新询单数据'
)
for
[
name
,
manager
]
in
managers
:
chat_data_manager
=
ChatDataManager
(
client
)
logger
.
info
(
'开始更新
%
s'
,
name
)
chat_data_manager
.
update_data
()
manager
.
update_data
()
logger
.
info
(
'
%
s 更新完成'
,
name
)
logger
.
info
(
''
)
logger
.
info
(
'所有数据更新数据完成'
)
logger
.
info
(
'所有数据更新数据完成'
)
if
args
.
task
==
'make_embedding'
:
if
args
.
task
==
'make_embedding'
:
logger
.
info
(
''
)
logger
.
info
(
''
)
logger
.
info
(
'--'
*
50
)
logger
.
info
(
'开始构建用户特征 embedding'
)
logger
.
info
(
'开始构建用户特征 embedding'
)
manager
=
ProfileManager
()
manager
=
ProfileManager
()
manager
.
make_embeddings
()
manager
.
make_embeddings
()
logger
.
info
(
'用户特征 embedding 构建完成'
)
logger
.
info
(
'用户特征 embedding 构建完成'
)
logger
.
info
(
'开始构建订单相关索引'
)
if
args
.
task
==
'make_index'
:
manager
=
OrderDataManager
()
indexers
=
[
manager
.
make_index
()
[
'[用户->咨询师]兜底关系索引'
,
UserCounselorDefaultIndexer
()],
logger
.
info
(
'订单相关索引 构建完成'
)
[
'基于订单数据的[用户->咨询师]关系索引'
,
UserCounselorOrderIndexer
()],
[
'基于询单数据的[用户->咨询师]关系索引'
,
UserCounselorChatIndexer
()],
[
'基于多种数据组合的[用户->咨询师]关系索引'
,
UserCounselorCombinationIndexer
()],
]
logger
.
info
(
'开始构建询单相关索引'
)
logger
.
info
(
''
)
chat_data_manager
=
ChatDataManager
()
logger
.
info
(
'--'
*
50
)
chat_data_manager
.
make_index
()
logger
.
info
(
'询单相关索引 构建完成'
)
for
[
name
,
indexer
]
in
indexers
:
logger
.
info
(
'开始构建
%
s'
,
name
)
indexer
.
make_index
()
logger
.
info
(
'
%
s 构建完成'
,
name
)
logger
.
info
(
''
)
if
args
.
task
==
'make_virtual_embedding'
:
logger
.
info
(
'所有索引更新数据完成'
)
logger
.
info
(
''
)
logger
.
info
(
'开始构建用户特征虚拟embedding'
)
manager
=
ProfileManager
()
# if args.task == 'make_virtual_embedding':
manager
.
make_virtual_embedding
()
# logger.info('')
logger
.
info
(
'用户特征虚拟 embedding 构建完成'
)
# logger.info('开始构建用户特征虚拟embedding')
\ No newline at end of file
# manager = ProfileManager()
# manager.make_virtual_embedding()
# logger.info('用户特征虚拟 embedding 构建完成')
\ No newline at end of file
src/core/indexer.py
0 → 100644
View file @
523f5373
This diff is collapsed.
Click to expand it.
src/core/manager/__init__.py
View file @
523f5373
...
@@ -5,6 +5,4 @@ from .database_manager import DatabaseDataManager
...
@@ -5,6 +5,4 @@ from .database_manager import DatabaseDataManager
from
.profile_manager
import
ProfileManager
from
.profile_manager
import
ProfileManager
from
.chat_data_manager
import
ChatDataManager
from
.chat_data_manager
import
ChatDataManager
from
.order_data_manager
import
OrderDataManager
from
.order_data_manager
import
OrderDataManager
\ No newline at end of file
from
.user_counselor_index_manager
import
UserCounselorIndexManager
\ No newline at end of file
src/core/manager/chat_data_manager.py
View file @
523f5373
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
os
import
json
from
datetime
import
datetime
,
timedelta
import
pandas
as
pd
import
pandas
as
pd
from
ydl_ai_recommender.src.utils.log
import
create_logger
from
ydl_ai_recommender.src.utils.log
import
create_logger
...
@@ -13,8 +9,6 @@ from ydl_ai_recommender.src.core.manager import DatabaseDataManager
...
@@ -13,8 +9,6 @@ from ydl_ai_recommender.src.core.manager import DatabaseDataManager
class
ChatDataManager
(
DatabaseDataManager
):
class
ChatDataManager
(
DatabaseDataManager
):
def
__init__
(
self
,
client
=
None
)
->
None
:
def
__init__
(
self
,
client
=
None
)
->
None
:
super
()
.
__init__
(
client
,
create_logger
(
__name__
,
'chat_data_manager.log'
))
super
()
.
__init__
(
client
,
create_logger
(
__name__
,
'chat_data_manager.log'
))
self
.
now
=
datetime
.
now
()
def
_make_query_sql
(
self
,
conditions
=
None
):
def
_make_query_sql
(
self
,
conditions
=
None
):
...
@@ -30,7 +24,6 @@ class ChatDataManager(DatabaseDataManager):
...
@@ -30,7 +24,6 @@ class ChatDataManager(DatabaseDataManager):
sql
+=
condition_sql
sql
+=
condition_sql
return
sql
return
sql
def
update_data
(
self
):
def
update_data
(
self
):
""" 从数据库中拉取最新订单数据并保存 """
""" 从数据库中拉取最新订单数据并保存 """
sql
=
self
.
_make_query_sql
()
sql
=
self
.
_make_query_sql
()
...
@@ -40,7 +33,6 @@ class ChatDataManager(DatabaseDataManager):
...
@@ -40,7 +33,6 @@ class ChatDataManager(DatabaseDataManager):
self
.
save_csv_data
(
df
,
'all_chat_info.csv'
)
self
.
save_csv_data
(
df
,
'all_chat_info.csv'
)
return
df
return
df
def
update_test_data
(
self
,
conditions
):
def
update_test_data
(
self
,
conditions
):
""" 从数据库中拉取指定条件订单用于测试 """
""" 从数据库中拉取指定条件订单用于测试 """
...
@@ -50,69 +42,13 @@ class ChatDataManager(DatabaseDataManager):
...
@@ -50,69 +42,13 @@ class ChatDataManager(DatabaseDataManager):
df
=
pd
.
DataFrame
(
all_data
)
df
=
pd
.
DataFrame
(
all_data
)
self
.
save_csv_data
(
df
,
'test_chat_info.csv'
)
self
.
save_csv_data
(
df
,
'test_chat_info.csv'
)
def
load_raw_data
(
self
):
def
load_raw_data
(
self
):
return
self
.
load_csv_data
(
'all_chat_info.csv'
)
return
self
.
load_csv_data
(
'all_chat_info.csv'
)
def
load_test_data
(
self
):
def
load_test_data
(
self
):
return
self
.
load_csv_data
(
'test_chat_info.csv'
)
return
self
.
load_csv_data
(
'test_chat_info.csv'
)
def
make_index
(
self
):
"""
构建索引
用户-咨询师 索引
"""
self
.
logger
.
info
(
''
)
self
.
logger
.
info
(
'开始构建 用户-咨询师 索引'
)
df
=
self
.
load_raw_data
()
self
.
logger
.
info
(
'本地用户咨询师对话数据加载完成,共加载
%
s 条数据'
,
len
(
df
))
user_chat
=
{}
for
index
,
row
in
df
.
iterrows
():
uid
,
supplier_id
=
row
[
'uid'
],
row
[
'doctor_id'
]
if
uid
not
in
user_chat
:
user_chat
[
uid
]
=
{}
if
supplier_id
not
in
user_chat
[
uid
]:
user_chat
[
uid
][
supplier_id
]
=
[]
user_chat
[
uid
][
supplier_id
]
.
append
([
row
[
'dt'
],
row
[
'user_to_doctor'
],
row
[
'doctor_to_user'
]])
def
compute_score
(
infos
):
w
=
[
0
,
0
,
0
,
0
]
for
[
dt
,
_u2d
,
_d2u
]
in
infos
:
u2d
,
d2u
=
int
(
_u2d
),
int
(
_d2u
)
date
=
datetime
.
strptime
(
dt
,
'
%
Y-
%
m-
%
d'
)
if
(
self
.
now
-
date
)
<=
timedelta
(
days
=
7
):
w
[
0
]
=
max
(
1.
,
w
[
0
],
(
u2d
+
d2u
)
/
20
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
30
):
w
[
1
]
=
max
(
1.
,
w
[
1
],
(
u2d
+
d2u
)
/
20
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
180
):
w
[
2
]
=
max
(
1.
,
w
[
2
],
(
u2d
+
d2u
)
/
20
)
else
:
w
[
3
]
=
max
(
1.
,
w
[
3
],
(
u2d
+
d2u
)
/
20
)
value
=
w
[
0
]
*
0.5
+
w
[
1
]
*
0.25
+
w
[
2
]
*
0.15
+
w
[
3
]
*
0.1
return
value
index
=
{}
for
uid
,
chats
in
user_chat
.
items
():
supplier_values
=
[]
for
supplier_id
,
infos
in
chats
.
items
():
# 日期越近权重越大
value
=
compute_score
(
infos
)
supplier_values
.
append
([
supplier_id
,
value
])
index
[
uid
]
=
sorted
(
supplier_values
,
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
self
.
logger
.
info
(
'用户-咨询师 询单索引构建完成,共构建
%
s 条数据'
,
len
(
index
))
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_chat_index.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
index
,
f
,
ensure_ascii
=
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
...
...
src/core/manager/database_manager.py
View file @
523f5373
...
@@ -8,20 +8,16 @@ import pandas as pd
...
@@ -8,20 +8,16 @@ import pandas as pd
from
ydl_ai_recommender.src.utils
import
get_data_path
from
ydl_ai_recommender.src.utils
import
get_data_path
from
ydl_ai_recommender.src.utils.log
import
create_logger
from
ydl_ai_recommender.src.utils.log
import
create_logger
from
.manager
import
Manager
# class Manager():
# def __init__(self, logger=None) -> None:
# if logger is None:
# self.logger = create_logger(__name__)
# else:
# self.logger = logger
class
Manager
():
# self.local_file_dir = get_data_path()
def
__init__
(
self
,
logger
=
None
)
->
None
:
if
logger
is
None
:
self
.
logger
=
create_logger
(
__name__
)
else
:
self
.
logger
=
logger
self
.
local_file_dir
=
get_data_path
()
def
make_index
(
self
):
raise
NotImplemented
class
DatabaseDataManager
(
Manager
):
class
DatabaseDataManager
(
Manager
):
...
@@ -29,7 +25,6 @@ class DatabaseDataManager(Manager):
...
@@ -29,7 +25,6 @@ class DatabaseDataManager(Manager):
super
()
.
__init__
(
logger
)
super
()
.
__init__
(
logger
)
self
.
client
=
client
self
.
client
=
client
def
fetch_data_from_db
(
self
,
sql
:
str
)
->
List
:
def
fetch_data_from_db
(
self
,
sql
:
str
)
->
List
:
if
self
.
client
is
None
:
if
self
.
client
is
None
:
self
.
logger
.
error
(
'未连接数据库'
)
self
.
logger
.
error
(
'未连接数据库'
)
...
@@ -37,34 +32,28 @@ class DatabaseDataManager(Manager):
...
@@ -37,34 +32,28 @@ class DatabaseDataManager(Manager):
return
self
.
client
.
query
(
sql
)
return
self
.
client
.
query
(
sql
)
def
load_xlsx_data
(
self
,
filename
):
def
load_xlsx_data
(
self
,
filename
):
return
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
dtype
=
str
)
return
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
dtype
=
str
)
def
save_xlsx_data
(
self
,
df
,
filename
):
def
save_xlsx_data
(
self
,
df
,
filename
):
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
index
=
None
)
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
index
=
None
)
def
load_csv_data
(
self
,
filename
):
def
load_csv_data
(
self
,
filename
):
return
pd
.
read_csv
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
dtype
=
str
)
return
pd
.
read_csv
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
dtype
=
str
)
def
save_csv_data
(
self
,
df
,
filename
):
def
save_csv_data
(
self
,
df
,
filename
):
df
.
to_csv
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
encoding
=
'utf-8'
,
index
=
False
)
df
.
to_csv
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
encoding
=
'utf-8'
,
index
=
False
)
def
load_json_data
(
self
,
filename
):
def
load_json_data
(
self
,
filename
):
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
return
json
.
load
(
f
)
return
json
.
load
(
f
)
def
save_json_data
(
self
,
data
,
filename
):
def
save_json_data
(
self
,
data
,
filename
):
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
return
json
.
dump
(
data
,
f
,
ensure_ascii
=
False
)
return
json
.
dump
(
data
,
f
,
ensure_ascii
=
False
)
def
update_data
(
self
):
raise
NotImplementedError
def
update_data
(
self
,
sql
,
filename
):
def
update_test_data
(
self
,
conditions
):
_
,
all_data
=
self
.
fetch_data_from_db
(
sql
)
raise
NotImplementedError
df
=
pd
.
DataFrame
(
all_data
)
self
.
save_xlsx_data
(
df
,
filename
)
src/core/manager/manager.py
View file @
523f5373
...
@@ -15,4 +15,4 @@ class Manager():
...
@@ -15,4 +15,4 @@ class Manager():
def
make_index
(
self
):
def
make_index
(
self
):
raise
NotImplemented
raise
NotImplementedError
\ No newline at end of file
\ No newline at end of file
src/core/manager/order_data_manager.py
View file @
523f5373
...
@@ -16,7 +16,6 @@ class OrderDataManager(DatabaseDataManager):
...
@@ -16,7 +16,6 @@ class OrderDataManager(DatabaseDataManager):
super
()
.
__init__
(
client
,
create_logger
(
__name__
,
'order_data_manager.log'
))
super
()
.
__init__
(
client
,
create_logger
(
__name__
,
'order_data_manager.log'
))
self
.
now
=
datetime
.
now
()
self
.
now
=
datetime
.
now
()
def
_make_query_sql
(
self
,
conditions
=
None
):
def
_make_query_sql
(
self
,
conditions
=
None
):
condition_sql
=
''
condition_sql
=
''
if
conditions
:
if
conditions
:
...
@@ -28,8 +27,7 @@ class OrderDataManager(DatabaseDataManager):
...
@@ -28,8 +27,7 @@ class OrderDataManager(DatabaseDataManager):
sql
+=
condition_sql
sql
+=
condition_sql
return
sql
return
sql
def
update_data
(
self
):
def
update_order_data
(
self
):
""" 从数据库中拉取最新订单数据并保存 """
""" 从数据库中拉取最新订单数据并保存 """
sql
=
self
.
_make_query_sql
()
sql
=
self
.
_make_query_sql
()
_
,
all_data
=
self
.
fetch_data_from_db
(
sql
)
_
,
all_data
=
self
.
fetch_data_from_db
(
sql
)
...
@@ -38,7 +36,7 @@ class OrderDataManager(DatabaseDataManager):
...
@@ -38,7 +36,7 @@ class OrderDataManager(DatabaseDataManager):
self
.
save_xlsx_data
(
df
,
'all_order_info.xlsx'
)
self
.
save_xlsx_data
(
df
,
'all_order_info.xlsx'
)
def
update_test_
order_
data
(
self
,
conditions
):
def
update_test_data
(
self
,
conditions
):
""" 从数据库中拉取指定条件订单用于测试 """
""" 从数据库中拉取指定条件订单用于测试 """
sql
=
self
.
_make_query_sql
(
conditions
)
sql
=
self
.
_make_query_sql
(
conditions
)
...
@@ -56,82 +54,6 @@ class OrderDataManager(DatabaseDataManager):
...
@@ -56,82 +54,6 @@ class OrderDataManager(DatabaseDataManager):
return
self
.
load_xlsx_data
(
'test_order_info.xlsx'
)
return
self
.
load_xlsx_data
(
'test_order_info.xlsx'
)
def
make_index
(
self
):
"""
构建索引
用户-咨询师 索引
top100 咨询师列表 用于冷启动
"""
self
.
logger
.
info
(
''
)
self
.
logger
.
info
(
'开始构建 用户-咨询师 索引'
)
df
=
self
.
load_raw_data
()
self
.
logger
.
info
(
'本地订单加载数据完成,共加载
%
s 条数据'
,
len
(
df
))
user_order
=
{}
for
index
,
row
in
df
.
iterrows
():
uid
,
supplier_id
=
row
[
'uid'
],
row
[
'supplier_id'
]
if
uid
not
in
user_order
:
user_order
[
uid
]
=
{}
if
supplier_id
not
in
user_order
[
uid
]:
user_order
[
uid
][
supplier_id
]
=
[]
user_order
[
uid
][
supplier_id
]
.
append
([
row
[
'price'
],
row
[
'update_time'
]])
def
compute_score
(
infos
):
w
=
[
0
,
0
,
0
,
0
]
for
[
_price
,
dt
]
in
infos
:
price
=
float
(
_price
)
date
=
datetime
.
strptime
(
dt
,
'
%
Y-
%
m-
%
d'
)
if
(
self
.
now
-
date
)
<=
timedelta
(
days
=
7
):
w
[
0
]
=
max
(
1.
,
w
[
0
],
price
/
400
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
30
):
w
[
1
]
=
max
(
1.
,
w
[
1
],
price
/
400
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
180
):
w
[
2
]
=
max
(
1.
,
w
[
2
],
price
/
400
)
else
:
w
[
3
]
=
max
(
1.
,
w
[
3
],
price
/
400
)
value
=
w
[
0
]
*
0.5
+
w
[
1
]
*
0.25
+
w
[
2
]
*
0.15
+
w
[
3
]
*
0.1
return
value
index
=
{}
for
uid
,
orders
in
user_order
.
items
():
supplier_values
=
[]
for
supplier_id
,
infos
in
orders
.
items
():
# 订单越多排序约靠前,相同数量订单,最新订单约晚越靠前
value
=
compute_score
(
infos
)
supplier_values
.
append
([
supplier_id
,
value
])
# value = len(infos)
# latest_time = max([info[1] for info in infos])
# supplier_values.append([supplier_id, value, latest_time])
# index[uid] = sorted(supplier_values, key=lambda x: (x[2], x[1]), reverse=True)
index
[
uid
]
=
sorted
(
supplier_values
,
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
self
.
logger
.
info
(
'用户-咨询师 索引构建完成,共构建
%
s 条数据'
,
len
(
index
))
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_index.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
index
,
f
,
ensure_ascii
=
False
)
self
.
logger
.
info
(
'用户-咨询师 索引数据已保存,共有用户
%
s'
,
len
(
index
))
# 订单最多的咨询师
supplier_cnter
=
Counter
(
df
[
'supplier_id'
])
top100_supplier
=
[]
for
key
,
_
in
supplier_cnter
.
most_common
(
100
):
top100_supplier
.
append
(
str
(
key
))
self
.
logger
.
info
(
'top100 订单量咨询师统计完成'
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'top100_supplier.txt'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
'
\n
'
.
join
(
top100_supplier
))
self
.
logger
.
info
(
'top100 订单量咨询师列表已保存'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
manager
=
OrderDataManager
()
manager
=
OrderDataManager
()
print
(
manager
.
make_index
())
print
(
manager
.
make_index
())
\ No newline at end of file
src/core/manager/profile_manager.py
View file @
523f5373
...
@@ -6,7 +6,7 @@ from typing import List
...
@@ -6,7 +6,7 @@ from typing import List
import
pandas
as
pd
import
pandas
as
pd
from
ydl_ai_recommender.src.core.profile
import
profile_converters
from
ydl_ai_recommender.src.core.profile
import
encode_profile
from
ydl_ai_recommender.src.core.manager
import
DatabaseDataManager
from
ydl_ai_recommender.src.core.manager
import
DatabaseDataManager
from
ydl_ai_recommender.src.utils.log
import
create_logger
from
ydl_ai_recommender.src.utils.log
import
create_logger
...
@@ -29,8 +29,7 @@ class ProfileManager(DatabaseDataManager):
...
@@ -29,8 +29,7 @@ class ProfileManager(DatabaseDataManager):
sql
+=
' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order{})'
.
format
(
condition_sql
)
sql
+=
' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order{})'
.
format
(
condition_sql
)
return
sql
return
sql
def
update_data
(
self
):
def
update_profile
(
self
):
""" 从数据库中拉取最新画像特征并保存 """
""" 从数据库中拉取最新画像特征并保存 """
sql
=
self
.
_make_query_sql
()
sql
=
self
.
_make_query_sql
()
...
@@ -39,8 +38,7 @@ class ProfileManager(DatabaseDataManager):
...
@@ -39,8 +38,7 @@ class ProfileManager(DatabaseDataManager):
df
=
pd
.
DataFrame
(
all_data
)
df
=
pd
.
DataFrame
(
all_data
)
self
.
save_xlsx_data
(
df
,
'all_profile.xlsx'
)
self
.
save_xlsx_data
(
df
,
'all_profile.xlsx'
)
def
update_test_data
(
self
,
conditions
):
def
update_test_profile
(
self
,
conditions
):
""" 从数据库中拉取指定条件画像信息用于测试 """
""" 从数据库中拉取指定条件画像信息用于测试 """
sql
=
self
.
_make_query_sql
(
conditions
)
sql
=
self
.
_make_query_sql
(
conditions
)
...
@@ -48,39 +46,13 @@ class ProfileManager(DatabaseDataManager):
...
@@ -48,39 +46,13 @@ class ProfileManager(DatabaseDataManager):
df
=
pd
.
DataFrame
(
all_data
)
df
=
pd
.
DataFrame
(
all_data
)
self
.
save_xlsx_data
(
df
,
'test_profile.xlsx'
)
self
.
save_xlsx_data
(
df
,
'test_profile.xlsx'
)
def
_load_profile_data
(
self
):
def
_load_profile_data
(
self
):
return
self
.
load_xlsx_data
(
'all_profile.xlsx'
)
return
self
.
load_xlsx_data
(
'all_profile.xlsx'
)
def
load_test_profile_data
(
self
):
def
load_test_profile_data
(
self
):
return
self
.
load_xlsx_data
(
'test_profile.xlsx'
)
return
self
.
load_xlsx_data
(
'test_profile.xlsx'
)
def
profile_to_embedding
(
self
,
profile
):
"""
将用户画像信息转换为向量
"""
embedding
=
[]
for
[
name
,
converter
]
in
profile_converters
:
embedding
.
extend
(
converter
.
convert
(
profile
[
name
]))
return
embedding
def
embedding_to_profile
(
self
,
embedding
):
"""
向量转换为用户画像
"""
ret
=
{}
si
=
0
for
[
name
,
converter
]
in
profile_converters
:
ei
=
si
+
converter
.
dim
ret
[
name
]
=
converter
.
inconvert
(
embedding
[
si
:
ei
])
si
=
ei
return
ret
def
make_embeddings
(
self
):
def
make_embeddings
(
self
):
user_profiles
=
self
.
_load_profile_data
()
user_profiles
=
self
.
_load_profile_data
()
self
.
logger
.
info
(
'订单用户画像数据加载完成,共加载
%
s 条'
,
len
(
user_profiles
))
self
.
logger
.
info
(
'订单用户画像数据加载完成,共加载
%
s 条'
,
len
(
user_profiles
))
...
@@ -88,7 +60,7 @@ class ProfileManager(DatabaseDataManager):
...
@@ -88,7 +60,7 @@ class ProfileManager(DatabaseDataManager):
self
.
logger
.
info
(
'开始构建订单用户的用户画像向量'
)
self
.
logger
.
info
(
'开始构建订单用户的用户画像向量'
)
for
_
,
profile
in
user_profiles
.
iterrows
():
for
_
,
profile
in
user_profiles
.
iterrows
():
user_ids
.
append
(
str
(
profile
[
'uid'
]))
user_ids
.
append
(
str
(
profile
[
'uid'
]))
embeddings
.
append
(
self
.
profile_to_embedding
(
profile
))
embeddings
.
append
(
encode_profile
(
profile
))
self
.
logger
.
info
(
'用户画像向量构建完成,共构建
%
s 用户'
,
len
(
user_ids
))
self
.
logger
.
info
(
'用户画像向量构建完成,共构建
%
s 用户'
,
len
(
user_ids
))
...
@@ -100,7 +72,6 @@ class ProfileManager(DatabaseDataManager):
...
@@ -100,7 +72,6 @@ class ProfileManager(DatabaseDataManager):
return
embeddings
return
embeddings
def
make_virtual_embedding
(
self
):
def
make_virtual_embedding
(
self
):
user_ids
=
[]
user_ids
=
[]
embeddings
=
[]
embeddings
=
[]
...
...
src/core/manager/user_counselor_index_manager.py
deleted
100644 → 0
View file @
7401bbe5
# -*- coding: utf-8 -*-
import
os
import
json
from
ydl_ai_recommender.src.utils.log
import
create_logger
from
ydl_ai_recommender.src.core.manager
import
Manager
# from ydl_ai_recommender.src.core.manager import OrderDataManager, ChatDataManager
class
UserCounselorIndexManager
(
Manager
):
def
__init__
(
self
)
->
None
:
super
()
.
__init__
(
create_logger
(
__name__
,
'order_data_manager.log'
))
# self.order_data_manager = OrderDataManager()
# self.chat_data_manager = ChatDataManager()
def
make_index
(
self
):
self
.
logger
.
info
(
'开始构建用户-咨询师 合并索引'
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_index.json'
),
encoding
=
'utf-8'
)
as
f
:
user_doctor_index
=
json
.
load
(
f
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_chat_index.json'
),
encoding
=
'utf-8'
)
as
f
:
user_doctor_chat_index
=
json
.
load
(
f
)
merged_index
=
{}
for
uid
,
counselors
in
user_doctor_index
.
items
():
chat_index
=
{
c_id
:
value
for
c_id
,
value
in
user_doctor_chat_index
.
get
(
uid
,
[])
}
new_counselors
=
[]
for
[
c_id
,
value
]
in
counselors
:
if
c_id
in
chat_index
:
merge_value
=
value
*
0.6
+
chat_index
[
c_id
]
*
0.4
else
:
merge_value
=
value
*
0.6
new_counselors
.
append
([
c_id
,
merge_value
])
merged_index
[
uid
]
=
sorted
(
new_counselors
,
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
self
.
logger
.
info
(
'用户-咨询师 合并索引构建完成,共构建
%
s 条数据'
,
len
(
merged_index
))
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'merged_user_doctor_index.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
merged_index
,
f
,
ensure_ascii
=
False
)
if
__name__
==
'__main__'
:
manager
=
UserCounselorIndexManager
()
manager
.
make_index
()
\ No newline at end of file
src/core/profile
/__init__
.py
→
src/core/profile.py
View file @
523f5373
...
@@ -5,20 +5,17 @@ from typing import Dict, List, Any, Union
...
@@ -5,20 +5,17 @@ from typing import Dict, List, Any, Union
import
pandas
as
pd
import
pandas
as
pd
# from .country_code_profile import CountryCodeProfile
# from .profile import ChannelIdTypeProfile
class
BaseProfile
():
class
BaseProfile
():
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
dim
:
int
=
0
self
.
dim
:
int
=
0
def
convert
(
self
,
value
):
def
convert
(
self
,
value
):
raise
NotImplemented
raise
NotImplemented
Error
def
inconvert
(
self
,
embedding
:
List
[
Union
[
int
,
float
]])
->
str
:
def
inconvert
(
self
,
embedding
:
List
[
Union
[
int
,
float
]])
->
str
:
raise
NotImplemented
raise
NotImplemented
Error
class
CountryCodeProfile
(
BaseProfile
):
class
CountryCodeProfile
(
BaseProfile
):
...
@@ -29,7 +26,7 @@ class CountryCodeProfile(BaseProfile):
...
@@ -29,7 +26,7 @@ class CountryCodeProfile(BaseProfile):
def
convert
(
self
,
value
):
def
convert
(
self
,
value
):
try
:
try
:
value
=
int
(
value
)
value
=
int
(
value
)
except
Exception
as
e
:
except
Exception
:
return
[
0
,
0
,
1
]
return
[
0
,
0
,
1
]
if
value
==
86
:
if
value
==
86
:
return
[
1
,
0
,
0
]
return
[
1
,
0
,
0
]
...
@@ -53,9 +50,9 @@ class ChannelIdTypeProfile(BaseProfile):
...
@@ -53,9 +50,9 @@ class ChannelIdTypeProfile(BaseProfile):
def
convert
(
self
,
value
):
def
convert
(
self
,
value
):
try
:
try
:
value
=
int
(
value
)
value
=
int
(
value
)
except
Exception
as
e
:
except
Exception
:
return
[
0
,
0
,
1
]
return
[
0
,
0
,
1
]
if
value
==
1
:
if
value
==
1
:
return
[
1
,
0
,
0
]
return
[
1
,
0
,
0
]
elif
value
==
2
:
elif
value
==
2
:
...
@@ -86,7 +83,7 @@ class FfromLoginProfile(BaseProfile):
...
@@ -86,7 +83,7 @@ class FfromLoginProfile(BaseProfile):
ret
=
[
0
,
0
,
0
,
0
,
0
]
ret
=
[
0
,
0
,
0
,
0
,
0
]
try
:
try
:
value
=
value
.
lower
()
value
=
value
.
lower
()
except
Exception
as
e
:
except
Exception
:
return
ret
return
ret
for
i
,
v
in
enumerate
(
self
.
brand_list
):
for
i
,
v
in
enumerate
(
self
.
brand_list
):
...
@@ -119,11 +116,11 @@ class UserPreferenceCateProfile(BaseProfile):
...
@@ -119,11 +116,11 @@ class UserPreferenceCateProfile(BaseProfile):
ret
=
[
0.
]
*
8
ret
=
[
0.
]
*
8
if
pd
.
isnull
(
value
):
if
pd
.
isnull
(
value
):
return
ret
return
ret
if
isinstance
(
value
,
str
):
if
isinstance
(
value
,
str
):
try
:
try
:
value
=
json
.
loads
(
value
)
value
=
json
.
loads
(
value
)
except
Exception
as
e
:
except
Exception
:
return
ret
return
ret
for
info
in
value
:
for
info
in
value
:
...
@@ -167,11 +164,10 @@ class NumClassProfile(BaseProfile):
...
@@ -167,11 +164,10 @@ class NumClassProfile(BaseProfile):
value
=
float
(
value
)
value
=
float
(
value
)
index
=
self
.
value_index
(
value
)
index
=
self
.
value_index
(
value
)
ret
[
index
]
=
1
ret
[
index
]
=
1
except
:
except
Exception
:
return
ret
return
ret
return
ret
return
ret
def
inconvert
(
self
,
embedding
):
def
inconvert
(
self
,
embedding
):
ret
=
''
ret
=
''
# 确保embedding中有包含1的值
# 确保embedding中有包含1的值
...
@@ -245,7 +241,6 @@ class MultiChoiceProfile(BaseProfile):
...
@@ -245,7 +241,6 @@ class MultiChoiceProfile(BaseProfile):
self
.
option_dict
=
option_dict
self
.
option_dict
=
option_dict
self
.
re_option_dict
=
{
v
:
k
for
k
,
v
in
self
.
option_dict
.
items
()}
self
.
re_option_dict
=
{
v
:
k
for
k
,
v
in
self
.
option_dict
.
items
()}
def
convert
(
self
,
value
:
List
):
def
convert
(
self
,
value
:
List
):
ret
=
[
0
]
*
len
(
self
.
option_dict
)
ret
=
[
0
]
*
len
(
self
.
option_dict
)
if
pd
.
isnull
(
value
):
if
pd
.
isnull
(
value
):
...
@@ -261,7 +256,6 @@ class MultiChoiceProfile(BaseProfile):
...
@@ -261,7 +256,6 @@ class MultiChoiceProfile(BaseProfile):
pass
pass
return
ret
return
ret
def
inconvert
(
self
,
embedding
):
def
inconvert
(
self
,
embedding
):
ret
=
[]
ret
=
[]
...
@@ -284,7 +278,6 @@ class CityProfile(BaseProfile):
...
@@ -284,7 +278,6 @@ class CityProfile(BaseProfile):
self
.
level
=
level
self
.
level
=
level
self
.
dim
=
self
.
level
*
10
self
.
dim
=
self
.
level
*
10
def
convert
(
self
,
value
):
def
convert
(
self
,
value
):
ret
=
[
0
]
*
self
.
dim
ret
=
[
0
]
*
self
.
dim
...
@@ -297,11 +290,10 @@ class CityProfile(BaseProfile):
...
@@ -297,11 +290,10 @@ class CityProfile(BaseProfile):
n
=
int
(
_n
)
n
=
int
(
_n
)
ret
[
i
*
10
+
n
]
=
1
ret
[
i
*
10
+
n
]
=
1
except
Exception
as
e
:
except
Exception
:
pass
pass
return
ret
return
ret
def
inconvert
(
self
,
embedding
):
def
inconvert
(
self
,
embedding
):
# 邮编固定都是6
# 邮编固定都是6
ret
=
[
0
]
*
6
ret
=
[
0
]
*
6
...
@@ -318,7 +310,6 @@ class AidiCstBiasCityProfile(CityProfile):
...
@@ -318,7 +310,6 @@ class AidiCstBiasCityProfile(CityProfile):
def
__init__
(
self
,
level
=
2
)
->
None
:
def
__init__
(
self
,
level
=
2
)
->
None
:
super
()
.
__init__
(
level
=
level
)
super
()
.
__init__
(
level
=
level
)
def
convert
(
self
,
value_object
):
def
convert
(
self
,
value_object
):
ret
=
[
0
]
*
self
.
dim
ret
=
[
0
]
*
self
.
dim
...
@@ -331,7 +322,7 @@ class AidiCstBiasCityProfile(CityProfile):
...
@@ -331,7 +322,7 @@ class AidiCstBiasCityProfile(CityProfile):
if
isinstance
(
value_object
,
str
):
if
isinstance
(
value_object
,
str
):
try
:
try
:
value_object
=
json
.
loads
(
value_object
)
value_object
=
json
.
loads
(
value_object
)
except
Exception
as
e
:
except
Exception
:
pass
pass
if
isinstance
(
value_object
,
dict
):
if
isinstance
(
value_object
,
dict
):
...
@@ -340,7 +331,7 @@ class AidiCstBiasCityProfile(CityProfile):
...
@@ -340,7 +331,7 @@ class AidiCstBiasCityProfile(CityProfile):
for
i
,
_n
in
enumerate
(
value
[:
self
.
level
]):
for
i
,
_n
in
enumerate
(
value
[:
self
.
level
]):
n
=
int
(
_n
)
n
=
int
(
_n
)
ret
[
i
*
10
+
n
]
=
1
ret
[
i
*
10
+
n
]
=
1
except
Exception
as
e
:
except
Exception
:
pass
pass
return
ret
return
ret
...
@@ -367,3 +358,25 @@ profile_converters = [
...
@@ -367,3 +358,25 @@ profile_converters = [
[
'd30_session_num'
,
NumClassProfile
([
0
,
1
])],
[
'd30_session_num'
,
NumClassProfile
([
0
,
1
])],
]
]
def
encode_profile
(
profile
):
"""
将用户画像信息转换为向量
"""
embedding
=
[]
for
[
name
,
converter
]
in
profile_converters
:
embedding
.
extend
(
converter
.
convert
(
profile
[
name
]))
return
embedding
def
decode_profile
(
embedding
):
"""
向量转换为用户画像
"""
ret
=
{}
si
=
0
for
[
name
,
converter
]
in
profile_converters
:
ei
=
si
+
converter
.
dim
ret
[
name
]
=
converter
.
inconvert
(
embedding
[
si
:
ei
])
si
=
ei
return
ret
src/core/profile/country_code_profile.py
deleted
100644 → 0
View file @
7401bbe5
# -*- coding: utf-8 -*-
class
CountryCodeProfile
():
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
value
):
try
:
value
=
int
(
value
)
except
Exception
as
e
:
return
[
0
,
0
,
1
]
if
value
==
86
:
return
[
1
,
0
,
0
]
else
:
return
[
0
,
1
,
0
]
\ No newline at end of file
src/core/profile/profile.py
deleted
100644 → 0
View file @
7401bbe5
# -*- coding: utf-8 -*-
class
ChannelIdTypeProfile
():
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
value
):
try
:
value
=
int
(
value
)
except
Exception
as
e
:
return
[
0
,
0
,
1
]
if
value
==
1
:
return
[
1
,
0
,
0
]
elif
value
==
2
:
return
[
0
,
1
,
0
]
else
:
return
[
0
,
0
,
1
]
\ No newline at end of file
src/core/recommender.py
View file @
523f5373
...
@@ -7,7 +7,9 @@ from typing import List, Dict
...
@@ -7,7 +7,9 @@ from typing import List, Dict
import
faiss
import
faiss
import
numpy
as
np
import
numpy
as
np
from
ydl_ai_recommender.src.core.manager
import
ProfileManager
from
ydl_ai_recommender.src.core.indexer
import
UserCounselorDefaultIndexer
from
ydl_ai_recommender.src.core.indexer
import
UserCounselorCombinationIndexer
from
ydl_ai_recommender.src.core.profile
import
encode_profile
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
from
ydl_ai_recommender.src.utils
import
get_conf_path
,
get_data_path
from
ydl_ai_recommender.src.utils
import
get_conf_path
,
get_data_path
from
ydl_ai_recommender.src.utils.log
import
create_logger
from
ydl_ai_recommender.src.utils.log
import
create_logger
...
@@ -19,7 +21,7 @@ class Recommender():
...
@@ -19,7 +21,7 @@ class Recommender():
pass
pass
def
recommend
(
self
,
user
)
->
List
:
def
recommend
(
self
,
user
)
->
List
:
raise
NotImplemented
raise
NotImplemented
Error
class
UserCFRecommender
(
Recommender
):
class
UserCFRecommender
(
Recommender
):
...
@@ -37,7 +39,11 @@ class UserCFRecommender(Recommender):
...
@@ -37,7 +39,11 @@ class UserCFRecommender(Recommender):
else
:
else
:
self
.
logger
.
warn
(
'未连接数据库'
)
self
.
logger
.
warn
(
'未连接数据库'
)
self
.
manager
=
ProfileManager
()
self
.
default_indexer
=
UserCounselorDefaultIndexer
(
self
.
logger
)
self
.
default_indexer
.
load_index_data
()
self
.
indexer
=
UserCounselorCombinationIndexer
(
self
.
logger
)
self
.
indexer
.
load_index_data
()
self
.
local_file_dir
=
get_data_path
()
self
.
local_file_dir
=
get_data_path
()
self
.
load_data
()
self
.
load_data
()
...
@@ -45,8 +51,6 @@ class UserCFRecommender(Recommender):
...
@@ -45,8 +51,6 @@ class UserCFRecommender(Recommender):
def
load_data
(
self
):
def
load_data
(
self
):
order_user_embedding
=
[]
order_user_embedding
=
[]
order_user_ids
=
[]
order_user_ids
=
[]
order_user_counselor_index
=
{}
default_counselor
=
[]
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings_ids.txt'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings_ids.txt'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
order_user_ids
=
[
line
.
strip
()
for
line
in
f
]
order_user_ids
=
[
line
.
strip
()
for
line
in
f
]
...
@@ -54,21 +58,8 @@ class UserCFRecommender(Recommender):
...
@@ -54,21 +58,8 @@ class UserCFRecommender(Recommender):
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings.json'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings.json'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
order_user_embedding
=
json
.
load
(
f
)
order_user_embedding
=
json
.
load
(
f
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'merged_user_doctor_index.json'
),
encoding
=
'utf-8'
)
as
f
:
# with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), encoding='utf-8') as f:
order_user_counselor_index
=
json
.
load
(
f
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'top100_supplier.txt'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
default_counselor
=
[
line
.
strip
()
for
line
in
f
]
self
.
order_user_embedding
=
order_user_embedding
self
.
order_user_embedding
=
order_user_embedding
self
.
order_user_ids
=
order_user_ids
self
.
order_user_ids
=
order_user_ids
self
.
order_user_counselor_index
=
order_user_counselor_index
self
.
default_counselor
=
[{
'counselor'
:
str
(
user
),
'score'
:
1
-
0.01
*
index
,
'from'
:
'top_100'
,
}
for
index
,
user
in
enumerate
(
default_counselor
)]
self
.
index
=
faiss
.
IndexFlatL2
(
len
(
self
.
order_user_embedding
[
0
]))
self
.
index
=
faiss
.
IndexFlatL2
(
len
(
self
.
order_user_embedding
[
0
]))
self
.
index
.
add
(
np
.
array
(
self
.
order_user_embedding
))
self
.
index
.
add
(
np
.
array
(
self
.
order_user_embedding
))
...
@@ -88,28 +79,27 @@ class UserCFRecommender(Recommender):
...
@@ -88,28 +79,27 @@ class UserCFRecommender(Recommender):
return
[]
return
[]
def
user_token
(
self
,
user_profile
):
return
self
.
manager
.
profile_to_embedding
(
user_profile
)
def
_recommend_top
(
self
,
size
=
100
):
def
_recommend_top
(
self
,
size
=
100
):
return
self
.
default_counselor
[:
size
]
return
[{
'counselor'
:
str
(
c_id
),
'score'
:
score
,
'from'
:
'default'
,
}
for
[
c_id
,
score
]
in
self
.
default_indexer
.
index
(
size
)]
def
_recommend
(
self
,
user_embedding
):
def
_recommend
(
self
,
user_embedding
):
D
,
I
=
self
.
index
.
search
(
np
.
array
([
user_embedding
]),
self
.
k
)
D
,
I
=
self
.
index
.
search
(
np
.
array
([
user_embedding
]),
self
.
k
)
counselors
=
[]
counselors
=
[]
for
idx
,
score
in
zip
(
I
[
0
],
D
[
0
]):
for
idx
,
s
imi_s
core
in
zip
(
I
[
0
],
D
[
0
]):
# 相似用户uid
# 相似用户uid
similar_user_id
=
self
.
order_user_ids
[
idx
]
similar_user_id
=
self
.
order_user_ids
[
idx
]
similar_user_counselor
=
self
.
order_user_counselor_index
.
get
(
similar_user_id
,
[])
similar_user_counselor
=
self
.
indexer
.
index
(
q
=
similar_user_id
,
count
=
self
.
top_n
)
recommend_data
=
[{
recommend_data
=
[{
'counselor'
:
str
(
user
[
0
])
,
'counselor'
:
c_id
,
'score'
:
1
/
max
(
0.01
,
float
(
score
)
*
(
index
+
1
)),
'score'
:
score
/
max
(
0.01
,
float
(
simi_score
)),
'from'
:
'similar_users {}'
.
format
(
similar_user_id
),
'from'
:
'similar_users {}'
.
format
(
similar_user_id
),
}
for
index
,
user
in
enumerate
(
similar_user_counselor
[:
self
.
top_n
])
]
}
for
(
c_id
,
score
)
in
similar_user_counselor
]
counselors
.
extend
(
recommend_data
)
counselors
.
extend
(
recommend_data
)
counselors
.
sort
(
key
=
lambda
x
:
x
[
'score'
],
reverse
=
True
)
counselors
.
sort
(
key
=
lambda
x
:
x
[
'score'
],
reverse
=
True
)
...
@@ -117,7 +107,7 @@ class UserCFRecommender(Recommender):
...
@@ -117,7 +107,7 @@ class UserCFRecommender(Recommender):
def
recommend_with_profile
(
self
,
user_profile
,
size
=
0
,
is_merge
=
True
):
def
recommend_with_profile
(
self
,
user_profile
,
size
=
0
,
is_merge
=
True
):
user_embedding
=
self
.
user_token
(
user_profile
)
user_embedding
=
encode_profile
(
user_profile
)
counselors
=
self
.
_recommend
(
user_embedding
)
counselors
=
self
.
_recommend
(
user_embedding
)
# size == 0 时,不追加默认推荐咨询师
# size == 0 时,不追加默认推荐咨询师
...
...
src/data/mysql_client.py
View file @
523f5373
...
@@ -40,7 +40,8 @@ class MySQLClient():
...
@@ -40,7 +40,8 @@ class MySQLClient():
try
:
try
:
self
.
cursor
.
close
()
self
.
cursor
.
close
()
self
.
connection
.
close
()
self
.
connection
.
close
()
self
.
logger
.
info
(
'dataset disconnected'
)
# 容易触发 NameError: name 'open' is not defined
# self.logger.info('dataset disconnected')
except
Exception
as
e
:
except
Exception
as
e
:
self
.
logger
.
error
(
'销毁 MySQLClient 失败'
,
exc_info
=
True
)
self
.
logger
.
error
(
'销毁 MySQLClient 失败'
,
exc_info
=
True
)
print
(
e
)
print
(
e
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment