Skip to content
Snippets Groups Projects
Commit e6b29fa8 authored by geant-release-service's avatar geant-release-service
Browse files

Finished release 0.3.

parents f262efa1 c90d90d9
Branches
Tags 0.3
No related merge requests found
Showing
with 291 additions and 11 deletions
......@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup(
name="stripe-checkout",
version="0.2",
version="0.3",
author="GEANT",
author_email="swd@geant.org",
description="Stripe custom checkout support service",
......
......@@ -136,3 +136,5 @@ LOGGING = {
"level": "INFO",
},
}
STRIPE_WEBHOOK_ALLOWED_IPS = ["*"]
......@@ -30,3 +30,18 @@ STATIC_URL = os.getenv("STATIC_URL", "/static/") # noqa: F405
STATIC_ROOT = os.getenv("STATIC_ROOT", "staticfiles/") # noqa: F405
SESSION_COOKIE_SECURE = True
CSRF_COOKIE_SECURE = True
STRIPE_WEBHOOK_ALLOWED_IPS = [
"3.18.12.63",
"3.130.192.231",
"13.235.14.237",
"13.235.122.149",
"18.211.135.69",
"35.154.171.200",
"52.15.183.38",
"54.88.130.119",
"54.88.130.237",
"54.187.174.169",
"54.187.205.235",
"54.187.216.72",
]
from django.contrib import admin
from .models import PricedItem, Event, Order
from .models import ExchangeRate, PricedItem, Event, Order
@admin.register(PricedItem)
......@@ -18,3 +18,8 @@ class EventAdmin(admin.ModelAdmin):
@admin.register(Order)
class OrderAdmin(admin.ModelAdmin):
list_display = ["visitor_id", "stripe_id", "paid"]
@admin.register(ExchangeRate)
class ExchangeRateAdmin(admin.ModelAdmin):
list_display = ["date", "rate"]
import csv
import datetime
from django.core.management.base import BaseCommand
import requests
from stripe_checkout.stripe_checkout.models import ExchangeRate
from django.utils import timezone
URL = (
"https://www.trade-tariff.service.gov.uk/api/v2/exchange_rates/files/"
"monthly_csv_{year}-{month}.csv"
)
class Command(BaseCommand):
def handle(self, *args, **options):
date = timezone.now()
month_str = date.strftime("%Y-%m")
exchange_rate = ExchangeRate.objects.filter(date__month=date.month).first()
if exchange_rate:
self.stdout.write(
f"Exchange rate for {month_str}"
f" already exists ({exchange_rate.rate:.4f})"
)
return
data = self.get_data(date)
rate = self.extract_exchange_rate(data)
ExchangeRate.objects.create(rate=rate, date=date.date())
self.stdout.write(
self.style.SUCCESS(
f"Updated exchange rate for {date.strftime('%Y-%m')}: {rate:.4f}"
)
)
def get_data(self, date: datetime.date):
url = URL.format(month=date.month, year=date.year)
result = requests.get(url)
result.raise_for_status()
return result.content.decode()
def extract_exchange_rate(self, content: str):
reader = csv.DictReader(content.split("\n"))
for row in reader:
if row["Currency Code"] == "EUR":
return 1 / float(row["Currency Units per £1"])
raise ValueError("")
# Generated by Django 5.1.4 on 2025-01-14 14:38
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
(
"stripe_checkout",
"0003_rename_stripe_id_priceditem_stripe_product_id_and_more",
),
]
operations = [
migrations.AlterField(
model_name="order",
name="items",
field=models.ManyToManyField(
related_name="order", to="stripe_checkout.priceditem"
),
),
]
# Generated by Django 5.1.4 on 2025-01-14 15:39
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("stripe_checkout", "0004_alter_order_items"),
]
operations = [
migrations.CreateModel(
name="ExchangeRate",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("rate", models.FloatField()),
("date", models.DateField()),
],
),
migrations.AddField(
model_name="order",
name="exchange_rate",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="orders",
to="stripe_checkout.exchangerate",
),
),
]
......@@ -56,11 +56,23 @@ class PricedItem(models.Model):
)
class ExchangeRate(models.Model):
rate = models.FloatField()
date = models.DateField()
class Order(models.Model):
visitor_id = models.CharField(max_length=40, blank=False)
stripe_id = models.CharField(max_length=255, blank=False, unique=True)
items = models.ManyToManyField(to=PricedItem, related_name="order")
paid = models.BooleanField(default=False, db_default=False)
exchange_rate = models.ForeignKey(
ExchangeRate,
on_delete=models.SET_NULL,
blank=True,
null=True,
related_name="orders",
)
class Event(models.Model):
......
......@@ -9,6 +9,7 @@ import stripe.error
from stripe.error import StripeError # noqa F401
from django.conf import settings
from stripe_checkout.stripe_checkout.models import ExchangeRate
from stripe_checkout.stripe_checkout.visit import Visitor
if TYPE_CHECKING:
......@@ -40,7 +41,11 @@ def get_or_create_customer(visitor: Visitor) -> Optional[str]:
def create_invoice(
shopping_cart: ShoppingCart, customer_id, purchase_order=None, vat_number=None
shopping_cart: ShoppingCart,
customer_id,
purchase_order=None,
vat_number=None,
gbp_exchange_rate: Optional[ExchangeRate] = None,
):
stripe.api_key = settings.STRIPE_API_KEY
custom_fields = []
......@@ -48,6 +53,12 @@ def create_invoice(
custom_fields.append({"name": "Purchase Order", "value": purchase_order})
if vat_number:
custom_fields.append({"name": "VAT number", "value": vat_number})
if gbp_exchange_rate:
rate = gbp_exchange_rate.rate
vat = shopping_cart.vat * rate
custom_fields.append(
{"name": "GBP VAT Rate", "value": f"GBP {vat:.2f} ({rate:.4f})"}
)
invoice = stripe.Invoice.create(
customer=customer_id,
......
import functools
from django.conf import settings
from django.http import HttpResponseForbidden
ALLOW_ALL = "*"
def whitelist_ips(func=None, by_setting=None):
if not func:
return functools.partial(whitelist_ips, by_setting=by_setting)
@functools.wraps(func)
def _whitelisted_view(request):
allowed_ips = _read_whitelist_setting(by_setting, [ALLOW_ALL])
if not allowed_ips & {ALLOW_ALL, get_client_ip(request)}:
return HttpResponseForbidden("forbidden")
return func(request)
return _whitelisted_view
def _read_whitelist_setting(setting: str, default=None) -> set[str]:
result = getattr(settings, setting, default)
assert isinstance(result, list) and all(
isinstance(s, str) for s in result
), f"{setting} must be a list of strings"
return set(result)
def get_client_ip(request):
# cf. https://stackoverflow.com/a/4581997
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
if x_forwarded_for:
ip = x_forwarded_for.split(",")[0]
else:
ip = request.META.get("REMOTE_ADDR")
return ip
import logging
from typing import Union
import requests
from django import forms
from django.http import Http404, HttpResponse
from django.shortcuts import redirect, render
from django.views.decorators.http import require_POST, require_http_methods, require_GET
from django.views.decorators.csrf import csrf_exempt
import requests
from django.views.decorators.http import require_GET, require_http_methods, require_POST
from stripe_checkout.stripe_checkout.shopping_cart import ShoppingCart
from .models import Event
from . import stripe
from .models import Event, ExchangeRate
from .shopping_cart import ShoppingCart
from .utils import whitelist_ips
from .visit import VisitorAPI
logger = logging.getLogger(__name__)
......@@ -18,7 +19,7 @@ logger = logging.getLogger(__name__)
class PaymentDetailsForm(forms.Form):
purchase_order = forms.CharField(
label="Purchase Order (optional)", max_length=10, required=False
label="Purchase Order (optional)", max_length=100, required=False
)
vat_number = forms.CharField(
label="VAT Number (optional)", max_length=100, required=False
......@@ -60,11 +61,14 @@ def checkout(request, visitor_id):
def create_invoice(visitor, data):
shopping_cart = get_shopping_cart(visitor)
customer = stripe.get_or_create_customer(visitor)
exchange_rate = ExchangeRate.objects.order_by("-date").first()
return stripe.create_invoice(
shopping_cart,
customer,
purchase_order=data["purchase_order"],
vat_number=data["vat_number"],
gbp_exchange_rate=exchange_rate,
)
......@@ -85,6 +89,7 @@ def checkout_success(request, visitor_id):
@csrf_exempt
@require_POST
@whitelist_ips(by_setting="STRIPE_WEBHOOK_ALLOWED_IPS")
def stripe_event(request):
try:
event = stripe.read_event(request.body, request.headers.get("Stripe-Signature"))
......
......@@ -9,9 +9,9 @@ import responses
from django.conf import settings
from django.contrib.auth import get_user_model
from django.test import Client
from django.utils import timezone
from stripe_checkout import config
from stripe_checkout.stripe_checkout.models import ItemKind, PricedItem
from stripe_checkout.stripe_checkout.models import ExchangeRate, ItemKind, PricedItem
VISIT_RESPONSES_DIR = pathlib.Path(__file__).parent / "visit-responses"
......@@ -240,3 +240,8 @@ def config_file(stripe_api_key, visit_api_key, visit_expo_id, tmp_path):
@pytest.fixture
def client(setup_django, mock_visit, mock_stripe):
return Client()
@pytest.fixture
def default_exchange_rate():
return ExchangeRate.objects.create(date=timezone.now().date(), rate=0.8)
import re
import pytest
import responses
from django.core.management import call_command
from django.utils import timezone
from stripe_checkout.stripe_checkout.models import ExchangeRate
def fake_csv():
return "\n".join(
[
"Currency Code,Currency Units per £1",
"EUR, 1.250",
]
)
@pytest.fixture(autouse=True)
def mock_rates_api():
responses.add(
responses.GET,
re.compile(
r"https://www.trade-tariff.service.gov.uk/api/v2/exchange_rates/files/.+$"
),
body=fake_csv(),
)
@responses.activate
@pytest.mark.django_db
def test_fetches_data():
call_command("getexchangerate")
exchange_rates = ExchangeRate.objects.all()
assert len(exchange_rates) == 1
assert exchange_rates[0].date == timezone.now().date()
assert exchange_rates[0].rate == 0.8
@responses.activate
@pytest.mark.django_db
def test_fetches_data_only_once():
call_command("getexchangerate")
call_command("getexchangerate")
exchange_rates = ExchangeRate.objects.all()
assert len(exchange_rates) == 1
......@@ -2,6 +2,7 @@ import json
from unittest.mock import patch
import pytest
import responses
import stripe
from stripe_checkout.stripe_checkout.models import Event, Order
......@@ -35,6 +36,17 @@ def test_create_invoice(client, visitor_id):
assert Order.objects.count() == 1
@responses.activate
@pytest.mark.django_db
def test_exchange_rate(client, default_exchange_rate, visitor_id):
client.post(f"/checkout/{visitor_id}/", data={"payment_method": "invoice"})
call_args = stripe.Invoice.create.call_args[1]
assert call_args["custom_fields"][0] == {
"name": "GBP VAT Rate",
"value": "GBP 1.60 (0.8000)",
}
@responses.activate
@pytest.mark.django_db
def test_event_webhook(client):
......@@ -50,3 +62,23 @@ def test_event_webhook(client):
)
assert rv.status_code == 200
assert Event.objects.exists()
@responses.activate
@pytest.mark.django_db
def test_event_webhook_disallowed_when_not_whitelisted(client, settings):
settings.STRIPE_WEBHOOK_ALLOWED_IPS = ["1.1.1.1"]
with patch(
"stripe.Webhook.construct_event", side_effect=lambda b, *_: json.loads(b)
):
rv = client.post(
"/stripe-event-webhook/",
json.dumps(
{
"type": "invoice.paid",
"data": {"object": {"id": "stripe-invoice"}},
}
),
content_type="application/json",
)
assert rv.status_code == 403
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment