-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_pyg.py
42 lines (26 loc) · 875 Bytes
/
run_pyg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from loguru import logger
from utils import create_parser, set_global_seed, setup_mlflow, setup_logger, init_config, update_config
from pyg.train import train_gnn
from pyg.inference import infer_gnn
from pyg.evaluate import eval_gnn
def main():
# Initialize configuration
config = init_config()
# Setup arguments
args = create_parser(config)
# Update configurations
config = update_config(dict(vars(config)), vars(args))
# Setup logger
setup_logger(args)
# Setup MLflow
setup_mlflow(config)
# Set seed
set_global_seed(config["general_config"]["seed"])
if args.mode == 'train':
train_gnn(config)
if args.mode == 'evaluate':
eval_gnn(config)
elif args.mode == 'inference':
infer_gnn(config)
if __name__ == "__main__":
main()