show(), hide(), visible. font-light, normal, medium, semibold, bold, extrabold, black. protocol connection pooling. fetch() with GURT. DNS from HTTP to GURT.

This commit is contained in:
Face
2025-08-18 17:45:46 +03:00
parent 3ed49fae0d
commit a8313ec3d8
38 changed files with 2123 additions and 2059 deletions

View File

@@ -27,6 +27,7 @@ Issues:
3. Certain properties like `scale` and `rotate` don't apply to the `active` pseudo-class because they rely on mouse_enter and mouse_exit events 3. Certain properties like `scale` and `rotate` don't apply to the `active` pseudo-class because they rely on mouse_enter and mouse_exit events
4. `<div style="bg-[#3b82f6] w-[100px] h-[100px] flex hover:scale-110 transition hover:rotate-45">Box</div>` something like this has the "Box" text (presumably the PanelContainer) as the target of the hover, not the div itself (which has the w/h size) 4. `<div style="bg-[#3b82f6] w-[100px] h-[100px] flex hover:scale-110 transition hover:rotate-45">Box</div>` something like this has the "Box" text (presumably the PanelContainer) as the target of the hover, not the div itself (which has the w/h size)
5. font in button doesn't comply with CSS, its the projects default 5. font in button doesn't comply with CSS, its the projects default
6. Flex containers, ironically enough, make the page unresponsive. This happens because of our custom `AutoSizingFlexContainer.gd` script, which aims to set a Godot UI size to the flex containers based on their content. However, they don't get resized when the window is resized, leading to unresponsiveness. The fact that we're setting the `custom_minimum_size` is not the root cause, but rather the fact that the script doesn't update the size when the window is resized - or, more likely, I just don't understand how flexbox works.
Notes: Notes:
- **< input />** is sort-of inline in normal web. We render it as a block element (new-line). - **< input />** is sort-of inline in normal web. We render it as a block element (new-line).

718
dns/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -10,16 +10,14 @@ regex = "1.10.4"
jsonwebtoken = "9.2" jsonwebtoken = "9.2"
bcrypt = "0.15" bcrypt = "0.15"
serenity = { version = "0.12", features = ["client", "gateway", "rustls_backend", "model"] } serenity = { version = "0.12", features = ["client", "gateway", "rustls_backend", "model"] }
actix-web-httpauth = "0.8"
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
colored = "2.1.0" colored = "2.1.0"
sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "postgres", "chrono", "uuid", "migrate", "json"] } sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "postgres", "chrono", "uuid", "migrate", "json"] }
anyhow = "1.0.86" anyhow = "1.0.86"
futures = "0.3.30" futures = "0.3.30"
actix-web = "4.6.0"
macros-rs = "1.2.1" macros-rs = "1.2.1"
prettytable = "0.10.0" prettytable = "0.10.0"
actix-governor = "0.5.0" gurt = { path = "../protocol/library" }
pretty_env_logger = "0.5.0" pretty_env_logger = "0.5.0"
clap-verbosity-flag = "2.2.0" clap-verbosity-flag = "2.2.0"

View File

