rapidforms.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from django.contrib.contenttypes.models import ContentType
  2. from django.core.exceptions import ValidationError
  3. from django.db import transaction
  4. __author__ = 'marcos'
  5. import collections
  6. from django import forms
  7. from rapid.widgets import RapidSelector, RapidRelationReadOnly, rapidAlternativesWidget
  8. from rapid.wrappers import FieldData, ModelData, InstanceData
  9. class RapidAlternativesField(forms.Field):
  10. def __init__(self, field_name, alternatives, selector_name, form, request, instance=None, *args, **kwargs):
  11. model = None
  12. if form.base_fields[selector_name].initial:
  13. model = ContentType.objects.get(pk=form.base_fields[selector_name].initial)
  14. alt = []
  15. for a in alternatives:
  16. md = ModelData(a)
  17. ft = for_model(request, a)
  18. prefix = field_name + '_' + str(md.content_type().pk)
  19. selected = model == a
  20. inst = instance if selected else None
  21. if request.method == 'POST':
  22. fm = ft(request.POST, request.FILES, prefix=prefix, instance=inst)
  23. else:
  24. fm = ft(prefix=prefix, instance=inst)
  25. alt.append((md.content_type().pk, (md, fm, selected)))
  26. alt = dict(alt)
  27. kwargs['widget'] = rapidAlternativesWidget(alt, selector_name)
  28. # noinspection PyArgumentList
  29. super(RapidAlternativesField, self).__init__(field_name, *args, **kwargs)
  30. def to_python(self, value):
  31. if value is None:
  32. return None
  33. if value.is_valid():
  34. obj = value.save(commit=False)
  35. obj.save_m2m = value.save_m2m
  36. return obj
  37. raise ValidationError(_("Invalid value"), code='invalid')
  38. def for_model(request, model, default_relations=()):
  39. default_relations_request = request.GET.get('default')
  40. widgets = []
  41. default_relations_fields = []
  42. if default_relations_request:
  43. default_relations_fields = default_relations_request.split(",")
  44. default_relations += ((x, int(y)) for (x, y) in (f.split(":") for f in default_relations_fields))
  45. default_relations_fields = [x for x, y in default_relations]
  46. for (x, y) in default_relations:
  47. f = FieldData(getattr(model, x).field, request)
  48. widgets.append((x, RapidRelationReadOnly(f.related_model())))
  49. for f in ModelData(model).local_fields():
  50. if f.is_relation() and unicode(f.bare_name()) not in default_relations_fields:
  51. if f.related_model().has_permission(request, 'select'):
  52. widgets.append((f.bare_name(), RapidSelector(f)))
  53. # ModelForm.Meta has attributes with the same names, thus I'll rename them
  54. form_model = model
  55. form_widgets = dict(widgets)
  56. class CForm(forms.ModelForm):
  57. def __init__(self, *args, **kwargs):
  58. initial = kwargs.get('initial', {})
  59. for (k, v) in default_relations:
  60. initial[k] = v
  61. if initial:
  62. kwargs['initial'] = initial
  63. for n, f in ModelData(model).rapid_alternative_data():
  64. ct = ModelData(model).field_by_name(f.ct_field).field
  65. fk = ModelData(model).field_by_name(f.fk_field).field
  66. fl = RapidAlternativesField(n, ct.alternatives, ct.name, self, request, initial)
  67. type(self.__class__).__setattr__(self.__class__, n, fl)
  68. nd = collections.OrderedDict()
  69. for k, v in self.__class__.base_fields.iteritems():
  70. if k == fk.name:
  71. nd[n] = fl
  72. else:
  73. nd[k] = v
  74. self.__class__.base_fields = nd
  75. super(CForm, self).__init__(*args, **kwargs)
  76. @transaction.atomic
  77. def save(self, commit=True):
  78. if not commit:
  79. return super(CForm, self).save(commit)
  80. else:
  81. obj = super(CForm, self).save(commit=False)
  82. for n, f in ModelData(model).rapid_alternative_data():
  83. ct = ModelData(model).field_by_name(f.ct_field).field
  84. fk = ModelData(model).field_by_name(f.fk_field).field
  85. if self.instance:
  86. fl = RapidAlternativesField(n, ct.alternatives, ct.name, self, request, self.instance)
  87. old_t = getattr(self.instance, f.ct_field)
  88. new_t = getattr(obj, f.ct_field)
  89. if old_t != new_t:
  90. getattr(self.instance, f.bare_name).delete()
  91. fob = self.cleaned_data[n]
  92. fob.save()
  93. if hasattr(fob, 'save_m2m'):
  94. fob.save_m2m()
  95. setattr(obj, f.fk_field, fob.pk)
  96. obj.save()
  97. self.save_m2m()
  98. return obj
  99. class Meta:
  100. model = form_model
  101. fields = '__all__'
  102. widgets = form_widgets
  103. return CForm