Using pytorch image

This commit is contained in:
Chintan Shah 2019-09-07 17:53:46 -04:00
parent 69d6c0e053
commit 7ba7fa320d
10 changed files with 8 additions and 7 deletions

View File

@ -1,3 +1,3 @@
FROM tensorflow/tensorflow:latest-py3 FROM ufoym/deepo:cpu
COPY requirements.txt . COPY requirements.txt .
RUN pip install -r requirements.txt RUN pip install -r requirements.txt

View File

@ -7,7 +7,7 @@ import tensorflow as tf
import yaml import yaml
from lib.utils import load_graph_data from lib.utils import load_graph_data
from model.dcrnn_supervisor import DCRNNSupervisor from model.tf.dcrnn_supervisor import DCRNNSupervisor
def main(args): def main(args):

View File

View File

@ -0,0 +1 @@
import torch

0
model/tf/__init__.py Normal file
View File

View File

@ -6,8 +6,7 @@ import tensorflow as tf
from tensorflow.contrib import legacy_seq2seq from tensorflow.contrib import legacy_seq2seq
from lib.metrics import masked_mae_loss from model.tf.dcrnn_cell import DCGRUCell
from model.dcrnn_cell import DCGRUCell
class DCRNNModel(object): class DCRNNModel(object):

View File

@ -13,7 +13,7 @@ from lib import utils, metrics
from lib.AMSGrad import AMSGrad from lib.AMSGrad import AMSGrad
from lib.metrics import masked_mae_loss from lib.metrics import masked_mae_loss
from model.dcrnn_model import DCRNNModel from model.tf.dcrnn_model import DCRNNModel
class DCRNNSupervisor(object): class DCRNNSupervisor(object):

View File

@ -3,4 +3,5 @@ numpy>=1.12.1
pandas>=0.19.2 pandas>=0.19.2
pyyaml pyyaml
statsmodels statsmodels
tensorflow>=1.3.0 tensorflow>=1.3.0
torch

View File

@ -6,7 +6,7 @@ import tensorflow as tf
import yaml import yaml
from lib.utils import load_graph_data from lib.utils import load_graph_data
from model.dcrnn_supervisor import DCRNNSupervisor from model.tf.dcrnn_supervisor import DCRNNSupervisor
def run_dcrnn(args): def run_dcrnn(args):