Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
Y
ydl_ai_recommender
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
闫发泽
ydl_ai_recommender
Commits
523f5373
Commit
523f5373
authored
Dec 19, 2022
by
柴鹏飞
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
refactor code
parent
7401bbe5
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
443 additions
and
384 deletions
+443
-384
test.py
bin/test.py
+2
-2
update.py
bin/update.py
+56
-35
indexer.py
src/core/indexer.py
+310
-0
__init__.py
src/core/manager/__init__.py
+0
-3
chat_data_manager.py
src/core/manager/chat_data_manager.py
+0
-64
database_manager.py
src/core/manager/database_manager.py
+12
-23
manager.py
src/core/manager/manager.py
+2
-2
order_data_manager.py
src/core/manager/order_data_manager.py
+2
-81
profile_manager.py
src/core/manager/profile_manager.py
+4
-33
user_counselor_index_manager.py
src/core/manager/user_counselor_index_manager.py
+0
-53
profile.py
src/core/profile.py
+33
-20
country_code_profile.py
src/core/profile/country_code_profile.py
+0
-17
profile.py
src/core/profile/profile.py
+0
-20
recommender.py
src/core/recommender.py
+20
-30
mysql_client.py
src/data/mysql_client.py
+2
-1
No files found.
bin/test.py
View file @
523f5373
...
...
@@ -190,11 +190,11 @@ def update_test_data(args):
# 订单数据
manager
=
OrderDataManager
(
client
)
manager
.
update_test_
order_
data
(
conditions
=
conditions
)
manager
.
update_test_data
(
conditions
=
conditions
)
# 用户画像数据
manager
=
ProfileManager
(
client
)
manager
.
update_test_
profile
(
conditions
=
conditions
)
manager
.
update_test_
data
(
conditions
=
conditions
)
logger
.
info
(
'测试数据更新完成'
)
...
...
bin/update.py
View file @
523f5373
...
...
@@ -4,9 +4,17 @@ import os
import
argparse
from
datetime
import
datetime
from
ydl_ai_recommender.src.core.manager
import
OrderDataManager
from
ydl_ai_recommender.src.core.manager
import
ChatDataManager
from
ydl_ai_recommender.src.core.manager
import
ProfileManager
from
ydl_ai_recommender.src.core.manager
import
(
OrderDataManager
,
ChatDataManager
,
ProfileManager
,
)
from
ydl_ai_recommender.src.core.indexer
import
(
UserCounselorDefaultIndexer
,
UserCounselorOrderIndexer
,
UserCounselorChatIndexer
,
UserCounselorCombinationIndexer
,
)
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
from
ydl_ai_recommender.src.utils
import
get_conf_path
,
get_project_path
from
ydl_ai_recommender.src.utils.log
import
create_logger
...
...
@@ -18,18 +26,14 @@ logger = create_logger(__name__, 'update.log')
parser
=
argparse
.
ArgumentParser
(
description
=
'壹点灵 咨询师推荐 算法召回 离线更新数据模型'
)
parser
.
add_argument
(
'-t'
,
'--task'
,
type
=
str
,
required
=
True
,
choices
=
(
'load_db_data'
,
'make_embedding'
,
'make_
virtual_embedding
'
),
help
=
'执行任务名称'
choices
=
(
'load_db_data'
,
'make_embedding'
,
'make_
index
'
),
help
=
'执行任务名称'
)
parser
.
add_argument
(
'--index_last_date'
,
default
=
None
,
type
=
str
,
help
=
'构建索引最后日期,超过该日期的数据不使用'
)
args
=
parser
.
parse_args
()
if
__name__
==
'__main__'
:
if
args
.
task
==
'load_db_data'
:
logger
.
info
(
''
)
def
initialize_dir
():
# 创建数据目录
now
=
datetime
.
now
()
really_data_dir
=
os
.
path
.
join
(
get_project_path
(),
'data_{}'
.
format
(
now
.
strftime
(
'
%
Y
%
m
%
d_
%
H
%
M
%
S'
)))
...
...
@@ -51,46 +55,63 @@ if __name__ == '__main__':
# TODO 历史数据删除
logger
.
info
(
'开始从数据库中更新数据'
)
client
=
MySQLClient
.
create_from_config_file
(
get_conf_path
())
logger
.
info
(
'开始从数据库中更新画像数据'
)
profile_manager
=
ProfileManager
(
client
)
profile_manager
.
update_profile
()
if
__name__
==
'__main__'
:
logger
.
info
(
'开始从数据库中更新订单数据'
)
order_data_manager
=
OrderDataManager
(
client
)
order_data_manager
.
update_order_data
()
logger
.
info
(
''
)
if
args
.
task
==
'load_db_data'
:
logger
.
info
(
'开始从数据库中更新询单数据'
)
chat_data_manager
=
ChatDataManager
(
client
)
chat_data_manager
.
update_data
()
initialize_dir
()
logger
.
info
(
'开始从数据库中更新数据'
)
client
=
MySQLClient
.
create_from_config_file
(
get_conf_path
())
logger
.
info
(
'所有数据更新数据完成'
)
managers
=
[
[
'画像数据'
,
ProfileManager
(
client
)],
[
'订单数据'
,
OrderDataManager
(
client
)],
[
'询单数据'
,
ChatDataManager
(
client
)],
]
for
[
name
,
manager
]
in
managers
:
logger
.
info
(
'开始更新
%
s'
,
name
)
manager
.
update_data
()
logger
.
info
(
'
%
s 更新完成'
,
name
)
logger
.
info
(
''
)
logger
.
info
(
'所有数据更新数据完成'
)
if
args
.
task
==
'make_embedding'
:
logger
.
info
(
''
)
logger
.
info
(
'--'
*
50
)
logger
.
info
(
'开始构建用户特征 embedding'
)
manager
=
ProfileManager
()
manager
.
make_embeddings
()
logger
.
info
(
'用户特征 embedding 构建完成'
)
logger
.
info
(
'开始构建订单相关索引'
)
manager
=
OrderDataManager
()
manager
.
make_index
()
logger
.
info
(
'订单相关索引 构建完成'
)
logger
.
info
(
'开始构建询单相关索引'
)
chat_data_manager
=
ChatDataManager
()
chat_data_manager
.
make_index
()
logger
.
info
(
'询单相关索引 构建完成'
)
if
args
.
task
==
'make_index'
:
indexers
=
[
[
'[用户->咨询师]兜底关系索引'
,
UserCounselorDefaultIndexer
()],
[
'基于订单数据的[用户->咨询师]关系索引'
,
UserCounselorOrderIndexer
()],
[
'基于询单数据的[用户->咨询师]关系索引'
,
UserCounselorChatIndexer
()],
[
'基于多种数据组合的[用户->咨询师]关系索引'
,
UserCounselorCombinationIndexer
()],
]
logger
.
info
(
''
)
logger
.
info
(
'--'
*
50
)
if
args
.
task
==
'make_virtual_embedding'
:
for
[
name
,
indexer
]
in
indexers
:
logger
.
info
(
'开始构建
%
s'
,
name
)
indexer
.
make_index
()
logger
.
info
(
'
%
s 构建完成'
,
name
)
logger
.
info
(
''
)
logger
.
info
(
'开始构建用户特征虚拟embedding'
)
manager
=
ProfileManager
()
manager
.
make_virtual_embedding
()
logger
.
info
(
'用户特征虚拟 embedding 构建完成'
)
\ No newline at end of file
logger
.
info
(
'所有索引更新数据完成'
)
# if args.task == 'make_virtual_embedding':
# logger.info('')
# logger.info('开始构建用户特征虚拟embedding')
# manager = ProfileManager()
# manager.make_virtual_embedding()
# logger.info('用户特征虚拟 embedding 构建完成')
\ No newline at end of file
src/core/indexer.py
0 → 100644
View file @
523f5373
# -*- coding: utf-8 -*-
import
os
import
json
from
collections
import
Counter
from
datetime
import
datetime
,
timedelta
from
typing
import
Dict
,
List
,
Tuple
from
ydl_ai_recommender.src.utils
import
get_data_path
from
ydl_ai_recommender.src.utils.log
import
create_logger
from
ydl_ai_recommender.src.core.manager
import
OrderDataManager
,
ChatDataManager
class
Indexer
():
"""
索引构建、管理类
"""
def
__init__
(
self
,
logger
=
None
)
->
None
:
if
logger
is
None
:
self
.
logger
=
create_logger
(
__name__
)
else
:
self
.
logger
=
logger
self
.
local_file_dir
=
get_data_path
()
def
index
(
self
,
q
:
str
,
count
:
int
=
0
)
->
List
[
Tuple
[
str
,
float
]]:
"""
返回值类型:[[相似id, score], [相似id, score], ...]
"""
raise
NotImplementedError
def
make_index
(
self
)
->
Dict
[
str
,
List
]:
raise
NotImplementedError
class
UserCounselorDefaultIndexer
(
Indexer
):
"""
[用户->咨询师]兜底关系索引
"""
def
__init__
(
self
,
logger
=
None
)
->
None
:
super
()
.
__init__
(
logger
)
self
.
data_manager
=
OrderDataManager
(
logger
)
self
.
index_file
=
os
.
path
.
join
(
self
.
local_file_dir
,
'index_list.txt'
)
self
.
count
=
100
self
.
index_data
=
[]
def
load_index_data
(
self
):
index
=
[]
with
open
(
self
.
index_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
index
=
[(
line
.
strip
(),
0.01
-
0.0001
*
index
)
for
index
,
line
in
enumerate
(
f
)]
self
.
index_data
=
index
def
index
(
self
,
q
=
''
,
count
=
0
)
->
List
[
Tuple
[
str
,
float
]]:
if
len
(
self
.
index_data
)
==
0
:
self
.
logger
.
error
(
'未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法'
)
raise
if
count
==
0
:
return
self
.
index_data
else
:
return
self
.
index_data
[:
count
]
def
make_index
(
self
)
->
Dict
[
str
,
List
]:
self
.
logger
.
info
(
''
)
self
.
logger
.
info
(
'开始构建[用户->咨询师]兜底关系索引'
)
df
=
self
.
data_manager
.
load_raw_data
()
self
.
logger
.
info
(
'构建索引数据加载完成,共加载
%
s 条数据'
,
len
(
df
))
supplier_cnter
=
Counter
(
df
[
'supplier_id'
])
index_list
=
[]
for
key
,
_
in
supplier_cnter
.
most_common
(
self
.
count
):
index_list
.
append
(
str
(
key
))
self
.
logger
.
info
(
'[用户->咨询师]兜底关系索引构建完成'
)
with
open
(
self
.
index_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
'
\n
'
.
join
(
index_list
))
self
.
logger
.
info
(
'[用户->咨询师]兜底关系索引构建完成已保存'
)
return
index_list
class
UserCounselorOrderIndexer
(
Indexer
):
"""
基于订单数据的[用户->咨询师]关系索引
"""
def
__init__
(
self
,
logger
=
None
)
->
None
:
super
()
.
__init__
(
logger
)
self
.
data_manager
=
OrderDataManager
(
logger
)
self
.
index_file
=
os
.
path
.
join
(
self
.
local_file_dir
,
'user_counselor_order_index.json'
)
self
.
index_data
=
{}
self
.
now
=
datetime
.
now
()
def
_compute_score
(
self
,
infos
:
List
)
->
float
:
w
=
[
0
,
0
,
0
,
0
]
for
[
_price
,
dt
]
in
infos
:
price
=
float
(
_price
)
date
=
datetime
.
strptime
(
dt
,
'
%
Y-
%
m-
%
d'
)
if
(
self
.
now
-
date
)
<=
timedelta
(
days
=
7
):
w
[
0
]
=
max
(
1.
,
w
[
0
],
price
/
400
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
30
):
w
[
1
]
=
max
(
1.
,
w
[
1
],
price
/
400
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
180
):
w
[
2
]
=
max
(
1.
,
w
[
2
],
price
/
400
)
else
:
w
[
3
]
=
max
(
1.
,
w
[
3
],
price
/
400
)
value
=
w
[
0
]
*
0.5
+
w
[
1
]
*
0.25
+
w
[
2
]
*
0.15
+
w
[
3
]
*
0.1
return
value
def
make_index
(
self
)
->
Dict
[
str
,
List
]:
self
.
logger
.
info
(
''
)
self
.
logger
.
info
(
'开始构建基于订单数据的[用户->咨询师]关系索引'
)
df
=
self
.
data_manager
.
load_raw_data
()
self
.
logger
.
info
(
'构建索引数据加载完成,共加载
%
s 条数据'
,
len
(
df
))
user_order
=
{}
for
index
,
row
in
df
.
iterrows
():
uid
,
supplier_id
=
row
[
'uid'
],
row
[
'supplier_id'
]
if
uid
not
in
user_order
:
user_order
[
uid
]
=
{}
if
supplier_id
not
in
user_order
[
uid
]:
user_order
[
uid
][
supplier_id
]
=
[]
user_order
[
uid
][
supplier_id
]
.
append
([
row
[
'price'
],
row
[
'update_time'
]])
index
=
{}
for
uid
,
orders
in
user_order
.
items
():
supplier_values
=
[]
for
supplier_id
,
infos
in
orders
.
items
():
value
=
self
.
_compute_score
(
infos
)
supplier_values
.
append
((
supplier_id
,
value
))
index
[
uid
]
=
sorted
(
supplier_values
,
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
self
.
logger
.
info
(
'基于订单数据的[用户->咨询师]关系索引构建完成,共构建
%
s 条数据'
,
len
(
index
))
with
open
(
self
.
index_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
index
,
f
,
ensure_ascii
=
False
)
self
.
logger
.
info
(
'基于订单数据的[用户->咨询师]关系索引数据已保存,共有用户
%
s'
,
len
(
index
))
def
load_index_data
(
self
):
index
=
[]
with
open
(
self
.
index_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
index
=
json
.
load
(
f
)
self
.
index_data
=
index
def
index
(
self
,
q
=
''
,
count
=
0
)
->
List
[
Tuple
[
str
,
float
]]:
if
len
(
self
.
index_data
)
==
0
:
self
.
logger
.
error
(
'未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法'
)
raise
if
count
==
0
:
return
self
.
index_data
.
get
(
q
,
[])
else
:
return
self
.
index_data
.
get
(
q
,
[])[:
count
]
class
UserCounselorChatIndexer
(
Indexer
):
"""
基于询单数据的[用户->咨询师]关系索引
"""
def
__init__
(
self
,
logger
=
None
)
->
None
:
super
()
.
__init__
(
logger
)
self
.
data_manager
=
ChatDataManager
(
logger
)
self
.
index_file
=
os
.
path
.
join
(
self
.
local_file_dir
,
'user_counselor_chat_index.json'
)
self
.
index_data
=
{}
self
.
now
=
datetime
.
now
()
def
_compute_score
(
self
,
infos
:
List
)
->
float
:
w
=
[
0
,
0
,
0
,
0
]
for
[
dt
,
_u2d
,
_d2u
]
in
infos
:
u2d
,
d2u
=
int
(
_u2d
),
int
(
_d2u
)
date
=
datetime
.
strptime
(
dt
,
'
%
Y-
%
m-
%
d'
)
if
(
self
.
now
-
date
)
<=
timedelta
(
days
=
7
):
w
[
0
]
=
max
(
1.
,
w
[
0
],
(
u2d
+
d2u
)
/
20
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
30
):
w
[
1
]
=
max
(
1.
,
w
[
1
],
(
u2d
+
d2u
)
/
20
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
180
):
w
[
2
]
=
max
(
1.
,
w
[
2
],
(
u2d
+
d2u
)
/
20
)
else
:
w
[
3
]
=
max
(
1.
,
w
[
3
],
(
u2d
+
d2u
)
/
20
)
value
=
w
[
0
]
*
0.5
+
w
[
1
]
*
0.25
+
w
[
2
]
*
0.15
+
w
[
3
]
*
0.1
return
value
def
make_index
(
self
)
->
Dict
[
str
,
List
]:
self
.
logger
.
info
(
''
)
self
.
logger
.
info
(
'开始构建基于询单数据的[用户->咨询师]关系索引'
)
df
=
self
.
data_manager
.
load_raw_data
()
self
.
logger
.
info
(
'构建索引数据加载完成,共加载
%
s 条数据'
,
len
(
df
))
user_chat
=
{}
for
index
,
row
in
df
.
iterrows
():
uid
,
supplier_id
=
row
[
'uid'
],
row
[
'doctor_id'
]
if
uid
not
in
user_chat
:
user_chat
[
uid
]
=
{}
if
supplier_id
not
in
user_chat
[
uid
]:
user_chat
[
uid
][
supplier_id
]
=
[]
user_chat
[
uid
][
supplier_id
]
.
append
([
row
[
'dt'
],
row
[
'user_to_doctor'
],
row
[
'doctor_to_user'
]])
index
=
{}
for
uid
,
chats
in
user_chat
.
items
():
supplier_values
=
[]
for
supplier_id
,
infos
in
chats
.
items
():
value
=
self
.
_compute_score
(
infos
)
supplier_values
.
append
([
supplier_id
,
value
])
index
[
uid
]
=
sorted
(
supplier_values
,
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
self
.
logger
.
info
(
'基于询单数据的[用户->咨询师]关系索引构建完成,共构建
%
s 条数据'
,
len
(
index
))
with
open
(
self
.
index_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
index
,
f
,
ensure_ascii
=
False
)
self
.
logger
.
info
(
'基于询单数据的[用户->咨询师]关系索引数据已保存,共有用户
%
s'
,
len
(
index
))
def
load_index_data
(
self
):
index
=
[]
with
open
(
self
.
index_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
index
=
json
.
load
(
f
)
self
.
index_data
=
index
def
index
(
self
,
q
=
''
,
count
=
0
)
->
List
[
Tuple
[
str
,
float
]]:
if
len
(
self
.
index_data
)
==
0
:
self
.
logger
.
error
(
'未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法'
)
raise
if
count
==
0
:
return
self
.
index_data
.
get
(
q
,
[])
else
:
return
self
.
index_data
.
get
(
q
,
[])[:
count
]
class
UserCounselorCombinationIndexer
(
Indexer
):
"""
基于多种数据组合的[用户->咨询师]关系索引
"""
def
__init__
(
self
,
order_w
=
0.6
,
chat_w
=
0.4
,
logger
=
None
)
->
None
:
super
()
.
__init__
(
logger
)
self
.
order_w
=
order_w
self
.
chat_w
=
chat_w
self
.
index_file
=
os
.
path
.
join
(
self
.
local_file_dir
,
'user_counselor_combination_index.json'
)
self
.
index_data
=
{}
def
make_index
(
self
)
->
Dict
[
str
,
List
]:
self
.
logger
.
info
(
''
)
self
.
logger
.
info
(
'开始构建基于多种数据组合的[用户->咨询师]关系索引'
)
order_indexer
=
UserCounselorOrderIndexer
(
self
.
logger
)
chat_indexer
=
UserCounselorChatIndexer
(
self
.
logger
)
order_indexer
.
load_index_data
()
chat_indexer
.
load_index_data
()
self
.
logger
.
info
(
'构建索引数据加载完成'
)
index
=
{}
for
uid
,
counselors
in
order_indexer
.
index_data
.
items
():
chat_index
=
{
c_id
:
value
for
c_id
,
value
in
chat_indexer
.
index_data
.
get
(
uid
,
[])
}
new_counselors
=
[]
for
(
c_id
,
value
)
in
counselors
:
merge_value
=
value
*
self
.
order_w
+
chat_index
.
get
(
c_id
,
0
)
*
self
.
chat_w
new_counselors
.
append
((
c_id
,
merge_value
))
index
[
uid
]
=
sorted
(
new_counselors
,
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
self
.
logger
.
info
(
'基于多种数据组合的[用户->咨询师]关系索引构建完成,共构建
%
s 条数据'
,
len
(
index
))
with
open
(
self
.
index_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
index
,
f
,
ensure_ascii
=
False
)
return
index
def
load_index_data
(
self
):
index
=
[]
with
open
(
self
.
index_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
index
=
json
.
load
(
f
)
self
.
index_data
=
index
def
index
(
self
,
q
=
''
,
count
=
0
)
->
List
[
Tuple
[
str
,
float
]]:
if
len
(
self
.
index_data
)
==
0
:
self
.
logger
.
error
(
'未加载索引数据,使用`index`函数之前,确认对应已执行执行 `load_index_data()` 方法'
)
raise
if
count
==
0
:
return
self
.
index_data
.
get
(
q
,
[])
else
:
return
self
.
index_data
.
get
(
q
,
[])[:
count
]
if
__name__
==
'__main__'
:
indexer
=
UserCounselorDefaultIndexer
()
indexer
.
make_index
()
indexer
=
UserCounselorOrderIndexer
()
indexer
.
make_index
()
indexer
=
UserCounselorChatIndexer
()
indexer
.
make_index
()
indexer
=
UserCounselorCombinationIndexer
()
indexer
.
make_index
()
\ No newline at end of file
src/core/manager/__init__.py
View file @
523f5373
...
...
@@ -6,5 +6,3 @@ from .database_manager import DatabaseDataManager
from
.profile_manager
import
ProfileManager
from
.chat_data_manager
import
ChatDataManager
from
.order_data_manager
import
OrderDataManager
\ No newline at end of file
from
.user_counselor_index_manager
import
UserCounselorIndexManager
\ No newline at end of file
src/core/manager/chat_data_manager.py
View file @
523f5373
# -*- coding: utf-8 -*-
import
os
import
json
from
datetime
import
datetime
,
timedelta
import
pandas
as
pd
from
ydl_ai_recommender.src.utils.log
import
create_logger
...
...
@@ -13,8 +9,6 @@ from ydl_ai_recommender.src.core.manager import DatabaseDataManager
class
ChatDataManager
(
DatabaseDataManager
):
def
__init__
(
self
,
client
=
None
)
->
None
:
super
()
.
__init__
(
client
,
create_logger
(
__name__
,
'chat_data_manager.log'
))
self
.
now
=
datetime
.
now
()
def
_make_query_sql
(
self
,
conditions
=
None
):
...
...
@@ -30,7 +24,6 @@ class ChatDataManager(DatabaseDataManager):
sql
+=
condition_sql
return
sql
def
update_data
(
self
):
""" 从数据库中拉取最新订单数据并保存 """
sql
=
self
.
_make_query_sql
()
...
...
@@ -40,7 +33,6 @@ class ChatDataManager(DatabaseDataManager):
self
.
save_csv_data
(
df
,
'all_chat_info.csv'
)
return
df
def
update_test_data
(
self
,
conditions
):
""" 从数据库中拉取指定条件订单用于测试 """
...
...
@@ -50,69 +42,13 @@ class ChatDataManager(DatabaseDataManager):
df
=
pd
.
DataFrame
(
all_data
)
self
.
save_csv_data
(
df
,
'test_chat_info.csv'
)
def
load_raw_data
(
self
):
return
self
.
load_csv_data
(
'all_chat_info.csv'
)
def
load_test_data
(
self
):
return
self
.
load_csv_data
(
'test_chat_info.csv'
)
def
make_index
(
self
):
"""
构建索引
用户-咨询师 索引
"""
self
.
logger
.
info
(
''
)
self
.
logger
.
info
(
'开始构建 用户-咨询师 索引'
)
df
=
self
.
load_raw_data
()
self
.
logger
.
info
(
'本地用户咨询师对话数据加载完成,共加载
%
s 条数据'
,
len
(
df
))
user_chat
=
{}
for
index
,
row
in
df
.
iterrows
():
uid
,
supplier_id
=
row
[
'uid'
],
row
[
'doctor_id'
]
if
uid
not
in
user_chat
:
user_chat
[
uid
]
=
{}
if
supplier_id
not
in
user_chat
[
uid
]:
user_chat
[
uid
][
supplier_id
]
=
[]
user_chat
[
uid
][
supplier_id
]
.
append
([
row
[
'dt'
],
row
[
'user_to_doctor'
],
row
[
'doctor_to_user'
]])
def
compute_score
(
infos
):
w
=
[
0
,
0
,
0
,
0
]
for
[
dt
,
_u2d
,
_d2u
]
in
infos
:
u2d
,
d2u
=
int
(
_u2d
),
int
(
_d2u
)
date
=
datetime
.
strptime
(
dt
,
'
%
Y-
%
m-
%
d'
)
if
(
self
.
now
-
date
)
<=
timedelta
(
days
=
7
):
w
[
0
]
=
max
(
1.
,
w
[
0
],
(
u2d
+
d2u
)
/
20
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
30
):
w
[
1
]
=
max
(
1.
,
w
[
1
],
(
u2d
+
d2u
)
/
20
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
180
):
w
[
2
]
=
max
(
1.
,
w
[
2
],
(
u2d
+
d2u
)
/
20
)
else
:
w
[
3
]
=
max
(
1.
,
w
[
3
],
(
u2d
+
d2u
)
/
20
)
value
=
w
[
0
]
*
0.5
+
w
[
1
]
*
0.25
+
w
[
2
]
*
0.15
+
w
[
3
]
*
0.1
return
value
index
=
{}
for
uid
,
chats
in
user_chat
.
items
():
supplier_values
=
[]
for
supplier_id
,
infos
in
chats
.
items
():
# 日期越近权重越大
value
=
compute_score
(
infos
)
supplier_values
.
append
([
supplier_id
,
value
])
index
[
uid
]
=
sorted
(
supplier_values
,
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
self
.
logger
.
info
(
'用户-咨询师 询单索引构建完成,共构建
%
s 条数据'
,
len
(
index
))
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_chat_index.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
index
,
f
,
ensure_ascii
=
False
)
if
__name__
==
'__main__'
:
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
...
...
src/core/manager/database_manager.py
View file @
523f5373
...
...
@@ -8,20 +8,16 @@ import pandas as pd
from
ydl_ai_recommender.src.utils
import
get_data_path
from
ydl_ai_recommender.src.utils.log
import
create_logger
from
.manager
import
Manager
# class Manager():
# def __init__(self, logger=None) -> None:
# if logger is None:
# self.logger = create_logger(__name__)
# else:
# self.logger = logger
class
Manager
():
def
__init__
(
self
,
logger
=
None
)
->
None
:
if
logger
is
None
:
self
.
logger
=
create_logger
(
__name__
)
else
:
self
.
logger
=
logger
self
.
local_file_dir
=
get_data_path
()
def
make_index
(
self
):
raise
NotImplemented
# self.local_file_dir = get_data_path()
class
DatabaseDataManager
(
Manager
):
...
...
@@ -29,7 +25,6 @@ class DatabaseDataManager(Manager):
super
()
.
__init__
(
logger
)
self
.
client
=
client
def
fetch_data_from_db
(
self
,
sql
:
str
)
->
List
:
if
self
.
client
is
None
:
self
.
logger
.
error
(
'未连接数据库'
)
...
...
@@ -37,34 +32,28 @@ class DatabaseDataManager(Manager):
return
self
.
client
.
query
(
sql
)
def
load_xlsx_data
(
self
,
filename
):
return
pd
.
read_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
dtype
=
str
)
def
save_xlsx_data
(
self
,
df
,
filename
):
df
.
to_excel
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
index
=
None
)
def
load_csv_data
(
self
,
filename
):
return
pd
.
read_csv
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
dtype
=
str
)
def
save_csv_data
(
self
,
df
,
filename
):
df
.
to_csv
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
encoding
=
'utf-8'
,
index
=
False
)
def
load_json_data
(
self
,
filename
):
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
return
json
.
load
(
f
)
def
save_json_data
(
self
,
data
,
filename
):
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
filename
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
return
json
.
dump
(
data
,
f
,
ensure_ascii
=
False
)
def
update_data
(
self
):
raise
NotImplementedError
def
update_data
(
self
,
sql
,
filename
):
_
,
all_data
=
self
.
fetch_data_from_db
(
sql
)
df
=
pd
.
DataFrame
(
all_data
)
self
.
save_xlsx_data
(
df
,
filename
)
def
update_test_data
(
self
,
conditions
):
raise
NotImplementedError
src/core/manager/manager.py
View file @
523f5373
...
...
@@ -15,4 +15,4 @@ class Manager():
def
make_index
(
self
):
raise
NotImplemented
\ No newline at end of file
raise
NotImplementedError
\ No newline at end of file
src/core/manager/order_data_manager.py
View file @
523f5373
...
...
@@ -16,7 +16,6 @@ class OrderDataManager(DatabaseDataManager):
super
()
.
__init__
(
client
,
create_logger
(
__name__
,
'order_data_manager.log'
))
self
.
now
=
datetime
.
now
()
def
_make_query_sql
(
self
,
conditions
=
None
):
condition_sql
=
''
if
conditions
:
...
...
@@ -28,8 +27,7 @@ class OrderDataManager(DatabaseDataManager):
sql
+=
condition_sql
return
sql
def
update_order_data
(
self
):
def
update_data
(
self
):
""" 从数据库中拉取最新订单数据并保存 """
sql
=
self
.
_make_query_sql
()
_
,
all_data
=
self
.
fetch_data_from_db
(
sql
)
...
...
@@ -38,7 +36,7 @@ class OrderDataManager(DatabaseDataManager):
self
.
save_xlsx_data
(
df
,
'all_order_info.xlsx'
)
def
update_test_
order_
data
(
self
,
conditions
):
def
update_test_data
(
self
,
conditions
):
""" 从数据库中拉取指定条件订单用于测试 """
sql
=
self
.
_make_query_sql
(
conditions
)
...
...
@@ -56,82 +54,6 @@ class OrderDataManager(DatabaseDataManager):
return
self
.
load_xlsx_data
(
'test_order_info.xlsx'
)
def
make_index
(
self
):
"""
构建索引
用户-咨询师 索引
top100 咨询师列表 用于冷启动
"""
self
.
logger
.
info
(
''
)
self
.
logger
.
info
(
'开始构建 用户-咨询师 索引'
)
df
=
self
.
load_raw_data
()
self
.
logger
.
info
(
'本地订单加载数据完成,共加载
%
s 条数据'
,
len
(
df
))
user_order
=
{}
for
index
,
row
in
df
.
iterrows
():
uid
,
supplier_id
=
row
[
'uid'
],
row
[
'supplier_id'
]
if
uid
not
in
user_order
:
user_order
[
uid
]
=
{}
if
supplier_id
not
in
user_order
[
uid
]:
user_order
[
uid
][
supplier_id
]
=
[]
user_order
[
uid
][
supplier_id
]
.
append
([
row
[
'price'
],
row
[
'update_time'
]])
def
compute_score
(
infos
):
w
=
[
0
,
0
,
0
,
0
]
for
[
_price
,
dt
]
in
infos
:
price
=
float
(
_price
)
date
=
datetime
.
strptime
(
dt
,
'
%
Y-
%
m-
%
d'
)
if
(
self
.
now
-
date
)
<=
timedelta
(
days
=
7
):
w
[
0
]
=
max
(
1.
,
w
[
0
],
price
/
400
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
30
):
w
[
1
]
=
max
(
1.
,
w
[
1
],
price
/
400
)
elif
(
self
.
now
-
date
)
<=
timedelta
(
days
=
180
):
w
[
2
]
=
max
(
1.
,
w
[
2
],
price
/
400
)
else
:
w
[
3
]
=
max
(
1.
,
w
[
3
],
price
/
400
)
value
=
w
[
0
]
*
0.5
+
w
[
1
]
*
0.25
+
w
[
2
]
*
0.15
+
w
[
3
]
*
0.1
return
value
index
=
{}
for
uid
,
orders
in
user_order
.
items
():
supplier_values
=
[]
for
supplier_id
,
infos
in
orders
.
items
():
# 订单越多排序约靠前,相同数量订单,最新订单约晚越靠前
value
=
compute_score
(
infos
)
supplier_values
.
append
([
supplier_id
,
value
])
# value = len(infos)
# latest_time = max([info[1] for info in infos])
# supplier_values.append([supplier_id, value, latest_time])
# index[uid] = sorted(supplier_values, key=lambda x: (x[2], x[1]), reverse=True)
index
[
uid
]
=
sorted
(
supplier_values
,
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
self
.
logger
.
info
(
'用户-咨询师 索引构建完成,共构建
%
s 条数据'
,
len
(
index
))
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_index.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
index
,
f
,
ensure_ascii
=
False
)
self
.
logger
.
info
(
'用户-咨询师 索引数据已保存,共有用户
%
s'
,
len
(
index
))
# 订单最多的咨询师
supplier_cnter
=
Counter
(
df
[
'supplier_id'
])
top100_supplier
=
[]
for
key
,
_
in
supplier_cnter
.
most_common
(
100
):
top100_supplier
.
append
(
str
(
key
))
self
.
logger
.
info
(
'top100 订单量咨询师统计完成'
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'top100_supplier.txt'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
'
\n
'
.
join
(
top100_supplier
))
self
.
logger
.
info
(
'top100 订单量咨询师列表已保存'
)
if
__name__
==
'__main__'
:
manager
=
OrderDataManager
()
print
(
manager
.
make_index
())
\ No newline at end of file
src/core/manager/profile_manager.py
View file @
523f5373
...
...
@@ -6,7 +6,7 @@ from typing import List
import
pandas
as
pd
from
ydl_ai_recommender.src.core.profile
import
profile_converters
from
ydl_ai_recommender.src.core.profile
import
encode_profile
from
ydl_ai_recommender.src.core.manager
import
DatabaseDataManager
from
ydl_ai_recommender.src.utils.log
import
create_logger
...
...
@@ -29,8 +29,7 @@ class ProfileManager(DatabaseDataManager):
sql
+=
' WHERE uid IN (SELECT DISTINCT uid FROM ods.ods_ydl_standard_order{})'
.
format
(
condition_sql
)
return
sql
def
update_profile
(
self
):
def
update_data
(
self
):
""" 从数据库中拉取最新画像特征并保存 """
sql
=
self
.
_make_query_sql
()
...
...
@@ -39,8 +38,7 @@ class ProfileManager(DatabaseDataManager):
df
=
pd
.
DataFrame
(
all_data
)
self
.
save_xlsx_data
(
df
,
'all_profile.xlsx'
)
def
update_test_profile
(
self
,
conditions
):
def
update_test_data
(
self
,
conditions
):
""" 从数据库中拉取指定条件画像信息用于测试 """
sql
=
self
.
_make_query_sql
(
conditions
)
...
...
@@ -49,38 +47,12 @@ class ProfileManager(DatabaseDataManager):
df
=
pd
.
DataFrame
(
all_data
)
self
.
save_xlsx_data
(
df
,
'test_profile.xlsx'
)
def
_load_profile_data
(
self
):
return
self
.
load_xlsx_data
(
'all_profile.xlsx'
)
def
load_test_profile_data
(
self
):
return
self
.
load_xlsx_data
(
'test_profile.xlsx'
)
def
profile_to_embedding
(
self
,
profile
):
"""
将用户画像信息转换为向量
"""
embedding
=
[]
for
[
name
,
converter
]
in
profile_converters
:
embedding
.
extend
(
converter
.
convert
(
profile
[
name
]))
return
embedding
def
embedding_to_profile
(
self
,
embedding
):
"""
向量转换为用户画像
"""
ret
=
{}
si
=
0
for
[
name
,
converter
]
in
profile_converters
:
ei
=
si
+
converter
.
dim
ret
[
name
]
=
converter
.
inconvert
(
embedding
[
si
:
ei
])
si
=
ei
return
ret
def
make_embeddings
(
self
):
user_profiles
=
self
.
_load_profile_data
()
self
.
logger
.
info
(
'订单用户画像数据加载完成,共加载
%
s 条'
,
len
(
user_profiles
))
...
...
@@ -88,7 +60,7 @@ class ProfileManager(DatabaseDataManager):
self
.
logger
.
info
(
'开始构建订单用户的用户画像向量'
)
for
_
,
profile
in
user_profiles
.
iterrows
():
user_ids
.
append
(
str
(
profile
[
'uid'
]))
embeddings
.
append
(
self
.
profile_to_embedding
(
profile
))
embeddings
.
append
(
encode_profile
(
profile
))
self
.
logger
.
info
(
'用户画像向量构建完成,共构建
%
s 用户'
,
len
(
user_ids
))
...
...
@@ -100,7 +72,6 @@ class ProfileManager(DatabaseDataManager):
return
embeddings
def
make_virtual_embedding
(
self
):
user_ids
=
[]
embeddings
=
[]
...
...
src/core/manager/user_counselor_index_manager.py
deleted
100644 → 0
View file @
7401bbe5
# -*- coding: utf-8 -*-
import
os
import
json
from
ydl_ai_recommender.src.utils.log
import
create_logger
from
ydl_ai_recommender.src.core.manager
import
Manager
# from ydl_ai_recommender.src.core.manager import OrderDataManager, ChatDataManager
class
UserCounselorIndexManager
(
Manager
):
def
__init__
(
self
)
->
None
:
super
()
.
__init__
(
create_logger
(
__name__
,
'order_data_manager.log'
))
# self.order_data_manager = OrderDataManager()
# self.chat_data_manager = ChatDataManager()
def
make_index
(
self
):
self
.
logger
.
info
(
'开始构建用户-咨询师 合并索引'
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_index.json'
),
encoding
=
'utf-8'
)
as
f
:
user_doctor_index
=
json
.
load
(
f
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_doctor_chat_index.json'
),
encoding
=
'utf-8'
)
as
f
:
user_doctor_chat_index
=
json
.
load
(
f
)
merged_index
=
{}
for
uid
,
counselors
in
user_doctor_index
.
items
():
chat_index
=
{
c_id
:
value
for
c_id
,
value
in
user_doctor_chat_index
.
get
(
uid
,
[])
}
new_counselors
=
[]
for
[
c_id
,
value
]
in
counselors
:
if
c_id
in
chat_index
:
merge_value
=
value
*
0.6
+
chat_index
[
c_id
]
*
0.4
else
:
merge_value
=
value
*
0.6
new_counselors
.
append
([
c_id
,
merge_value
])
merged_index
[
uid
]
=
sorted
(
new_counselors
,
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
self
.
logger
.
info
(
'用户-咨询师 合并索引构建完成,共构建
%
s 条数据'
,
len
(
merged_index
))
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'merged_user_doctor_index.json'
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
json
.
dump
(
merged_index
,
f
,
ensure_ascii
=
False
)
if
__name__
==
'__main__'
:
manager
=
UserCounselorIndexManager
()
manager
.
make_index
()
\ No newline at end of file
src/core/profile
/__init__
.py
→
src/core/profile.py
View file @
523f5373
...
...
@@ -5,20 +5,17 @@ from typing import Dict, List, Any, Union
import
pandas
as
pd
# from .country_code_profile import CountryCodeProfile
# from .profile import ChannelIdTypeProfile
class
BaseProfile
():
def
__init__
(
self
)
->
None
:
self
.
dim
:
int
=
0
self
.
dim
:
int
=
0
def
convert
(
self
,
value
):
raise
NotImplemented
raise
NotImplemented
Error
def
inconvert
(
self
,
embedding
:
List
[
Union
[
int
,
float
]])
->
str
:
raise
NotImplemented
raise
NotImplemented
Error
class
CountryCodeProfile
(
BaseProfile
):
...
...
@@ -29,7 +26,7 @@ class CountryCodeProfile(BaseProfile):
def
convert
(
self
,
value
):
try
:
value
=
int
(
value
)
except
Exception
as
e
:
except
Exception
:
return
[
0
,
0
,
1
]
if
value
==
86
:
return
[
1
,
0
,
0
]
...
...
@@ -53,7 +50,7 @@ class ChannelIdTypeProfile(BaseProfile):
def
convert
(
self
,
value
):
try
:
value
=
int
(
value
)
except
Exception
as
e
:
except
Exception
:
return
[
0
,
0
,
1
]
if
value
==
1
:
...
...
@@ -86,7 +83,7 @@ class FfromLoginProfile(BaseProfile):
ret
=
[
0
,
0
,
0
,
0
,
0
]
try
:
value
=
value
.
lower
()
except
Exception
as
e
:
except
Exception
:
return
ret
for
i
,
v
in
enumerate
(
self
.
brand_list
):
...
...
@@ -123,7 +120,7 @@ class UserPreferenceCateProfile(BaseProfile):
if
isinstance
(
value
,
str
):
try
:
value
=
json
.
loads
(
value
)
except
Exception
as
e
:
except
Exception
:
return
ret
for
info
in
value
:
...
...
@@ -167,11 +164,10 @@ class NumClassProfile(BaseProfile):
value
=
float
(
value
)
index
=
self
.
value_index
(
value
)
ret
[
index
]
=
1
except
:
except
Exception
:
return
ret
return
ret
def
inconvert
(
self
,
embedding
):
ret
=
''
# 确保embedding中有包含1的值
...
...
@@ -245,7 +241,6 @@ class MultiChoiceProfile(BaseProfile):
self
.
option_dict
=
option_dict
self
.
re_option_dict
=
{
v
:
k
for
k
,
v
in
self
.
option_dict
.
items
()}
def
convert
(
self
,
value
:
List
):
ret
=
[
0
]
*
len
(
self
.
option_dict
)
if
pd
.
isnull
(
value
):
...
...
@@ -261,7 +256,6 @@ class MultiChoiceProfile(BaseProfile):
pass
return
ret
def
inconvert
(
self
,
embedding
):
ret
=
[]
...
...
@@ -284,7 +278,6 @@ class CityProfile(BaseProfile):
self
.
level
=
level
self
.
dim
=
self
.
level
*
10
def
convert
(
self
,
value
):
ret
=
[
0
]
*
self
.
dim
...
...
@@ -297,11 +290,10 @@ class CityProfile(BaseProfile):
n
=
int
(
_n
)
ret
[
i
*
10
+
n
]
=
1
except
Exception
as
e
:
except
Exception
:
pass
return
ret
def
inconvert
(
self
,
embedding
):
# 邮编固定都是6
ret
=
[
0
]
*
6
...
...
@@ -318,7 +310,6 @@ class AidiCstBiasCityProfile(CityProfile):
def
__init__
(
self
,
level
=
2
)
->
None
:
super
()
.
__init__
(
level
=
level
)
def
convert
(
self
,
value_object
):
ret
=
[
0
]
*
self
.
dim
...
...
@@ -331,7 +322,7 @@ class AidiCstBiasCityProfile(CityProfile):
if
isinstance
(
value_object
,
str
):
try
:
value_object
=
json
.
loads
(
value_object
)
except
Exception
as
e
:
except
Exception
:
pass
if
isinstance
(
value_object
,
dict
):
...
...
@@ -340,7 +331,7 @@ class AidiCstBiasCityProfile(CityProfile):
for
i
,
_n
in
enumerate
(
value
[:
self
.
level
]):
n
=
int
(
_n
)
ret
[
i
*
10
+
n
]
=
1
except
Exception
as
e
:
except
Exception
:
pass
return
ret
...
...
@@ -367,3 +358,25 @@ profile_converters = [
[
'd30_session_num'
,
NumClassProfile
([
0
,
1
])],
]
def
encode_profile
(
profile
):
"""
将用户画像信息转换为向量
"""
embedding
=
[]
for
[
name
,
converter
]
in
profile_converters
:
embedding
.
extend
(
converter
.
convert
(
profile
[
name
]))
return
embedding
def
decode_profile
(
embedding
):
"""
向量转换为用户画像
"""
ret
=
{}
si
=
0
for
[
name
,
converter
]
in
profile_converters
:
ei
=
si
+
converter
.
dim
ret
[
name
]
=
converter
.
inconvert
(
embedding
[
si
:
ei
])
si
=
ei
return
ret
src/core/profile/country_code_profile.py
deleted
100644 → 0
View file @
7401bbe5
# -*- coding: utf-8 -*-
class
CountryCodeProfile
():
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
value
):
try
:
value
=
int
(
value
)
except
Exception
as
e
:
return
[
0
,
0
,
1
]
if
value
==
86
:
return
[
1
,
0
,
0
]
else
:
return
[
0
,
1
,
0
]
\ No newline at end of file
src/core/profile/profile.py
deleted
100644 → 0
View file @
7401bbe5
# -*- coding: utf-8 -*-
class
ChannelIdTypeProfile
():
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
value
):
try
:
value
=
int
(
value
)
except
Exception
as
e
:
return
[
0
,
0
,
1
]
if
value
==
1
:
return
[
1
,
0
,
0
]
elif
value
==
2
:
return
[
0
,
1
,
0
]
else
:
return
[
0
,
0
,
1
]
\ No newline at end of file
src/core/recommender.py
View file @
523f5373
...
...
@@ -7,7 +7,9 @@ from typing import List, Dict
import
faiss
import
numpy
as
np
from
ydl_ai_recommender.src.core.manager
import
ProfileManager
from
ydl_ai_recommender.src.core.indexer
import
UserCounselorDefaultIndexer
from
ydl_ai_recommender.src.core.indexer
import
UserCounselorCombinationIndexer
from
ydl_ai_recommender.src.core.profile
import
encode_profile
from
ydl_ai_recommender.src.data.mysql_client
import
MySQLClient
from
ydl_ai_recommender.src.utils
import
get_conf_path
,
get_data_path
from
ydl_ai_recommender.src.utils.log
import
create_logger
...
...
@@ -19,7 +21,7 @@ class Recommender():
pass
def
recommend
(
self
,
user
)
->
List
:
raise
NotImplemented
raise
NotImplemented
Error
class
UserCFRecommender
(
Recommender
):
...
...
@@ -37,7 +39,11 @@ class UserCFRecommender(Recommender):
else
:
self
.
logger
.
warn
(
'未连接数据库'
)
self
.
manager
=
ProfileManager
()
self
.
default_indexer
=
UserCounselorDefaultIndexer
(
self
.
logger
)
self
.
default_indexer
.
load_index_data
()
self
.
indexer
=
UserCounselorCombinationIndexer
(
self
.
logger
)
self
.
indexer
.
load_index_data
()
self
.
local_file_dir
=
get_data_path
()
self
.
load_data
()
...
...
@@ -45,8 +51,6 @@ class UserCFRecommender(Recommender):
def
load_data
(
self
):
order_user_embedding
=
[]
order_user_ids
=
[]
order_user_counselor_index
=
{}
default_counselor
=
[]
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings_ids.txt'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
order_user_ids
=
[
line
.
strip
()
for
line
in
f
]
...
...
@@ -54,21 +58,8 @@ class UserCFRecommender(Recommender):
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'user_embeddings.json'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
order_user_embedding
=
json
.
load
(
f
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'merged_user_doctor_index.json'
),
encoding
=
'utf-8'
)
as
f
:
# with open(os.path.join(self.local_file_dir, 'user_doctor_index.json'), encoding='utf-8') as f:
order_user_counselor_index
=
json
.
load
(
f
)
with
open
(
os
.
path
.
join
(
self
.
local_file_dir
,
'top100_supplier.txt'
),
'r'
,
encoding
=
'utf-8'
)
as
f
:
default_counselor
=
[
line
.
strip
()
for
line
in
f
]
self
.
order_user_embedding
=
order_user_embedding
self
.
order_user_ids
=
order_user_ids
self
.
order_user_counselor_index
=
order_user_counselor_index
self
.
default_counselor
=
[{
'counselor'
:
str
(
user
),
'score'
:
1
-
0.01
*
index
,
'from'
:
'top_100'
,
}
for
index
,
user
in
enumerate
(
default_counselor
)]
self
.
index
=
faiss
.
IndexFlatL2
(
len
(
self
.
order_user_embedding
[
0
]))
self
.
index
.
add
(
np
.
array
(
self
.
order_user_embedding
))
...
...
@@ -88,28 +79,27 @@ class UserCFRecommender(Recommender):
return
[]
def
user_token
(
self
,
user_profile
):
return
self
.
manager
.
profile_to_embedding
(
user_profile
)
def
_recommend_top
(
self
,
size
=
100
):
return
self
.
default_counselor
[:
size
]
return
[{
'counselor'
:
str
(
c_id
),
'score'
:
score
,
'from'
:
'default'
,
}
for
[
c_id
,
score
]
in
self
.
default_indexer
.
index
(
size
)]
def
_recommend
(
self
,
user_embedding
):
D
,
I
=
self
.
index
.
search
(
np
.
array
([
user_embedding
]),
self
.
k
)
counselors
=
[]
for
idx
,
score
in
zip
(
I
[
0
],
D
[
0
]):
for
idx
,
s
imi_s
core
in
zip
(
I
[
0
],
D
[
0
]):
# 相似用户uid
similar_user_id
=
self
.
order_user_ids
[
idx
]
similar_user_counselor
=
self
.
order_user_counselor_index
.
get
(
similar_user_id
,
[])
similar_user_counselor
=
self
.
indexer
.
index
(
q
=
similar_user_id
,
count
=
self
.
top_n
)
recommend_data
=
[{
'counselor'
:
str
(
user
[
0
])
,
'score'
:
1
/
max
(
0.01
,
float
(
score
)
*
(
index
+
1
)),
'counselor'
:
c_id
,
'score'
:
score
/
max
(
0.01
,
float
(
simi_score
)),
'from'
:
'similar_users {}'
.
format
(
similar_user_id
),
}
for
index
,
user
in
enumerate
(
similar_user_counselor
[:
self
.
top_n
])
]
}
for
(
c_id
,
score
)
in
similar_user_counselor
]
counselors
.
extend
(
recommend_data
)
counselors
.
sort
(
key
=
lambda
x
:
x
[
'score'
],
reverse
=
True
)
...
...
@@ -117,7 +107,7 @@ class UserCFRecommender(Recommender):
def
recommend_with_profile
(
self
,
user_profile
,
size
=
0
,
is_merge
=
True
):
user_embedding
=
self
.
user_token
(
user_profile
)
user_embedding
=
encode_profile
(
user_profile
)
counselors
=
self
.
_recommend
(
user_embedding
)
# size == 0 时,不追加默认推荐咨询师
...
...
src/data/mysql_client.py
View file @
523f5373
...
...
@@ -40,7 +40,8 @@ class MySQLClient():
try
:
self
.
cursor
.
close
()
self
.
connection
.
close
()
self
.
logger
.
info
(
'dataset disconnected'
)
# 容易触发 NameError: name 'open' is not defined
# self.logger.info('dataset disconnected')
except
Exception
as
e
:
self
.
logger
.
error
(
'销毁 MySQLClient 失败'
,
exc_info
=
True
)
print
(
e
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment