Skip to content
Snippets Groups Projects
Commit e6b29fa8 authored by geant-release-service's avatar geant-release-service
Browse files

Finished release 0.3.

parents f262efa1 c90d90d9
No related branches found
No related tags found
No related merge requests found
Showing
with 291 additions and 11 deletions
...@@ -2,7 +2,7 @@ from setuptools import setup, find_packages ...@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name="stripe-checkout", name="stripe-checkout",
version="0.2", version="0.3",
author="GEANT", author="GEANT",
author_email="swd@geant.org", author_email="swd@geant.org",
description="Stripe custom checkout support service", description="Stripe custom checkout support service",
......
...@@ -136,3 +136,5 @@ LOGGING = { ...@@ -136,3 +136,5 @@ LOGGING = {
"level": "INFO", "level": "INFO",
}, },
} }
STRIPE_WEBHOOK_ALLOWED_IPS = ["*"]
...@@ -30,3 +30,18 @@ STATIC_URL = os.getenv("STATIC_URL", "/static/") # noqa: F405 ...@@ -30,3 +30,18 @@ STATIC_URL = os.getenv("STATIC_URL", "/static/") # noqa: F405
STATIC_ROOT = os.getenv("STATIC_ROOT", "staticfiles/") # noqa: F405 STATIC_ROOT = os.getenv("STATIC_ROOT", "staticfiles/") # noqa: F405
SESSION_COOKIE_SECURE = True SESSION_COOKIE_SECURE = True
CSRF_COOKIE_SECURE = True CSRF_COOKIE_SECURE = True
STRIPE_WEBHOOK_ALLOWED_IPS = [
"3.18.12.63",
"3.130.192.231",
"13.235.14.237",
"13.235.122.149",
"18.211.135.69",
"35.154.171.200",
"52.15.183.38",
"54.88.130.119",
"54.88.130.237",
"54.187.174.169",
"54.187.205.235",
"54.187.216.72",
]
from django.contrib import admin from django.contrib import admin
from .models import PricedItem, Event, Order from .models import ExchangeRate, PricedItem, Event, Order
@admin.register(PricedItem) @admin.register(PricedItem)
...@@ -18,3 +18,8 @@ class EventAdmin(admin.ModelAdmin): ...@@ -18,3 +18,8 @@ class EventAdmin(admin.ModelAdmin):
@admin.register(Order) @admin.register(Order)
class OrderAdmin(admin.ModelAdmin): class OrderAdmin(admin.ModelAdmin):
list_display = ["visitor_id", "stripe_id", "paid"] list_display = ["visitor_id", "stripe_id", "paid"]
@admin.register(ExchangeRate)
class ExchangeRateAdmin(admin.ModelAdmin):
list_display = ["date", "rate"]
import csv
import datetime
from django.core.management.base import BaseCommand
import requests
from stripe_checkout.stripe_checkout.models import ExchangeRate
from django.utils import timezone
URL = (
"https://www.trade-tariff.service.gov.uk/api/v2/exchange_rates/files/"
"monthly_csv_{year}-{month}.csv"
)
class Command(BaseCommand):
def handle(self, *args, **options):
date = timezone.now()
month_str = date.strftime("%Y-%m")
exchange_rate = ExchangeRate.objects.filter(date__month=date.month).first()
if exchange_rate:
self.stdout.write(
f"Exchange rate for {month_str}"
f" already exists ({exchange_rate.rate:.4f})"
)
return
data = self.get_data(date)
rate = self.extract_exchange_rate(data)
ExchangeRate.objects.create(rate=rate, date=date.date())
self.stdout.write(
self.style.SUCCESS(
f"Updated exchange rate for {date.strftime('%Y-%m')}: {rate:.4f}"
)
)
def get_data(self, date: datetime.date):
url = URL.format(month=date.month, year=date.year)
result = requests.get(url)
result.raise_for_status()
return result.content.decode()
def extract_exchange_rate(self, content: str):
reader = csv.DictReader(content.split("\n"))
for row in reader:
if row["Currency Code"] == "EUR":
return 1 / float(row["Currency Units per £1"])
raise ValueError("")
# Generated by Django 5.1.4 on 2025-01-14 14:38
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
(
"stripe_checkout",
"0003_rename_stripe_id_priceditem_stripe_product_id_and_more",
),
]
operations = [
migrations.AlterField(
model_name="order",
name="items",
field=models.ManyToManyField(
related_name="order", to="stripe_checkout.priceditem"
),
),
]
# Generated by Django 5.1.4 on 2025-01-14 15:39
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("stripe_checkout", "0004_alter_order_items"),
]
operations = [
migrations.CreateModel(
name="ExchangeRate",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("rate", models.FloatField()),
("date", models.DateField()),
],
),
migrations.AddField(
model_name="order",
name="exchange_rate",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="orders",
to="stripe_checkout.exchangerate",
),
),
]
...@@ -56,11 +56,23 @@ class PricedItem(models.Model): ...@@ -56,11 +56,23 @@ class PricedItem(models.Model):
) )
class ExchangeRate(models.Model):
rate = models.FloatField()
date = models.DateField()
class Order(models.Model): class Order(models.Model):
visitor_id = models.CharField(max_length=40, blank=False) visitor_id = models.CharField(max_length=40, blank=False)
stripe_id = models.CharField(max_length=255, blank=False, unique=True) stripe_id = models.CharField(max_length=255, blank=False, unique=True)
items = models.ManyToManyField(to=PricedItem, related_name="order") items = models.ManyToManyField(to=PricedItem, related_name="order")
paid = models.BooleanField(default=False, db_default=False) paid = models.BooleanField(default=False, db_default=False)
exchange_rate = models.ForeignKey(
ExchangeRate,
on_delete=models.SET_NULL,
blank=True,
null=True,
related_name="orders",
)
class Event(models.Model): class Event(models.Model):
......
...@@ -9,6 +9,7 @@ import stripe.error ...@@ -9,6 +9,7 @@ import stripe.error
from stripe.error import StripeError # noqa F401 from stripe.error import StripeError # noqa F401
from django.conf import settings from django.conf import settings
from stripe_checkout.stripe_checkout.models import ExchangeRate
from stripe_checkout.stripe_checkout.visit import Visitor from stripe_checkout.stripe_checkout.visit import Visitor
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -40,7 +41,11 @@ def get_or_create_customer(visitor: Visitor) -> Optional[str]: ...@@ -40,7 +41,11 @@ def get_or_create_customer(visitor: Visitor) -> Optional[str]:
def create_invoice( def create_invoice(
shopping_cart: ShoppingCart, customer_id, purchase_order=None, vat_number=None shopping_cart: ShoppingCart,
customer_id,
purchase_order=None,
vat_number=None,
gbp_exchange_rate: Optional[ExchangeRate] = None,
): ):
stripe.api_key = settings.STRIPE_API_KEY stripe.api_key = settings.STRIPE_API_KEY
custom_fields = [] custom_fields = []
...@@ -48,6 +53,12 @@ def create_invoice( ...@@ -48,6 +53,12 @@ def create_invoice(
custom_fields.append({"name": "Purchase Order", "value": purchase_order}) custom_fields.append({"name": "Purchase Order", "value": purchase_order})
if vat_number: if vat_number:
custom_fields.append({"name": "VAT number", "value": vat_number}) custom_fields.append({"name": "VAT number", "value": vat_number})
if gbp_exchange_rate:
rate = gbp_exchange_rate.rate
vat = shopping_cart.vat * rate
custom_fields.append(
{"name": "GBP VAT Rate", "value": f"GBP {vat:.2f} ({rate:.4f})"}
)
invoice = stripe.Invoice.create( invoice = stripe.Invoice.create(
customer=customer_id, customer=customer_id,
......
import functools
from django.conf import settings
from django.http import HttpResponseForbidden
ALLOW_ALL = "*"
def whitelist_ips(func=None, by_setting=None):
if not func:
return functools.partial(whitelist_ips, by_setting=by_setting)
@functools.wraps(func)
def _whitelisted_view(request):
allowed_ips = _read_whitelist_setting(by_setting, [ALLOW_ALL])
if not allowed_ips & {ALLOW_ALL, get_client_ip(request)}:
return HttpResponseForbidden("forbidden")
return func(request)
return _whitelisted_view
def _read_whitelist_setting(setting: str, default=None) -> set[str]:
result = getattr(settings, setting, default)
assert isinstance(result, list) and all(
isinstance(s, str) for s in result
), f"{setting} must be a list of strings"
return set(result)
def get_client_ip(request):
# cf. https://stackoverflow.com/a/4581997
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
if x_forwarded_for:
ip = x_forwarded_for.split(",")[0]
else:
ip = request.META.get("REMOTE_ADDR")
return ip
import logging import logging
from typing import Union from typing import Union
import requests
from django import forms from django import forms
from django.http import Http404, HttpResponse from django.http import Http404, HttpResponse
from django.shortcuts import redirect, render from django.shortcuts import redirect, render
from django.views.decorators.http import require_POST, require_http_methods, require_GET
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
import requests from django.views.decorators.http import require_GET, require_http_methods, require_POST
from stripe_checkout.stripe_checkout.shopping_cart import ShoppingCart
from .models import Event
from . import stripe from . import stripe
from .models import Event, ExchangeRate
from .shopping_cart import ShoppingCart
from .utils import whitelist_ips
from .visit import VisitorAPI from .visit import VisitorAPI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -18,7 +19,7 @@ logger = logging.getLogger(__name__) ...@@ -18,7 +19,7 @@ logger = logging.getLogger(__name__)
class PaymentDetailsForm(forms.Form): class PaymentDetailsForm(forms.Form):
purchase_order = forms.CharField( purchase_order = forms.CharField(
label="Purchase Order (optional)", max_length=10, required=False label="Purchase Order (optional)", max_length=100, required=False
) )
vat_number = forms.CharField( vat_number = forms.CharField(
label="VAT Number (optional)", max_length=100, required=False label="VAT Number (optional)", max_length=100, required=False
...@@ -60,11 +61,14 @@ def checkout(request, visitor_id): ...@@ -60,11 +61,14 @@ def checkout(request, visitor_id):
def create_invoice(visitor, data): def create_invoice(visitor, data):
shopping_cart = get_shopping_cart(visitor) shopping_cart = get_shopping_cart(visitor)
customer = stripe.get_or_create_customer(visitor) customer = stripe.get_or_create_customer(visitor)
exchange_rate = ExchangeRate.objects.order_by("-date").first()
return stripe.create_invoice( return stripe.create_invoice(
shopping_cart, shopping_cart,
customer, customer,
purchase_order=data["purchase_order"], purchase_order=data["purchase_order"],
vat_number=data["vat_number"], vat_number=data["vat_number"],
gbp_exchange_rate=exchange_rate,
) )
...@@ -85,6 +89,7 @@ def checkout_success(request, visitor_id): ...@@ -85,6 +89,7 @@ def checkout_success(request, visitor_id):
@csrf_exempt @csrf_exempt
@require_POST @require_POST
@whitelist_ips(by_setting="STRIPE_WEBHOOK_ALLOWED_IPS")
def stripe_event(request): def stripe_event(request):
try: try:
event = stripe.read_event(request.body, request.headers.get("Stripe-Signature")) event = stripe.read_event(request.body, request.headers.get("Stripe-Signature"))
......
...@@ -9,9 +9,9 @@ import responses ...@@ -9,9 +9,9 @@ import responses
from django.conf import settings from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.test import Client from django.test import Client
from django.utils import timezone
from stripe_checkout import config from stripe_checkout import config
from stripe_checkout.stripe_checkout.models import ItemKind, PricedItem from stripe_checkout.stripe_checkout.models import ExchangeRate, ItemKind, PricedItem
VISIT_RESPONSES_DIR = pathlib.Path(__file__).parent / "visit-responses" VISIT_RESPONSES_DIR = pathlib.Path(__file__).parent / "visit-responses"
...@@ -240,3 +240,8 @@ def config_file(stripe_api_key, visit_api_key, visit_expo_id, tmp_path): ...@@ -240,3 +240,8 @@ def config_file(stripe_api_key, visit_api_key, visit_expo_id, tmp_path):
@pytest.fixture @pytest.fixture
def client(setup_django, mock_visit, mock_stripe): def client(setup_django, mock_visit, mock_stripe):
return Client() return Client()
@pytest.fixture
def default_exchange_rate():
return ExchangeRate.objects.create(date=timezone.now().date(), rate=0.8)
import re
import pytest
import responses
from django.core.management import call_command
from django.utils import timezone
from stripe_checkout.stripe_checkout.models import ExchangeRate
def fake_csv():
return "\n".join(
[
"Currency Code,Currency Units per £1",
"EUR, 1.250",
]
)
@pytest.fixture(autouse=True)
def mock_rates_api():
responses.add(
responses.GET,
re.compile(
r"https://www.trade-tariff.service.gov.uk/api/v2/exchange_rates/files/.+$"
),
body=fake_csv(),
)
@responses.activate
@pytest.mark.django_db
def test_fetches_data():
call_command("getexchangerate")
exchange_rates = ExchangeRate.objects.all()
assert len(exchange_rates) == 1
assert exchange_rates[0].date == timezone.now().date()
assert exchange_rates[0].rate == 0.8
@responses.activate
@pytest.mark.django_db
def test_fetches_data_only_once():
call_command("getexchangerate")
call_command("getexchangerate")
exchange_rates = ExchangeRate.objects.all()
assert len(exchange_rates) == 1
...@@ -2,6 +2,7 @@ import json ...@@ -2,6 +2,7 @@ import json
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import responses import responses
import stripe
from stripe_checkout.stripe_checkout.models import Event, Order from stripe_checkout.stripe_checkout.models import Event, Order
...@@ -35,6 +36,17 @@ def test_create_invoice(client, visitor_id): ...@@ -35,6 +36,17 @@ def test_create_invoice(client, visitor_id):
assert Order.objects.count() == 1 assert Order.objects.count() == 1
@responses.activate
@pytest.mark.django_db
def test_exchange_rate(client, default_exchange_rate, visitor_id):
client.post(f"/checkout/{visitor_id}/", data={"payment_method": "invoice"})
call_args = stripe.Invoice.create.call_args[1]
assert call_args["custom_fields"][0] == {
"name": "GBP VAT Rate",
"value": "GBP 1.60 (0.8000)",
}
@responses.activate @responses.activate
@pytest.mark.django_db @pytest.mark.django_db
def test_event_webhook(client): def test_event_webhook(client):
...@@ -50,3 +62,23 @@ def test_event_webhook(client): ...@@ -50,3 +62,23 @@ def test_event_webhook(client):
) )
assert rv.status_code == 200 assert rv.status_code == 200
assert Event.objects.exists() assert Event.objects.exists()
@responses.activate
@pytest.mark.django_db
def test_event_webhook_disallowed_when_not_whitelisted(client, settings):
settings.STRIPE_WEBHOOK_ALLOWED_IPS = ["1.1.1.1"]
with patch(
"stripe.Webhook.construct_event", side_effect=lambda b, *_: json.loads(b)
):
rv = client.post(
"/stripe-event-webhook/",
json.dumps(
{
"type": "invoice.paid",
"data": {"object": {"id": "stripe-invoice"}},
}
),
content_type="application/json",
)
assert rv.status_code == 403
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment