terrible hack
This commit is contained in:
parent
cd6c55c1ab
commit
29eff4a194
1 changed files with 9 additions and 2 deletions
|
@ -39,8 +39,15 @@ def torch_gc():
|
||||||
|
|
||||||
def enable_tf32():
|
def enable_tf32():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.backends.cudnn.benchmark = True
|
#TODO: make this better; find a way to check if it is a turing card
|
||||||
torch.backends.cudnn.enabled = True
|
turing = ["1630","1650","1660","Quadro RTX 3000","Quadro RTX 4000","Quadro RTX 4000","Quadro RTX 5000","Quadro RTX 5000","Quadro RTX 6000","Quadro RTX 6000","Quadro RTX 8000","Quadro RTX T400","Quadro RTX T400","Quadro RTX T600","Quadro RTX T1000","Quadro RTX T1000","2060","2070","2080","Titan RTX","Tesla T4","MX450","MX550"]
|
||||||
|
for devid in range(0,torch.cuda.device_count()):
|
||||||
|
for i in turing:
|
||||||
|
if i in torch.cuda.get_device_name(devid):
|
||||||
|
shd = True
|
||||||
|
if shd:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue