From 8085ecfbfc32523ed434c4c93020b5a56bc1f626 Mon Sep 17 00:00:00 2001 From: DzvinkaYarish Date: Tue, 14 Jun 2022 12:53:45 +0300 Subject: [PATCH] Minor fixes --- chemprop/utils.py | 2 +- train.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/chemprop/utils.py b/chemprop/utils.py index 96e0a33..a990e49 100644 --- a/chemprop/utils.py +++ b/chemprop/utils.py @@ -167,7 +167,7 @@ def get_loss_func(args: Namespace) -> nn.Module: return nn.BCEWithLogitsLoss(reduction='none') if args.dataset_type == 'regression': - return nn.L1Loss(reduction='none') + return nn.L2Loss(reduction='none') if args.dataset_type == 'multiclass': return nn.CrossEntropyLoss(reduction='none') diff --git a/train.py b/train.py index bcaf38f..fb25f83 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,5 @@ -"""Trains a model on a dataset.""" +"""Entrypoint fot training.""" +"""Single or cross-validation training""" from chemprop.parsing import parse_train_args from chemprop.train import cross_validate