@@ -1,44 +1,103 @@
<head> <head>
<title>Domain Dashboard</title> <title>Domain Dashboard</title>
<icon src="https://cdn-icons-png.flaticon.com/512/295/295128.png"> <icon src="https://cdn-icons-png.flaticon.com/512/295/295128.png">
<meta name="theme-color" content="#0891b2"> <meta name="theme-color" content="#0891b2">
<meta name="description" content="Manage your domains and registrations"> <meta name="description" content="Manage your domains and registrations">
<style> <style>
body { bg-[#f8fafc] p-6 font-sans } body {
h1 { text-[#0891b2] text-3xl font-bold text-center } bg-[#171616] font-sans text-white
h2 { text-[#0f766e] text-xl font-semibold } }
h3 { text-[#374151] text-lg font-medium }
.container { bg-[#ffffff] p-6 rounded-lg shadow-lg max-w-6xl mx-auto }
.primary-btn { px-4 py-2 rounded-lg font-medium cursor-pointer transition-colors bg-[#0891b2] text-white hover:bg-[#0e7490] }
.success-btn { px-4 py-2 rounded-lg font-medium cursor-pointer transition-colors bg-[#059669] text-white hover:bg-[#047857] }
.danger-btn { px-4 py-2 rounded-lg font-medium cursor-pointer transition-colors bg-[#dc2626] text-white hover:bg-[#b91c1c] }
.secondary-btn { px-4 py-2 rounded-lg font-medium cursor-pointer transition-colors bg-[#6b7280] text-white hover:bg-[#4b5563] }
.warning-btn { px-4 py-2 rounded-lg font-medium cursor-pointer transition-colors bg-[#f59e0b] text-white hover:bg-[#d97706] }
.form-group { flex flex-col gap-2 mb-4 }
.form-input { w-full p-3 border border-gray-300 rounded-md }
.card { bg-[#ffffff] p-4 rounded-lg shadow border }
.stats-card { bg-[#f0f9ff] p-4 rounded-lg border border-[#0891b2] }
.domain-item { bg-[#f8fafc] p-4 rounded-lg border mb-2 flex justify-between items-center }
.log-area { bg-[#1f2937] text-white p-4 rounded-lg font-mono text-sm max-h-64 overflow-auto }
.error-text { text-[#dc2626] text-sm }
.modal { w-full h-full bg-[rgba(0,0,0,0.5)] flex items-center justify-center z-50 }
.modal-content { bg-white p-6 rounded-lg max-w-md w-full mx-4 }
.tld-selector { flex flex-wrap gap-2 }
.tld-option { px-3 py-1 rounded border cursor-pointer hover:bg-[#f3f4f6] }
.tld-selected { bg-[#0891b2] text-white hover:bg-[#0e7490] }
</style>
<script src="dashboard.lua" /> h1 {
text-[#ef4444] text-3xl font-bold text-center
}
h2 {
text-[#dc2626] text-xl font-semibold
}
h3 {
text-[#fca5a5] text-lg font-medium
}
.container {
bg-[#262626] p-6 rounded-lg shadow-lg max-w-6xl mx-auto
}
.primary-btn {
px-4 py-2 rounded-lg font-medium cursor-pointer transition-colors bg-[#dc2626] text-white
}
.success-btn {
px-4 py-2 rounded-lg font-medium cursor-pointer transition-colors bg-[#ef4444] text-white
}
.danger-btn {
px-4 py-2 rounded-lg font-medium cursor-pointer transition-colors bg-[#b91c1c] text-white
}
.secondary-btn {
px-4 py-2 rounded-lg font-medium cursor-pointer transition-colors bg-[#525252] text-white w-32 h-12
}
.warning-btn {
px-4 py-2 rounded-lg font-medium cursor-pointer transition-colors bg-[#dc2626] text-white
}
.form-group {
flex flex-col gap-2 mb-4 w-full
}
.form-input {
w-full p-3 border border-gray-600 rounded-md bg-[#374151] text-white
}
.card {
bg-[#262626] p-4 rounded-lg shadow border border-gray-700
}
.stats-card {
bg-[#1f1f1f] p-4 rounded-lg border border-[#dc2626]
}
.domain-item {
bg-[#374151] p-4 rounded-lg border border-gray-700 mb-2 flex justify-between items-center
}
.log-area {
bg-[#111827] text-white p-4 rounded-lg font-mono text-sm max-h-64 overflow-auto
}
.error-text {
text-[#fca5a5] text-sm
}
.tld-selector {
flex flex-wrap gap-2
}
.tld-option {
px-3 py-1 rounded border border-gray-600 cursor-pointer bg-[#374151] text-white w-12 h-12
}
.tld-selected {
bg-[#dc2626] text-white
}
.invite-code-display {
bg-[#374151] p-3 rounded font-mono text-center mb-2 text-white
}
</style>
<script src="dashboard.lua" />
</head> </head>
<body> <body>
<h1>🌐 Domain Management Dashboard</h1>
<div style="container mt-6"> <div style="container mt-6">
<div style="stats-card mb-6"> <div style="stats-card mb-6">
<div style="flex justify-between items-center"> <div style="flex justify-between items-center w-full">
<div id="user-info" style="text-lg font-semibold">Loading...</div> <p id="user-info" style="text-white text-lg font-semibold">Loading...</p>
<button id="logout-btn" style="secondary-btn">Logout</button> <button id="logout-btn" style="secondary-btn">Logout</button>
</div> </div>
</div> </div>
@@ -46,17 +105,17 @@
<div style="card mb-6"> <div style="card mb-6">
<h2>Register New Domain</h2> <h2>Register New Domain</h2>
<div style="form-group"> <div style="form-group">
<label>Domain Name:</label> <p>Domain Name:</p>
<input id="domain-name" type="text" style="form-input" placeholder="myawesome" /> <input id="domain-name" type="text" style="form-input" placeholder="myawesome" />
</div> </div>
<div style="form-group"> <div style="form-group">
<label>Select TLD:</label> <p>Select TLD:</p>
<div id="tld-selector" style="tld-selector"> <div id="tld-selector" style="tld-selector">
Loading TLDs... <p id="tld-loading">Loading TLDs...</p>
</div> </div>
</div> </div>
<div style="form-group"> <div style="form-group">
<label>IP Address:</label> <p>IP Address:</p>
<input id="domain-ip" type="text" style="form-input" placeholder="192.168.1.100" /> <input id="domain-ip" type="text" style="form-input" placeholder="192.168.1.100" />
</div> </div>
<div id="domain-error" style="error-text hidden mb-2"></div> <div id="domain-error" style="error-text hidden mb-2"></div>
@@ -65,49 +124,29 @@
<div style="card mb-6"> <div style="card mb-6">
<h2>Invite System</h2> <h2>Invite System</h2>
<p style="text-[#6b7280] mb-4">Create invite codes to share with friends, or redeem codes to get more domain registrations.</p> <p style="text-[#6b7280] mb-4">Create invite codes to share with friends, or redeem codes to get more domain
registrations.</p>
<div style="flex flex-row gap-4">
<div style="flex-1"> <p id="invite-code-display" style="invite-code-display mt-2">Placeholder</p>
<h3>Create Invite</h3>
<button id="create-invite-btn" style="warning-btn">Generate Invite Code</button> <div style="flex flex-col gap-4 items-center justify-center mx-auto">
</div> <h3>Create Invite</h3>
<div style="flex-1"> <button id="create-invite-btn" style="warning-btn">Generate Invite Code</button>
<h3>Redeem Invite</h3> </div>
<div style="flex gap-2"> <div style="flex flex-col gap-4 mx-auto">
<input id="invite-code-input" type="text" style="form-input" placeholder="Enter invite code" /> <h3>Redeem Invite</h3>
<button id="redeem-invite-btn" style="primary-btn">Redeem</button> <div style="flex gap-2">
</div> <input id="invite-code-input" type="text" style="form-input" placeholder="Enter invite code" />
<div id="redeem-error" style="error-text hidden mt-2"></div> <button id="redeem-invite-btn" style="primary-btn">Redeem</button>
</div> </div>
<div id="redeem-error" style="error-text hidden mt-2"></div>
</div> </div>
</div> </div>
<div style="card mb-6"> <div style="card mb-6">
<h2>My Domains</h2> <h2>My Domains</h2>
<div id="domains-list"> <div id="domains-list">
Loading domains... <p id="domains-loading">Loading domains...</p>
</div>
</div>
<div style="card">
<h2>Activity Log</h2>
<div style="log-area">
<pre id="log-area">Initializing...</pre>
</div>
</div>
</div>
<div id="invite-modal" style="modal hidden">
<div style="modal-content">
<h3>Invite Code Generated</h3>
<p>Share this code with friends to give them 3 additional domain registrations:</p>
<div style="bg-[#f3f4f6] p-3 rounded font-mono text-center mb-4">
<span id="invite-code-display">Loading...</span>
</div>
<div style="flex gap-2 justify-center">
<button id="copy-invite-code" style="primary-btn">Copy Code</button>
<button id="close-invite-modal" style="secondary-btn">Close</button>
</div> </div>
</div> </div>
</div> </div>

View File

@@ -5,19 +5,11 @@ local authToken = nil
local userInfo = gurt.select('#user-info') local userInfo = gurt.select('#user-info')
local domainsList = gurt.select('#domains-list') local domainsList = gurt.select('#domains-list')
local logArea = gurt.select('#log-area')
local inviteModal = gurt.select('#invite-modal')
local tldSelector = gurt.select('#tld-selector') local tldSelector = gurt.select('#tld-selector')
local loadingElement = gurt.select('#tld-loading')
local displayElement = gurt.select('#invite-code-display')
local logMessages = {} displayElement:hide()
local function addLog(message)
table.insert(logMessages, Time.format(Time.now(), '%H:%M:%S') .. ' - ' .. message)
if #logMessages > 50 then
table.remove(logMessages, 1)
end
logArea.text = table.concat(logMessages, '\n')
end
local function showError(elementId, message) local function showError(elementId, message)
local element = gurt.select('#' .. elementId) local element = gurt.select('#' .. elementId)
@@ -32,189 +24,51 @@ local function hideError(elementId)
element.classList:add('hidden') element.classList:add('hidden')
end end
local function showModal(modalId)
local modal = gurt.select('#' .. modalId)
modal.classList:remove('hidden')
end
local function hideModal(modalId)
local modal = gurt.select('#' .. modalId)
modal.classList:add('hidden')
end
local function makeRequest(url, options)
options = options or {}
if authToken then
options.headers = options.headers or {}
options.headers.Authorization = 'Bearer ' .. authToken
end
return fetch(url, options)
end
local function checkAuth()
authToken = gurt.crumbs.get("auth_token")
if authToken then
addLog('Found auth token, checking validity...')
local response = makeRequest('gurt://localhost:4878/auth/me')
print(table.tostring(response))
if response:ok() then
user = response:json()
addLog('Authentication successful for user: ' .. user.username)
updateUserInfo()
loadDomains()
loadTLDs()
else
addLog('Token invalid, redirecting to login...')
--gurt.crumbs.delete('auth_token')
--gurt.location.goto('../')
end
else
addLog('No auth token found, redirecting to login...')
gurt.location.goto('../')
end
end
local function logout()
gurt.crumbs.delete('auth_token')
addLog('Logged out successfully')
gurt.location.goto("../")
end
local function loadDomains()
addLog('Loading domains...')
local response = makeRequest('gurt://localhost:4878/domains?page=1&size=100')
if response:ok() then
local data = response:json()
domains = data.domains or {}
addLog('Loaded ' .. #domains .. ' domains')
renderDomains()
else
addLog('Failed to load domains: ' .. response:text())
end
end
local function loadTLDs()
addLog('Loading available TLDs...')
local response = fetch('gurt://localhost:4878/tlds')
if response:ok() then
tlds = response:json()
addLog('Loaded ' .. #tlds .. ' TLDs')
renderTLDSelector()
else
addLog('Failed to load TLDs: ' .. response:text())
end
end
local function submitDomain(name, tld, ip)
hideError('domain-error')
addLog('Submitting domain: ' .. name .. '.' .. tld)
local response = makeRequest('gurt://localhost:4878/domain', {
method = 'POST',
headers = { ['Content-Type'] = 'application/json' },
body = JSON.stringify({ name = name, tld = tld, ip = ip })
})
if response:ok() then
local data = response:json()
addLog('Domain submitted successfully: ' .. data.domain)
-- Update user registrations remaining
user.registrations_remaining = user.registrations_remaining - 1
updateUserInfo()
-- Clear form
gurt.select('#domain-name').text = ''
gurt.select('#domain-ip').text = ''
-- Refresh domains list
loadDomains()
else
local error = response:text()
showError('domain-error', 'Domain submission failed: ' .. error)
addLog('Domain submission failed: ' .. error)
end
end
local function createInvite()
addLog('Creating invite code...')
local response = makeRequest('gurt://localhost:4878/auth/invite', { method = 'POST' })
if response:ok() then
local data = response:json()
local inviteCode = data.invite_code
gurt.select('#invite-code-display').text = inviteCode
addLog('Invite code created: ' .. inviteCode)
showModal('invite-modal')
else
addLog('Failed to create invite: ' .. response:text())
end
end
local function redeemInvite(code)
hideError('redeem-error')
addLog('Redeeming invite code: ' .. code)
local response = makeRequest('gurt://localhost:4878/auth/redeem-invite', {
method = 'POST',
headers = { ['Content-Type'] = 'application/json' },
body = JSON.stringify({ invite_code = code })
})
if response:ok() then
local data = response:json()
addLog('Invite redeemed: +' .. data.registrations_added .. ' registrations')
-- Update user info
user.registrations_remaining = user.registrations_remaining + data.registrations_added
updateUserInfo()
-- Clear form
gurt.select('#invite-code-input').text = ''
else
local error = response:text()
showError('redeem-error', 'Failed to redeem invite: ' .. error)
addLog('Failed to redeem invite: ' .. error)
end
end
-- UI rendering functions
local function updateUserInfo() local function updateUserInfo()
if user then userInfo.text = 'Welcome, ' .. user.username .. '!'
userInfo.text = 'Welcome, ' .. user.username .. ' | Registrations remaining: ' .. user.registrations_remaining
end
end end
local function renderTLDSelector() local function renderTLDSelector()
loadingElement:remove()
tldSelector.text = '' tldSelector.text = ''
for i, tld in ipairs(tlds) do local i = 1
local option = gurt.create('div', { local total = #tlds
local intervalId
intervalId = gurt.setInterval(function()
if i > total then
gurt.clearInterval(intervalId)
return
end
local tld = tlds[i]
local option = gurt.create('button', {
text = '.' .. tld, text = '.' .. tld,
style = 'tld-option', style = 'tld-option',
['data-tld'] = tld ['data-tld'] = tld
}) })
option:on('click', function() option:on('click', function()
-- Clear previous selection -- Clear previous selection
local options = gurt.selectAll('.tld-option') local options = gurt.selectAll('.tld-option')
for j = 1, #options do for j = 1, #options do
options[j].classList:remove('tld-selected') options[j].classList:remove('tld-selected')
end end
-- Select this option -- Select this option
option.classList:add('tld-selected') option.classList:add('tld-selected')
end) end)
tldSelector:append(option) tldSelector:append(option)
end i = i + 1
end, 16)
end end
local function renderDomains() local function renderDomains()
local loadingElement = gurt.select('#domains-loading')
loadingElement:remove()
domainsList.text = '' domainsList.text = ''
if #domains == 0 then if #domains == 0 then
@@ -281,35 +135,153 @@ local function renderDomains()
end end
end end
local function updateDomainIP(name, tld, ip) local function loadDomains()
addLog('Updating IP for ' .. name .. '.' .. tld .. ' to ' .. ip) print('Loading domains...')
local response = fetch('gurt://localhost:8877/domains?page=1&size=100', {
local response = makeRequest('gurt://localhost:4878/domain/' .. name .. '/' .. tld, { headers = {
method = 'PUT', Authorization = 'Bearer ' .. authToken
headers = { ['Content-Type'] = 'application/json' }, }
body = JSON.stringify({ ip = ip })
}) })
if response:ok() then if response:ok() then
addLog('Domain IP updated successfully') local data = response:json()
loadDomains() domains = data.domains or {}
print('Loaded ' .. #domains .. ' domains')
renderDomains()
else else
addLog('Failed to update domain IP: ' .. response:text()) print('Failed to load domains: ' .. response:text())
end end
end end
local function deleteDomain(name, tld) local function loadTLDs()
addLog('Deleting domain: ' .. name .. '.' .. tld) print('Loading available TLDs...')
local response = fetch('gurt://localhost:8877/tlds')
local response = makeRequest('gurt://localhost:4878/domain/' .. name .. '/' .. tld, { if response:ok() then
method = 'DELETE' tlds = response:json()
print('Loaded ' .. #tlds .. ' TLDs')
renderTLDSelector()
else
print('Failed to load TLDs: ' .. response:text())
end
end
local function checkAuth()
authToken = gurt.crumbs.get("auth_token")
if authToken then
print('Found auth token, checking validity...')
local response = fetch('gurt://localhost:8877/auth/me', {
headers = {
Authorization = 'Bearer ' .. authToken
}
})
print(table.tostring(response))
if response:ok() then
user = response:json()
print('Authentication successful for user: ' .. user.username)
updateUserInfo()
loadDomains()
loadTLDs()
else
print('Token invalid, redirecting to login...')
gurt.crumbs.delete('auth_token')
gurt.location.goto('../')
end
else
print('No auth token found, redirecting to login...')
gurt.location.goto('../')
end
end
local function logout()
gurt.crumbs.delete('auth_token')
print('Logged out successfully')
gurt.location.goto("../")
end
local function submitDomain(name, tld, ip)
hideError('domain-error')
print('Submitting domain: ' .. name .. '.' .. tld)
local response = fetch('gurt://localhost:8877/domain', {
method = 'POST',
headers = {
['Content-Type'] = 'application/json',
Authorization = 'Bearer ' .. authToken
},
body = JSON.stringify({ name = name, tld = tld, ip = ip })
}) })
if response:ok() then if response:ok() then
addLog('Domain deleted successfully') local data = response:json()
print('Domain submitted successfully: ' .. data.domain)
-- Update user registrations remaining
user.registrations_remaining = user.registrations_remaining - 1
updateUserInfo()
-- Clear form
gurt.select('#domain-name').text = ''
gurt.select('#domain-ip').text = ''
-- Refresh domains list
loadDomains() loadDomains()
else else
addLog('Failed to delete domain: ' .. response:text()) local error = response:text()
showError('domain-error', 'Domain submission failed: ' .. error)
print('Domain submission failed: ' .. error)
end
end
local function createInvite()
print('Creating invite code...')
local response = fetch('gurt://localhost:8877/auth/invite', {
method = 'POST',
headers = {
Authorization = 'Bearer ' .. authToken
}
})
if response:ok() then
local data = response:json()
local inviteCode = data.invite_code
displayElement.text = 'Invite code: ' .. inviteCode .. ' (copied to clipboard)'
displayElement:show()
Clipboard.write(inviteCode)
print('Invite code created and copied to clipboard: ' .. inviteCode)
else
print('Failed to create invite: ' .. response:text())
end
end
local function redeemInvite(code)
hideError('redeem-error')
print('Redeeming invite code: ' .. code)
local response = fetch('gurt://localhost:8877/auth/redeem-invite', {
method = 'POST',
headers = {
['Content-Type'] = 'application/json',
Authorization = 'Bearer ' .. authToken
},
body = JSON.stringify({ invite_code = code })
})
if response:ok() then
local data = response:json()
print('Invite redeemed: +' .. data.registrations_added .. ' registrations')
-- Update user info
user.registrations_remaining = user.registrations_remaining + data.registrations_added
updateUserInfo()
-- Clear form
gurt.select('#invite-code-input').text = ''
else
local error = response:text()
showError('redeem-error', 'Failed to redeem invite: ' .. error)
print('Failed to redeem invite: ' .. error)
end end
end end
@@ -349,16 +321,6 @@ gurt.select('#redeem-invite-btn'):on('click', function()
end end
end) end)
gurt.select('#close-invite-modal'):on('click', function()
hideModal('invite-modal')
end)
gurt.select('#copy-invite-code'):on('click', function()
local inviteCode = gurt.select('#invite-code-display').text
Clipboard.write(inviteCode)
addLog('Invite code copied to clipboard')
end)
-- Initialize -- Initialize
addLog('Dashboard initialized') print('Dashboard initialized')
checkAuth() checkAuth()

View File

@@ -1,3 +1,7 @@
if gurt.crumbs.get("auth_token") then
gurt.location.goto("/dashboard.html")
end
local submitBtn = gurt.select('#submit') local submitBtn = gurt.select('#submit')
local username_input = gurt.select('#username') local username_input = gurt.select('#username')
local password_input = gurt.select('#password') local password_input = gurt.select('#password')
@@ -8,7 +12,6 @@ function addLog(message)
log_output.text = log_output.text .. message .. '\\n' log_output.text = log_output.text .. message .. '\\n'
end end
print(gurt.location.href)
submitBtn:on('submit', function(event) submitBtn:on('submit', function(event)
local username = event.data.username local username = event.data.username
local password = event.data.password local password = event.data.password
@@ -18,7 +21,7 @@ submitBtn:on('submit', function(event)
password = password password = password
}) })
print(request_body) print(request_body)
local url = 'gurt://localhost:8080/auth/login' local url = 'gurt://localhost:8877/auth/login'
local headers = { local headers = {
['Content-Type'] = 'application/json' ['Content-Type'] = 'application/json'
} }

View File

@@ -1,7 +1,4 @@
use actix_web::{dev::ServiceRequest, web, Error, HttpMessage}; use gurt::prelude::*;
use actix_web_httpauth::extractors::bearer::BearerAuth;
use actix_web_httpauth::extractors::AuthenticationError;
use actix_web_httpauth::headers::www_authenticate::bearer::Bearer;
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation, Algorithm}; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation, Algorithm};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use bcrypt::{hash, verify, DEFAULT_COST}; use bcrypt::{hash, verify, DEFAULT_COST};
@@ -42,7 +39,7 @@ pub struct UserInfo {
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
} }
pub fn generate_jwt(user_id: i32, username: &str, secret: &str) -> Result<String, jsonwebtoken::errors::Error> { pub fn generate_jwt(user_id: i32, username: &str, secret: &str) -> std::result::Result<String, jsonwebtoken::errors::Error> {
let expiration = SystemTime::now() let expiration = SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
.unwrap() .unwrap()
@@ -57,7 +54,7 @@ pub fn generate_jwt(user_id: i32, username: &str, secret: &str) -> Result<String
encode(&Header::default(), &claims, &EncodingKey::from_secret(secret.as_ref())) encode(&Header::default(), &claims, &EncodingKey::from_secret(secret.as_ref()))
} }
pub fn validate_jwt(token: &str, secret: &str) -> Result<Claims, jsonwebtoken::errors::Error> { pub fn validate_jwt(token: &str, secret: &str) -> std::result::Result<Claims, jsonwebtoken::errors::Error> {
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
validation.validate_exp = true; validation.validate_exp = true;
@@ -65,31 +62,39 @@ pub fn validate_jwt(token: &str, secret: &str) -> Result<Claims, jsonwebtoken::e
.map(|token_data| token_data.claims) .map(|token_data| token_data.claims)
} }
pub fn hash_password(password: &str) -> Result<String, bcrypt::BcryptError> { pub fn hash_password(password: &str) -> std::result::Result<String, bcrypt::BcryptError> {
hash(password, DEFAULT_COST) hash(password, DEFAULT_COST)
} }
pub fn verify_password(password: &str, hash: &str) -> Result<bool, bcrypt::BcryptError> { pub fn verify_password(password: &str, hash: &str) -> std::result::Result<bool, bcrypt::BcryptError> {
verify(password, hash) verify(password, hash)
} }
pub async fn jwt_middleware( pub async fn jwt_middleware_gurt(ctx: &ServerContext, jwt_secret: &str) -> Result<Claims> {
req: ServiceRequest, let start_time = std::time::Instant::now();
credentials: BearerAuth, log::info!("JWT middleware started for {} {}", ctx.method(), ctx.path());
) -> Result<ServiceRequest, (Error, ServiceRequest)> {
let jwt_secret = req let auth_header = ctx.header("authorization")
.app_data::<web::Data<String>>() .or_else(|| ctx.header("Authorization"))
.unwrap() .ok_or_else(|| {
.as_ref(); log::warn!("JWT middleware failed: Missing Authorization header in {:?}", start_time.elapsed());
GurtError::invalid_message("Missing Authorization header")
})?;
match validate_jwt(credentials.token(), jwt_secret) { if !auth_header.starts_with("Bearer ") {
Ok(claims) => { log::warn!("JWT middleware failed: Invalid header format in {:?}", start_time.elapsed());
req.extensions_mut().insert(claims); return Err(GurtError::invalid_message("Invalid Authorization header format"));
Ok(req)
}
Err(_) => {
let config = AuthenticationError::new(Bearer::default());
Err((Error::from(config), req))
}
} }
let token = &auth_header[7..]; // Remove "Bearer " prefix
let result = validate_jwt(token, jwt_secret)
.map_err(|e| GurtError::invalid_message(format!("Invalid JWT token: {}", e)));
match &result {
Ok(_) => log::info!("JWT middleware completed successfully in {:?}", start_time.elapsed()),
Err(e) => log::warn!("JWT middleware failed: {} in {:?}", e, start_time.elapsed()),
}
result
} }

View File

@@ -25,6 +25,8 @@ impl Config {
url: "postgresql://username:password@localhost/domains".into(), url: "postgresql://username:password@localhost/domains".into(),
max_connections: 10, max_connections: 10,
}, },
cert_path: "localhost+2.pem".into(),
key_path: "localhost+2-key.pem".into(),
}, },
discord: Discord { discord: Discord {
bot_token: "".into(), bot_token: "".into(),

View File

@@ -15,6 +15,8 @@ pub struct Server {
pub(crate) address: String, pub(crate) address: String,
pub(crate) port: u64, pub(crate) port: u64,
pub(crate) database: Database, pub(crate) database: Database,
pub(crate) cert_path: String,
pub(crate) key_path: String,
} }
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]

216
dns/src/gurt_server.rs Normal file
View File

@@ -0,0 +1,216 @@
mod auth_routes;
mod helpers;
mod models;
mod routes;
use crate::{auth::jwt_middleware_gurt, config::Config, discord_bot};
use colored::Colorize;
use macros_rs::fmt::{crashln, string};
use std::{net::IpAddr, str::FromStr, sync::Arc, collections::HashMap};
use gurt::prelude::*;
use gurt::{GurtStatusCode, Route};
#[derive(Clone)]
pub(crate) struct AppState {
trusted: IpAddr,
config: Config,
db: sqlx::PgPool,
jwt_secret: String,
}
impl AppState {
pub fn new(trusted: IpAddr, config: Config, db: sqlx::PgPool, jwt_secret: String) -> Self {
Self {
trusted,
config,
db,
jwt_secret,
}
}
}
#[derive(Clone)]
pub(crate) struct RateLimitState {
limits: Arc<tokio::sync::RwLock<HashMap<String, Vec<chrono::DateTime<chrono::Utc>>>>>,
}
impl RateLimitState {
pub fn new() -> Self {
Self {
limits: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
pub async fn check_rate_limit(&self, key: &str, window_secs: i64, max_requests: usize) -> bool {
let mut limits = self.limits.write().await;
let now = chrono::Utc::now();
let window_start = now - chrono::Duration::seconds(window_secs);
let entry = limits.entry(key.to_string()).or_insert_with(Vec::new);
entry.retain(|&timestamp| timestamp > window_start);
if entry.len() >= max_requests {
false
} else {
entry.push(now);
true
}
}
}
struct AppHandler {
app_state: AppState,
rate_limit_state: Option<RateLimitState>,
handler_type: HandlerType,
}
// Macro to reduce JWT middleware duplication
macro_rules! handle_authenticated {
($ctx:expr, $app_state:expr, $handler:expr) => {
match jwt_middleware_gurt(&$ctx, &$app_state.jwt_secret).await {
Ok(claims) => $handler(&$ctx, $app_state, claims).await,
Err(e) => Ok(GurtResponse::new(GurtStatusCode::Unauthorized)
.with_string_body(&format!("Authentication failed: {}", e))),
}
};
}
#[derive(Clone)]
enum HandlerType {
Index,
GetDomain,
GetDomains,
GetTlds,
CheckDomain,
Register,
Login,
GetUserInfo,
CreateInvite,
RedeemInvite,
CreateDomainInvite,
RedeemDomainInvite,
CreateDomain,
UpdateDomain,
DeleteDomain,
}
impl GurtHandler for AppHandler {
fn handle(&self, ctx: &ServerContext) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<GurtResponse>> + Send + '_>> {
let app_state = self.app_state.clone();
let rate_limit_state = self.rate_limit_state.clone();
let handler_type = self.handler_type.clone();
let ctx_data = (
ctx.remote_addr,
ctx.request.clone(),
);
Box::pin(async move {
let start_time = std::time::Instant::now();
let ctx = ServerContext {
remote_addr: ctx_data.0,
request: ctx_data.1,
};
log::info!("Handler started for {} {} from {}", ctx.method(), ctx.path(), ctx.remote_addr);
let result = match handler_type {
HandlerType::Index => routes::index(app_state).await,
HandlerType::GetDomain => routes::get_domain(&ctx, app_state).await,
HandlerType::GetDomains => routes::get_domains(&ctx, app_state).await,
HandlerType::GetTlds => routes::get_tlds(app_state).await,
HandlerType::CheckDomain => routes::check_domain(&ctx, app_state).await,
HandlerType::Register => auth_routes::register(&ctx, app_state).await,
HandlerType::Login => auth_routes::login(&ctx, app_state).await,
HandlerType::GetUserInfo => handle_authenticated!(ctx, app_state, auth_routes::get_user_info),
HandlerType::CreateInvite => handle_authenticated!(ctx, app_state, auth_routes::create_invite),
HandlerType::RedeemInvite => handle_authenticated!(ctx, app_state, auth_routes::redeem_invite),
HandlerType::CreateDomainInvite => handle_authenticated!(ctx, app_state, auth_routes::create_domain_invite),
HandlerType::RedeemDomainInvite => handle_authenticated!(ctx, app_state, auth_routes::redeem_domain_invite),
HandlerType::CreateDomain => {
// Check rate limit first
if let Some(ref rate_limit_state) = rate_limit_state {
let client_ip = ctx.client_ip().to_string();
if !rate_limit_state.check_rate_limit(&client_ip, 600, 5).await {
return Ok(GurtResponse::new(GurtStatusCode::TooLarge).with_string_body("Rate limit exceeded: 5 requests per 10 minutes"));
}
}
handle_authenticated!(ctx, app_state, routes::create_domain)
},
HandlerType::UpdateDomain => handle_authenticated!(ctx, app_state, routes::update_domain),
HandlerType::DeleteDomain => handle_authenticated!(ctx, app_state, routes::delete_domain),
};
let duration = start_time.elapsed();
match &result {
Ok(response) => {
log::info!("Handler completed for {} {} in {:?} - Status: {}",
ctx.method(), ctx.path(), duration, response.status_code);
},
Err(e) => {
log::error!("Handler failed for {} {} in {:?} - Error: {}",
ctx.method(), ctx.path(), duration, e);
}
}
result
})
}
}
pub async fn start(cli: crate::Cli) -> std::io::Result<()> {
let config = Config::new().set_path(&cli.config).read();
let trusted_ip = match IpAddr::from_str(&config.server.address) {
Ok(addr) => addr,
Err(err) => crashln!("Cannot parse address.\n{}", string!(err).white()),
};
let db = match config.connect_to_db().await {
Ok(pool) => pool,
Err(err) => crashln!("Failed to connect to PostgreSQL database.\n{}", string!(err).white()),
};
// Start Discord bot
if !config.discord.bot_token.is_empty() {
if let Err(e) = discord_bot::start_discord_bot(config.discord.bot_token.clone(), db.clone()).await {
log::error!("Failed to start Discord bot: {}", e);
}
}
let jwt_secret = config.auth.jwt_secret.clone();
let app_state = AppState::new(trusted_ip, config.clone(), db, jwt_secret);
let rate_limit_state = RateLimitState::new();
// Create GURT server
let mut server = GurtServer::new();
// Load TLS certificates
if let Err(e) = server.load_tls_certificates(&config.server.cert_path, &config.server.key_path) {
crashln!("Failed to load TLS certificates: {}", e);
}
server = server
.route(Route::get("/"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::Index })
.route(Route::get("/domain/*"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::GetDomain })
.route(Route::get("/domains"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::GetDomains })
.route(Route::get("/tlds"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::GetTlds })
.route(Route::get("/check"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::CheckDomain })
.route(Route::post("/auth/register"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::Register })
.route(Route::post("/auth/login"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::Login })
.route(Route::get("/auth/me"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::GetUserInfo })
.route(Route::post("/auth/invite"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::CreateInvite })
.route(Route::post("/auth/redeem-invite"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::RedeemInvite })
.route(Route::post("/auth/create-domain-invite"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::CreateDomainInvite })
.route(Route::post("/auth/redeem-domain-invite"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::RedeemDomainInvite })
.route(Route::post("/domain"), AppHandler { app_state: app_state.clone(), rate_limit_state: Some(rate_limit_state), handler_type: HandlerType::CreateDomain })
.route(Route::put("/domain/*"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::UpdateDomain })
.route(Route::delete("/domain/*"), AppHandler { app_state: app_state.clone(), rate_limit_state: None, handler_type: HandlerType::DeleteDomain });
log::info!("GURT server listening on {}", config.get_address());
server.listen(&config.get_address()).await.map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, format!("GURT server error: {}", e))
})
}

View File

@@ -0,0 +1,402 @@
use super::{models::*, AppState};
use crate::auth::*;
use gurt::prelude::*;
use gurt::GurtStatusCode;
use sqlx::Row;
use chrono::Utc;
pub(crate) async fn register(ctx: &ServerContext, app_state: AppState) -> Result<GurtResponse> {
let user: RegisterRequest = serde_json::from_slice(ctx.body())
.map_err(|_| GurtError::invalid_message("Invalid JSON"))?;
let registrations = 3; // New users get 3 registrations by default
// Hash password
let password_hash = match hash_password(&user.password) {
Ok(hash) => hash,
Err(_) => {
return Ok(GurtResponse::internal_server_error().with_json_body(&Error {
msg: "Failed to hash password",
error: "HASH_ERROR".into(),
})?);
}
};
// Create user
let user_result = sqlx::query(
"INSERT INTO users (username, password_hash, registrations_remaining, domain_invite_codes) VALUES ($1, $2, $3, $4) RETURNING id"
)
.bind(&user.username)
.bind(&password_hash)
.bind(registrations)
.bind(3) // Default 3 domain invite codes
.fetch_one(&app_state.db)
.await;
match user_result {
Ok(row) => {
let user_id: i32 = row.get("id");
// Generate JWT
match generate_jwt(user_id, &user.username, &app_state.jwt_secret) {
Ok(token) => {
let response = LoginResponse {
token,
user: UserInfo {
id: user_id,
username: user.username.clone(),
registrations_remaining: registrations,
domain_invite_codes: 3,
created_at: Utc::now(),
},
};
Ok(GurtResponse::ok().with_json_body(&response)?)
}
Err(_) => {
Ok(GurtResponse::internal_server_error().with_json_body(&Error {
msg: "Failed to generate token",
error: "JWT_ERROR".into(),
})?)
}
}
}
Err(e) => {
if e.to_string().contains("duplicate key") {
Ok(GurtResponse::bad_request().with_json_body(&Error {
msg: "Username already exists",
error: "DUPLICATE_USERNAME".into(),
})?)
} else {
Ok(GurtResponse::internal_server_error().with_json_body(&Error {
msg: "Failed to create user",
error: e.to_string(),
})?)
}
}
}
}
pub(crate) async fn login(ctx: &ServerContext, app_state: AppState) -> Result<GurtResponse> {
let body_bytes = ctx.body();
let login_req: LoginRequest = serde_json::from_slice(body_bytes)
.map_err(|e| {
log::error!("JSON parse error: {}", e);
GurtError::invalid_message("Invalid JSON")
})?;
// Find user
let user_result = sqlx::query_as::<_, User>(
"SELECT id, username, password_hash, registrations_remaining, domain_invite_codes, created_at FROM users WHERE username = $1"
)
.bind(&login_req.username)
.fetch_optional(&app_state.db)
.await;
match user_result {
Ok(Some(user)) => {
// Verify password
match verify_password(&login_req.password, &user.password_hash) {
Ok(true) => {
// Generate JWT
match generate_jwt(user.id, &user.username, &app_state.jwt_secret) {
Ok(token) => {
let response = LoginResponse {
token,
user: UserInfo {
id: user.id,
username: user.username,
registrations_remaining: user.registrations_remaining,
domain_invite_codes: user.domain_invite_codes,
created_at: user.created_at,
},
};
Ok(GurtResponse::ok().with_json_body(&response)?)
}
Err(_) => {
Ok(GurtResponse::internal_server_error().with_json_body(&Error {
msg: "Failed to generate token",
error: "JWT_ERROR".into(),
})?)
}
}
}
Ok(false) => {
Ok(GurtResponse::new(GurtStatusCode::Unauthorized).with_json_body(&Error {
msg: "Invalid credentials",
error: "INVALID_CREDENTIALS".into(),
})?)
}
Err(_) => {
Ok(GurtResponse::internal_server_error().with_json_body(&Error {
msg: "Password verification failed",
error: "PASSWORD_ERROR".into(),
})?)
}
}
}
Ok(None) => {
Ok(GurtResponse::new(GurtStatusCode::Unauthorized).with_json_body(&Error {
msg: "Invalid credentials",
error: "INVALID_CREDENTIALS".into(),
})?)
}
Err(_) => {
Ok(GurtResponse::internal_server_error().with_json_body(&Error {
msg: "Database error",
error: "DATABASE_ERROR".into(),
})?)
}
}
}
pub(crate) async fn get_user_info(_ctx: &ServerContext, app_state: AppState, claims: Claims) -> Result<GurtResponse> {
let user_result = sqlx::query_as::<_, User>(
"SELECT id, username, password_hash, registrations_remaining, domain_invite_codes, created_at FROM users WHERE id = $1"
)
.bind(claims.user_id)
.fetch_optional(&app_state.db)
.await;
match user_result {
Ok(Some(user)) => {
let user_info = UserInfo {
id: user.id,
username: user.username,
registrations_remaining: user.registrations_remaining,
domain_invite_codes: user.domain_invite_codes,
created_at: user.created_at,
};
Ok(GurtResponse::ok().with_json_body(&user_info)?)
}
Ok(None) => {
Ok(GurtResponse::not_found().with_json_body(&Error {
msg: "User not found",
error: "USER_NOT_FOUND".into(),
})?)
}
Err(_) => {
Ok(GurtResponse::internal_server_error().with_json_body(&Error {
msg: "Database error",
error: "DATABASE_ERROR".into(),
})?)
}
}
}
pub(crate) async fn create_invite(_ctx: &ServerContext, app_state: AppState, claims: Claims) -> Result<GurtResponse> {
// Generate random invite code
let invite_code: String = {
use rand::Rng;
let mut rng = rand::thread_rng();
(0..12)
.map(|_| {
let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
chars[rng.gen_range(0..chars.len())] as char
})
.collect()
};
// Insert invite code into database
let insert_result = sqlx::query(
"INSERT INTO invite_codes (code, created_by, created_at) VALUES ($1, $2, $3)"
)
.bind(&invite_code)
.bind(claims.user_id)
.bind(Utc::now())
.execute(&app_state.db)
.await;
match insert_result {
Ok(_) => {
let response = serde_json::json!({
"invite_code": invite_code
});
Ok(GurtResponse::ok().with_json_body(&response)?)
}
Err(_) => {
Ok(GurtResponse::internal_server_error().with_json_body(&Error {
msg: "Failed to create invite",
error: "DATABASE_ERROR".into(),
})?)
}
}
}
pub(crate) async fn redeem_invite(ctx: &ServerContext, app_state: AppState, claims: Claims) -> Result<GurtResponse> {
let request: serde_json::Value = serde_json::from_slice(ctx.body())
.map_err(|_| GurtError::invalid_message("Invalid JSON"))?;
let invite_code = request["invite_code"].as_str()
.ok_or(GurtError::invalid_message("Missing invite_code"))?;
// Check if invite code exists and is not used
let invite_result = sqlx::query_as::<_, InviteCode>(
"SELECT id, code, created_by, used_by, created_at, used_at FROM invite_codes WHERE code = $1 AND used_by IS NULL"
)
.bind(invite_code)
.fetch_optional(&app_state.db)
.await;
match invite_result {
Ok(Some(invite)) => {
// Mark invite as used and give user 3 additional registrations
let mut tx = app_state.db.begin().await
.map_err(|_| GurtError::invalid_message("Database error"))?;
sqlx::query("UPDATE invite_codes SET used_by = $1, used_at = $2 WHERE id = $3")
.bind(claims.user_id)
.bind(Utc::now())
.bind(invite.id)
.execute(&mut *tx)
.await
.map_err(|_| GurtError::invalid_message("Database error"))?;
sqlx::query("UPDATE users SET registrations_remaining = registrations_remaining + 3 WHERE id = $1")
.bind(claims.user_id)
.execute(&mut *tx)
.await
.map_err(|_| GurtError::invalid_message("Database error"))?;
tx.commit().await
.map_err(|_| GurtError::invalid_message("Database error"))?;
let response = serde_json::json!({
"registrations_added": 3
});
Ok(GurtResponse::ok().with_json_body(&response)?)
}
Ok(None) => {
Ok(GurtResponse::bad_request().with_json_body(&Error {
msg: "Invalid or already used invite code",
error: "INVALID_INVITE".into(),
})?)
}
Err(_) => {
Ok(GurtResponse::internal_server_error().with_json_body(&Error {
msg: "Database error",
error: "DATABASE_ERROR".into(),
})?)
}
}
}
pub(crate) async fn create_domain_invite(_ctx: &ServerContext, app_state: AppState, claims: Claims) -> Result<GurtResponse> {
// Check if user has domain invite codes remaining
let user: (i32,) = sqlx::query_as("SELECT domain_invite_codes FROM users WHERE id = $1")
.bind(claims.user_id)
.fetch_one(&app_state.db)
.await
.map_err(|_| GurtError::invalid_message("User not found"))?;
if user.0 <= 0 {
return Ok(GurtResponse::bad_request().with_json_body(&Error {
msg: "No domain invite codes remaining",
error: "NO_INVITES_REMAINING".into(),
})?);
}
// Generate random domain invite code
let invite_code: String = {
use rand::Rng;
let mut rng = rand::thread_rng();
(0..12)
.map(|_| {
let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
chars[rng.gen_range(0..chars.len())] as char
})
.collect()
};
// Insert domain invite code and decrease user's count
let mut tx = app_state.db.begin().await
.map_err(|_| GurtError::invalid_message("Database error"))?;
sqlx::query(
"INSERT INTO domain_invite_codes (code, created_by, created_at) VALUES ($1, $2, $3)"
)
.bind(&invite_code)
.bind(claims.user_id)
.bind(Utc::now())
.execute(&mut *tx)
.await
.map_err(|_| GurtError::invalid_message("Database error"))?;
sqlx::query("UPDATE users SET domain_invite_codes = domain_invite_codes - 1 WHERE id = $1")
.bind(claims.user_id)
.execute(&mut *tx)
.await
.map_err(|_| GurtError::invalid_message("Database error"))?;
tx.commit().await
.map_err(|_| GurtError::invalid_message("Database error"))?;
let response = serde_json::json!({
"domain_invite_code": invite_code
});
Ok(GurtResponse::ok().with_json_body(&response)?)
}
pub(crate) async fn redeem_domain_invite(ctx: &ServerContext, app_state: AppState, claims: Claims) -> Result<GurtResponse> {
let request: serde_json::Value = serde_json::from_slice(ctx.body())
.map_err(|_| GurtError::invalid_message("Invalid JSON"))?;
let invite_code = request["domain_invite_code"].as_str()
.ok_or(GurtError::invalid_message("Missing domain_invite_code"))?;
// Check if domain invite code exists and is not used
let invite_result = sqlx::query_as::<_, DomainInviteCode>(
"SELECT id, code, created_by, used_by, created_at, used_at FROM domain_invite_codes WHERE code = $1 AND used_by IS NULL"
)
.bind(invite_code)
.fetch_optional(&app_state.db)
.await;
match invite_result {
Ok(Some(invite)) => {
// Mark invite as used and give user 1 additional domain invite code
let mut tx = app_state.db.begin().await
.map_err(|_| GurtError::invalid_message("Database error"))?;
sqlx::query("UPDATE domain_invite_codes SET used_by = $1, used_at = $2 WHERE id = $3")
.bind(claims.user_id)
.bind(Utc::now())
.bind(invite.id)
.execute(&mut *tx)
.await
.map_err(|_| GurtError::invalid_message("Database error"))?;
sqlx::query("UPDATE users SET domain_invite_codes = domain_invite_codes + 1 WHERE id = $1")
.bind(claims.user_id)
.execute(&mut *tx)
.await
.map_err(|_| GurtError::invalid_message("Database error"))?;
tx.commit().await
.map_err(|_| GurtError::invalid_message("Database error"))?;
let response = serde_json::json!({
"domain_invite_codes_added": 1
});
Ok(GurtResponse::ok().with_json_body(&response)?)
}
Ok(None) => {
Ok(GurtResponse::bad_request().with_json_body(&Error {
msg: "Invalid or already used domain invite code",
error: "INVALID_DOMAIN_INVITE".into(),
})?)
}
Err(_) => {
Ok(GurtResponse::internal_server_error().with_json_body(&Error {
msg: "Database error",
error: "DATABASE_ERROR".into(),
})?)
}
}
}
#[derive(serde::Serialize)]
struct Error {
msg: &'static str,
error: String,
}

View File

@@ -0,0 +1,20 @@
use gurt::prelude::*;
use std::net::IpAddr;
pub fn validate_ip(domain: &super::models::Domain) -> Result<()> {
if domain.ip.parse::<IpAddr>().is_err() {
return Err(GurtError::invalid_message("Invalid IP address"));
}
Ok(())
}
pub fn deserialize_lowercase<'de, D>(deserializer: D) -> std::result::Result<String, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
let s = String::deserialize(deserializer)?;
Ok(s.to_lowercase())
}

View File

@@ -80,7 +80,7 @@ pub(crate) struct ResponseDnsRecord {
pub(crate) record_type: String, pub(crate) record_type: String,
pub(crate) name: String, pub(crate) name: String,
pub(crate) value: String, pub(crate) value: String,
pub(crate) ttl: i32, pub(crate) ttl: Option<i32>,
pub(crate) priority: Option<i32>, pub(crate) priority: Option<i32>,
} }
@@ -89,27 +89,6 @@ pub(crate) struct UpdateDomain {
pub(crate) ip: String, pub(crate) ip: String,
} }
#[derive(Serialize)]
pub(crate) struct Error {
pub(crate) msg: &'static str,
pub(crate) error: String,
}
#[derive(Serialize)]
pub(crate) struct Ratelimit {
pub(crate) msg: String,
pub(crate) error: &'static str,
pub(crate) after: u64,
}
#[derive(Deserialize)]
pub(crate) struct PaginationParams {
#[serde(alias = "p", alias = "doc")]
pub(crate) page: Option<u32>,
#[serde(alias = "s", alias = "size", alias = "l", alias = "limit")]
pub(crate) page_size: Option<u32>,
}
#[derive(Serialize)] #[derive(Serialize)]
pub(crate) struct PaginationResponse { pub(crate) struct PaginationResponse {
pub(crate) domains: Vec<ResponseDomain>, pub(crate) domains: Vec<ResponseDomain>,
@@ -117,12 +96,6 @@ pub(crate) struct PaginationResponse {
pub(crate) limit: u32, pub(crate) limit: u32,
} }
#[derive(Deserialize)]
pub(crate) struct DomainQuery {
pub(crate) name: String,
pub(crate) tld: Option<String>,
}
#[derive(Serialize)] #[derive(Serialize)]
pub(crate) struct DomainList { pub(crate) struct DomainList {
pub(crate) domain: String, pub(crate) domain: String,

View File

@@ -0,0 +1,314 @@
use super::{models::*, AppState, helpers::validate_ip};
use crate::auth::Claims;
use gurt::prelude::*;
use std::{env, collections::HashMap};
fn parse_query_string(query: &str) -> HashMap<String, String> {
let mut params = HashMap::new();
for pair in query.split('&') {
if let Some((key, value)) = pair.split_once('=') {
params.insert(key.to_string(), value.to_string());
}
}
params
}
pub(crate) async fn index(_app_state: AppState) -> Result<GurtResponse> {
let body = format!(
"GurtDNS v{}!\n\nThe available endpoints are:\n\n - [GET] /domains\n - [GET] /domain/{{name}}/{{tld}}\n - [POST] /domain\n - [PUT] /domain/{{key}}\n - [DELETE] /domain/{{key}}\n - [GET] /tlds\n\nRatelimits are as follows: 5 requests per 10 minutes on `[POST] /domain`.\n\nCode link: https://github.com/outpoot/gurted",
env!("CARGO_PKG_VERSION")
);
Ok(GurtResponse::ok().with_string_body(body))
}
pub(crate) async fn create_logic(domain: Domain, user_id: i32, app: &AppState) -> Result<Domain> {
validate_ip(&domain)?;
if !app.config.tld_list().contains(&domain.tld.as_str())
|| !domain.name.chars().all(|c| c.is_alphabetic() || c == '-')
|| domain.name.len() > 24
|| domain.name.is_empty()
|| domain.name.starts_with('-')
|| domain.name.ends_with('-') {
return Err(GurtError::invalid_message("Invalid name, non-existent TLD, or name too long (24 chars)."));
}
if app.config.offen_words().iter().any(|word| domain.name.contains(word)) {
return Err(GurtError::invalid_message("The given domain name is offensive."));
}
let existing_count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM domains WHERE name = ? AND tld = ?"
)
.bind(&domain.name)
.bind(&domain.tld)
.fetch_one(&app.db)
.await
.map_err(|_| GurtError::invalid_message("Database error"))?;
if existing_count > 0 {
return Err(GurtError::invalid_message("Domain already exists"));
}
sqlx::query(
"INSERT INTO domains (name, tld, ip, user_id, status) VALUES ($1, $2, $3, $4, 'pending')"
)
.bind(&domain.name)
.bind(&domain.tld)
.bind(&domain.ip)
.bind(user_id)
.execute(&app.db)
.await
.map_err(|_| GurtError::invalid_message("Failed to create domain"))?;
// Decrease user's registrations remaining
sqlx::query("UPDATE users SET registrations_remaining = registrations_remaining - 1 WHERE id = $1")
.bind(user_id)
.execute(&app.db)
.await
.map_err(|_| GurtError::invalid_message("Failed to update user registrations"))?;
Ok(domain)
}
pub(crate) async fn create_domain(ctx: &ServerContext, app_state: AppState, claims: Claims) -> Result<GurtResponse> {
// Check if user has registrations remaining
let user: (i32,) = sqlx::query_as("SELECT registrations_remaining FROM users WHERE id = $1")
.bind(claims.user_id)
.fetch_one(&app_state.db)
.await
.map_err(|_| GurtError::invalid_message("User not found"))?;
if user.0 <= 0 {
return Ok(GurtResponse::bad_request().with_json_body(&Error {
msg: "Failed to create domain",
error: "No registrations remaining".into(),
})?);
}
let domain: Domain = serde_json::from_slice(ctx.body())
.map_err(|_| GurtError::invalid_message("Invalid JSON"))?;
match create_logic(domain.clone(), claims.user_id, &app_state).await {
Ok(created_domain) => {
Ok(GurtResponse::ok().with_json_body(&created_domain)?)
}
Err(e) => {
Ok(GurtResponse::bad_request().with_json_body(&Error {
msg: "Failed to create domain",
error: e.to_string(),
})?)
}
}
}
pub(crate) async fn get_domain(ctx: &ServerContext, app_state: AppState) -> Result<GurtResponse> {
let path_parts: Vec<&str> = ctx.path().split('/').collect();
if path_parts.len() < 4 {
return Ok(GurtResponse::bad_request().with_string_body("Invalid path format. Expected /domain/{name}/{tld}"));
}
let name = path_parts[2];
let tld = path_parts[3];
let domain: Option<Domain> = sqlx::query_as::<_, Domain>(
"SELECT id, name, tld, ip, user_id, status, denial_reason, created_at FROM domains WHERE name = $1 AND tld = $2 AND status = 'approved'"
)
.bind(name)
.bind(tld)
.fetch_optional(&app_state.db)
.await
.map_err(|_| GurtError::invalid_message("Database error"))?;
match domain {
Some(domain) => {
let response_domain = ResponseDomain {
name: domain.name,
tld: domain.tld,
ip: domain.ip,
records: None, // TODO: Implement DNS records
};
Ok(GurtResponse::ok().with_json_body(&response_domain)?)
}
None => Ok(GurtResponse::not_found().with_string_body("Domain not found"))
}
}
pub(crate) async fn get_domains(ctx: &ServerContext, app_state: AppState) -> Result<GurtResponse> {
// Parse pagination from query parameters
let path = ctx.path();
let query_params = if let Some(query_start) = path.find('?') {
let query_string = &path[query_start + 1..];
parse_query_string(query_string)
} else {
HashMap::new()
};
let page = query_params.get("page")
.and_then(|p| p.parse::<u32>().ok())
.unwrap_or(1)
.max(1); // Ensure page is at least 1
let page_size = query_params.get("limit")
.and_then(|l| l.parse::<u32>().ok())
.unwrap_or(100)
.clamp(1, 1000); // Limit between 1 and 1000
let offset = (page - 1) * page_size;
let domains: Vec<Domain> = sqlx::query_as::<_, Domain>(
"SELECT id, name, tld, ip, user_id, status, denial_reason, created_at FROM domains WHERE status = 'approved' ORDER BY created_at DESC LIMIT $1 OFFSET $2"
)
.bind(page_size as i64)
.bind(offset as i64)
.fetch_all(&app_state.db)
.await
.map_err(|_| GurtError::invalid_message("Database error"))?;
let response_domains: Vec<ResponseDomain> = domains.into_iter().map(|domain| {
ResponseDomain {
name: domain.name,
tld: domain.tld,
ip: domain.ip,
records: None,
}
}).collect();
let response = PaginationResponse {
domains: response_domains,
page,
limit: page_size,
};
Ok(GurtResponse::ok().with_json_body(&response)?)
}
pub(crate) async fn get_tlds(app_state: AppState) -> Result<GurtResponse> {
Ok(GurtResponse::ok().with_json_body(&app_state.config.tld_list())?)
}
pub(crate) async fn check_domain(ctx: &ServerContext, app_state: AppState) -> Result<GurtResponse> {
let path = ctx.path();
let query_params = if let Some(query_start) = path.find('?') {
let query_string = &path[query_start + 1..];
parse_query_string(query_string)
} else {
return Ok(GurtResponse::bad_request().with_string_body("Missing query parameters. Expected ?name=<name>&tld=<tld>"));
};
let name = query_params.get("name")
.ok_or_else(|| GurtError::invalid_message("Missing 'name' parameter"))?;
let tld = query_params.get("tld")
.ok_or_else(|| GurtError::invalid_message("Missing 'tld' parameter"))?;
let domain: Option<Domain> = sqlx::query_as::<_, Domain>(
"SELECT id, name, tld, ip, user_id, status, denial_reason, created_at FROM domains WHERE name = $1 AND tld = $2"
)
.bind(name)
.bind(tld)
.fetch_optional(&app_state.db)
.await
.map_err(|_| GurtError::invalid_message("Database error"))?;
let domain_list = DomainList {
domain: format!("{}.{}", name, tld),
taken: domain.is_some(),
};
Ok(GurtResponse::ok().with_json_body(&domain_list)?)
}
pub(crate) async fn update_domain(ctx: &ServerContext, app_state: AppState, claims: Claims) -> Result<GurtResponse> {
let path_parts: Vec<&str> = ctx.path().split('/').collect();
if path_parts.len() < 4 {
return Ok(GurtResponse::bad_request().with_string_body("Invalid path format. Expected /domain/{name}/{tld}"));
}
let name = path_parts[2];
let tld = path_parts[3];
let update_data: UpdateDomain = serde_json::from_slice(ctx.body())
.map_err(|_| GurtError::invalid_message("Invalid JSON"))?;
// Verify user owns this domain
let domain: Option<Domain> = sqlx::query_as::<_, Domain>(
"SELECT id, name, tld, ip, user_id, status, denial_reason, created_at FROM domains WHERE name = $1 AND tld = $2 AND user_id = $3"
)
.bind(name)
.bind(tld)
.bind(claims.user_id)
.fetch_optional(&app_state.db)
.await
.map_err(|_| GurtError::invalid_message("Database error"))?;
let domain = match domain {
Some(d) => d,
None => return Ok(GurtResponse::not_found().with_string_body("Domain not found or access denied"))
};
// Validate IP
validate_ip(&Domain {
id: domain.id,
name: domain.name.clone(),
tld: domain.tld.clone(),
ip: update_data.ip.clone(),
user_id: domain.user_id,
status: domain.status,
denial_reason: domain.denial_reason,
created_at: domain.created_at,
})?;
sqlx::query("UPDATE domains SET ip = $1 WHERE name = $2 AND tld = $3 AND user_id = $4")
.bind(&update_data.ip)
.bind(name)
.bind(tld)
.bind(claims.user_id)
.execute(&app_state.db)
.await
.map_err(|_| GurtError::invalid_message("Failed to update domain"))?;
Ok(GurtResponse::ok().with_string_body("Domain updated successfully"))
}
pub(crate) async fn delete_domain(ctx: &ServerContext, app_state: AppState, claims: Claims) -> Result<GurtResponse> {
let path_parts: Vec<&str> = ctx.path().split('/').collect();
if path_parts.len() < 4 {
return Ok(GurtResponse::bad_request().with_string_body("Invalid path format. Expected /domain/{name}/{tld}"));
}
let name = path_parts[2];
let tld = path_parts[3];
// Verify user owns this domain
let domain: Option<Domain> = sqlx::query_as::<_, Domain>(
"SELECT id, name, tld, ip, user_id, status, denial_reason, created_at FROM domains WHERE name = $1 AND tld = $2 AND user_id = $3"
)
.bind(name)
.bind(tld)
.bind(claims.user_id)
.fetch_optional(&app_state.db)
.await
.map_err(|_| GurtError::invalid_message("Database error"))?;
if domain.is_none() {
return Ok(GurtResponse::not_found().with_string_body("Domain not found or access denied"));
}
sqlx::query("DELETE FROM domains WHERE name = $1 AND tld = $2 AND user_id = $3")
.bind(name)
.bind(tld)
.bind(claims.user_id)
.execute(&app_state.db)
.await
.map_err(|_| GurtError::invalid_message("Failed to delete domain"))?;
Ok(GurtResponse::ok().with_string_body("Domain deleted successfully"))
}
#[derive(serde::Serialize)]
struct Error {
msg: &'static str,
error: String,
}

View File

@@ -1,94 +0,0 @@
mod auth_routes;
mod helpers;
mod models;
mod ratelimit;
mod routes;
use crate::{auth::jwt_middleware, config::Config, discord_bot};
use actix_governor::{Governor, GovernorConfigBuilder};
use actix_web::{http::Method, web, web::Data, App, HttpServer};
use actix_web_httpauth::middleware::HttpAuthentication;
use colored::Colorize;
use macros_rs::fmt::{crashln, string};
use ratelimit::RealIpKeyExtractor;
use std::{net::IpAddr, str::FromStr, time::Duration};
// Domain struct is now defined in models.rs
#[derive(Clone)]
pub(crate) struct AppState {
trusted: IpAddr,
config: Config,
db: sqlx::PgPool,
}
#[actix_web::main]
pub async fn start(cli: crate::Cli) -> std::io::Result<()> {
let config = Config::new().set_path(&cli.config).read();
let trusted_ip = match IpAddr::from_str(&config.server.address) {
Ok(addr) => addr,
Err(err) => crashln!("Cannot parse address.\n{}", string!(err).white()),
};
let governor_builder = GovernorConfigBuilder::default()
.methods(vec![Method::POST])
.period(Duration::from_secs(600))
.burst_size(5)
.key_extractor(RealIpKeyExtractor)
.finish()
.unwrap();
let db = match config.connect_to_db().await {
Ok(pool) => pool,
Err(err) => crashln!("Failed to connect to PostgreSQL database.\n{}", string!(err).white()),
};
// Start Discord bot
if !config.discord.bot_token.is_empty() {
if let Err(e) = discord_bot::start_discord_bot(config.discord.bot_token.clone(), db.clone()).await {
log::error!("Failed to start Discord bot: {}", e);
}
}
let auth_middleware = HttpAuthentication::bearer(jwt_middleware);
let jwt_secret = config.auth.jwt_secret.clone();
let app = move || {
let data = AppState {
db: db.clone(),
trusted: trusted_ip,
config: Config::new().set_path(&cli.config).read(),
};
App::new()
.app_data(Data::new(data))
.app_data(Data::new(jwt_secret.clone()))
// Public routes
.service(routes::index)
.service(routes::get_domain)
.service(routes::get_domains)
.service(routes::get_tlds)
.service(routes::check_domain)
// Auth routes
.service(auth_routes::register)
.service(auth_routes::login)
// Protected routes
.service(
web::scope("")
.wrap(auth_middleware.clone())
.service(auth_routes::get_user_info)
.service(auth_routes::create_invite)
.service(auth_routes::redeem_invite)
.service(auth_routes::create_domain_invite)
.service(auth_routes::redeem_domain_invite)
.service(routes::update_domain)
.service(routes::delete_domain)
.route("/domain", web::post().to(routes::create_domain).wrap(Governor::new(&governor_builder)))
)
};
log::info!("Listening on {}", config.get_address());
HttpServer::new(app).bind(config.get_address())?.run().await
}

View File

@@ -1,547 +0,0 @@
use super::{models::*, AppState};
use crate::auth::*;
use actix_web::{web, HttpResponse, Responder, HttpRequest, HttpMessage};
use sqlx::Row;
use rand::Rng;
use chrono::Utc;
#[actix_web::post("/auth/register")]
pub(crate) async fn register(
user: web::Json<RegisterRequest>,
app: web::Data<AppState>
) -> impl Responder {
let registrations = 3; // New users get 3 registrations by default
// Hash password
let password_hash = match hash_password(&user.password) {
Ok(hash) => hash,
Err(_) => {
return HttpResponse::InternalServerError().json(Error {
msg: "Failed to hash password",
error: "HASH_ERROR".into(),
});
}
};
// Create user
let user_result = sqlx::query(
"INSERT INTO users (username, password_hash, registrations_remaining, domain_invite_codes) VALUES ($1, $2, $3, $4) RETURNING id"
)
.bind(&user.username)
.bind(&password_hash)
.bind(registrations)
.bind(3) // Default 3 domain invite codes
.fetch_one(&app.db)
.await;
match user_result {
Ok(row) => {
let user_id: i32 = row.get("id");
// Generate JWT
match generate_jwt(user_id, &user.username, &app.config.auth.jwt_secret) {
Ok(token) => {
HttpResponse::Ok().json(LoginResponse {
token,
user: UserInfo {
id: user_id,
username: user.username.clone(),
registrations_remaining: registrations,
domain_invite_codes: 3,
created_at: Utc::now(),
},
})
}
Err(_) => HttpResponse::InternalServerError().json(Error {
msg: "Failed to generate token",
error: "TOKEN_ERROR".into(),
}),
}
}
Err(sqlx::Error::Database(db_err)) => {
if db_err.is_unique_violation() {
HttpResponse::Conflict().json(Error {
msg: "Username already exists",
error: "USER_EXISTS".into(),
})
} else {
HttpResponse::InternalServerError().json(Error {
msg: "Database error",
error: "DB_ERROR".into(),
})
}
}
Err(_) => HttpResponse::InternalServerError().json(Error {
msg: "Database error",
error: "DB_ERROR".into(),
}),
}
}
#[actix_web::post("/auth/login")]
pub(crate) async fn login(
credentials: web::Json<LoginRequest>,
app: web::Data<AppState>
) -> impl Responder {
match sqlx::query_as::<_, User>(
"SELECT id, username, password_hash, registrations_remaining, domain_invite_codes, created_at FROM users WHERE username = $1"
)
.bind(&credentials.username)
.fetch_optional(&app.db)
.await
{
Ok(Some(user)) => {
match verify_password(&credentials.password, &user.password_hash) {
Ok(true) => {
match generate_jwt(user.id, &user.username, &app.config.auth.jwt_secret) {
Ok(token) => {
HttpResponse::Ok().json(LoginResponse {
token,
user: UserInfo {
id: user.id,
username: user.username,
registrations_remaining: user.registrations_remaining,
domain_invite_codes: user.domain_invite_codes,
created_at: user.created_at,
},
})
}
Err(e) => {
eprintln!("JWT generation error: {:?}", e);
HttpResponse::InternalServerError().json(Error {
msg: "Failed to generate token",
error: "TOKEN_ERROR".into(),
})
},
}
}
Ok(false) | Err(_) => {
HttpResponse::Unauthorized().json(Error {
msg: "Invalid credentials",
error: "INVALID_CREDENTIALS".into(),
})
}
}
}
Ok(None) => {
HttpResponse::Unauthorized().json(Error {
msg: "Invalid credentials",
error: "INVALID_CREDENTIALS".into(),
})
}
Err(e) => {
eprintln!("Database error: {:?}", e);
HttpResponse::InternalServerError().json(Error {
msg: "Database error",
error: "DB_ERROR".into(),
})
},
}
}
#[actix_web::get("/auth/me")]
pub(crate) async fn get_user_info(
req: HttpRequest,
app: web::Data<AppState>
) -> impl Responder {
let extensions = req.extensions();
let claims = match extensions.get::<Claims>() {
Some(claims) => claims,
None => {
return HttpResponse::Unauthorized().json(Error {
msg: "Authentication required",
error: "AUTH_REQUIRED".into(),
});
}
};
match sqlx::query_as::<_, User>(
"SELECT id, username, password_hash, registrations_remaining, domain_invite_codes, created_at FROM users WHERE id = $1"
)
.bind(claims.user_id)
.fetch_optional(&app.db)
.await
{
Ok(Some(user)) => {
HttpResponse::Ok().json(UserInfo {
id: user.id,
username: user.username,
registrations_remaining: user.registrations_remaining,
domain_invite_codes: user.domain_invite_codes,
created_at: user.created_at,
})
}
Ok(None) => HttpResponse::NotFound().json(Error {
msg: "User not found",
error: "USER_NOT_FOUND".into(),
}),
Err(_) => HttpResponse::InternalServerError().json(Error {
msg: "Database error",
error: "DB_ERROR".into(),
}),
}
}
#[actix_web::post("/auth/invite")]
pub(crate) async fn create_invite(
req: HttpRequest,
app: web::Data<AppState>
) -> impl Responder {
let extensions = req.extensions();
let claims = match extensions.get::<Claims>() {
Some(claims) => claims,
None => {
return HttpResponse::Unauthorized().json(Error {
msg: "Authentication required",
error: "AUTH_REQUIRED".into(),
});
}
};
// Generate random invite code
let invite_code: String = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(16)
.map(char::from)
.collect();
// Create invite code (no registration cost)
match sqlx::query(
"INSERT INTO invite_codes (code, created_by) VALUES ($1, $2)"
)
.bind(&invite_code)
.bind(claims.user_id)
.execute(&app.db)
.await
{
Ok(_) => {},
Err(_) => {
return HttpResponse::InternalServerError().json(Error {
msg: "Failed to create invite code",
error: "DB_ERROR".into(),
});
}
}
HttpResponse::Ok().json(serde_json::json!({
"invite_code": invite_code
}))
}
#[actix_web::post("/auth/redeem-invite")]
pub(crate) async fn redeem_invite(
invite_request: web::Json<serde_json::Value>,
req: HttpRequest,
app: web::Data<AppState>
) -> impl Responder {
let extensions = req.extensions();
let claims = match extensions.get::<Claims>() {
Some(claims) => claims,
None => {
return HttpResponse::Unauthorized().json(Error {
msg: "Authentication required",
error: "AUTH_REQUIRED".into(),
});
}
};
let invite_code = match invite_request.get("invite_code").and_then(|v| v.as_str()) {
Some(code) => code,
None => {
return HttpResponse::BadRequest().json(Error {
msg: "Invite code is required",
error: "INVITE_CODE_REQUIRED".into(),
});
}
};
// Find and validate invite code
let invite = match sqlx::query_as::<_, InviteCode>(
"SELECT id, code, created_by, used_by, created_at, used_at FROM invite_codes WHERE code = $1 AND used_by IS NULL"
)
.bind(invite_code)
.fetch_optional(&app.db)
.await
{
Ok(Some(invite)) => invite,
Ok(None) => {
return HttpResponse::BadRequest().json(Error {
msg: "Invalid or already used invite code",
error: "INVALID_INVITE".into(),
});
}
Err(_) => {
return HttpResponse::InternalServerError().json(Error {
msg: "Database error",
error: "DB_ERROR".into(),
});
}
};
// Start transaction to redeem invite
let mut tx = match app.db.begin().await {
Ok(tx) => tx,
Err(_) => {
return HttpResponse::InternalServerError().json(Error {
msg: "Database error",
error: "DB_ERROR".into(),
});
}
};
// Mark invite as used
if let Err(_) = sqlx::query(
"UPDATE invite_codes SET used_by = $1, used_at = CURRENT_TIMESTAMP WHERE id = $2"
)
.bind(claims.user_id)
.bind(invite.id)
.execute(&mut *tx)
.await
{
let _ = tx.rollback().await;
return HttpResponse::InternalServerError().json(Error {
msg: "Failed to redeem invite code",
error: "DB_ERROR".into(),
});
}
// Add registrations to user (3 registrations per invite)
if let Err(_) = sqlx::query(
"UPDATE users SET registrations_remaining = registrations_remaining + 3 WHERE id = $1"
)
.bind(claims.user_id)
.execute(&mut *tx)
.await
{
let _ = tx.rollback().await;
return HttpResponse::InternalServerError().json(Error {
msg: "Failed to add registrations",
error: "DB_ERROR".into(),
});
}
if let Err(_) = tx.commit().await {
return HttpResponse::InternalServerError().json(Error {
msg: "Transaction failed",
error: "DB_ERROR".into(),
});
}
HttpResponse::Ok().json(serde_json::json!({
"message": "Invite code redeemed successfully",
"registrations_added": 3
}))
}
#[actix_web::post("/auth/domain-invite")]
pub(crate) async fn create_domain_invite(
req: HttpRequest,
app: web::Data<AppState>
) -> impl Responder {
let extensions = req.extensions();
let claims = match extensions.get::<Claims>() {
Some(claims) => claims,
None => {
return HttpResponse::Unauthorized().json(Error {
msg: "Authentication required",
error: "AUTH_REQUIRED".into(),
});
}
};
// Check if user has domain invite codes remaining
let user = match sqlx::query_as::<_, User>(
"SELECT id, username, password_hash, registrations_remaining, domain_invite_codes, created_at FROM users WHERE id = $1"
)
.bind(claims.user_id)
.fetch_optional(&app.db)
.await
{
Ok(Some(user)) => user,
Ok(None) => {
return HttpResponse::NotFound().json(Error {
msg: "User not found",
error: "USER_NOT_FOUND".into(),
});
}
Err(_) => {
return HttpResponse::InternalServerError().json(Error {
msg: "Database error",
error: "DB_ERROR".into(),
});
}
};
if user.domain_invite_codes <= 0 {
return HttpResponse::BadRequest().json(Error {
msg: "No domain invite codes remaining",
error: "NO_DOMAIN_INVITES".into(),
});
}
// Generate random domain invite code
let invite_code: String = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(16)
.map(char::from)
.collect();
// Start transaction
let mut tx = match app.db.begin().await {
Ok(tx) => tx,
Err(_) => {
return HttpResponse::InternalServerError().json(Error {
msg: "Database error",
error: "DB_ERROR".into(),
});
}
};
// Create domain invite code
if let Err(_) = sqlx::query(
"INSERT INTO domain_invite_codes (code, created_by) VALUES ($1, $2)"
)
.bind(&invite_code)
.bind(claims.user_id)
.execute(&mut *tx)
.await
{
let _ = tx.rollback().await;
return HttpResponse::InternalServerError().json(Error {
msg: "Failed to create domain invite code",
error: "DB_ERROR".into(),
});
}
// Decrease user's domain invite codes
if let Err(_) = sqlx::query(
"UPDATE users SET domain_invite_codes = domain_invite_codes - 1 WHERE id = $1"
)
.bind(claims.user_id)
.execute(&mut *tx)
.await
{
let _ = tx.rollback().await;
return HttpResponse::InternalServerError().json(Error {
msg: "Failed to update domain invite codes",
error: "DB_ERROR".into(),
});
}
if let Err(_) = tx.commit().await {
return HttpResponse::InternalServerError().json(Error {
msg: "Transaction failed",
error: "DB_ERROR".into(),
});
}
HttpResponse::Ok().json(serde_json::json!({
"domain_invite_code": invite_code
}))
}
#[actix_web::post("/auth/redeem-domain-invite")]
pub(crate) async fn redeem_domain_invite(
invite_request: web::Json<serde_json::Value>,
req: HttpRequest,
app: web::Data<AppState>
) -> impl Responder {
let extensions = req.extensions();
let claims = match extensions.get::<Claims>() {
Some(claims) => claims,
None => {
return HttpResponse::Unauthorized().json(Error {
msg: "Authentication required",
error: "AUTH_REQUIRED".into(),
});
}
};
let invite_code = match invite_request.get("domain_invite_code").and_then(|v| v.as_str()) {
Some(code) => code,
None => {
return HttpResponse::BadRequest().json(Error {
msg: "Domain invite code is required",
error: "DOMAIN_INVITE_CODE_REQUIRED".into(),
});
}
};
// Find and validate domain invite code
let invite = match sqlx::query_as::<_, DomainInviteCode>(
"SELECT id, code, created_by, used_by, created_at, used_at FROM domain_invite_codes WHERE code = $1 AND used_by IS NULL"
)
.bind(invite_code)
.fetch_optional(&app.db)
.await
{
Ok(Some(invite)) => invite,
Ok(None) => {
return HttpResponse::BadRequest().json(Error {
msg: "Invalid or already used domain invite code",
error: "INVALID_DOMAIN_INVITE".into(),
});
}
Err(_) => {
return HttpResponse::InternalServerError().json(Error {
msg: "Database error",
error: "DB_ERROR".into(),
});
}
};
// Start transaction to redeem invite
let mut tx = match app.db.begin().await {
Ok(tx) => tx,
Err(_) => {
return HttpResponse::InternalServerError().json(Error {
msg: "Database error",
error: "DB_ERROR".into(),
});
}
};
// Mark domain invite as used
if let Err(_) = sqlx::query(
"UPDATE domain_invite_codes SET used_by = $1, used_at = CURRENT_TIMESTAMP WHERE id = $2"
)
.bind(claims.user_id)
.bind(invite.id)
.execute(&mut *tx)
.await
{
let _ = tx.rollback().await;
return HttpResponse::InternalServerError().json(Error {
msg: "Failed to redeem domain invite code",
error: "DB_ERROR".into(),
});
}
// Add domain invite codes to user (1 per domain invite)
if let Err(_) = sqlx::query(
"UPDATE users SET domain_invite_codes = domain_invite_codes + 1 WHERE id = $1"
)
.bind(claims.user_id)
.execute(&mut *tx)
.await
{
let _ = tx.rollback().await;
return HttpResponse::InternalServerError().json(Error {
msg: "Failed to add domain invite codes",
error: "DB_ERROR".into(),
});
}
if let Err(_) = tx.commit().await {
return HttpResponse::InternalServerError().json(Error {
msg: "Transaction failed",
error: "DB_ERROR".into(),
});
}
HttpResponse::Ok().json(serde_json::json!({
"message": "Domain invite code redeemed successfully",
"domain_invite_codes_added": 1
}))
}

View File

@@ -1,72 +0,0 @@
use super::{models::*, AppState};
use actix_web::{web::Data, HttpResponse};
use regex::Regex;
use serde::Deserialize;
use std::net::{Ipv4Addr, Ipv6Addr};
pub fn validate_ip(domain: &Domain) -> Result<(), HttpResponse> {
let valid_url = Regex::new(r"(?i)\bhttps?://[-a-z0-9+&@#/%?=~_|!:,.;]*[-a-z0-9+&@#/%=~_|]").unwrap();
let is_valid_ip = domain.ip.parse::<Ipv4Addr>().is_ok() || domain.ip.parse::<Ipv6Addr>().is_ok();
let is_valid_url = valid_url.is_match(&domain.ip);
if is_valid_ip || is_valid_url {
if domain.name.len() <= 100 {
Ok(())
} else {
Err(HttpResponse::BadRequest().json(Error {
msg: "Failed to create domain",
error: "Invalid name, non-existent TLD, or name too long (100 chars).".into(),
}))
}
} else {
Err(HttpResponse::BadRequest().json(Error {
msg: "Failed to create domain",
error: "Invalid name, non-existent TLD, or name too long (100 chars).".into(),
}))
}
}
pub fn deserialize_lowercase<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(s.to_lowercase())
}
pub async fn is_domain_taken(name: &str, tld: Option<&str>, app: Data<AppState>) -> Vec<DomainList> {
if let Some(tld) = tld {
let count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM domains WHERE name = ? AND tld = ?"
)
.bind(name)
.bind(tld)
.fetch_one(&app.db)
.await
.unwrap_or(0);
vec![DomainList {
taken: count > 0,
domain: format!("{}.{}", name, tld),
}]
} else {
let mut result = Vec::new();
for tld in &*app.config.tld_list() {
let count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM domains WHERE name = ? AND tld = ?"
)
.bind(name)
.bind(tld)
.fetch_one(&app.db)
.await
.unwrap_or(0);
result.push(DomainList {
taken: count > 0,
domain: format!("{}.{}", name, tld),
});
}
result
}
}

View File

@@ -1,61 +0,0 @@
use super::models::Ratelimit;
use actix_web::{dev::ServiceRequest, web, HttpResponse, HttpResponseBuilder};
use std::{
net::{IpAddr, SocketAddr},
str::FromStr,
time::{SystemTime, UNIX_EPOCH},
};
use actix_governor::{
governor::clock::{Clock, DefaultClock, QuantaInstant},
governor::NotUntil,
KeyExtractor, SimpleKeyExtractionError,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct RealIpKeyExtractor;
impl KeyExtractor for RealIpKeyExtractor {
type Key = IpAddr;
type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
fn extract(&self, req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
let reverse_proxy_ip = req
.app_data::<web::Data<super::AppState>>()
.map(|ip| ip.get_ref().trusted.to_owned())
.unwrap_or_else(|| IpAddr::from_str("0.0.0.0").unwrap());
let peer_ip = req.peer_addr().map(|socket| socket.ip());
let connection_info = req.connection_info();
match peer_ip {
Some(peer) if peer == reverse_proxy_ip => connection_info
.realip_remote_addr()
.ok_or_else(|| SimpleKeyExtractionError::new("Could not extract real IP address from request"))
.and_then(|str| {
SocketAddr::from_str(str)
.map(|socket| socket.ip())
.or_else(|_| IpAddr::from_str(str))
.map_err(|_| SimpleKeyExtractionError::new("Could not extract real IP address from request"))
}),
_ => connection_info
.peer_addr()
.ok_or_else(|| SimpleKeyExtractionError::new("Could not extract peer IP address from request"))
.and_then(|str| SocketAddr::from_str(str).map_err(|_| SimpleKeyExtractionError::new("Could not extract peer IP address from request")))
.map(|socket| socket.ip()),
}
}
fn exceed_rate_limit_response(&self, negative: &NotUntil<QuantaInstant>, mut response: HttpResponseBuilder) -> HttpResponse {
let current_unix_timestamp = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time went backwards").as_secs();
let wait_time = negative.wait_time_from(DefaultClock::default().now()).as_secs();
let wait_time_unix = current_unix_timestamp + negative.wait_time_from(DefaultClock::default().now()).as_secs();
response.json(Ratelimit {
after: wait_time_unix,
error: "ratelimited_endpoint",
msg: format!("Too many requests, try again in {wait_time}s"),
})
}
}

View File

@@ -1,400 +0,0 @@
use super::{models::*, AppState};
use crate::{auth::Claims, discord_bot::*, http::helpers};
use std::env;
use actix_web::{
web::{self, Data},
HttpRequest, HttpResponse, Responder, HttpMessage,
};
#[actix_web::get("/")]
pub(crate) async fn index() -> impl Responder {
HttpResponse::Ok().body(format!(
"GurtDNS v{}!\n\nThe available endpoints are:\n\n - [GET] /domains\n - [GET] /domain/{{name}}/{{tld}}\n - [POST] /domain\n - [PUT] /domain/{{key}}\n - [DELETE] /domain/{{key}}\n - [GET] /tlds\n\nRatelimits are as follows: 5 requests per 10 minutes on `[POST] /domain`.\n\nCode link: https://github.com/outpoot/gurted",env!("CARGO_PKG_VERSION")),
)
}
pub(crate) async fn create_logic(domain: Domain, user_id: i32, app: &AppState) -> Result<Domain, HttpResponse> {
helpers::validate_ip(&domain)?;
if !app.config.tld_list().contains(&domain.tld.as_str()) || !domain.name.chars().all(|c| c.is_alphabetic() || c == '-') || domain.name.len() > 24 {
return Err(HttpResponse::BadRequest().json(Error {
msg: "Failed to create domain",
error: "Invalid name, non-existent TLD, or name too long (24 chars).".into(),
}));
}
if app.config.offen_words().iter().any(|word| domain.name.contains(word)) {
return Err(HttpResponse::BadRequest().json(Error {
msg: "Failed to create domain",
error: "The given domain name is offensive.".into(),
}));
}
let existing_count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM domains WHERE name = ? AND tld = ?"
)
.bind(&domain.name)
.bind(&domain.tld)
.fetch_one(&app.db)
.await
.map_err(|_| HttpResponse::InternalServerError().finish())?;
if existing_count > 0 {
return Err(HttpResponse::Conflict().finish());
}
sqlx::query(
"INSERT INTO domains (name, tld, ip, user_id, status) VALUES ($1, $2, $3, $4, 'pending')"
)
.bind(&domain.name)
.bind(&domain.tld)
.bind(&domain.ip)
.bind(user_id)
.execute(&app.db)
.await
.map_err(|_| HttpResponse::Conflict().finish())?;
Ok(domain)
}
pub(crate) async fn create_domain(
domain: web::Json<Domain>,
app: Data<AppState>,
req: HttpRequest
) -> impl Responder {
let extensions = req.extensions();
let claims = match extensions.get::<Claims>() {
Some(claims) => claims,
None => {
return HttpResponse::Unauthorized().json(Error {
msg: "Authentication required",
error: "AUTH_REQUIRED".into(),
});
}
};
// Check if user has registrations or domain invite codes remaining
let (user_registrations, user_domain_invites): (i32, i32) = match sqlx::query_as::<_, (i32, i32)>(
"SELECT registrations_remaining, domain_invite_codes FROM users WHERE id = $1"
)
.bind(claims.user_id)
.fetch_one(&app.db)
.await
{
Ok((registrations, domain_invites)) => (registrations, domain_invites),
Err(_) => {
return HttpResponse::InternalServerError().json(Error {
msg: "Database error",
error: "DB_ERROR".into(),
});
}
};
if user_registrations <= 0 && user_domain_invites <= 0 {
return HttpResponse::BadRequest().json(Error {
msg: "No domain registrations or domain invite codes remaining",
error: "NO_REGISTRATIONS_OR_INVITES".into(),
});
}
let domain = domain.into_inner();
match create_logic(domain.clone(), claims.user_id, app.as_ref()).await {
Ok(_) => {
// Start transaction for domain registration
let mut tx = match app.db.begin().await {
Ok(tx) => tx,
Err(_) => {
return HttpResponse::InternalServerError().json(Error {
msg: "Database error",
error: "DB_ERROR".into(),
});
}
};
// Get the created domain ID
let domain_id: i32 = match sqlx::query_scalar(
"SELECT id FROM domains WHERE name = $1 AND tld = $2 AND user_id = $3 ORDER BY created_at DESC LIMIT 1"
)
.bind(&domain.name)
.bind(&domain.tld)
.bind(claims.user_id)
.fetch_one(&mut *tx)
.await
{
Ok(id) => id,
Err(_) => {
let _ = tx.rollback().await;
return HttpResponse::InternalServerError().json(Error {
msg: "Failed to get domain ID",
error: "DB_ERROR".into(),
});
}
};
// Get user's current domain invite codes
let user_domain_invites: i32 = match sqlx::query_scalar(
"SELECT domain_invite_codes FROM users WHERE id = $1"
)
.bind(claims.user_id)
.fetch_one(&mut *tx)
.await
{
Ok(invites) => invites,
Err(_) => {
let _ = tx.rollback().await;
return HttpResponse::InternalServerError().json(Error {
msg: "Database error getting user domain invites",
error: "DB_ERROR".into(),
});
}
};
// Auto-consume domain invite code if available, otherwise use registration
if user_domain_invites > 0 {
// Use domain invite code
if let Err(_) = sqlx::query(
"UPDATE users SET domain_invite_codes = domain_invite_codes - 1 WHERE id = $1"
)
.bind(claims.user_id)
.execute(&mut *tx)
.await
{
let _ = tx.rollback().await;
return HttpResponse::InternalServerError().json(Error {
msg: "Failed to consume domain invite code",
error: "DB_ERROR".into(),
});
}
} else {
// Use regular registration
if let Err(_) = sqlx::query(
"UPDATE users SET registrations_remaining = registrations_remaining - 1 WHERE id = $1"
)
.bind(claims.user_id)
.execute(&mut *tx)
.await
{
let _ = tx.rollback().await;
return HttpResponse::InternalServerError().json(Error {
msg: "Failed to consume registration",
error: "DB_ERROR".into(),
});
}
}
// Commit the transaction
if let Err(_) = tx.commit().await {
return HttpResponse::InternalServerError().json(Error {
msg: "Transaction failed",
error: "DB_ERROR".into(),
});
}
// Send to Discord for approval
let registration = DomainRegistration {
id: domain_id,
domain_name: domain.name.clone(),
tld: domain.tld.clone(),
ip: domain.ip.clone(),
user_id: claims.user_id,
username: claims.username.clone(),
};
let bot_token = app.config.discord.bot_token.clone();
let channel_id = app.config.discord.channel_id;
tokio::spawn(async move {
if let Err(e) = send_domain_approval_request(
channel_id,
registration,
&bot_token,
).await {
log::error!("Failed to send Discord message: {}", e);
}
});
HttpResponse::Ok().json(serde_json::json!({
"message": "Domain registration submitted for approval",
"domain": format!("{}.{}", domain.name, domain.tld),
"status": "pending"
}))
}
Err(error) => error,
}
}
#[actix_web::get("/domain/{name}/{tld}")]
pub(crate) async fn get_domain(path: web::Path<(String, String)>, app: Data<AppState>) -> impl Responder {
let (name, tld) = path.into_inner();
match sqlx::query_as::<_, Domain>(
"SELECT id, name, tld, ip, user_id, status, denial_reason, created_at FROM domains WHERE name = $1 AND tld = $2 AND status = 'approved'"
)
.bind(&name)
.bind(&tld)
.fetch_optional(&app.db)
.await
{
Ok(Some(domain)) => HttpResponse::Ok().json(ResponseDomain {
tld: domain.tld,
name: domain.name,
ip: domain.ip,
records: None,
}),
Ok(None) => HttpResponse::NotFound().finish(),
Err(_) => HttpResponse::InternalServerError().finish(),
}
}
#[actix_web::put("/domain/{name}/{tld}")]
pub(crate) async fn update_domain(
path: web::Path<(String, String)>,
domain_update: web::Json<UpdateDomain>,
app: Data<AppState>,
req: HttpRequest
) -> impl Responder {
let extensions = req.extensions();
let claims = match extensions.get::<Claims>() {
Some(claims) => claims,
None => {
return HttpResponse::Unauthorized().json(Error {
msg: "Authentication required",
error: "AUTH_REQUIRED".into(),
});
}
};
let (name, tld) = path.into_inner();
match sqlx::query(
"UPDATE domains SET ip = $1 WHERE name = $2 AND tld = $3 AND user_id = $4 AND status = 'approved'"
)
.bind(&domain_update.ip)
.bind(&name)
.bind(&tld)
.bind(claims.user_id)
.execute(&app.db)
.await
{
Ok(result) => {
if result.rows_affected() == 1 {
HttpResponse::Ok().json(domain_update.into_inner())
} else {
HttpResponse::NotFound().json(Error {
msg: "Domain not found or not owned by user",
error: "DOMAIN_NOT_FOUND".into(),
})
}
}
Err(_) => HttpResponse::InternalServerError().finish(),
}
}
#[actix_web::delete("/domain/{name}/{tld}")]
pub(crate) async fn delete_domain(
path: web::Path<(String, String)>,
app: Data<AppState>,
req: HttpRequest
) -> impl Responder {
let extensions = req.extensions();
let claims = match extensions.get::<Claims>() {
Some(claims) => claims,
None => {
return HttpResponse::Unauthorized().json(Error {
msg: "Authentication required",
error: "AUTH_REQUIRED".into(),
});
}
};
let (name, tld) = path.into_inner();
match sqlx::query(
"DELETE FROM domains WHERE name = $1 AND tld = $2 AND user_id = $3"
)
.bind(&name)
.bind(&tld)
.bind(claims.user_id)
.execute(&app.db)
.await
{
Ok(result) => {
if result.rows_affected() == 1 {
HttpResponse::Ok().finish()
} else {
HttpResponse::NotFound().json(Error {
msg: "Domain not found or not owned by user",
error: "DOMAIN_NOT_FOUND".into(),
})
}
}
Err(_) => HttpResponse::InternalServerError().finish(),
}
}
#[actix_web::post("/domain/check")]
pub(crate) async fn check_domain(query: web::Json<DomainQuery>, app: Data<AppState>) -> impl Responder {
let DomainQuery { name, tld } = query.into_inner();
let result = helpers::is_domain_taken(&name, tld.as_deref(), app).await;
HttpResponse::Ok().json(result)
}
#[actix_web::get("/domains")]
pub(crate) async fn get_domains(query: web::Query<PaginationParams>, app: Data<AppState>) -> impl Responder {
let page = query.page.unwrap_or(1);
let limit = query.page_size.unwrap_or(15);
if page == 0 || limit == 0 {
return HttpResponse::BadRequest().json(Error {
msg: "page_size or page must be greater than 0",
error: "Invalid pagination parameters".into(),
});
}
if limit > 100 {
return HttpResponse::BadRequest().json(Error {
msg: "page_size must be greater than 0 and less than or equal to 100",
error: "Invalid pagination parameters".into(),
});
}
let offset = (page - 1) * limit;
match sqlx::query_as::<_, Domain>(
"SELECT id, name, tld, ip, user_id, status, denial_reason, created_at FROM domains WHERE status = 'approved' ORDER BY created_at DESC LIMIT $1 OFFSET $2"
)
.bind(limit as i64)
.bind(offset as i64)
.fetch_all(&app.db)
.await
{
Ok(domains) => {
let response_domains: Vec<ResponseDomain> = domains
.into_iter()
.map(|domain| ResponseDomain {
tld: domain.tld,
name: domain.name,
ip: domain.ip,
records: None,
})
.collect();
HttpResponse::Ok().json(PaginationResponse {
domains: response_domains,
page,
limit,
})
}
Err(err) => HttpResponse::InternalServerError().json(Error {
msg: "Failed to fetch domains",
error: err.to_string(),
}),
}
}
#[actix_web::get("/tlds")]
pub(crate) async fn get_tlds(app: Data<AppState>) -> impl Responder { HttpResponse::Ok().json(&*app.config.tld_list()) }

View File

@@ -1,5 +1,5 @@
mod config; mod config;
mod http; mod gurt_server;
mod secret; mod secret;
mod auth; mod auth;
mod discord_bot; mod discord_bot;
@@ -27,12 +27,11 @@ struct Cli {
#[derive(Subcommand)] #[derive(Subcommand)]
enum Commands { enum Commands {
/// Start the daemon
Start, Start,
} }
#[tokio::main]
fn main() { async fn main() {
let cli = Cli::parse(); let cli = Cli::parse();
let mut env = pretty_env_logger::formatted_builder(); let mut env = pretty_env_logger::formatted_builder();
let level = cli.verbose.log_level_filter(); let level = cli.verbose.log_level_filter();
@@ -47,7 +46,7 @@ fn main() {
match &cli.command { match &cli.command {
Commands::Start => { Commands::Start => {
if let Err(err) = http::start(cli) { if let Err(err) = gurt_server::start(cli).await {
log::error!("Failed to start server: {err}") log::error!("Failed to start server: {err}")
} }
} }

View File

@@ -0,0 +1,10 @@
{
"permissions": {
"allow": [
"WebSearch",
"WebFetch(domain:github.com)"
],
"deny": [],
"ask": []
}
}

View File

@@ -85,7 +85,6 @@ layout_mode = 1
offset_right = 200.0 offset_right = 200.0
offset_bottom = 35.0 offset_bottom = 35.0
theme = ExtResource("2_theme") theme = ExtResource("2_theme")
text = "test"
placeholder_text = "Enter text..." placeholder_text = "Enter text..."
caret_blink = true caret_blink = true

View File

@@ -462,9 +462,33 @@ static func parse_utility_class_internal(rule: CSSRule, utility_name: String) ->
return return
# Handle font weight # Handle font weight
if utility_name == "font-thin":
rule.properties["font-thin"] = true
return
if utility_name == "font-extralight":
rule.properties["font-extralight"] = true
return
if utility_name == "font-light":
rule.properties["font-light"] = true
return
if utility_name == "font-normal":
rule.properties["font-normal"] = true
return
if utility_name == "font-medium":
rule.properties["font-medium"] = true
return
if utility_name == "font-semibold":
rule.properties["font-semibold"] = true
return
if utility_name == "font-bold": if utility_name == "font-bold":
rule.properties["font-bold"] = true rule.properties["font-bold"] = true
return return
if utility_name == "font-extrabold":
rule.properties["font-extrabold"] = true
return
if utility_name == "font-black":
rule.properties["font-black"] = true
return
# Handle font family # Handle font family
if utility_name == "font-sans": if utility_name == "font-sans":
@@ -478,7 +502,7 @@ static func parse_utility_class_internal(rule: CSSRule, utility_name: String) ->
rule.properties["font-mono"] = true rule.properties["font-mono"] = true
return return
var reserved_font_styles = ["font-sans", "font-serif", "font-mono", "font-bold", "font-italic"] var reserved_font_styles = ["font-sans", "font-serif", "font-mono", "font-thin", "font-extralight", "font-light", "font-normal", "font-medium", "font-semibold", "font-bold", "font-extrabold", "font-black", "font-italic"]
# Handle custom font families like font-roboto # Handle custom font families like font-roboto
if utility_name.begins_with("font-") and not utility_name in reserved_font_styles: if utility_name.begins_with("font-") and not utility_name in reserved_font_styles:
var font_name = utility_name.substr(5) # after 'font-' var font_name = utility_name.substr(5) # after 'font-'
@@ -521,8 +545,8 @@ static func parse_utility_class_internal(rule: CSSRule, utility_name: String) ->
val = val.substr(1, val.length() - 2) val = val.substr(1, val.length() - 2)
rule.properties["width"] = SizeUtils.parse_size(val) rule.properties["width"] = SizeUtils.parse_size(val)
return return
# Height # Height, but h-full is temporarily disabled since it fucks with Yoga layout engine
if utility_name.begins_with("h-"): if utility_name.begins_with("h-") and utility_name != "h-full":
var val = utility_name.substr(2) var val = utility_name.substr(2)
if val.begins_with("[") and val.ends_with("]"): if val.begins_with("[") and val.ends_with("]"):
val = val.substr(1, val.length() - 2) val = val.substr(1, val.length() - 2)

View File

@@ -412,47 +412,69 @@ func apply_element_styles(node: Control, element: HTMLElement, parser: HTMLParse
label.text = text label.text = text
static func apply_element_bbcode_formatting(element: HTMLElement, styles: Dictionary, content: String, parser: HTMLParser = null) -> String: static func apply_element_bbcode_formatting(element: HTMLElement, styles: Dictionary, content: String, parser: HTMLParser = null) -> String:
# Apply general styling first (color, font-weight) for all elements
var formatted_content = content
# Apply font weight (bold/semibold/etc)
if styles.has("font-bold") and styles["font-bold"]:
formatted_content = "[b]" + formatted_content + "[/b]"
elif styles.has("font-semibold") and styles["font-semibold"]:
formatted_content = "[b]" + formatted_content + "[/b]" # BBCode doesn't have semibold, use bold
# Apply italic
if styles.has("font-italic") and styles["font-italic"]:
formatted_content = "[i]" + formatted_content + "[/i]"
# Apply underline
if styles.has("underline") and styles["underline"]:
formatted_content = "[u]" + formatted_content + "[/u]"
# Apply color
if styles.has("color"):
var color = styles["color"]
if typeof(color) == TYPE_COLOR:
color = "#" + color.to_html(false)
else:
color = str(color)
formatted_content = "[color=%s]%s[/color]" % [color, formatted_content]
# Apply tag-specific formatting
match element.tag_name: match element.tag_name:
"b": "b":
if styles.has("font-bold") and styles["font-bold"]: if not (styles.has("font-bold") and styles["font-bold"]):
return "[b]" + content + "[/b]" formatted_content = "[b]" + formatted_content + "[/b]"
"i": "i":
if styles.has("font-italic") and styles["font-italic"]: if not (styles.has("font-italic") and styles["font-italic"]):
return "[i]" + content + "[/i]" formatted_content = "[i]" + formatted_content + "[/i]"
"u": "u":
if styles.has("underline") and styles["underline"]: if not (styles.has("underline") and styles["underline"]):
return "[u]" + content + "[/u]" formatted_content = "[u]" + formatted_content + "[/u]"
"small": "small":
if styles.has("font-size"): if styles.has("font-size"):
return "[font_size=%d]%s[/font_size]" % [styles["font-size"], content] formatted_content = "[font_size=%d]%s[/font_size]" % [styles["font-size"], formatted_content]
else: else:
return "[font_size=20]%s[/font_size]" % content formatted_content = "[font_size=20]%s[/font_size]" % formatted_content
"mark": "mark":
if styles.has("bg"): if styles.has("bg"):
var color = styles["bg"] var bg_color = styles["bg"]
if typeof(color) == TYPE_COLOR: if typeof(bg_color) == TYPE_COLOR:
color = color.to_html(false) bg_color = bg_color.to_html(false)
return "[bgcolor=#%s]%s[/bgcolor]" % [color, content] formatted_content = "[bgcolor=#%s]%s[/bgcolor]" % [bg_color, formatted_content]
else: else:
return "[bgcolor=#FFFF00]%s[/bgcolor]" % content formatted_content = "[bgcolor=#FFFF00]%s[/bgcolor]" % formatted_content
"code": "code":
if styles.has("font-size"): if styles.has("font-size"):
return "[font_size=%d][code]%s[/code][/font_size]" % [styles["font-size"], content] formatted_content = "[font_size=%d][code]%s[/code][/font_size]" % [styles["font-size"], formatted_content]
else: else:
return "[font_size=20][code]%s[/code][/font_size]" % content formatted_content = "[font_size=20][code]%s[/code][/font_size]" % formatted_content
"a": "a":
var href = element.get_attribute("href") var href = element.get_attribute("href")
var color = "#1a0dab"
if styles.has("color"):
var c = styles["color"]
if typeof(c) == TYPE_COLOR:
color = "#" + c.to_html(false)
else:
color = str(c)
if href.length() > 0: if href.length() > 0:
# Pass raw href - URL resolution happens in handle_link_click # Pass raw href - URL resolution happens in handle_link_click
return "[color=%s][url=%s]%s[/url][/color]" % [color, href, content] formatted_content = "[url=%s]%s[/url]" % [href, formatted_content]
return content
return formatted_content
static func get_bbcode_with_styles(element: HTMLElement, styles: Dictionary, parser: HTMLParser) -> String: static func get_bbcode_with_styles(element: HTMLElement, styles: Dictionary, parser: HTMLParser) -> String:
var text = "" var text = ""

View File

@@ -661,18 +661,43 @@ func _handle_text_setting(operation: Dictionary):
if dom_node: if dom_node:
var text_node = get_dom_node(dom_node, "text") var text_node = get_dom_node(dom_node, "text")
if text_node: if text_node:
if text_node.has_method("set_text"): if text_node is RichTextLabel:
var formatted_text = element.get_bbcode_formatted_text(dom_parser)
formatted_text = "[font_size=24]%s[/font_size]" % formatted_text
text_node.text = formatted_text
text_node.call_deferred("_auto_resize_to_content")
elif text_node.has_method("set_text"):
text_node.set_text(text) text_node.set_text(text)
elif "text" in text_node: elif "text" in text_node:
text_node.text = text text_node.text = text
if text_node.has_method("_auto_resize_to_content"):
text_node.call_deferred("_auto_resize_to_content")
else:
var rich_text_label = _find_rich_text_label_recursive(dom_node)
if rich_text_label:
var formatted_text = element.get_bbcode_formatted_text(dom_parser)
formatted_text = "[font_size=24]%s[/font_size]" % formatted_text
rich_text_label.text = formatted_text
rich_text_label.call_deferred("_auto_resize_to_content")
func _find_rich_text_label_recursive(node: Node) -> RichTextLabel:
if node is RichTextLabel:
return node
for child in node.get_children():
var result = _find_rich_text_label_recursive(child)
if result:
return result
return null
func _handle_text_getting(operation: Dictionary): func _handle_text_getting(operation: Dictionary):
var selector: String = operation.selector var selector: String = operation.selector
var element = SelectorUtils.find_first_matching(selector, dom_parser.parse_result.all_elements) var element = SelectorUtils.find_first_matching(selector, dom_parser.parse_result.all_elements)
if element: if element:
# Return the element's cached text content from the HTML element
# This avoids the need for a callback system since we have the text cached
return element.text_content return element.text_content
return "" return ""

View File

@@ -6,7 +6,7 @@ const SECONDARY_COLOR = Color(43/255.0, 43/255.0, 43/255.0, 1)
const HOVER_COLOR = Color(0, 0, 0, 1) const HOVER_COLOR = Color(0, 0, 0, 1)
const DEFAULT_CSS = """ const DEFAULT_CSS = """
body { text-base text-[#000000] text-left bg-white } body { text-base text-[#000000] text-left bg-white font-serif }
h1 { text-5xl font-bold } h1 { text-5xl font-bold }
h2 { text-4xl font-bold } h2 { text-4xl font-bold }
h3 { text-3xl font-bold } h3 { text-3xl font-bold }
@@ -21,7 +21,7 @@ code { text-xl font-mono }
a { text-[#1a0dab] } a { text-[#1a0dab] }
pre { text-xl font-mono } pre { text-xl font-mono }
button { text-[16px] bg-[#1b1b1b] rounded-md text-white hover:bg-[#2a2a2a] active:bg-[#101010] } button { text-[16px] bg-[#1b1b1b] rounded-md text-white hover:bg-[#2a2a2a] active:bg-[#101010] px-3 py-1.5 }
button[disabled] { bg-[#666666] text-[#999999] cursor-not-allowed } button[disabled] { bg-[#666666] text-[#999999] cursor-not-allowed }
""" """

View File

@@ -1,7 +1,7 @@
extends RefCounted extends RefCounted
class_name GurtProtocol class_name GurtProtocol
const DNS_API_URL = "http://localhost:8080" const DNS_API_URL = "gurt://localhost:8877"
static func is_gurt_domain(url: String) -> bool: static func is_gurt_domain(url: String) -> bool:
if url.begins_with("gurt://"): if url.begins_with("gurt://"):
@@ -52,41 +52,39 @@ static func is_ip_address(address: String) -> bool:
return true return true
static func fetch_domain_info(name: String, tld: String) -> Dictionary: static func fetch_domain_info(name: String, tld: String) -> Dictionary:
var http_request = HTTPRequest.new() var path = "/domain/" + name + "/" + tld
var tree = Engine.get_main_loop() var dns_address = "localhost:8877"
tree.current_scene.add_child(http_request)
http_request.timeout = 5.0 print("DNS API URL: gurt://" + dns_address + path)
var url = DNS_API_URL + "/domain/" + name + "/" + tld var response = await fetch_content_via_gurt_direct(dns_address, path)
print("DNS API URL: ", url)
var error = http_request.request(url) if response.has("error"):
if "No response from GURT server" in response.error or "Failed to create GURT client" in response.error:
return {"error": "DNS server is not responding"}
else:
return {"error": "Failed to make DNS request"}
if error != OK: if not response.has("content"):
print("HTTP request failed with error: ", error)
http_request.queue_free()
return {"error": "Failed to make DNS request"}
var response = await http_request.request_completed
http_request.queue_free()
if response[1] == 0 and response[3].size() == 0:
return {"error": "DNS server is not responding"} return {"error": "DNS server is not responding"}
var http_code = response[1] var content = response.content
var body = response[3] if content.is_empty():
return {"error": "DNS server is not responding"}
if http_code != 200:
return {"error": "Domain not found or not approved"}
var json = JSON.new() var json = JSON.new()
var parse_result = json.parse(body.get_string_from_utf8()) var parse_result = json.parse(content.get_string_from_utf8())
if parse_result != OK: if parse_result != OK:
return {"error": "Invalid JSON response from DNS server"} return {"error": "Invalid JSON response from DNS server"}
return json.data var data = json.data
# Check if the response indicates an error (like 404)
if data is Dictionary and data.has("error"):
return {"error": "Domain not found or not approved"}
return data
static func fetch_content_via_gurt(ip: String, path: String = "/") -> Dictionary: static func fetch_content_via_gurt(ip: String, path: String = "/") -> Dictionary:
var client = GurtProtocolClient.new() var client = GurtProtocolClient.new()

View File

@@ -62,16 +62,16 @@ static func apply_element_styles(node: Control, element: HTMLParser.HTMLElement,
node.size_flags_stretch_ratio = percentage_value node.size_flags_stretch_ratio = percentage_value
else: else:
node.custom_minimum_size.x = width node.custom_minimum_size.x = width
var should_center_h = styles.has("mx-auto") or styles.has("justify-self-center") or (styles.has("text-align") and styles["text-align"] == "center") node.size_flags_horizontal = Control.SIZE_SHRINK_BEGIN
node.size_flags_horizontal = Control.SIZE_SHRINK_CENTER if should_center_h else Control.SIZE_SHRINK_BEGIN node.set_meta("size_flags_horizontal_set", true)
if height != null: if height != null:
if SizingUtils.is_percentage(height): if SizingUtils.is_percentage(height):
node.size_flags_vertical = Control.SIZE_EXPAND_FILL node.size_flags_vertical = Control.SIZE_EXPAND_FILL
else: else:
node.custom_minimum_size.y = height node.custom_minimum_size.y = height
var should_center_v = styles.has("my-auto") or styles.has("align-self-center") node.size_flags_vertical = Control.SIZE_SHRINK_BEGIN
node.size_flags_vertical = Control.SIZE_SHRINK_CENTER if should_center_v else Control.SIZE_SHRINK_BEGIN node.set_meta("size_flags_vertical_set", true)
node.set_meta("size_flags_set_by_style_manager", true) node.set_meta("size_flags_set_by_style_manager", true)
elif node is VBoxContainer or node is HBoxContainer or node is Container: elif node is VBoxContainer or node is HBoxContainer or node is Container:
@@ -97,6 +97,9 @@ static func apply_element_styles(node: Control, element: HTMLParser.HTMLElement,
else: else:
# regular controls # regular controls
SizingUtils.apply_regular_control_sizing(node, width, height, styles) SizingUtils.apply_regular_control_sizing(node, width, height, styles)
# Apply centering for FlexContainers
apply_flexcontainer_centering(node, styles)
if label and label != node: if label and label != node:
label.anchors_preset = Control.PRESET_FULL_RECT label.anchors_preset = Control.PRESET_FULL_RECT
@@ -145,7 +148,7 @@ static func apply_element_styles(node: Control, element: HTMLParser.HTMLElement,
if needs_styling: if needs_styling:
# If node is a MarginContainer wrapper, get the actual content node for styling # If node is a MarginContainer wrapper, get the actual content node for styling
var content_node = node var content_node = node
if node is MarginContainer and node.name.begins_with("MarginWrapper_"): if node is MarginContainer and node.has_meta("is_margin_wrapper"):
if node.get_child_count() > 0: if node.get_child_count() > 0:
content_node = node.get_child(0) content_node = node.get_child(0)
@@ -168,7 +171,7 @@ static func apply_element_styles(node: Control, element: HTMLParser.HTMLElement,
target_node_for_bg.call_deferred("add_background_rect") target_node_for_bg.call_deferred("add_background_rect")
else: else:
var content_node = node var content_node = node
if node is MarginContainer and node.name.begins_with("MarginWrapper_"): if node is MarginContainer and node.has_meta("is_margin_wrapper"):
if node.get_child_count() > 0: if node.get_child_count() > 0:
content_node = node.get_child(0) content_node = node.get_child(0)
@@ -195,6 +198,7 @@ static func apply_element_styles(node: Control, element: HTMLParser.HTMLElement,
apply_transform_properties(transform_target, styles) apply_transform_properties(transform_target, styles)
return node return node
static func apply_stylebox_to_panel_container(panel_container: PanelContainer, styles: Dictionary) -> void: static func apply_stylebox_to_panel_container(panel_container: PanelContainer, styles: Dictionary) -> void:
@@ -273,12 +277,12 @@ static func clear_styling_metadata(node: Control) -> void:
static func handle_margin_wrapper(node: Control, styles: Dictionary, needs_margin: bool): static func handle_margin_wrapper(node: Control, styles: Dictionary, needs_margin: bool):
var current_wrapper = null var current_wrapper = null
if node is MarginContainer and node.name.begins_with("MarginWrapper_"): if node is MarginContainer and node.has_meta("is_margin_wrapper"):
current_wrapper = node current_wrapper = node
elif node.get_parent() and node.get_parent() is MarginContainer: elif node.get_parent() and node.get_parent() is MarginContainer:
var parent = node.get_parent() var parent = node.get_parent()
if parent.name.begins_with("MarginWrapper_"): if parent.has_meta("is_margin_wrapper"):
current_wrapper = parent current_wrapper = parent
if needs_margin: if needs_margin:
@@ -323,6 +327,7 @@ static func remove_margin_wrapper(margin_container: MarginContainer, original_no
static func apply_margin_wrapper(node: Control, styles: Dictionary) -> Control: static func apply_margin_wrapper(node: Control, styles: Dictionary) -> Control:
var margin_container = MarginContainer.new() var margin_container = MarginContainer.new()
margin_container.name = "MarginWrapper_" + node.name margin_container.name = "MarginWrapper_" + node.name
margin_container.set_meta("is_margin_wrapper", true)
var has_explicit_width = styles.has("width") var has_explicit_width = styles.has("width")
var has_explicit_height = styles.has("height") var has_explicit_height = styles.has("height")
@@ -406,11 +411,16 @@ static func apply_styles_to_label(label: Control, styles: Dictionary, element: H
if not FontManager.loaded_fonts.has(font_family): if not FontManager.loaded_fonts.has(font_family):
# Font not loaded yet, use sans-serif as fallback # Font not loaded yet, use sans-serif as fallback
var fallback_font = FontManager.get_font("sans-serif") var fallback_font = FontManager.get_font("sans-serif")
apply_font_to_label(label, fallback_font) apply_font_to_label(label, fallback_font, styles)
if font_resource: if font_resource:
apply_font_to_label(label, font_resource) apply_font_to_label(label, font_resource, styles)
else:
# No custom font family, but check if we need to apply font weight
if styles.has("font-thin") or styles.has("font-extralight") or styles.has("font-light") or styles.has("font-normal") or styles.has("font-medium") or styles.has("font-semibold") or styles.has("font-extrabold") or styles.has("font-black"):
var default_font = FontManager.get_font("sans-serif")
apply_font_to_label(label, default_font, styles)
# Apply font size # Apply font size
if styles.has("font-size"): if styles.has("font-size"):
font_size = int(styles["font-size"]) font_size = int(styles["font-size"])
@@ -487,15 +497,6 @@ static func apply_styles_to_label(label: Control, styles: Dictionary, element: H
label.text = styled_text label.text = styled_text
static func apply_flex_container_properties(node: FlexContainer, styles: Dictionary) -> void:
FlexUtils.apply_flex_container_properties(node, styles)
static func apply_flex_item_properties(node: Control, styles: Dictionary) -> void:
FlexUtils.apply_flex_item_properties(node, styles)
static func parse_flex_value(val):
return FlexUtils.parse_flex_value(val)
static func apply_body_styles(body: HTMLParser.HTMLElement, parser: HTMLParser, website_container: Control, website_background: Control) -> void: static func apply_body_styles(body: HTMLParser.HTMLElement, parser: HTMLParser, website_container: Control, website_background: Control) -> void:
var styles = parser.get_element_styles_with_inheritance(body, "", []) var styles = parser.get_element_styles_with_inheritance(body, "", [])
@@ -553,8 +554,35 @@ static func apply_body_styles(body: HTMLParser.HTMLElement, parser: HTMLParser,
static func parse_radius(radius_str: String) -> int: static func parse_radius(radius_str: String) -> int:
return SizeUtils.parse_radius(radius_str) return SizeUtils.parse_radius(radius_str)
static func apply_font_to_label(label: RichTextLabel, font_resource: Font) -> void: static func apply_font_to_label(label: RichTextLabel, font_resource: Font, styles: Dictionary = {}) -> void:
label.add_theme_font_override("normal_font", font_resource) # Create normal font with appropriate weight
var normal_font = SystemFont.new()
normal_font.font_names = font_resource.font_names if font_resource is SystemFont else ["Arial"]
# Set weight based on styles
var font_weight = 400 # Default normal weight
if styles.has("font-thin"):
font_weight = 100
elif styles.has("font-extralight"):
font_weight = 200
elif styles.has("font-light"):
font_weight = 300
elif styles.has("font-normal"):
font_weight = 400
elif styles.has("font-medium"):
font_weight = 500
elif styles.has("font-semibold"):
font_weight = 600
elif styles.has("font-bold"):
font_weight = 700
elif styles.has("font-extrabold"):
font_weight = 800
elif styles.has("font-black"):
font_weight = 900
normal_font.font_weight = font_weight
label.add_theme_font_override("normal_font", normal_font)
var bold_font = SystemFont.new() var bold_font = SystemFont.new()
bold_font.font_names = font_resource.font_names if font_resource is SystemFont else ["Arial"] bold_font.font_names = font_resource.font_names if font_resource is SystemFont else ["Arial"]
@@ -761,3 +789,19 @@ static func await_and_restore_transform(node: Control, target_scale: Vector2, ta
node.scale = target_scale node.scale = target_scale
node.rotation = target_rotation node.rotation = target_rotation
node.pivot_offset = node.size / 2 node.pivot_offset = node.size / 2
static func apply_flexcontainer_centering(node: Control, styles: Dictionary) -> void:
if not node is FlexContainer:
return
var should_center_h = styles.has("mx-auto") or styles.has("justify-self-center") or (styles.has("text-align") and styles["text-align"] == "center")
var should_center_v = styles.has("my-auto") or styles.has("align-self-center")
if should_center_h and not node.has_meta("size_flags_horizontal_set"):
node.size_flags_horizontal = Control.SIZE_SHRINK_CENTER
if should_center_v and not node.has_meta("size_flags_vertical_set"):
node.size_flags_vertical = Control.SIZE_SHRINK_CENTER
if should_center_h or should_center_v:
node.set_meta("size_flags_set_by_style_manager", true)

View File

@@ -14,9 +14,14 @@ static func apply_flex_container_properties(node, styles: Dictionary) -> void:
# Flex wrap # Flex wrap
if styles.has("flex-wrap"): if styles.has("flex-wrap"):
match styles["flex-wrap"]: match styles["flex-wrap"]:
"nowrap": node.flex_wrap = FlexContainer.FlexWrap.NoWrap "nowrap":
"wrap": node.flex_wrap = FlexContainer.FlexWrap.Wrap node.flex_wrap = FlexContainer.FlexWrap.NoWrap
"wrap-reverse": node.flex_wrap = FlexContainer.FlexWrap.WrapReverse "wrap":
node.flex_wrap = FlexContainer.FlexWrap.Wrap
# this is probably not needed but i dont feel like testing it
node.flex_property_changed("flex_wrap", FlexContainer.FlexWrap.Wrap)
"wrap-reverse":
node.flex_wrap = FlexContainer.FlexWrap.WrapReverse
# Justify content # Justify content
if styles.has("justify-content"): if styles.has("justify-content"):
match styles["justify-content"]: match styles["justify-content"]:

View File

@@ -195,6 +195,13 @@ static func trigger_element_restyle(element: HTMLParser.HTMLElement, dom_parser:
var dom_node = dom_parser.parse_result.dom_nodes.get(element_id, null) var dom_node = dom_parser.parse_result.dom_nodes.get(element_id, null)
if not dom_node: if not dom_node:
return return
# Check if element has the "hidden" class before styling
var has_hidden_class = false
var current_style = element.get_attribute("style", "")
if current_style.length() > 0:
var style_classes = CSSParser.smart_split_utility_classes(current_style)
has_hidden_class = "hidden" in style_classes
# margins, wrappers, etc. # margins, wrappers, etc.
var updated_dom_node = StyleManager.apply_element_styles(dom_node, element, dom_parser) var updated_dom_node = StyleManager.apply_element_styles(dom_node, element, dom_parser)
@@ -204,9 +211,15 @@ static func trigger_element_restyle(element: HTMLParser.HTMLElement, dom_parser:
dom_parser.parse_result.dom_nodes[element_id] = updated_dom_node dom_parser.parse_result.dom_nodes[element_id] = updated_dom_node
dom_node = updated_dom_node dom_node = updated_dom_node
# Apply visibility state to the correct node (wrapper or content)
if has_hidden_class:
dom_node.visible = false
else:
dom_node.visible = true
# Find node # Find node
var actual_element_node = dom_node var actual_element_node = dom_node
if dom_node is MarginContainer and dom_node.name.begins_with("MarginWrapper_"): if dom_node is MarginContainer and dom_node.has_meta("is_margin_wrapper"):
if dom_node.get_child_count() > 0: if dom_node.get_child_count() > 0:
actual_element_node = dom_node.get_child(0) actual_element_node = dom_node.get_child(0)
@@ -223,7 +236,7 @@ static func trigger_element_restyle(element: HTMLParser.HTMLElement, dom_parser:
static func update_element_text_content(dom_node: Control, element: HTMLParser.HTMLElement, dom_parser: HTMLParser) -> void: static func update_element_text_content(dom_node: Control, element: HTMLParser.HTMLElement, dom_parser: HTMLParser) -> void:
# Get node # Get node
var content_node = dom_node var content_node = dom_node
if dom_node is MarginContainer and dom_node.name.begins_with("MarginWrapper_"): if dom_node is MarginContainer and dom_node.has_meta("is_margin_wrapper"):
if dom_node.get_child_count() > 0: if dom_node.get_child_count() > 0:
content_node = dom_node.get_child(0) content_node = dom_node.get_child(0)

View File

@@ -515,6 +515,12 @@ static func add_element_methods(vm: LuauVM, lua_api: LuaAPI) -> void:
vm.lua_pushcallable(LuaDOMUtils._element_create_tween_wrapper, "element.createTween") vm.lua_pushcallable(LuaDOMUtils._element_create_tween_wrapper, "element.createTween")
vm.lua_setfield(-2, "createTween") vm.lua_setfield(-2, "createTween")
vm.lua_pushcallable(LuaDOMUtils._element_show_wrapper, "element.show")
vm.lua_setfield(-2, "show")
vm.lua_pushcallable(LuaDOMUtils._element_hide_wrapper, "element.hide")
vm.lua_setfield(-2, "hide")
_add_classlist_support(vm, lua_api) _add_classlist_support(vm, lua_api)
vm.lua_newtable() vm.lua_newtable()
@@ -881,6 +887,24 @@ static func _element_index_wrapper(vm: LuauVM) -> int:
# Fallback to empty array # Fallback to empty array
vm.lua_newtable() vm.lua_newtable()
return 1 return 1
"visible":
if lua_api:
# Get element ID and find the element
vm.lua_getfield(1, "_element_id")
var element_id: String = vm.lua_tostring(-1)
vm.lua_pop(1)
var element = lua_api.dom_parser.find_by_id(element_id) if element_id != "body" else lua_api.dom_parser.find_first("body")
if element:
# Check if element has display: none (hidden class)
var class_attr = element.get_attribute("class")
var is_hidden = "hidden" in class_attr or element.get_attribute("style").contains("display:none") or element.get_attribute("style").contains("display: none")
vm.lua_pushboolean(not is_hidden)
return 1
# Fallback to true (visible by default)
vm.lua_pushboolean(true)
return 1
_: _:
# Check for DOM traversal properties first # Check for DOM traversal properties first
if lua_api: if lua_api:
@@ -1034,6 +1058,48 @@ static func _element_newindex_wrapper(vm: LuauVM) -> int:
emit_dom_operation(lua_api, operation) emit_dom_operation(lua_api, operation)
return 0 return 0
"visible":
var is_visible: bool = vm.lua_toboolean(3)
vm.lua_getfield(1, "_element_id")
var element_id: String = vm.lua_tostring(-1)
vm.lua_pop(1)
var element = lua_api.dom_parser.find_by_id(element_id) if element_id != "body" else lua_api.dom_parser.find_first("body")
if element:
var class_attr = element.get_attribute("class")
var classes = class_attr.split(" ") if not class_attr.is_empty() else []
if is_visible:
# Remove hidden class if present
var hidden_index = classes.find("hidden")
if hidden_index >= 0:
classes.remove_at(hidden_index)
var new_class_attr = " ".join(classes).strip_edges()
element.set_attribute("class", new_class_attr)
# Update visual element
var operation = {
"type": "remove_class",
"element_id": element_id,
"class_name": "hidden"
}
emit_dom_operation(lua_api, operation)
else:
# Add hidden class if not present
if not "hidden" in classes:
classes.append("hidden")
var new_class_attr = " ".join(classes).strip_edges()
element.set_attribute("class", new_class_attr)
# Update visual element
var operation = {
"type": "add_class",
"element_id": element_id,
"class_name": "hidden"
}
emit_dom_operation(lua_api, operation)
return 0
_: _:
# Store in table normally # Store in table normally
vm.lua_pushvalue(2) vm.lua_pushvalue(2)
@@ -1041,6 +1107,71 @@ static func _element_newindex_wrapper(vm: LuauVM) -> int:
vm.lua_rawset(1) vm.lua_rawset(1)
return 0 return 0
static func _element_show_wrapper(vm: LuauVM) -> int:
var lua_api = vm.get_meta("lua_api") as LuaAPI
if not lua_api:
return 0
vm.luaL_checktype(1, vm.LUA_TTABLE)
vm.lua_getfield(1, "_element_id")
var element_id: String = vm.lua_tostring(-1)
vm.lua_pop(1)
var element = lua_api.dom_parser.find_by_id(element_id) if element_id != "body" else lua_api.dom_parser.find_first("body")
if element:
var class_attr = element.get_attribute("class")
var classes = class_attr.split(" ") if not class_attr.is_empty() else []
# Remove hidden class if present
var hidden_index = classes.find("hidden")
if hidden_index >= 0:
classes.remove_at(hidden_index)
var new_class_attr = " ".join(classes).strip_edges()
element.set_attribute("class", new_class_attr)
# Update visual element
var operation = {
"type": "remove_class",
"element_id": element_id,
"class_name": "hidden"
}
emit_dom_operation(lua_api, operation)
return 0
static func _element_hide_wrapper(vm: LuauVM) -> int:
var lua_api = vm.get_meta("lua_api") as LuaAPI
if not lua_api:
return 0
vm.luaL_checktype(1, vm.LUA_TTABLE)
vm.lua_getfield(1, "_element_id")
var element_id: String = vm.lua_tostring(-1)
vm.lua_pop(1)
var element = lua_api.dom_parser.find_by_id(element_id) if element_id != "body" else lua_api.dom_parser.find_first("body")
if element:
var class_attr = element.get_attribute("class")
var classes = class_attr.split(" ") if not class_attr.is_empty() else []
# Add hidden class if not present
if not "hidden" in classes:
classes.append("hidden")
var new_class_attr = " ".join(classes).strip_edges()
element.set_attribute("class", new_class_attr)
# Update visual element
var operation = {
"type": "add_class",
"element_id": element_id,
"class_name": "hidden"
}
emit_dom_operation(lua_api, operation)
return 0
static func _element_create_tween_wrapper(vm: LuauVM) -> int: static func _element_create_tween_wrapper(vm: LuauVM) -> int:
var lua_api = vm.get_meta("lua_api") as LuaAPI var lua_api = vm.get_meta("lua_api") as LuaAPI
if not lua_api: if not lua_api:

View File

@@ -126,6 +126,8 @@ static func _response_ok_handler(vm: LuauVM) -> int:
return 1 return 1
static func make_http_request(url: String, method: String, headers: PackedStringArray, body: String) -> Dictionary: static func make_http_request(url: String, method: String, headers: PackedStringArray, body: String) -> Dictionary:
if url.begins_with("gurt://"):
return make_gurt_request(url, method, headers, body)
var http_client = HTTPClient.new() var http_client = HTTPClient.new()
var response_data = { var response_data = {
"status": 0, "status": 0,
@@ -269,3 +271,63 @@ static func make_http_request(url: String, method: String, headers: PackedString
http_client.close() http_client.close()
return response_data return response_data
static var _gurt_client: GurtProtocolClient = null
static func make_gurt_request(url: String, method: String, headers: PackedStringArray, body: String) -> Dictionary:
var response_data = {
"status": 0,
"status_text": "Network Error",
"headers": {},
"body": ""
}
# Reuse existing client or create new one
if _gurt_client == null:
_gurt_client = GurtProtocolClient.new()
if not _gurt_client.create_client(10):
response_data.status = 0
response_data.status_text = "Connection Failed"
return response_data
var client = _gurt_client
# Convert headers array to dictionary
var headers_dict = {}
for header in headers:
var parts = header.split(":", 1)
if parts.size() == 2:
headers_dict[parts[0].strip_edges()] = parts[1].strip_edges()
# Prepare request options
var options = {
"method": method
}
if not headers_dict.is_empty():
options["headers"] = headers_dict
if not body.is_empty():
options["body"] = body
var response = client.request(url, options)
# Keep connection alive for reuse instead of disconnecting after every request
# client.disconnect()
if not response:
response_data.status = 0
response_data.status_text = "No Response"
return response_data
response_data.status = response.status_code
response_data.status_text = response.status_message if response.status_message else "OK"
response_data.headers = response.headers if response.headers else {}
var body_content = response.body if response.body else ""
if body_content is PackedByteArray:
response_data.body = body_content.get_string_from_utf8()
else:
response_data.body = str(body_content)
return response_data

View File

@@ -55,8 +55,6 @@ func _ready():
DisplayServer.window_set_min_size(MIN_SIZE) DisplayServer.window_set_min_size(MIN_SIZE)
get_viewport().size_changed.connect(_on_viewport_size_changed) get_viewport().size_changed.connect(_on_viewport_size_changed)
call_deferred("render")
func _on_viewport_size_changed(): func _on_viewport_size_changed():
recalculate_percentage_elements(website_container) recalculate_percentage_elements(website_container)
@@ -305,7 +303,7 @@ func create_element_node(element: HTMLParser.HTMLElement, parser: HTMLParser) ->
return null return null
final_node = StyleManager.apply_element_styles(final_node, element, parser) final_node = StyleManager.apply_element_styles(final_node, element, parser)
# Flex item properties may still apply # Flex item properties may still apply
StyleManager.apply_flex_item_properties(final_node, styles) FlexUtils.apply_flex_item_properties(final_node, styles)
return final_node return final_node
if is_flex_container: if is_flex_container:
@@ -335,6 +333,9 @@ func create_element_node(element: HTMLParser.HTMLElement, parser: HTMLParser) ->
elif not element.text_content.is_empty(): elif not element.text_content.is_empty():
var new_node = await create_element_node_internal(element, parser) var new_node = await create_element_node_internal(element, parser)
container_for_children.add_child(new_node) container_for_children.add_child(new_node)
# For flex divs, we're done - no additional node creation needed
elif element.tag_name == "div":
pass
else: else:
final_node = await create_element_node_internal(element, parser) final_node = await create_element_node_internal(element, parser)
if not final_node: if not final_node:
@@ -359,10 +360,10 @@ func create_element_node(element: HTMLParser.HTMLElement, parser: HTMLParser) ->
flex_container_node = first_child flex_container_node = first_child
if flex_container_node is FlexContainer: if flex_container_node is FlexContainer:
StyleManager.apply_flex_container_properties(flex_container_node, styles) FlexUtils.apply_flex_container_properties(flex_container_node, styles)
# Apply flex ITEM properties # Apply flex ITEM properties
StyleManager.apply_flex_item_properties(final_node, styles) FlexUtils.apply_flex_item_properties(final_node, styles)
# Skip ul/ol and non-flex forms, they handle their own children # Skip ul/ol and non-flex forms, they handle their own children
var skip_general_processing = false var skip_general_processing = false
@@ -473,6 +474,11 @@ func create_element_node_internal(element: HTMLParser.HTMLElement, parser: HTMLP
"div": "div":
var styles = parser.get_element_styles_with_inheritance(element, "", []) var styles = parser.get_element_styles_with_inheritance(element, "", [])
var hover_styles = parser.get_element_styles_with_inheritance(element, "hover", []) var hover_styles = parser.get_element_styles_with_inheritance(element, "hover", [])
var is_flex_container = styles.has("display") and ("flex" in styles["display"])
# For flex divs, don't create div scene - the AutoSizingFlexContainer handles it
if is_flex_container:
return null
# Create div container # Create div container
if BackgroundUtils.needs_background_wrapper(styles) or hover_styles.size() > 0: if BackgroundUtils.needs_background_wrapper(styles) or hover_styles.size() > 0:

Binary file not shown.

View File

@@ -18,6 +18,7 @@ godot = "0.1"
tokio = { version = "1.0", features = ["rt"] } tokio = { version = "1.0", features = ["rt"] }
url = "2.5" url = "2.5"
serde_json = "1.0"
[profile.release] [profile.release]
opt-level = "z" opt-level = "z"

View File

@@ -1,6 +1,6 @@
use godot::prelude::*; use godot::prelude::*;
use gurt::prelude::*; use gurt::prelude::*;
use gurt::{GurtMethod, GurtClientConfig}; use gurt::{GurtMethod, GurtClientConfig, GurtRequest};
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
use std::sync::Arc; use std::sync::Arc;
use std::cell::RefCell; use std::cell::RefCell;
@@ -175,21 +175,27 @@ impl GurtProtocolClient {
} }
}; };
let url = format!("gurt://{}:{}{}", host, port, path); let body = options.get("body").unwrap_or("".to_variant()).to::<String>();
let response = match runtime.block_on(async { let headers_dict = options.get("headers").unwrap_or(Dictionary::new().to_variant()).to::<Dictionary>();
match method {
GurtMethod::GET => client.get(&url).await, let mut request = GurtRequest::new(method, path.to_string())
GurtMethod::POST => client.post(&url, "").await, .with_header("Host", host)
GurtMethod::PUT => client.put(&url, "").await, .with_header("User-Agent", "GURT-Client/1.0.0");
GurtMethod::DELETE => client.delete(&url).await,
GurtMethod::HEAD => client.head(&url).await, for key_variant in headers_dict.keys_array().iter_shared() {
GurtMethod::OPTIONS => client.options(&url).await, let key = key_variant.to::<String>();
GurtMethod::PATCH => client.patch(&url, "").await, if let Some(value_variant) = headers_dict.get(key_variant) {
_ => { let value = value_variant.to::<String>();
godot_print!("Unsupported method: {:?}", method); request = request.with_header(key, value);
return Err(GurtError::invalid_message("Unsupported method"));
}
} }
}
if !body.is_empty() {
request = request.with_string_body(&body);
}
let response = match runtime.block_on(async {
client.send_request(host, port, request).await
}) { }) {
Ok(resp) => resp, Ok(resp) => resp,
Err(e) => { Err(e) => {

View File

@@ -9,6 +9,8 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::time::{timeout, Duration}; use tokio::time::{timeout, Duration};
use tokio_rustls::{TlsConnector, rustls::{ClientConfig as TlsClientConfig, RootCertStore, pki_types::ServerName}}; use tokio_rustls::{TlsConnector, rustls::{ClientConfig as TlsClientConfig, RootCertStore, pki_types::ServerName}};
use std::sync::Arc; use std::sync::Arc;
use std::collections::HashMap;
use std::sync::Mutex;
use url::Url; use url::Url;
use tracing::debug; use tracing::debug;
@@ -19,6 +21,19 @@ pub struct GurtClientConfig {
pub handshake_timeout: Duration, pub handshake_timeout: Duration,
pub user_agent: String, pub user_agent: String,
pub max_redirects: usize, pub max_redirects: usize,
pub enable_connection_pooling: bool,
pub max_connections_per_host: usize,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct ConnectionKey {
host: String,
port: u16,
}
struct PooledTlsConnection {
connection: tokio_rustls::client::TlsStream<TcpStream>,
last_used: std::time::Instant,
} }
impl Default for GurtClientConfig { impl Default for GurtClientConfig {
@@ -29,6 +44,8 @@ impl Default for GurtClientConfig {
handshake_timeout: Duration::from_secs(DEFAULT_HANDSHAKE_TIMEOUT), handshake_timeout: Duration::from_secs(DEFAULT_HANDSHAKE_TIMEOUT),
user_agent: format!("GURT-Client/{}", crate::GURT_VERSION), user_agent: format!("GURT-Client/{}", crate::GURT_VERSION),
max_redirects: 5, max_redirects: 5,
enable_connection_pooling: true,
max_connections_per_host: 4,
} }
} }
} }
@@ -72,18 +89,69 @@ impl PooledConnection {
pub struct GurtClient { pub struct GurtClient {
config: GurtClientConfig, config: GurtClientConfig,
connection_pool: Arc<Mutex<HashMap<ConnectionKey, Vec<PooledTlsConnection>>>>,
} }
impl GurtClient { impl GurtClient {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
config: GurtClientConfig::default(), config: GurtClientConfig::default(),
connection_pool: Arc::new(Mutex::new(HashMap::new())),
} }
} }
pub fn with_config(config: GurtClientConfig) -> Self { pub fn with_config(config: GurtClientConfig) -> Self {
Self { Self {
config, config,
connection_pool: Arc::new(Mutex::new(HashMap::new())),
}
}
async fn get_pooled_connection(&self, host: &str, port: u16) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
if !self.config.enable_connection_pooling {
return self.perform_handshake(host, port).await;
}
let key = ConnectionKey {
host: host.to_string(),
port,
};
if let Ok(mut pool) = self.connection_pool.lock() {
if let Some(connections) = pool.get_mut(&key) {
connections.retain(|conn| conn.last_used.elapsed().as_secs() < 30);
if let Some(pooled_conn) = connections.pop() {
debug!("Reusing pooled connection for {}:{}", host, port);
return Ok(pooled_conn.connection);
}
}
}
debug!("Creating new connection for {}:{}", host, port);
self.perform_handshake(host, port).await
}
fn return_connection_to_pool(&self, host: &str, port: u16, connection: tokio_rustls::client::TlsStream<TcpStream>) {
if !self.config.enable_connection_pooling {
return;
}
let key = ConnectionKey {
host: host.to_string(),
port,
};
if let Ok(mut pool) = self.connection_pool.lock() {
let connections = pool.entry(key).or_insert_with(Vec::new);
if connections.len() < self.config.max_connections_per_host {
connections.push(PooledTlsConnection {
connection,
last_used: std::time::Instant::now(),
});
debug!("Returned connection to pool");
}
} }
} }
@@ -241,19 +309,66 @@ impl GurtClient {
async fn send_request_internal(&self, host: &str, port: u16, request: GurtRequest) -> Result<GurtResponse> { async fn send_request_internal(&self, host: &str, port: u16, request: GurtRequest) -> Result<GurtResponse> {
debug!("Sending {} {} to {}:{}", request.method, request.path, host, port); debug!("Sending {} {} to {}:{}", request.method, request.path, host, port);
let tls_stream = self.perform_handshake(host, port).await?; let mut tls_stream = self.get_pooled_connection(host, port).await?;
let mut conn = PooledConnection::with_tls(tls_stream);
let request_data = request.to_string(); let request_data = request.to_string();
conn.connection.write_all(request_data.as_bytes()).await?; tls_stream.write_all(request_data.as_bytes()).await
.map_err(|e| GurtError::connection(format!("Failed to write request: {}", e)))?;
let response_bytes = timeout( let mut buffer = Vec::new();
self.config.request_timeout, let mut temp_buffer = [0u8; 8192];
self.read_response_data(&mut conn)
).await
.map_err(|_| GurtError::timeout("Request timeout"))??;
let response = GurtResponse::parse_bytes(&response_bytes)?; let start_time = std::time::Instant::now();
let mut headers_parsed = false;
let mut expected_body_length: Option<usize> = None;
let mut headers_end_pos: Option<usize> = None;
loop {
if start_time.elapsed() > self.config.request_timeout {
return Err(GurtError::timeout("Request timeout"));
}
match timeout(Duration::from_millis(100), tls_stream.read(&mut temp_buffer)).await {
Ok(Ok(0)) => break, // Connection closed
Ok(Ok(n)) => {
buffer.extend_from_slice(&temp_buffer[..n]);
if !headers_parsed {
if let Some(pos) = buffer.windows(4).position(|w| w == b"\r\n\r\n") {
headers_end_pos = Some(pos + 4);
headers_parsed = true;
let headers_section = std::str::from_utf8(&buffer[..pos])
.map_err(|e| GurtError::invalid_message(format!("Invalid UTF-8 in headers: {}", e)))?;
for line in headers_section.lines().skip(1) {
if line.to_lowercase().starts_with("content-length:") {
if let Some(length_str) = line.split(':').nth(1) {
expected_body_length = length_str.trim().parse().ok();
}
}
}
}
}
if headers_parsed {
if let (Some(headers_end), Some(expected_len)) = (headers_end_pos, expected_body_length) {
if buffer.len() >= headers_end + expected_len {
break;
}
} else if expected_body_length.is_none() && headers_parsed {
break;
}
}
},
Ok(Err(e)) => return Err(GurtError::connection(format!("Read error: {}", e))),
Err(_) => continue,
}
}
let response = GurtResponse::parse_bytes(&buffer)?;
self.return_connection_to_pool(host, port, tls_stream);
Ok(response) Ok(response)
} }
@@ -410,6 +525,7 @@ impl Clone for GurtClient {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
config: self.config.clone(), config: self.config.clone(),
connection_pool: self.connection_pool.clone(),
} }
} }
} }