rapidforms.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # coding=utf-8
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. from __future__ import unicode_literals
  6. from django.contrib.contenttypes.models import ContentType
  7. from django.core.exceptions import ValidationError
  8. from django.db import transaction
  9. import gettext
  10. from django.utils.encoding import force_text
  11. _ = gettext.gettext
  12. __author__ = 'marcos'
  13. import collections
  14. from django import forms
  15. from rapid.widgets import RapidSelector, RapidRelationReadOnly, rapid_alternatives_widget
  16. from rapid.wrappers import FieldData, ModelData
  17. class RapidAlternativesField(forms.Field):
  18. def __init__(self, field_name, alternatives, selector_name, form, request, instance=None, *args, **kwargs):
  19. model = None
  20. if instance:
  21. model = getattr(instance, selector_name)
  22. if form.base_fields[selector_name].initial:
  23. model = ContentType.objects.get_for_id(form.base_fields[selector_name].initial)
  24. alt = []
  25. for a in alternatives:
  26. md = ModelData(a)
  27. ft = for_model(request, a)
  28. prefix = field_name + '_' + str(md.content_type().pk)
  29. selected = ContentType.objects.get_for_model(a) == model
  30. inst = getattr(instance, field_name) if selected else None
  31. if request.method == 'POST':
  32. fm = ft(request.POST, request.FILES, prefix=prefix, instance=inst)
  33. else:
  34. fm = ft(prefix=prefix, instance=inst)
  35. alt.append((md.content_type().pk, (md, fm, selected)))
  36. alt = dict(alt)
  37. kwargs['widget'] = rapid_alternatives_widget(alt, selector_name)
  38. # noinspection PyArgumentList
  39. super(RapidAlternativesField, self).__init__(field_name, *args, **kwargs)
  40. def to_python(self, value):
  41. if value is None:
  42. return None
  43. if value.is_valid():
  44. obj = value.save(commit=False)
  45. obj.save_m2m = value.save_m2m
  46. return obj
  47. raise ValidationError(_("Invalid value"), code='invalid')
  48. def for_model(request, model, default_relations=()):
  49. default_relations = list(default_relations)
  50. default_relations_request = request.GET.get('default')
  51. widgets = []
  52. default_relations_fields = []
  53. if default_relations_request:
  54. default_relations_fields = default_relations_request.split(",")
  55. default_relations += ((x, int(y)) for (x, y) in (f.split(":") for f in default_relations_fields))
  56. default_relations_fields = [x for x, y in default_relations]
  57. for (x, y) in default_relations:
  58. f = FieldData(getattr(model, x).field, request)
  59. widgets.append((x, RapidRelationReadOnly(f.related_model())))
  60. for f in ModelData(model).local_fields():
  61. if f.is_relation() and force_text(f.bare_name()) not in default_relations_fields:
  62. if f.related_model().has_permission(request, 'select'):
  63. widgets.append((f.bare_name(), RapidSelector(f)))
  64. # ModelForm.Meta has attributes with the same names, thus I'll rename them
  65. form_model = model
  66. form_widgets = dict(widgets)
  67. # noinspection PyTypeChecker
  68. class CForm(forms.ModelForm):
  69. def __init__(self, *args, **kwargs):
  70. initial = kwargs.get('initial', {})
  71. for (k, v) in default_relations:
  72. initial[k] = v
  73. if initial:
  74. kwargs['initial'] = initial
  75. instance = kwargs.get('instance')
  76. for n, fd in ModelData(model).rapid_alternative_data():
  77. ct = ModelData(model).field_by_name(fd.ct_field).field
  78. fk = ModelData(model).field_by_name(fd.fk_field).field
  79. # noinspection PyTypeChecker
  80. fl = RapidAlternativesField(n, ct.alternatives, ct.name, self, request, instance)
  81. # noinspection PyArgumentList
  82. type(self.__class__).__setattr__(self.__class__, n, fl)
  83. nd = collections.OrderedDict()
  84. for k, v in self.__class__.base_fields.items():
  85. if k == fk.name:
  86. nd[n] = fl
  87. else:
  88. nd[k] = v
  89. self.__class__.base_fields = nd
  90. super(CForm, self).__init__(*args, **kwargs)
  91. @transaction.atomic
  92. def save(self, commit=True):
  93. if not commit:
  94. return super(CForm, self).save(commit)
  95. else:
  96. obj = super(CForm, self).save(commit=False)
  97. for n, fd in ModelData(model).rapid_alternative_data():
  98. if self.instance:
  99. old_t = getattr(self.instance, fd.ct_field)
  100. new_t = getattr(obj, fd.ct_field)
  101. if old_t != new_t:
  102. getattr(self.instance, fd.bare_name).delete()
  103. fob = self.cleaned_data[n]
  104. fob.save()
  105. if hasattr(fob, 'save_m2m'):
  106. fob.save_m2m()
  107. setattr(obj, fd.fk_field, fob.pk)
  108. obj.save()
  109. self.save_m2m()
  110. return obj
  111. class Meta(object):
  112. model = form_model
  113. fields = '__all__'
  114. widgets = form_widgets
  115. return CForm