Using pytorch image
This commit is contained in:
parent
69d6c0e053
commit
7ba7fa320d
|
|
@ -1,3 +1,3 @@
|
|||
FROM tensorflow/tensorflow:latest-py3
|
||||
FROM ufoym/deepo:cpu
|
||||
COPY requirements.txt .
|
||||
RUN pip install -r requirements.txt
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import tensorflow as tf
|
|||
import yaml
|
||||
|
||||
from lib.utils import load_graph_data
|
||||
from model.dcrnn_supervisor import DCRNNSupervisor
|
||||
from model.tf.dcrnn_supervisor import DCRNNSupervisor
|
||||
|
||||
|
||||
def main(args):
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
import torch
|
||||
|
|
@ -6,8 +6,7 @@ import tensorflow as tf
|
|||
|
||||
from tensorflow.contrib import legacy_seq2seq
|
||||
|
||||
from lib.metrics import masked_mae_loss
|
||||
from model.dcrnn_cell import DCGRUCell
|
||||
from model.tf.dcrnn_cell import DCGRUCell
|
||||
|
||||
|
||||
class DCRNNModel(object):
|
||||
|
|
@ -13,7 +13,7 @@ from lib import utils, metrics
|
|||
from lib.AMSGrad import AMSGrad
|
||||
from lib.metrics import masked_mae_loss
|
||||
|
||||
from model.dcrnn_model import DCRNNModel
|
||||
from model.tf.dcrnn_model import DCRNNModel
|
||||
|
||||
|
||||
class DCRNNSupervisor(object):
|
||||
|
|
@ -3,4 +3,5 @@ numpy>=1.12.1
|
|||
pandas>=0.19.2
|
||||
pyyaml
|
||||
statsmodels
|
||||
tensorflow>=1.3.0
|
||||
tensorflow>=1.3.0
|
||||
torch
|
||||
|
|
@ -6,7 +6,7 @@ import tensorflow as tf
|
|||
import yaml
|
||||
|
||||
from lib.utils import load_graph_data
|
||||
from model.dcrnn_supervisor import DCRNNSupervisor
|
||||
from model.tf.dcrnn_supervisor import DCRNNSupervisor
|
||||
|
||||
|
||||
def run_dcrnn(args):
|
||||
|
|
|
|||
Loading…
Reference in New Issue