Using pytorch image
This commit is contained in:
parent
69d6c0e053
commit
7ba7fa320d
|
|
@ -1,3 +1,3 @@
|
||||||
FROM tensorflow/tensorflow:latest-py3
|
FROM ufoym/deepo:cpu
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN pip install -r requirements.txt
|
RUN pip install -r requirements.txt
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import tensorflow as tf
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from lib.utils import load_graph_data
|
from lib.utils import load_graph_data
|
||||||
from model.dcrnn_supervisor import DCRNNSupervisor
|
from model.tf.dcrnn_supervisor import DCRNNSupervisor
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
import torch
|
||||||
|
|
@ -6,8 +6,7 @@ import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.contrib import legacy_seq2seq
|
from tensorflow.contrib import legacy_seq2seq
|
||||||
|
|
||||||
from lib.metrics import masked_mae_loss
|
from model.tf.dcrnn_cell import DCGRUCell
|
||||||
from model.dcrnn_cell import DCGRUCell
|
|
||||||
|
|
||||||
|
|
||||||
class DCRNNModel(object):
|
class DCRNNModel(object):
|
||||||
|
|
@ -13,7 +13,7 @@ from lib import utils, metrics
|
||||||
from lib.AMSGrad import AMSGrad
|
from lib.AMSGrad import AMSGrad
|
||||||
from lib.metrics import masked_mae_loss
|
from lib.metrics import masked_mae_loss
|
||||||
|
|
||||||
from model.dcrnn_model import DCRNNModel
|
from model.tf.dcrnn_model import DCRNNModel
|
||||||
|
|
||||||
|
|
||||||
class DCRNNSupervisor(object):
|
class DCRNNSupervisor(object):
|
||||||
|
|
@ -3,4 +3,5 @@ numpy>=1.12.1
|
||||||
pandas>=0.19.2
|
pandas>=0.19.2
|
||||||
pyyaml
|
pyyaml
|
||||||
statsmodels
|
statsmodels
|
||||||
tensorflow>=1.3.0
|
tensorflow>=1.3.0
|
||||||
|
torch
|
||||||
|
|
@ -6,7 +6,7 @@ import tensorflow as tf
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from lib.utils import load_graph_data
|
from lib.utils import load_graph_data
|
||||||
from model.dcrnn_supervisor import DCRNNSupervisor
|
from model.tf.dcrnn_supervisor import DCRNNSupervisor
|
||||||
|
|
||||||
|
|
||||||
def run_dcrnn(args):
|
def run_dcrnn(args):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue