From 7ba7fa320d2ce0b41db06afa322461f6610d5a39 Mon Sep 17 00:00:00 2001 From: Chintan Shah Date: Sat, 7 Sep 2019 17:53:46 -0400 Subject: [PATCH] Using pytorch image --- DCRNN_CPU | 2 +- dcrnn_train.py | 2 +- model/pytorch/__init__.py | 0 model/pytorch/dcrnn_cell.py | 1 + model/tf/__init__.py | 0 model/{ => tf}/dcrnn_cell.py | 0 model/{ => tf}/dcrnn_model.py | 3 +-- model/{ => tf}/dcrnn_supervisor.py | 2 +- requirements.txt | 3 ++- run_demo.py | 2 +- 10 files changed, 8 insertions(+), 7 deletions(-) create mode 100644 model/pytorch/__init__.py create mode 100644 model/pytorch/dcrnn_cell.py create mode 100644 model/tf/__init__.py rename model/{ => tf}/dcrnn_cell.py (100%) rename model/{ => tf}/dcrnn_model.py (98%) rename model/{ => tf}/dcrnn_supervisor.py (99%) diff --git a/DCRNN_CPU b/DCRNN_CPU index 2a7e78a..de4e4e8 100644 --- a/DCRNN_CPU +++ b/DCRNN_CPU @@ -1,3 +1,3 @@ -FROM tensorflow/tensorflow:latest-py3 +FROM ufoym/deepo:cpu COPY requirements.txt . RUN pip install -r requirements.txt diff --git a/dcrnn_train.py b/dcrnn_train.py index de75465..0c28dfb 100644 --- a/dcrnn_train.py +++ b/dcrnn_train.py @@ -7,7 +7,7 @@ import tensorflow as tf import yaml from lib.utils import load_graph_data -from model.dcrnn_supervisor import DCRNNSupervisor +from model.tf.dcrnn_supervisor import DCRNNSupervisor def main(args): diff --git a/model/pytorch/__init__.py b/model/pytorch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/pytorch/dcrnn_cell.py b/model/pytorch/dcrnn_cell.py new file mode 100644 index 0000000..aaa3285 --- /dev/null +++ b/model/pytorch/dcrnn_cell.py @@ -0,0 +1 @@ +import torch diff --git a/model/tf/__init__.py b/model/tf/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/dcrnn_cell.py b/model/tf/dcrnn_cell.py similarity index 100% rename from model/dcrnn_cell.py rename to model/tf/dcrnn_cell.py diff --git a/model/dcrnn_model.py b/model/tf/dcrnn_model.py similarity index 98% rename from model/dcrnn_model.py rename to model/tf/dcrnn_model.py index 43797fa..16fa379 100644 --- a/model/dcrnn_model.py +++ b/model/tf/dcrnn_model.py @@ -6,8 +6,7 @@ import tensorflow as tf from tensorflow.contrib import legacy_seq2seq -from lib.metrics import masked_mae_loss -from model.dcrnn_cell import DCGRUCell +from model.tf.dcrnn_cell import DCGRUCell class DCRNNModel(object): diff --git a/model/dcrnn_supervisor.py b/model/tf/dcrnn_supervisor.py similarity index 99% rename from model/dcrnn_supervisor.py rename to model/tf/dcrnn_supervisor.py index 18a399e..8024628 100644 --- a/model/dcrnn_supervisor.py +++ b/model/tf/dcrnn_supervisor.py @@ -13,7 +13,7 @@ from lib import utils, metrics from lib.AMSGrad import AMSGrad from lib.metrics import masked_mae_loss -from model.dcrnn_model import DCRNNModel +from model.tf.dcrnn_model import DCRNNModel class DCRNNSupervisor(object): diff --git a/requirements.txt b/requirements.txt index f577b89..989b6e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ numpy>=1.12.1 pandas>=0.19.2 pyyaml statsmodels -tensorflow>=1.3.0 \ No newline at end of file +tensorflow>=1.3.0 +torch \ No newline at end of file diff --git a/run_demo.py b/run_demo.py index ecbbe86..05b617f 100644 --- a/run_demo.py +++ b/run_demo.py @@ -6,7 +6,7 @@ import tensorflow as tf import yaml from lib.utils import load_graph_data -from model.dcrnn_supervisor import DCRNNSupervisor +from model.tf.dcrnn_supervisor import DCRNNSupervisor def run_dcrnn(args):