import { ApolloReactFeature } from '@thesisedu/feature-apollo-react'
import {
  ClassAndStudentGradesDocument,
  ClassAndStudentGradesQuery,
} from '@thesisedu/feature-assignments-react/dist/schema'
import { AfterSegmentEnableChangedHook } from '@thesisedu/feature-courses-react'

import { CourseAssignmentsReactFeature } from '../CourseAssignmentsReactFeature'
import { getCachedSegmentAssignmentId } from '../cache/segmentAssignmentCache'
import { ClassFragment } from '../schema'

export default function (feature: CourseAssignmentsReactFeature) {
  feature.hookManager.registerMutateHook<AfterSegmentEnableChangedHook>(
    'feature-courses-react:after-segment-enable-changed',
    async (_, { classId, segmentId, enabled }) => {
      const apolloFeature = feature.root.requireFeature<ApolloReactFeature>(
        ApolloReactFeature.package,
      )
      const client = apolloFeature.client
      const segmentIds = Array.isArray(segmentId) ? segmentId : [segmentId]
      let foundValidAssignment = false
      for (const segmentId of segmentIds) {
        // Search the cache for all assignments corresponding to the segmentId, and then update
        // their excluded state to match the new enabled state.
        const cachedAssignmentId = getCachedSegmentAssignmentId(classId, segmentId)
        if (cachedAssignmentId) {
          foundValidAssignment = true
          await client.cache.modify({
            id: client.cache.identify({ __typename: 'Assignment', id: cachedAssignmentId }),
            fields: {
              excluded(value) {
                return !enabled
              },
            },
          })
        }
      }

      // Invalid the class grades.
      if (foundValidAssignment) {
        const result = await client.query<ClassAndStudentGradesQuery>({
          query: ClassAndStudentGradesDocument,
          variables: { id: classId },
          fetchPolicy: 'network-only',
        })
        const resultClass = result.data.node?.__typename === 'Class' ? result.data.node : undefined
        if (resultClass) {
          await client.cache.modify({
            id: client.cache.identify({ __typename: 'Class', id: classId }),
            fields: {
              averageGrade: _ => {
                return resultClass.averageGrade
              },
              students: (students: ClassFragment['students']) => {
                const newEdges = students.edges.map(edge => {
                  const matchingEdge = resultClass.students.edges.find(
                    candidate => candidate.node.id === edge.node.id,
                  )
                  if (matchingEdge) {
                    return { ...edge, grade: matchingEdge.grade }
                  } else return edge
                })

                return {
                  ...students,
                  edges: newEdges,
                }
              },
            },
          })
        }
      }

      return _
    },
  )
}
