from collections import deque
from functools import lru_cache
import math
import os
import random
import time
import warnings
import heapq
import numpy as np
import torch
import torch.nn as nn
import json
from collections import defaultdict
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TORCH_LONG = torch.long
__TORCH_AVAILABLE = True
"""Near infinity, useful as a large penalty for scoring when inf is bad."""
NEAR_INF = 1e20
NEAR_INF_FP16 = 65504
[docs]def neginf(dtype):
"""Returns a representable finite number near -inf for a dtype."""
if dtype is torch.float16:
return -NEAR_INF_FP16
else:
return -NEAR_INF
[docs]def _create_embeddings(dictionary, embedding_size, padding_idx):
"""Create and initialize word embeddings."""
#e=nn.Embedding.from_pretrained(data, freeze=False, padding_idx=0).double()
e = nn.Embedding(len(dictionary)+4, embedding_size, padding_idx)
e.weight.data.copy_(torch.from_numpy(np.load('../kgsfdata/word2vec_redial.npy')))
#nn.init.normal_(e.weight, mean=0, std=embedding_size ** -0.5)
#e.weight=data
#nn.init.constant_(e.weight[padding_idx], 0)
return e
[docs]def _create_entity_embeddings(entity_num, embedding_size, padding_idx):
"""Create and initialize word embeddings."""
e = nn.Embedding(entity_num, embedding_size)
nn.init.normal_(e.weight, mean=0, std=embedding_size ** -0.5)
nn.init.constant_(e.weight[padding_idx], 0)
return e
[docs]def _edge_list(kg, n_entity, hop):
edge_list = []
for h in range(hop):
for entity in range(n_entity):
edge_list.append((entity, entity, 185))
if entity not in kg:
continue
for tail_and_relation in kg[entity]:
if entity != tail_and_relation[1] and tail_and_relation[0] != 185 :
edge_list.append((entity, tail_and_relation[1], tail_and_relation[0]))
edge_list.append((tail_and_relation[1], entity, tail_and_relation[0]))
relation_cnt = defaultdict(int)
relation_idx = {}
for h, t, r in edge_list:
relation_cnt[r] += 1
for h, t, r in edge_list:
if relation_cnt[r] > 1000 and r not in relation_idx:
relation_idx[r] = len(relation_idx)
return [(h, t, relation_idx[r]) for h, t, r in edge_list if relation_cnt[r] > 1000], len(relation_idx)
[docs]def _concept_edge_list4GCN():
node2index=json.load(open('../kgsfdata/key2index_3rd.json',encoding='utf-8'))
f=open('../kgsfdata/conceptnet_edges2nd.txt',encoding='utf-8')
edges=set()
stopwords=set([word.strip() for word in open('../kgsfdata/stopwords.txt',encoding='utf-8')])
for line in f:
lines=line.strip().split('\t')
entity0=node2index[lines[1].split('/')[0]]
entity1=node2index[lines[2].split('/')[0]]
if lines[1].split('/')[0] in stopwords or lines[2].split('/')[0] in stopwords:
continue
edges.add((entity0,entity1))
edges.add((entity1,entity0))
edge_set=[[co[0] for co in list(edges)],[co[1] for co in list(edges)]]
return torch.LongTensor(edge_set).to(device)
[docs]def seed_everything(seed: int):
import random, os
import numpy as np
import torch
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
seed_everything(42)