ZHAWNotes/Notes/Semester 4/MLDM - Machine Learning and Data Mining/Labs/L03_Linear_Regression_LAB_ASSIGNMENT.ipynb

1058 lines
200 KiB
Text
Raw Normal View History

2023-06-17 11:47:27 +00:00
{
"cells": [
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 1,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "FZEco2HK6D57"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 2,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "a3nCUqopXHwv"
},
"outputs": [],
"source": [
"RANDOM_SEED = 0x0"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "jjTkUw7BWulH"
},
"source": [
"# Lab 03: Linear Regression"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "gNnZUk36Xz7_"
},
"source": [
"For the first few Tasks, we will work with synthetic univariate data.\n",
"We generate $100$ features $x_i \\in [-1, 1]$ as `x` and two different\n",
"regression targets `y1` and `y2`."
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 3,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "Ojta777H2ulb"
},
"outputs": [],
"source": [
"data_rng = np.random.default_rng(RANDOM_SEED)\n",
"n = 100\n",
"x = 2 * data_rng.random(n) - 1 # create n points between -1 and 1\n",
"\n",
"# setup synthetic linear data\n",
"true_offset = 0.5\n",
"true_slope = 1.25\n",
"noise = data_rng.normal(loc=0., scale=0.25, size=(n,))\n",
"\n",
"y1 = true_offset + true_slope * x + noise\n",
"\n",
"\n",
"# setup synthetic non-linear data\n",
"y2 = true_offset + np.sin(np.pi * x) + noise"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "ntdpTWzqZqAU"
},
"source": [
"# Task 1 (1 Point): Pearson Correlation"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "JbNJ7WhzbAtm"
},
"source": [
"### Task 1a\n",
"\n",
"Plot `x` against the target variable `y1`.\n",
"\n",
"* use `plt.scatter`\n",
"\n",
"\n",
"Do you think there is a linear relationship between `x` and the target?"
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 4,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "MxYMdhfxyYAd"
},
2023-06-17 13:35:15 +00:00
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x7fbd80918090>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAi8AAAGdCAYAAADaPpOnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABAXElEQVR4nO3df3hU5Zn/8c8kQgJKgohkgkZFdFEECWjBoF9hNRbQtbL26ipqxR+LK9VWi201XQui7YW0Wt1WK9aqtKWI2ir+omkRpK42SgVSDQgrbCwUk1iIzkCQX8nz/YOdMZPMj3Mm58ycM/N+Xddcl5mcmXlODvHceZ77ue+AMcYIAADAJwqyPQAAAAA7CF4AAICvELwAAABfIXgBAAC+QvACAAB8heAFAAD4CsELAADwFYIXAADgK4dlewBO6+jo0EcffaR+/fopEAhkezgAAMACY4x27dqlwYMHq6Ag+dxKzgUvH330kSoqKrI9DAAAkIZt27bp2GOPTXpMzgUv/fr1k3To5EtKSrI8GgAAYEU4HFZFRUX0Pp5MzgUvkaWikpISghcAAHzGSsoHCbsAAMBXCF4AAICvELwAAABfIXgBAAC+QvACAAB8heAFAAD4CsELAADwFYIXAADgKzlXpA4AAFjX3mG0urFVH+/aq0H9ijV2yAAVFni7NyDBCwAAeaq2oUlzX9qgptDe6HPlpcWac/FwTR5RnsWRJceyEQAAeai2oUkzF62NCVwkqTm0VzMXrVVtQ1OWRpYawQsAAHmmvcNo7ksbZOJ8L/Lc3Jc2qL0j3hHZR/ACAECeWd3Y2m3GpTMjqSm0V6sbWzM3KBsIXgAAyDMf70ocuKRzXKYRvAAAkGcG9St29LhMI3gBACDPjB0yQOWlxUq0ITqgQ7uOxg4ZkMlhWUbwAgBAniksCGjOxcMlqVsAE/l6zsXDPVvvheAFAIA8NHlEuR65aoyCpbFLQ8HSYj1y1RhP13mhSB0AAHlq8ohyXTA86LsKu67OvMybN09f+MIX1K9fPw0aNEhTp07Vpk2bUr7u2Wef1SmnnKLi4mKNHDlSy5Ytc3OYAADkrcKCgKqGHqVLKo9R1dCjPB+4SC4HL3/6059000036a233tLy5ct14MABffGLX1RbW1vC1/z5z3/WtGnTdP3112vdunWaOnWqpk6dqoaGBjeHCgAAdKiAXd2WnXqhfrvqtuz0ZKG6gDEmY6P6xz/+oUGDBulPf/qTzj333LjHXHbZZWpra9PLL78cfe6ss85SZWWlFixYkPIzwuGwSktLFQqFVFJS4tjYAQDIddnsdWTn/p3RhN1QKCRJGjAg8daruro6VVdXxzw3adIk1dXVxT1+3759CofDMQ8AAGCPn3odZSx46ejo0K233qqzzz5bI0aMSHhcc3OzysrKYp4rKytTc3Nz3OPnzZun0tLS6KOiosLRcQMAkOv81usoY8HLTTfdpIaGBi1ZssTR962pqVEoFIo+tm3b5uj7AwDgRU7mpvit11FGtkrffPPNevnll/X666/r2GOPTXpsMBhUS0tLzHMtLS0KBoNxjy8qKlJRUZFjYwUAwOuczk3xW68jV2dejDG6+eab9fzzz2vlypUaMmRIytdUVVVpxYoVMc8tX75cVVVVbg0TAADfcCM3xW+9jlwNXm666SYtWrRIixcvVr9+/dTc3Kzm5mZ99tln0WOuvvpq1dTURL++5ZZbVFtbq/vvv18bN27UXXfdpXfeeUc333yzm0MFAMDz3MpN8VuvI1eDl0ceeUShUEgTJ05UeXl59PH0009Hj9m6dauamj6PEsePH6/Fixfr5z//uUaNGqXf/va3Wrp0adIkXwAA8oFbuSl+63Xkas6LlRIyq1at6vbcV77yFX3lK19xYUQAAPiXm7kpkV5HXXNpghmq82IHvY0AAPAJt3NT/NLriOAFAACfiOSmNIf2xs17CejQTElPclMivY68LKMVdgEAQPr8lpviFoIXAAB8JJKbEiyNXRoKlhbrkavGeCo3xS0sGwEA4DN+yU1xC8ELAAA+5IfcFLewbAQAAHyF4AUAAPgKwQsAAPAVcl4AAPCJ9g6Tt0m6nRG8AADgA7UNTd1K95d7sHR/JrBsBACAx9U2NGnmorXdmjI2h/Zq5qK1qm1oSvDK3ETwAgCAh7V3GM19aUPcdgCR5+a+tEHtHambIecKghcAADxsdWNrtxmXzoykptBerW5szdygsozgBQAAD/t4V+LApbM3N+/Im9kXghcAADKgvcOobstOvVC/XXVbdloONAb1K059kKSHXtusc+avzIv8F3YbAQDgsp7sFBo7ZIDKS4vVHNobN++ls0gCb643aGTmBQAAF/V0p1BhQUBzLh4uSUpV0SVfEngJXgAAcIlTO4UmjyjXI1eNUbA09RKSmwm86S59OY1lIwAAXGJnp1CqDtGTR5Sro0P62uK1lj7baqKvVcve/Uh3vtCg1rYD0eeyVSSPmRcAAFxiNYCwclx7h9E9r2yw/NlWE32tmLdsg762eF1M4CIdCryyUSSP4AUAAJdYDSCsHJdqFqez8tJDfY+csOzdJj36emPC7xtlPseG4AUAAJdEdgolSrQNyHqgYWcZaM7Fwx1p2NjeYXTnCw0pj8t0kTyCFwAAXJJsp1Dka6uBhtVZnG9Wn+xYDsrqxla1tu23dKzTOTbJELwAAOCiRDuFgqXFtuqxfNK2X6linPLSYt183snRr3u6O8hOQOJkjk0q7DYCAMBlk0eU64LhQa1ubNXHu/ZqUL9DS0VWl3ZqG5p00+K1SYvUBRQ7i9OTwngRVgOSAYf3cizHxgqCFwAAMqCwIJByO3Q8yWrFRBQEpIemjY4GJZHCeF1f0xTaqxsXrdXPrhitC08fnPKzIzk7qRKFv3/JCEdybKxi2QgAAA+zssuow0hHHl4kyVqwc/NT67Ts3dTbmyM5O8nCkv84d4ilQMhJBC8AAHiY3VoxVoOdry22Vp8lkrNT3iVn56jDe+tnV4xRzYXDLY3PSSwbAQDgYVbzTgYeUaS6LTv1exsF4+a+tEEXDA+mXPLpac6O0wheAADwsFRdpQOS+vftpdueqVdzeJ+t97bamkBKP2fHDSwbAQCQhkw1KUxVK8ZI+mTPAduBS0Qm67M4hZkXAABsSmcbcnuHSXvZJZJ30vUzg6XF+uxAuz7dcyDJq5PLZH0WpwSMMdnpZ+2ScDis0tJShUIhlZSUZHs4AIAck2gbciQMiVd4zomaK1L3AKijw+jKx99O6zwCOhT8vHH7eVnLXenMzv2bZSMAACxKtg058lzXJoWRYKfrDqDmNDoyR/JOLqk8RlVDj9KOtvSWiiKhyuVfOE4vv/uRq8tebmDZCAAAi1JtQzaKTYJNFewEZH3HTzzWK+D2julR1L9vLxlJD7z6P9Hn0pkJyhZmXgAAsMjpmiudg510WO1a/VbN+Xpqxln6r8sr9c3qk/XJngPd8mTSmQnKFoIXAAAssjrTETnObrBjl9Wu1b0PK1DV0KP0L6cP1pK/bIv7XomWvbyI4AUAAIusznREmhTaDXbSYadrtdWZoAeWb/J0Hgw5LwAAWBSZ6Zi5aG20xkpE55mOSP6KlQJzwU7BTrqsVsC1OsPz0Gtb9NBrWzybB+PqzMvrr7+uiy++WIMHD1YgENDSpUuTHr9q1SoFAoFuj+bmZjeHCQCAZXZmOqwu6zixVbnrTqR472l3hsereTCuzry0tbVp1KhRuu6663TppZdaft2mTZti9ngPGjTIjeEBAJAWO71+khWYy/SsRqqZoK6c2BHlBleDlylTpmjKlCm2Xzdo0CD179/f+QEBAOAQO71+vNLYMNmyVyJdt397gScTdisrK1VeXq4LLrhAb775ZtJj9+3bp3A4HPMAAMBrrCzrZEKiZa9UvNQDyVPBS3l5uRYsWKDf/e53+t3vfqeKigpNnDhRa9euTfiaefPmqbS0NPqoqKjI4IgBAPCfySPK9cbt5+mpGWfp5n8+ydJrvNQDKWO9jQKBgJ5//nlNnTrV1usmTJig4447Tr/+9a/jfn/fvn3at+/z8sjhcFgVFRX0N
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(x, y1)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"_Yes, definitely!_"
]
2023-06-17 11:47:27 +00:00
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "6Ak0nQ0PDGpm"
},
"source": [
"Plot `x` against the target variable `y2`.\n",
"\n",
"Do you think there is a linear relationship between `x` and the target?"
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 4,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "HpzwoBdQDd-d"
},
2023-06-17 13:35:15 +00:00
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x7f2336ae7550>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAi8AAAGdCAYAAADaPpOnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA+EklEQVR4nO3dfXhU5Z3/8c8kQngoGUCECRotIoum4UFsg1BWfIgFcSn+2mtXUCtYiy2trQq1mq5CUftDqq3uKiuWVamlSh8V2dq0KvJzsUEqkCqirLBREJNYiUwgyFNy//6gMzLJPJ2Zc2bOOfN+XVeui8ycSe4zk+F853t/7+8dMMYYAQAAeERRvgcAAABgBcELAADwFIIXAADgKQQvAADAUwheAACApxC8AAAATyF4AQAAnkLwAgAAPOWEfA/Abh0dHXr//ffVp08fBQKBfA8HAACkwRijffv2afDgwSoqSp5b8V3w8v7776u8vDzfwwAAABnYtWuXTjnllKTH+C546dOnj6RjJ19aWprn0QAAgHS0traqvLw8eh1PxnfBS2SqqLS0lOAFAACPSafkg4JdAADgKQQvAADAUwheAACApxC8AAAATyF4AQAAnkLwAgAAPIXgBQAAeArBCwAA8BTfNakDAORee4fRhoYWfbDvoAb26aGqIf1VXMT+cnAGwQsAICu1Wxq1cPVWNYYPRm8rC/bQgqkVmlxZlseRwa+YNgIAZKx2S6PmrNgUE7hIUlP4oOas2KTaLY15Ghn8jOAFAJCR9g6jhau3ysS5L3LbwtVb1d4R7wggcwQvAICMbGho6ZJxOZ6R1Bg+qA0NLbkbFAoCwQsAICMf7EscuGRyHJAughcAQEYG9ulh63FAugheAAAZqRrSX2XBHkq0IDqgY6uOqob0z+WwUAAIXgAAGSkuCmjB1ApJ6hLARL5fMLWCfi+wHcELACBjkyvL9NBVYxQKxk4NhYI99NBVY+jzAkfQpA4AkJXJlWW6uCJEh13kDMELACBrxUUBjRt6Yr6HgQLBtBEAAPAUghcAAOApBC8AAMBTCF4AAICnELwAAABPIXgBAACeQvACAAA8heAFAAB4CsELAADwFIIXAADgKQQvAADAUwheAACAp7AxIwC4THuHYYdmIAmCFwBwkdotjVq4eqsawwejt5UFe2jB1ApNrizL48gA92DaCABconZLo+as2BQTuEhSU/ig5qzYpNotjXkaGeAuBC8A4ALtHUYLV2+ViXNf5LaFq7eqvSPeEUBhIXgBABfY0NDSJeNyPCOpMXxQGxpacjcowKUIXgDABT7YlzhwyeQ4wM8IXgDABQb26WHrcYCfsdoIAFygakh/lQV7qCl8MG7dS0BSKHhs2bQVLLuGHxG8AIALFBcFtGBqheas2KSAFBPAREKNBVMrLAUeLLuGXzFtBAAuMbmyTA9dNUahYOzUUCjYQw9dNcZSwMGya/gZmRcAcJHJlWW6uCKU1VRPqmXXAR1bdn1xRSirKSSmpJAvBC8A4DLFRQGNG3pixo+3suw609/DlBTyiWkjAPAZp5ddMyWFfCN4AQCfcXLZNZ2A4QYELwDgM5Fl14mqTwI6NsVjddm1RCdguAPBCwB4THuHUd2OPVpVv1t1O/Z0yXJEll1L6hLAZLrsOoJOwHADCnYBwEPSLZSNLLvufGwoy6JaOgHDDRzNvLz00kuaOnWqBg8erEAgoKeffjrp8WvXrlUgEOjy1dTU5OQwAcATrBbKTq4s07pbLtSTs8/Vv00frSdnn6t1t1yY1WogJ6ekgHQ5Gry0tbVp1KhRWrJkiaXHbdu2TY2NjdGvgQMHOjRCAPCGTAtlI8uup40+WeOGnph1HxYnp6SAdDk6bXTJJZfokksusfy4gQMHqm/fvvYPCAA8Khe9W9KVaEqqf+/umjZ6sII9u6u9wxDAwDGurHkZPXq0Dh06pMrKSv3gBz/Q5z//+YTHHjp0SIcOHYp+39ramoshAkBOua1Q9vhOwM9vbdJT9bu1p+2wHn35HT368js0rIOjXLXaqKysTEuXLtVvf/tb/fa3v1V5ebnOP/98bdq0KeFjFi1apGAwGP0qLy/P4YgBIDfcWChbXBRQ+ONjAUtL25GY+2hYBycFjDE56SQUCAT01FNP6bLLLrP0uIkTJ+rUU0/Vz3/+87j3x8u8lJeXKxwOq7S0NJshA4BrtHcYTVi8Rk3hg3HrXgI6tpJo3S0X5my6JjKmZNNZ/Xt30/qaanU/wVWfleFCra2tCgaDaV2/Xf/XVFVVpe3btye8v6SkRKWlpTFfAOA3biyUTVWHI0ktbUd07qIXyMDAVq4PXurr61VWxpwpAEQKZUPB2KmhULCHHrpqTM7rS9Ktr2lpO+zLKaRUzQLhHEcLdvfv3x+TNWloaFB9fb369++vU089VTU1Ndq9e7cef/xxSdL999+vIUOG6DOf+YwOHjyo//zP/9SaNWv0pz/9yclhAoBnHF8o+8G+gxrY51hPlXys7LFaX7Nw9VZdXBHyxSokdtXOL0eDl1dffVUXXHBB9Pu5c+dKkmbOnKnly5ersbFRO3fujN5/+PBhzZs3T7t371avXr00cuRIPf/88zE/AwAKXaR3S75FGtYlqsM5Xi6Xcjst0iyw8zlHipTzkQUrNDkr2M0VKwU/AIDsJLqQJ/Jv00dr2uiTHR2Tk1IVKScrnG7vMK7ImLmVleu3K/u8AAC8IVKH8/2nXu+yXDoer+95lGmzQKaZ7OX6gl0AgLtNrizT+ppq9e/dPeExftnzKJNmgVb3pEJqBC8A4FFuWu3S/YQi/d//U6mA3LOUW7L/ObLaLDDTPamQHNNGAOBBbpyGSLTnUShP43LiOUpVpBypeYlkmNy0J5WfELwAgMfkY7VLusWmblnK7dRzFGkWOGfFJgWkmJ8fL8Pktj2p/ILgBQA8JNU0RED291OxmsHI91Jup58jKxkmN+5J5QcELwDgIbmehvBiTxOrz1EmS5jTzTBZnWZCegheACBHEl0krVw8czkNkY8sjx2sPEfZ1MWkk2GyOs2E9BC8AEAOJLpIfnFUmZ75a2PaF89MVrtkWn/i1WLTdJ+jdz48oPuf/x/Hs0puK2T2A4IXAHBYoqmXxvBBPfxSQ5fjk108rUxDZLvaxqvFpuk8R4NKS/Tkhp05yyq5pZDZL+jzAgAOSjb1kkiy/h+RaQgpeT+V57Y2Zd0YzavFpuk8RzOqTlVTa3pZJTvHNW7oiZo2+mSNG3oigUsWCF4AwEGppl4SSXbxjExDhIKxQUMo2EMPXTVGF1eEbGmMFslgJLrEurlrbqrn6NMDeqf1c9yWVcIxTBsBgIOyvfglenyyaYi6HXtsqVXxerFpqucoHW7LKuEYghcAcFC2F79kj0+02sVKrUqqgl6vF5smeo5YwuxtBC8A4KBUF8lEsrl4pr/apk0TFq9JWdDrtmLTbFZQRXg9q1ToAsYYX+0G1draqmAwqHA4rNLS0nwPBwCiq40kpRXARC6XmS7Vbe8wmrB4TdKsQt9e3fTRgSO2/26n2b1fkRv3iCpUVq7fBC8AkAN29Xmx8vviBUyR4CTYq5v2xgleIseEgj207pYLXZV5SLTk3I5gzy1ZpUJG8ELwAsCF7Oiwa0WigGn6507Vfc//T8rHPzn7XNc0n4tkkxIVIlsJuPwSrPjlPCKsXL+peQGAHElUPOrURoaJalX+67X303q8m5YJ29Xt1y/TRH45j0zR5wUAfCxeYzQvNp+zo9tvZNopm8Z92WjvMKrbsUer6nerbseelH12Esn3ebgBmRcAKDBeXCacbcCV700m7cqU5Ps83ILMCwAUmHS3GHDTxS8ScCWSqtuvlWknu9mZKcnnebgJwQsAFKBU7fPdVjdRXBTQF0clH1OygCtfm0ymypRI6W3VEOHVzTLtxrQRABQotzWfS6Z2S6N+GmcH7ojrzhuSNODKV52PXYXGEV6sV3ICmRcAKGBe2Ok41c7cAUnP/LUxafYiX5tM2p0p8fJmmXYieAGAPLFr9Ynf2VHnka86H7szJV6sV3IC00YAkAd+69PhZMM0u7IX+dhk0u6VX
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(x, y2)"
]
2023-06-17 11:47:27 +00:00
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"_No, there is not_"
]
},
{
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "HycYQm3tbvyf"
},
"source": [
"### Task 1b\n",
"\n",
"In class you have seen the formula for the Pearson Correlation:\n",
"$\\rho(a, b) = \\frac{\\sum_{i = 1}^{m} (a_i - \\bar{a})(b_i - \\bar{b})}{\\sqrt{\\sum_{i=1}^{m} (a_i - \\bar{a})^2\\sum_{i = 1}^{m}(b_i - \\bar{b})^2}} $, where $\\bar{a} = \\frac{1}{m}\\sum_{i=1}^{m} a_i$ and $\\bar{b} = \\frac{1}{m}\\sum_{i=1}^{m} b_i$.\n",
"\n",
"* Compute the Pearson Correlation $\\rho$ between `x` and the target `y1`.\n",
"* Compute the Pearson Correlation between `x` and `y2`.\n",
"* Check that you get the same result as the reference implementation"
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 9,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "EUoJXIrCy0p6"
},
2023-06-17 13:35:15 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"rho(x, y1): 0.9513\n",
"rho(x, y2): 0.7052\n"
]
}
],
"source": [
"\n",
"def pearson(a, b):\n",
" return sum([(a[i] - a.mean()) * (b[i] - b.mean()) for i in range(0, len(a))]) / \\\n",
" np.sqrt(sum([(a[i] - a.mean()) ** 2 for i in range(0, len(a))]) * sum([(b[i] - b.mean()) ** 2 for i in range(0, len(b))]))\n",
"\n",
"print(f\"rho(x, y1): {pearson(x, y1):.4f}\")\n",
"print(f\"rho(x, y2): {pearson(x, y2):.4f}\")"
]
2023-06-17 11:47:27 +00:00
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 10,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "L_NesuDQddHS"
},
2023-06-17 13:35:15 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"rho(x, y1): 0.9513\n",
"rho(x, y2): 0.7052\n"
]
}
],
2023-06-17 11:47:27 +00:00
"source": [
"# Refer to the output of this cell to check whether your implementation of rho\n",
"# is correct.\n",
"\n",
"from scipy.stats import pearsonr\n",
"\n",
"print(f\"rho(x, y1): {pearsonr(x, y1)[0]:.4f}\")\n",
"print(f\"rho(x, y2): {pearsonr(x, y2)[0]:.4f}\")"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "Kr9OWmCilrAv"
},
"source": [
"## 📢 **HAND-IN** 📢: Report in Moodle whether you solved this task."
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "rbjhdwFceHlL"
},
"source": [
"# Task 2 (2 Points): Univariate Linear Regression"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "ucnYGKbmecz_"
},
"source": [
"### Task 2a\n",
"\n",
"You will now implement Linear Regression with a single variable. In class you have seen that the underlying model is: $y = \\theta_0 + \\theta_1x$.\n",
"You also derived the maximum likelihood estimates for $\\theta_0$ and $\\theta_1$:\n",
"\n",
"* $\\hat{\\theta}_1 = \\frac{\\sum_{i=1}^{m} (x_i - \\bar{x})(y_i - \\bar{y})}{\\sum_{i=1}^{m}(x_i - \\bar{x})^2}$ with $\\bar{x} = \\frac{1}{m}\\sum_{i=1}^{m} x_i$ and $\\bar{y} = \\frac{1}{m}\\sum_{i=1}^{m} y_i$.\n",
"* $\\hat{\\theta}_0 = \\bar{y} - \\hat{\\theta}_1\\bar{x}$\n",
"\n",
"In the following cell, implement the `.fit` and `.predict` methods:\n",
"* In the `.predict` method you will have to apply the model to the input `x`\n",
"* In the `.fit` method you will have to compute $\\hat{\\theta}_0$ and $\\hat{\\theta}_1$."
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 11,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "qS0Oa5Btgk74"
},
"outputs": [],
"source": [
"class UnivariateLinearRegression:\n",
"\n",
" def __init__(self):\n",
" self.theta_0: float = 0.\n",
" self.theta_1: float = 0.\n",
"\n",
" def predict(self, x):\n",
" # y = theta_0 + theta_1 * x\n",
2023-06-17 13:35:15 +00:00
" return self.theta_0 + self.theta_1 * x\n",
2023-06-17 11:47:27 +00:00
"\n",
" def fit(self, x, y):\n",
"\n",
2023-06-17 13:35:15 +00:00
" self.theta_1 = sum([(x[i] - x.mean()) * (y[i] - y.mean()) for i in range(0, len(x))]) / sum([(x[i] - x.mean()) ** 2 for i in range(0, len(x))])\n",
" self.theta_0 = y.mean() - self.theta_1 * x.mean()\n",
2023-06-17 11:47:27 +00:00
"\n",
" return self"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "9LzenH1UhLOs"
},
"source": [
"### Task 2b\n",
"\n",
"Fit your linear model to `x` and the target `y1`.\n",
"\n",
"* Create an instance of the class `UnivariateLinearRegression`\n",
"* fit the model using its `.fit` method\n",
"* get the predicted values, using `.predict`\n"
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 13,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "UHGuDWAntd8R"
},
2023-06-17 13:35:15 +00:00
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.81972252, -0.08621518, -0.65077025, -0.71108605, 1.25473192,\n",
" 1.50019396, 0.74489883, 1.04803555, 0.58943114, 1.55525673,\n",
" 1.26110578, -0.74510824, 1.36362466, -0.66899868, 1.04842757,\n",
" -0.31846659, 1.37787255, 0.58409243, -0.01238023, 0.29103911,\n",
" -0.68199134, -0.44521854, 0.9027792 , 0.84495782, 0.76648623,\n",
" 0.19478983, 1.70856978, 1.66816843, 0.9395856 , 0.85302538,\n",
" 0.94675253, 0.20772813, -0.41853885, 1.02827672, 0.54435158,\n",
" 0.0136006 , 0.4468457 , 1.44278502, 1.55271809, 0.1309298 ,\n",
" 0.65828127, 0.04228939, 0.71446262, 0.08186971, 0.21438391,\n",
" 1.44472561, -0.1913948 , 0.78573633, -0.54457236, 1.30253353,\n",
" 1.19015742, -0.16126428, 1.41070099, -0.60735898, 0.07744293,\n",
" -0.38107765, 0.35926577, 1.21292081, -0.18279715, -0.62351186,\n",
" 0.24629335, -0.26207004, -0.5279483 , 0.67999998, -0.01488642,\n",
" 0.90616057, -0.2595968 , 1.57262835, 0.14897817, -0.49157451,\n",
" 0.80034535, 1.53572082, 0.33468582, 1.60341403, 0.48153732,\n",
" 0.29730957, 0.77839929, 1.70335527, 1.58948153, 0.383213 ,\n",
" 1.1176936 , 0.47543535, 0.55411682, 1.1869188 , 0.27122316,\n",
" 1.0603401 , 1.00275116, 1.54782335, -0.46828955, 1.04684768,\n",
" 1.53638546, 1.63631745, -0.71557985, 1.3790104 , 1.66905593,\n",
" 1.60987763, -0.38481676, 1.64792032, 1.44388969, 1.27719337])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = UnivariateLinearRegression()\n",
"model.fit(x, y1)\n",
"model.predict(x)"
]
2023-06-17 11:47:27 +00:00
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "elE3OfjHjBRO"
},
"source": [
"* implement the function `plot_model`\n",
"* use `plot_model` to plot your linear regression model given the true datapoints"
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 22,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "T0eKDuRt1YOF"
},
2023-06-17 13:35:15 +00:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAi8AAAGzCAYAAADnmPfhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABcJUlEQVR4nO3deXgUVboG8LeyL2QBsnQCAQIyKLKE3QAKYpBNBsYZR1EEVBZRVERHYO4AgoyA4oijXmBQQQXc5iooQpRVBgggS5SwCTEQliQskQQCJJA+9w+mm3TSS1V3VXdV5/09Tx5N9enqUymgvpzzne9IQggBIiIiIoMI8HUHiIiIiJRg8EJERESGwuCFiIiIDIXBCxERERkKgxciIiIyFAYvREREZCgMXoiIiMhQGLwQERGRoTB4ISIiIkNh8EJELjVp0gQjRoywfr9p0yZIkoRNmzap9hmSJOHll19W7Xxa+PHHH9G1a1dERkZCkiRkZ2f7uktEtRKDFyKdW7JkCSRJsn6FhYXhd7/7HcaNG4eioiJfd0+R1atX6z5AceTatWt44IEHUFxcjDfffBMff/wxGjdurPnnfvbZZxg6dCiaN28OSZLQs2dPzT+TSO+CfN0BIpJnxowZSE1NxdWrV7FlyxbMnz8fq1evRk5ODiIiIrzal7vuugtXrlxBSEiIovetXr0a7777rt0A5sqVKwgK0u8/Sbm5uTh+/DgWLVqEkSNHeu1z58+fj927d6NTp044f/681z6XSM/0+y8FEdno168fOnbsCAAYOXIk6tevj3/84x9YuXIlhgwZYvc9ZWVliIyMVL0vAQEBCAsLU/Wcap9PbWfOnAEAxMbGevVzP/74YzRo0AABAQFo1aqVVz+bSK84bURkUL169QIA5OXlAQBGjBiBOnXqIDc3F/3790dUVBQeeeQRAIDZbMa8efNw++23IywsDImJiRgzZgx+++03m3MKITBz5kw0bNgQERERuPvuu7F///4an+0o52XHjh3o378/6tati8jISLRp0wZvvfWWtX/vvvsuANhMg1nYy3nZu3cv+vXrh+joaNSpUwf33HMPtm/fbtPGMq22detWTJgwAfHx8YiMjMQf/vAHnD17VtbPcsOGDbjzzjsRGRmJ2NhYDBo0CAcPHrS+PmLECPTo0QMA8MADDzidvvn1118hSRLefPPNGq9t27YNkiThk08+kdUvAEhJSUFAAP+pJqqKIy9EBpWbmwsAqF+/vvXY9evX0adPH3Tv3h1z5861TieNGTMGS5YswWOPPYZnn30WeXl5eOedd7B3715s3boVwcHBAICpU6di5syZ6N+/P/r37489e/bg3nvvRUVFhcv+rF27Fvfddx+SkpLw3HPPwWQy4eDBg1i1ahWee+45jBkzBqdPn8batWvx8ccfuzzf/v37ceeddyI6OhovvfQSgoODsXDhQvTs2RM//PADunTpYtP+mWeeQd26dTFt2jQcO3YM8+bNw7hx4/DZZ585/Zx169ahX79+aNq0KV5++WVcuXIFb7/9Nrp164Y9e/agSZMmGDNmDBo0aIBXX30Vzz77LDp16oTExES752vatCm6deuGZcuW4fnnn7d5bdmyZYiKisKgQYNcXj8ROSGISNcWL14sAIh169aJs2fPihMnTohPP/1U1K9fX4SHh4uTJ08KIYQYPny4ACAmTZpk8/7//Oc/AoBYtmyZzfHMzEyb42fOnBEhISFiwIABwmw2W9v99a9/FQDE8OHDrcc2btwoAIiNGzcKIYS4fv26SE1NFY0bNxa//fabzedUPdfTTz8tHP2zA0BMmzbN+v3gwYNFSEiIyM3NtR47ffq0iIqKEnfddVeNn09GRobNZz3//PMiMDBQXLhwwe7nWaSlpYmEhARx/vx567GffvpJBAQEiGHDhtW45i+++MLp+YQQYuHChQKAOHjwoPVYRUWFiIuLs/k5KnX77beLHj16uP1+In/BsUgig8jIyEB8fDxSUlLw0EMPoU6dOvjqq6/QoEEDm3Zjx461+f6LL75ATEwMevfujXPnzlm/OnTogDp16mDjxo0AboxAVFRU4JlnnrGZzhk/frzLvu3duxd5eXkYP358jZyQqueSq7KyEt9//z0GDx6Mpk2bWo8nJSXh4YcfxpYtW1BaWmrzntGjR9t81p133onKykocP37c4ecUFBQgOzsbI0aMQL169azH27Rpg969e2P16tWK+w4Af/7znxEWFoZly5ZZj3333Xc4d+4chg4d6tY5iegmThsRGcS7776L3/3udwgKCkJiYiJatGhRIxciKCgIDRs2tDl25MgRlJSUICEhwe55LYmolod88+bNbV6Pj49H3bp1nfbNMoWlVkLp2bNncfnyZbRo0aLGa7fddhvMZjNOnDiB22+/3Xq8UaNGNu0sfa6e11OV5Zodfc53333nVtJzbGwsBg4ciOXLl+OVV14BcGPKqEGDBtZcJSJyH4MXIoPo3LmzdbWRI6GhoTUCGrPZjISEBJtRgKri4+NV66MvBQYG2j0uhPByT24YNmwYvvjiC2zbtg2tW7fG119/jaeeeorJt0QqYPBC5OeaNWuGdevWoVu3bggPD3fYzlJw7ciRIzZTNWfPnnU6emH5DADIyclBRkaGw3Zyp5Di4+MRERGBw4cP13jt0KFDCAgIQEpKiqxzOWO5ZkefExcX5/ZS8759+yI+Ph7Lli1Dly5dcPnyZTz66KMe9ZeIbuCvAER+7s9//jMqKyut0xdVXb9+HRcuXABwI6cmODgYb7/9ts1oxbx581x+Rvv27ZGamop58+ZZz2dR9VyWQKB6m+oCAwNx7733YuXKlTh27Jj1eFFREZYvX47u3bsjOjraZb9cSUpKQlpaGj788EObPuXk5OD7779H//793T53UFAQhgwZgs8//xxLlixB69at0aZNG4/7TEQceSHyez169MCYMWMwa9YsZGdn495770VwcDCOHDmCL774Am+99Rb+9Kc/IT4+Hi+++CJmzZqF++67D/3798fevXuxZs0axMXFOf2MgIAAzJ8/HwMHDkRaWhoee+wxJCUl4dChQ9i/fz++++47AECHDh0AAM8++yz69OmDwMBAPPTQQ3bPOXPmTKxduxbdu3fHU089haCgICxcuBDl5eV47bXXVPv5vP766+jXrx/S09PxxBNPWJdKx8TEeLyVwbBhw/DPf/4TGzduxJw5c9w6x+bNm7F582YAN0bBysrKMHPmTAA3Kh3fddddHvWRyJB8vNqJiFywLAX+8ccfnbYbPny4iIyMdPj6v/71L9GhQwcRHh4uoqKiROvWrcVLL70kTp8+bW1TWVkppk+fLpKSkkR4eLjo2bOnyMnJEY0bN3a6VNpiy5Ytonfv3iIqKkpERkaKNm3aiLffftv6+vXr18Uzzzwj4uPjhSRJNsumUW2ptBBC7NmzR/Tp00fUqVNHREREiLvvvlts27ZN1s/HUR/tWbdunejWrZsIDw8X0dHRYuDAgeLAgQN2zydnqXRVt99+uwgICLAuaVdq2rRpAoDdr+o/L6LaQhLCR9lsRES1QLt27VCvXj2sX7/e110h8hvMeSEi0siuXbuQnZ2NYcOG+borRH6FIy9ERCrLycnB7t278cYbb+DcuXP49ddfbTaerKysdLnvUp06dVCnTh2tu0pkSBx5ISJS2b///W889thjuHbtGj755JMaO2afOHECSUlJTr/mzp3ro94T6R9HXoiIvOzq1avYsmWL0zZNmza1qbdDRDcxeCEiIiJD4bQRERERGYrfFakzm804ffo0oqKi3NrNloiIiLxPCIGLFy8iOTnZ5R5gfhe8nD59WpU9T4iIiMj7Tpw4gYYNGzpt43fBS1RUFIAbF6/G3idERESkvdLSUqSkpFif4874XfBimSqKjo5m8EJERGQwclI+mLBLREREhsLghYiIiAyFwQsREREZit/lvMghhMD169dRWVnp666QSgIDAxEUFMTl8UREtUCtC14qKipQUFCAy5cv+7orpLKIiAgkJSUhJCTE110hIiIN1argxWw2Iy8vD4GBgUhOTkZISAh/U
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2023-06-17 11:47:27 +00:00
"source": [
"def plot_model(x, y_pred, y_true, title):\n",
2023-06-17 13:35:15 +00:00
" plt.title(title)\n",
" plt.plot(x, y_pred, label=\"prediction\", c=\"black\")\n",
" plt.scatter(x, y_true, label=\"true\")\n",
" plt.legend()\n",
2023-06-17 11:47:27 +00:00
" plt.show()\n",
"\n",
2023-06-17 13:35:15 +00:00
"plot_model(x, model.predict(x), y1, \"Prediction of y_1\")"
2023-06-17 11:47:27 +00:00
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "tt2RnAwAG1n9"
},
"source": [
"* Fit another linear model to `x` and `y2`\n",
"* get the predicted values\n",
"* plot the model with `plot_model`"
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 23,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "Ccq3GI17Ga2x"
},
2023-06-17 13:35:15 +00:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAi8AAAGzCAYAAADnmPfhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABZmElEQVR4nO3de1xUdf4/8NeA3JFBFBg0VDTXGyhqoZilKSZarv7a3dI0tS11LfNWm9q3NM3N3NyyNTfdtrTy1mXXzDLKa62GWiol3lLCO+CFHBBFhPn8/nBnYmDuc86Zc2Zez8eDx0POfObM58yA583n8/68PzohhAARERGRRgT5ugNERERE7mDwQkRERJrC4IWIiIg0hcELERERaQqDFyIiItIUBi9ERESkKQxeiIiISFMYvBAREZGmMHghIiIiTWHwQkROtWzZEmPGjLF8v337duh0Omzfvl2y19DpdHjhhRckO58cvvvuO/Ts2RNRUVHQ6XTIy8vzdZeIAhKDFyKVW7FiBXQ6neUrPDwcv/nNbzBx4kSUlJT4untu2bhxo+oDFHtu3LiBP/zhDygtLcVrr72G999/Hy1atJD1NS9duoRXXnkFd911F+Lj4xEbG4sePXrggw8+kPV1idSuga87QESumTt3LlJSUlBZWYkdO3bgzTffxMaNG5Gfn4/IyEhF+3LXXXfh2rVrCA0Ndet5GzduxJIlS2wGMNeuXUODBur9L6mgoAAnT57EW2+9hccee0yR18zNzcX//d//YdCgQXjuuefQoEED/Pvf/8awYcNw6NAhzJkzR5F+EKmNev+nICIrAwcOxG233QYAeOyxx9C4cWO8+uqrWL9+PYYPH27zORUVFYiKipK8L0FBQQgPD5f0nFKfT2rnz58HAMTGxir2mh07dsSxY8esRngef/xxZGVlYcGCBXjmmWdk+XyJ1I7TRkQa1bdvXwBAYWEhAGDMmDGIjo5GQUEBBg0ahIYNG2LEiBEAAJPJhEWLFqFjx44IDw9HYmIixo8fj19++cXqnEIIzJs3D7fccgsiIyNx99134+DBg/Ve217Oy+7duzFo0CA0atQIUVFR6NSpE15//XVL/5YsWQIAVtNgZrZyXvbv34+BAwciJiYG0dHR6NevH3bt2mXVxjyttnPnTkybNg3x8fGIiorC//t//w8XLlxw6b3cunUr7rzzTkRFRSE2NhZDhgzB4cOHLY+PGTMGvXv3BgD84Q9/gE6nQ58+fWye6+eff4ZOp8Nrr71W77Fvv/0WOp0Oa9ascalfKSkp9aamdDodhg4diuvXr+Pnn3926TxE/oYjL0QaVVBQAABo3Lix5Vh1dTUGDBiAXr16YeHChZbppPHjx2PFihV45JFHMGnSJBQWFuKNN97A/v37sXPnToSEhAAAZs2ahXnz5mHQoEEYNGgQ9u3bh3vuuQdVVVVO+7Np0ybcd999SEpKwuTJk2EwGHD48GF89tlnmDx5MsaPH49z585h06ZNeP/9952e7+DBg7jzzjsRExODZ555BiEhIVi2bBn69OmDr7/+Gt27d7dq/+STT6JRo0aYPXs2Tpw4gUWLFmHixIlO80M2b96MgQMHolWrVnjhhRdw7do1LF68GHfccQf27duHli1bYvz48WjWrBleeuklTJo0CbfffjsSExNtnq9Vq1a44447sGrVKkydOtXqsVWrVqFhw4YYMmSI0+t3pLi4GADQpEkTr85DpFmCiFRt+fLlAoDYvHmzuHDhgjh9+rRYu3ataNy4sYiIiBBnzpwRQggxevRoAUDMmDHD6vn//e9/BQCxatUqq+M5OTlWx8+fPy9CQ0PFvffeK0wmk6Xds88+KwCI0aNHW45t27ZNABDbtm0TQghRXV0tUlJSRIsWLcQvv/xi9Tq1z/XEE08Ie//tABCzZ8+2fD906FARGhoqCgoKLMfOnTsnGjZsKO666656709WVpbVa02dOlUEBweLy5cv23w9s/T0dJGQkCAuXbpkOfbDDz+IoKAgMWrUqHrX/NFHHzk8nxBCLFu2TAAQhw8fthyrqqoSTZo0sXofPXHp0iWRkJAg7rzzTq/OQ6RlnDYi0oisrCzEx8cjOTkZw4YNQ3R0NNatW4dmzZpZtZswYYLV9x999BH0ej369++PixcvWr66deuG6OhobNu2DcDNEYiqqio8+eSTVtM5U6ZMcdq3/fv3o7CwEFOmTKmXE1L7XK6qqanBV199haFDh6JVq1aW40lJSXjooYewY8cOlJWVWT1n3LhxVq915513oqamBidPnrT7OkVFRcjLy8OYMWMQFxdnOd6pUyf0798fGzdudLvvAPDAAw8gPDwcq1atshz78ssvcfHiRYwcOdKjcwI3p/9GjBiBy5cvY/HixR6fh0jrOG1EpBFLlizBb37zGzRo0ACJiYlo27YtgoKs//5o0KABbrnlFqtjx44dg9FoREJCgs3zmhNRzTf5Nm3aWD0eHx+PRo0aOeybeQorNTXV9Qty4MKFC7h69Sratm1b77H27dvDZDLh9OnT6Nixo+V48+bNrdqZ+1w3r6c28zXbe50vv/zSo6Tn2NhYDB48GKtXr8aLL74I4OaUUbNmzSy5Sp548sknkZOTg/feew+dO3f2+DxEWsfghUgjMjIyLKuN7AkLC6sX0JhMJiQkJFiNAtQWHx8vWR99KTg42OZxIYTCPblp1KhR+Oijj/Dtt98iLS0Nn376KR5//PF6n4+r5syZg3/84x94+eWX8fDDD0vcWyJtYfBC5Odat26NzZs344477kBERITdduZVLceOHbOaqrlw4YLD0QvzawBAfn4+srKy7LZzdQopPj4ekZGROHr0aL3Hjhw5gqCgICQnJ7t0LkfM12zvdZo0aeLxUuTs7GzEx8dj1apV6N69O65evepx0GGujTNlyhRMnz7do3MQ+RPmvBD5uQceeAA1NTWW6YvaqqurcfnyZQA3c2pCQkKwePFiq9GKRYsWOX2Nrl27IiUlBYsWLbKcz6z2ucyBQN02dQUHB+Oee+7B+vXrceLECcvxkpISrF69Gr169UJMTIzTfjmTlJSE9PR0vPvuu1Z9ys/Px1dffYVBgwZ5fO4GDRpg+PDh+PDDD7FixQqkpaWhU6dObp/ngw8+wKRJkzBixAi8+uqrHveHyJ9w5IXIz/Xu3Rvjx4/H/PnzkZeXh3vuuQchISE4duwYPvroI7z++uv4/e9/j/j4eDz99NOYP38+7rvvPgwaNAj79+/HF1984XRJblBQEN58800MHjwY6enpeOSRR5CUlIQjR47g4MGD+PLLLwEA3bp1AwBMmjQJAwYMQHBwMIYNG2bznPPmzcOmTZvQq1cvPP7442jQoAGWLVuG69ev469//atk788rr7yCgQMHIjMzE48++qhlqbRer/d6K4NRo0bh73//O7Zt24YFCxa4/fw9e/Zg1KhRaNy4Mfr161dv6q9nz55Wo2REAcPHq52IyAnzUuDvvvvOYbvRo0eLqKgou4//85//FN26dRMRERGiYcOGIi0tTTzzzDPi3LlzljY1NTVizpw5IikpSURERIg+ffqI/Px80aJFC4dLpc127Ngh+vfvLxo2bCiioqJEp06dxOLFiy2PV1dXiyeffFLEx8cLnU5ntWwadZZKCyHEvn37xIABA0R0dLSIjIwUd999t/j2229den/s9dGWzZs3izvuuENERESImJgYMXjwYHHo0CGb53NlqXRtHTt2FEFBQZYl7e4wX5u9r+XLl7t9TiJ/oBPCR9lsREQBoEuXLoiLi8OWLVt83RUiv8GcFyIimXz//ffIy8vDqFGjfN0VIr/CkRciIonl5+dj7969+Nvf/oaLFy/i559/ttp4sqamxum+S9HR0YiOjpa7q0SaxJEXIiKJffzxx3jkkUdw48YNrFmzpt6O2adPn0ZSUpLDr4ULF/qo90Tqx5EXIiKFVVZWYseOHQ7btGrViiuJiOxg8EJERESawmkjIiIi0hS/K1JnMplw7tw5NGzY0KPdbImIiEh5QgiUl5ejadOmTvcA87vg5dy5c5LseUJERETKO336NG655RaHbfwue
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = UnivariateLinearRegression()\n",
"model.fit(x, y2)\n",
"plot_model(x, model.predict(x), y2, \"Prediction of y_2\")"
]
2023-06-17 11:47:27 +00:00
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "E0i3gWvIl7nY"
},
"source": [
"## 📢 **HAND-IN** 📢: A PDF document containing the following:\n",
"\n",
"* both plots containing the linear regression model and true datapoints\n",
"* a short (2-3 sentences) interpretation of the curves: why do you think they look the way\n",
"they do? can you draw any conclusions?\n",
"\n",
"**Solutions for Tasks 2, 3 and 4 should be in the same document: you will only upload 1 document with your solutions for all 3 tasks!**"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
" - _The prediction of $y_1$ is very accurate because there seems to be a linear relation between $x$ and $y_1$._\n",
" - _The prediction of $y_2$ looks very inaccurate - though, there seems to be a non-linear relation between_\n",
" - _The prediction of $y_2$ probably has an $R^2$ value nearby $0$._"
]
},
{
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "0TK0Pi4ClphY"
},
"source": [
"# Task 3 (4 Points): Univariate Linear Regression using Stochastic Gradient Descent"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "YL31gChVqLpC"
},
"source": [
"### Task 3a\n",
"\n",
"In class you have seen an alternative version to estimate the parameters $\\theta_i$ of the linear regression models by using Gradient Descent.\n",
"\n",
"For the univariate linear regression model, the stochastic gradient descent updates look like this:\n",
"* $\\theta_{0}^{(t+1)} = \\theta_{0}^{(t)} - \\alpha (\\theta_{0}^{(t)} + \\theta_{1}^{(t)} x_t - y_t)$\n",
"* $\\theta_{1}^{(t+1)} = \\theta_{1}^{(t)} - \\alpha (\\theta_{0}^{(t)} + \\theta_{1}^{(t)} x_t - y_t) x_t$\n",
"\n",
"Here $\\alpha$ is the learning rate, and $(x_t, y_t)$ is the data point sampled\n",
"at time $t$.\n",
"\n",
"\n",
"In the following cell, implement the `.fit` and `.predict` methods:\n",
"* In the `.predict` method you will have to apply the model to the input `x`.\n",
"* In the `.fit` method you will have to implement the update equations for\n",
"$\\theta_0$ and $\\theta_1$."
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 24,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "wJMHvQmXmVKr"
},
"outputs": [],
"source": [
"class SGDUnivariateLinearRegression:\n",
"\n",
" def __init__(self):\n",
" self.theta_0: float = 0.\n",
" self.theta_1: float = 0.\n",
" self.rng = np.random.default_rng(RANDOM_SEED)\n",
"\n",
" def predict(self, x):\n",
" # y = theta_0 + theta_1 * x\n",
2023-06-17 13:35:15 +00:00
" return self.theta_0 + self.theta_1 * x\n",
2023-06-17 11:47:27 +00:00
"\n",
" def fit(self, x, y, n_iter: int = 100, learning_rate: float = 1.0):\n",
" for t in range(n_iter):\n",
" sample_ix = self.rng.integers(0, len(x))\n",
"\n",
" xt = x[sample_ix]\n",
" yt = y[sample_ix]\n",
"\n",
" # TODO: update self.theta_0 and self.theta_1 SIMULTANEOUSLY (!!!) according to their update equations\n",
2023-06-17 13:35:15 +00:00
" theta_0 = self.theta_0 - learning_rate * (self.theta_0 + self.theta_1 * xt - yt)\n",
" self.theta_1 = self.theta_1 - learning_rate * (self.theta_0 + self.theta_1 * xt - yt) * xt\n",
" self.theta_0 = theta_0\n",
2023-06-17 11:47:27 +00:00
"\n",
" return self"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "MHLBmTm4vK9p"
},
"source": [
"### Task 3b\n",
"\n",
"Run SGD for `x` and the target `y1` and compute the mean squared error (MSE).\n",
"The MSE is defined as: $\\frac{1}{n}\\sum_{i=1}^{n} (\\hat{y}_i - y_i)^2$, where\n",
"$\\hat{y}$ are the model predictions.\n",
"\n",
"* Create an instance of the class `SGDUnivariateLinearRegression`\n",
"* fit the model using its `.fit` method\n",
"* get the predicted values, using `.predict`\n",
"* implement the `mse` function\n",
"* compute the MSE of your predictions"
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 25,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "CZ1szyQhK9so"
},
"outputs": [],
"source": [
"def mse(y_pred, y_true):\n",
2023-06-17 13:35:15 +00:00
" return 1 / len(y_pred) * sum([(y_pred[i] - y_true[i]) ** 2 for i in range(0, len(y_pred))])"
2023-06-17 11:47:27 +00:00
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 26,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "V35vBU5Yti8Z"
},
2023-06-17 13:35:15 +00:00
"outputs": [
{
"data": {
"text/plain": [
"0.2393998747452736"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = SGDUnivariateLinearRegression()\n",
"model.fit(x, y1)\n",
"mse(y1, model.predict(x))"
]
2023-06-17 11:47:27 +00:00
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "hSsE1o6GwA3K"
},
"source": [
"### Task 3c\n",
"\n",
"You will now plot the learning curves for different learning rates $\\alpha$.\n",
"A learning curves shows how a model's performance changes with increasing number of update steps.\n",
"In our case we will plot the model's MSE as a function of the number of update\n",
"steps `n_iter` for different values of `learning_rate`.\n",
"\n",
"In the following cell we setup most of the scaffold to create this plot. Follow\n",
"the instructions in the comments to finish the plots."
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 27,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "4Rr5ix7LNISB"
},
2023-06-17 13:35:15 +00:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGxCAYAAACeKZf2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB+DUlEQVR4nO3dd3xT9frA8U+SNknbdJcOoOy9SxniArXIEBUvVxFZohcHoiCKigMUvYIKKiqCoIiXIag/t4giiohUkELZe5XRQaF7ZZ3fH6Ghkba00PS06fP2lZfNOd9zznOStHn4To2iKApCCCGEEB5Cq3YAQgghhBBVSZIbIYQQQngUSW6EEEII4VEkuRFCCCGER5HkRgghhBAeRZIbIYQQQngUSW6EEEII4VEkuRFCCCGER/FSO4DqZrfbOX36NP7+/mg0GrXDEUIIIUQFKIpCTk4O9evXR6stv26mziU3p0+fJjo6Wu0whBBCCHEZTpw4QcOGDcstU+eSG39/f8Dx4gQEBKgcjRBCCCEqIjs7m+joaOf3eHnqXHJT3BQVEBAgyY0QQghRy1SkS4l0KBZCCCGER5HkRgghhBAeRZIbIYQQQniUOtfnRghPY7PZsFgsaochaihvb290Op3aYQhRrSS5EaKWUhSFlJQUMjMz1Q5F1HBBQUFERkbK3F6izpDkRohaqjixCQ8Px9fXV764xEUURSE/P5+0tDQAoqKiVI5IiOohyY0QtZDNZnMmNqGhoWqHI2owHx8fANLS0ggPD5cmKlEnSIdiIWqh4j42vr6+KkciaoPiz4n0zRJ1hSQ3QtRi0hQlKkI+J6KukeRGCCGEEB5FkhshRI1x7NgxNBoNiYmJNTqG/Px8hgwZQkBAABqNRkasCVHDSIdiIYSopE8++YQ//viDjRs3EhYWRmBgoNohCSFKkOSmimQXWthzOhutRkOPpiFqhyOEcKPDhw/Ttm1bOnTocNnnsNlsaDQatFqpQBeiqslvVRXZl5zD3Qv+4un/26F2KELUaHa7nddff50WLVpgMBho1KgR//3vf8ss//vvv9OjRw8MBgNRUVE888wzWK1W5/4vvviCjh074uPjQ2hoKHFxceTl5Tn3f/jhh7Rt2xaj0UibNm14//33Xc6/efNmYmJiMBqNdOvWjW3btpUbf58+fZg9ezbr169Ho9HQp08fADIyMhg1ahTBwcH4+voyYMAADh486Dxu8eLFBAUF8e2339KuXTsMBgNJSUk0adKEV155hVGjRmEymWjcuDHffvstZ86c4fbbb8dkMtGpUye2bNlSmZdZiDpNkpsq4mdwzB2RV2S9REkh3Csvr+xHYWHFyxYUVKxsZU2ZMoWZM2fywgsvsGfPHpYvX05ERESpZU+dOsXAgQPp3r0727dvZ968eXz00Ue88sorACQnJzNs2DDuu+8+9u7dy7p16/jXv/6FoigALFu2jKlTp/Lf//6XvXv38uqrr/LCCy/wySefAJCbm8ugQYNo164dCQkJvPjiizz55JPlxv/ll18yduxYevXqRXJyMl9++SUA9957L1u2bOHbb78lPj4eRVEYOHCgy/Dr/Px8XnvtNT788EN2795NeHg4AG+99RbXXHMN27Zt45ZbbmHkyJGMGjWKESNGsHXrVpo3b86oUaOc9yWEuASljsnKylIAJSsrq0rPe/RMrtL46e+Vdi/8WKXnFaI0BQUFyp49e5SCgoKL9kHZj4EDXcv6+pZdtndv17JhYaWXq4zs7GzFYDAoCxcuLHX/0aNHFUDZtm2boiiK8uyzzyqtW7dW7Ha7s8zcuXMVk8mk2Gw2JSEhQQGUY8eOlXq+5s2bK8uXL3fZ9vLLLyu9evVSFEVRPvjgAyU0NNTldZw3b55LDKWZMGGC0rvEC3TgwAEFUP7880/ntvT0dMXHx0f57LPPFEVRlI8//lgBlMTERJdzNW7cWBkxYoTzeXJysgIoL7zwgnNbfHy8AijJycllxlSe8j4vQtQWlfn+lj43VcTP4Hgp88w27HYFrVbmlRDin/bu3UtRURE33XRThcv36tXLZZ6Wa665htzcXE6ePEnnzp256aab6NixI/369ePmm2/m3//+N8HBweTl5XH48GHuv/9+xo4d6zzearU6OwDv3buXTp06YTQanft79ep1Wffl5eVFz549ndtCQ0Np3bo1e/fudW7T6/V06tTpouNLbiuuxerYseNF29LS0oiMjKx0fELUNZLcVJHiZimAAovNmewIUd1yc8ve98+Z988vOVSqf/ZzPXbsskNyKl4KoKrodDrWrFnDxo0b+fnnn3n33Xd57rnn2LRpk3NW3oULF7okHcXHqcHHx6fUCfW8vb2dPxfvL22b3W53c4RCeAbpc1NFfLx1FP/Nkn43Qk1+fmU/SlRQXLLsP/OQsspVRsuWLfHx8WHt2rUVKt+2bVtn/5Vif/75J/7+/jRs2BBwfPFfc801vPTSS2zbtg29Xs9XX31FREQE9evX58iRI7Ro0cLl0bRpU+f5d+zYQWGJzkh//fVX5W7q/HmsViubNm1ybjt79iz79++nXbt2lT6fEOLKSHJTRTQaDX76C01TQoiLGY1Gnn76aZ566in+97//cfjwYf766y8++uijUsuPGzeOEydO8Oijj7Jv3z6++eYbpk2bxqRJk9BqtWzatIlXX32VLVu2kJSUxJdffsmZM2do27YtAC+99BIzZszgnXfe4cCBA+zcuZOPP/6YN998E4B77rkHjUbD2LFj2bNnD6tWrWLWrFmVvq+WLVty++23M3bsWDZs2MD27dsZMWIEDRo04Pbbb7/8F0wIcVmk7aQK+Rl05BZZpeZGiHK88MILeHl5MXXqVE6fPk1UVBQPPfRQqWUbNGjAqlWrmDx5Mp07dyYkJIT777+f559/HoCAgADWr1/P22+/TXZ2No0bN2b27NkMGDAAgP/85z/4+vryxhtvMHnyZPz8/OjYsSMTJ04EwGQy8d133/HQQw8RExNDu3bteO211xgyZEil7+vjjz9mwoQJDBo0CLPZzPXXX8+qVatcmpeEENVDoyh1a2xhdnY2gYGBZGVlERAQUKXnvnHWOo6k57Hygavo2Sy0Ss8tREmFhYUcPXqUpk2bunSGFaI08nkRnqAy39/SLFWFLoyYkpobIYQQQi2S3FQhX33xRH7S50YIIYRQiyQ3VchUXHMjfW6EEEII1UhyU4VKTuQnhBBCCHVIclOFZH0pIYQQQn2S3FShC/PcSHIjhBBCqEWSmyrkK31uhBBCCNVJclOFTOebpfJltJQQQgihGkluqpDv+WapXKm5EeKyHTt2DI1GQ2JiYoWPWbx4MUFBQW6LSQhRu0hyU4VMMomfEKKE3bt3M2TIEJo0aYJGo+Htt9+u0HE7duzguuuuw2g0Eh0dzeuvv35Rmc8//5w2bdpgNBrp2LEjq1atquLohai9JLmpQjKJnxCipPz8fJo1a8bMmTOJjIys0DHZ2dncfPPNNG7cmISEBN544w1efPFFFixY4CyzceNGhg0bxv3338+2bdsYPHgwgwcPZteuXe66FSFqFUluqpBM4ifEpa1evZprr72WoKAgQkNDGTRoEIcPHy6z/Lp169BoNPzwww906tQJo9HIVVddVeoX+U8//UTbtm0xmUz079+f5ORk576///6bvn37EhYWRmBgIL1792br1q1uucdi3bt354033uDuu+/GYDBU6Jhly5ZhNptZtGgR7du35+677+axxx5zrmQOMGfOHPr378/kyZNp27YtL7/8Ml27duW9995z160IUatIclOFikdL5cskfkKUKS8vj0mTJrFlyxbWrl2LVqvljjvuwG63l3vc5MmTmT17Nn///Tf16tXj1ltvxWKxOPfn5+cza9YslixZwvr160lKSuLJJ5907s/JyWH06NFs2LCBv/76i5YtWzJw4EBycnLKvOayZcswmUzlPv74448rf1FKiI+P5/rrr0ev1zu39evXj/3795ORkeEsExcX53Jcv379iI+Pr9JYhKitvNQOwJMUj5aSDsVCDYqiUGBRJ7H28dah0WgqVHbIkCEuzxctW
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2023-06-17 11:47:27 +00:00
"source": [
"n_iters = [50, 100, 200, 500, 1000, 2000]\n",
"learning_rates = [1., .1, .01]\n",
"\n",
"# we plot the MSE achieved by the closed form model as a reference\n",
"closed_form = UnivariateLinearRegression()\n",
"closed_form.fit(x, y1)\n",
"mse_base = mse(y_pred=closed_form.predict(x), y_true=y1)\n",
"plt.plot(n_iters, np.ones_like(n_iters) * mse_base, label=\"closed form\", linestyle='--', c='b')\n",
"\n",
"for alpha in learning_rates:\n",
" mses = []\n",
" for n_iter in n_iters:\n",
" # fit a SGDUnivariateLinearRegression model using n_iter=n_iter and\n",
" # learning_rate=alpha\n",
" # compute its mse and append the mse value to the mses list\n",
2023-06-17 13:35:15 +00:00
" model = SGDUnivariateLinearRegression()\n",
" model.fit(x, y1, n_iter=n_iter, learning_rate=alpha)\n",
2023-06-17 11:47:27 +00:00
"\n",
2023-06-17 13:35:15 +00:00
" mse_ = mse(model.predict(x), y1)\n",
2023-06-17 11:47:27 +00:00
" mses.append(mse_)\n",
" plt.plot(n_iters, mses, label=f\"alpha = {alpha:.2f}\")\n",
"\n",
"plt.xlabel(\"n_iter\")\n",
"plt.ylabel(\"MSE\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "SmCkMMJyEEgV"
},
"source": [
"## 📢 **HAND-IN** 📢: A PDF document containing the following:\n",
"\n",
"* the final plot containing learning curves\n",
"* a short (2-3 sentences) interpretation of the curves: why do you think they look the way\n",
"they do? can you draw any conclusions?\n",
"\n",
"In case you were not able to arrive at the final plot:\n",
"\n",
"* include screenshots of the code you wrote so we can assign partial credit\n",
"\n",
"**Solutions for Tasks 2, 3 and 4 should be in the same document: you will only upload 1 document with your solutions for all 3 tasks!**\n"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"- _The learning rate $1.0$ is too high causing the function to diverge from the actual results_\n",
"- _The best learning rate seems to be $0.01$._"
]
},
{
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "dgrNtwsPyigH"
},
"source": [
"# Task 4 (3 Points): Multivariate Linear Regression"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "_sPWegXCg2y1"
},
"source": [
"In this task we will apply linear regression to non-synthetic data.\n",
"The variable `X` is a `pandas` `Dataframe` containing features and `y` contains\n",
"the target. Read through the description to get an idea of the different variables."
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 28,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "djGUQ3kVx9ob"
},
2023-06-17 13:35:15 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
".. _diabetes_dataset:\n",
"\n",
"Diabetes dataset\n",
"----------------\n",
"\n",
"Ten baseline variables, age, sex, body mass index, average blood\n",
"pressure, and six blood serum measurements were obtained for each of n =\n",
"442 diabetes patients, as well as the response of interest, a\n",
"quantitative measure of disease progression one year after baseline.\n",
"\n",
"**Data Set Characteristics:**\n",
"\n",
" :Number of Instances: 442\n",
"\n",
" :Number of Attributes: First 10 columns are numeric predictive values\n",
"\n",
" :Target: Column 11 is a quantitative measure of disease progression one year after baseline\n",
"\n",
" :Attribute Information:\n",
" - age age in years\n",
" - sex\n",
" - bmi body mass index\n",
" - bp average blood pressure\n",
" - s1 tc, total serum cholesterol\n",
" - s2 ldl, low-density lipoproteins\n",
" - s3 hdl, high-density lipoproteins\n",
" - s4 tch, total cholesterol / HDL\n",
" - s5 ltg, possibly log of serum triglycerides level\n",
" - s6 glu, blood sugar level\n",
"\n",
"Note: Each of these 10 feature variables have been mean centered and scaled by the standard deviation times the square root of `n_samples` (i.e. the sum of squares of each column totals 1).\n",
"\n",
"Source URL:\n",
"https://www4.stat.ncsu.edu/~boos/var.select/diabetes.html\n",
"\n",
"For more information see:\n",
"Bradley Efron, Trevor Hastie, Iain Johnstone and Robert Tibshirani (2004) \"Least Angle Regression,\" Annals of Statistics (with discussion), 407-499.\n",
"(https://web.stanford.edu/~hastie/Papers/LARS/LeastAngle_2002.pdf)\n",
"\n"
]
}
],
2023-06-17 11:47:27 +00:00
"source": [
"from sklearn.datasets import load_diabetes\n",
"\n",
"data = load_diabetes(as_frame=True)\n",
"\n",
"X = data['data']\n",
"y = data['target']\n",
"description = data['DESCR']\n",
"\n",
"print(description)"
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "byOVt9t9_2c7"
},
"source": [
"### Task 4a\n",
"\n",
"Implement linear regression using `sklearn`.\n",
"\n",
"* create an instance of the class `sklearn.linear_model.LinearRegression`. Refer to the documentation at: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html\n",
"* call its `.fit` method\n",
"* get the predicted values with `.predict`"
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 30,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "eyiU4nCQBovr"
},
"outputs": [],
"source": [
"from sklearn.linear_model import LinearRegression"
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 32,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "G4AktC189PAc"
},
2023-06-17 13:35:15 +00:00
"outputs": [
{
"data": {
"text/plain": [
"array([206.11667725, 68.07103297, 176.88279035, 166.91445843,\n",
" 128.46225834, 106.35191443, 73.89134662, 118.85423042,\n",
" 158.80889721, 213.58462442, 97.07481511, 95.10108423,\n",
" 115.06915952, 164.67656842, 103.07814257, 177.17487964,\n",
" 211.7570922 , 182.84134823, 148.00326937, 124.01754066,\n",
" 120.33362197, 85.80068961, 113.1134589 , 252.45225837,\n",
" 165.48779206, 147.71997564, 97.12871541, 179.09358468,\n",
" 129.05345958, 184.7811403 , 158.71516713, 69.47575778,\n",
" 261.50385365, 112.82234716, 78.37318279, 87.66360785,\n",
" 207.92114668, 157.87641942, 240.84708073, 136.93257456,\n",
" 153.48044608, 74.15426666, 145.62742227, 77.82978811,\n",
" 221.07832768, 125.21957584, 142.6029986 , 109.49562511,\n",
" 73.14181818, 189.87117754, 157.9350104 , 169.55699526,\n",
" 134.1851441 , 157.72539008, 139.11104979, 72.73116856,\n",
" 207.82676612, 80.11171342, 104.08335958, 134.57871054,\n",
" 114.23552012, 180.67628279, 61.12935368, 98.72404613,\n",
" 113.79577026, 189.95771575, 148.98351571, 124.34152283,\n",
" 114.8395504 , 121.99957578, 73.91017087, 236.71054289,\n",
" 142.31126791, 124.51672384, 150.84073896, 127.75230658,\n",
" 191.16896496, 77.05671154, 166.82164929, 91.00591229,\n",
" 174.75156797, 122.83451589, 63.27231315, 151.99867317,\n",
" 53.72959077, 166.0050229 , 42.6491333 , 153.04229493,\n",
" 80.54701716, 106.90148495, 79.93968011, 187.1672654 ,\n",
" 192.5989033 , 61.07398313, 107.4076912 , 125.04307496,\n",
" 207.72402726, 214.21248827, 123.47464895, 139.16439034,\n",
" 168.21372017, 106.92902558, 150.64748328, 157.92364009,\n",
" 152.75958287, 116.22381927, 73.03167734, 155.67052006,\n",
" 230.1417777 , 143.49797317, 38.09587272, 121.8593267 ,\n",
" 152.79404663, 207.99702587, 291.23106133, 189.17571129,\n",
" 214.02877593, 235.18106509, 165.38480498, 151.2469168 ,\n",
" 156.57659557, 200.44066818, 219.35193167, 174.78830391,\n",
" 169.23118221, 187.87537099, 57.49340026, 108.54836058,\n",
" 92.68731024, 210.87347343, 245.47097701, 69.84285129,\n",
" 113.03485904, 68.42650654, 141.69639374, 239.46240737,\n",
" 58.37858726, 235.47123197, 254.92309543, 253.30708899,\n",
" 155.51063293, 230.55961445, 170.44330954, 117.9953395 ,\n",
" 178.55406527, 240.07119308, 190.33892524, 228.66470581,\n",
" 114.24456339, 178.36552308, 209.091817 , 144.85615197,\n",
" 200.65926745, 121.34295733, 150.50993019, 199.01879825,\n",
" 146.27926469, 124.02163345, 85.25913019, 235.16173729,\n",
" 82.1730808 , 231.29474031, 144.36940116, 197.04628448,\n",
" 146.99841953, 77.18813284, 59.37368356, 262.68557988,\n",
" 225.12900796, 220.20301952, 46.59651844, 88.10194612,\n",
" 221.77450036, 97.25199783, 164.48838425, 119.90096817,\n",
" 157.80220788, 223.08012207, 99.59081773, 165.84386951,\n",
" 179.47680741, 89.83353846, 171.82590335, 158.36419935,\n",
" 201.48185539, 186.39194958, 197.47424761, 66.57371647,\n",
" 154.59985312, 116.18319159, 195.91755793, 128.04834496,\n",
" 91.20395862, 140.57223765, 155.22669143, 169.70326581,\n",
" 98.7573858 , 190.14568824, 142.51704894, 177.27157771,\n",
" 95.30812216, 69.06191507, 164.16391317, 198.0659024 ,\n",
" 178.25996632, 228.58539684, 160.67104137, 212.28734795,\n",
" 222.4833913 , 172.85421282, 125.27946793, 174.72103207,\n",
" 152.38094643, 98.58135665, 99.73771331, 262.29507095,\n",
" 223.74033222, 221.33976142, 133.61470602, 145.42828204,\n",
" 53.04569008, 141.82052358, 153.68617582, 125.22290891,\n",
" 77.25168449, 230.26180811, 78.9090807 , 105.2051755 ,\n",
" 117.99622779, 99.06233889, 166.55796947, 159.34137227,\n",
" 158.27448255, 143.05684078, 231.55890118, 176.64724258,\n",
" 187.23580712, 65.39099908, 190.66218796, 179.75181691,\n",
" 234.9080532 , 119.15669025, 85.63551834, 100.8597527 ,\n",
" 140.41937377, 101.83524022, 120.66560385, 83.0664276 ,\n",
" 234.58488012, 245.15862773, 263.26954282, 274.87127261,\n",
" 180.67257769, 203.05642297, 254.21625849, 118.44300922,\n",
" 268.45369506, 104.83843473, 115.86820464, 140.45857194,\n",
" 58.46948192, 129.83145265, 263.78607272, 45.00934573,\n",
" 123.28890007, 131.0856888 , 34.89181681, 138.35467112,\n",
" 244.30103923, 89.95923929, 192.07096194, 164.33017386,\n",
" 147.74779723, 191.89092557, 176.44360299, 158.3490221 ,\n",
" 189.19166962, 116.58117777, 111.449754 , 117.45232726,\n",
" 165.79598354, 97.80405886, 139.54451791, 84.17319946,\n",
" 159.93677518, 202.39971737, 80.48131518, 146.64558568,\n",
" 79.05314048, 191.33777472, 220.67516721, 203.75017281,\n",
" 92.86459928, 179.15576252, 81.79874055, 152.8290929 ,\n",
" 76.80052219, 97.79590831, 106.8371012 , 123.83461591,\n",
" 218.13908293, 126.01937664, 206.7587966 , 230.5767944 ,\n",
" 122.05921633, 135.67824405, 126.37042532, 148.49374458,\n",
" 88.07147107, 138.95823614, 203.8691938 , 172.55288732,\n",
" 122.95701477, 213.92310163, 174.89158814, 110.07294222,\n",
" 198.36584973, 173.25229067, 162.64748776, 193.31578983,\n",
" 191.53493643, 284.13932209, 279.31133207, 216.00823829,\n",
" 210.08668656, 216.21612991, 157.01450004, 224.06431372,\n",
" 189.06103154, 103.56515315, 178.70270016, 111.81862434,\n",
" 291.00196609, 182.64651752, 79.33315426, 86.33029851,\n",
" 249.1510082 , 174.51537682, 122.10291074, 146.2718871 ,\n",
" 170.65483847, 183.497196 , 163.36806262, 157.03297709,\n",
" 144.42614949, 125.30053093, 177.50251197, 104.57681546,\n",
" 132.17560518, 95.06210623, 249.89755705, 86.23824126,\n",
" 61.99847009, 156.81295053, 192.32218372, 133.85525804,\n",
" 93.67249793, 202.49572354, 52.54148927, 174.82799914,\n",
" 196.91468873, 118.06336979, 235.29941812, 165.09438096,\n",
" 160.41761959, 162.37786753, 254.05587268, 257.23492156,\n",
" 197.5039462 , 184.06877122, 58.62131994, 194.39216636,\n",
" 110.775815 , 142.20991224, 128.82520996, 180.13082199,\n",
" 211.26488624, 169.59494046, 164.33851796, 136.23374077,\n",
" 174.51001028, 74.67587343, 246.29432383, 114.14494406,\n",
" 111.54552901, 140.0224376 , 109.99895704, 91.37283987,\n",
" 163.01540596, 75.16804478, 254.06119047, 53.47338214,\n",
" 98.48397565, 100.66315554, 258.58683032, 170.67256752,\n",
" 61.91771186, 182.31148421, 171.26948629, 189.19505093,\n",
" 187.18494664, 87.12170524, 148.37964317, 251.35815403,\n",
" 199.69656904, 283.63576862, 50.85911237, 172.14766276,\n",
" 204.05976093, 174.16540137, 157.93182911, 150.50028158,\n",
" 232.97445368, 121.5814873 , 164.54245461, 172.67625919,\n",
" 226.7768891 , 149.46832104, 99.13924946, 80.43418456,\n",
" 140.16148637, 191.90710484, 199.28001608, 153.63277325,\n",
" 171.80344337, 112.11054883, 162.60002916, 129.84290324,\n",
" 258.03100468, 100.70810916, 115.87608197, 122.53559675,\n",
" 218.1797988 , 60.94350929, 131.09296884, 119.48376601,\n",
" 52.60911672, 193.01756549, 101.05581371, 121.22668124,\n",
" 211.85894518, 53.44727472])"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = LinearRegression()\n",
"model.fit(X, y)\n",
"model.predict(X)"
]
2023-06-17 11:47:27 +00:00
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "qQUdYHOXpeLd"
},
"source": [
"### Task 4b\n",
"\n",
"The estimated parameters $\\theta$ of the linear model can be found in the `.coef_` member variable. The feature names can be found in the `.feature_names_in_` member variable. They are the same as the names of the columns of `X` and should be in the same order.\n",
"\n",
"Visualize the estimated parameters and the feature names in a bar plot.\n",
"\n",
"Using these, answer the following questions:\n",
"\n",
"* Which are the 3 most influential features?\n",
"* How do you interpret the sign of the coefficients?\n",
"* If you had to exclude 1 feature, which one would you select and why?"
]
},
{
"cell_type": "code",
2023-06-17 13:35:15 +00:00
"execution_count": 37,
2023-06-17 11:47:27 +00:00
"metadata": {
"id": "odXnubfHqrfc"
},
2023-06-17 13:35:15 +00:00
"outputs": [
{
"data": {
"text/plain": [
"<BarContainer object of 10 artists>"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjMAAAGdCAYAAADnrPLBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAwP0lEQVR4nO3de1hU9aLG8ZeLDHgZRBPwguKthPJOx0Yru7DFwoq9rZ1lpWWahiVqKuQlzYrStNRM00qt9LHbrl1qJOnWLpLmNbXEPGq6tUH3URizQpHf+aPDnCYVdW+H4Wffz/OsJ2et31rrXTQyr2vWrAkyxhgBAABYKjjQAQAAAP4TlBkAAGA1ygwAALAaZQYAAFiNMgMAAKxGmQEAAFajzAAAAKtRZgAAgNVCAx2gIpSWlmr//v2qUaOGgoKCAh0HAACcBWOMjhw5onr16ik4+PTnX/4QZWb//v2Ki4sLdAwAAPBv2Lt3rxo0aHDa5X+IMlOjRg1Jv/4wnE5ngNMAAICz4fF4FBcX530dPx2/lpkTJ05o7NixeuONN+R2u1WvXj317t1bo0aN8r7dY4zRY489ptmzZ6uwsFCdOnXSjBkz1Lx5c+92Dh06pIceekgffvihgoOD1b17d02ZMkXVq1c/qxxl+3I6nZQZAAAsc6ZLRPx6AfAzzzyjGTNm6IUXXtC3336rZ555RhMmTNC0adO8YyZMmKCpU6dq5syZWr16tapVq6aUlBT98ssv3jE9e/bU1q1blZubq0WLFunTTz9Vv379/BkdAABYIsif35rdrVs3xcTE6JVXXvHO6969uyIiIvTGG2/IGKN69epp6NCheuSRRyRJRUVFiomJ0dy5c9WjRw99++23SkxM1FdffaWkpCRJUk5Ojm688Ub985//VL169c6Yw+PxKDIyUkVFRZyZAQDAEmf7+u3XMzMdO3bUsmXLtH37dknSpk2b9Pnnn+uGG26QJO3atUtut1vJycnedSIjI9WhQwfl5eVJkvLy8lSzZk1vkZGk5ORkBQcHa/Xq1afcb3FxsTwej88EAAAuTH69ZiYzM1Mej0ctWrRQSEiITpw4oSeffFI9e/aUJLndbklSTEyMz3oxMTHeZW63W9HR0b6hQ0NVq1Yt75jfy87O1rhx48734QAAgErIr2dm3nrrLc2fP18LFizQ+vXrNW/ePD377LOaN2+eP3errKwsFRUVeae9e/f6dX8AACBw/HpmZtiwYcrMzFSPHj0kSS1bttT333+v7Oxs9erVS7GxsZKkgoIC1a1b17teQUGB2rRpI0mKjY3VgQMHfLZbUlKiQ4cOedf/PYfDIYfD4YcjAgAAlY1fz8z89NNPJ92xLyQkRKWlpZKkxo0bKzY2VsuWLfMu93g8Wr16tVwulyTJ5XKpsLBQ69at845Zvny5SktL1aFDB3/GBwAAFvDrmZmbbrpJTz75pBo2bKhLL71UGzZs0OTJk3XfffdJ+vVz4xkZGXriiSfUvHlzNW7cWKNHj1a9evWUlpYmSUpISFDXrl3Vt29fzZw5U8ePH9fAgQPVo0ePs/okEwAAuLD5tcxMmzZNo0eP1oMPPqgDBw6oXr16euCBBzRmzBjvmOHDh+vo0aPq16+fCgsLdeWVVyonJ0fh4eHeMfPnz9fAgQN1/fXXe2+aN3XqVH9GBwAAlvDrfWYqC+4zAwCAfSrFfWYAAAD8jTIDAACsRpkBAABW8+sFwAAAVFbxmYsDHeEku59ODXQEK3FmBgAAWI0yAwAArEaZAQAAVqPMAAAAq1FmAACA1SgzAADAapQZAABgNcoMAACwGmUGAABYjTIDAACsRpkBAABWo8wAAACrUWYAAIDVKDMAAMBqlBkAAGA1ygwAALAaZQYAAFiNMgMAAKxGmQEAAFajzAAAAKtRZgAAgNUoMwAAwGqUGQAAYDXKDAAAsBplBgAAWI0yAwAArEaZAQAAVqPMAAAAq/m9zOzbt0933XWXateurYiICLVs2VJr1671LjfGaMyYMapbt64iIiKUnJys7777zmcbhw4dUs+ePeV0OlWzZk316dNHP/74o7+jAwAAC/i1zBw+fFidOnVSlSpV9NFHH+mbb77RpEmTFBUV5R0zYcIETZ06VTNnztTq1atVrVo1paSk6JdffvGO6dmzp7Zu3arc3FwtWrRIn376qfr16+fP6AAAwBJBxhjjr41nZmbqiy++0GeffXbK5cYY1atXT0OHDtUjjzwiSSoqKlJMTIzmzp2rHj166Ntvv1ViYqK++uorJSUlSZJycnJ044036p///Kfq1at3xhwej0eRkZEqKiqS0+k8fwcIALBWfObiQEc4ye6nUwMdoVI529dvv56Z+eCDD5SUlKTbbrtN0dHRatu2rWbPnu1dvmvXLrndbiUnJ3vnRUZGqkOHDsrLy5Mk5eXlqWbNmt4iI0nJyckKDg7W6tWrT7nf4uJieTwenwkAAFyY/Fpmdu7cqRkzZqh58+b6+OOPNWDAAD388MOaN2+eJMntdkuSYmJifNaLiYnxLnO73YqOjvZZHhoaqlq1annH/F52drYiIyO9U1xc3Pk+NAAAUEn4tcyUlpaqXbt2euqpp9S2bVv169dPffv21cyZM/25W2VlZamoqMg77d2716/7AwAAgePXMlO3bl0lJib6zEtISNCePXskSbGxsZKkgoICnzEFBQXeZbGxsTpw4IDP8pKSEh06dMg75vccDoecTqfPBAAALkx+LTOdOnVSfn6+z7zt27erUaNGkqTGjRsrNjZWy5Yt8y73eDxavXq1XC6XJMnlcqmwsFDr1q3zjlm+fLlKS0vVoUMHf8YHAAAWCPXnxgcPHqyOHTvqqaee0l//+letWbNGs2bN0qxZsyRJQUFBysjI0BNPPKHmzZurcePGGj16tOrVq6e0tDRJv57J6dq1q/ftqePHj2vgwIHq0aPHWX2SCQAAXNj8WmYuv/xyvffee8rKytLjjz+uxo0b6/nnn1fPnj29Y4YPH66jR4+qX79+Kiws1JVXXqmcnByFh4d7x8yfP18DBw7U9ddfr+DgYHXv3l1Tp071Z3QAAGAJv95nprLgPjMAgN/jPjOVX6W4zwwAAIC/UWYAAIDVKDMAAMBqlBkAAGA1ygwAALAaZQYAAFiNMgMAAKxGmQEAAFajzAAAAKtRZgAAgNUoMwAAwGp+/aJJ4Hzju1QAAL/HmRkAAGA1ygwAALAaZQYAAFiNMgMAAKxGmQEAAFajzAAAAKtRZgAAgNUoMwAAwGrcNA/ABYebKwJ/LJyZAQAAVqPMAAAAq1FmAACA1SgzAADAapQZAABgNcoMAACwGmUGAABYjTIDAACsRpkBAABWo8wAAACrUWYAAIDVKDMAAMBqFVZmnn76aQUFBSkjI8M775dfflF6erpq166t6tWrq3v37iooKPBZb8+ePUpNTVXVqlUVHR2tYcOGqaSkpKJiAwCASq5CysxXX32ll156Sa1atfKZP3jwYH344Yd6++23tXLlSu3fv19/+ctfvMtPnDih1NRUHTt2TKtWrdK8efM0d+5cjRkzpiJiAwAAC/i9zPz444/q2bOnZs+eraioKO/8oqIivfLKK5o8ebKuu+46tW/fXnPmzNGqVav05ZdfSpKWLl2qb775Rm+88YbatGmjG264QePHj9f06dN17Ngxf0cHAAAW8HuZSU9PV2pqqpKTk33mr1u3TsePH/eZ36JFCzVs2FB5eXmSpLy8PLVs2VIxMTHeMSkpKfJ4PNq6detp91lcXCyPx+MzAQCAC1OoPze+cOFCrV+/Xl999dVJy9xut8LCwlSzZk2f+TExMXK73d4xvy0yZcvLlp1Odna2xo0b9x+mBwAANvDbmZm9e/dq0KBBmj9/vsLDw/21m1PKyspSUVGRd9q7d2+F7h8AAFQcv5WZdevW6cCBA2rXrp1CQ0MVGhqqlStXaurUqQoNDVVMTIyOHTumwsJCn/UKCgoUGxsrSYqNjT3p001lj8vGnIrD4ZDT6fSZAADAhclvZeb666/X5s2btXHjRu+UlJSknj17ev9cpUoVLVu2zLtOfn6+9uzZI5fLJ
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.bar(model.feature_names_in_, height=model.coef_)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
2023-06-17 11:47:27 +00:00
"source": [
2023-06-17 13:35:15 +00:00
"- _`bmi`, `s1`, `s5`_\n",
"- Negative coefficients such as the one in `s1` indicate a negative correlation\n",
"- `age`, since it has nearly no influence"
2023-06-17 11:47:27 +00:00
]
},
{
2023-06-17 13:35:15 +00:00
"attachments": {},
2023-06-17 11:47:27 +00:00
"cell_type": "markdown",
"metadata": {
"id": "xa_HDxFeolBj"
},
"source": [
"## 📢 **HAND-IN** 📢: A PDF document containing the following:\n",
"\n",
"* the bar plot\n",
"* your answers to the questions in Task 4b\n",
"\n",
"**Solutions for Tasks 2, 3 and 4 should be in the same document: you will only upload 1 document with your solutions for all 3 tasks!**\n"
]
}
],
"metadata": {
"colab": {
"private_outputs": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2023-06-17 13:35:15 +00:00
"version": "3.11.3"
2023-06-17 11:47:27 +00:00
}
},
"nbformat": 4,
"nbformat_minor": 0
2023-06-17 13:35:15 +00:00
}