30 lines
760 B
Python
Executable File
30 lines
760 B
Python
Executable File
import torch as pt
|
|
|
|
class JitEnabledProxy():
|
|
def __init__(self):
|
|
self.enabled = True
|
|
|
|
def __bool__(self):
|
|
return self.enabled
|
|
|
|
def disable_jit():
|
|
_jit_enabled.enabled = False
|
|
|
|
def enable_jit():
|
|
_jit_enabled.enabled = True
|
|
|
|
def jit_script(func):
|
|
'''
|
|
Decorator that decides whether to execute a JIT compiled version
|
|
of func or func itself, at each call to the decorated function,
|
|
based on the _jit_enabled flag.
|
|
'''
|
|
jit_script_func = pt.jit.script(func)
|
|
def dynamic_jit_enabled_func(*args, **kwargs):
|
|
if _jit_enabled:
|
|
return jit_script_func(*args, **kwargs)
|
|
else:
|
|
return func(*args, **kwargs)
|
|
return dynamic_jit_enabled_func
|
|
|
|
_jit_enabled = JitEnabledProxy() |