This commit is contained in:
Echo-Ji 2021-11-07 19:48:24 +08:00
parent 3a5f5e5170
commit 57185839ca
1 changed files with 0 additions and 535 deletions

View File

@ -1,535 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[ 0, 1, 2, 3, 4],\n",
" [ 0, 6, 7, 8, 9],\n",
" [ 0, 0, 12, 13, 14],\n",
" [ 0, 0, 0, 18, 19],\n",
" [ 0, 0, 0, 0, 24]]),\n",
" array([[ 0, 0, 0, 0, 0],\n",
" [ 5, 6, 0, 0, 0],\n",
" [10, 11, 12, 0, 0],\n",
" [15, 16, 17, 18, 0],\n",
" [20, 21, 22, 23, 24]]))"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = np.arange(0, 25).reshape((5, 5))\n",
"\n",
"out = np.triu(m)\n",
"inp = np.tril(m)\n",
"out, inp"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1, 0, 1, 0, 0],\n",
" [1, 0, 0, 1, 1],\n",
" [1, 0, 0, 1, 0],\n",
" [1, 1, 1, 0, 0],\n",
" [1, 0, 1, 1, 1]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"adj = np.random.randint(0, 2, size=(5, 5))\n",
"adj"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0, 0, 2, 0, 0],\n",
" [ 5, 0, 0, 8, 9],\n",
" [10, 0, 0, 13, 0],\n",
" [15, 16, 17, 0, 0],\n",
" [20, 0, 22, 23, 48]])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"((inp + out) * adj)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1],\n",
" [2],\n",
" [3]])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = torch.tensor([1, 2, 3])\n",
"a.unsqueeze_(-1)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 5, 5])\n"
]
},
{
"ename": "IndexError",
"evalue": "Dimension out of range (expected to be in range of [-1, 0], but got 1)",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mIndexError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-29-95457010b19d>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze_\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mr\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mr\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstart_dim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend_dim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;31mIndexError\u001b[0m: Dimension out of range (expected to be in range of [-1, 0], but got 1)"
]
}
],
"source": [
"r = torch.tensor(((inp + out) * adj))\n",
"r.unsqueeze_(0)\n",
"print(r.shape)\n",
"r[r > 0].flatten(start_dim=0, end_dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"r = r.repeat((2, 1, 1))"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 0, 2, 0, 0, 5, 0, 0, 8, 9, 10, 0, 0, 13, 0, 15, 16, 17,\n",
" 0, 0, 20, 0, 22, 23, 48],\n",
" [ 0, 0, 2, 0, 0, 5, 0, 0, 8, 9, 10, 0, 0, 13, 0, 15, 16, 17,\n",
" 0, 0, 20, 0, 22, 23, 48]], dtype=torch.int32)"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"r = torch.flatten(r, start_dim=1, end_dim=-1)\n",
"r[r>0]"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 2, 5, 8, 9, 10, 13, 15, 16, 17, 20, 22, 23, 48],\n",
" [ 2, 5, 8, 9, 10, 13, 15, 16, 17, 20, 22, 23, 48]],\n",
" dtype=torch.int32)"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"r[r > 0].reshape(2, -1)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
"class GraphGrad(torch.nn.Module):\n",
" def __init__(self, adj_mx):\n",
" \"\"\"Graph gradient operator that transform functions on nodes to functions on edges.\n",
" \"\"\"\n",
" super(GraphGrad, self).__init__()\n",
" self.adj_mx = adj_mx\n",
" self.grad = self._grad(adj_mx)\n",
" \n",
" @staticmethod\n",
" def _grad(adj_mx):\n",
" \"\"\"Fetch the graph gradient operator.\"\"\"\n",
" num_nodes = adj_mx.size()[-1]\n",
"\n",
" num_edges = (adj_mx > 0.).sum()\n",
" grad = torch.zeros(num_nodes, num_edges)\n",
" e = 0\n",
" for i in range(num_nodes):\n",
" for j in range(num_nodes):\n",
" if adj_mx[i, j] == 0:\n",
" continue\n",
"\n",
" grad[i, e] = 1.\n",
" grad[j, e] = -1.\n",
" e += 1\n",
" return grad\n",
"\n",
" def forward(self, z):\n",
" \"\"\"Transform z with shape `(..., num_nodes)` to f with shape `(..., num_edges)`.\n",
" \"\"\"\n",
" return torch.matmul(z, self.grad)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1, 0, 1, 0, 0],\n",
" [1, 0, 0, 1, 1],\n",
" [1, 0, 0, 1, 0],\n",
" [1, 1, 1, 0, 0],\n",
" [1, 0, 1, 1, 1]])"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"adj"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([5, 14])"
]
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gg = GraphGrad(torch.tensor(adj))\n",
"grad = gg.grad\n",
"grad.shape"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([14, 5])"
]
},
"execution_count": 94,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grad.transpose(-1, -2).shape"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"14"
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grad.size(-1)"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(5, 5)"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inp.shape"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0.],\n",
" [ -5., 5., 1., 6., 6., -5., 0., -5., -6., 0., -5., 0.,\n",
" 0., 0.],\n",
" [-10., -2., 1., 11., 11., 2., 12., -10., -11., -12., -10., -12.,\n",
" 0., 0.],\n",
" [-15., -2., 1., -2., 16., 2., -1., 3., 2., 1., -15., -17.,\n",
" -18., 0.],\n",
" [-20., -2., 1., -2., -3., 2., -1., 3., 2., 1., 4., 2.,\n",
" 1., -24.]])"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gg(torch.tensor(inp, dtype=torch.float32))"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([5, 5])"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.tensor(inp).shape \n",
"# (grad_T.T)"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1;31mDocstring:\u001b[0m\n",
"matmul(input, other, *, out=None) -> Tensor\n",
"\n",
"Matrix product of two tensors.\n",
"\n",
"The behavior depends on the dimensionality of the tensors as follows:\n",
"\n",
"- If both tensors are 1-dimensional, the dot product (scalar) is returned.\n",
"- If both arguments are 2-dimensional, the matrix-matrix product is returned.\n",
"- If the first argument is 1-dimensional and the second argument is 2-dimensional,\n",
" a 1 is prepended to its dimension for the purpose of the matrix multiply.\n",
" After the matrix multiply, the prepended dimension is removed.\n",
"- If the first argument is 2-dimensional and the second argument is 1-dimensional,\n",
" the matrix-vector product is returned.\n",
"- If both arguments are at least 1-dimensional and at least one argument is\n",
" N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first\n",
" argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the\n",
" batched matrix multiply and removed after. If the second argument is 1-dimensional, a\n",
" 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after.\n",
" The non-matrix (i.e. batch) dimensions are :ref:`broadcasted <broadcasting-semantics>` (and thus\n",
" must be broadcastable). For example, if :attr:`input` is a\n",
" :math:`(j \\times 1 \\times n \\times n)` tensor and :attr:`other` is a :math:`(k \\times n \\times n)`\n",
" tensor, :attr:`out` will be a :math:`(j \\times k \\times n \\times n)` tensor.\n",
"\n",
" Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs\n",
" are broadcastable, and not the matrix dimensions. For example, if :attr:`input` is a\n",
" :math:`(j \\times 1 \\times n \\times m)` tensor and :attr:`other` is a :math:`(k \\times m \\times p)`\n",
" tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the\n",
" matrix dimensions) are different. :attr:`out` will be a :math:`(j \\times k \\times n \\times p)` tensor.\n",
"\n",
"This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.\n",
"\n",
".. note::\n",
"\n",
" The 1-dimensional dot product version of this function does not support an :attr:`out` parameter.\n",
"\n",
"Arguments:\n",
" input (Tensor): the first tensor to be multiplied\n",
" other (Tensor): the second tensor to be multiplied\n",
"\n",
"Keyword args:\n",
" out (Tensor, optional): the output tensor.\n",
"\n",
"Example::\n",
"\n",
" >>> # vector x vector\n",
" >>> tensor1 = torch.randn(3)\n",
" >>> tensor2 = torch.randn(3)\n",
" >>> torch.matmul(tensor1, tensor2).size()\n",
" torch.Size([])\n",
" >>> # matrix x vector\n",
" >>> tensor1 = torch.randn(3, 4)\n",
" >>> tensor2 = torch.randn(4)\n",
" >>> torch.matmul(tensor1, tensor2).size()\n",
" torch.Size([3])\n",
" >>> # batched matrix x broadcasted vector\n",
" >>> tensor1 = torch.randn(10, 3, 4)\n",
" >>> tensor2 = torch.randn(4)\n",
" >>> torch.matmul(tensor1, tensor2).size()\n",
" torch.Size([10, 3])\n",
" >>> # batched matrix x batched matrix\n",
" >>> tensor1 = torch.randn(10, 3, 4)\n",
" >>> tensor2 = torch.randn(10, 4, 5)\n",
" >>> torch.matmul(tensor1, tensor2).size()\n",
" torch.Size([10, 3, 5])\n",
" >>> # batched matrix x broadcasted matrix\n",
" >>> tensor1 = torch.randn(10, 3, 4)\n",
" >>> tensor2 = torch.randn(4, 5)\n",
" >>> torch.matmul(tensor1, tensor2).size()\n",
" torch.Size([10, 3, 5])\n",
"\u001b[1;31mType:\u001b[0m builtin_function_or_method\n"
]
}
],
"source": [
"torch.matmul?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "1b89aa55be347d0b8cc51b3a166e8002614a385bd8cff32165269c80e70c12a7"
},
"kernelspec": {
"display_name": "Python 3.8.5 64-bit ('base': conda)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}