TrafficWheel/model/ARIMA/torch_utils.py

30 lines
760 B
Python
Executable File

import torch as pt
class JitEnabledProxy():
def __init__(self):
self.enabled = True
def __bool__(self):
return self.enabled
def disable_jit():
_jit_enabled.enabled = False
def enable_jit():
_jit_enabled.enabled = True
def jit_script(func):
'''
Decorator that decides whether to execute a JIT compiled version
of func or func itself, at each call to the decorated function,
based on the _jit_enabled flag.
'''
jit_script_func = pt.jit.script(func)
def dynamic_jit_enabled_func(*args, **kwargs):
if _jit_enabled:
return jit_script_func(*args, **kwargs)
else:
return func(*args, **kwargs)
return dynamic_jit_enabled_func
_jit_enabled = JitEnabledProxy()