Commit ae5af70b by 柴鹏飞

精排模型调整

parent 9664b944
...@@ -5,3 +5,5 @@ numpy ...@@ -5,3 +5,5 @@ numpy
openpyxl openpyxl
tornado==6.2 tornado==6.2
scikit-learn
...@@ -237,7 +237,7 @@ class MultiChoiceProfile(BaseProfile): ...@@ -237,7 +237,7 @@ class MultiChoiceProfile(BaseProfile):
def __init__(self, option_dict: Dict[Any, int]) -> None: def __init__(self, option_dict: Dict[Any, int]) -> None:
super().__init__() super().__init__()
self.dim = 6 self.dim = len(option_dict)
self.option_dict = option_dict self.option_dict = option_dict
self.re_option_dict = {v: k for k, v in self.option_dict.items()} self.re_option_dict = {v: k for k, v in self.option_dict.items()}
......
...@@ -34,7 +34,7 @@ class Ranker(): ...@@ -34,7 +34,7 @@ class Ranker():
self.counselor_embeddings_dimension = 171 self.counselor_embeddings_dimension = 171
self.counselor_embeddings = self._load_counselor_embeddings() self.counselor_embeddings = self._load_counselor_embeddings()
model_path = os.path.join(get_model_path(), 'ranker', 'LR_v1.1.pkl') model_path = os.path.join(get_model_path(), 'ranker', 'RF.1.0.pkl')
with open(model_path, 'rb') as f: with open(model_path, 'rb') as f:
self.model = pickle.load(f) self.model = pickle.load(f)
......
...@@ -41,7 +41,7 @@ def main(args): ...@@ -41,7 +41,7 @@ def main(args):
model.fit(x_train, y_train) model.fit(x_train, y_train)
print('模型训练完成') print('模型训练完成')
model_save_path = os.path.join(args.save_path, args.model + '.pkl') model_save_path = os.path.join(args.save_path, f'{args.model}.{args.save_version}.pkl')
with open(model_save_path, 'wb') as f: with open(model_save_path, 'wb') as f:
pickle.dump(model, f) pickle.dump(model, f)
print('模型已保存至 ', model_save_path) print('模型已保存至 ', model_save_path)
...@@ -63,6 +63,7 @@ if __name__ == '__main__': ...@@ -63,6 +63,7 @@ if __name__ == '__main__':
parser.add_argument('-m', '--model', type=str, default='LR', choices=list(models.keys()), help='模型类型') parser.add_argument('-m', '--model', type=str, default='LR', choices=list(models.keys()), help='模型类型')
parser.add_argument('--data_path', type=str, default='./data', help='训练数据存放目录') parser.add_argument('--data_path', type=str, default='./data', help='训练数据存放目录')
parser.add_argument('--save_path', type=str, default='./model', help='训练好的模型存放目录') parser.add_argument('--save_path', type=str, default='./model', help='训练好的模型存放目录')
parser.add_argument('-v', '--save_version', type=str, default='latest', help='模型版本')
parser.add_argument('--do_test', default=False, action='store_true', help='训练完成后执行测试程序') parser.add_argument('--do_test', default=False, action='store_true', help='训练完成后执行测试程序')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment